<a href="https://colab.research.google.com/github/jvschw/fmriGAN/blob/main/Masterarbeit_WGAN_GP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## **Setup**

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')
!pip install nilearn
!pip install wandb -q
!pip install nltools

%cd "/content/gdrive/MyDrive/Masterarbeit"
%pylab inline

import scipy
import numpy as np
import os
import nibabel as nib
import datetime
import glob
import wandb
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms

from scipy.ndimage import zoom
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
from torch.autograd import Variable
from tqdm.notebook import tqdm

from model_wgan_gp_tanh import Discriminator, Generator, initialize_weights

## Choose Hyperparameters

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
LEARNING_RATE = 2e-4
BATCH_SIZE = 4
IMAGE_SIZE = 64
CHANNELS_IMG = 1
Z_DIM = 1000
NUM_EPOCHS = 101
FEATURES_CRITIC = 32
FEATURES_GEN = 32
CRITIC_ITERATIONS = 5
LAMBDA_GP = 10
load_pretrained = False

## Configuration

In [None]:
cfg = {
    'env': 'COLAB',
    'usr': "MICHAEL",
    'version_name': 'v3.1',
    'epochs': NUM_EPOCHS,                 #  epoch = (Number of iterations * batch size) / total number of images in training
    'batch_size': BATCH_SIZE, 
    'learning_rate':  LEARNING_RATE,
    'dbg_rescaled': 0,
    'labels': ['footright'],
    'smoothing': '0mm', 
    'latent_dim': Z_DIM,}
    
#DO NOT EDIT
cfg['expid'] = cfg['usr'] + '_' + cfg['version_name'] + '_' + datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + '_' + '_'.join(cfg['labels'])  + '_SS' + cfg['smoothing']

In [None]:
tags=['Brain_GAN', 'Initial Run', '_'.join(cfg['labels']), 'SS'+cfg['smoothing'],
      'latentdim'+str(cfg['latent_dim']), 'batch_size'+str(cfg['batch_size'])]
print(tags)

['Brain_GAN', 'Initial Run', 'footright', 'SS0mm', 'latentdim1000', 'batch_size4']


In [None]:
gpu = True
workers = 0
LAMBDA= LAMBDA_GP
_eps = 1e-15

PROJECT_DIR="/content/gdrive/MyDrive/Masterarbeit"
PATH_TO_DATA="/content/gdrive/MyDrive/Masterarbeit/tstat_"+ cfg['smoothing']

Create directory for saving models when training


In [None]:
imgREF = nib.load(os.path.join(PROJECT_DIR, 'tstat_' + cfg['smoothing'], cfg['labels'][0], 'sub100307.nii.gz'))

PATH_CHECKPOINT_MOD = os.path.join(PROJECT_DIR, 'checkpoint_models', cfg['expid'])

if not os.path.isdir(PATH_CHECKPOINT_MOD):
    print('Creating directory: ', PATH_CHECKPOINT_MOD)
    os.makedirs(PATH_CHECKPOINT_MOD)


#PATH_CHECKPOINT_IMG = os.path.join(PROJECT_DIR, 'checkpoint_images', cfg['expid'])
#if not os.path.isdir(PATH_CHECKPOINT_IMG):
#    print('Creating directory: ', PATH_CHECKPOINT_IMG)
#    os.makedirs(PATH_CHECKPOINT_IMG)

#for fl in glob.glob(os.path.join(PATH_CHECKPOINT_IMG, '*.nii.gz')):
#    print('removing ', fl)
#    os.remove(fl)

Creating directory:  /content/gdrive/MyDrive/Masterarbeit/checkpoint_models/MICHAEL_v3.1_20211130-132728_footright_SS0mm


In [None]:
def do_log(i, some_dict):
    wandb.log(some_dict)

# **Dataloader** 

In [None]:
class NiftiDataset_(Dataset):

    def __init__(self, data_dir, labels, n, transforms=None):
        self.transforms = transforms
        self.mask = np.load("mask_dil64.npy")

        # get the files
        for iLabel in range(len(labels)):
            file_names = sorted(glob.glob(os.path.join(data_dir, labels[iLabel], "*.nii.gz")))

            if iLabel == 0:
                self.data = np.array(file_names[:n])
                self.labels = np.array(np.repeat(labels[iLabel], len(self.data)))
            else:
                self.data = np.append(self.data, file_names[:n])
                self.labels = np.append(self.labels, np.repeat(labels[iLabel], len(self.data)))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):

        img = np.nan_to_num(nib.load(self.data[idx]).get_fdata())  
        img[np.isnan(img)] = 0

        dims = img.shape
        img_ = zoom(img, (64/dims[0], 64/dims[1], 64/dims[2]))

        img_ = (img_-np.min(img_))/(np.max(img_)-np.min(img_))
        img_ = 2*img_-1

        mask_img = img_ * self.mask

        mask_img = torch.tensor(mask_img)
        mask_img = mask_img.unsqueeze(0)
        


        sample = {'img': mask_img.float(), 'label': self.labels[idx]}
        return sample

In [None]:
dataset = NiftiDataset_(PATH_TO_DATA, cfg['labels'], 802) # was 709 for handleft

