In [None]:
from __future__ import division
import os
import torch
import torch.nn as nn
import numpy as np
from scipy.ndimage import zoom
import math
import gc
import sys; sys.path.insert(0, '..')

from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader
from torchvision import transforms
from models.models import *
from common.datasets import *

from tqdm.notebook import tqdm

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

<h1>Pretrain Generator</h1>

In [None]:
# Function for the validation pass
def validation(model, validateloader, criterion):
    with torch.no_grad():
        model.eval()
        val_loss = 0
        
        with tqdm(total=len(validateloader)) as pbar:
            for xs, ys in iter(validateloader):
                xs, ys = torch.unsqueeze(xs, 1), torch.unsqueeze(ys, 1)
                xs, ys = xs.to(device), ys.to(device) # send data to cuda for training
                outputs = model(xs) # passes image to the model, and gets a ouput which is the class probability prediction
                val_loss += criterion(outputs, ys) # calculates val_loss from model predictions and true labels
                pbar.update(1)

    return val_loss.item()

In [None]:
# test identity
path = "/path/to/data/" #Change to where your data is stored
xpath = path + "/bone"
ypath = path + "/flesh"
xnames = os.listdir(xpath)
ynames = os.listdir(ypath)
split = .2

# Get transforms from first scan
sample = np.load(xpath + "/" + xnames[0])
mean = sample.mean()
std = sample.std()
height = sample.shape[0]
sample = None
transform = transforms.Compose([transforms.Normalize(mean=[mean], std=[std])])

In [None]:
full = NumpyDataset(xnames, y=ynames, xpath=xpath, ypath=ypath, transform=transform, zoom=.5, square=True)
split_idx = math.floor(len(full) * (1 - split))

train, valid = torch.utils.data.random_split(full, (split_idx, len(full) - split_idx))

# Get minibatch size
bs = 16
num_mb = 4
mbs = bs // num_mb

nw = 4
train_loader = DataLoader(train, batch_size=mbs, shuffle=True, num_workers=nw)
valid_loader = DataLoader(valid, batch_size=mbs, shuffle=True, num_workers=nw)

In [None]:
G = VNet(height).to(device)
# D = ResNet(BasicBlock, [2, 2, 2, 2], sample_size=112, sample_duration=16, num_classes=2).to(device)
G_criterion = nn.MSELoss() # mean squared error loss
# D_criterion = nn.BCELoss() # binary cross entropy loss
G_opt = torch.optim.Adam(G.parameters(), lr=0.0003, betas=(0.5, 0.999))
# D_opt = torch.optim.Adam(D.parameters(), lr=0.0003, betas=(0.5, 0.999))

In [None]:
savepath = "./pretrain4"
checkpoint = torch.load(savepath)
G.load_state_dict(checkpoint['G_state_dict'])

In [None]:
train_losses = []
valid_losses = []

In [None]:
# Pretrain Generator
total_epoch = 40
for epoch in range(total_epoch): # loops through number of epochs
    running_loss = 0
    with tqdm(total=len(train_loader)) as pbar:
        G_opt.zero_grad()
        
        for i, data in enumerate(train_loader): # loops through training_loader
            G.train()
            G.float() # Undo the double() in the validation loop
            
            # Seperate, fix dimensions, put to device
            inputs, labels = data
            inputs, labels = torch.unsqueeze(inputs, 1), torch.unsqueeze(labels, 1)
            inputs, labels = inputs.to(device), labels.to(device)

            # forward + backward + optimize                                          
            inputs = inputs.float()
            labels = labels.float()
            outputs = G(inputs) # forward pass and get predictions
            
            # calculate loss
            loss = G_criterion(outputs, labels)
            loss.backward()
            running_loss += loss.item()
            
            # accumulate gradients for the number of minibatches, only update optimiser each batch
            if (i + 1) % num_mb == 0:
                G_opt.step()
                G_opt.zero_grad()
            
            pbar.update(1)
    
    val_loss = validation(G, valid_loader, G_criterion)
    
    train_losses.append(running_loss/len(train_loader))
    valid_losses.append(val_loss/len(valid_loader))
  
    print("Epoch: {}/{}, Training Loss: {}, Validation Loss: {}".format(epoch+1, total_epoch, running_loss/len(train_loader), val_loss/len(valid_loader)))
    print('-' * 20)

print("Finished Training")

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
 
# multiple line plot
plt.plot(UNet_t)
plt.plot(VNet_t)
plt.plot(DVNet_t)
plt.xlabel("epochs")
plt.ylabel("loss")
plt.legend(["UNet", "VNet", "DVNet"])
plt.title("Training Losses")
plt.show

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
 
# multiple line plot
plt.plot(UNet_v)
plt.plot(VNet_v)
plt.plot(DVNet_v)
plt.xlabel("epochs")
plt.ylabel("loss")
plt.legend(["UNet", "VNet", "DVNet"])
plt.title("Validation Losses")
plt.show

