In [1]:
from torch.optim.lr_scheduler import ReduceLROnPlateau
from collections import defaultdict
import pandas as pd
import numpy as np
import torch.fft
import subprocess
import logging
import random
import shutil
import psutil
import scipy
import torch
import copy
import yaml
import time
import tqdm
import sys
import gc

import segmentation_models_pytorch as smp

from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

from holodecml.data import PickleReader, UpsamplingReader
from holodecml.propagation import InferencePropagator
from holodecml.transforms import LoadTransformations
from holodecml.models import load_model
from holodecml.losses import load_loss

import os
import warnings
warnings.filterwarnings("ignore")

from stylegan import StyledGenerator
import lpips

In [2]:
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

available_ncpus = len(psutil.Process().cpu_affinity())

# Set up the GPU
is_cuda = torch.cuda.is_available()
device = torch.device("cpu") if not is_cuda else torch.device("cuda:0")

In [3]:
# ### Set seeds for reproducibility
def seed_everything(seed=1234):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.deterministic = True
        
def requires_grad(model, flag=True):
    for p in model.parameters():
        p.requires_grad = flag

In [4]:
config = "../config/gan.yml"
with open(config) as cf:
    conf = yaml.load(cf, Loader=yaml.FullLoader)

In [5]:
# Set seeds for reproducibility
seed = 1000 if "seed" not in conf else conf["seed"]
seed_everything(seed)

save_loc = conf["save_loc"]
os.makedirs(save_loc, exist_ok = True)
os.makedirs(os.path.join(save_loc, "images"), exist_ok = True)
shutil.copyfile(config, os.path.join(save_loc, "model.yml"))

tile_size = int(conf["data"]["tile_size"])
step_size = int(conf["data"]["step_size"])
data_path = conf["data"]["output_path"]
data_path_raw = conf["data"]["output_path_raw"]

total_positive = int(conf["data"]["total_positive"])
total_negative = int(conf["data"]["total_negative"])
total_examples = int(conf["data"]["total_training"])

transform_mode = "None" if "transform_mode" not in conf["data"] else conf["data"]["transform_mode"]
config_ncpus = int(conf["data"]["cores"])
use_cached = False if "use_cached" not in conf["data"] else conf["data"]["use_cached"]

name_tag = f"{tile_size}_{step_size}_{total_positive}_{total_negative}_{total_examples}_{transform_mode}"
fn_train = f"{data_path}/training_{name_tag}.pkl"
fn_valid = f"{data_path}/validation_{name_tag}.pkl"
fn_train_raw = f"{data_path_raw}/training_{name_tag}.pkl"
fn_valid_raw = f"{data_path_raw}/validation_{name_tag}.pkl"

# Trainer params
train_batch_size = conf["trainer"]["train_batch_size"]
valid_batch_size = conf["trainer"]["valid_batch_size"]

epochs = conf["trainer"]["epochs"]
batches_per_epoch = conf["trainer"]["batches_per_epoch"]
Tensor = torch.cuda.FloatTensor if is_cuda else torch.FloatTensor
adv_loss = conf["trainer"]["adv_loss"]
lambda_gp = conf["trainer"]["lambda_gp"]
train_gen_every = conf["trainer"]["train_gen_every"]
train_disc_every = conf["trainer"]["train_disc_every"]
threshold = conf["trainer"]["threshold"]

In [6]:
# Load the preprocessing transforms
if "Normalize" in conf["transforms"]["training"]:
    conf["transforms"]["validation"]["Normalize"]["mode"] = conf["transforms"]["training"]["Normalize"]["mode"]
    conf["transforms"]["inference"]["Normalize"]["mode"] = conf["transforms"]["training"]["Normalize"]["mode"]

train_transforms = LoadTransformations(conf["transforms"]["training"])
valid_transforms = LoadTransformations(conf["transforms"]["validation"])

In [7]:
train_synthetic_dataset = PickleReader(
    fn_train,
    transform=train_transforms,
    max_images=int(0.8 * conf["data"]["total_training"]),
    max_buffer_size=int(0.1 * conf["data"]["total_training"]),
    color_dim=conf["discriminator"]["in_channels"],
    shuffle=True
)

