In [1]:
%pip install wandb

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
def setup_file_system(in_colab):
    if in_colab:
        from google.colab import drive

        # Set the base and mount path
        MOUNT_PATH_DRIVE = '/content/drive'
        BASE_PATH = join(
            MOUNT_PATH_DRIVE, 
            "MyDrive/barco_skin_lesion_classification"
        )

        # Mount the google drive
        drive.mount(MOUNT_PATH_DRIVE)

        return BASE_PATH

    else:
        return "/workspaces/barco_skin_lesion_classification"

In [3]:
import sys
from os import chdir
from os.path import join

# Method to check if the notebook is running in colab or local
IN_COLAB = 'google.colab' in sys.modules

# Set the base path of the project
BASE_PATH = setup_file_system(IN_COLAB)

# Set the base path of the project
chdir(join(BASE_PATH, "src/"))

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [4]:
# Imports
# Utils
import matplotlib as plt
import numpy as np
import wandb
import sys
import importlib
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import datetime


# DL libraries
import torch
import torch.optim as optim
from torch import nn
from torch.utils.data import DataLoader

# User libraries
from datasets.segmentationdataset import SegmentationDataset
from models.unet_model import UNet
from trainers.segmentation_model_trainer import train_segmentation_model
from validators.segmentation_model_validator import validate_segmentation_model
from util import config, model_management

# Data

In [5]:
# Get the data
train_segmentation_dataset = SegmentationDataset(
    join(BASE_PATH, config.SEGMENTATION_DATA_PATH_TRAIN_FEATURES),
    join(BASE_PATH, config.SEGMENTATION_DATA_PATH_TRAIN_LABELS),
    config.SEGMENTATION_TRAIN_TRANSFORMATIONS_BOTH
    )

test_segmentation_dataset = SegmentationDataset(
    join(BASE_PATH, config.SEGMENTATION_DATA_PATH_TEST_FEATURES),
    join(BASE_PATH, config.SEGMENTATION_DATA_PATH_TEST_LABELS),
    config.SEGMENTATION_TEST_TRANSFORMATIONS_BOTH
    )

# Place the datasets in dataloaders
train_segmentation_dataloader = DataLoader(train_segmentation_dataset, batch_size=config.SEGMENTATION_BATCH_SIZE)
test_segmentation_dataloader = DataLoader(test_segmentation_dataset, batch_size=1)



# Setup

In [6]:
# Get the model
model = UNet(n_channels = 3, n_classes = 1)
model.to(config.DEVICE)

# Set the optimizer
optimizer = optim.Adam(model.parameters(), lr=config.SEGMENTATION_LR)

# Set the loss fn
criteria = nn.BCEWithLogitsLoss()

# Set the gradient scaler
grad_scaler = torch.cuda.amp.grad_scaler.GradScaler()


# Setup weights and biasses
wandb.login()

# Get the current time for the checkpoint name
now = datetime.datetime.now()

