In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import segmentation_models_pytorch as smp

from tqdm import tqdm
from torchvision.transforms import v2
from torch.utils.data import DataLoader
from dice_loss import dice_coeff
from dataset import *

**Hyperparams**

In [None]:
BATCH_SIZE = 32
LR = 1e-4
EPOCHS = 30

**Training Loop/Setup**

In [None]:
# Data Transforms
train_transform = v2.Compose([
    v2.RandomHorizontalFlip(p=0.5),
    v2.RandomVerticalFlip(p=0.5)
])

# Dataset/Dataloaders
train = CropDataset('./crop_set/train.hdf5', transform=train_transform)
valid = CropDataset('./crop_set/valid.hdf5')
test = CropDataset('./crop_set/test.hdf5')

train_loader = DataLoader(dataset=train, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid, batch_size=len(valid), shuffle=True)
test_loader = DataLoader(dataset=test, batch_size=len(test), shuffle=True)

# For GPU Training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Model Definition
model = smp.Unet(
    encoder_name='resnet34',
    encoder_weights=None,
    in_channels=5,
    classes=3,
    activation='softmax'
    ).to(device)

# Loss and Optimizer Function
optimizer = optim.Adam(model.parameters(), lr=LR)
loss_fn = nn.CrossEntropyLoss()

In [None]:
def pred_step(samples, model):
    samples = nn.functional.normalize(samples) # Normalize inputs before passing through function
    preds = model.forward(samples)
    return preds

def eval_step(preds, labels, loss_fn, compute_dice=False):
    labels = torch.Tensor(labels).to(device)
    loss = loss_fn(preds, labels)

    if compute_dice:
        dice = dice_coeff(preds, labels) # Calculate dice coeff
        return loss, dice

    return loss


In [None]:
train_loss = []
train_dice = []
valid_loss = []
valid_dice = []
bar = tqdm(range(EPOCHS), position=0)

for epoch in bar:
    model.train()
    train_loss_e = []
    train_dice_e = []

    # Training step
    for idx, batch in enumerate(train_loader):
        samples, labels = batch
        samples, labels = samples.to(device).type(torch.cuda.FloatTensor), labels.cuda().to(device)

        preds = pred_step(samples, model)
        loss, dice = eval_step(preds, labels, loss_fn, True)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        train_loss_e.append(loss.cpu().item())
        train_dice_e.append(dice.detach().cpu())

    # Validation step
    model.eval()
    for idx, batch in enumerate(valid_loader):
        with torch.no_grad():
            samples, labels = batch
            samples, labels = samples.to(device).type(torch.cuda.FloatTensor), labels.cuda().to(device)

            preds = pred_step(samples, model)
            loss, dice = eval_step(preds, labels.type(torch.float), loss_fn, True)

            valid_loss.append(loss.item())
            valid_dice.append(dice.detach().cpu())

    print('train loss: {}, valid loss: {}, train dice: {}, valid dice: {}'.format(np.mean(train_loss_e), loss.item(), np.mean(train_dice_e), dice.item()))
    train_loss.append(np.mean(train_loss_e)) # log loss/dice
    train_dice.append(np.mean(train_dice_e))

In [None]:
# Plot loss/dice curves
fig, ax = plt.subplots(1, 2)
ax[0].plot(train_loss, label='train loss')
ax[0].plot(valid_loss, label='valid loss')
ax[1].plot(train_dice, label='train dice')
ax[1].plot(valid_dice, label='valid dice')

**Testing**

In [None]:
# Run through test loader and compute dice/loss
for idx, batch in enumerate(test_loader):
    with torch.no_grad():
        samples, labels = batch
        samples, labels = samples.to(device).type(torch.cuda.FloatTensor), labels.cuda().type(torch.cuda.FloatTensor).to(device)

        preds = pred_step(samples, model)
        loss, dice = eval_step(preds, labels, loss_fn, True)

print(loss.item())
print(dice.item())