test_synthetic_dataset = PickleReader(
    fn_valid,
    transform=valid_transforms,
    max_images=int(0.1 * conf["data"]["total_training"]),
    max_buffer_size=int(0.1 * conf["data"]["total_training"]),
    color_dim=conf["discriminator"]["in_channels"],
    shuffle=False
)

In [8]:
train_synthetic_loader = torch.utils.data.DataLoader(
    train_synthetic_dataset,
    batch_size=train_batch_size,
    num_workers=0,
    pin_memory=True,
    shuffle=True)

test_synthetic_loader = torch.utils.data.DataLoader(
    test_synthetic_dataset,
    batch_size=valid_batch_size,
    num_workers=0,  # 0 = One worker with the main process
    pin_memory=True,
    shuffle=False)

In [9]:
train_holodec_dataset = PickleReader(
    fn_train_raw,
    transform=train_transforms,
    max_images=int(0.8 * conf["data"]["total_training"]),
    max_buffer_size=int(0.1 * conf["data"]["total_training"]),
    color_dim=conf["discriminator"]["in_channels"],
    shuffle=True
)

test_holodec_dataset = PickleReader(
    fn_valid_raw,
    transform=valid_transforms,
    max_images=int(0.1 * conf["data"]["total_training"]),
    max_buffer_size=int(0.1 * conf["data"]["total_training"]),
    color_dim=conf["discriminator"]["in_channels"],
    shuffle=False
)

In [10]:
train_holodec_loader = torch.utils.data.DataLoader(
    train_holodec_dataset,
    batch_size=train_batch_size,
    num_workers=0,
    pin_memory=True,
    shuffle=True)

test_holodec_loader = torch.utils.data.DataLoader(
    test_holodec_dataset,
    batch_size=valid_batch_size,
    num_workers=0,  # 0 = One worker with the main process
    pin_memory=True,
    shuffle=False)

In [11]:
class Generator(nn.Module):
    def __init__(self, latent_dim = 512, img_shape = (1, 512, 512)):
        
        super(Generator, self).__init__()
        
        self.img_shape = img_shape

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.shape[0], 1, *self.img_shape)
        return img
    
class ConditionalGenerator(nn.Module):
    def __init__(self, latent_dim = 512, img_shape = (1, 512, 512)):
        
        super(ConditionalGenerator, self).__init__()
        
        self.img_shape = img_shape

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(2 * latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )
        
        self.encoder = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 1024),
            *block(1024, 512),
            *block(512, 256),
            *block(256, 128),
            *block(128, latent_dim, normalize=False)
        )

    def forward(self, x, z):
        x = self.encoder(x.reshape(x.shape[0], int(np.prod(self.img_shape))))
        z = torch.cat([x, z], 1)
        img = self.model(z)
        img = img.view(img.shape[0], 1, *self.img_shape)
        return img

In [12]:
generator = load_model(conf["generator"]).to(device) #ConditionalGenerator(latent_dim = 512, img_shape = (512, 512)).to(device) 
discriminator = load_model(conf["discriminator"]).to(device)

In [13]:
# generator = StyledGenerator(512).to(device)
# zz = torch.randn(2, 32, 512, device='cuda').chunk(2, 0)[0].squeeze(0)
# generator(zz, step = int(math.log2(512)) - 2, alpha = 1).shape

In [14]:
adv_loss = conf["trainer"]["adv_loss"]
if adv_loss == "bce":
    adversarial_loss = torch.nn.BCELoss().to(device)
    
perceptual_alex = lpips.LPIPS(net='alex').cuda()

Mask_Loss = load_loss("focal-tyversky")

Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /glade/work/schreck/py37/lib/python3.7/site-packages/lpips/weights/v0.1/alex.pth


In [15]:
optimizer_G = torch.optim.Adam(
    filter(lambda p: p.requires_grad, generator.parameters()),
    lr = conf["optimizer_G"]["learning_rate"],
    betas = (conf["optimizer_G"]["b0"], conf["optimizer_G"]["b1"]))

