In [None]:
from __future__ import division
import os
import torch
import torch.nn as nn
import numpy as np
from scipy.ndimage import zoom
import math
import gc
import sys; sys.path.insert(0, '..')

from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader
from torchvision import transforms
from models.models import *
from common.datasets import *

from tqdm.notebook import tqdm

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

<h1>Pretrain Generator</h1>

In [None]:
# Function for the validation pass
def validation(model, validateloader, criterion):
    with torch.no_grad():
        model.eval()
        val_loss = 0
        
        with tqdm(total=len(validateloader)) as pbar:
            for xs, ys in iter(validateloader):
                xs, ys = torch.unsqueeze(xs, 1), torch.unsqueeze(ys, 1)
                xs, ys = xs.to(device), ys.to(device) # send data to cuda for training
                outputs = model(xs) # passes image to the model, and gets a ouput which is the class probability prediction
                val_loss += criterion(outputs, ys) # calculates val_loss from model predictions and true labels
                pbar.update(1)

    return val_loss.item()

In [None]:
# test identity
path = "/path/to/data/"
xpath = path + "/bone"
ypath = path + "/flesh"
xnames = os.listdir(xpath)
ynames = os.listdir(ypath)
split = .2

# Get transforms from first scan
sample = np.load(xpath + "/" + xnames[0])
mean = sample.mean()
std = sample.std()
height = sample.shape[0]
sample = None
transform = transforms.Compose([transforms.Normalize(mean=[mean], std=[std])])

In [None]:
full = NumpyDataset(xnames, y=ynames, xpath=xpath, ypath=ypath, transform=transform, zoom=.5, square=True)
split_idx = math.floor(len(full) * (1 - split))

train, valid = torch.utils.data.random_split(full, (split_idx, len(full) - split_idx))

# Get minibatch size
bs = 15
num_mb = 5
mbs = bs // num_mb

nw = 4
train_loader = DataLoader(train, batch_size=mbs, shuffle=True, num_workers=nw)
valid_loader = DataLoader(valid, batch_size=mbs, shuffle=True, num_workers=nw)

In [None]:
# Make Discriminator and remove output layers
D = ResNet(BasicBlock, [2, 2, 2, 2], sample_size=112, sample_duration=16, num_classes=2, conv1_t_size=3)
body = nn.Sequential(*list(D.children())[:-3])

# Generate VNet
G = DynamicVnet(body, img_size=train[0][0].shape, blur=False, blur_final=False,
          self_attention=False, norm_type=None, last_cross=True, bottle=False).to(device)
G_criterion = nn.MSELoss()
G_opt = torch.optim.Adam(G.parameters(), lr=0.0003, betas=(0.5, 0.999))

In [None]:
savepath = "./pretrain_smol"
checkpoint = torch.load(savepath)
G.load_state_dict(checkpoint['G_state_dict'])

In [None]:
train_losses = []
valid_losses = []

In [None]:
# Pretrain Generator
total_epoch = 40
for epoch in range(total_epoch): # loops through number of epochs
    running_loss = 0
    with tqdm(total=len(train_loader)) as pbar:
        G_opt.zero_grad()
        
        for i, data in enumerate(train_loader): # loops through training_loader
            G.train()
            G.float() # Undo the double() in the validation loop
            
            # Seperate, fix dimensions, put to device
            inputs, labels = data
            inputs, labels = torch.unsqueeze(inputs, 1), torch.unsqueeze(labels, 1)
            inputs, labels = inputs.to(device), labels.to(device)

            # forward + backward + optimize                                          
            inputs = inputs.float()
            labels = labels.float()
            outputs = G(inputs) # forward pass and get predictions
            
            # calculate loss
            loss = G_criterion(outputs, labels)
            loss.backward()
            running_loss += loss.item()
            
            # accumulate gradients for the number of minibatches, only update optimiser each batch
            if (i + 1) % num_mb == 0:
                G_opt.step()
                G_opt.zero_grad()
            
            pbar.update(1)
    
    val_loss = validation(G, valid_loader, G_criterion)
    
    train_losses.append(running_loss/len(train_loader))
    valid_losses.append(val_loss/len(valid_loader))
  
    print("Epoch: {}/{}, Training Loss: {}, Validation Loss: {}".format(epoch+1, total_epoch, running_loss/len(train_loader), val_loss/len(valid_loader)))
    print('-' * 20)

