In [2]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from skimage.metrics import structural_similarity as ssim
import numpy as np
from sklearn.manifold import TSNE

In [3]:
#transforms
transform = transforms.Compose([
    transforms.ToTensor(),  #to tensor
    transforms.Normalize((0.5,), (0.5,))  #normalize
])

#dataset split
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

print(len(train_dataset), len(test_dataset)) #60K, 10K
print(train_dataset[0][0].shape) #1x28x28

60000 10000
torch.Size([1, 28, 28])


In [4]:
class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost=0.25):
        super(VectorQuantizer, self).__init__()

        self.num_embeddings = num_embeddings #number of vectors in codebook
        self.embedding_dim = embedding_dim #dimension of each vector
        self.commitment_cost = commitment_cost #beta term

        self.embeddings = nn.Embedding(num_embeddings, embedding_dim)
        self.embeddings.weight.data.uniform_(-1.0 / num_embeddings, 1.0 / num_embeddings)

    def forward(self, inputs):
        inputs_flatten = inputs.view(-1, self.embedding_dim) #flatten input tensor

        #finding closest codebook vectors
        distances = (torch.sum(inputs_flatten ** 2, dim=1, keepdim=True) +
                     torch.sum(self.embeddings.weight ** 2, dim=1) -
                     2 * torch.matmul(inputs_flatten, self.embeddings.weight.t()))
        
        #closest embedding indices
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)

        #one-hot indices -. new representation
        encodings = torch.zeros(encoding_indices.size(0), self.num_embeddings, device=inputs.device)
        encodings.scatter_(1, encoding_indices, 1)

        #quantize
        #multiplying one-hot indices with codebook vectors
        quantized = torch.matmul(encodings, self.embeddings.weight).view(inputs.shape)

        #losses
        e_latent_loss = F.mse_loss(quantized.detach(), inputs)
        q_latent_loss = F.mse_loss(quantized, inputs.detach())
        loss = q_latent_loss + self.commitment_cost * e_latent_loss

        quantized = inputs + (quantized - inputs).detach()  #something for backprop

        return quantized, loss, encoding_indices

In [5]:
##simple encoder
class Encoder(nn.Module):
    def __init__(self, in_channels=1, embedding_dim=64, num_classes=10):
        super(Encoder, self).__init__()
        self.label_embedding = nn.Embedding(num_classes, in_channels*28*28)
        self.conv1 = nn.Conv2d(in_channels*2, 32, kernel_size=4, stride=2, padding=1)  
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)  
        self.conv3 = nn.Conv2d(64, embedding_dim, kernel_size=3, stride=1, padding=1)  

    def forward(self, x, labels):
        label_embed = self.label_embedding(labels)
        label_embed = label_embed.view(-1, 1, 28, 28)
        x = torch.cat((x, label_embed), dim=1)

        x = F.leaky_relu(self.conv1(x))
        x = F.leaky_relu(self.conv2(x))
        x = self.conv3(x)

        return x


##simple decoder
class Decoder(nn.Module):
    def __init__(self, embedding_dim=64, out_channels=1):
        super(Decoder, self).__init__()

        self.conv1 = nn.ConvTranspose2d(embedding_dim, 64, kernel_size=3, stride=1, padding=1)  
        self.conv2 = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1)  
        self.conv3 = nn.ConvTranspose2d(32, out_channels, kernel_size=4, stride=2, padding=1)  

    def forward(self, x):

        x = F.leaky_relu(self.conv1(x))
        x = F.leaky_relu(self.conv2(x))
        x = torch.tanh(self.conv3(x))  #-1,1 output

        return x

In [6]:
##put together
class VQVAE(nn.Module):
    def __init__(self, embedding_dim=64, num_embeddings=512):
        super(VQVAE, self).__init__()

        self.encoder = Encoder(embedding_dim=embedding_dim)
        self.quantizer = VectorQuantizer(num_embeddings, embedding_dim)
        self.decoder = Decoder(embedding_dim=embedding_dim)

    def forward(self, x, labels):

        z = self.encoder(x, labels)
        quantized, vq_loss, _ = self.quantizer(z)
        x_recon = self.decoder(quantized)

        return x_recon, vq_loss

