In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from torchvision import tv_tensors
import torch
from torch import nn
from torchvision.transforms import v2

import project.utils
%reload project
from project.config import PROJECT_ROOT

In [None]:
from project.data.cvc_clinic import ClinicDB

transforms = v2.Compose([
    v2.ToImage(),
    v2.RandomPhotometricDistort(p=1),
    # v2.RandomZoomOut(),
    # v2.RandomIoUCrop(),
    v2.RandomHorizontalFlip(p=1),
    # v2.SanitizeBoundingBoxes(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

transforms

In [None]:
dataset = ClinicDB(PROJECT_ROOT / "data", split="train", transforms=transforms)
dataset.data.shape, dataset[0][0].shape

In [None]:
train_loader = dataset.get_loader(batch_size=5, shuffle=True)
train_loader

In [None]:
base_loader = ClinicDB(PROJECT_ROOT / "data", split="train", transforms=None).get_loader(batch_size=1, shuffle=True)
for i, (data, target) in enumerate(base_loader):
    print(data.shape)
    print(type(data), type(target))
    print(data.dtype)
    plt.imshow(data[0].numpy().transpose(1, 2, 0))
    plt.show()
    plt.imshow(target["masks"][0].numpy().transpose(1, 2, 0), cmap="gray")
    plt.show()
    break

In [None]:
# Intersection over Union (IoU) loss
class IoULoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(IoULoss, self).__init__()

    def forward(self, inputs, targets):
        # Sigmoid activation
        inputs = torch.sigmoid(inputs)

        # Flattening
        inputs = inputs.view(-1)
        targets = targets.view(-1)

        # Intersection
        intersection = (inputs * targets).sum()

        # Union
        total = (inputs + targets).sum()

        # IoU
        IoU = (intersection + 1e-8) / (total - intersection + 1e-8)

        return (1 - IoU) * 100  # Scale to 0-100 to make it more interpretable

In [None]:
%reload project
from project.models.unet import UNet
from project.models.vnet.vnet import BinaryDiceLoss

torch.cuda.empty_cache()
model = UNet(n_channels=3, n_classes=1).to("cuda")
# criterion = IoULoss()
criterion = BinaryDiceLoss()
optimizer = torch.optim.AdamW(model.parameters())

In [None]:
for epoch in range(100):
    model.train()
    for i, (data, target) in enumerate(train_loader):
        mask = target["masks"]
        data, mask = data.to("cuda"), mask.to("cuda")

        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, mask)
        loss.backward()
        optimizer.step()

        if i % 20 == 0:
            print(f"Epoch {epoch}, Batch {i}, Loss: {loss.item()}")

In [None]:
model.eval()

test_transforms = v2.Compose([
    v2.ToImage(),
    # v2.RandomPhotometricDistort(p=1),
    # v2.RandomZoomOut(),
    # v2.RandomIoUCrop(),
    # v2.RandomHorizontalFlip(p=1),
    # v2.SanitizeBoundingBoxes(),
    v2.ToDtype(torch.float32, scale=True),
    # v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


normalize = v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

test_set = ClinicDB(PROJECT_ROOT / "data", split="test", transforms=test_transforms)
test_loader = test_set.get_loader(batch_size=1, shuffle=True)

for i, (data, target) in enumerate(test_loader):
    data, target = data.to("cuda"), target["masks"].to("cuda")
    output = model(normalize(data))

    plt.figure(figsize=(15, 5))
    plt.subplot(1, 3, 1)
    plt.imshow(data[0].cpu().numpy().transpose(1, 2, 0))
    plt.title("Input")
    plt.axis("off")

    plt.subplot(1, 3, 2)

    # Sigmoid activation
    output = torch.sigmoid(output)
    output = (output > 0.5).float()

    plt.imshow(output[0].cpu().numpy().transpose(1, 2, 0).squeeze(), cmap="gray")
    plt.title("Output")
    plt.axis("off")

    plt.subplot(1, 3, 3)
    plt.imshow(target[0].cpu().numpy().transpose(1, 2, 0).squeeze(), cmap="gray")
    plt.title("Target")
    plt.axis("off")

    plt.show()

    if i == 5:
        break


In [None]:
from project.config import PROJECT_ROOT

save_path = PROJECT_ROOT / "models" / "clinicdb_unet_distort.pth"
torch.save({
    "model": model.state_dict(),
    "optimizer": optimizer.state_dict(),
    "epoch": epoch
}, save_path)