print("Finished Training")

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
 
# multiple line plot
plt.plot(train_losses)
plt.plot(valid_losses)
plt.xlabel("epochs")
plt.ylabel("loss")
plt.legend(["train", "valid"])
plt.title("Dynamic VNet (No Pretrain)")
plt.show

In [None]:
valid_losses

In [None]:
savepath = "./pretrain_smol"
torch.save({'G_state_dict': G.state_dict()}, savepath)

In [None]:
get_generations(G, "../generations", path + "/bone", transform=transform, zoom=.5, bs=5, device=device, num=5)

<h1>Combine Models</h1>

In [None]:
# test identity
path = "/path/to/data/" 
xpath = path + "/bone"
ypath = path + "/flesh"
xnames = os.listdir(xpath)
ynames = os.listdir(ypath)
split = .2

# Get transforms from first scan
sample = np.load(xpath + "/" + xnames[0])
mean = sample.mean()
std = sample.std()
height = sample.shape[0]
sample = None
transform = transforms.Compose([transforms.Normalize(mean=[mean], std=[std])])

In [None]:
full = NumpyDataset(xnames, y=ynames, xpath=xpath, ypath=ypath, transform=transform, zoom=.5, square=True)
train, valid = torch.utils.data.random_split(full, (len(full) - 5, 5))

# Get minibatch size
bs = 16
num_mb = 16
mbs = bs // num_mb

nw = 4
train_loader = DataLoader(train, batch_size=mbs, shuffle=True, num_workers=nw)
valid_loader = DataLoader(valid, batch_size=mbs, shuffle=True, num_workers=nw)

In [None]:
# Make Discriminator and remove output layers
D = ResNet(BasicBlock, [3, 4, 6, 2], sample_size=112, sample_duration=16, num_classes=2, conv1_t_size=3)
body = nn.Sequential(*list(D.children())[:-3])

# Generate VNet
G = DynamicVnet(body, img_size=train[0][0].shape, blur=False, blur_final=False,
          self_attention=False, norm_type=None, last_cross=True,
          bottle=False).to(device)
D.to(device)

D_criterion = nn.BCELoss() # binary cross entropy loss
D_opt = torch.optim.Adam(D.parameters(), lr=0.0003, betas=(0.9, 0.999))
G_criterion = nn.MSELoss()
G_opt = torch.optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.5))

In [None]:
D_reals = []
D_fakes = []
G_reals = []

for _ in range(bs):
    # One sided label smoothing to encourage the discriminator to generalise
    D_reals.append([.9, .1])
    G_reals.append([1, 0])
    D_fakes.append([0, 1])

D_reals = torch.FloatTensor(D_reals)
D_reals = D_reals.to(device)
D_fakes = torch.FloatTensor(D_fakes)
D_fakes = D_fakes.to(device)
G_reals = torch.FloatTensor(G_reals)
G_reals = G_reals.to(device)

In [None]:
savepath = "./pretrain_smol"
checkpoint = torch.load(savepath)
G.load_state_dict(checkpoint['G_state_dict'])
# D.load_state_dict(checkpoint['D_state_dict'])

In [None]:
savepath = "./DGAN"
checkpoint = torch.load(savepath)
G.load_state_dict(checkpoint['G_state_dict'])
D.load_state_dict(checkpoint['D_state_dict'])

