In [10]:
import sys
sys.path.append("..")

## Load Dataset

In [11]:
from datasets import load_dataset

# Load the dataset
dataset = load_dataset("jtz18/skin-lesion")

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


## Seeding

In [12]:
import torch
import numpy as np
import random
import os

def set_seed(seed):
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)
  os.environ['PYTHONHASHSEED'] = str(seed)

  torch.backends.cudnn.deterministic = True
  torch.backends.cudnn.benchmark = False

set_seed(42)

## Init Model

In [13]:
from src.segFormer2 import SegFormer
from src.maluNet import MALUNet
from src.unet import UNET
from src.unetr import UNETR
# from src.unetr import UNETR

if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

# model = UNETR(img_shape=(256, 256), embed_dim=256, num_heads= 8, dropout=0.3).to(device)
# model = SegFormer(
# in_channels=3,
# widths=[64, 128, 256, 512],
# depths=[3, 4, 6, 3],
# all_num_heads=[1, 2, 4, 8],
# patch_sizes=[7, 3, 3, 3],
# overlap_sizes=[4, 2, 2, 2],
# reduction_ratios=[8, 4, 2, 1],
# mlp_expansions=[4, 4, 4, 4],
# decoder_channels=256,
# scale_factors=[8, 4, 2, 1],
# num_classes=1,
# drop_prob=0.3,
# ).to(device)

model = UNET(dropout_rate=0, kernel_size=7).to(device)

## Init Wandb

In [14]:
# Hyperparameters etc.
import wandb
# LEARNING_RATE = 1e-5
LEARNING_RATE = 0.0001

DEVICE = device
BATCH_SIZE = 16

NUM_EPOCHS = 10

NUM_WORKERS = 0

IMAGE_HEIGHT = 256
IMAGE_WIDTH = 256

PIN_MEMORY = True
LOAD_MODEL = False
CHECKPOINT_FILENAME = ""   # Model checkpoint filename if LOAD is True eg: checkpoints/checkpoint_9.pth.tar or None
CLASS = "task1"
MODEL = model.__class__.__name__

# Initialize a new run
run = wandb.init(project="unet-skin-lesion", config={
    "learning_rate": LEARNING_RATE,
    "device": DEVICE,
    "batch_size": BATCH_SIZE,
    "num_epochs": NUM_EPOCHS,
    "num_workers": NUM_WORKERS,
    "image_height": IMAGE_HEIGHT,
    "image_width": IMAGE_WIDTH,
    "pin_memory": PIN_MEMORY,
    "load_model": LOAD_MODEL,
    "checkpoint_filename": CHECKPOINT_FILENAME,
    "class": CLASS,
    "Model": MODEL,
})

# Define Training Function

In [15]:
from tqdm import tqdm
def train_fn(loader, model, optimizer, loss_fn, scaler):
    loop = tqdm(loader)

    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=DEVICE)
        targets = targets.float().unsqueeze(1).to(device=DEVICE)

        # forward
        with torch.cuda.amp.autocast():
            predictions = model(data)
            loss = loss_fn(predictions, targets)
            wandb.log({"Training Loss": loss})

        # backward
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # update tqdm loop
        loop.set_postfix(loss=loss.item())

## Augmentations

In [16]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch import nn
from src.utils.utils import *


train_transform = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Rotate(limit=35, p=1.0),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.1),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ],
)

val_transforms = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ],
)

loss_fn = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)


train_loader, val_loader, test_loader = get_loaders(
    dataset,
    BATCH_SIZE,
    train_transform,
    val_transforms,
    NUM_WORKERS,
    PIN_MEMORY,
)

## Training Loop

In [17]:
start_epoch = 0
if LOAD_MODEL:
    checkpoint = torch.load(CHECKPOINT_FILENAME)
    model, optimizer, start_epoch = load_checkpoint(checkpoint, model, optimizer)



check_accuracy(val_loader, model, device=DEVICE, loss_fn=loss_fn)
scaler = torch.cuda.amp.GradScaler()

