In [None]:
import torch
from PIL import Image
from torchvision.datasets import FashionMNIST
from torchvision.utils import save_image
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torch.optim as optim
from torch.utils.data import DataLoader, SubsetRandomSampler
from torchvision import datasets, transforms
from torchvision.utils import save_image
import torchvision.utils as vutils
import os
import glob
import datetime
import time
start_time = datetime.datetime.now()
print(start_time)

os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

In [None]:
device = torch.device("cuda:2")
print(torch.cuda.is_available())

In [3]:
# state_dict
model_path_hmvae = './saved_models_hmvae_fasionmnist/model_epoch_491_loss_24.2987.pt'
state_dict = torch.load(model_path_hmvae, map_location=device)
model_path_MEMvae = './saved_models_MEMvae_fasionmnist/model_epoch_487_loss_24.0129.pt'
state_dict_MEMvae = torch.load(model_path_MEMvae, map_location=device)
model_path_vae = './saved_models_vae_fasionmnist/model_epoch_489_loss_24.6418.pt'
state_dict_vae = torch.load(model_path_vae, map_location=device)

In [None]:
from torchvision.transforms import ToTensor

mnist_test = FashionMNIST(root="./fashiondata", train=False, download=True)
real_images_dir = "./fasionmnist_real_images"
os.makedirs(real_images_dir, exist_ok=True)


for i, (img_pil, label) in enumerate(mnist_test):
    img_pil.save(os.path.join(real_images_dir, f"image_{i:04d}.png"))

In [None]:
from HMVAE import HMVAE
n_levels = 3
n_regions=2
region_slot_config= [
    ( [50, 50, 50], [64, 128, 256] ),  
    ( [64, 64,64],   [32,64, 128] )      
]
latent_dim = 20
hidden_dim =128
image_channels = 1  # MNIST has 1 channel
image_size = 28  # 28x28 images
model_hmvae = HMVAE( n_levels, n_regions, region_slot_config,latent_dim,hidden_dim,image_channels,image_size).to(device)
model_hmvae.load_state_dict(state_dict)
model_hmvae.to(device)

In [None]:
#Loading 1-layers
from MEMVAE import MEMVAE
n_levels = 1
num_slots_list = [50]  # Just example sizes
slot_dim_list = [64]  # Example embedding sizes
latent_dim = 20
hidden_dim =128
image_channels = 1  # MNIST has 1 channel
image_size = 28  # 28x28 images
model_MEMvae= MEMVAE( n_levels, num_slots_list, slot_dim_list,latent_dim,hidden_dim,image_channels,image_size).to(device)
model_MEMvae.load_state_dict(state_dict_MEMvae)
model_MEMvae.to(device)

In [None]:
from VAE import VAE
model_vae  = VAE(latent_dim=20,hidden_dim=128,image_channels=1,image_size=28).to(device)
model_vae.load_state_dict(state_dict_vae)

In [None]:
def generate_from_latent(model, num_samples, save_dir, latent_dim=20):
    model.eval()
    
    os.makedirs(save_dir, exist_ok=True)
    with torch.no_grad():
        
        z = torch.randn(num_samples, latent_dim).to(device)
        
        if model == model_vae:
            gen_data = model.decoder(z)  
        else:
            gen_data, _ = model.decoder(z)  
            
        gen_data = gen_data.view(-1, 1, 28, 28)
        
       
        for i in range(num_samples):
            img_path = os.path.join(save_dir, f'image_{i:04d}.png')  
            vutils.save_image(gen_data[i], img_path, normalize=False) 


generate_from_latent(model_vae, num_samples=10000, save_dir='./generated_images_vae_fasion')
generate_from_latent(model_hmvae, num_samples=10000, save_dir='./generated_images_hmvae_fasion')
generate_from_latent(model_MEMvae, num_samples=10000, save_dir='./generated_images_MEMVAE_fasion')

In [None]:
import random 
from torchvision.utils import make_grid, save_image
import os
import random
import glob
from PIL import Image
import torch
from torchvision import transforms
from torchvision.utils import make_grid, save_image

# Corrected model name spelling from "fasion" to "fashion"
models = ["vae_fashion", "hmvae_fashion", "memvae_fashion"]

output_dir = "./generated_grids_fashion"
os.makedirs(output_dir, exist_ok=True)

for model in models:
    input_folder = f"./generated_images_{model}"
    
    # Check if input folder exists
    if not os.path.exists(input_folder):
        print(f"Warning: {input_folder} does not exist, skipping...")
        continue
    
    # Get all PNG images in the folder
    all_images = glob.glob(os.path.join(input_folder, "*.png"))
    
    # Check if we have enough images
    if len(all_images) < 64:
        print(f"Warning: Insufficient images in {input_folder} (found {len(all_images)}, need 64), skipping...")
        continue
    
    # Randomly sample 64 images
    image_paths = random.sample(all_images, 64)
    
    # Load and process images
    tensor_list = []
    transform = transforms.ToTensor()
    for img_path in image_paths:
        img = Image.open(img_path).convert("RGB")
        tensor = transform(img).unsqueeze(0)  # Add batch dimension
        tensor_list.append(tensor)
    
    # Create image grid
    batch = torch.cat(tensor_list, dim=0)
    grid = make_grid(batch, nrow=8, padding=2, normalize=True)
    
    # Save grid image
    output_path = os.path.join(output_dir, f"{model}_grid_8x8.png")
    save_image(grid, output_path)
    print(f"{model} grid image saved to: {output_path}")

In [None]:
from PIL import Image
def convert_grayscale_to_rgb(image_path):
    img = Image.open(image_path).convert("L")  
    img_rgb = np.stack([np.array(img)] * 3, axis=-1)  
    Image.fromarray(img_rgb).save(image_path)  

In [11]:
for img_path in glob.glob("./generated_images_vae_fasion/*.png"):
    convert_grayscale_to_rgb(img_path)
for img_path in glob.glob("./generated_images_hmvae_fasion/*.png"):
    convert_grayscale_to_rgb(img_path)
for img_path in glob.glob("./generated_images_MEMVAE_fasion/*.png"):
    convert_grayscale_to_rgb(img_path)
for img_path in glob.glob("./fasionmnist_real_images/*.png"):
    convert_grayscale_to_rgb(img_path)

In [None]:
import subprocess
models = [
    "vae_fasion", 
    "hmvae_fasion", 
    "MEMVAE_fasion"
]
for model in models:
    gen_dir = f"./generated_images_{model}"
    cmd = f"python -m pytorch_fid  {gen_dir} ./fasionmnist_real_images"
    result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
    print(f"FID for {model}: {result.stdout}")