In [1]:
import segmentation_models_pytorch as smp
import torch

In [2]:
generator = smp.Unet(
    encoder_name="resnet18",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=1,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=1,                      # model output channels (number of classes in your dataset),
    activation = "tanh"
)

In [11]:
#gen_noise = generator(torch.ones([1, 1, 64, 64]))

In [12]:
#gen_noise.shape

In [13]:
# for thing in generator.encoder(torch.ones([1, 1, 64, 64])):
#     print(thing.shape)

In [5]:
aux_params=dict(
    pooling='avg',             # one of 'avg', 'max'
    dropout=0.2,               # dropout ratio, default is None
    activation='sigmoid',      # activation function, default is None
    classes=1,                 # define number of output labels
)

discriminator = smp.Unet('resnet18', in_channels=1, aux_params=aux_params)
_, verdict = discriminator(torch.ones([32, 1, 64, 64]))

In [6]:
#torch.where(verdict > 0.5, True, False)

In [7]:
from torch.optim.lr_scheduler import ReduceLROnPlateau
from collections import defaultdict
import pandas as pd
import numpy as np
import torch.fft
import subprocess
import logging
import random
import shutil
import psutil
import scipy
import torch
import copy
import yaml
import time
import tqdm
import sys
import gc


from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

from holodecml.data import PickleReader, UpsamplingReader
from holodecml.propagation import InferencePropagator
from holodecml.transforms import LoadTransformations
from holodecml.models import load_model
from holodecml.losses import load_loss

import os
import warnings
warnings.filterwarnings("ignore")

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

In [8]:
available_ncpus = len(psutil.Process().cpu_affinity())

In [9]:
# ### Set seeds for reproducibility
def seed_everything(seed=1234):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.deterministic = True

In [10]:
config = "../results/gan/model.yml"
with open(config) as cf:
    conf = yaml.load(cf, Loader=yaml.FullLoader)

In [11]:
# Set seeds for reproducibility
seed = 1000 if "seed" not in conf else conf["seed"]
seed_everything(seed)

tile_size = int(conf["data"]["tile_size"])
step_size = int(conf["data"]["step_size"])
data_path = conf["data"]["output_path"]
data_path_raw = conf["data"]["output_path_raw"]

total_positive = int(conf["data"]["total_positive"])
total_negative = int(conf["data"]["total_negative"])
total_examples = int(conf["data"]["total_training"])

transform_mode = "None" if "transform_mode" not in conf["data"] else conf["data"]["transform_mode"]
config_ncpus = int(conf["data"]["cores"])
use_cached = False if "use_cached" not in conf["data"] else conf["data"]["use_cached"]

name_tag = f"{tile_size}_{step_size}_{total_positive}_{total_negative}_{total_examples}_{transform_mode}"
fn_train = f"{data_path}/training_{name_tag}.pkl"
fn_valid = f"{data_path}/validation_{name_tag}.pkl"
fn_train_raw = f"{data_path_raw}/training_{name_tag}.pkl"
fn_valid_raw = f"{data_path_raw}/validation_{name_tag}.pkl"

output_path = conf["data"]["output_path"]

train_batch_size = 16
valid_batch_size = 16

latent_dim = 128
img_shape = (1, tile_size, tile_size)

d_learning_rate = 0.0004
g_learning_rate = 0.0001

b1 = 0.0
b2 = 0.9

In [12]:
is_cuda = torch.cuda.is_available()
data_device = torch.device("cpu") if "device" not in conf["data"] else conf["data"]["device"]

In [13]:
# Load the preprocessing transforms
if "Normalize" in conf["transforms"]["training"]:
    conf["transforms"]["validation"]["Normalize"]["mode"] = conf["transforms"]["training"]["Normalize"]["mode"]
    conf["transforms"]["inference"]["Normalize"]["mode"] = conf["transforms"]["training"]["Normalize"]["mode"]

train_transforms = LoadTransformations(conf["transforms"]["training"])
valid_transforms = LoadTransformations(conf["transforms"]["validation"])

In [14]:
# train_synthetic_dataset = UpsamplingReader(
#     conf,
#     transform=train_transforms,
#     max_size=100,
#     device=data_device
# )

train_synthetic_dataset = PickleReader(
    fn_train,
    transform=train_transforms,
    max_images=int(0.8 * conf["data"]["total_training"]),
    max_buffer_size=int(0.1 * conf["data"]["total_training"]),
    color_dim=conf["model"]["in_channels"],
    shuffle=True
)

test_synthetic_dataset = PickleReader(
    fn_valid,
    transform=valid_transforms,
    max_images=int(0.1 * conf["data"]["total_training"]),
    max_buffer_size=int(0.1 * conf["data"]["total_training"]),
    color_dim=conf["model"]["in_channels"],
    shuffle=False
)

In [15]:
train_synthetic_loader = torch.utils.data.DataLoader(
    train_synthetic_dataset,
    batch_size=train_batch_size,
    num_workers=0, #available_ncpus//2,
    pin_memory=True,
    shuffle=True)

