In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np

# Define the Encoder
class Encoder(nn.Module):
    def __init__(self, latent_dim=20):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
        self.fc_mu = nn.Linear(7*7*64 + 10, latent_dim)
        self.fc_logvar = nn.Linear(7*7*64 + 10, latent_dim)
        self.relu = nn.ReLU()

    def forward(self, x, y):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = x.view(x.size(0), -1)
        x = torch.cat([x, y], dim=1)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar

# Define the Decoder
class Decoder(nn.Module):
    def __init__(self, latent_dim=20):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(latent_dim + 10, 7*7*64)
        self.deconv1 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.deconv2 = nn.ConvTranspose2d(32, 1, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, z, y):
        z = torch.cat([z, y], dim=1)
        x = self.fc(z)
        x = x.view(x.size(0), 64, 7, 7)
        x = self.relu(self.deconv1(x))
        x = self.sigmoid(self.deconv2(x))
        return x

# Define the CVAE
class CVAE(nn.Module):
    def __init__(self, latent_dim=20):
        super(CVAE, self).__init__()
        self.encoder = Encoder(latent_dim)
        self.decoder = Decoder(latent_dim)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x, y):
        mu, logvar = self.encoder(x, y)
        z = self.reparameterize(mu, logvar)
        return self.decoder(z, y), mu, logvar

# Define the Loss Function
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

# Set up training environment
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CVAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

# Training loop
def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, labels) in enumerate(train_loader):
        data = data.to(device)
        labels = F.one_hot(labels, num_classes=10).float().to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data, labels)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
    print(f'Epoch {epoch}, Loss: {train_loss / len(train_loader.dataset)}')

# Train for 10 epochs
for epoch in range(1, 11):
    train(epoch)

# Save the decoder weights
torch.save(model.decoder.state_dict(), 'decoder_weights.pth')

100%|██████████| 9.91M/9.91M [00:00<00:00, 56.2MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.69MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 12.6MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 5.89MB/s]


Epoch 1, Loss: 151.38528248697918
Epoch 2, Loss: 114.30116756184896
Epoch 3, Loss: 108.90424615885416
Epoch 4, Loss: 106.06983440755208
Epoch 5, Loss: 104.44439197591146
Epoch 6, Loss: 103.30896795247396
Epoch 7, Loss: 102.35019763997396
Epoch 8, Loss: 101.66701414388021
Epoch 9, Loss: 101.08671310221354
Epoch 10, Loss: 100.59428188476562


In [2]:
from google.colab import files
files.download('decoder_weights.pth')


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>