In [None]:
import pickle
from pytorch_lightning import LightningModule, Trainer, seed_everything
from pytorch_lightning.loggers import TensorBoardLogger
from torch import nn
import torch
import pytorch_lightning as pl
import random
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from utils import Create_data, Create_text_embeddings, Create_image_embeddings, Plot_results
from torch.utils.data import DataLoader
from transformers import logging
from Linking_AE_Architectures import Linking_AE

This file presents the training of linking AE, which is used to translate a text vector into an image vector, which is then decoded into an image using the decoder from the VAE trained earlier. 

To fix the seed, the function seed_everything from the pytorch_lightning library is used, as well as manual_seed from torch.

This model is described in the paper under "Linking" autoencoder model.

In [None]:
seed_everything(41)
torch.use_deterministic_algorithms(True)
torch.manual_seed(41)

In [2]:
batch_size = 8
run_id = 'Linking_AE_Test'
path_to_data = "G:/Nanomaterial_Morphology_Prediction/Datasets/Augmented_One_Particle_Dataset/"
vae_checkpoint_path = "G:/Nanomaterial_Morphology_Prediction/VAE_Training/Results_VAE/Results_VAE_Final_Validation/version_0/checkpoints/epoch=39-step=26320.ckpt"  

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
crop_size = 128
num_of_channels = 1


In [3]:
def Create_All_embeddings():
    preprocessed_data = Create_data(path_to_data, pattern=0)
    Text_embeddings = Create_text_embeddings(preprocessed_data, device, batch_size=2000)

    with open(
        "Embeddings/text_embeddings_{}.embs".format(run_id), "wb"
    ) as handle:
        pickle.dump(Text_embeddings, handle, protocol=pickle.HIGHEST_PROTOCOL)
    Image_embeddings = Create_image_embeddings(
        path_to_data,
        vae_checkpoint_path,
        batch_size=2000,
        crop_size=crop_size,
        num_of_channels=num_of_channels,
    )

    with open(
        "Embeddings/image_embeddings_{}.embs".format(run_id), "wb"
    ) as handle:
        pickle.dump(Image_embeddings, handle, protocol=pickle.HIGHEST_PROTOCOL)
    return preprocessed_data, Text_embeddings, Image_embeddings


In [4]:
preprocessed_data, text, image = Create_All_embeddings()

Batch 1 completed out of 1
1755 images preprocessed out of 1755


In [5]:
train_dl = DataLoader([[text[i], image[i]] for i in range(len(text))], shuffle=True, batch_size=batch_size, num_workers=4)

In [None]:
ae = Linking_AE()
logger = TensorBoardLogger("Result_Linking_AE", name=run_id)
trainer = pl.Trainer(
    max_epochs=200,
    log_every_n_steps=10,
    logger = logger
)

In [None]:
trainer.fit(ae, train_dl)

In [8]:
ae_ckpt_path = "Result_Linking_AE/Linking_AE_Test/version_0/checkpoints/epoch=199-step=44000.ckpt"
ae = Linking_AE.load_from_checkpoint(ae_ckpt_path)

In [9]:
number_of_samples = 10

In [10]:
save_to = "Result_Linking_AE\\{r}\\".format(r = run_id)

randomlist = random.sample(range(0, len(text)), number_of_samples)
x = text[randomlist].clone().detach()
y = image[randomlist].clone().detach()
with torch.no_grad():
    ae.eval()
    y_reconst = ae(x.to(ae.device))
    ae.train()

transform = transforms.Compose(
    [
        transforms.Resize(crop_size),
        transforms.CenterCrop(crop_size),
        transforms.ToTensor(),
        transforms.Grayscale(num_output_channels=num_of_channels),
    ]
)
dataset = ImageFolder(root=path_to_data, transform=transform)
initial = [dataset[i][0] for i in randomlist]

In [11]:
Plot_results(y_reconst, y, initial, save_to, vae_checkpoint_path)