In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torch.nn.functional as F

import torchvision
from torchvision import transforms

In [2]:
t = transforms.Compose([
    transforms.PILToTensor(),
    #transforms.ConvertImageDtype(torch.float)
    transforms.Lambda(lambda x: x.float())
])

In [3]:
traindata = torchvision.datasets.CIFAR10('./', train=True, download=False, transform= transforms.ToTensor())
testdata = torchvision.datasets.CIFAR10('./', train=False, download=False, transform= transforms.ToTensor())

In [4]:
trainloader = DataLoader(traindata, batch_size=64, shuffle=True)
testloader = DataLoader(testdata, batch_size=64, shuffle=False)

In [5]:
#next(iter(trainloader))

In [6]:
class VAE(torch.nn.Module):
    def __init__(self, input_dim=3*32*32, latent_dim=25, output_dim=3*32*32):
        super().__init__()
        self.layer1 = nn.Linear(input_dim, 250)
        self.mean = nn.Linear(250, latent_dim)
        self.lgvar = nn.Linear(250, latent_dim)

        self.layer2 = nn.Linear(latent_dim, 250)
        self.outputLayer = nn.Linear(250, output_dim)

    def encoder(self, x):
        x = F.relu(self.layer1(x))
        mu = self.mean(x)
        logvar = self.lgvar(x)

        return mu, logvar

    def reparameterization(self, mean, logvar):
        #std = torch.sqrt(torch.exp(logvar))
        std = torch.exp(0.5 * logvar)
        exp = torch.rand_like(std)
        z = mean + std*exp
        
        return z

    def decoder(self, z):
        z = F.relu(self.layer2(z))
        return self.outputLayer(z)

    def forward(self, x):
        mean, lgvar = self.encoder(x.view(-1, 3*32*32))
        z = self.reparameterization(mean, lgvar)
        recons_x = self.decoder(z)

        return recons_x, mean, lgvar
        

In [7]:
def vae_loss(recons_x, x, mean, logvar):
    logvar = torch.clamp(logvar, -1, 1)
    #print(recons_x.shape)
    recons_loss = F.mse_loss(recons_x, x.view(-1, 3*32*32), reduction='sum')
    #print(recons_loss)
    kl_loss = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())
    #print(kl_loss)
    return recons_loss + kl_loss

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VAE().to(device)
print(model)

VAE(
  (layer1): Linear(in_features=3072, out_features=250, bias=True)
  (mean): Linear(in_features=250, out_features=25, bias=True)
  (lgvar): Linear(in_features=250, out_features=25, bias=True)
  (layer2): Linear(in_features=25, out_features=250, bias=True)
  (outputLayer): Linear(in_features=250, out_features=3072, bias=True)
)


In [9]:
optimizer = optim.Adam(model.parameters(), lr=0.002)

In [10]:
epochs = 10
for epoch in range(epochs):
    model.train()
    train_loss = 0
    cosine_sims = []

    for batch_idx, (data, _) in enumerate(trainloader):
        data = data.to(device)
        optimizer.zero_grad()

        recon_batch, mu, logvar = model(data)
        #print(mu, logvar)
        loss = vae_loss(recon_batch, data, mu, logvar)
        loss.backward()
        #torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=25)
        optimizer.step()
        train_loss += loss.item()

        flat_data = data.view(-1, 3*32*32)
        cos_sim = F.cosine_similarity(flat_data, recon_batch, dim=1)  
        cosine_sims.append(cos_sim)

    # Aggregate cosine similarity across all batches
    epoch_cosine_similarity = torch.cat(cosine_sims).mean().item()

    print(f'Epoch {epoch + 1}, Loss: {train_loss / len(trainloader.dataset):.4f}, Cosine Similarity: {epoch_cosine_similarity:.4f}')

Epoch 1, Loss: 103.7209, Cosine Similarity: 0.9327
Epoch 2, Loss: 82.1583, Cosine Similarity: 0.9465
Epoch 3, Loss: 81.5829, Cosine Similarity: 0.9468
Epoch 4, Loss: 81.2420, Cosine Similarity: 0.9471
Epoch 5, Loss: 81.0293, Cosine Similarity: 0.9472
Epoch 6, Loss: 80.7730, Cosine Similarity: 0.9474
Epoch 7, Loss: 81.0075, Cosine Similarity: 0.9473
Epoch 8, Loss: 80.3353, Cosine Similarity: 0.9476
Epoch 9, Loss: 80.3086, Cosine Similarity: 0.9477
Epoch 10, Loss: 80.1898, Cosine Similarity: 0.9478


In [11]:
model.eval()
test_loss = 0
cosine_sim_test = []

for batch_idx, (data, _) in enumerate(testloader):
    data = data.to(device)
    recon_batch, mu, logvar = model(data)
    loss = vae_loss(recon_batch, data, mu, logvar)

    test_loss += loss.item()

    flat_data = data.view(-1, 3*32*32)
    cos_sim = F.cosine_similarity(flat_data, recon_batch, dim=1)  
    cosine_sim_test.append(cos_sim)

cosine_similarity = torch.cat(cosine_sims).mean().item()
print(f"Test Cosine Similarity: {round(cosine_similarity,4)}")

Test Cosine Similarity: 0.9478
