In [None]:
import numpy as np
import pickle
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler

import torchvision.datasets as dset
import torchvision.transforms as T

%load_ext autoreload
%autoreload 2

In [None]:
with open("data/pkl_files/trimmed_inputs.pkl", "rb") as fp:
    inputs = pickle.load(fp)
with open("data/pkl_files/trimmed_outputs.pkl", "rb") as fp:
    outputs = pickle.load(fp)

In [None]:
# Sanity check
print(inputs.shape)
print(outputs.shape)

In [None]:
# Sanity Check
idx = np.random.randint(0, inputs.shape[0])
def plot_sampled_images(idx):    
    cropped_seg_img = outputs[idx,0,:,:]
    plt.imshow(cropped_seg_img)
    plt.show()
    cropped_flair_img = inputs[idx,0,:,:]
    plt.imshow(cropped_flair_img)
    plt.show()
    cropped_t1_img = inputs[idx,1,:,:]
    plt.imshow(cropped_t1_img)
    plt.show()
    cropped_t1ce_img = inputs[idx,2,:,:]
    plt.imshow(cropped_t1ce_img)
    plt.show()
    cropped_t2_img = inputs[idx,3,:,:]
    plt.imshow(cropped_t2_img)
    plt.show()
plot_sampled_images(idx)

In [None]:
np.random.seed(0)
shuffled_idxs = list(range(inputs.shape[0]))
np.random.shuffle(shuffled_idxs)
train_cutoff = int(inputs.shape[0]*75/100)
val_cutoff = int(inputs.shape[0]*95/100)

train_inputs = []
train_outputs = []
val_inputs = []
val_outputs = []
test_inputs = []
test_outputs = []

for iteration, idx in enumerate(shuffled_idxs):
    if iteration < train_cutoff:
        train_inputs.append(inputs[idx])
        train_outputs.append(outputs[idx])
    elif iteration < val_cutoff:
        val_inputs.append(inputs[idx])
        val_outputs.append(outputs[idx])
    else:
        test_inputs.append(inputs[idx])
        test_outputs.append(outputs[idx])

In [None]:
# sanity check
print(len(train_inputs))
print(len(train_outputs))
print(len(val_inputs))
print(len(val_outputs))
print(len(test_inputs))
print(len(test_outputs))

In [None]:
train_inputs = np.asarray(train_inputs)
val_inputs = np.asarray(val_inputs)
test_inputs = np.asarray(test_inputs)

train_inputs_mean = np.mean(train_inputs, axis=0)
train_inputs_std = np.std(train_inputs, axis=0)

train_inputs -= train_inputs_mean
train_inputs /= train_inputs_std
val_inputs -= train_inputs_mean
val_inputs /= train_inputs_std
test_inputs -= train_inputs_mean
test_inputs /= train_inputs_std

In [None]:
train_outputs = np.asarray(train_outputs)
n_tumor_pixels = np.count_nonzero(train_outputs)
n_tumor1 = np.count_nonzero(train_outputs==1)
n_tumor2 = np.count_nonzero(train_outputs==2)
n_tumor3 = np.count_nonzero(train_outputs==3)
n_tumor4 = np.count_nonzero(train_outputs==4)
total_pixels = train_outputs.shape[0]*train_outputs.shape[2]*train_outputs.shape[3]
print(total_pixels)
n_blank_pixels = total_pixels - n_tumor_pixels
frac_tumor = n_tumor_pixels/total_pixels
print(frac_tumor)

In [None]:
train_data = []
val_data = []
test_data = []

for idx in range(train_inputs.shape[0]):
    train_data.append([train_inputs[idx], train_outputs[idx]])
for idx in range(val_inputs.shape[0]):
    val_data.append([val_inputs[idx], val_outputs[idx]])
for idx in range(test_inputs.shape[0]):
    test_data.append([test_inputs[idx], test_outputs[idx]])

