In [None]:
%load_ext autoreload
%autoreload 2

import torch
from data.utils import *
import matplotlib.pyplot as plt
import numpy as np
import umap
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, classification_report
from sklearn.model_selection import train_test_split
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import DataLoader, Dataset
from PIL import Image
from scipy.linalg import sqrtm
import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm


In [14]:
vae_model = unpickle('../saved_models/mnist_cvae_new_2000.pkl')
test_loader = unpickle('../data/contrastive_mnist_data_loaders.pkl')['test']

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)

## Reconstruction

In [None]:
images = [50, 70, 90, 110, 132]

vae_model.eval()

fig, axes = plt.subplots(2, len(images), figsize=(10, 4))

for idx, img_idx in enumerate(images):
    test_image = test_loader.dataset[img_idx][0].to(device).view(1, 1, 28, 28)
    
    recon = vae_model(test_image)[0].view(1, 28, 28).cpu().detach()
    x = test_image.view(1, 28, 28).cpu().detach()
    
    recon = recon * 0.5 + 0.5
    x = x * 0.5 + 0.5
    
    axes[0, idx].imshow(x.permute(1, 2, 0), cmap='gray')
    axes[0, idx].axis('off')
    # axes[0, idx].set_title(f"Orig {img_idx}")
    
    axes[1, idx].imshow(recon.permute(1, 2, 0), cmap='gray')
    axes[1, idx].axis('off')
    # axes[1, idx].set_title(f"Recon {img_idx}")

plt.tight_layout()
plt.show()

## Latent Space Visualization

In [None]:
vae_model.eval()

latents = []
labels = []

with torch.no_grad():
    for images, lbls in test_loader:
        images = images.to(device)
        mu, _ = vae_model.encode(images)
        # mu = vae_model.projection(mu)
        latents.append(mu.cpu().numpy())
        labels.append(lbls.cpu().numpy())

latents = np.concatenate(latents, axis=0)
labels = np.concatenate(labels, axis=0)

reducer = umap.UMAP(n_neighbors=15, min_dist=0.1)
embedding = reducer.fit_transform(latents)

plt.figure(figsize=(10, 8))
scatter = plt.scatter(embedding[:, 0], embedding[:, 1], c=labels, cmap='tab10', s=10)
plt.colorbar(scatter, ticks=range(10), label="Class label")
plt.title("UMAP Projection of CVAE Latent Space (Colored by MNIST Labels)")
plt.xlabel("UMAP-1")
plt.ylabel("UMAP-2")
plt.grid(True)
plt.tight_layout()
plt.show()

## Image Generation

In [None]:
random = torch.randn((1, 10)).to(device)
gen = vae_model.decode(random)[0].detach().cpu()
gen = gen * 0.5 + 0.5


plt.imshow(gen.permute(1, 2, 0), cmap='gray')
plt.axis('off')

## Logistic Regression

In [None]:
X_train, X_test, y_train, y_test = train_test_split(latents, labels, test_size=0.2, random_state=42)

clf = LogisticRegression(max_iter=1000, solver='lbfgs', multi_class='multinomial')
clf.fit(X_train, y_train)

y_pred = clf.predict(X_test)
acc = accuracy_score(y_test, y_pred)

print(f"Test Accuracy: {acc * 100:.2f}%")

## FID Score

In [None]:
inception = models.inception_v3(pretrained=True, transform_input=False)
inception.fc = torch.nn.Identity()
inception.eval().cuda()

In [None]:
transform = transforms.Compose([
    transforms.Resize((299, 299)),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

In [None]:
class TensorImageDataset(Dataset):
    def __init__(self, images, transform=None):
        self.images = images
        self.transform = transform

    def __len__(self):
        return self.images.shape[0]

    def __getitem__(self, idx):
        img = self.images[idx]
        if self.transform:
            img = self.transform(img)
        return img

In [None]:
def get_activations(dataloader):
    features = []
    with torch.no_grad():
        for batch in dataloader:
            batch = batch.cuda()
            preds = inception(batch)
            features.append(preds.cpu().numpy())
    return np.concatenate(features, axis=0)


In [None]:
def calculate_fid(mu1, sigma1, mu2, sigma2):
    diff = mu1 - mu2
    covmean = sqrtm(sigma1 @ sigma2)
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    return diff.dot(diff) + np.trace(sigma1 + sigma2 - 2 * covmean)

In [None]:
vae_model = unpickle('../saved_models/mnist_cvae_new_2_400.pkl')
test_loader = unpickle('../data/mnist_data_loaders.pkl')['test']

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)

In [None]:
real_images = test_loader.dataset.data.view(10000, 1, 28, 28)
real_images = (real_images / 256) * 2 - 1
real_images = real_images.repeat(1,3,1,1)
real_images = F.interpolate(real_images, size=(299,299), mode='bilinear', align_corners=False)

random = torch.randn((10000, 10)).to(device)
gen_images = vae_model.decode(random).detach().cpu()
gen_images = (gen_images * 0.5 + 0.5) * 2 - 1
gen_images = gen_images.repeat(1,3,1,1)
gen_images = F.interpolate(gen_images, size=(299,299), mode='bilinear', align_corners=False)

real = TensorImageDataset(real_images)
gen = TensorImageDataset(gen_images)


In [None]:
real_loader = DataLoader(real, batch_size=64, shuffle=False)
fake_loader = DataLoader(gen, batch_size=64, shuffle=False)

real_acts = get_activations(real_loader)
fake_acts = get_activations(fake_loader)

mu_real, sigma_real = np.mean(real_acts, axis=0), np.cov(real_acts, rowvar=False)
mu_fake, sigma_fake = np.mean(fake_acts, axis=0), np.cov(fake_acts, rowvar=False)

fid_score = calculate_fid(mu_real, sigma_real, mu_fake, sigma_fake)
print(f"FID Score: {fid_score:.2f}")