### SETUP

In [None]:
%%capture
!pip install segmentation-models-pytorch
!pip install pytorch_lightning

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Set up kaggle credentials.
!cp /content/drive/MyDrive/Kaggle/kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

In [None]:
# Donwload the dataset.
!kaggle competitions download -c hubmap-organ-segmentation

In [None]:
%%capture
# Unzip the data.
!unzip hubmap-organ-segmentation.zip
!rm hubmap-organ-segmentation.zip

### LIBRARIES

In [None]:
import os, gc
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Dataset
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Model & Modelling
import torch
from torchvision.transforms import functional as F
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from sklearn.model_selection import StratifiedShuffleSplit

# Utils
from utils.dataset import HHHHBDataset
from utils.model import HHHHBModel
from utils.viz import plot_samples, plot_sample

### DATA
Preliminary EDA [here.](https://www.kaggle.com/code/bomera/hhhhb-eda)

In [None]:
data_path = '/content' # Path to the data.
dataset = pd.read_csv(f"{data_path}/train.csv")
test = pd.read_csv(f"{data_path}/test.csv")
submission = pd.read_csv(f"{data_path}/sample_submission.csv")

In [None]:
# Add file paths.
dataset['file_path'] = data_path + '/train_images/' + dataset['id'].astype(str) + '.tiff'

### DATASET

In [None]:
# Test Augmentations.
transforms = A.Compose([
    A.Resize(640, 640),
    A.VerticalFlip(0.5),
    A.RandomRotate90(p=0.5),
    A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5),
    A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.5),
    A.RandomBrightnessContrast(p=0.5),
    ToTensorV2()
])

# No transforms.
hhhhb_dataset = HHHHBDataset(dataset, metadata=True) 
# Test transforms.
hhhhb_datasetT= HHHHBDataset(dataset, transforms=transforms, metadata=True) 
len(hhhhb_datasetT), len(hhhhb_dataset)

In [None]:
plot_samples(hhhhb_dataset, indices=[1, 4, 5, 6], annotate=True, cols=4)

In [None]:
plot_samples(hhhhb_datasetT, indices=[1, 4, 5, 6], annotate=True, cols=4, is_transformed=True)

In [None]:
# Clear RAM
del hhhhb_dataset
del hhhhb_datasetT
gc.collect()

#### VALIDATION SET

In [None]:
# Check for distributions on organs and sex.
dataset['organ'].value_counts(normalize=True)

In [None]:
# Split across the organs.
split = StratifiedShuffleSplit(n_splits=1, test_size=0.1, random_state=42)
for train_idx, val_idx in split.split(dataset, dataset['organ']):
    train, val = dataset.loc[train_idx].reset_index(drop=True), dataset.loc[val_idx].reset_index(drop=True)

In [None]:
train['organ'].value_counts(normalize=True), val['organ'].value_counts(normalize=True)

### TRAINING

In [None]:
num_workers = 2
batch_size =  4

# train transforms
train_transforms = A.Compose([
    A.Resize(640, 640),
    A.VerticalFlip(0.5),
    A.RandomRotate90(p=0.5),
    A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5),
    A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.5),
    A.RandomBrightnessContrast(p=0.5),
    A.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
    ToTensorV2()
])

# val transforms
val_transforms = A.Compose([
    A.Resize(640, 640),
    A.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
    ToTensorV2()
])


# Train set.
train_loader = DataLoader(
    HHHHBDataset(data=train, transforms=train_transforms), batch_size=batch_size, num_workers=num_workers,
    shuffle=True,pin_memory=True)

# Valid set.
val_loader = DataLoader(
    HHHHBDataset(data=val, transforms=val_transforms), batch_size=batch_size, num_workers=num_workers,
    shuffle=False, pin_memory=True)

In [None]:
# Set up model object.
hhhhb_model = HHHHBModel("unet", "resnet50", in_channels=3, out_classes=1)

# Set up call backs.
checkpoint_callback = pl.callbacks.ModelCheckpoint(monitor="val_score", mode="max", verbose=True)
early_stopping_callback = pl.callbacks.early_stopping.EarlyStopping(monitor="val_score", patience=4, mode="max",verbose=True)
model_summary_callback = pl.callbacks.ModelSummary(max_depth=1)

# Set up trainer.
trainer = pl.Trainer(
    accelerator="auto",
    callbacks=[checkpoint_callback, early_stopping_callback, model_summary_callback]
)

In [None]:
# Fit the model.
trainer.fit(
    hhhhb_model, 
    train_dataloaders=train_loader, 
    val_dataloaders=val_loader,
)

In [None]:
# hhhhb_model.trainer.callbacks[0].best_score

### TENSORBOARD

In [None]:
# %load_ext tensorboard
# %tensorboard --logdir ./lightning_logs/version_0

#### Inference