In [7]:
##training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = VQVAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

def train_vqvae(model, dataloader, epochs=10):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for images, labels in dataloader:
            images = images.to(device)
            optimizer.zero_grad()
            recon, vq_loss = model(images, labels)
            recon_loss = F.mse_loss(recon, images)
            loss = recon_loss + vq_loss
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        
        print(f"Epoch {epoch + 1}, Loss: {total_loss / len(dataloader):.4f}")

train_vqvae(model, train_loader, epochs=10)

Epoch 1, Loss: 38.8650
Epoch 2, Loss: 47.8278
Epoch 3, Loss: 23.4310
Epoch 4, Loss: 13.3208
Epoch 5, Loss: 9.0492
Epoch 6, Loss: 5.7711
Epoch 7, Loss: 4.0559
Epoch 8, Loss: 3.0746
Epoch 9, Loss: 2.4344
Epoch 10, Loss: 2.0251


In [27]:
#extract latent representations and labels
def extract_latent_codes(model, dataloader):
    model.eval()
    latents, labels = [], []
    with torch.no_grad():
        for images, targets in dataloader:
            images = images.to(device)
            z = model.encoder(images, targets)
            quantized, _, _ = model.quantizer(z)

            pooled_latents = quantized.mean(dim=[2, 3])
            print(pooled_latents.shape)
            latents.append(pooled_latents.cpu())
            labels.append(targets)
    return torch.cat(latents), torch.cat(labels)

