In [None]:
%load_ext autoreload
%autoreload 2

import os.path

import gc
gc.collect()

import torch
import torch.optim as optim
from torchvision import transforms
import torch.autograd.profiler as profiler

%load_ext tensorboard

torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
torch.multiprocessing.set_start_method('spawn')
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True

from network.net import UNet3D
from network.loss import KDLoss
from network.metrics import DiceScore
from network.dataset import train_val_dataset, MRIDataset
from network.ds_transforms import ToTensor, RandomCropCollate
from network.training import Training

import sys
assert sys.version_info.major == 3, 'Not running on Python 3'

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

import logging
logging.basicConfig(level=logging.INFO, stream=sys.stdout)

In [None]:
!rm -rf ./logs/ 

epochs = 200
batch_size = 4
patch_size = (96, 48, 96)
patch_per_image = 18
lr = 4e-4
session_num = 0
ds_path = "/home/imag2/IMAG2_DL/KDCompression/Dataset/ds.npy"
profile = False

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))

In [None]:
net = UNet3D()
net.to(device, non_blocking=True)

loss_fn = KDLoss()
metric_fn = DiceScore()
optimizer = optim.Adam(net.parameters(), lr=lr)

tsfrms = transforms.Compose([ToTensor(hdf5=False)])
ds = MRIDataset(ds_path, transform=tsfrms)

ds_train, ds_val = train_val_dataset(ds, 0.2)

collate_fn_train = RandomCropCollate(patch_size)
sampler_train = torch.utils.data.RandomSampler(ds_train, replacement=True)
ds_train = torch.utils.data.DataLoader(ds_train, batch_size=batch_size,num_workers=0, sampler=sampler_train,
                                       collate_fn=lambda b: collate_fn_train.collate(b, device))

collate_fn_test = RandomCropCollate(patch_size)
sampler_test = torch.utils.data.SequentialSampler(ds_val)
ds_val = torch.utils.data.DataLoader(ds_val, batch_size=batch_size, num_workers=0, sampler=sampler_test,
                                     collate_fn=lambda b: collate_fn_test.collate(b, device))

In [None]:
%tensorboard --logdir ./logs/ --host 0.0.0.0

In [None]:
run_name = 'run-{}'.format(session_num)
print('\n\n--- Starting trial: {}'.format(run_name))
run_logdir = 'logs/' + run_name

training = Training(net, optimizer, loss_fn, metric_fn, ds_train, ds_val, patch_size, run_logdir)

if profile:
    with profiler.profile(record_shapes=True, profile_memory=True, use_cuda=True) as prof:
        training.train_model(epochs=epochs, patch_per_image=patch_per_image)
        print(prof.key_averages().table(sort_by="cpu_time_total"))
        print(prof.key_averages().table(sort_by="cpu_memory_usage"))
        prof.export_chrome_trace('trace.json')
else:
    training.train_model(epochs=epochs, patch_per_image=patch_per_image)

training.save_model(path=os.path.join(run_logdir, 'KD.pt'))
ds.close()