# Start wandb
wandb.init(
    settings=wandb.Settings(start_method="fork"),
    project="segmentation", 
    entity="dermapool",
    name=f'experiment_{now.strftime("%m_%d_%Y_%H_%M_%S")}', 
    config={
        "learning_rate": config.SEGMENTATION_LR,
        "batch_size": config.SEGMENTATION_BATCH_SIZE,
        "epochs": config.SEGMENTATION_EPOCHS,
        "image_dims": f'h: {config.SEGMENTATION_IMAGE_HEIGHT}, w: {config.SEGMENTATION_IMAGE_WIDTH}',
        "start_from_artifact": config.SEGMENTATION_START_FROM_ARTIFACT,
        "start_artifact": config.SEGMENTATION_START_ARTIFACT,
    }
)

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mrobberdg[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Currently logged in as: [33mrobberdg[0m ([33mdermapool[0m). Use [1m`wandb login --relogin`[0m to force relogin


# Training

In [7]:
# Set the variables to keep track of the best model
best_validation_loss = 10000
best_model_state = model.state_dict()

for epoch in range(config.SEGMENTATION_EPOCHS):
  # Set the model in training mode
  model.train()

  # Train the model
  total_train_loss_this_epoch = train_segmentation_model(
      model,
      optimizer,
      criteria,
      grad_scaler,
      train_segmentation_dataloader
  )
  
  # Set the model in evaluation mode
  model.eval()

  # Validate the model
  total_val_loss_this_epoch, sample_image_array = validate_segmentation_model(
      model,
      criteria,
      test_segmentation_dataloader,
      test_segmentation_dataset
  )

  # Convert the image array to a real imag object
  sample_image_array = sample_image_array.cpu()
  sample_image = Image.fromarray(np.uint8(sample_image_array) , 'L')

  # Calculate the loss values
  train_loss_this_epoch = total_train_loss_this_epoch/len(train_segmentation_dataloader.dataset)
  val_loss_this_epoch = total_val_loss_this_epoch/len(test_segmentation_dataloader.dataset)

  # Log the train loss this epoch
  wandb.log({
      'train_loss': train_loss_this_epoch,
      'val_loss': val_loss_this_epoch,
      'sample_image': wandb.Image(sample_image)
  })

  print(f'epoch: {epoch}, train_loss: {train_loss_this_epoch}, val_loss: {val_loss_this_epoch}')

  # If this is the best performing model yet, save it
  if val_loss_this_epoch < best_validation_loss:
    # Update the score
    best_validation_loss = val_loss_this_epoch

    now = datetime.datetime.now()

    # Save the model
    checkpoint_path = join(
      BASE_PATH, 
      config.SEGMENTATION_MODEL_CHECKPOINT_PATH, 
      f'chechpoint_{now.strftime("%m_%d_%Y_%H_%M_%S")}.pth'
    )
    best_model_state = model_management.save_model(model, checkpoint_path, False)

    

100%|██████████| 156/156 [08:28<00:00,  3.26s/it]
100%|██████████| 100/100 [01:10<00:00,  1.42it/s]


epoch: 0, train_loss: 0.1094826250288519, val_loss: 1.0277013778686523


100%|██████████| 156/156 [02:08<00:00,  1.21it/s]
100%|██████████| 100/100 [00:21<00:00,  4.76it/s]


epoch: 1, train_loss: 0.09074228275462734, val_loss: 0.8165408968925476


100%|██████████| 156/156 [02:09<00:00,  1.20it/s]
100%|██████████| 100/100 [00:19<00:00,  5.12it/s]


epoch: 2, train_loss: 0.08266166870653295, val_loss: 0.7462659478187561


100%|██████████| 156/156 [02:09<00:00,  1.20it/s]
100%|██████████| 100/100 [00:19<00:00,  5.10it/s]


epoch: 3, train_loss: 0.07632924247093935, val_loss: 0.6704601049423218


100%|██████████| 156/156 [02:11<00:00,  1.18it/s]
100%|██████████| 100/100 [00:19<00:00,  5.12it/s]


epoch: 4, train_loss: 0.07061323932198017, val_loss: 0.6573135852813721


100%|██████████| 156/156 [02:09<00:00,  1.21it/s]
100%|██████████| 100/100 [00:21<00:00,  4.76it/s]


epoch: 5, train_loss: 0.06635663405646298, val_loss: 0.6608635783195496


100%|██████████| 156/156 [02:09<00:00,  1.20it/s]
100%|██████████| 100/100 [00:19<00:00,  5.02it/s]


epoch: 6, train_loss: 0.06319010489543343, val_loss: 0.6435630917549133


100%|██████████| 156/156 [02:10<00:00,  1.19it/s]
100%|██████████| 100/100 [00:19<00:00,  5.01it/s]


epoch: 7, train_loss: 0.05937043384446463, val_loss: 0.6046898365020752


100%|██████████| 156/156 [02:10<00:00,  1.20it/s]
100%|██████████| 100/100 [00:19<00:00,  5.13it/s]


epoch: 8, train_loss: 0.057544558427767845, val_loss: 0.6371056437492371


100%|██████████| 156/156 [02:11<00:00,  1.19it/s]
100%|██████████| 100/100 [00:19<00:00,  5.15it/s]


epoch: 9, train_loss: 0.05475859281867623, val_loss: 0.5751019716262817


100%|██████████| 156/156 [02:09<00:00,  1.20it/s]
100%|██████████| 100/100 [00:20<00:00,  4.78it/s]


epoch: 10, train_loss: 0.05289748760348047, val_loss: 0.5643982291221619


100%|██████████| 156/156 [02:10<00:00,  1.20it/s]
100%|██████████| 100/100 [00:20<00:00,  4.98it/s]


epoch: 11, train_loss: 0.05005164957084748, val_loss: 0.5865805149078369


100%|██████████| 156/156 [02:10<00:00,  1.20it/s]
100%|██████████| 100/100 [00:19<00:00,  5.12it/s]


epoch: 12, train_loss: 0.04735689010970002, val_loss: 0.6037932634353638


100%|██████████| 156/156 [02:09<00:00,  1.20it/s]
 17%|█▋        | 17/100 [00:07<00:35,  2.33it/s]


KeyboardInterrupt: ignored

In [8]:
now = datetime.datetime.now()

# Save the final model
checkpoint_path = join(
    BASE_PATH, 
    config.SEGMENTATION_MODEL_CHECKPOINT_PATH, 
    f'chechpoint_{now.strftime("%m_%d_%Y_%H_%M_%S")}.pth'
)
best_model_state = model_management.save_model(model, checkpoint_path, True)

In [9]:
# Mark the run as finished
wandb.finish()

0,1
train_loss,█▆▅▄▄▃▃▂▂▂▂▁▁
val_loss,█▅▄▃▂▂▂▂▂▁▁▁▂

0,1
train_loss,0.04736
val_loss,0.60379