test_synthetic_loader = torch.utils.data.DataLoader(
    test_synthetic_dataset,
    batch_size=valid_batch_size,
    num_workers=0,  # 0 = One worker with the main process
    pin_memory=True,
    shuffle=False)

In [16]:
# holo_conf = copy.deepcopy(conf)
# holo_conf["data"]["data_path"] = holo_conf["data"]["raw_data"]


# train_holodec_dataset = UpsamplingReader(
#     holo_conf,
#     transform=train_transforms,
#     max_size=1000,
#     device=data_device
# )


# # test_holodec_inputs = torch.from_numpy(np.load(os.path.join(
# #             output_path, f'manual_images_{transform_mode}.npy'))).float()
# # test_holodec_labels = torch.from_numpy(np.load(os.path.join(
# #             output_path, f'manual_labels_{transform_mode}.npy'))).float()

# # test_holodec_dataset = torch.utils.data.TensorDataset(test_holodec_inputs, test_holodec_labels)


# holo_conf = copy.deepcopy(conf)
# holo_conf["data"]["raw_data"] = "/glade/p/cisl/aiml/ai4ess_hackathon/holodec/real_holograms_CSET_RF07_20150719_203600-203700.nc"
# holo_conf["data"]["data_path"] = holo_conf["data"]["raw_data"]

# test_holodec_dataset = UpsamplingReader(
#     holo_conf,
#     transform=valid_transforms,
#     max_size=1000,
#     device=data_device
# )

In [None]:
train_holodec_dataset = PickleReader(
    fn_train_raw,
    transform=train_transforms,
    max_images=int(0.8 * conf["data"]["total_training"]),
    max_buffer_size=int(0.1 * conf["data"]["total_training"]),
    color_dim=conf["model"]["in_channels"],
    shuffle=True
)

test_holodec_dataset = PickleReader(
    fn_valid_raw,
    transform=valid_transforms,
    max_images=int(0.1 * conf["data"]["total_training"]),
    max_buffer_size=int(0.1 * conf["data"]["total_training"]),
    color_dim=conf["model"]["in_channels"],
    shuffle=False
)

In [17]:
train_holodec_loader = torch.utils.data.DataLoader(
    train_holodec_dataset,
    batch_size=train_batch_size,
    num_workers=available_ncpus,
    pin_memory=True,
    shuffle=True)

test_holodec_loader = torch.utils.data.DataLoader(
    test_holodec_dataset,
    batch_size=valid_batch_size,
    num_workers=available_ncpus,
    pin_memory=True,
    shuffle=False)

In [18]:
generator = smp.Unet(
    encoder_name="resnet18",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=1,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=1,                      # model output channels (number of classes in your dataset),
#     activation = "tanh"
)

aux_params=dict(
    pooling='avg',             # one of 'avg', 'max'
    dropout=0.0,               # dropout ratio, default is None
    activation='sigmoid',      # activation function, default is None
    classes=1,                 # define number of output labels
)

discriminator = smp.Unet(
#     encoder_name="resnet18",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
#     encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=1, 
    aux_params=aux_params
)

#print(generator(torch.ones([1, 1, 64, 64])).shape)
#print(discriminator(torch.ones([1, 1, 64, 64]))[1].shape)

In [19]:
# Loss function
adversarial_loss = torch.nn.BCELoss()

if is_cuda:
    discriminator.cuda()
    adversarial_loss.cuda()
    
def requires_grad(model, flag=True):
    for p in model.parameters():
        p.requires_grad = flag

In [20]:
optimizer_D = torch.optim.Adam(
    filter(lambda p: p.requires_grad, discriminator.parameters()), 
    lr=d_learning_rate
)

In [21]:
lr_scheduler = ReduceLROnPlateau(
        optimizer_D,
        patience=1,
        min_lr=1.0e-13,
        verbose=True
    )

In [None]:
n_epochs = 5000
batches_per_epoch = 10
valid_batches_per_epoch = 50
stopping_patience = 4
threshold = 0.5

Tensor = torch.cuda.FloatTensor if is_cuda else torch.FloatTensor

