In [1]:
import segmentation_models_pytorch as smp
import torch

from torch.optim.lr_scheduler import ReduceLROnPlateau
from collections import defaultdict
import pandas as pd
import numpy as np
import torch.fft
import subprocess
import logging
import random
import shutil
import psutil
import scipy
import torch
import copy
import yaml
import time
import tqdm
import sys
import gc


from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

from holodecml.data import PickleReader, UpsamplingReader
from holodecml.propagation import InferencePropagator
from holodecml.transforms import LoadTransformations
from holodecml.models import load_model
from holodecml.losses import load_loss

import os
import warnings
warnings.filterwarnings("ignore")

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

In [2]:
available_ncpus = len(psutil.Process().cpu_affinity())

In [3]:
# ### Set seeds for reproducibility
def seed_everything(seed=1234):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.deterministic = True

In [4]:
config = "../results/generator/model.yml"
with open(config) as cf:
    conf = yaml.load(cf, Loader=yaml.FullLoader)

In [5]:
# Set seeds for reproducibility
seed = 1000 if "seed" not in conf else conf["seed"]
seed_everything(seed)

tile_size = int(conf["data"]["tile_size"])
step_size = int(conf["data"]["step_size"])
data_path = conf["data"]["output_path"]
data_path_raw = conf["data"]["output_path_raw"]

total_positive = int(conf["data"]["total_positive"])
total_negative = int(conf["data"]["total_negative"])
total_examples = int(conf["data"]["total_training"])

transform_mode = "None" if "transform_mode" not in conf["data"] else conf["data"]["transform_mode"]
config_ncpus = int(conf["data"]["cores"])
use_cached = False if "use_cached" not in conf["data"] else conf["data"]["use_cached"]

name_tag = f"{tile_size}_{step_size}_{total_positive}_{total_negative}_{total_examples}_{transform_mode}"
fn_train = f"{data_path}/training_{name_tag}.pkl"
fn_valid = f"{data_path}/validation_{name_tag}.pkl"
fn_train_raw = f"{data_path_raw}/training_{name_tag}.pkl"
fn_valid_raw = f"{data_path_raw}/validation_{name_tag}.pkl"

output_path = conf["data"]["output_path"]

train_batch_size = 16
valid_batch_size = 16

latent_dim = 128
img_shape = (1, tile_size, tile_size)

d_learning_rate = 0.0004
g_learning_rate = 0.0001

b1 = 0.0
b2 = 0.9

In [6]:
is_cuda = torch.cuda.is_available()
data_device = torch.device("cpu") if "device" not in conf["data"] else conf["data"]["device"]

In [7]:
# Load the preprocessing transforms
if "Normalize" in conf["transforms"]["training"]:
    conf["transforms"]["validation"]["Normalize"]["mode"] = conf["transforms"]["training"]["Normalize"]["mode"]
    conf["transforms"]["inference"]["Normalize"]["mode"] = conf["transforms"]["training"]["Normalize"]["mode"]

train_transforms = LoadTransformations(conf["transforms"]["training"])
valid_transforms = LoadTransformations(conf["transforms"]["validation"])

In [8]:
train_synthetic_dataset = PickleReader(
    fn_train,
    transform=train_transforms,
    max_images=int(0.8 * conf["data"]["total_training"]),
    max_buffer_size=int(0.1 * conf["data"]["total_training"]),
    color_dim=conf["model"]["in_channels"],
    shuffle=True
)

test_synthetic_dataset = PickleReader(
    fn_valid,
    transform=valid_transforms,
    max_images=int(0.1 * conf["data"]["total_training"]),
    max_buffer_size=int(0.1 * conf["data"]["total_training"]),
    color_dim=conf["model"]["in_channels"],
    shuffle=False
)

In [9]:
train_synthetic_loader = torch.utils.data.DataLoader(
    train_synthetic_dataset,
    batch_size=train_batch_size,
    num_workers=0, #available_ncpus//2,
    pin_memory=True,
    shuffle=True)

test_synthetic_loader = torch.utils.data.DataLoader(
    test_synthetic_dataset,
    batch_size=valid_batch_size,
    num_workers=0,  # 0 = One worker with the main process
    pin_memory=True,
    shuffle=False)

In [10]:
# holo_conf = copy.deepcopy(conf)
# holo_conf["data"]["data_path"] = holo_conf["data"]["raw_data"]


# train_holodec_dataset = UpsamplingReader(
#     holo_conf,
#     transform=train_transforms,
#     device=data_device
# )


# # test_holodec_inputs = torch.from_numpy(np.load(os.path.join(
# #             output_path, f'manual_images_{transform_mode}.npy'))).float()
# # test_holodec_labels = torch.from_numpy(np.load(os.path.join(
# #             output_path, f'manual_labels_{transform_mode}.npy'))).float()

