In [None]:
%load_ext autoreload
%autoreload 2
%load_ext tensorboard

import gc
gc.collect()

import torch
import torch.optim as optim
from torchvision import transforms

torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()

from network.net import UNet3D
from network.loss import KDLoss
from network.metrics import DiceScore
from network.dataset import train_val_dataset, MRIDataset
from network.ds_transforms import Normalize, CropRandomPatch, ToTensor
from network.training import Training

import os.path

import sys
assert sys.version_info.major == 3, 'Not running on Python 3'

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

import logging
logging.basicConfig(level=logging.INFO, stream=sys.stdout)

In [None]:
!rm -rf ./logs/ 

epochs = 50
batch_size = 4
patch_size = (96, 48, 96)
session_num = 0
ds_path = "/home/imag2/IMAG2_DL/KDCompression/Dataset/ds.h5"

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))

In [None]:
net = UNet3D()
net.to(device, non_blocking=True)

loss_fn = KDLoss()
metric_fn = DiceScore()
optimizer = optim.Adam(net.parameters(), lr=4e-4)

tsfrms = transforms.Compose([Normalize(),
                             CropRandomPatch(patch_size),
                             ToTensor()])
ds = MRIDataset(ds_path, transform = tsfrms)

ds_train, ds_val = train_val_dataset(ds, 0.2)

ds_train = torch.utils.data.DataLoader(ds_train, batch_size=batch_size, shuffle=True,
                                       num_workers=4, pin_memory=True, prefetch_factor=10)
ds_val = torch.utils.data.DataLoader(ds_val, batch_size=batch_size, shuffle=True,
                                     num_workers=4,pin_memory=True, prefetch_factor=10)

In [None]:
%tensorboard --logdir ./logs/ --host 0.0.0.0

In [None]:
run_name = 'run-{}'.format(session_num)
print('\n\n--- Starting trial: {}'.format(run_name))
run_logdir = 'logs/' + run_name

training = Training(net, optimizer, loss_fn, metric_fn, ds_train, ds_val, run_logdir)
training.train_model(epochs=epochs, device=device)
training.save_model(path=os.path.join(run_logdir, 'KD.pt'))

torch.cuda.synchronize()