results = defaultdict(list)
for epoch in range(n_epochs):
    
    ### Train
    real_images = iter(train_holodec_loader)
    synthethic_images = iter(train_synthetic_loader)
    dual_iter = tqdm.tqdm(
        enumerate(zip(real_images, synthethic_images)),
        total = batches_per_epoch, 
        leave = True)
    
    train_results = defaultdict(list)
    for i, ((holo_img, holo_label), (synth_img, synth_label)) in dual_iter:
                    
        if holo_img.shape[0] != synth_img.shape[0]:
            continue
            
        optimizer_D.zero_grad()
        requires_grad(discriminator, True)
                
        # Adversarial ground truths
        real = Variable(Tensor(holo_img.size(0), 1).fill_(1.0), requires_grad=False)
        synth = Variable(Tensor(synth_img.size(0), 1).fill_(0.0), requires_grad=False)
        labels = torch.cat([real, synth], 0)

        # Configure input
        real_imgs = Variable(holo_img.type(Tensor))
        synthethic_imgs = Variable(synth_img.type(Tensor))
        batched_imgs = torch.cat([real_imgs, synthethic_imgs], 0)
        
        _, verdict = discriminator(batched_imgs)
        loss = adversarial_loss(verdict, labels)
        
        train_results["loss"].append(loss.item())
        train_results["accuracy"].append(((verdict > threshold) == labels).float().mean().item())
        
        to_print = f'Epoch {epoch} train_loss {np.mean(train_results["loss"])} train_acc {np.mean(train_results["accuracy"])}'
        dual_iter.set_description(to_print)
        dual_iter.update()
        
        loss.backward()
        optimizer_D.step()
            
        if i == batches_per_epoch and i > 0:
            break
    
    ### Validation
    requires_grad(discriminator, False)
    
    real_images = iter(test_holodec_loader)
    synthethic_images = iter(test_synthetic_loader)
    dual_iter = tqdm.tqdm(
        enumerate(zip(real_images, synthethic_images)),
        total = valid_batches_per_epoch, 
        leave = True)
    
    valid_results = defaultdict(list)
    for i, ((holo_img, holo_label), (synth_img, synth_label)) in dual_iter:
                    
        if holo_img.shape[0] != synth_img.shape[0]:
            continue
               
        # Adversarial ground truths
        real = Variable(Tensor(holo_img.size(0), 1).fill_(1.0), requires_grad=False)
        synth = Variable(Tensor(synth_img.size(0), 1).fill_(0.0), requires_grad=False)
        labels = torch.cat([real, synth], 0)

        # Configure input
        real_imgs = Variable(holo_img.type(Tensor))
        synthethic_imgs = Variable(synth_img.type(Tensor))
        batched_imgs = torch.cat([real_imgs, synthethic_imgs], 0)
        
        _, verdict = discriminator(batched_imgs)
        loss = adversarial_loss(verdict, labels)
        
        valid_results["loss"].append(loss.item())
        valid_results["accuracy"].append(((verdict > threshold) == labels).float().mean().item())
        
        to_print = f'Epoch {epoch} valid_loss {np.mean(valid_results["loss"])} valid_acc {np.mean(valid_results["accuracy"])}'
        dual_iter.set_description(to_print)
        dual_iter.update()
        
        if i == valid_batches_per_epoch and i > 0:
            break
            
    results["epoch"].append(epoch)
    results["train_loss"].append(np.mean(train_results["loss"]))
    results["train_accuracy"].append(np.mean(train_results["accuracy"]))
    results["valid_loss"].append(np.mean(valid_results["loss"]))
    results["valid_accuracy"].append(np.mean(valid_results["accuracy"]))
    
    print_str = f"Epoch {epoch}"
    print_str += f' train_loss {np.mean(train_results["loss"])} valid_loss {np.mean(valid_results["loss"])}'
    print_str += f' train_acc {np.mean(train_results["accuracy"])} valid_acc {np.mean(valid_results["accuracy"])}'

    print(print_str)
    
    df = pd.DataFrame.from_dict(results).reset_index()

    # Save the dataframe to disk
    df.to_csv(f'{conf["save_loc"]}/training_log.csv', index=False)
    
    # If the best model, save to disk
    # Save the model if its the best so far.
    if results["valid_accuracy"][-1] == max(results["valid_accuracy"]):
        state_dict = {
            'epoch': epoch,
            'model_state_dict': discriminator.state_dict(),
            'optimizer_state_dict': optimizer_D.state_dict(),
            'loss': results["valid_accuracy"][-1]
        }
        torch.save(state_dict, f'{conf["save_loc"]}/best.pt')
    
    # Lower the learning rate
    lr_scheduler.step(1.0-np.mean(valid_results["accuracy"]))
    
    # Early stopping
    best_epoch = [i for i, j in enumerate(
        results["valid_accuracy"]) if j == max(results["valid_accuracy"])][0]
    offset = epoch - best_epoch
    if offset >= stopping_patience:
        break

Epoch 0 train_loss 0.13253163713600521 train_acc 0.9517045454545454: 100%|██████████| 10/10 [01:08<00:00,  6.88s/it]
Epoch 0 valid_loss 0.045717847129922384 valid_acc 0.9865196078431373: 100%|██████████| 50/50 [03:55<00:00,  4.70s/it]

Epoch 0 train_loss 0.13253163713600521 valid_loss 0.045717847129922384 train_acc 0.9517045454545454 valid_acc 0.9865196078431373



Epoch 1 train_loss 0.01508790817266262 train_acc 0.9943181818181818: 100%|██████████| 10/10 [01:16<00:00,  7.65s/it] 
Epoch 1 valid_loss 0.0013701987045351416 valid_acc 1.0:   6%|▌         | 3/50 [00:51<10:43, 13.69s/it]