In [2]:
# Imports
# Utils
import matplotlib as plt
import numpy as np
import wandb
import sys
import importlib
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import datetime
from os.path import join
from os import chdir


# DL libraries
import torch
import torch.optim as optim
from torch import nn
from torch.utils.data import DataLoader

# User libraries
from datasets.segmentationdataset import SegmentationDataset
from models.unet_model import UNet
from trainers.segmentation_model_trainer import train_segmentation_model
from validators.segmentation_model_validator import validate_segmentation_model
from util import config, model_management, file_management

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# Set the base path of the project
BASE_PATH = file_management.setup_file_system(config.IN_COLAB)

# Set the base path of the project
chdir(join(BASE_PATH, config.CODE_PATH))

# Data

In [3]:
# Get the data
train_segmentation_dataset = SegmentationDataset(
    join(BASE_PATH, config.SEGMENTATION_DATA_PATH_TRAIN_FEATURES),
    join(BASE_PATH, config.SEGMENTATION_DATA_PATH_TRAIN_LABELS),
    config.SEGMENTATION_TRAIN_TRANSFORMATIONS_BOTH
    )

test_segmentation_dataset = SegmentationDataset(
    join(BASE_PATH, config.SEGMENTATION_DATA_PATH_TEST_FEATURES),
    join(BASE_PATH, config.SEGMENTATION_DATA_PATH_TEST_LABELS),
    config.SEGMENTATION_TEST_TRANSFORMATIONS_BOTH
    )

# Place the datasets in dataloaders
train_segmentation_dataloader = DataLoader(train_segmentation_dataset, batch_size=config.SEGMENTATION_BATCH_SIZE)
test_segmentation_dataloader = DataLoader(test_segmentation_dataset, batch_size=1)



# Setup

In [4]:
# Get the model
model = UNet(n_channels = 3, n_classes = 1)
model.to(config.DEVICE)

# Set the optimizer
optimizer = optim.Adam(model.parameters(), lr=config.SEGMENTATION_LR)

# Set the loss fn
criteria = nn.BCEWithLogitsLoss()

# Set the gradient scaler
grad_scaler = torch.cuda.amp.grad_scaler.GradScaler()


# Setup weights and biasses
wandb.login()

# Get the current time for the checkpoint name
now = datetime.datetime.now()

