In [None]:
import os
import math

import matplotlib.pyplot as plt
import torchvision
import numpy as np
import torch
from HemaDataset import HemaDataset
from HemaModel import HemaModel
from DINOv2ForRadiology.dinov2.eval.segmentation.utils import UNetDecoder
from DINOv2ForRadiology.dinov2.data.transforms import make_segmentation_train_transforms, make_segmentation_eval_transforms

from monai.losses.dice import DiceLoss, DiceCELoss
from monai.metrics import DiceMetric

In [None]:
image_size = 448
root = f"{os.getcwd() + os.sep}data{os.sep}Dataset011_Cell"

train_image_transform, train_target_transform  = make_segmentation_train_transforms(resize_size=image_size)
eval_image_transform, eval_target_transform  = make_segmentation_eval_transforms(resize_size=image_size)

train_dataset = HemaDataset(root=root, split="train", seg_entire_cell=True, image_transform=train_image_transform, target_transform=train_target_transform)
val_dataset = HemaDataset(root=root, split="val", seg_entire_cell=True, image_transform=eval_image_transform, target_transform=eval_target_transform)

In [None]:
batch_size = 4
epochs = 100
epoch_length = math.ceil(len(train_dataset) / batch_size)
max_iter = epoch_length * epochs 

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=False)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=False)

encoder = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14').cuda()
decoder = UNetDecoder(in_channels=encoder.embed_dim, out_channels=3, image_size=448, resize_image=True).cuda()
model = HemaModel(encoder=encoder, decoder=decoder)

if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs!")
    model = torch.nn.DataParallel(model)    

optim_param_groups = [
    {"params": encoder.parameters(), "lr": 1e-5},
    {"params": decoder.parameters(), "lr": 1e-3}
]

loss_fn = DiceCELoss(sigmoid=True)
optimizer = torch.optim.SGD(optim_param_groups, momentum=0.9, weight_decay=0)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, max_iter, eta_min=0)

In [None]:
for epoch in range(epochs):
    for image, mask in train_loader:

        image = image.cuda(non_blocking=True)
        mask = mask.cuda(non_blocking=True)

        output = model(image)

        loss = loss_fn(output, mask)

        # compute the gradients
        optimizer.zero_grad()
        loss.backward()

        # step
        optimizer.step()
        scheduler.step()
        
if epoch % 10 == 0:
    eval_metric = DiceMetric(include_background=False, reduction="none") # bug with reduction "mean," will do it manually.
    print(f"epoch: {epoch}, loss for last iteration in epoch {loss}, ", end="")
    for image, mask in val_loader:
        image = image.cuda(non_blocking=True)
        mask = mask.cuda(non_blocking=True)

        output = model(image)

        eval_metric(y_pred=output, y=mask)
        
    dice = eval_metric.aggregate().mean(axis=1).mean(axis=0) # take average across classes first (channels), then across batch. 
    print(f"dice: {dice}") 