In [1]:
import segmentation_models_pytorch as smp
import torch

from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingWarmRestarts
from collections import defaultdict
import pandas as pd
import numpy as np
import torch.fft
import subprocess
import logging
import random
import shutil
import psutil
import scipy
import torch
import copy
import yaml
import time
import tqdm
import sys
import gc


from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

from holodecml.data import PickleReader, UpsamplingReader
from holodecml.propagation import InferencePropagator
from holodecml.transforms import LoadTransformations
from holodecml.models import load_model
from holodecml.losses import load_loss

import os
import warnings
warnings.filterwarnings("ignore")

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

In [2]:
import lpips

In [3]:
available_ncpus = len(psutil.Process().cpu_affinity())

In [4]:
# ### Set seeds for reproducibility
def seed_everything(seed=1234):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.deterministic = True

In [5]:
config = "../results/generator/model.yml"
with open(config) as cf:
    conf = yaml.load(cf, Loader=yaml.FullLoader)

In [6]:
# Set seeds for reproducibility
seed = 1000 if "seed" not in conf else conf["seed"]
seed_everything(seed)

tile_size = int(conf["data"]["tile_size"])
step_size = int(conf["data"]["step_size"])
data_path = conf["data"]["output_path"]
data_path_raw = conf["data"]["output_path_raw"]

total_positive = int(conf["data"]["total_positive"])
total_negative = int(conf["data"]["total_negative"])
total_examples = int(conf["data"]["total_training"])

transform_mode = "None" if "transform_mode" not in conf["data"] else conf["data"]["transform_mode"]
config_ncpus = int(conf["data"]["cores"])
use_cached = False if "use_cached" not in conf["data"] else conf["data"]["use_cached"]

name_tag = f"{tile_size}_{step_size}_{total_positive}_{total_negative}_{total_examples}_{transform_mode}"
fn_train = f"{data_path}/training_{name_tag}.pkl"
fn_valid = f"{data_path}/validation_{name_tag}.pkl"
fn_train_raw = f"{data_path_raw}/training_{name_tag}.pkl"
fn_valid_raw = f"{data_path_raw}/validation_{name_tag}.pkl"

output_path = conf["data"]["output_path"]

train_batch_size = 16
valid_batch_size = 16

latent_dim = 128
img_shape = (1, tile_size, tile_size)

d_learning_rate = 0.0004
g_learning_rate = 0.0001

b1 = 0.0
b2 = 0.9

In [7]:
is_cuda = torch.cuda.is_available()
data_device = torch.device("cpu") if "device" not in conf["data"] else conf["data"]["device"]

In [8]:
# Load the preprocessing transforms
if "Normalize" in conf["transforms"]["training"]:
    conf["transforms"]["validation"]["Normalize"]["mode"] = conf["transforms"]["training"]["Normalize"]["mode"]
    conf["transforms"]["inference"]["Normalize"]["mode"] = conf["transforms"]["training"]["Normalize"]["mode"]

train_transforms = LoadTransformations(conf["transforms"]["training"])
valid_transforms = LoadTransformations(conf["transforms"]["validation"])

In [9]:
train_synthetic_dataset = PickleReader(
    fn_train,
    transform=train_transforms,
    max_images=int(0.8 * conf["data"]["total_training"]),
    max_buffer_size=int(0.1 * conf["data"]["total_training"]),
    color_dim=conf["model"]["in_channels"],
    shuffle=True
)

test_synthetic_dataset = PickleReader(
    fn_valid,
    transform=valid_transforms,
    max_images=int(0.1 * conf["data"]["total_training"]),
    max_buffer_size=int(0.1 * conf["data"]["total_training"]),
    color_dim=conf["model"]["in_channels"],
    shuffle=False
)

In [10]:
train_synthetic_loader = torch.utils.data.DataLoader(
    train_synthetic_dataset,
    batch_size=train_batch_size,
    num_workers=0, #available_ncpus//2,
    pin_memory=True,
    shuffle=True)

test_synthetic_loader = torch.utils.data.DataLoader(
    test_synthetic_dataset,
    batch_size=valid_batch_size,
    num_workers=0,  # 0 = One worker with the main process
    pin_memory=True,
    shuffle=False)

In [12]:
train_holodec_dataset = PickleReader(
    fn_train_raw,
    transform=train_transforms,
    max_images=int(0.8 * conf["data"]["total_training"]),
    max_buffer_size=int(0.1 * conf["data"]["total_training"]),
    color_dim=conf["model"]["in_channels"],
    shuffle=True
)