In [None]:
loader = torch.utils.data.DataLoader(dataset,batch_size=cfg['batch_size'], shuffle=True, num_workers=workers, drop_last=True)

# **Training**


Define helper functions

In [None]:
def calc_gradient_penalty(model, x, x_gen, w=10):
    assert x.size()==x_gen.size()                                               # check if real and sample size match
    alpha_size = tuple((len(x), *(1,)*(x.dim()-1)))
    alpha_t = torch.cuda.FloatTensor if x.is_cuda else torch.Tensor
    alpha = alpha_t(*alpha_size).uniform_()
    x_hat = x.data*alpha + x_gen.data*(1-alpha)
    x_hat = Variable(x_hat, requires_grad=True)

    def eps_norm(x):
        x = x.view(len(x), -1)
        return (x*x+_eps).sum(-1).sqrt()
    def bi_penalty(x):
        return (x-1)**2

    grad_xhat = torch.autograd.grad(model(x_hat).sum(), x_hat, create_graph=True, only_inputs=True)[0]
    penalty = w*bi_penalty(eps_norm(grad_xhat)).mean()
    return penalty

In [None]:
def saveResampledNifti(generatedData64, imgREF, fNameOut):
    dims = imgREF.shape
    generatedDataOrgRes = zoom(generatedData64, (dims[0]/64, dims[1]/64, dims[2]/64))
    imgRes = nib.Nifti1Image(generatedDataOrgRes, affine = imgREF.affine)
    nib.save(imgRes, fNameOut)

def generate_fake(generator, z_dim, output_dir, n_fakes=100, batch_size=1):
                print('saving into %s' % output_dir)
                for k in tqdm(range(n_fakes)):

                    noise = torch.randn(batch_size, z_dim, 1, 1, 1).to(device)
                    dat = generator(noise)

                    fNameOut = os.path.join(output_dir,'fake'+'{:03}'.format(k+1) + '.nii.gz')
                    generatedData64 = np.squeeze(dat.data.cpu().numpy())
                    saveResampledNifti(generatedData64, imgREF, fNameOut)

Initialize the parameters

In [None]:
gen = Generator(cfg['latent_dim'], CHANNELS_IMG, FEATURES_GEN).to(device)
critic = Discriminator(CHANNELS_IMG, FEATURES_CRITIC).to(device)

if load_pretrained:
    gen.load_state_dict(torch.load("checkpoint_models/tanh/ckpt_gen_iter000100.pth"))
    critic.load_state_dict(torch.load("checkpoint_models/tanh/ckpt_critic_iter000100.pth"))

initialize_weights(gen)
initialize_weights(critic)

opt_gen = optim.Adam(gen.parameters(), lr=cfg['learning_rate'], betas=(0.0, 0.9))
opt_critic = optim.Adam(critic.parameters(), lr=cfg['learning_rate'], betas=(0.0, 0.9))

**Training**

In [None]:
def train(nbr_models_saved=10):
    
    gen.train()
    critic.train()

    
    wandb.init(
    entity="ml4ni",
    project='alpha-wgan',
    name=cfg['expid'], 
    notes='https://colab.research.google.com/drive/10hVK7wxbZRzynuxahMx5L2ziLwjdQx89#scrollTo=zIt_iSG52ug4', 
    tags=tags,
    config=cfg)

    config = wandb.config
    

    for epoch in range(config['epochs']):
        for batch_idx, real in enumerate(loader):

            real = real["img"].to(device)
            cur_batch_size = real.shape[0]

            # Train Critic: max E[critic(real)] - E[critic(fake)] equivalent to minimizing the negative of that
            for _ in range(CRITIC_ITERATIONS):
                noise = torch.randn(cur_batch_size, Z_DIM, 1, 1,1).to(device)
                fake = gen(noise)
                critic_real = critic(real).reshape(-1)
                critic_fake = critic(fake).reshape(-1)
                gp = calc_gradient_penalty(critic, real, fake)
                loss_critic = -(torch.mean(critic_real) - torch.mean(critic_fake)) + LAMBDA_GP * gp  
                               
                
                critic.zero_grad()
                loss_critic.backward(retain_graph=True)
                opt_critic.step()


            # Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
            
            gen_fake = critic(fake).reshape(-1)
            loss_gen = -torch.mean(gen_fake) 
            gen.zero_grad()
            loss_gen.backward()
            opt_gen.step()
            
            if batch_idx % cur_batch_size == 0 and batch_idx > 0:
                print(f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(loader)} \
                      Loss D: {loss_critic:.4f}, loss G: {loss_gen:.4f}")

            do_log(epoch, { "D:":loss_critic, "G:":loss_gen})

        if epoch%nbr_models_saved ==0:
            fNameOut = os.path.join(PATH_CHECKPOINT_MOD, 'ckpt_gen_iter' + '{:06}'.format(epoch+1) + '.pth')
            print('Saving ', fNameOut)
            torch.save(gen.state_dict(), fNameOut)
              
            fNameOut = os.path.join(PATH_CHECKPOINT_MOD, 'ckpt_critic_iter' + '{:06}'.format(epoch+1) + '.pth')
            print('Saving ', fNameOut)
            torch.save(critic.state_dict(), fNameOut)

In [None]:
train()