In [None]:
savepath = "./pretrain4"
torch.save({'G_state_dict': G.state_dict()}, savepath)

In [None]:
path = "../testnpys"
transform = transforms.Compose([transforms.Normalize(mean=[mean], std=[std])])
get_generations(G, path + "/generated", path + "/bone", zoom_f=.5)

<h1>Pretrain Discriminator</h1>

This section still does not work on massive, it also may not be necessary

In [None]:
# Function for the validation pass
def validation(model, validateloader, criterion):
    with torch.no_grad():
        model.eval()
        val_loss = 0
        total = 0
        correct = 0

        for images, labels in iter(validateloader):
            images, labels = images.to(device), labels.to(device) # send data to cuda, where the model is kept
            outputs = model(images) # passes image to the model, and gets a ouput which is the class probability prediction
            outputs = outputs.double()
            labels = labels.double()

            val_loss += criterion(outputs, labels) # calculates val_loss from model predictions and true labels
            _, idxprediction = torch.max(outputs, 1) # turns class probability predictions to class labels
            _, idxlabels = torch.max(labels, 1)

            total += labels.size(0) # sums the number of predictions
            correct += (idxprediction == idxlabels).sum().item() # sums the number of correct predictions

    return val_loss.item(), correct/total # return loss value, accuracy

In [None]:
# test identity
path = "/path/to/data"
xpath = path + "/bone"
ypath = path + "/generated"
xnames = os.listdir(xpath)
ynames = os.listdir(ypath)
split = .2

# Get transforms from first scan
sample = np.load(xpath + "/" + xnames[0])
mean = sample.mean()
std = sample.std()
height = sample.shape[0]
sample = None
transform = transforms.Compose([transforms.Normalize(mean=[mean], std=[std])])

In [None]:
full = NumpyClassDataset(xnames, ynames, class1path=xpath, class2path=ypath, transform=transform,square=True)
split_idx = math.floor(len(full) * (1 - split))

train, valid = torch.utils.data.random_split(full, (split_idx, len(full) - split_idx))

train_loader = DataLoader(train, batch_size=15, shuffle=True, num_workers=11)
valid_loader = DataLoader(valid, batch_size=15, shuffle=True, num_workers=11)

In [None]:
D = ResNet(BasicBlock, [2, 2, 2, 2], sample_size=112, sample_duration=16, num_classes=2).to(device)
criterion = nn.BCELoss() # binary cross entropy loss
opt = torch.optim.Adam(D.parameters(), lr=0.001)

In [None]:
savepath = "./toast"
# train_loader = None
# valid_loader = None
outputs = None
gc.collect()

checkpoint = torch.load(savepath)
D.load_state_dict(checkpoint['D_state_dict'])

In [None]:
total_epoch = 30
for epoch in range(total_epoch): # loops through number of epochs
    running_loss = 0
    for i, data in enumerate(train_loader): # loops through training_loader
        D.train()
        D.float()
        print(i + 1, "/", len(train_loader))
        inputs, labels = data 
        inputs, labels = inputs.to(device), labels.to(device) # send data to cuda for training

        inputs = inputs.float()
        labels = labels.float()
        
#         print("Input shape:", inputs.shape)
#         print("Label shape:", labels.shape)
        
        # forward + backward + optimize                                          
        opt.zero_grad() # zero the gradients in model parameters
        outputs = D(inputs) # forward pass and get predictions
        
#         print("Output shape:", outputs.shape)
        outputs = outputs.float()
        labels = labels.float()
        
        loss = criterion(outputs, labels) # calculate loss
        loss.backward() # calculates gradient w.r.t to loss for all parameters in model that have requires_grad=True
        opt.step() # iteration all parameters in the model with requires_grad=True and update their weights.

        running_loss += loss.item() # sum total loss in current epoch for print later

    val_loss, accuracy = validation(D, valid_loader, criterion) # after training for one epoch, run the validation() function to see how the model is doing on the validation dataset
    print("Epoch: {}/{}, Loss: {}, Val Loss: {}, Val Accuracy: {}".format(epoch+1, total_epoch, running_loss/len(train_loader), val_loss, accuracy))
    print('-' * 20)

print("Finished Training")

In [None]:
savepath = "./pretrain_D"
torch.save({'D_state_dict': D.state_dict()}, savepath)

<h1>Combine Models (NoGAN)</h1>

In [None]:
# test identity
path = "/path/to/data/"
xpath = path + "/bone"
ypath = path + "/flesh"
xnames = os.listdir(xpath)
ynames = os.listdir(ypath)
split = .2

# Get transforms from first scan
sample = np.load(xpath + "/" + xnames[0])
mean = sample.mean()
std = sample.std()
height = sample.shape[0]
sample = None
transform = transforms.Compose([transforms.Normalize(mean=[mean], std=[std])])

In [None]:
full = NumpyDataset(xnames, y=ynames, xpath=xpath, ypath=ypath, transform=transform, zoom=.5, square=True)
train, valid = torch.utils.data.random_split(full, (len(full) - 5, 5))

