In [None]:
import warnings
warnings.filterwarnings("ignore")

import pytorch_lightning as pl
import os
import numpy as np
import torchvision
import torchvision.transforms as transforms
from torch.nn import functional as F
from pytorch_lightning import LightningModule, Trainer, seed_everything
from pytorch_lightning.loggers import TensorBoardLogger
from statistics import mean
import pandas as pd
import torch

from VAE_Architectures import VAE
from utils import create_dataset, get_train_images, plot_vae_images


This file contains the VAE training, the decoder from which is later used as a generative system.

To fix the seed, the function seed_everything from the pytorch_lightning library is used, as well as manual_seed from torch.

This model is described in the paper under Autoencoder-based generative model.

In [2]:
seed_everything(41)
torch.use_deterministic_algorithms(True)
torch.manual_seed(41)

Global seed set to 41


<torch._C.Generator at 0x1ff91acc490>

In [10]:
batch_size = 64
crop_size = 128
kl_coeff = 0.01
lr = 1e-3
num_of_channels = 1
max_epochs = 2
run_id = "Results_VAE"
path_to_data = "../Datasets/Augmented_One_Particle_Dataset"

In [11]:
save_path = "Results_VAE/{r}/".format(r=run_id)
if not os.path.exists(save_path):
    os.makedirs(save_path)

In [12]:
dataset, dataloader = create_dataset(
    path_to_data, batch_size, crop_size, num_of_channels
)

In [13]:
class GenerateCallback(pl.Callback):
    def __init__(self, input_images, save_path, every_n_epochs=1):
        super().__init__()
        self.input_images = input_images
        self.every_n_epochs = every_n_epochs
        self.save_path = save_path

    def on_train_epoch_end(self, trainer, pl_module):
        if trainer.current_epoch % self.every_n_epochs == 0:
            image_save_path = self.save_path + "Images/epoch_{e}/".format(
                e=trainer.current_epoch
            )
            if not os.path.exists(image_save_path):
                os.makedirs(image_save_path)

            input_images = self.input_images.to(pl_module.device)
            with torch.no_grad():
                pl_module.eval()
                reconstructed_images = pl_module(input_images)
                pl_module.train()
            # Plot and add to tensorboard
            imgs = torch.stack([input_images, reconstructed_images], dim=1).flatten(
                0, 1
            )
            grid = torchvision.utils.make_grid(
                imgs, nrow=2, normalize=True, range=(-1, 1)
            )
            trainer.logger.experiment.add_image(
                "Reconstructions", grid, global_step=trainer.global_step
            )

            psnr_metrics, ssim_metrics = plot_vae_images(
                input_images, reconstructed_images, image_save_path
            )

            df_metrics = pd.read_csv(self.save_path + "df_metrics.csv")
            df_metrics.loc[len(df_metrics)] = [
                trainer.current_epoch,
                mean(psnr_metrics),
                mean(ssim_metrics),
            ]
            df_metrics.to_csv(self.save_path + "df_metrics.csv", index=False)

In [14]:
def Train_VAE(run_id):
    vae = VAE(crop_size, lr=1e-3, kl_coeff=kl_coeff)
    logger = TensorBoardLogger("Results_VAE", name=run_id)
    trainer = pl.Trainer(
        gpus=1,
        logger=logger,
        max_epochs=max_epochs,
        callbacks=[
            GenerateCallback(get_train_images(10, dataset), every_n_epochs=1, save_path=save_path)
        ],
        log_every_n_steps=5,
        auto_lr_find=False,
        auto_scale_batch_size=False,
    )
    return trainer, vae


In [15]:
df_metrics = pd.DataFrame(columns=["Epoch", "PSNR", "SSIM"])
df_metrics.to_csv(save_path + "df_metrics.csv", index=False)

In [None]:
trainer, ae = Train_VAE(run_id)


In [None]:
trainer.fit(ae, dataloader)


In [1]:
import pandas as pd
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator

log_dir = "Results_VAE/Results_VAE_Final_Validation/version_0"

event_accumulator = EventAccumulator(log_dir)
event_accumulator.Reload()

events = event_accumulator.Scalars("train_loss")
x = [x.step for x in events]
y = [x.value for x in events]

df = pd.DataFrame({"step": x, "train_loss": y})
df.to_csv("Results_VAE/Results_VAE_Final_Validation/train_loss.csv", index=False)

In [2]:
losses = pd.read_csv('Results_VAE/Results_VAE_Final_Validation/train_loss.csv')
losses['Epoch'] = losses.index
metrics = pd.read_csv('Results_VAE/Results_VAE_Final_Validation/df_metrics.csv')

In [3]:
metrics.head()

Unnamed: 0,Epoch,PSNR,SSIM
0,0.0,29.119607,0.976639
1,1.0,29.890835,0.978315
2,2.0,31.374599,0.980945
3,3.0,31.732449,0.982828
4,4.0,32.474348,0.982533


In [10]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Create figure with secondary y-axis
fig = make_subplots(specs=[[{"secondary_y": True}]])

# Add traces
fig.add_trace(
    go.Scatter(x=metrics['Epoch'], y=metrics['PSNR'], name="PSNR"),
    secondary_y=False,
)

fig.add_trace(
    go.Scatter(x=metrics['Epoch'], y=metrics['SSIM'], name="SSIM"),
    secondary_y=True,
)

# Add figure title
# fig.update_layout(
#     title_text="SSIM and PSNR metrics by epoch"
# )

# Set x-axis title
fig.update_xaxes(title_text="Epoch")

# Set y-axes titles
fig.update_yaxes(title_text="PSNR", secondary_y=False, showgrid = True   ,mirror=True,
    ticks='outside',
    showline=True,
    linecolor='black',
    gridcolor='lightgrey')
fig.update_yaxes(title_text="SSIM", secondary_y=True, showgrid = True, )
fig.update_xaxes(
    mirror=True,
    ticks='outside',
    showline=True,
    linecolor='black',
    gridcolor='lightgrey'
)
fig.update_layout(
    autosize=False,
    width=800,
    height=800,
    margin=dict(
        l=50,
        r=50,
        b=100,
        t=100,
        pad=4
    ),
    paper_bgcolor="White",     plot_bgcolor='white', title_x=0.5,  font=dict(size=30, color="black"),     legend=dict(
        x=0.02,
        y=.98,
        traceorder="normal",
        font=dict(
            family="sans-serif",
            size=30,
            color="black"
        ),         bordercolor="Black", 
        borderwidth=1
    )
)
fig.update_traces(line={'width': 5})
fig.show()

In [13]:
import plotly.express as px

df = losses
fig = px.line(df, x="Epoch", y="train_loss")
# Set y-axes titles
fig.update_yaxes(title_text="Train loss", showgrid = True   ,mirror=True,
    ticks='outside',
    showline=True,
    linecolor='black',
    gridcolor='lightgrey')
fig.update_xaxes(
    mirror=True,
    ticks='outside',
    showline=True,
    linecolor='black',
    gridcolor='lightgrey'
)
fig.update_layout(
    autosize=False,
    width=800,
    height=800,
    margin=dict(
        l=50,
        r=50,
        b=100,
        t=100,
        pad=4
    ),
    paper_bgcolor="White",     plot_bgcolor='white', title_x=0.5,  font=dict(size=30, color="black")
)
fig.update_traces(line={'width': 5})
fig.show()