In [None]:
!pip install --quiet -U torch torchmetrics lightning torchvision seaborn nibabel
!git clone https://github.com/PCiunkiewicz/ensf-619.git

In [None]:
# Optional for nightly release; may need to update cuda version tag
!pip install -U --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu117

In [None]:
# Optional for TPU support; may need to update tpu-pytorch target url
!pip install cloud-tpu-client https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-2.0-cp39-cp39-linux_x86_64.whl

In [None]:
!git -C ensf-619/ pull

In [1]:
import os
import sys
sys.path.insert(0, '/content/ensf-619/final_project')

import numpy as np
import matplotlib.pyplot as plt

import torch
from lightning.pytorch import Trainer
from lightning.pytorch.tuner import Tuner
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
from torch.utils.data import random_split, DataLoader

from paths import MODEL_PATH
from data import load_data, DANNDataset
from dann import DeepCascadeDANN
from transform import DeepCascadeTransform

NUM_WORKERS = os.cpu_count()

In [2]:
SIZE = 164

target_images, _ = load_data(size=SIZE, mode='newborn')
target_images = torch.tensor(target_images, dtype=torch.float32)
target_images = target_images.unsqueeze(1)

src_images, masks = load_data(size=SIZE, mode='negative')
src_images = torch.tensor(src_images, dtype=torch.float32)
src_images = src_images.unsqueeze(1)
masks = torch.tensor(masks, dtype=torch.float32)

transform = DeepCascadeTransform(size=SIZE)
dataset = DANNDataset(src_images, masks, target_images, transform=transform)
train_ds, val_ds = random_split(dataset, [0.8, 0.2])
val_ds.dataset.val = True

In [None]:
%reload_ext tensorboard
%tensorboard --logdir /content/drive/MyDrive/ENSF619/models/DeepCascadeDANN/lightning_logs

In [None]:
BATCH_SIZE = 32

train_dataloader = DataLoader(
    train_ds,
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=True,
    pin_memory=True,
    num_workers=NUM_WORKERS,
)

val_dataloader = DataLoader(
    val_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    drop_last=True,
    pin_memory=True,
    num_workers=NUM_WORKERS,
)

model = DeepCascadeDANN(
    # depth_str='ikikii',
    depth_str='iiiii',
    img_size=SIZE,
    beta=0.01,
    lr=1e-3,
    weight_decay=1e-5
)

trainer = Trainer(
    max_epochs=200,
    precision='16-mixed',
    default_root_dir=MODEL_PATH / 'DeepCascadeDANN',
    accelerator='auto',
    callbacks=[
        ModelCheckpoint(save_weights_only=True, mode='min', monitor='val_loss'),
        LearningRateMonitor('epoch')
    ]
)

# tuner = Tuner(trainer)
# lr_finder = tuner.lr_find(model, train_dataloader, val_dataloader)
# fig = lr_finder.plot(suggest=True)
# fig.show()

# compiled_model = torch.compile(model) # Complex64 not yet supported for compile.
trainer.fit(model, train_dataloader, val_dataloader)
model = DeepCascadeDANN.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)