In [1]:
import os
import gc
import sys
import cv2
import glob
import math
import time
import tqdm
import random
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings('ignore')

from accelerate import Accelerator

from functools import partial
from sklearn.model_selection import StratifiedKFold

import torch
import torch.nn as nn
import torch.optim as optim
import torch.functional as F
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision import datasets
import torchvision.transforms as transforms

import albumentations as A 
from albumentations.pytorch.transforms import ToTensorV2

from transformers import get_cosine_schedule_with_warmup  # keep this
from torch.optim import AdamW

from colorama import Fore, Back, Style
r_ = Fore.RED
b_ = Fore.BLUE
c_ = Fore.CYAN
g_ = Fore.GREEN
y_ = Fore.YELLOW
m_ = Fore.MAGENTA
sr_ = Style.RESET_ALL

In [52]:
def generate_and_save_images(model, test_sample, figsize=(20,15)):
    f, axarr = plt.subplots(1,2,figsize=figsize)
    img= torchvision.utils.make_grid(test_sample, normalize=True).permute(1,2,0).numpy()
    axarr[0].imshow(img)
    
    # mean, logvar = model.encode(test_sample)
    # std = torch.exp(logvar/2)
    # z = model.reparamatrize(mean, std)
    # predictions = model.decode(z).detach().cpu()

    # predictions, _ = model(test_sample)
    mu, _ = model.encode(test_sample)
    z = torch.randn_like(mu).to(test_sample.device)

    predictions = model.decode(z)
    predictions = predictions.detach().cpu()
    fig = plt.figure(figsize=(15, 7))
    img = torchvision.utils.make_grid(predictions, normalize=True).permute(1,2,0).numpy()

    plt.savefig('image.png')
    axarr[1].imshow(img)


In [None]:
from vae_training import VAE

# model = Model()
# model.load_state_dict(torch.load('./imagenet_vae_model.bin'))
# model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cpu"

model = VAE(latent_dim=512, img_size=128, beta=1.0).to(device)
model.load_state_dict(torch.load('C:/Users/pzaka/Documents/CNN/imagenet_vae_model.bin', map_location=device))
model.eval()

config = {'lr':1e-3,
        'wd':1e-2,
        'bs':256,
        'img_size':128,
        'epochs':100,
        'seed':1000}

train_transform = transforms.Compose([
    transforms.Resize((config['img_size'],config['img_size'])),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5],
                        std=[0.5, 0.5, 0.5]),
])

valid_dataset = datasets.ImageFolder(root=os.path.join("C:/Users/pzaka/Documents/datasets/imagewoof2", "val"), transform=train_transform)
valid_dl = DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=0)

dataiter = iter(valid_dl)
sample = next(dataiter)

generate_and_save_images(model,sample[0])

In [None]:
from diffusers import AutoencoderKL

vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
vae.requires_grad_(False)
vae.to(device)

In [None]:
dataiter = iter(valid_dl)
test_sample = next(dataiter)[0]

f, axarr = plt.subplots(1,2,figsize=(20,15))
img= torchvision.utils.make_grid(test_sample, normalize=True).permute(1,2,0).numpy()
axarr[0].imshow(img)


predictions = vae.encode(test_sample).latent_dist.sample()

# predictions = torch.randn((test_sample.shape[0], 4, 32, 32)).to(device)
predictions = vae.decode(predictions).sample
predictions = predictions.detach().cpu()
fig = plt.figure(figsize=(15, 7))
img = torchvision.utils.make_grid(predictions, normalize=True).permute(1,2,0).numpy()
axarr[1].imshow(img)

In [None]:
import tqdm

latent = []
for v, l in tqdm.tqdm(valid_dl):
    if l == 0:
        latent.append(vae.encode(v).latent_dist.sample())

In [41]:
l = torch.randn(1, 4, 16, 16) * torch.stack(latent[:2]).std(0) +  torch.stack(latent[:2]).mean(0)

In [None]:
predictions1 = vae.decode(latent[0]).sample
predictions2 = vae.decode(latent[5]).sample

predictions = vae.decode((latent[0] + latent[5]) / 2).sample

f, axarr = plt.subplots(1,3)
img = torchvision.utils.make_grid(predictions1, normalize=True).permute(1,2,0).numpy()
axarr[0].imshow(img)
img = torchvision.utils.make_grid(predictions2, normalize=True).permute(1,2,0).numpy()
axarr[1].imshow(img)
img = torchvision.utils.make_grid(predictions, normalize=True).permute(1,2,0).numpy()
axarr[2].imshow(img)

In [None]:
plt.scatter(torch.stack(latent).cpu()[:,0,0,0,0].numpy(), torch.stack(latent).cpu()[:,0,0,0,1].numpy())
plt.show()

In [None]:
plt.hist(torch.stack(latent).cpu()[:,0,0,0,0].numpy(), bins=30)
plt.show()