test_holodec_dataset = PickleReader(
    fn_valid_raw,
    transform=valid_transforms,
    max_images=int(0.1 * conf["data"]["total_training"]),
    max_buffer_size=int(0.1 * conf["data"]["total_training"]),
    color_dim=conf["model"]["in_channels"],
    shuffle=False
)

In [13]:
train_holodec_loader = torch.utils.data.DataLoader(
    train_holodec_dataset,
    batch_size=train_batch_size,
    num_workers=0,
    pin_memory=True,
    shuffle=True)

test_holodec_loader = torch.utils.data.DataLoader(
    test_holodec_dataset,
    batch_size=valid_batch_size,
    num_workers=0,
    pin_memory=True,
    shuffle=False)

In [14]:
generator = smp.Linknet(
    encoder_name="xception",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=1,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=1,                      # model output channels (number of classes in your dataset),
    activation = "tanh"
)

In [15]:
# Loss function
adversarial_loss = torch.nn.L1Loss()
perceptual_alex = lpips.LPIPS(net='alex').cuda()

if is_cuda:
    generator.cuda()
    adversarial_loss.cuda()
    
def requires_grad(model, flag=True):
    for p in model.parameters():
        p.requires_grad = flag

Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /glade/work/schreck/py37/lib/python3.7/site-packages/lpips/weights/v0.1/alex.pth


In [16]:
optimizer_G = torch.optim.Adam(
    filter(lambda p: p.requires_grad, generator.parameters()), 
    lr=g_learning_rate
)

In [17]:
# lr_scheduler = ReduceLROnPlateau(
#         optimizer_G,
#         patience=1,
#         min_lr=g_learning_rate * 1e-3,
#         verbose=True
#     )


T_0=10
T_mult=1
eta_min=0.001
last_epoch=-1

lr_scheduler = CosineAnnealingWarmRestarts(
    optimizer_G, 
    T_0=T_0, 
    T_mult=T_mult,
    eta_min=eta_min,
    last_epoch=last_epoch
)

In [18]:
n_epochs = 200
batches_per_epoch = 500
valid_batches_per_epoch = 100
stopping_patience = 4
metric = "valid_loss"

Tensor = torch.cuda.FloatTensor if is_cuda else torch.FloatTensor