In [None]:
train_loader = torch.utils.data.DataLoader(train_data, shuffle=True, batch_size=64)
val_loader = torch.utils.data.DataLoader(val_data, shuffle=True, batch_size=64)
test_loader = torch.utils.data.DataLoader(test_data, shuffle=True, batch_size=64)

print(len(train_loader))
print(len(train_data))

In [None]:
USE_GPU = True

dtype = torch.float32 # we will be using float throughout this tutorial

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

# Constant to control how frequently we print train loss
print_every = 30

print('using device:', device)

In [None]:
from models.unet import *

In [None]:
def test_UNet():
    channels_in = 4
    n_classes = 5
    x = torch.zeros((64, 4, 32, 32), dtype=dtype)
    model = UNet(channels_in, n_classes)
    scores = model(x)
    print(scores.size())
test_UNet()

In [None]:
# Defining all parameters 
print_every = 25
loss_history = []
learning_rate = 1e-1
input_channels = 4
n_classes = 5
model = UNet(channel_in=input_channels, classes=n_classes)
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, nesterov=True)
# criterion = torch.nn.CrossEntropyLoss(weight=torch.FloatTensor([1.0, n_blank_pixel/float(n_tumor_pixels)]).cuda()) 
class_weights = [1.0, n_blank_pixels/float(n_tumor1), n_blank_pixels/float(n_tumor2), 1.0, n_blank_pixels/float(n_tumor4)]
print(class_weights)
criterion = torch.nn.CrossEntropyLoss(weight=torch.FloatTensor(class_weights).cuda()) 
# could try dice loss? not sure how it would work nor how it expects its parameters though lol

if USE_GPU and torch.cuda.is_available():
    model.cuda()

In [None]:
def create_masks(x):
    mask = np.zeros((n_classes, x.shape[0], x.shape[1]))
    for c in range(n_classes):
        mask[c] = x == c
    return mask

In [None]:
def compute_iou(scores, output):  
    scores = scores.astype(np.bool)
    output = output.astype(np.bool)
    overlap = scores*output
    union = scores+output # Logical OR
    IOU = (overlap.sum())/(float(union.sum())+1e-10) 
    return IOU

In [None]:
def iou_metric(preds, labels):
    # preds and labels are of size (64, 32, 32)
    # return average iou value for each class
    numpy_preds = preds.cpu().numpy()  
    numpy_labels = labels.cpu().numpy()
    
    iou_tracker = np.zeros((numpy_preds.shape[0], n_classes))
    # get the masks
    for idx in range(numpy_preds.shape[0]):
        """plt.title('Segmented Output')
        plt.imshow(numpy_preds[idx,:,:])
        plt.show()
        plt.title('Ground Truth')
        plt.imshow(numpy_labels[idx,:,:])
        plt.show()"""
        pred_mask = create_masks(numpy_preds[idx])
        label_mask = create_masks(numpy_labels[idx])
        for c in range(n_classes):
            iou_tracker[idx, c] = compute_iou(pred_mask[c], label_mask[c])
        # print(iou_tracker[idx])
    # print(np.mean(iou_tracker, axis=0))
    return np.mean(iou_tracker, axis=0)  

In [None]:
def check_iou(loader, model):
    all_iou = np.zeros((len(loader),n_classes))

    with torch.no_grad():
        for t, (x,y) in enumerate(loader):
            x = x.to(device=device, dtype=dtype)  # move to device, e.g. GPU
            y = y.to(device=device, dtype=torch.float).view(y.shape[0], 32, 32)
            scores = model(x)
            _, preds = scores.max(1)
            
            all_iou[t] = iou_metric(preds, y)
    # print(np.mean(all_iou, axis=0))
    return np.mean(all_iou, axis=0)

In [None]:
check_iou(val_loader, model)