In [None]:
max_epoch = 40
loss_scaling = 10
D_threshold = .4
D_loss = D_threshold
generating = False
save_folder= "../generations"

In [None]:
for epoch in range(max_epoch):
    running_G_loss = 0
    running_D_loss2 = 0
    G_count = 0
    
    with tqdm(total=len(train_loader)) as pbar:
        G_opt.zero_grad()
        D_opt.zero_grad()
        
        for i, (inputs, labels) in enumerate(train_loader):
            # I still don't know why we need the floats
            D.train()
            D.float()
            G.train()
            G.float()
            
            # fix dimensions, put to device
            inputs, labels = inputs.float(), labels.float()
            inputs, labels = inputs.to(device), labels.to(device)
            inputs, labels = torch.unsqueeze(inputs, 1), torch.unsqueeze(labels, 1)
            
            # Determine if switch on batch change
            if (i + 1) % num_mb == 1:
                if D_loss < D_threshold:
                    generating = True
            
            # If statement between modes
            if not generating:
                # Training Discriminator (D)
                
                # calculate D's loss for real database
                x_outputs = D(inputs)
                D_x_loss = D_criterion(x_outputs, D_reals[0:len(x_outputs)])

                # calculate G's loss for fake data
                z_outputs = D(G(inputs))
                D_z_loss = D_criterion(z_outputs, D_fakes[0:len(z_outputs)])

                # total loss
                D_loss = D_x_loss + D_z_loss

                # back prop each minibatch
                D_loss.backward()
                                
                # accumulate gradients for the number of minibatches, only update optimiser each batch
                if (i + 1) % num_mb == 0:
                    D_opt.step()
                    D_opt.zero_grad()
                    generating = False
                    
            else:
                #Training Generator (G)
            
                # Generate images
                z_outputs = G(inputs)

                # Get discrimator loss
                D_outputs = D(z_outputs)
                D_loss2 = D_criterion(D_outputs, G_reals[0:len(D_outputs)])
                G_loss = G_criterion(z_outputs, labels)

                # Combine loss
                C_loss = G_loss * loss_scaling + D_loss2
                running_G_loss += G_loss.item()
                running_D_loss2 += D_loss2.item()
                G_count += 1
                
                # back prop
                C_loss.backward()
                
                if (i + 1) % num_mb == 0:
                    # accumulate gradients for the number of minibatches, only update optimiser each batch
                    G_opt.step()
                    G_opt.zero_grad()
                    D_loss += 1
                    generating = False
            
            pbar.update(1)
            
    if G_count == 0:
        print("Epoch: {}/{}, No G training".format(epoch+1, max_epoch))
    else:
        print("Epoch: {}/{}, Total Loss: {}, Pixel Loss: {}, Adversarial Loss: {}".format(epoch+1, max_epoch, (running_D_loss2 + loss_scaling * running_G_loss)/G_count, running_G_loss/G_count, running_D_loss2/G_count))
    
    # Get samples
    gc.collect()
    with torch.no_grad():
        G.eval()
        i = 0

        for (inputs, _) in valid_loader:
            inputs = inputs.to(device)
            inputs = torch.unsqueeze(inputs, 1)
            outputs = G(inputs)
            for output in outputs:
                output.mul_(std).add_(mean)
                output = output.detach().cpu().numpy()
                np.save(save_folder + "/" + str(epoch) + "_" + str(i) + ".npy", output[0])
                i += 1

In [None]:
import gc
gc.collect()

In [None]:
savepath = "./pretrain_big"
torch.save({'G_state_dict': G.state_dict(), 'D_state_dict': D.state_dict()}, savepath)

In [None]:
savepath = "./DGAN"
checkpoint = torch.load(savepath)
G.load_state_dict(checkpoint['G_state_dict'])
D.load_state_dict(checkpoint['D_state_dict'])

In [None]:
get_generations(G, "../generations", path + "/bone", transform=transform, zoom=.5, bs=5, device=device, num=5)