bs = 12
nw = 12
train_loader = DataLoader(train, batch_size=bs, shuffle=True, num_workers=nw)
valid_loader = DataLoader(valid, batch_size=bs, shuffle=True, num_workers=nw)

In [None]:
G = VNet(height).to(device)
D = ResNet(BasicBlock, [2, 2, 2, 2], sample_size=112, sample_duration=16, num_classes=2).to(device)
G_criterion = nn.MSELoss() # mean squared error loss
D_criterion = nn.BCELoss() # binary cross entropy loss
G_opt = torch.optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
D_opt = torch.optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [None]:
D_reals = []
D_fakes = []
G_reals = []

for _ in range(bs):
    # One sided label smoothing to encourage the discriminator to generalise
    D_reals.append([.9, .1])
    G_reals.append([1, 0])
    D_fakes.append([0, 1])

D_reals = torch.FloatTensor(D_reals)
D_reals = D_reals.to(device)
D_fakes = torch.FloatTensor(D_fakes)
D_fakes = D_fakes.to(device)
G_reals = torch.FloatTensor(G_reals)
G_reals = G_reals.to(device)

In [None]:
max_epoch = 5
loss_scaling = 30
cycles = 1
count = 0
D_threshold = .2
D_loss = D_threshold + 1

In [None]:
# If we continue to train the discriminator until some threshold does this mean that we do not need to pretrain the discriminator?

In [None]:
# savepath = "./pretrain4"
# checkpoint = torch.load(savepath)
# G.load_state_dict(checkpoint['G_state_dict'])

In [None]:
savepath = "./GAN"
checkpoint = torch.load(savepath)
G.load_state_dict(checkpoint['G_state_dict'])
D.load_state_dict(checkpoint['D_state_dict'])

In [None]:
for epoch in range(max_epoch):
    for idx, (inputs, labels) in enumerate(train_loader):
        # I still don't know why we need the floats
        D.train()
        D.float()
        G.train()
        G.float()
        
        inputs = inputs.float()
        labels = labels.float()
        inputs, labels = inputs.to(device), labels.to(device) # send data to cuda for training
        
        print(idx + 1, "/", len(train_loader))

        # Skip if generator is Cycling or Discriminator is > threshold
        """
          Training Discriminator (D)
        """
        un_inputs = torch.unsqueeze(inputs, 1)

        # calculate D's loss for real dataabs
        x_outputs = D(un_inputs)
        # Fix to handle differences in batch size for last batch
        D_x_loss = D_criterion(x_outputs, D_reals[0:len(x_outputs)])

        # calculate G's loss for fake data
        z_outputs = D(G(inputs))
        D_z_loss = D_criterion(z_outputs, D_fakes[0:len(z_outputs)])

        # total loss
        D_loss = D_x_loss + D_z_loss  
        print("Discriminator", D_loss)

        # back prop
        D.zero_grad()
        D_loss.backward()
        D_opt.step()

#         if D_loss < D_threshold:
#             count = 0    
        
        if D_loss < D_threshold:
            """
              Training Generator (G)
            """
            
            # Generate images
            z_outputs = G(inputs)

            # Get discrimator loss
            D_outputs = D(z_outputs)
            D_loss2 = D_criterion(D_outputs, G_reals[0:len(D_outputs)])
            G_loss = G_criterion(z_outputs, labels)

            # Combine loss
            C_loss = G_loss * loss_scaling + D_loss2
            print("Generator", G_loss, D_loss2)

            # back prop
            G.zero_grad()
            C_loss.backward()
            G_opt.step()
            
            count += 1
    
    # Get samples
    G.eval()
    for (inputs, _) in train_loader:
        inputs = inputs.to(device)
        outputs = G(inputs)
        for j, output in enumerate(outputs):
            output.mul_(std).add_(mean)
            output = output.detach().cpu().numpy()
            np.save("../testnpys/GAN/" + str(epoch) + "_" + str(j) + ".npy", output)
    
#         if step % 500 == 0:
#             print('Epoch: {}/{}, Step: {}, D Loss: {}, G Loss: {}'.format(epoch + 1, max_epoch, step, D_loss.item(), G_loss.item()))
        
#         if step % 1000 == 0:
#             G.eval()

#             img = get_sample_image(G, n_noise)
#             imsave('samples/{}_step{}.jpg'.format(MODEL_NAME, str(step).zfill(3)), img, cmap='gray')
#             G.train()
#         step += 1

In [None]:
savepath = "./GAN"
torch.save({'G_state_dict': G.state_dict(), 'D_state_dict': D.state_dict()}, savepath)

In [None]:
savepath = "./GAN"
checkpoint = torch.load(savepath)
G.load_state_dict(checkpoint['G_state_dict'])
D.load_state_dict(checkpoint['D_state_dict'])

<h1>Save</h1>

In [None]:
get_generations(G, path + "/GAN", path + "/bone", num=5)