In [None]:
%load_ext autoreload
%autoreload 2

import torch
import torch.optim as optim
from torchvision import transforms
import network

import os.path

import sys
assert sys.version_info.major == 3, 'Not running on Python 3'

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

import logging
logging.basicConfig(level=logging.INFO, stream=sys.stdout)

In [None]:
epochs = 50
batch_size = 4
patch_size = (96, 48, 96)
session_num = 0
ds_folder = "~/imag2/IMAG2_DL/APMRI-DNN/Dataset/All"

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))

In [None]:
net = network.net.UNet3D()
loss_fn = network.loss.DiceLoss()
metric_fn = network.metrics.DiceScore()
optimizer = optim.Adam(net.parameters(), lr=4e-4)

tsfrms = transforms.Compose([network.ds_transforms.Normalize(),
                             network.ds_transforms.CropRandomPatch(patch_size),
                             network.ds_transforms.ToTensor()])
ds = network.dataset.MRIDataset(ds_folder, transform = tsfrms)

ds_train, ds_val = network.dataset.train_val_dataset(ds, 0.2)

ds_train = torch.utils.data.DataLoader(ds_train, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
ds_val = torch.utils.data.DataLoader(ds_val, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)

In [None]:
run_name = 'run-{}'.format(session_num)
print('\n\n--- Starting trial: {}'.format(run_name))
run_logdir = 'tf_logs/hparam_tuning/' + run_name

net.to(device, non_blocking=True)

training = network.training.Training(net, optimizer, loss_fn, metric_fn, ds_train, ds_val, run_logdir)
training.train_model(epochs=epochs, device=device)
training.save_model(path=os.path.join(run_logdir, 'KD.pt'))