In [13]:
pip install torch torchvision torchaudio

Defaulting to user installation because normal site-packages is not writeable
Collecting torchvision
  Downloading torchvision-0.20.1-cp312-cp312-win_amd64.whl.metadata (6.2 kB)
Collecting torchaudio
  Downloading torchaudio-2.5.1-cp312-cp312-win_amd64.whl.metadata (6.5 kB)
Downloading torchvision-0.20.1-cp312-cp312-win_amd64.whl (1.6 MB)
   ---------------------------------------- 0.0/1.6 MB ? eta -:--:--
   ---------------------------------------- 0.0/1.6 MB ? eta -:--:--
   - -------------------------------------- 0.0/1.6 MB 653.6 kB/s eta 0:00:03
   ----- ---------------------------------- 0.2/1.6 MB 2.3 MB/s eta 0:00:01
   ------------------- -------------------- 0.8/1.6 MB 5.4 MB/s eta 0:00:01
   ---------------------------------- ----- 1.3/1.6 MB 7.1 MB/s eta 0:00:01
   ---------------------------------------- 1.6/1.6 MB 7.1 MB/s eta 0:00:00
Downloading torchaudio-2.5.1-cp312-cp312-win_amd64.whl (2.4 MB)
   ---------------------------------------- 0.0/2.4 MB ? eta -:--:--
   ---

In [1]:
import torch
import torch.nn as nn

In [3]:
# the MNIST data has each data sample with dimensions: 1x28x28

class Encoder(nn.Module):
    def __init__(self, input_channels, hidden_dim, z_dim):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(input_channels, 32, kernel_size=4, stride=2, padding=1)  
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)              
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)  
           
        self.fc_mu = nn.Linear(128 * 4 * 4, z_dim)
        self.fc_logvar = nn.Linear(128 * 4 * 4, z_dim)

    def forward(self, x):
        h = F.relu(self.conv1(x)) # 1x28x28 -> 32x14x14
        h = F.relu(self.conv2(h)) # 32x14x14 -> 64x7x7
        h = F.relu(self.conv3(h)) # 64x7x7 -> 128x4x4
        h = h.view(h.size(0), -1) # 128x4x4 -> 128x16 (just to "concatenate all 2D data into 1 single dimension")
        
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

In [5]:
class Decoder(nn.Module):
    def __init__(self, z_dim, hidden_dim, output_channels):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(z_dim, 128 * 4 * 4)
        self.deconv1 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)  # 4x4 -> 8x8
        self.deconv2 = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1)                    # 8x8 -> 16x16
        self.deconv3 = nn.ConvTranspose2d(32, output_channels, kernel_size=4, stride=2, padding=1)       # 16x16 -> 32x32
        self.output_layer = nn.Conv2d(output_channels, output_channels, kernel_size=5, stride=1, padding=0)  # 32x32 -> 28x28

    def forward(self, z):
        h = F.relu(self.fc(z))
        h = h.view(-1, 128, 4, 4)
        h = F.relu(self.deconv1(h))
        h = F.relu(self.deconv2(h))
        h = F.relu(self.deconv3(h))
        x_recon = torch.sigmoid(self.output_layer(h))
        return x_recon

In [7]:
class VAE(nn.Module):
    def __init__(self, input_channels, hidden_dim, z_dim):
        super(VAE, self).__init__()
        self.encoder = Encoder(input_channels, hidden_dim, z_dim)
        self.decoder = Decoder(z_dim, hidden_dim, input_channels)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std) #here is the assumption p(z)=N(0,I)
        return mu + eps * std

    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decoder(z)
        return x_recon, mu, logvar

#criterion = nn.MSELoss(reduction='sum')
criterion = nn.BCELoss(reduction='sum')
def loss_function(recon_x, x, mu, logvar):
    MSE = criterion(recon_x, x)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return MSE + KLD

In [9]:
class ConditionalEncoder(nn.Module):
    def __init__(self, input_channels, hidden_dim, z_dim, num_classes):
        super(Encoder, self).__init__()
        self.num_classes = num_classes
        self.conv1 = nn.Conv2d(input_channels + num_classes, 32, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
        self.fc_mu = nn.Linear(128 * 4 * 4, z_dim)
        self.fc_logvar = nn.Linear(128 * 4 * 4, z_dim)

    def forward(self, x, labels):

        # one-hot encode labels and expand dimensions shape [batch_size, num_classes, spatialY, spatialX],
        y = one_hot_encode(labels, self.num_classes).unsqueeze(2).unsqueeze(3) # [64, 10, 1, 1]
        y = y.expand(-1, -1, x.size(2), x.size(3)) #[64, 10, 28, 28] make a "layer of conditional information"
        # concatenate along the channel dimension
        x = torch.cat([x, y], dim=1)


        h = F.relu(self.conv1(x))
        h = F.relu(self.conv2(h))
        h = F.relu(self.conv3(h))
        h = h.view(h.size(0), -1)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

In [11]:
def loss_function(recon_x, x, mu, logvar, beta=1.0):
    BCE = criterion(recon_x, x)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + beta * KLD