In [None]:
def plot_output(scores, y):
    with torch.no_grad():
        _, preds = scores.max(1)
        numpy_scores = preds.cpu().numpy().astype(np.uint8)
        numpy_truth = y.cpu().numpy().astype(np.uint8)
        idx = np.random.randint(1, y.shape[0])
        while(np.count_nonzero(numpy_truth[idx]) < int(0.10*32*32)):
            idx = np.random.randint(1, y.shape[0])
        plt.title('Segmented Output')
        plt.imshow(numpy_scores[idx,:,:])
        plt.show()
        plt.title('Ground Truth')
        plt.imshow(numpy_truth[idx,:,:])
        plt.show()

In [None]:
import torch.nn.functional as F
loss_history = []
val_iou_history = []
train_iou_history = []
plot_every = 100
def train(model, optimizer, epochs=1):
    model = model.to(device=device)
    loss_history.clear()
    val_iou_history.clear()
    train_iou_history.clear()
    for e in range(epochs):
        for t, (x,y) in enumerate(train_loader):
            model.train()
            x = x.to(device=device, dtype=dtype)
            labels = y.to(device=device, dtype=torch.long).view(y.shape[0], 32, 32)
            # print(labels.shape)

            scores = model(x)
            # print(scores.shape)

            loss = criterion(scores, labels)
            # loss = F.cross_entropy(scores, labels, weight=torch.FloatTensor([1/(float(26265600-1762336)), 1/(float(1762336))]).cuda())
            loss_history.append(loss.item())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # print(t)
            if t % print_every == 0:
                print('Epoch %d, Iteration %d, loss = %.4f' % (e, t, loss.item()))
                val_iou = check_iou(val_loader, model)
                val_iou_history.append(val_iou)
                train_iou = check_iou(train_loader, model)
                train_iou_history.append(train_iou)
                print('Validation IOU: ', val_iou)
                print('Training IOU: ', train_iou)
            # if t % plot_every == 0:
                # plot_output(scores, labels)
train(model, optimizer, epochs=250)

In [None]:
fig = plt.figure()
plt.title('Loss over time')
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.plot(range(len(loss_history))[0::100], loss_history[0::100])
fig.savefig('Loss over time (5 classes).png')

In [None]:
val_hist = np.asarray(val_iou_history)
train_hist = np.asarray(train_iou_history)
print(val_hist.shape)
print(train_hist.shape)

In [None]:
fig = plt.figure()
plt.title('IOU for no tumor over time')
plt.xlabel('Iteration')
plt.ylabel('IOU')
plt.plot(range(val_hist.shape[0]), val_hist[:,0], label='val')
plt.plot(range(train_hist.shape[0]), train_hist[:,0], label='train')
plt.legend()
fig.savefig('IOU for no tumor (5 classes) - Poster.png')

In [None]:
fig = plt.figure()
plt.title('IOU for tumor 1 over time')
plt.xlabel('Iteration')
plt.ylabel('IOU')
plt.plot(range(val_hist.shape[0]), val_hist[:,1], label='val')
plt.plot(range(train_hist.shape[0]), train_hist[:,1], label='train')
plt.legend()
fig.savefig('IOU for tumor 1 (5 classes) - Poster.png')

In [None]:
fig = plt.figure()
plt.title('IOU for tumor 2 over time')
plt.xlabel('Iteration')
plt.ylabel('IOU')
plt.plot(range(val_hist.shape[0]), val_hist[:,2], label='val')
plt.plot(range(train_hist.shape[0]), train_hist[:,2], label='train')
plt.legend()
fig.savefig('IOU for tumor 2 (5 classes) - Poster.png')

In [None]:
fig = plt.figure()
plt.title('IOU for tumor 4 over time')
plt.xlabel('Iteration')
plt.ylabel('IOU')
plt.plot(range(val_hist.shape[0]), val_hist[:,4], label='val')
plt.plot(range(train_hist.shape[0]), train_hist[:,4], label='train')
plt.legend()
fig.savefig('IOU for tumor 4 (5 classes) - Poster.png')

In [None]:
check_iou(test_loader, model)