# # test_holodec_dataset = torch.utils.data.TensorDataset(test_holodec_inputs, test_holodec_labels)


# holo_conf = copy.deepcopy(conf)
# holo_conf["data"]["raw_data"] = "/glade/p/cisl/aiml/ai4ess_hackathon/holodec/real_holograms_CSET_RF07_20150719_203600-203700.nc"
# holo_conf["data"]["data_path"] = holo_conf["data"]["raw_data"]

# test_holodec_dataset = UpsamplingReader(
#     holo_conf,
#     transform=valid_transforms,
#     device=data_device
# )

In [None]:
train_holodec_dataset = PickleReader(
    fn_train_raw,
    transform=train_transforms,
    max_images=int(0.8 * conf["data"]["total_training"]),
    max_buffer_size=int(0.1 * conf["data"]["total_training"]),
    color_dim=conf["model"]["in_channels"],
    shuffle=True
)

test_holodec_dataset = PickleReader(
    fn_valid_raw,
    transform=valid_transforms,
    max_images=int(0.1 * conf["data"]["total_training"]),
    max_buffer_size=int(0.1 * conf["data"]["total_training"]),
    color_dim=conf["model"]["in_channels"],
    shuffle=False
)

In [11]:
train_holodec_loader = torch.utils.data.DataLoader(
    train_holodec_dataset,
    batch_size=train_batch_size,
    num_workers=available_ncpus,
    pin_memory=True,
    shuffle=True)

test_holodec_loader = torch.utils.data.DataLoader(
    test_holodec_dataset,
    batch_size=valid_batch_size,
    num_workers=available_ncpus,
    pin_memory=True,
    shuffle=False)

In [12]:
generator = smp.Unet(
    encoder_name="resnet18",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=1,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=1,                      # model output channels (number of classes in your dataset),
    activation = "tanh"
)

In [13]:
# Loss function
adversarial_loss = torch.nn.L1Loss()

if is_cuda:
    generator.cuda()
    adversarial_loss.cuda()
    
def requires_grad(model, flag=True):
    for p in model.parameters():
        p.requires_grad = flag

In [14]:
optimizer_G = torch.optim.Adam(
    filter(lambda p: p.requires_grad, generator.parameters()), 
    lr=g_learning_rate
)

In [15]:
lr_scheduler = ReduceLROnPlateau(
        optimizer_G,
        patience=1,
        min_lr=g_learning_rate * 1e-3,
        verbose=True
    )

In [None]:
n_epochs = 100
batches_per_epoch = 50
valid_batches_per_epoch = 50
stopping_patience = 4

Tensor = torch.cuda.FloatTensor if is_cuda else torch.FloatTensor

results = defaultdict(list)
for epoch in range(n_epochs):
    
    ### Train
    real_images = iter(train_holodec_loader)
    synthethic_images = iter(train_synthetic_loader)
    dual_iter = tqdm.tqdm(
        enumerate(zip(real_images, synthethic_images)),
        total = batches_per_epoch, 
        leave = True)
    
    train_results = defaultdict(list)
    for i, ((holo_img, holo_label), (synth_img, synth_label)) in dual_iter:
                    
        if holo_img.shape[0] != synth_img.shape[0]:
            continue
            
        optimizer_G.zero_grad()
        requires_grad(generator, True)

        # Configure input
        real_imgs = Variable(holo_img.type(Tensor))
        synthethic_imgs = Variable(synth_img.type(Tensor))
        batched_imgs = torch.cat([real_imgs, synthethic_imgs], 0)
        
        pred_imgs = generator(batched_imgs)
        loss = adversarial_loss(pred_imgs, batched_imgs)
        
        train_results["loss"].append(loss.item())
        
        to_print = f'Epoch {epoch} train_loss {np.mean(train_results["loss"])}'
        dual_iter.set_description(to_print)
        dual_iter.update()
        
        loss.backward()
        optimizer_G.step()
            
        if i == batches_per_epoch and i > 0:
            break
    
    ### Validation
    requires_grad(generator, False)
    
    real_images = iter(test_holodec_loader)
    synthethic_images = iter(test_synthetic_loader)
    dual_iter = tqdm.tqdm(
        enumerate(zip(real_images, synthethic_images)),
        total = valid_batches_per_epoch, 
        leave = True)
    
    valid_results = defaultdict(list)
    for i, ((holo_img, holo_label), (synth_img, synth_label)) in dual_iter:
                    
        if holo_img.shape[0] != synth_img.shape[0]:
            continue

        # Configure input
        real_imgs = Variable(holo_img.type(Tensor))
        synthethic_imgs = Variable(synth_img.type(Tensor))
        batched_imgs = torch.cat([real_imgs, synthethic_imgs], 0)
        
        pred_images = generator(batched_imgs)
        loss = adversarial_loss(pred_images, batched_imgs)
        
        valid_results["loss"].append(loss.item())
        
        to_print = f'Epoch {epoch} valid_loss {np.mean(valid_results["loss"])}'
        dual_iter.set_description(to_print)
        dual_iter.update()
        
        if i == valid_batches_per_epoch and i > 0:
            break
            
    results["epoch"].append(epoch)
    results["train_loss"].append(np.mean(train_results["loss"]))
    results["valid_loss"].append(np.mean(valid_results["loss"]))
    
    print_str = f"Epoch {epoch}"
    print_str += f' train_loss {np.mean(train_results["loss"])} valid_loss {np.mean(valid_results["loss"])}'
    print(print_str)
    
    df = pd.DataFrame.from_dict(results).reset_index()

    # Save the dataframe to disk
    df.to_csv(f'{conf["save_loc"]}/training_log.csv', index=False)
    
    # If the best model, save to disk
    # Save the model if its the best so far.
    if results["valid_loss"][-1] == min(results["valid_loss"]):
        state_dict = {
            'epoch': epoch,
            'model_state_dict': generator.state_dict(),
            'optimizer_state_dict': optimizer_G.state_dict(),
            'loss': results["valid_loss"][-1]
        }
        torch.save(state_dict, f'{conf["save_loc"]}/best.pt')
    
    # Lower the learning rate
    lr_scheduler.step(np.mean(valid_results["loss"]))
    
    # Early stopping
    best_epoch = [i for i, j in enumerate(
        results["valid_loss"]) if j == min(results["valid_loss"])][0]
    offset = epoch - best_epoch
    if offset >= stopping_patience:
        break

