In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import sys
import numpy as np
import torch 
import lightning
import matplotlib.pyplot as plt
from networks import *
from utils import *
import random
import tqdm
from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
from torch.utils.data import DataLoader
import pandas as pd
%load_ext autoreload
%autoreload 2

In [None]:
#initialize datasets
batchsize=16
path_sparse = '/data-pool/data_no_backup/ga63cun/PE/64/'
path_gt = '/data-pool/data_no_backup/ga63cun/PE/4095/'
save_path = "./model_weights/2DUNet/"

df_train = pd.read_csv("./train.csv") 
df_val = pd.read_csv("./val.csv")

#initialize training parameters
lr = 5e-5
weight_decay = 1e-2
optimizer_algo = "AdamW"
optimizer_params={"weight_decay": weight_decay}
scheduler_algo = "StepLR"
scheduler_params = {"step_size":4, "gamma":0.9}
patch_size = (256, 256)
ww = 3_000
wl = 0

In [None]:
dataset_train = SparseDataset(df = df_train, 
                 path_sparse = path_sparse, 
                 path_gt = path_gt, 
                 augmentation = True, 
                 image_size=patch_size, 
                 ww=ww, 
                 wl=wl
                             )

dataset_val = SparseDataset(df = df_val, 
                 path_sparse = path_sparse, 
                 path_gt = path_gt, 
                 augmentation = False, 
                 image_size=patch_size, 
                 ww=ww, 
                 wl=wl
                           )

dataloader_train = DataLoader(dataset_train, batch_size=batchsize, num_workers=4, shuffle=True)
dataloader_val = DataLoader(dataset_val, batch_size=batchsize, num_workers=4, shuffle=False)

In [None]:
#test dataloader

batch_sparse_train, batch_gt_train, batch_label_train = next(iter(dataloader_train))
batch_sparse_val, batch_gt_val, batch_label_val = next(iter(dataloader_val))


In [None]:
k=0

fig, ax = plt.subplots(1, 2, figsize=(6, 3))
print(batch_label_train[k])
ax[0].imshow(batch_sparse_train[k, 0], cmap='gray', vmin=0, vmax=1)
ax[1].imshow(batch_gt_train[k, 0], cmap='gray', vmin=0, vmax=1)

fig, ax = plt.subplots(1, 2, figsize=(6, 3))
print(batch_label_val[k])
ax[0].imshow(batch_sparse_val[k, 0], cmap='gray', vmin=0, vmax=1)
ax[1].imshow(batch_gt_val[k, 0], cmap='gray', vmin=0, vmax=1)


In [None]:
#initialize model
unet = UNet(n_channels=1, n_classes=1, bilinear=True).float()

In [None]:
model = LitModel(unet=unet, 
                 optimizer_algo=optimizer_algo, 
                 optimizer_params=optimizer_params,
                 loss = nn.MSELoss(reduction='mean'), 
                 lr = lr,
                 scheduler_algo="StepLR",
                 scheduler_params=scheduler_params
                   )

lr_monitor = L.pytorch.callbacks.LearningRateMonitor(logging_interval='epoch')
tblogger = TensorBoardLogger(save_path)
csvlogger = CSVLogger(save_path, version=tblogger.version)
checkpoint = ModelCheckpoint(monitor="val_loss", mode="min", save_top_k=3)
early_stopping = EarlyStopping(monitor="val_loss", mode="min", patience=7)

trainer = L.Trainer(logger=[csvlogger, tblogger], 
                    callbacks=[lr_monitor, checkpoint, early_stopping], 
                    max_epochs=400)

In [None]:
# find good initial learning rate
from lightning.pytorch.tuner import Tuner
tuner = Tuner(trainer)
lr_finder = tuner.lr_find(model, dataloader_train, min_lr=1e-7, max_lr=9e-1, num_training=150, early_stop_threshold=50)
print(lr_finder.suggestion())
plt.plot(lr_finder.results["lr"], lr_finder.results["loss"])
plt.xscale("log")

In [None]:
#test if model overfits on 2 batches
trainer = L.Trainer(logger=[csvlogger, tblogger], 
                    callbacks=[lr_monitor, checkpoint, early_stopping], 
                    max_epochs=400, overfit_batches=2)

trainer.fit(model, dataloader_train, dataloader_val)

In [None]:
#load model from checkpoint
unet = UNet(n_channels=1, n_classes=1, bilinear=True).float()

model = LitModel.load_from_checkpoint("./model_weights/2DUNet/lightning_logs/version_1/checkpoints/epoch=0-step=5.ckpt", unet=unet)