results = defaultdict(list)
for epoch in range(n_epochs):
    
    ### Train
    dual_iter = tqdm.tqdm(
        enumerate(iter(train_holodec_loader)),
        total = batches_per_epoch,
        leave = True
    )
    
    train_results = defaultdict(list)
    for i, (holo_img, holo_label) in dual_iter:
                                
        optimizer_G.zero_grad()
        requires_grad(generator, True)
        requires_grad(perceptual_alex, False)

        # Configure input
        batched_imgs = Variable(holo_img.type(Tensor))
        #synthethic_imgs = Variable(synth_img.type(Tensor))
        #batched_imgs = torch.cat([real_imgs, synthethic_imgs], 0)
        
        pred_imgs = generator(batched_imgs)
        loss = adversarial_loss(pred_imgs, batched_imgs)
        perceptual_score = perceptual_alex(pred_imgs, batched_imgs).mean()
        
        train_results["loss"].append(loss.item())
        train_results["lpips"].append(perceptual_score.item())
        
        to_print = f'Epoch {epoch} train_loss {np.mean(train_results["loss"]):.6f}'
        to_print += f' train_percept {np.mean(train_results["lpips"]):.6f}'
        to_print += f' lr {optimizer_G.param_groups[0]["lr"]}'
        dual_iter.set_description(to_print)
        dual_iter.update()
        
        loss.backward()
        optimizer_G.step()
        
        lr_scheduler.step(epoch + i / batches_per_epoch)
            
        if i == batches_per_epoch and i > 0:
            break
    
    ### Validation
    requires_grad(generator, False)
    requires_grad(perceptual_alex, False)

    dual_iter = tqdm.tqdm(
        enumerate(iter(test_holodec_loader)),
        total = valid_batches_per_epoch, 
        leave = True)
    
    valid_results = defaultdict(list)
    for i, (holo_img, holo_label) in dual_iter:
                    
        # Configure input
        batched_imgs = Variable(holo_img.type(Tensor))
        pred_images = generator(batched_imgs)
        loss = adversarial_loss(pred_images, batched_imgs)
        perceptual_score = perceptual_alex(pred_imgs, batched_imgs).mean()
        
        valid_results["loss"].append(loss.item())
        valid_results["lpips"].append(perceptual_score.item())
        
        to_print = f'Epoch {epoch} valid_loss {np.mean(valid_results["loss"]):.6f}'
        to_print += f' valid_percept {np.mean(valid_results["lpips"]):.6f}'
        dual_iter.set_description(to_print)
        dual_iter.update()
        
        if i == valid_batches_per_epoch and i > 0:
            break
         
    # Save the last validation batch images
    save_image(batched_imgs.data[:16], f'{conf["save_loc"]}/images/real_{epoch}.png', nrow=4, normalize=True)
    save_image(pred_images.data[:16], f'{conf["save_loc"]}/images/pred_{epoch}.png', nrow=4, normalize=True)
            
    results["epoch"].append(epoch)
    results["train_loss"].append(np.mean(train_results["loss"]))
    results["train_perception"].append(np.mean(train_results["lpips"]))
    results["valid_loss"].append(np.mean(valid_results["loss"]))
    results["valid_perception"].append(np.mean(valid_results["lpips"]))
    
    #print_str = f"Epoch {epoch}"
    #print_str += f' train_loss {np.mean(train_results["loss"])} valid_loss {np.mean(valid_results["loss"])}'
    #print(print_str)
    
    df = pd.DataFrame.from_dict(results).reset_index()

    # Save the dataframe to disk
    df.to_csv(f'{conf["save_loc"]}/training_log.csv', index=False)
    
    # If the best model, save to disk
    # Save the model if its the best so far.
    if results[metric][-1] == min(results[metric]):
        state_dict = {
            'epoch': epoch,
            'model_state_dict': generator.state_dict(),
            'optimizer_state_dict': optimizer_G.state_dict(),
            'loss': results["valid_loss"][-1]
        }
        torch.save(state_dict, f'{conf["save_loc"]}/best.pt')
    
    # Lower the learning rate
    #lr_scheduler.step(np.mean(valid_results["loss"]))
    
    # Save some images
    #save_image(synthethic_imgs.data[:16], f'{conf["save_loc"]}/images/synth_{epoch}.png', nrow=4, normalize=True)
    save_image(batched_imgs.data[:16], f'{conf["save_loc"]}/images/real_{epoch}.png', nrow=4, normalize=True)
    save_image(pred_images.data[:16], f'{conf["save_loc"]}/images/pred_{epoch}.png', nrow=4, normalize=True)
    
    # Early stopping
    best_epoch = [i for i, j in enumerate(
        results[metric]) if j == min(results[metric])][0]
    offset = epoch - best_epoch
    if offset >= stopping_patience:
        break

Epoch 0 train_loss 0.142729 train_percept 0.265159 lr 0.00012193727965514856: 100%|██████████| 500/500 [05:59<00:00,  1.39it/s]
Epoch 0 valid_loss 0.094591 valid_percept 0.464260: 100%|██████████| 100/100 [00:26<00:00,  3.79it/s]
Epoch 1 train_loss 0.079879 train_percept 0.084073 lr 0.0001857762320395614: 100%|██████████| 500/500 [06:20<00:00,  1.31it/s] 
Epoch 1 valid_loss 0.065890 valid_percept 0.438006: 100%|██████████| 100/100 [00:26<00:00,  3.79it/s]
Epoch 2 train_loss 0.059784 train_percept 0.043496 lr 0.00028526794452815325: 100%|██████████| 500/500 [06:21<00:00,  1.31it/s]
Epoch 2 valid_loss 0.049015 valid_percept 0.509616: 100%|██████████| 100/100 [00:26<00:00,  3.77it/s]
Epoch 3 train_loss 0.044944 train_percept 0.030726 lr 0.00041067347510301863: 100%|██████████| 500/500 [06:22<00:00,  1.31it/s]
Epoch 3 valid_loss 0.043115 valid_percept 0.449869: 100%|██████████| 100/100 [00:26<00:00,  3.85it/s]
Epoch 4 train_loss 0.035364 train_percept 0.019876 lr 0.0005497172566797806: 100

KeyboardInterrupt: 

In [20]:
requires_grad(generator, False)

In [22]:
genn = iter(test_synthetic_loader)

In [29]:
batch = next(genn)[0].cuda()

In [30]:
pred_imgs = generator(batch)

In [31]:
save_image(batch.data[:16], f'{conf["save_loc"]}/images/synth_{epoch}.png', nrow=4, normalize=True)
save_image(pred_imgs.data[:16], f'{conf["save_loc"]}/images/synth_pred_{epoch}.png', nrow=4, normalize=True)