In [1]:
# File setup
from os.path import join
from os import chdir

BASE_PATH = "/workspaces/barco_skin_lesion_classification"
CODE_PATH = join(BASE_PATH, 'src/')

# Set the base path of the project
chdir(CODE_PATH)

In [2]:
# Imports
# Utils
import matplotlib as plt
import numpy as np
import wandb
import sys
import importlib
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import datetime


# DL libraries
import torch
import torch.optim as optim
from torch import nn
from torch.utils.data import DataLoader

# User libraries
from datasets.segmentationdataset import SegmentationDataset
from models.unet_model import UNet
from trainers.segmentation_model_trainer import train_segmentation_model
from validators.segmentation_model_validator import validate_segmentation_model
from util import config

  from .autonotebook import tqdm as notebook_tqdm


# Data

In [3]:
# Get the data
train_segmentation_dataset = SegmentationDataset(
    join(BASE_PATH, config.SEGMENTATION_DATA_PATH_TRAIN_FEATURES),
    join(BASE_PATH, config.SEGMENTATION_DATA_PATH_TRAIN_LABELS),
    config.SEGMENTATION_TRAIN_TRANSFORMATIONS_BOTH
    )

test_segmentation_dataset = SegmentationDataset(
    join(BASE_PATH, config.SEGMENTATION_DATA_PATH_TEST_FEATURES),
    join(BASE_PATH, config.SEGMENTATION_DATA_PATH_TEST_LABELS),
    config.SEGMENTATION_TEST_TRANSFORMATIONS_BOTH
    )

# Place the datasets in dataloaders
train_segmentation_dataloader = DataLoader(train_segmentation_dataset, batch_size=config.SEGMENTATION_BATCH_SIZE)
test_segmentation_dataloader = DataLoader(test_segmentation_dataset, batch_size=1)



# Setup

In [4]:
# Get the model
model = UNet(n_channels = 3, n_classes = 1)
model.to(config.DEVICE)

# Set the optimizer
optimizer = optim.Adam(model.parameters(), lr=config.SEGMENTATION_LR)

# Set the loss fn
criteria = nn.BCEWithLogitsLoss()

# Set the gradient scaler
grad_scaler = torch.cuda.amp.grad_scaler.GradScaler()


# Setup weights and biasses
wandb.login()

