# VAE Variational Autoencoder

import needest modules

In [None]:
import torch
import torch.nn as nn
from torch import optim 
from torchinfo import summary
from torchmetrics import Accuracy
from torch.utils.data import Dataset, DataLoader
from VAE.model import VAE
from progress.bar import IncrementalBar
import matplotlib.pyplot as p
from torch.autograd import Variable 
import matplotlib.pyplot as plt
from VAE.data import Ego4d, DEVICE, BATCH_SIZE, transform1, transform2, ResumableRandomSampler

prepare dataset for training

In [None]:
checkpoint = torch.load('/home/qwest/project/PycharmProjects/Reinforsment_Learning/VAE/weights/main/VAE_checkpoint_32_44.pt')

In [None]:
print('transform initializate sucsess')
train_dataset = Ego4d(img_dir='/home/qwest/data_for_ml/2_25',
                           transform1=transform1,
                           transform2=transform2)
print("train_dataset init")
sampler = ResumableRandomSampler(train_dataset)
sampler.set_state(checkpoint['sampler_state'])
print("train_sampler init")
train_loader = DataLoader(dataset=train_dataset,
                           batch_size=32,
                           shuffle=False,
                           sampler=sampler,
                           num_workers=6)
print("train_loader init")

In [None]:
print("Len of trainloader: ",len(train_loader))


In [5]:
torch.save(sampler.get_state(), "test_samp.pth")

setup param's for VAE training

In [13]:
lr = 0.001
epochs = 50
latent_dim = 32

create model & optimizer with parametrs ^^^ 

In [14]:
model = VAE(latent_dim, batch_size=BATCH_SIZE).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=lr)

In [None]:
model.load_state_dict(torch.load('model.pt', weights_only=True))

## OR

In [None]:
checkpoint = torch.load('/home/qwest/project/PycharmProjects/Reinforsment_Learning/VAE/weights/main/VAE_checkpoint_32_44.pt')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

create train func 

In [20]:
def train(epoch):
        """
        train VAE model.

        Args:
        epoch (int): number of epoch.
        """
        bar = IncrementalBar('Countdown', max = len(train_loader))

        x = next(iter(train_loader))
        model.train()
        print(f'Epoch {epoch} start')
        eval_loss = 0
        # Loop through all batches in the training dataset
        for i, data in enumerate(train_loader):
                data = data.to(DEVICE)
                optimizer.zero_grad()
                
                recon_batch, mu, logvar = model(data)
                loss = model.loss_function(recon_batch, data, mu, logvar)
                eval_loss += loss
                
                loss.backward() # Compute the gradients with respect to the model parameters
                
                optimizer.step()
                print(i)
                bar.next(i) # Update the model parameters using the optimizer

        torch.save({
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'loss':loss,
                        'epoch':epoch,
                        'full_model':model,
                        'sampler_state':sampler.get_state(),
                        },
                        f'VAE/weights/main/VAE_checkpoint_{latent_dim}_{epoch}.pt')
        bar.finish()
        print(f"Avg loss: {loss:2f} \n")
        model.eval()
        recon_img, _, _ = model(x[:1].to(DEVICE))
        img = recon_img.view(3, 64, 64).detach().cpu().numpy().transpose(1, 2, 0)
        f = p.imshow(img)
        p.show()


train model with logging on mlflow 

In [None]:
# Log model summary.|
with open("model_summary.txt", "w") as f:
    f.write(str(summary(model)))

for t in range(epoch, epochs+1):
    train(t)

In [11]:
x = next(iter(train_loader))

In [23]:
torch.onnx.export(model, x.to(DEVICE), "model.onnx", input_names=['image'], output_names=['image'])

## Test model

In [None]:
# checkpoint = torch.load('/home/qwest/project/PycharmProjects/Reinforsment_Learning/test_W/VAE_checkpoint_1283.pt')
# model.load_state_dict(checkpoint['model_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# epoch = checkpoint['epoch']
# loss = checkpoint['loss']
model.eval()

test on exist image 

In [None]:
x = next(iter(train_loader))
print(x.size())
reconstructed, mu, _ = model(x.to(DEVICE))
reconstructed = reconstructed.view(-1, 3, 64, 64).detach().cpu().numpy().transpose(0, 2, 3, 1)

fig = plt.figure(figsize=(25, 16))
for ii, img in enumerate(reconstructed):
    ax = fig.add_subplot(4, 8, ii + 1, xticks=[], yticks=[])
    plt.imshow((img))

transition of image to image

In [None]:
first_dog_idx = 3
second_dog_idx = 8

dz = (mu[second_dog_idx] - mu[first_dog_idx]) / 31
walk = Variable(torch.randn(32, latent_dim, 4, 4)).to(DEVICE)
walk[0] = mu[first_dog_idx]

for i in range(1, 32):
    walk[i] = walk[i-1] + dz
walk = model.decoder(walk).detach().cpu().numpy().transpose(0, 2, 3, 1)

fig = plt.figure(figsize=(25, 16))
for ii, img in enumerate(walk):
    ax = fig.add_subplot(4, 8, ii + 1, xticks=[], yticks=[])
    plt.imshow((img))

generate image from nose

In [None]:
samples = Variable(torch.randn(32, latent_dim, 4, 4)).to(DEVICE)
samples = model.decoder(samples).detach().cpu().numpy().transpose(0, 2, 3, 1)

fig = plt.figure(figsize=(25, 16))
for ii, img in enumerate(samples):
    ax = fig.add_subplot(4, 8, ii + 1, xticks=[], yticks=[])
    plt.imshow((img))

In [11]:
torch.save(model.state_dict(), "model.pt")

In [17]:
reconstructed, _, _ = model(x[0][None, :, :, :].to(DEVICE))

In [18]:
reconstructed = reconstructed.view(-1, 3, 64, 64).detach().cpu().numpy().transpose(0, 2, 3, 1)


In [None]:
reconstructed[0].shape

In [32]:
from tqdm import tqdm
for i in tqdm(range(10)):
    pass

IntProgress(value=0)