Epoch 0 train_loss 0.1900023131393919: 100%|██████████| 50/50 [03:29<00:00,  4.18s/it] 
Epoch 0 valid_loss 0.10530071968541425: 100%|██████████| 50/50 [03:37<00:00,  4.34s/it]


Epoch 0 train_loss 0.1900023131393919 valid_loss 0.10530071968541425


Epoch 1 train_loss 0.09508555163355435: 100%|██████████| 50/50 [03:33<00:00,  4.27s/it]
Epoch 1 valid_loss 0.09006606349173714: 100%|██████████| 50/50 [03:32<00:00,  4.26s/it]


Epoch 1 train_loss 0.09508555163355435 valid_loss 0.09006606349173714


Epoch 2 train_loss 0.08504742442392836: 100%|██████████| 50/50 [03:36<00:00,  4.33s/it]
Epoch 2 valid_loss 0.08281190649551504: 100%|██████████| 50/50 [03:34<00:00,  4.30s/it]


Epoch 2 train_loss 0.08504742442392836 valid_loss 0.08281190649551504


Epoch 3 train_loss 0.07833634010132622: 100%|██████████| 50/50 [03:36<00:00,  4.34s/it]
Epoch 3 valid_loss 0.07767726831576403: 100%|██████████| 50/50 [03:32<00:00,  4.25s/it]


Epoch 3 train_loss 0.07833634010132622 valid_loss 0.07767726831576403


Epoch 4 train_loss 0.07700869411814447: 100%|██████████| 50/50 [03:41<00:00,  4.43s/it]
Epoch 4 valid_loss 0.07633406478984683: 100%|██████████| 50/50 [03:33<00:00,  4.27s/it]


Epoch 4 train_loss 0.07700869411814447 valid_loss 0.07633406478984683


Epoch 5 train_loss 0.07263892874413845: 100%|██████████| 50/50 [03:35<00:00,  4.31s/it]
Epoch 5 valid_loss 0.0715823461319886: 100%|██████████| 50/50 [03:28<00:00,  4.17s/it] 


Epoch 5 train_loss 0.07263892874413845 valid_loss 0.0715823461319886


Epoch 6 train_loss 0.06938006100701351: 100%|██████████| 50/50 [03:39<00:00,  4.38s/it]
Epoch 6 valid_loss 0.07189084136602926: 100%|██████████| 50/50 [03:31<00:00,  4.24s/it]

Epoch 6 train_loss 0.06938006100701351 valid_loss 0.07189084136602926



Epoch 7 train_loss 0.06950113885835105: 100%|██████████| 50/50 [03:37<00:00,  4.36s/it]
Epoch 7 valid_loss 0.070238297476488: 100%|██████████| 50/50 [03:27<00:00,  4.15s/it]  


Epoch 7 train_loss 0.06950113885835105 valid_loss 0.070238297476488


Epoch 8 train_loss 0.06826724820569449: 100%|██████████| 50/50 [03:42<00:00,  4.45s/it]
Epoch 8 valid_loss 0.06666889881678656: 100%|██████████| 50/50 [03:27<00:00,  4.15s/it]


Epoch 8 train_loss 0.06826724820569449 valid_loss 0.06666889881678656


Epoch 9 train_loss 0.06332850751156609: : 57it [03:10,  2.07s/it]                      