In [None]:
def setup_file_system(in_colab):
    if in_colab:
        from google.colab import drive

        # Set the base and mount path
        MOUNT_PATH_DRIVE = '/content/drive'
        BASE_PATH = join(
            MOUNT_PATH_DRIVE, 
            "MyDrive/barco_skin_lesion_classification"
        )

        # Mount the google drive
        drive.mount(MOUNT_PATH_DRIVE)

        return BASE_PATH

    else:
        return "/workspaces/barco_skin_lesion_classification"

In [None]:
import sys
from os import chdir
from os.path import join

# Method to check if the notebook is running in colab or local
IN_COLAB = 'google.colab' in sys.modules

# Set the base path of the project
BASE_PATH = setup_file_system(IN_COLAB)

# Set the base path of the project
chdir(join(BASE_PATH, "src/"))

# Data

# Setup

In [None]:
# Set the variables to keep track of the best model
best_validation_loss = 10000
best_model_state = model.state_dict()

for epoch in range(config.SEGMENTATION_EPOCHS):
  # Set the model in training mode
  model.train()

  # Train the model
  total_train_loss_this_epoch = train_segmentation_model(
      model,
      optimizer,
      criteria,
      grad_scaler,
      train_segmentation_dataloader
  )
  
  # Set the model in evaluation mode
  model.eval()

  # Validate the model
  total_val_loss_this_epoch, sample_image_array = validate_segmentation_model(
      model,
      criteria,
      test_segmentation_dataloader,
      test_segmentation_dataset
  )

  # Convert the image array to a real imag object
  sample_image_array = sample_image_array.cpu()
  sample_image = Image.fromarray(np.uint8(sample_image_array) , 'L')

  # Calculate the loss values
  train_loss_this_epoch = total_train_loss_this_epoch/len(train_segmentation_dataloader.dataset)
  val_loss_this_epoch = total_val_loss_this_epoch/len(test_segmentation_dataloader.dataset)

  # Log the train loss this epoch
  wandb.log({
      'train_loss': train_loss_this_epoch,
      'val_loss': val_loss_this_epoch,
      'sample_image': wandb.Image(sample_image)
  })

  print(f'epoch: {epoch}, train_loss: {train_loss_this_epoch}, val_loss: {val_loss_this_epoch}')

  # If this is the best performing model yet, save it
  if val_loss_this_epoch < best_validation_loss:
    # Update the score
    best_validation_loss = val_loss_this_epoch

    now = datetime.datetime.now()

    # Save the model
    checkpoint_path = join(
      BASE_PATH, 
      config.SEGMENTATION_MODEL_CHECKPOINT_PATH, 
      f'chechpoint_{now.strftime("%m_%d_%Y_%H_%M_%S")}.pth'
    )
    best_model_state = model_management.save_model(model, checkpoint_path, False)

    

In [None]:
# Mark the run as finished
wandb.finish()