In [28]:
##small nn to classify latents
class LatentClassifier(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(LatentClassifier, self).__init__()
        
        self.fc1 = nn.Linear(input_dim, 128)
        self.bn1 = nn.BatchNorm1d(128)  
        self.fc2 = nn.Linear(128, 64)
        self.bn2 = nn.BatchNorm1d(64)  
        self.fc3 = nn.Linear(64, num_classes)
        self.dropout = nn.Dropout(0.3)
        

    def forward(self, x):
        x = F.leaky_relu(self.bn1(self.fc1(x)))
        x = self.dropout(x)
        x = F.leaky_relu(self.bn2(self.fc2(x)))
        x = self.dropout(x)
        x = self.fc3(x)

        return x


In [29]:
#extract
train_latents, train_labels = extract_latent_codes(model, train_loader)
test_latents, test_labels = extract_latent_codes(model, test_loader)

# print(f"Latent shape train: {train_latents.shape}")
# print(f"Labels: {train_labels.shape}")

# print(f"Latent shape test: {test_latents.shape}")
# print(f"Labels: {test_labels.shape}")

torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([6

In [30]:
#latents dataset
train_dataset_lat = torch.utils.data.TensorDataset(train_latents, train_labels)
test_dataset_lat = torch.utils.data.TensorDataset(test_latents, test_labels)

train_loader_lat = DataLoader(train_dataset_lat, batch_size=64, shuffle=True)
test_loader_lat = DataLoader(test_dataset_lat, batch_size=64, shuffle=False)

In [None]:
##train classifier
latent_dim = train_latents.shape[1] 
num_classes = 10
classifier = LatentClassifier(latent_dim, num_classes).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-3, weight_decay=3e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.3)

def train_classifier(model, dataloader, epochs=30):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for latents, labels in dataloader:
            latents, labels = latents.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(latents)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        scheduler.step()
        print(f"Epoch {epoch + 1}, Loss: {total_loss / len(dataloader):.4f}")

train_classifier(classifier, train_loader_lat, epochs=100)


Epoch 1, Loss: 1.4610
Epoch 2, Loss: 1.2394
Epoch 3, Loss: 1.1828
Epoch 4, Loss: 1.1483
Epoch 5, Loss: 1.1221
Epoch 6, Loss: 1.0901
Epoch 7, Loss: 1.0756
Epoch 8, Loss: 1.0693
Epoch 9, Loss: 1.0666
Epoch 10, Loss: 1.0605


In [37]:
##evaluate
def evaluate_classifier(model, dataloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for latents, labels in dataloader:
            latents, labels = latents.to(device), labels.to(device)
            outputs = model(latents)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

accuracy = evaluate_classifier(classifier, test_loader_lat)
print(f"Test Accuracy: {accuracy * 100:.2f}%")

Test Accuracy: 69.60%


In [1]:
#make some reconstructions
def reconstruct_images(model, test_loader, device, num_images=10):
    model.eval()
    data_iter = iter(test_loader)
    images, _ = next(data_iter)
    images = images.to(device)
    
    latents = model.encoder(images)
    reconstructed_images = model.decoder(latents)
    
    fig, axes = plt.subplots(2, num_images, figsize=(15, 3))
    for i in range(num_images):
        axes[0, i].imshow(images[i].cpu().detach().numpy().transpose(1, 2, 0))
        axes[0, i].axis('off')
        axes[1, i].imshow(reconstructed_images[i].cpu().detach().numpy().transpose(1, 2, 0))
        axes[1, i].axis('off')

    
    plt.show()

reconstruct_images(model, test_loader, device)

NameError: name 'model' is not defined

In [27]:
#reconstruction mse
def compute_mse(model, test_loader, device):
    model.eval()
    total_mse = 0
    with torch.no_grad():
        for data, _ in test_loader:
            data = data.to(device)
            # Get the reconstructed output from the model
            reconstructed = model.decoder(model.encoder(data))
            
            # Compute Mean Squared Error (MSE) between original and reconstructed image
            mse = F.mse_loss(reconstructed, data)
            total_mse += mse.item()
    
    avg_mse = total_mse / len(test_loader)
    print(f"Average MSE: {avg_mse:.4f}")
    return avg_mse

# Call the function to compute MSE
compute_mse(model, test_loader, device)

Average MSE: 0.0179


0.017920462159546697

In [28]:
#reconstructiuon ssim
def compute_ssim(model, test_loader, device):
    model.eval()
    total_ssim = 0
    with torch.no_grad():
        for data, _ in test_loader:
            data = data.to(device)
            reconstructed = model.decoder(model.encoder(data))
            
            data_np = data.cpu().numpy()
            reconstructed_np = reconstructed.cpu().numpy()
            
            batch_ssim = 0
            for i in range(data.size(0)):  
                batch_ssim += ssim(data_np[i, 0], reconstructed_np[i, 0], data_range=1)
            
            total_ssim += batch_ssim / data.size(0)
    
    avg_ssim = total_ssim / len(test_loader)
    print(f"Average SSIM: {avg_ssim:.4f}")
    return avg_ssim

compute_ssim(model, test_loader, device)


Average SSIM: 0.9058


np.float64(0.905776253233012)

In [2]:
#visualize latent space
def visualize_latent_space(model, data_loader, device):
    model.eval()
    latents = []
    labels = []
    
    with torch.no_grad():
        for data, target in data_loader:
            data = data.to(device)
            
            latent_representation = model.encoder(data, target)
            latents.append(latent_representation.view(latent_representation.size(0), -1).cpu().numpy())
            labels.append(target.cpu().numpy())
    
    latents = np.concatenate(latents, axis=0)
    labels = np.concatenate(labels, axis=0)
    
    tsne = TSNE(n_components=2, random_state=42)
    latent_tsne = tsne.fit_transform(latents)
    
    plt.figure(figsize=(8, 6))
    plt.scatter(latent_tsne[:, 0], latent_tsne[:, 1], c=labels, cmap='jet', s=2)
    # plt.colorbar(scatter)
    plt.axis('off')
    # plt.title("t-SNE Visualization of VQ-VAE Latent Space")
    plt.show()

visualize_latent_space(model, test_loader, device)

NameError: name 'model' is not defined