In [None]:
# !pip install tqdm
# !pip install torch
# !pip install torchvision torchaudio
# !pip install tensorboardX
# !pip install scikit-learn
# !pip install pytorch-lightning
# !pip install git+https://github.com/ncullen93/torchsample
# !pip install nibabel
# !pip install wget
# !pip install ipywidgets
# !pip install widgetsnbextension
# !pip install tensorflow

# jupyter labextension install @jupyter-widgets/jupyterlab-manager > /dev/null
# jupyter nbextension enable --py widgetsnbextension

In [None]:
import shutil
import os
import time
from datetime import datetime
import argparse
import numpy as np
from tqdm import tqdm
import multiprocessing

import torch
import torch.nn as nn
import torchmetrics
import torch.optim as optim
from torch.autograd import Variable
from torchsample.transforms import RandomRotate, RandomTranslate, RandomFlip, ToTensor, Compose, RandomAffine
from torchvision import transforms
import torch.nn.functional as F
from tensorboardX import SummaryWriter

import model
from dataset import MRDatasetMerged
from torch.utils.data import DataLoader 

import pytorch_lightning as pl
from sklearn import metrics
from ipywidgets import IntProgress

In [None]:
!jupyter nbextension enable --py widgetsnbextension
#%load_ext tensorboard
#%tensorboard --logdir lightning_logs/

Enabling notebook extension jupyter-js-widgets/extension...
      - Validating: [32mOK[0m


In [None]:
class Args:
    def __init__(self):
        self.task = "abnormal" #['abnormal', 'acl', 'meniscus']
        self.plane = "sagittal" #['sagittal', 'coronal', 'axial']
        self.prefix_name = "Test"
        self.augment = 1 #[0, 1]
        self.lr_scheduler = "plateau" #['plateau', 'step']
        self.gamma = 0.5
        self.epochs = 1
        self.lr = 1e-5
        self.flush_history = 0 #[0, 1]
        self.save_model = 1 #[0, 1]
        self.patience = 5
        self.log_every = 100
        
args = Args()

In [None]:
def to_tensor(x):
    return torch.Tensor(x)

num_workers = multiprocessing.cpu_count() - 1

log_root_folder = "./logs/{0}/{1}/".format(args.task, args.plane)
if args.flush_history == 1:
    objects = os.listdir(log_root_folder)
    for f in objects:
        if os.path.isdir(log_root_folder + f):
            shutil.rmtree(log_root_folder + f)

now = datetime.now()
logdir = log_root_folder + now.strftime("%Y%m%d-%H%M%S") + "/"
os.makedirs(logdir)

writer = SummaryWriter(logdir)

# augmentor = Compose([
#     transforms.Lambda(to_tensor),
#     RandomRotate(25),
#     RandomTranslate([0.11, 0.11]),
#     RandomFlip(),
# #     transforms.Lambda(lambda x: x.repeat(3, 1, 1, 1).permute(1, 0, 2, 3)),
# ])

train_dataset = MRDatasetMerged('./data/', transform=None, train=True)
validation_dataset = MRDatasetMerged('./data/', train=False)

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=num_workers, drop_last=False)
validation_loader = DataLoader(validation_dataset, batch_size=1, shuffle=False, num_workers=num_workers, drop_last=False)

mrnet = model.MRNet()



In [None]:
monitor = "val_f1"

callback = pl.callbacks.ModelCheckpoint(
            monitor=f'{monitor}',
            dirpath=f'/notebooks/checkpoints_{monitor}/',
            filename='checkpoint-{epoch:02d}-{' + f'{monitor}' + ':.2f}',
            save_top_k=3,
            mode='min',
        )

In [None]:
trainer = pl.Trainer(max_epochs=1, gpus=0, callbacks=[callback]) #1

  return torch._C._cuda_getDeviceCount() > 0
GPU available: False, used: False
TPU available: False, using: 0 TPU cores


In [None]:
trainer.fit(mrnet, train_loader, validation_loader)


  | Name             | Type              | Params
-------------------------------------------------------
0 | pretrained_model | AlexNet           | 61.1 M
1 | pooling_layer    | AdaptiveAvgPool2d | 0     
2 | classifer        | Linear            | 771   
3 | train_f1         | F1                | 0     
4 | valid_f1         | F1                | 0     
5 | train_auc        | AUROC             | 0     
-------------------------------------------------------
61.1 M    Trainable params
0         Non-trainable params
61.1 M    Total params
244.406   Total estimated model params size (MB)


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

In [None]:
m = MRNet.load_from_checkpoint(callback.best_model_path)

In [None]:
m(validation_dataset[0])

In [None]:
#export model
filepath = 'model_v2.onnx'
model = mrnet
input_sample = torch.randn((64, 3, 227, 227))
model.to_onnx(filepath, input_sample, export_params=True)