In [1]:
import torch
from torch import nn
from torch import optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class InfoVAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim, init_temperature, anneal_rate):
        super(InfoVAE, self).__init__()

        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, latent_dim)  # Output logits
        )

        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()
        )
        
        self.latent_dim = latent_dim
        self.temperature = init_temperature
        self.anneal_rate = anneal_rate

    def reparameterize(self, logits):
        # Gumbel-Softmax reparameterization
        gumbels = -torch.empty_like(logits).exponential_().log()
        y = logits + gumbels
        return F.softmax(y / self.temperature, dim=-1)

    def forward(self, x):
        logits = self.encoder(x)
        z = self.reparameterize(logits)
        return self.decoder(z), logits

    def anneal_temperature(self):
        self.temperature *= self.anneal_rate

def compute_loss(recon_x, x, logits, latent_dim):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    q_y = F.softmax(logits, dim=-1)
    log_q_y = torch.log(q_y + 1e-20)
    KLD = torch.sum(q_y * (log_q_y - torch.log(torch.tensor(1.0 / latent_dim))), dim=-1).mean()*latent_dim
    return BCE + KLD, BCE, KLD

def train(model, dataloader, optimizer, epoch):
    model.train()
    train_loss = 0
    BCE_loss = 0
    KLD_loss = 0
    for batch_idx, (data, _) in enumerate(dataloader):
        data = data.view(-1, 784).to(device)  # Flatten the images
        optimizer.zero_grad()

        recon_batch, logits = model(data)
        try:
            loss, BCE, KLD = compute_loss(recon_batch, data, logits, model.latent_dim)
        except:
            break
        loss.backward()
        train_loss += loss.item() / len(data)
        BCE_loss += BCE.item() / len(data)
        KLD_loss += KLD.item() / len(data)
        optimizer.step()

    return train_loss, BCE_loss, KLD_loss

def get_data_loader(dataset_name, transform, batch_size):
    if dataset_name == 'MNIST':
        dataset = datasets.MNIST('data', train=True, download=True, transform=transform)
    else:
        raise ValueError('Invalid dataset name')

    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return dataloader

def generate_digit(model, z):
    digit = model.decoder(z)
    digit_reshaped = digit.view(28, 28).detach().cpu()
    plt.imshow(digit_reshaped, cmap='gray')
    plt.show()

# Run the main function on MNIST
dataset_name = 'MNIST'
input_dim = 784  # 28*28, size of MNIST images
hidden_dim = 32
latent_dim = 10
init_temperature = 1.0
anneal_rate = 0.99
epochs = 100
batch_size = 128

# Initialize the model and optimizer
model = InfoVAE(input_dim, hidden_dim, latent_dim, init_temperature, anneal_rate).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Prepare the data
transform = transforms.ToTensor()
dataloader = get_data_loader(dataset_name, transform, batch_size)

for epoch in range(1, epochs + 1):
    try:
        total_loss, BCE_loss, KLD_loss = train(model, dataloader, optimizer, epoch)
        model.anneal_temperature()
        print('Epoch: {}, Loss: {}, BCE: {}, KLD: {}'.format(epoch, total_loss, BCE_loss, KLD_loss))
    except KeyboardInterrupt:
        break

# Generate average digits 0-9
z = torch.eye(10).to(device)
generate_digit(model, z)


Epoch: 1, Loss: 129609.62735493977, BCE: 129529.49825541179, KLD: 80.12901411758503
Epoch: 2, Loss: 97400.51867167155, BCE: 97316.09371948242, KLD: 84.42494252324104
Epoch: 3, Loss: 96926.60726420085, BCE: 96842.18118286133, KLD: 84.42667317887147
Epoch: 4, Loss: 96825.89423116048, BCE: 96741.46793619792, KLD: 84.42725592354934
Epoch: 5, Loss: 96786.72950744629, BCE: 96702.3031056722, KLD: 84.4276611705621
Epoch: 6, Loss: 96770.15310668945, BCE: 96685.72667439778, KLD: 84.42778872450192
Epoch: 7, Loss: 96759.43170674641, BCE: 96675.00527445476, KLD: 84.42791796227296
Epoch: 8, Loss: 96760.5835164388, BCE: 96676.15708414714, KLD: 84.42794930438201
Epoch: 9, Loss: 96750.33815002441, BCE: 96665.91171773274, KLD: 84.42801649371783
Epoch: 10, Loss: 96751.30640157063, BCE: 96666.87996927898, KLD: 84.42804500460625
Epoch: 11, Loss: 96747.79201253255, BCE: 96663.36558024089, KLD: 84.42805350820224
Epoch: 12, Loss: 96750.39926656087, BCE: 96665.97283426921, KLD: 84.42807715634505
Epoch: 13, Los

RuntimeError: shape '[28, 28]' is invalid for input of size 7840