optimizer_D = torch.optim.Adam(
    filter(lambda p: p.requires_grad, discriminator.parameters()), 
    lr = conf["optimizer_D"]["learning_rate"], 
    betas = (conf["optimizer_D"]["b0"], conf["optimizer_D"]["b1"]))

In [16]:
def compute_gradient_penalty(discriminator, real_imgs, gen_imgs):
    """Calculates the gradient penalty loss for WGAN GP"""
    alpha = torch.rand(real_imgs.size(0), 1, 1, 1).cuda().expand_as(real_imgs)
    interpolated = Variable(alpha * real_imgs.data + (1 - alpha) * gen_imgs.data, requires_grad=True)
    out = discriminator(interpolated)[1]
    grad = torch.autograd.grad(outputs=out,
                               inputs=interpolated,
                               grad_outputs=torch.ones(out.size()).cuda(),
                               retain_graph=True,
                               create_graph=True,
                               only_inputs=True)[0]
    grad = grad.view(grad.size(0), -1)
    grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))
    d_loss_gp = torch.mean((grad_l2norm - 1) ** 2)
    return d_loss_gp

In [None]:
results = defaultdict(list)
for epoch in range(epochs):
    
    ### Train
    real_images = iter(train_holodec_loader)
    synthethic_images = iter(train_synthetic_loader)
    dual_iter = tqdm.tqdm(
        enumerate(zip(real_images, synthethic_images)),
        total = batches_per_epoch, 
        leave = True)
    
    train_results = defaultdict(list)
    for i, ((holo_img, holo_label), (synth_img, synth_label)) in dual_iter:
            
        if holo_img.shape[0] != synth_img.shape[0]:
            continue
                            
        # Adversarial ground truths
        valid = Variable(Tensor(holo_img.size(0), 1).fill_(1.0), requires_grad=False)
        fake = Variable(Tensor(holo_img.size(0), 1).fill_(0.0), requires_grad=False)

        # Configure input
        real_imgs = Variable(holo_img.type(Tensor))
        synthethic_imgs = Variable(synth_img.type(Tensor))
        #stacked = torch.cat([real_imgs, synthethic_imgs])
        
        # Sample noise as generator input
        #z = Variable(Tensor(np.random.normal(0, 1.0, (holo_img.shape[0], 512))))
        z = Variable(Tensor(np.random.normal(0, 1.0, synthethic_imgs.shape)))
        # C-GAN-like input using the synthethic image as conditional input
        gen_input = torch.cat([synthethic_imgs, z], 1)
        # Generate a batch of images
        gen_imgs = generator(gen_input)
        # Discriminate the fake images
        pred_masks, verdict = discriminator(gen_imgs)
        
        # If using D for mask prediction, compute the loss as the metric
        #mask_loss = Mask_Loss(pred_masks, synth_label.to(device))
        #train_results["fake_mask_loss"].append(mask_loss.item())
        
        # -----------------
        #  Train Generator
        # -----------------
            
        if (i + 1) % train_gen_every == 0:
            
            optimizer_G.zero_grad()
            requires_grad(generator, True)
            requires_grad(discriminator, False)
            
            # Loss measures generator's ability to fool the discriminator
            if adv_loss == 'wgan-gp':
                g_loss = -verdict.mean()
            elif adv_loss == 'hinge':
                g_loss = -verdict.mean()
            elif adv_loss == 'bce':
                g_loss = adversarial_loss(verdict, valid)
                
            mask_loss = Mask_Loss(pred_masks, synth_label.to(device))
            g_loss += mask_loss
                
            g_loss.backward()
            optimizer_G.step()
            
            train_results["g_loss"].append(g_loss.item())
            train_results["fake_mask_loss"].append(mask_loss.item())
            
            # compute perception scores
            p_score_syn = perceptual_alex(gen_imgs, synthethic_imgs).mean()
            p_score_real = perceptual_alex(gen_imgs, real_imgs).mean()
            train_results["g_perc_syn"].append(p_score_syn.item())
            train_results["g_perc_real"].append(p_score_real.item())
            
        # ---------------------
        #  Train Discriminator
        # ---------------------
        
        if (i + 1) % train_disc_every == 0:
        
            optimizer_D.zero_grad()
            requires_grad(generator, False)
            requires_grad(discriminator, True)

            # Measure discriminator's ability to classify real from generated samples
            _, disc_real = discriminator(real_imgs)
            _, disc_synth = discriminator(gen_imgs.detach())
            pred_masks, disc_synth_true = discriminator(synthethic_imgs)
            
            train_results["real_acc"].append(((disc_real > threshold) == valid).float().mean().item())
            train_results["syn_acc"].append(((disc_synth > threshold) == fake).float().mean().item())
            train_results["syn_true_acc"].append(((disc_synth_true > threshold) == valid).float().mean().item())
            
            if adv_loss == 'wgan-gp':
                real_loss = -torch.mean(disc_real) - torch.mean(disc_synth_true)
                fake_loss = disc_synth.mean() 
            elif adv_loss == 'hinge':
                real_loss = torch.nn.ReLU()(1.0 - disc_real).mean() + torch.nn.ReLU()(1.0 - disc_synth_true).mean()
                fake_loss = torch.nn.ReLU()(1.0 + disc_synth).mean()             
            elif adv_loss == 'bce':
                real_loss = adversarial_loss(torch.cat([disc_real, disc_synth_true]), torch.cat([valid, valid]))
                #real_loss = adversarial_loss(disc_real, valid) #+ adversarial_loss(disc_synth_true, valid)
                fake_loss = adversarial_loss(disc_synth, fake) 
                
            #reg_loss = torch.nn.L1Loss()(gen_imgs.detach(), synthethic_imgs)
            mask_loss = Mask_Loss(pred_masks, synth_label.to(device))
            d_loss = real_loss + fake_loss + mask_loss
            #d_loss = real_loss + fake_loss + reg_loss + mask_loss
            
            train_results["d_loss"].append(d_loss.item())
            train_results["mask_loss"].append(mask_loss.item())
            
            if adv_loss == 'wgan-gp':
                # Compute gradient penalty
                alpha = torch.rand(real_imgs.size(0), 1, 1, 1).cuda().expand_as(real_imgs)
                interpolated = Variable(alpha * real_imgs.data + (1 - alpha) * gen_imgs.data, requires_grad=True)
                out = discriminator(interpolated)[1]

                grad = torch.autograd.grad(outputs=out,
                                           inputs=interpolated,
                                           grad_outputs=torch.ones(out.size()).cuda(),
                                           retain_graph=True,
                                           create_graph=True,
                                           only_inputs=True)[0]

                grad = grad.view(grad.size(0), -1)
                grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))
                d_loss_gp = torch.mean((grad_l2norm - 1) ** 2)
                d_loss_reg = lambda_gp * d_loss_gp
                #optimizer_D.zero_grad()
                #d_loss_reg.backward()
                #optimizer_D.step()
                d_loss += d_loss_reg
                train_results["d_reg"].append(d_loss_reg.item())
            
            d_loss.backward()
            optimizer_D.step()

        print_str =  f'Epoch {epoch}'
        print_str += f' D_loss {np.mean(train_results["d_loss"]):.6f}'
        if adv_loss == 'wgan-gp':
            print_str += f' D_reg {np.mean(train_results["d_reg"]):.6f}'
        print_str += f' G_loss {np.mean(train_results["g_loss"]):6f}'
        print_str += f' mask_loss {np.mean(train_results["mask_loss"]):6f}'
        print_str += f' fake_mask_loss {np.mean(train_results["fake_mask_loss"]):6f}'
        print_str += f' h_acc {np.mean(train_results["real_acc"]):.4f}'
        print_str += f' s_p_acc {np.mean(train_results["syn_acc"]):.4f}'
        print_str += f' s_t_acc {np.mean(train_results["syn_true_acc"]):.4f}'
        print_str += f' perc_syn {np.mean(train_results["g_perc_syn"]):.4f}'
        print_str += f' perc_real {np.mean(train_results["g_perc_real"]):.4f}'
        dual_iter.set_description(print_str)
        dual_iter.refresh()
        
        if i == batches_per_epoch and i > 0:
            break
        
    # Validate 
        
    # Epoch is over. Save some stuff.
    save_image(synthethic_imgs.data[:16], f'{conf["save_loc"]}/images/synth_{epoch}.png', nrow=4, normalize=True)
    save_image(real_imgs.data[:16], f'{conf["save_loc"]}/images/real_{epoch}.png', nrow=4, normalize=True)
    save_image(gen_imgs.data[:16], f'{conf["save_loc"]}/images/pred_{epoch}.png', nrow=4, normalize=True)

    # Save the dataframe to disk
    results["epoch"].append(epoch)
    results["d_loss"].append(np.mean(train_results["d_loss"]))
    if adv_loss == 'wgan-gp':
        results["d_loss_reg"].append(np.mean(train_results["d_reg"]))
    results["g_loss"].append(np.mean(train_results["g_loss"]))
    results["mask_loss"].append(np.mean(train_results["mask_loss"]))
    results["fake_mask_loss"].append(np.mean(train_results["fake_mask_loss"]))
    results["holo_acc"].append(np.mean(train_results["real_acc"]))
    results["pred_synth_acc"].append(np.mean(train_results["syn_acc"]))
    results["true_synth_acc"].append(np.mean(train_results["syn_true_acc"]))
    results["perception_syn"].append(np.mean(train_results["g_perc_syn"]))
    results["perception_holo"].append(np.mean(train_results["g_perc_real"]))
    
    df = pd.DataFrame.from_dict(results).reset_index()
    df.to_csv(f'{conf["save_loc"]}/training_log.csv', index=False)
    
    # Save the model
    state_dict = {
        'epoch': epoch,
        'discriminator_state_dict': discriminator.state_dict(),
        'optimizer_D_state_dict': optimizer_D.state_dict(),
        'generator_state_dict': generator.state_dict(),
        'optimizer_G_state_dict': optimizer_G.state_dict(),
    }
    torch.save(state_dict, f'{conf["save_loc"]}/best.pt')

