In [3]:
# !pip install tqdm
# !pip install torch
# !pip install torchvision torchaudio
# !pip install fastai
# !pip install tensorboardX
# !pip install scikit-learn
# !pip install "git+https://github.com/ncullen93/torchsample.git#egg=torchsample"
# !pip install nibabel
# !pip install nbdev
# !pip install torch==1.6.0 torchvision==0.7.0

In [12]:
import shutil
import os
import time
from datetime import datetime
import argparse
import numpy as np
from tqdm import tqdm
from nbdev import *

import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torchsample.transforms import RandomRotate, RandomTranslate, RandomFlip, ToTensor, Compose, RandomAffine
from torchvision import transforms
import torch.nn.functional as F
from tensorboardX import SummaryWriter

from dataloader import MRDataset
import model

from sklearn import metrics
from fastai.vision.all import *

In [28]:
import multiprocessing

num_workers = multiprocessing.cpu_count() - 1
num_workers

23

In [29]:
class Args:
    def __init__(self):
        self.task = "abnormal" #['abnormal', 'acl', 'meniscus']
        self.plane = "sagittal" #['sagittal', 'coronal', 'axial']
        self.prefix_name = "Test"
        self.augment = 1 #[0, 1]
        self.lr_scheduler = "plateau" #['plateau', 'step']
        self.gamma = 0.5
        self.epochs = 1
        self.lr = 1e-5
        self.flush_history = 0 #[0, 1]
        self.save_model = 1 #[0, 1]
        self.patience = 5
        self.log_every = 100
        
args = Args()

In [30]:
log_root_folder = "./logs/{0}/{1}/".format(args.task, args.plane)
if args.flush_history == 1:
    objects = os.listdir(log_root_folder)
    for f in objects:
        if os.path.isdir(log_root_folder + f):
            shutil.rmtree(log_root_folder + f)

now = datetime.now()
logdir = log_root_folder + now.strftime("%Y%m%d-%H%M%S") + "/"
os.makedirs(logdir)

writer = SummaryWriter(logdir)

augmentor = Compose([
    transforms.Lambda(lambda x: torch.Tensor(x)),
    RandomRotate(25),
    RandomTranslate([0.11, 0.11]),
    RandomFlip(),
    transforms.Lambda(lambda x: x.repeat(3, 1, 1, 1).permute(1, 0, 2, 3)),
])

train_dataset = MRDataset('./data/', args.task, args.plane, transform=augmentor, train=True)
validation_dataset = MRDataset('./data/', args.task, args.plane, train=False)

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=num_workers, drop_last=False)
validation_loader = DataLoader(validation_dataset, batch_size=1, shuffle=False, num_workers=num_workers, drop_last=False)


mrnet = model.MRNet()
bceloss = nn.BCEWithLogitsLoss()

./data/train-abnormal.csv


In [31]:
dls = DataLoaders(train_loader, validation_loader)

In [32]:
metrics = [RocAuc()]

learn = Learner(dls, mrnet, loss_func=bceloss, metrics=metrics)

In [33]:
learn.fine_tune(2)

epoch,train_loss,valid_loss,roc_auc_score,time
0,,00:00,,


In [34]:
learn.save(args.task[:3] + '_' + args.plane[:3] + '_v1')

Path('models/abn_sag_v1.pth')