# Experiment Pipeline

In [None]:
import os
from pathlib import Path
from configs import CONFIG_DIR
from figures import FIGURES_DIR

import torch
import torch.optim as optim
from torch.utils.data import DataLoader

from hubmap.data import DATA_DIR
from hubmap.dataset import transforms as T
from hubmap.dataset import TrainDataset, ValDataset

from hubmap.experiments.TransResUNet.utils import run
from hubmap.experiments.TransResUNet.utils import DiceBCELoss
from hubmap.experiments.TransResUNet.utils import visualize_detailed_results
from hubmap.experiments.TransResUNet.utils import visualize_detailed_results_overlay

from hubmap.training import LRScheduler
from hubmap.training import EarlyStopping

from hubmap.visualization import visualize_result

from hubmap.models.trans_res_u_net.model import TResUnet

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class ChannelWeightedDiceBCELoss(nn.Module):
    def __init__(self, weight=None, size_average=True, weights=torch.tensor([1, 1, 1, 1])):
        super(ChannelWeightedDiceBCELoss, self).__init__()
        self.weights = weights

    def forward(self, inputs, targets, smooth=1):
        # inputs = torch.softmax(inputs, dim=1)
        w = self.weights.unsqueeze(0).repeat(inputs.size(0), 1)
        w = w.to(inputs.device)        
        inputs = torch.sigmoid(inputs)

        # inputs = inputs.reshape(-1)
        # targets = targets.reshape(-1)

        intersection = (inputs * targets).sum((-2, -1))
        dice_loss = 1 - (2.0 * intersection + smooth) / (
            inputs.sum((-2, -1)) + targets.sum((-2, -1)) + smooth
        )
        dice_loss = (dice_loss * w).mean()
        
        BCE = F.binary_cross_entropy(inputs, targets, reduction="none")
        BCE_per_channel = BCE.mean(dim=(-2, -1))
        BCE = (BCE_per_channel * w).mean()
        Dice_BCE = BCE + dice_loss

        return Dice_BCE


In [None]:
NUM_EPOCHS = 50
BATCH_SIZE = 16
CHECKPOINT = "DELETE_ME_pretrained_resnet50_trial_5_weighted_bce_dice"
CONTINUE_TRAINING = False
PATIENCE = 20
LR = 1e-4
BACKBONE = "resnet50"
PRETRAINED = True

WEIGHT = torch.tensor([0.3, 0.3, 0.2, 0.2])

In [None]:
FIGURES_CHECKPOINT_PATH = Path(FIGURES_DIR, "TransResUNet", f"{CHECKPOINT}")
os.makedirs(FIGURES_CHECKPOINT_PATH, exist_ok=True)

In [None]:
CHECKPOINT_FILE_NAME = f"{CHECKPOINT}.pt"
CHECKPOINT_NAME = Path("TransResUNet", CHECKPOINT_FILE_NAME)
config = {
    "num_epochs": NUM_EPOCHS,
    "batch_size": BATCH_SIZE,
    "checkpoint_name": CHECKPOINT_NAME,
    "patience": PATIENCE,
    "lr": LR,
    "backbone": BACKBONE,
    "pretrained": PRETRAINED,
    "figures_directory": FIGURES_CHECKPOINT_PATH,
    "weight": WEIGHT
}
os.makedirs(Path(CONFIG_DIR / CHECKPOINT_NAME).parent.resolve(), exist_ok=True)
torch.save(config, Path(CONFIG_DIR / CHECKPOINT_NAME))

In [None]:
# torchvision.transforms.ColorJitter
# 

train_transforms = T.Compose(
    [
        T.ToTensor(),
        T.Resize((256, 256)),
        T.RandomHorizontalFlip(),
        T.RandomVerticalFlip(),
        T.RandomCrop((256, 256)),
    ]
)

val_transforms = T.Compose(
    [
        T.ToTensor(),
        T.Resize((256, 256)),
    ]
)

In [None]:
train_set = TrainDataset(DATA_DIR, transform=train_transforms, with_background=True)
val_set = ValDataset(DATA_DIR, transform=val_transforms, with_background=True)

In [None]:
train_loader = DataLoader(
    train_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=16
)
val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=16)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
from torch.optim.lr_scheduler import MultiStepLR
# Set the learning rate to 1e-3 for 10 epochs, then set to 1e-4
# lrs = MultiStepLR(optimizer, milestones=[10], gamma=0.1)

In [None]:
model = TResUnet(num_classes=4, backbone=BACKBONE, pretrained=PRETRAINED)
model = model.to(device)

optimizer = optim.Adam(model.parameters(), lr=LR)
criterion = ChannelWeightedDiceBCELoss(weight=WEIGHT.to(device))
lr_scheduler = LRScheduler(optimizer, patience=5)
# lrs = MultiStepLR(optimizer, milestones=[10], gamma=0.1)
# lr_scheduler = lambda _: lrs.step()
early_stopping = None

result = run(
    num_epochs=NUM_EPOCHS,
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    criterion=criterion,
    device=device,
    early_stopping=early_stopping,
    lr_scheduler=lr_scheduler,
    checkpoint_name=CHECKPOINT_NAME,
    continue_training=CONTINUE_TRAINING,
)

<br>

In [None]:
loss_fig, benchmark_fig = visualize_result(result)

In [None]:
loss_fig.savefig(Path(FIGURES_CHECKPOINT_PATH, "results_loss.png"))
benchmark_fig.savefig(Path(FIGURES_CHECKPOINT_PATH, "results_accuracy.png"))

In [None]:
data = iter(val_set)

In [None]:
image, target = next(data)

In [None]:
detailed = visualize_detailed_results(model, image, target, device, CHECKPOINT_NAME)

In [None]:
detailed.savefig(Path(FIGURES_CHECKPOINT_PATH, "example_results.png"))

In [None]:
detailed_overlay = visualize_detailed_results_overlay(
    model, image, target, device, CHECKPOINT_NAME
)

In [None]:
detailed_overlay.savefig(Path(FIGURES_CHECKPOINT_PATH, "example_overlay.png"))

<br>