# Start wandb
wandb.init(
    settings=wandb.Settings(start_method="fork"),
    project="segmentation", 
    entity="dermapool",
    name=f'experiment_{now.strftime("%m_%d_%Y_%H_%M_%S")}', 
    config={
        "learning_rate": config.SEGMENTATION_LR,
        "batch_size": config.SEGMENTATION_BATCH_SIZE,
        "epochs": config.SEGMENTATION_EPOCHS,
        "image_dims": f'h: {config.SEGMENTATION_IMAGE_HEIGHT}, w: {config.SEGMENTATION_IMAGE_WIDTH}',
        "start_from_artifact": config.SEGMENTATION_START_FROM_ARTIFACT,
        "start_artifact": config.SEGMENTATION_START_ARTIFACT,
    }
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mrobberdg[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Currently logged in as: [33mrobberdg[0m ([33mdermapool[0m). Use [1m`wandb login --relogin`[0m to force relogin


# Training

In [5]:
# Set the variables to keep track of the best model
best_validation_loss = 10000
best_model_state = model.state_dict()

for epoch in range(config.SEGMENTATION_EPOCHS):
  # Set the model in training mode
  model.train()

  # Train the model
  total_train_loss_this_epoch = train_segmentation_model(
      model,
      optimizer,
      criteria,
      grad_scaler,
      train_segmentation_dataloader
  )
  
  # Set the model in evaluation mode
  model.eval()

  # Validate the model
  total_val_loss_this_epoch, sample_image_array = validate_segmentation_model(
      model,
      criteria,
      test_segmentation_dataloader,
      test_segmentation_dataset
  )

  # Convert the image array to a real imag object
  sample_image_array = sample_image_array.cpu()
  sample_image = Image.fromarray(np.uint8(sample_image_array) , 'L')

  # Calculate the loss values
  train_loss_this_epoch = total_train_loss_this_epoch/len(train_segmentation_dataloader.dataset)
  val_loss_this_epoch = total_val_loss_this_epoch/len(test_segmentation_dataloader.dataset)

  # Log the train loss this epoch
  wandb.log({
      'train_loss': train_loss_this_epoch,
      'val_loss': val_loss_this_epoch,
      'sample_image': wandb.Image(sample_image)
  })

  print(f'epoch: {epoch}, train_loss: {train_loss_this_epoch}, val_loss: {val_loss_this_epoch}')

  # If this is the best performing model yet, save it
  if val_loss_this_epoch < best_validation_loss:
    # Update the score
    best_validation_loss = val_loss_this_epoch

    now = datetime.datetime.now()

    # Save the model
    checkpoint_path = join(
      BASE_PATH, 
      config.SEGMENTATION_MODEL_CHECKPOINT_PATH, 
      f'chechpoint_{now.strftime("%m_%d_%Y_%H_%M_%S")}.pth'
    )
    best_model_state = model_management.save_model(model, checkpoint_path, False)

    

100%|██████████| 156/156 [03:55<00:00,  1.51s/it]
100%|██████████| 100/100 [00:14<00:00,  6.73it/s]


epoch: 0, train_loss: 0.11502320831454842, val_loss: 1.5270304679870605


100%|██████████| 156/156 [03:53<00:00,  1.50s/it]
100%|██████████| 100/100 [00:14<00:00,  6.67it/s]


epoch: 1, train_loss: 0.0989519625591486, val_loss: 1.6397271156311035


100%|██████████| 156/156 [03:54<00:00,  1.50s/it]
100%|██████████| 100/100 [00:15<00:00,  6.53it/s]


epoch: 2, train_loss: 0.08908425853456989, val_loss: 1.2243667840957642


100%|██████████| 156/156 [03:54<00:00,  1.50s/it]
100%|██████████| 100/100 [00:14<00:00,  6.74it/s]


epoch: 3, train_loss: 0.08306730773416442, val_loss: 1.1008654832839966


100%|██████████| 156/156 [03:53<00:00,  1.50s/it]
100%|██████████| 100/100 [00:15<00:00,  6.66it/s]


epoch: 4, train_loss: 0.07837394417622802, val_loss: 0.8894386291503906


100%|██████████| 156/156 [03:54<00:00,  1.50s/it]
100%|██████████| 100/100 [00:15<00:00,  6.65it/s]


epoch: 5, train_loss: 0.07365298309417945, val_loss: 0.8096705675125122


100%|██████████| 156/156 [03:53<00:00,  1.50s/it]
100%|██████████| 100/100 [00:14<00:00,  6.67it/s]


epoch: 6, train_loss: 0.07016087249937111, val_loss: 0.8667360544204712


100%|██████████| 156/156 [03:54<00:00,  1.50s/it]
100%|██████████| 100/100 [00:14<00:00,  6.77it/s]


epoch: 7, train_loss: 0.06748566401893842, val_loss: 0.9088382720947266


100%|██████████| 156/156 [03:53<00:00,  1.50s/it]
100%|██████████| 100/100 [00:14<00:00,  6.71it/s]


epoch: 8, train_loss: 0.06567549100855205, val_loss: 0.9360012412071228


100%|██████████| 156/156 [03:54<00:00,  1.50s/it]
100%|██████████| 100/100 [00:14<00:00,  6.68it/s]


epoch: 9, train_loss: 0.0631255609996814, val_loss: 0.7916746139526367


100%|██████████| 156/156 [03:54<00:00,  1.50s/it]
100%|██████████| 100/100 [00:14<00:00,  6.67it/s]


epoch: 10, train_loss: 0.06096053064921042, val_loss: 0.9098976850509644


100%|██████████| 156/156 [03:53<00:00,  1.50s/it]
100%|██████████| 100/100 [00:14<00:00,  6.78it/s]


epoch: 11, train_loss: 0.058468086918354416, val_loss: 0.8926560282707214


100%|██████████| 156/156 [03:54<00:00,  1.50s/it]
100%|██████████| 100/100 [00:14<00:00,  6.70it/s]


epoch: 12, train_loss: 0.05606032196961122, val_loss: 0.834126353263855


100%|██████████| 156/156 [03:54<00:00,  1.50s/it]
100%|██████████| 100/100 [00:14<00:00,  6.68it/s]


epoch: 13, train_loss: 0.05397621474700062, val_loss: 0.7443171739578247


100%|██████████| 156/156 [03:54<00:00,  1.50s/it]
100%|██████████| 100/100 [00:15<00:00,  6.60it/s]


epoch: 14, train_loss: 0.05138131786128857, val_loss: 0.7469706535339355


100%|██████████| 156/156 [03:54<00:00,  1.50s/it]
100%|██████████| 100/100 [00:14<00:00,  6.77it/s]


epoch: 15, train_loss: 0.04922072779396198, val_loss: 0.7087028622627258


100%|██████████| 156/156 [03:54<00:00,  1.51s/it]
100%|██████████| 100/100 [00:15<00:00,  6.64it/s]


epoch: 16, train_loss: 0.0472483929506568, val_loss: 0.6903791427612305


100%|██████████| 156/156 [03:54<00:00,  1.50s/it]
100%|██████████| 100/100 [00:14<00:00,  6.74it/s]


epoch: 17, train_loss: 0.045304106968158515, val_loss: 0.6708046793937683


100%|██████████| 156/156 [03:54<00:00,  1.50s/it]
100%|██████████| 100/100 [00:14<00:00,  6.72it/s]


epoch: 18, train_loss: 0.04371555308484993, val_loss: 0.6468225121498108


100%|██████████| 156/156 [03:55<00:00,  1.51s/it]
100%|██████████| 100/100 [00:15<00:00,  6.66it/s]


epoch: 19, train_loss: 0.04257354274117287, val_loss: 0.5980347394943237


100%|██████████| 156/156 [03:54<00:00,  1.50s/it]
100%|██████████| 100/100 [00:15<00:00,  6.64it/s]


epoch: 20, train_loss: 0.04123706954330469, val_loss: 0.5794035196304321


100%|██████████| 156/156 [03:54<00:00,  1.51s/it]
100%|██████████| 100/100 [00:15<00:00,  6.62it/s]


epoch: 21, train_loss: 0.039733030338142045, val_loss: 0.5441393256187439


100%|██████████| 156/156 [03:54<00:00,  1.50s/it]
100%|██████████| 100/100 [00:14<00:00,  6.68it/s]


epoch: 22, train_loss: 0.03853542662995667, val_loss: 0.5516947507858276


100%|██████████| 156/156 [03:55<00:00,  1.51s/it]
100%|██████████| 100/100 [00:15<00:00,  6.64it/s]


epoch: 23, train_loss: 0.03720968328386474, val_loss: 0.5368618369102478


100%|██████████| 156/156 [03:55<00:00,  1.51s/it]
100%|██████████| 100/100 [00:14<00:00,  6.67it/s]


epoch: 24, train_loss: 0.03649799458819958, val_loss: 0.5743924379348755


100%|██████████| 156/156 [03:54<00:00,  1.50s/it]
100%|██████████| 100/100 [00:15<00:00,  6.66it/s]


epoch: 25, train_loss: 0.036728967480068694, val_loss: 0.5392946004867554


100%|██████████| 156/156 [03:54<00:00,  1.50s/it]
100%|██████████| 100/100 [00:15<00:00,  6.60it/s]


epoch: 26, train_loss: 0.035600906302189005, val_loss: 0.5876485705375671


100%|██████████| 156/156 [03:54<00:00,  1.50s/it]
100%|██████████| 100/100 [00:15<00:00,  6.67it/s]


epoch: 27, train_loss: 0.03527972960147078, val_loss: 0.5822862982749939


100%|██████████| 156/156 [03:54<00:00,  1.51s/it]
100%|██████████| 100/100 [00:15<00:00,  6.67it/s]


epoch: 28, train_loss: 0.03334402310293774, val_loss: 0.5350676774978638


100%|██████████| 156/156 [03:56<00:00,  1.51s/it]
100%|██████████| 100/100 [00:15<00:00,  6.64it/s]


epoch: 29, train_loss: 0.03227471717974618, val_loss: 0.540036141872406


In [6]:
now = datetime.datetime.now()

# Save the final model
checkpoint_path = join(
    BASE_PATH, 
    config.SEGMENTATION_MODEL_CHECKPOINT_PATH, 
    f'chechpoint_{now.strftime("%m_%d_%Y_%H_%M_%S")}.pth'
)
best_model_state = model_management.save_model(model, checkpoint_path, True)

In [7]:
# Mark the run as finished
wandb.finish()

0,1
train_loss,█▇▆▅▅▅▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁
val_loss,▇█▅▅▃▃▃▃▄▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁

0,1
train_loss,0.03227
val_loss,0.54004
