In [1]:
class VAEEncoder(nn.Module):
    def __init__(self, latent_dim):
        super(VAEEncoder, self).__init__()
        
        # Initial convolution
        self.conv_initial = nn.Conv2d(3, 64, 3, stride=1, padding=1)
        
        # Downsampling blocks
        self.conv1 = nn.Conv2d(64, 128, 4, stride=2, padding=1)    # 64x64 -> 32x32
        self.conv2 = nn.Conv2d(128, 256, 4, stride=2, padding=1)   # 32x32 -> 16x16
        self.conv3 = nn.Conv2d(256, 512, 4, stride=2, padding=1)   # 16x16 -> 8x8 (skip connection point)
        self.conv4 = nn.Conv2d(512, 1024, 4, stride=2, padding=1)  # 8x8 -> 4x4
        
        # Batch normalization
        self.bn1 = nn.BatchNorm2d(128)
        self.bn2 = nn.BatchNorm2d(256)
        self.bn3 = nn.BatchNorm2d(512)
        self.bn4 = nn.BatchNorm2d(1024)
        
        # Bottleneck
        self.fc_mu = nn.Linear(1024 * 4 * 4, latent_dim)
        self.fc_var = nn.Linear(1024 * 4 * 4, latent_dim)
        
        # Dropout for regularization
        self.dropout = nn.Dropout(0.2)
        
    def forward(self, x):
        # Initial convolution
        x = F.leaky_relu(self.conv_initial(x), 0.2)
        
        # Downsampling path
        x = F.leaky_relu(self.bn1(self.conv1(x)), 0.2)
        x = F.leaky_relu(self.bn2(self.conv2(x)), 0.2)
        
        # Save the feature map for the skip connection
        skip_connection = F.leaky_relu(self.bn3(self.conv3(x)), 0.2)
        
        x = F.leaky_relu(self.bn4(self.conv4(skip_connection)), 0.2)
        
        # Flatten and apply dropout
        x = self.dropout(x.view(x.size(0), -1))
        
        # Generate latent parameters
        mu = self.fc_mu(x)
        log_var = self.fc_var(x)
        
        return mu, log_var, skip_connection


class VAEDecoder(nn.Module):
    def __init__(self, latent_dim):
        super(VAEDecoder, self).__init__()
        
        # Initial fully connected layer
        self.fc = nn.Linear(latent_dim, 1024 * 4 * 4)
        
        # Upsampling blocks
        self.conv1 = nn.ConvTranspose2d(1024, 512, 4, stride=2, padding=1)  # 4x4 -> 8x8
        self.conv2 = nn.ConvTranspose2d(1024, 256, 4, stride=2, padding=1)   # 8x8 -> 16x16 (includes skip)
        self.conv3 = nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1)   # 16x16 -> 32x32
        self.conv4 = nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1)    # 32x32 -> 64x64
        
        # Batch normalization
        self.bn1 = nn.BatchNorm2d(512)
        self.bn2 = nn.BatchNorm2d(256)
        self.bn3 = nn.BatchNorm2d(128)
        self.bn4 = nn.BatchNorm2d(64)
        
        # Final convolution for output
        self.conv_final = nn.Conv2d(64, 3, 3, stride=1, padding=1)
        
        # Dropout for regularization
        self.dropout = nn.Dropout(0.2)
        
    def forward(self, z, skip_connection):
        # Reshape from latent space
        x = F.relu(self.fc(z))
        x = x.view(x.size(0), 1024, 4, 4)
        
        # Upsampling path
        x = F.relu(self.bn1(self.conv1(x)))
        
        # Apply skip connection
        # x = torch.cat([x, skip_connection], dim=1)  # Concatenate along channel dimension
        # Apply skip connection if provided
        if skip_connection is not None:
            x = torch.cat([x, skip_connection], dim=1)  # Concatenate along channel dimension
        
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        
        # Final convolution with tanh activation
        x = torch.tanh(self.conv_final(x))
        
        return x

class ConvVAE(nn.Module):
    def __init__(self, latent_dim):
        super(ConvVAE, self).__init__()
        self.encoder = VAEEncoder(latent_dim)
        self.decoder = VAEDecoder(latent_dim)
        
    def reparameterize(self, mu, log_var):
        if self.training:
            std = torch.exp(0.5 * log_var)
            eps = torch.randn_like(std)
            return mu + eps * std
        return mu
        
    def forward(self, x):
        # Encoder with skip connection
        mu, log_var, skip_connection = self.encoder(x)
        z = self.reparameterize(mu, log_var)
        
        # Decoder with skip connection
        recon_x = self.decoder(z, skip_connection)
        return recon_x, mu, log_var


NameError: name 'nn' is not defined

In [2]:
# Model training parameters
learning_rate=0.0001
step_size=10
gamma=0.5

kl_weight=0.01

num_epochs=40

name=f"run_kl_wgt_{str(kl_weight)}_ep_{num_epochs}_ld_1024_skip_cn"
project="assignment-5"


latent_dim = 512 # define latent dimension

# Load model
model = ConvVAE(latent_dim=latent_dim).to(device)
# model

model = run_vae_training(
    model, train_loader, val_loader, device, 
    num_epochs=num_epochs, learning_rate=learning_rate,
    step_size=step_size, gamma=gamma,
    kl_weight=kl_weight,
    name=name, project=project
)

save_path = os.path.join(saved_model_folder, name)
torch.save(model, save_path)
print(f"Model saved at: {save_path}")

NameError: name 'ConvVAE' is not defined

In [3]:
model = load_model(save_path)
max_samples = min(2000, len(val_loader.dataset))
fid_score = compute_fid_score(model, val_loader, device, max_samples=max_samples)
print(f"FID Score for model {name}: {fid_score}")

NameError: name 'load_model' is not defined

In [4]:
data_iter = iter(val_loader)
images, _ = next(data_iter)
visualize_reconstructions(model, images, device)

NameError: name 'val_loader' is not defined