for epoch in range(start_epoch, NUM_EPOCHS):
    print(f"epoch: {epoch}")
    train_fn(train_loader, model, optimizer, loss_fn, scaler)

    # save model
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer":optimizer.state_dict(),
        "epoch": epoch,
    }
    save_checkpoint(checkpoint, filename=f"checkpoint_{epoch}.pth.tar")

    # check accuracy
    check_accuracy(val_loader, model, device=DEVICE, loss_fn=loss_fn)

    # print some examples to a folder
    save_predictions_as_imgs(
        val_loader, model, folder=f"saved_images/{epoch}", device=DEVICE
    )

Validation Loss: 0.8379434432302203
Got 4795034/6553600 with acc 73.17
Dice score: 0.0
epoch: 0


100%|██████████| 163/163 [08:00<00:00,  2.95s/it, loss=0.421]


=> Saving checkpoint
Validation Loss: 0.7171069213322231
Got 5853581/6553600 with acc 89.32
Dice score: 0.8041993379592896
epoch: 1


100%|██████████| 163/163 [07:51<00:00,  2.89s/it, loss=0.211]


=> Saving checkpoint
Validation Loss: 0.6899064012936184
Got 5940812/6553600 with acc 90.65
Dice score: 0.8132959008216858
epoch: 2


100%|██████████| 163/163 [07:52<00:00,  2.90s/it, loss=0.2]  


=> Saving checkpoint
Validation Loss: 0.6770203028406415
Got 5989177/6553600 with acc 91.39
Dice score: 0.8247560858726501
epoch: 3


100%|██████████| 163/163 [07:45<00:00,  2.86s/it, loss=0.196]


=> Saving checkpoint
Validation Loss: 0.67695380960192
Got 5984516/6553600 with acc 91.32
Dice score: 0.8277658820152283
epoch: 4


100%|██████████| 163/163 [07:44<00:00,  2.85s/it, loss=0.123]


=> Saving checkpoint
Validation Loss: 0.701914199760982
Got 5865700/6553600 with acc 89.50
Dice score: 0.8217136859893799
epoch: 5


100%|██████████| 163/163 [07:43<00:00,  2.84s/it, loss=0.193]


=> Saving checkpoint
Validation Loss: 0.6823528409004211
Got 5959696/6553600 with acc 90.94
Dice score: 0.8312580585479736
epoch: 6


100%|██████████| 163/163 [07:44<00:00,  2.85s/it, loss=0.218]


=> Saving checkpoint
Validation Loss: 0.6784320729119437
Got 5970009/6553600 with acc 91.10
Dice score: 0.8359078764915466
epoch: 7


100%|██████████| 163/163 [07:42<00:00,  2.84s/it, loss=0.0968]


=> Saving checkpoint
Validation Loss: 0.6584079265594482
Got 6076763/6553600 with acc 92.72
Dice score: 0.8595824241638184
epoch: 8


100%|██████████| 163/163 [07:44<00:00,  2.85s/it, loss=0.274]


=> Saving checkpoint
Validation Loss: 0.6576228567532131
Got 5996387/6553600 with acc 91.50
Dice score: 0.8417826890945435
epoch: 9


100%|██████████| 163/163 [07:43<00:00,  2.84s/it, loss=0.0844]


=> Saving checkpoint
Validation Loss: 0.6406711425100055
Got 6095093/6553600 with acc 93.00
Dice score: 0.8580296039581299


In [18]:
run.finish()

0,1
Dice Score,▁██████████
Pixel Accuracy,▁▇▇▇▇▇▇▇█▇█
Training Loss,█▇▆▇▆▅▄▅▄▄▄▄▄▅▄▂▃▃▃▃▄▂▂▂▂▂▁▂▂▂▃▃▁▄▂▁▂▁▁▁
Validation Loss,█▄▃▂▂▃▂▂▂▂▁

0,1
Dice Score,0.85803
Pixel Accuracy,93.00374
Training Loss,0.08439
Validation Loss,0.64067