Epoch 0 D_loss 2.101567 G_loss 2.603015 mask_loss 0.998724 fake_mask_loss 0.999392 h_acc 0.6975 s_p_acc 0.8875 s_t_acc 0.7600 perc_syn 0.8900 perc_real 0.6055: 100%|██████████| 250/250 [04:29<00:00,  1.08s/it]
Epoch 1 D_loss 1.779466 G_loss 3.157093 mask_loss 0.997504 fake_mask_loss 0.999379 h_acc 0.7575 s_p_acc 0.8650 s_t_acc 0.8625 perc_syn 0.8086 perc_real 0.6560: 100%|██████████| 250/250 [04:48<00:00,  1.16s/it]
Epoch 2 D_loss 2.074030 G_loss 2.747243 mask_loss 0.994488 fake_mask_loss 0.999172 h_acc 0.5675 s_p_acc 0.8075 s_t_acc 0.8450 perc_syn 0.9103 perc_real 0.6429: 100%|██████████| 250/250 [05:04<00:00,  1.22s/it]
Epoch 3 D_loss 2.033770 G_loss 2.620184 mask_loss 0.988791 fake_mask_loss 0.995693 h_acc 0.6050 s_p_acc 0.8300 s_t_acc 0.9000 perc_syn 0.7446 perc_real 0.6580: 100%|██████████| 250/250 [05:05<00:00,  1.22s/it]
Epoch 4 D_loss 1.807148 G_loss 2.667739 mask_loss 0.967748 fake_mask_loss 0.981550 h_acc 0.5650 s_p_acc 0.8325 s_t_acc 1.0000 perc_syn 0.6655 perc_real 0.6145: 

In [None]:
# "Images0" and "training_log_0" used both real and synthetic in D and G, and the L1 loss term on G. 
# "Images1" and "training_log_1" used both real and synthetic in D, and the L1 loss term on D. 

# "Images2" and "training_log_2" used both real and synthetic in D, and the mask loss term on D
# "Images3" and "training_log_3" used both real and synthetic in D, and the L1 loss and mask loss term on D
# "Images4" and "training_log_4" used both real and synthetic in D, and the L1 loss and mask loss term on D, mask loss on G.