In [99]:
%matplotlib notebook
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torchvision import transforms, datasets
import torchvision.transforms as transforms
from transform import Resize, RandomZoom, RandomHorizontalFlip

import matplotlib.pyplot as plt
from einops import rearrange, repeat
import numpy as np
import argparse
import os
from data import build_oasis, build_mprage
from train import train
from metrics import dice
from models.unet import UNet

%load_ext autoreload
%autoreload 2

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = '0'

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [100]:
pre_train_mean = 0.4
pre_train_std = 0.25

In [194]:
transform_train =  transforms.Compose([
                                        Resize(100),
                                        RandomZoom(.15),
                                        RandomHorizontalFlip(.5)
                                    ])

transform_val = None

In [195]:
# train_set    = build_oasis(root='/scratch/backUps/jzopes/data/oasis_project/Transformer/', train=True, transform=transform_train)
# val_set      = build_oasis(root='/scratch/backUps/jzopes/data/oasis_project/Transformer/', train=False, transform=transform_val)


train_set    = build_mprage(root='/scratch/mplatscher/imaging_data/', train=True, train_size=0.8, transform=transform_train)
val_set      = build_mprage(root='/scratch/mplatscher/imaging_data/', train=False, train_size=0.8, transform=transform_val)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=32, shuffle=True, num_workers=1)
val_loader   = torch.utils.data.DataLoader(val_set, batch_size=32, shuffle=True, num_workers=1)



In [196]:
for sample in train_loader:
    img, lbl = sample
    break

In [197]:
plt.figure()

plt.hist(img[0,0].flatten(), bins=50)

<IPython.core.display.Javascript object>

(array([787474.,  34611.,   5405.,   2411.,   1832.,   1651.,   1623.,
          1649.,   1869.,   2110.,   2333.,   2478.,   2851.,   3110.,
          3481.,   3690.,   3870.,   4176.,   4399.,   4640.,   4844.,
          5007.,   5521.,   5723.,   6197.,   6701.,   6785.,   7015.,
          7175.,   7206.,   6732.,   6039.,   5452.,   4812.,   4454.,
          3998.,   3698.,   3329.,   3220.,   2879.,   2685.,   2658.,
          2475.,   2336.,   1940.,   1518.,   1202.,    977.,    878.,
           881.]),
 array([0.  , 0.02, 0.04, 0.06, 0.08, 0.1 , 0.12, 0.14, 0.16, 0.18, 0.2 ,
        0.22, 0.24, 0.26, 0.28, 0.3 , 0.32, 0.34, 0.36, 0.38, 0.4 , 0.42,
        0.44, 0.46, 0.48, 0.5 , 0.52, 0.54, 0.56, 0.58, 0.6 , 0.62, 0.64,
        0.66, 0.68, 0.7 , 0.72, 0.74, 0.76, 0.78, 0.8 , 0.82, 0.84, 0.86,
        0.88, 0.9 , 0.92, 0.94, 0.96, 0.98, 1.  ]),
 <a list of 50 Patch objects>)

In [202]:
plt.figure()

plt.imshow(img[0,0].numpy()[::-1,::-1,44].T, cmap='bone')
plt.contourf(lbl[0,0].numpy()[::-1,::-1,44].T, alpha=.3, levels=np.unique(lbl))

<IPython.core.display.Javascript object>

<matplotlib.contour.QuadContourSet at 0x7f5f87145b00>

## Check Augmentations

In [None]:
r = Resize(64)
r = RandomHorizontalFlip(.9)
r = RandomZoom(1.9)

In [None]:
fig, ax = plt.subplots(3,2, figsize=(10,8))

sample_trans = r(sample)
ax[0,0].imshow(sample[0][:,:,64], cmap='bone')
ax[1,0].imshow(sample[0][:,44,:], cmap='bone')
ax[2,0].imshow(sample[0][64,...], cmap='bone')
ax[0,0].contourf(sample[1][:,:,64], levels=np.arange(args['classes']), colors=list(sns.color_palette('pastel', args['classes'])), alpha=.5)
ax[1,0].contour(sample[1][:,44,:], levels=np.arange(args['classes']), colors=list(sns.color_palette('pastel', args['classes'])), alpha=.5)
ax[2,0].contour(sample[1][64,...], levels=np.arange(args['classes']), colors=list(sns.color_palette('pastel', args['classes'])), alpha=.5)



ax[0,1].imshow(sample_trans[0].squeeze()[...,64], cmap='bone')
ax[1,1].imshow(sample_trans[0].squeeze()[:,44,:], cmap='bone')
ax[2,1].imshow(sample_trans[0].squeeze()[64,...], cmap='bone')
ax[0,1].contourf(sample_trans[1].squeeze()[:,:,64], levels=np.arange(args['classes']), colors=list(sns.color_palette('pastel', args['classes'])), alpha=.5)
ax[1,1].contour(sample_trans[1].squeeze()[:,44,:], levels=np.arange(args['classes']), colors=list(sns.color_palette('pastel', args['classes'])), alpha=.5)
ax[2,1].contour(sample_trans[1].squeeze()[64,...], levels=np.arange(args['classes']), colors=list(sns.color_palette('pastel', args['classes'])), alpha=.5)