# Start wandb
wandb.init(
    settings=wandb.Settings(start_method="fork"),
    project="test-project", 
    name=f"experiment_{datetime.datetime.now()}", 
    config={
        "learning_rate": config.SEGMENTATION_LR,
        "batch_size": config.SEGMENTATION_BATCH_SIZE,
        "epochs": config.SEGMENTATION_EPOCHS,
    }
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mrobberdg[0m. Use [1m`wandb login --relogin`[0m to force relogin


# Training

In [5]:
# Set the variables to keep track of the best model
best_validation_loss = 10000
best_model_state = model.state_dict()

for epoch in range(config.SEGMENTATION_EPOCHS):
  # Set the model in training mode
  model.train()

  # Train the model
  total_train_loss_this_epoch = train_segmentation_model(
      model,
      optimizer,
      criteria,
      grad_scaler,
      train_segmentation_dataloader
  )
  
  # Set the model in evaluation mode
  model.eval()

  # Validate the model
  total_val_loss_this_epoch, sample_image_array = validate_segmentation_model(
      model,
      criteria,
      test_segmentation_dataloader,
      test_segmentation_dataset
  )

  # Convert the image array to a real imag object
  sample_image_array = sample_image_array.cpu()
  sample_image = Image.fromarray(np.uint8(sample_image_array) , 'L')

  # Calculate the loss values
  train_loss_this_epoch = total_train_loss_this_epoch/len(train_segmentation_dataloader.dataset)
  val_loss_this_epoch = total_val_loss_this_epoch/len(test_segmentation_dataloader.dataset)

  # Log the train loss this epoch
  wandb.log({
      'train_loss': train_loss_this_epoch,
      'val_loss': val_loss_this_epoch,
      'sample_image': wandb.Image(sample_image)
  })

  print(f'epoch: {epoch}, train_loss: {train_loss_this_epoch}, val_loss: {val_loss_this_epoch}')

  # If this is the best performing model yet, save it
  if val_loss_this_epoch < best_validation_loss:
    # Update the best value
    best_validation_loss = val_loss_this_epoch

    # Update the best model
    best_model_state = model.state_dict()

    # Save the best model
    checkpoint_path = join(
      BASE_PATH, 
      config.SEGMENTATION_MODEL_CHECKPOINT_PATH, 
      f'chechpoint_{datetime.datetime.now()}.pth'
    )
    torch.save(best_model_state, checkpoint_path)

100%|██████████| 78/78 [01:30<00:00,  1.17s/it]
100%|██████████| 100/100 [00:10<00:00,  9.43it/s]


epoch: 0, train_loss: 0.05720336062485634, val_loss: 1.7504621744155884


100%|██████████| 78/78 [01:29<00:00,  1.15s/it]
100%|██████████| 100/100 [00:10<00:00,  9.43it/s]


epoch: 1, train_loss: 0.048542023469853035, val_loss: 0.9512575268745422


100%|██████████| 78/78 [01:29<00:00,  1.15s/it]
100%|██████████| 100/100 [00:10<00:00,  9.32it/s]


epoch: 2, train_loss: 0.04465372947663809, val_loss: 0.8803255558013916


100%|██████████| 78/78 [01:29<00:00,  1.15s/it]
100%|██████████| 100/100 [00:10<00:00,  9.43it/s]


epoch: 3, train_loss: 0.04213714057097171, val_loss: 0.8333888053894043


100%|██████████| 78/78 [01:29<00:00,  1.15s/it]
100%|██████████| 100/100 [00:10<00:00,  9.43it/s]


epoch: 4, train_loss: 0.04023275821325582, val_loss: 0.7901238799095154


100%|██████████| 78/78 [01:30<00:00,  1.16s/it]
100%|██████████| 100/100 [00:10<00:00,  9.27it/s]


epoch: 5, train_loss: 0.03781630539760268, val_loss: 0.8197084069252014


100%|██████████| 78/78 [01:30<00:00,  1.16s/it]
100%|██████████| 100/100 [00:10<00:00,  9.38it/s]


epoch: 6, train_loss: 0.0357661241278614, val_loss: 0.7644237875938416


100%|██████████| 78/78 [01:30<00:00,  1.16s/it]
100%|██████████| 100/100 [00:10<00:00,  9.40it/s]


epoch: 7, train_loss: 0.0337447864395576, val_loss: 0.7026727795600891


100%|██████████| 78/78 [01:29<00:00,  1.15s/it]
100%|██████████| 100/100 [00:10<00:00,  9.19it/s]


epoch: 8, train_loss: 0.03220156518669251, val_loss: 0.669369637966156


100%|██████████| 78/78 [01:30<00:00,  1.15s/it]
100%|██████████| 100/100 [00:10<00:00,  9.49it/s]


epoch: 9, train_loss: 0.030741682459854945, val_loss: 0.6134169101715088


100%|██████████| 78/78 [01:30<00:00,  1.15s/it]
100%|██████████| 100/100 [00:10<00:00,  9.51it/s]


epoch: 10, train_loss: 0.030466526698951446, val_loss: 0.7008650898933411


100%|██████████| 78/78 [01:29<00:00,  1.15s/it]
100%|██████████| 100/100 [00:10<00:00,  9.30it/s]


epoch: 11, train_loss: 0.028843248903799173, val_loss: 0.6870133876800537


100%|██████████| 78/78 [01:30<00:00,  1.16s/it]
100%|██████████| 100/100 [00:10<00:00,  9.38it/s]


epoch: 12, train_loss: 0.028074813563057015, val_loss: 0.5742228031158447


100%|██████████| 78/78 [01:29<00:00,  1.15s/it]
100%|██████████| 100/100 [00:10<00:00,  9.35it/s]


epoch: 13, train_loss: 0.02622911032907659, val_loss: 0.5512775778770447


100%|██████████| 78/78 [01:30<00:00,  1.16s/it]
100%|██████████| 100/100 [00:10<00:00,  9.45it/s]


epoch: 14, train_loss: 0.024883774154742622, val_loss: 0.5598334670066833


100%|██████████| 78/78 [01:29<00:00,  1.15s/it]
100%|██████████| 100/100 [00:10<00:00,  9.40it/s]


epoch: 15, train_loss: 0.024072171476428376, val_loss: 0.6148641109466553


100%|██████████| 78/78 [01:29<00:00,  1.15s/it]
100%|██████████| 100/100 [00:10<00:00,  9.41it/s]


epoch: 16, train_loss: 0.023302145750405984, val_loss: 0.5691456198692322


100%|██████████| 78/78 [01:30<00:00,  1.16s/it]
100%|██████████| 100/100 [00:10<00:00,  9.50it/s]


epoch: 17, train_loss: 0.022250144096403572, val_loss: 0.5904369354248047


100%|██████████| 78/78 [01:29<00:00,  1.15s/it]
100%|██████████| 100/100 [00:10<00:00,  9.44it/s]


epoch: 18, train_loss: 0.021045770493142588, val_loss: 0.5633498430252075


100%|██████████| 78/78 [01:29<00:00,  1.15s/it]
100%|██████████| 100/100 [00:10<00:00,  9.44it/s]


epoch: 19, train_loss: 0.02035483481316157, val_loss: 0.5735278725624084


100%|██████████| 78/78 [01:30<00:00,  1.16s/it]
100%|██████████| 100/100 [00:10<00:00,  9.26it/s]


epoch: 20, train_loss: 0.019782895731658295, val_loss: 0.6224812865257263


100%|██████████| 78/78 [01:29<00:00,  1.15s/it]
100%|██████████| 100/100 [00:10<00:00,  9.29it/s]


epoch: 21, train_loss: 0.019205228581654136, val_loss: 0.5540577173233032


100%|██████████| 78/78 [01:30<00:00,  1.16s/it]
100%|██████████| 100/100 [00:10<00:00,  9.43it/s]


epoch: 22, train_loss: 0.018159808831494047, val_loss: 0.6215503215789795


100%|██████████| 78/78 [01:30<00:00,  1.16s/it]
100%|██████████| 100/100 [00:10<00:00,  9.23it/s]


epoch: 23, train_loss: 0.017442879240897147, val_loss: 0.6103328466415405


100%|██████████| 78/78 [01:30<00:00,  1.16s/it]
100%|██████████| 100/100 [00:10<00:00,  9.33it/s]


epoch: 24, train_loss: 0.017182936442309414, val_loss: 0.6284006237983704


100%|██████████| 78/78 [01:29<00:00,  1.15s/it]
100%|██████████| 100/100 [00:10<00:00,  9.30it/s]


epoch: 25, train_loss: 0.0170524168726723, val_loss: 0.6300985217094421


100%|██████████| 78/78 [01:30<00:00,  1.16s/it]
100%|██████████| 100/100 [00:10<00:00,  9.39it/s]


epoch: 26, train_loss: 0.016494014223480378, val_loss: 0.4839841425418854


100%|██████████| 78/78 [01:30<00:00,  1.15s/it]
100%|██████████| 100/100 [00:10<00:00,  9.36it/s]


epoch: 27, train_loss: 0.016039496292184617, val_loss: 0.6122429966926575


100%|██████████| 78/78 [01:30<00:00,  1.17s/it]
100%|██████████| 100/100 [00:10<00:00,  9.41it/s]


epoch: 28, train_loss: 0.015258515751065497, val_loss: 0.6647056937217712


100%|██████████| 78/78 [01:30<00:00,  1.16s/it]
100%|██████████| 100/100 [00:10<00:00,  9.29it/s]


epoch: 29, train_loss: 0.01578505823587358, val_loss: 0.53523188829422


100%|██████████| 78/78 [01:29<00:00,  1.15s/it]
100%|██████████| 100/100 [00:10<00:00,  9.40it/s]


epoch: 30, train_loss: 0.014958252250526844, val_loss: 0.45847088098526


100%|██████████| 78/78 [01:30<00:00,  1.15s/it]
100%|██████████| 100/100 [00:10<00:00,  9.39it/s]


epoch: 31, train_loss: 0.014032158872176861, val_loss: 0.4821939468383789


100%|██████████| 78/78 [01:30<00:00,  1.16s/it]
100%|██████████| 100/100 [00:10<00:00,  9.32it/s]


epoch: 32, train_loss: 0.013210050397236632, val_loss: 0.5287446975708008


100%|██████████| 78/78 [01:29<00:00,  1.15s/it]
100%|██████████| 100/100 [00:10<00:00,  9.38it/s]


epoch: 33, train_loss: 0.013385244090268396, val_loss: 0.48584747314453125


100%|██████████| 78/78 [01:29<00:00,  1.15s/it]
100%|██████████| 100/100 [00:10<00:00,  9.32it/s]


epoch: 34, train_loss: 0.012702202022601436, val_loss: 0.5108152627944946


100%|██████████| 78/78 [01:29<00:00,  1.15s/it]
100%|██████████| 100/100 [00:10<00:00,  9.27it/s]


epoch: 35, train_loss: 0.012883621341912576, val_loss: 0.5063464641571045


100%|██████████| 78/78 [01:30<00:00,  1.16s/it]
100%|██████████| 100/100 [00:10<00:00,  9.26it/s]


epoch: 36, train_loss: 0.012552473356509648, val_loss: 0.6562040448188782


100%|██████████| 78/78 [01:30<00:00,  1.15s/it]
100%|██████████| 100/100 [00:10<00:00,  9.43it/s]


epoch: 37, train_loss: 0.01211382275403932, val_loss: 0.6905884146690369


100%|██████████| 78/78 [01:30<00:00,  1.16s/it]
100%|██████████| 100/100 [00:10<00:00,  9.33it/s]


epoch: 38, train_loss: 0.01160703683697709, val_loss: 0.7220958471298218


100%|██████████| 78/78 [01:29<00:00,  1.15s/it]
100%|██████████| 100/100 [00:10<00:00,  9.46it/s]


epoch: 39, train_loss: 0.011480785721095926, val_loss: 0.5331467390060425


100%|██████████| 78/78 [01:29<00:00,  1.15s/it]
100%|██████████| 100/100 [00:10<00:00,  9.46it/s]


epoch: 40, train_loss: 0.01088976460929479, val_loss: 0.5931357145309448


100%|██████████| 78/78 [01:30<00:00,  1.16s/it]
100%|██████████| 100/100 [00:10<00:00,  9.47it/s]


epoch: 41, train_loss: 0.01054911385943819, val_loss: 0.5438180565834045


100%|██████████| 78/78 [01:29<00:00,  1.15s/it]
100%|██████████| 100/100 [00:10<00:00,  9.46it/s]


epoch: 42, train_loss: 0.01014353027798791, val_loss: 0.5632967948913574


100%|██████████| 78/78 [01:30<00:00,  1.15s/it]
100%|██████████| 100/100 [00:10<00:00,  9.42it/s]


epoch: 43, train_loss: 0.009962695503626811, val_loss: 0.5910130143165588


100%|██████████| 78/78 [01:30<00:00,  1.16s/it]
100%|██████████| 100/100 [00:10<00:00,  9.46it/s]


epoch: 44, train_loss: 0.0095821315633266, val_loss: 0.5358220338821411


100%|██████████| 78/78 [01:30<00:00,  1.16s/it]
 47%|████▋     | 47/100 [00:06<00:01, 27.07it/s]

In [None]:
# Save the best model
checkpoint_path = join(
    BASE_PATH, 
    config.SEGMENTATION_MODEL_CHECKPOINT_PATH, 
    f'chechpoint_{datetime.datetime.now()}.pth'
  )
torch.save(best_model_state, checkpoint_path)


In [None]:
# Mark the run as finished
wandb.finish()