In [1]:
import sys
sys.path.insert(0,"..")
import os
import torch.optim as optim
import torch, gc
from torch.utils.data import SubsetRandomSampler

from src.dataset import ImageDataset
from src.utils import get_indices

import numpy as np
%matplotlib inline

from unet.model import UNet
from unet.training import Trainer
import unet.utils

from gan.discriminator import Discriminator
from mwcnn.mwcnn import MWCNN

from src.loss_functions import FFTloss, VGGPerceptualLoss

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Dataset part used for testing
TEST_SPLIT = 0.15
# Batch size for training. Limited by GPU memory
BATCH_SIZE = 5
# Dataset folder used
DATASET_USED = 'e9_5_GLM87a_cycle1_8_8'
#DATASET_USED = 'e12_5_slide7_round1_section1'
# Full Dataset path
DATASETS = ['e9_5_GLM87a_cycle1_8_8', 'e12_5_slide7_round1_section1' , 'train', 'val']
ROOTDIR = '../data/mip2edof_2samples/'

# Training Epochs
EPOCHS = 200

gc.collect()
torch.cuda.empty_cache()



In [2]:
image_dataset = ImageDataset(ROOTDIR, DATASETS, normalize="percentile")


train_indices,validation_indices, test_indices = get_indices(len(image_dataset), image_dataset.root_dir, TEST_SPLIT, new=True)
train_sampler,validation_sampler, test_sampler = SubsetRandomSampler(train_indices),SubsetRandomSampler(validation_indices), SubsetRandomSampler(test_indices)

trainloader = torch.utils.data.DataLoader(image_dataset, BATCH_SIZE, sampler=train_sampler)
validationloader = torch.utils.data.DataLoader(image_dataset, BATCH_SIZE, sampler=validation_sampler)

testloader = torch.utils.data.DataLoader(image_dataset, 1, sampler=test_sampler)



## Pretrain UNET
First pretrain a UNET and MWCNN, afterwards, this can be used in the GAN set-up

In [None]:
import matplotlib.pyplot as plt
def plot_loss(num_epochs,train_losses):
    plt.plot(range(1, num_epochs + 1), train_losses, label='Training Loss')
    plt.title('Training Loss Curve')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()

In [3]:
filter_num = [16,32,64,128,256]
unet_model = UNet(filter_num).to(device)
# Training
unet_trainer = Trainer(unet_model,device)
unet_loss_record, validation_loss_record = unet_trainer.train(EPOCHS,trainloader,validationloader,mini_batch=BATCH_SIZE)

MODEL_NAME = f"models/UNet-{filter_num}.pth"
torch.save(unet_model, MODEL_NAME)

print(f'Training finished!')

Training finished!


In [None]:
plot_loss(EPOCHS,unet_loss_record)

In [None]:
# Create model
criterion = torch.nn.L1Loss()
MWCNN_model = MWCNN(n_feats=64, n_colors=1, batch_normalize=False).to(device)

# Training
MWCNN_trainer = Trainer(MWCNN_model, criterion, device, clear_cache=True)
mwcnn_loss_record = MWCNN_trainer.train(EPOCHS,trainloader,mini_batch=1)
torch.save(unet_model, MODEL_NAME)


In [None]:
plot_loss(EPOCHS,mwcnn_loss_record)

In [4]:
# make gan and train

In [6]:
generator = unet_model
discriminator = Discriminator(n_feats=64, patch_size=1024)
discriminator = discriminator.to(device)
discriminator.train()

d_optim = optim.Adam(discriminator.parameters(), lr = 1e-4)
g_optim = optim.Adam(generator.parameters(), lr = 1e-4)
scheduler = optim.lr_scheduler.StepLR(g_optim, step_size = 2000, gamma = 0.1)

VGG_loss = VGGPerceptualLoss(resize=False).to(device)
cross_ent = torch.nn.BCELoss()
L1_loss = torch.nn.L1Loss()
real_label = torch.ones((BATCH_SIZE, 1)).to(device)
fake_label = torch.zeros((BATCH_SIZE, 1)).to(device)

for epoch in range(EPOCHS):
    
    for i, data in enumerate(trainloader):
        # Loading data to device used.
        noisy = data['input_image'].to(device)
        sharp = data['output_image'].float().to(device)
                    
        ## Training Discriminator
        output = generator(noisy)
        fake_prob = discriminator(output)
        real_prob = discriminator(sharp)
        
        d_loss_real = cross_ent(real_prob, real_label)
        d_loss_fake = cross_ent(fake_prob, fake_label)
        
        d_loss = d_loss_real + d_loss_fake

        g_optim.zero_grad()
        d_optim.zero_grad()
        d_loss.backward()
        d_optim.step()
        
        ## Training Generator
        output, _ = generator(noisy)
        fake_prob = discriminator(output)
        
        _percep_loss, hr_feat, sr_feat = VGG_loss((sharp + 1.0) / 2.0, (output + 1.0) / 2.0)
        
        L1_loss = L1_loss(output, sharp)
        percep_loss = _percep_loss
        adversarial_loss = cross_ent(fake_prob, real_label)
        
        g_loss = percep_loss + adversarial_loss + L1_loss
        
        g_optim.zero_grad()
        d_optim.zero_grad()
        g_loss.backward()
        g_optim.step()

        scheduler.step()

        
    if epoch % 2 == 0:
        print(epoch)
        print(g_loss.item())
        print(d_loss.item())
        print('=========')



torch.Size([5, 1, 1024, 1024])
torch.Size([5, 32, 1024, 1024])
torch.Size([5, 32, 512, 512])
torch.Size([5, 128, 128, 128])
torch.Size([5, 1, 1024, 1024])
torch.Size([5, 32, 1024, 1024])
torch.Size([5, 32, 512, 512])
torch.Size([5, 128, 128, 128])


In [None]:
# show results
generator.eval()

for data in testloader:
    noisy = data['input_image'].to(device)
    # sharp = data['output_image'].to(device)
    output = generator(noisy)
    output = output.cpu().numpy()
    output = (output + 1.0) / 2.0
    output = output.transpose(1,2,0)
    result = Image.fromarray((output * 255.0).astype(np.uint8))
    result.save('./result/res_%04d.png'%i)