# Experiment Notebook for the TransRestU-Net

The [training script](./train.py) is based on the this notebook.

You might need to adjust the batch size to fit into your GPU memory.

In [None]:
import os
from pathlib import Path
from configs import CONFIG_DIR
from figures import FIGURES_DIR

import torch
import torch.optim as optim
from torch.utils.data import DataLoader

from hubmap.data import DATA_DIR
from hubmap.dataset import transforms as T
from hubmap.dataset import TrainDataset, ValDataset

from hubmap.experiments.TransResUNet.utils import run
from hubmap.metrics.dice_score import DiceScore
from hubmap.losses.dice_bce_loss import DiceBCELoss
from hubmap.losses.channel_weighted_dice_bce_loss import ChannelWeightedDiceBCELoss
from hubmap.visualization import visualize_detailed_results
from hubmap.visualization import visualize_detailed_results_overlay

from hubmap.training import LRScheduler
from hubmap.training import EarlyStopping

from hubmap.visualization import visualize_result

from hubmap.models.trans_res_u_net.model import TResUnet, TResUnet512

In [None]:
NUM_EPOCHS = 2
BATCH_SIZE = 8
CHECKPOINT = Path(".", "demo_trans_res_u_net")
CONTINUE_TRAINING = False
PATIENCE = 50
LR = 1e-4
BACKBONE = "resnext101_32x8d"
PRETRAINED = True

WEIGHT = torch.tensor([1, 1, 1, 1])

In [None]:
train_transforms = T.Compose(
    [
        T.ToTensor(),
        T.Resize((512, 512)),
        T.RandomHorizontalFlip(),
        T.RandomVerticalFlip(),
        T.RandomCrop((512, 512)),
    ]
)

val_transforms = T.Compose(
    [
        T.ToTensor(),
        T.Resize((512, 512)),
    ]
)

In [None]:
train_set = TrainDataset(DATA_DIR, transform=train_transforms, with_background=True)
val_set = ValDataset(DATA_DIR, transform=val_transforms, with_background=True)

In [None]:
train_loader = DataLoader(
    train_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=16
)
val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=16)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
model = TResUnet512(num_classes=4, backbone=BACKBONE, pretrained=PRETRAINED)
model = model.to(device)

optimizer = optim.Adam(model.parameters(), lr=LR)
criterion = ChannelWeightedDiceBCELoss(weights=WEIGHT.to(device))
lr_scheduler = LRScheduler(optimizer, patience=5)
early_stopping = EarlyStopping(patience=50)

result = run(
    num_epochs=NUM_EPOCHS,
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    criterion=criterion,
    device=device,
    early_stopping=early_stopping,
    lr_scheduler=lr_scheduler,
    checkpoint_name=CHECKPOINT,
    continue_training=CONTINUE_TRAINING,
)

<br>

In [None]:
loss_fig, _ = visualize_result(result)

In [None]:
data = iter(val_set)
image, target = next(data)

In [None]:
detailed = visualize_detailed_results(model, image, target, device, CHECKPOINT)

In [None]:
detailed_overlay = visualize_detailed_results_overlay(
    model, image, target, device, CHECKPOINT
)

<br>