In [None]:
import sys
sys.path.append("..")

## Load Dataset

In [None]:
from datasets import load_dataset

# Load the dataset
dataset = load_dataset("jtz18/skin-lesion")

## Seeding

In [None]:
import torch
import numpy as np
import random
import os

def set_seed(seed):
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)
  os.environ['PYTHONHASHSEED'] = str(seed)

  torch.backends.cudnn.deterministic = True
  torch.backends.cudnn.benchmark = False

set_seed(42)

## Init Model

In [None]:
from src.segFormer import Segformer
from src.maluNet import MALUNet
from src.unet import UNET

if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

model = Segformer().to(device)

## Init Wandb

In [None]:
# Hyperparameters etc.
import wandb
LEARNING_RATE = 1e-5

DEVICE = device
BATCH_SIZE = 16

NUM_EPOCHS = 50

NUM_WORKERS = 0

IMAGE_HEIGHT = 256
IMAGE_WIDTH = 256

PIN_MEMORY = True
LOAD_MODEL = False
CHECKPOINT_FILENAME = ""   # Model checkpoint filename if LOAD is True eg: checkpoints/checkpoint_9.pth.tar or None
CLASS = "task1"
MODEL = model.__class__.__name__

# Initialize a new run
run = wandb.init(project="50.039-DL", config={
    "learning_rate": LEARNING_RATE,
    "device": DEVICE,
    "batch_size": BATCH_SIZE,
    "num_epochs": NUM_EPOCHS,
    "num_workers": NUM_WORKERS,
    "image_height": IMAGE_HEIGHT,
    "image_width": IMAGE_WIDTH,
    "pin_memory": PIN_MEMORY,
    "load_model": LOAD_MODEL,
    "checkpoint_filename": CHECKPOINT_FILENAME,
    "class": CLASS,
    "Model": MODEL,
})

# Define Training Function

In [None]:
from tqdm import tqdm
def train_fn(loader, model, optimizer, loss_fn, scaler):
    loop = tqdm(loader)

    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=DEVICE)
        targets = targets.float().unsqueeze(1).to(device=DEVICE)

        # forward
        with torch.cuda.amp.autocast():
            predictions = model(data)
            loss = loss_fn(predictions, targets)
            wandb.log({"Training Loss": loss})

        # backward
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # update tqdm loop
        loop.set_postfix(loss=loss.item())

In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch import nn
from src.utils import *


train_transform = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Rotate(limit=35, p=1.0),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.1),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ],
)

val_transforms = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ],
)

loss_fn = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)


train_loader, val_loader, test_loader = get_loaders(
    BATCH_SIZE,
    train_transform,
    val_transforms,
    NUM_WORKERS,
    PIN_MEMORY,
)

## Training Loop

In [None]:
start_epoch = 0
if LOAD_MODEL:
    checkpoint = torch.load(CHECKPOINT_FILENAME)
    model, optimizer, start_epoch = load_checkpoint(checkpoint, model, optimizer)



check_accuracy(val_loader, model, device=DEVICE)
scaler = torch.cuda.amp.GradScaler()

for epoch in range(start_epoch, NUM_EPOCHS):
    print(f"epoch: {epoch}")
    train_fn(train_loader, model, optimizer, loss_fn, scaler)

    # save model
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer":optimizer.state_dict(),
        "epoch": epoch,
    }
    save_checkpoint(checkpoint, filename=f"checkpoint_{epoch}.pth.tar")

    # check accuracy
    check_accuracy(val_loader, model, device=DEVICE)

    # print some examples to a folder
    save_predictions_as_imgs(
        val_loader, model, folder=f"saved_images/{epoch}", device=DEVICE
    )

In [66]:
run.finish()

0,1
Dice Score,▁▅▇▇▇▇▇███▇██▇██████████████████████████
IoU Score,▁▃▆▇▆▇▇▇▇▇▇█▇▇█▇▇▇████▇▇█▇▇██▇██████████
Pixel Accuracy,▁▃▆▇▆▇▇▇▇▇▇▇▇▇▇▇▇▇█▇██▇▇██▇██▇█████████▇
Training Loss,█▅▂▄▂▃▃▂▅▃▂▃▂▂▂▂▂▂▂▂▂▂▂▁▅▂▂▂▂▁▂▃▂▂▄▂▂▂▃▂

0,1
Dice Score,0.83739
IoU Score,5.80156
Pixel Accuracy,91.52986
Training Loss,0.15081
