In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
from torchvision.datasets import ImageFolder
from torchvision.io import read_image
from torch.utils.data import SubsetRandomSampler
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import save_image

import numpy as np
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from PIL import Image
from tqdm.notebook import tqdm
import copy
import os
import gc

os.environ['TORCH_LOGS'] = "+dynamo"
os.environ['TORCHDYNAMO_VERBOSE'] = "1"

In [None]:
class VAE(nn.Module):
    def __init__(self, latent_dim):
        super(VAE, self).__init__()
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 3, stride=2, padding=1),  # Output: (32, 64, 64)
            nn.BatchNorm2d(32),
            nn.LeakyReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1), # Output: (64, 32, 32)
            nn.BatchNorm2d(64),
            nn.LeakyReLU(),
            nn.Conv2d(64, 128, 3, stride=2, padding=1), # Output: (128, 16, 16)
            nn.BatchNorm2d(128),
            nn.LeakyReLU(),
            nn.Conv2d(128, 256, 3, stride=2, padding=1), # Output: (256, 8, 8)
            nn.BatchNorm2d(256),
            nn.LeakyReLU(),
            nn.Conv2d(256, 512, 3, stride=2, padding=1), # Output: (512, 4, 4)
            nn.BatchNorm2d(512),
            nn.LeakyReLU()
        )

        self.fc_mu = nn.Linear(4 * 4 * 512, latent_dim)
        self.fc_logvar = nn.Linear(4 * 4 * 512, latent_dim)

        # Decoder
        self.decoder_input = nn.Linear(latent_dim, 4 * 4 * 512)

        # Decoder using ConvTranspose2d for upsampling
        self.decoder = nn.Sequential(
            nn.BatchNorm2d(512),
            nn.LeakyReLU(),
            
            # Reshape from latent vector to feature map
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),  # Output: (256, 8, 8)
            nn.BatchNorm2d(256),
            nn.LeakyReLU(),

            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),  # Output: (128, 16, 16)
            nn.BatchNorm2d(128),
            nn.LeakyReLU(),

            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),   # Output: (64, 32, 32)
            nn.BatchNorm2d(64),
            nn.LeakyReLU(),

            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),    # Output: (32, 64, 64)
            nn.BatchNorm2d(32),
            nn.LeakyReLU(),

            nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1),    # Output: (16, 128, 128)
            nn.BatchNorm2d(16),
            nn.LeakyReLU(),
            
            # Final layer: project to 3 channels (RGB) with 3x3 convolution
            nn.Conv2d(16, 3, kernel_size=3, padding=1),  # Output: (3, 128, 128)
            nn.Tanh()  # Or Sigmoid depending on normalization
            # Final layer outputs RGB values between -1 and 1
        )
    
    
    def encode(self,x):
        # Encode
        x = self.encoder(x)
        x = x.view(x.size(0), -1)  # Flatten
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar
    
    def decode(self,z):
        # Decode
        x = self.decoder_input(z)
        x = x.view(-1, 512, 4, 4)  # Reshape to image feature map
        x = self.decoder(x)
        return x
    
    def vae_loss(self, recon_x, x, mu, logvar, gamma=5e-5):
        # Reconstruction loss (MSE or L1 loss)
        recon_loss = F.mse_loss(recon_x, x, reduction='sum')
        
        # KL Divergence loss
        kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        
        return recon_loss + (gamma*kld_loss)
    def forward(self, x):
        # Encode
        mu, logvar =self.encode(x)

        # Reparameterize
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mu + eps * std
        
        # Decode
        x=self.decode(z)
        return x, mu, logvar


In [18]:
class ManualConvTranspose2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=4, stride=2, padding='same',bias=True):
        super(ManualConvTranspose2d, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        
        # Convolution layer with padding=0 because we're handling padding manually
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=padding,bias=bias)
        
    def forward(self, x):
        # Step 1: Zero-insert to upsample
        batch_size, channels, height, width = x.size()
        
        # Calculate upsampled dimensions
        upsampled_height = height * self.stride
        upsampled_width = width * self.stride
        upsampled = torch.zeros(batch_size, channels, upsampled_height, upsampled_width, device=x.device)
        
        # Step 2: Fill in original values, leaving zero-inserted spaces
        upsampled[:, :, ::self.stride, ::self.stride] = x

        # Step 3: Apply convolution
        output = self.conv(upsampled)
        
        return output

# Testing the implementation
input_tensor = torch.randn(1, 16, 4, 4)  # Example input
model = ManualConvTranspose2d(in_channels=16, out_channels=16, kernel_size=4, stride=2, padding='same')
output = model(input_tensor)
print("Output shape:", output.shape)


Output shape: torch.Size([1, 16, 8, 8])


In [19]:
# Image transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])  # Normalize to [-1, 1]
])
# Model, optimizer, and compilation
latent_dim = 128
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [20]:
class TiledImageDataset(Dataset):
    def __init__(self, data_dir, image_ids, transform=None):
        self.data_dir = data_dir
        self.image_ids = image_ids
        self.transform = transform
        self.image_paths = []
        
        # Collect all image paths for the given image IDs
        for image_id in image_ids:
            image_folder = os.path.join(data_dir, image_id)
            for img_name in os.listdir(image_folder):
                self.image_paths.append(os.path.join(image_folder, img_name))

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        img = Image.open(img_path).convert('RGB')
        
        if self.transform:
            img = self.transform(img)
        
        return img


In [None]:
# Assuming `image_ids` is a list of all folder names (IDs of the original images)
data_dir = 'dataset'
image_ids = os.listdir(data_dir)  # List of all original image IDs

# Split image IDs
train_ids, temp_ids = train_test_split(image_ids, test_size=0.3, random_state=42)
val_ids, test_ids = train_test_split(temp_ids, test_size=2/3, random_state=42)  # 10% validation, 20% test

# Create datasets
train_dataset = TiledImageDataset(data_dir, train_ids, transform=transform)
val_dataset = TiledImageDataset(data_dir, val_ids, transform=transform)
test_dataset = TiledImageDataset(data_dir, test_ids, transform=transform)

batch_size=2**9
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size*2, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size*2, shuffle=False)
len(train_dataset)

In [None]:
# Get 10% of training data for each epoch
def get_subset_sampler(dataset, percentage=0.1):
    dataset_size = len(dataset)
    indices = list(range(dataset_size))
    np.random.shuffle(indices)
    
    split = int(np.ceil(percentage * dataset_size))
    train_indices = indices[:split]
    
    return SubsetRandomSampler(train_indices)

In [None]:
# Model, optimizer, and compilation
latent_dim = 128
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vae = VAE(latent_dim)
compiled_vae = torch.compile(vae).to(device)
#compiled_vae.load_state_dict(torch.load('epochs/epoch_10/vae_weights_epoch_10.pth',weights_only=True))

optimizer = optim.AdamW(compiled_vae.parameters(), lr=0.001, weight_decay=1e-5)

# Create the learning rate scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min', 
    factor=0.9, 
    patience=1,
    cooldown=1)

num_epochs = 100
train_losses = []
val_losses = []

# Select and save 24 images from the test set (ensure to keep this consistent across epochs)
n=24
fixed_batch = next(iter(DataLoader(test_dataset, batch_size=n, shuffle=False)))
for i in range(n):
    save_image(fixed_batch[i],fp=f'epochs/img_{i}.png')

# Initialize TensorBoard writer
writer = SummaryWriter(log_dir='runs/vae_experiment')
# Log original images to TensorBoard
writer.add_images('Original Images', fixed_batch, 0)


# sample 10% of data per epoch and log everything
best_val_loss = float('inf')
for epoch in range(num_epochs):
    #sampler = get_subset_sampler(train_dataset, percentage=0.1)
    #train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler)
    accumulated_gradients = {name: torch.zeros_like(param) for name, param in compiled_vae.named_parameters()}
    compiled_vae.train()
    train_loss = 0
    num_samples=0
    # Use tqdm to create a progress bar for the training loop
    with tqdm(total=len(train_loader), desc=f'Training Epoch {epoch + 1}/{num_epochs}', unit='batch', position=0, leave=False) as pbar:
        for batch in train_loader:
            batch = batch.to(device)
            optimizer.zero_grad()

            recon_batch, mu, logvar = compiled_vae(batch)
            loss = compiled_vae.vae_loss(recon_batch, batch, mu, logvar)
            loss.backward()
            train_loss += loss.item()
            optimizer.step()

            for name, param in compiled_vae.named_parameters():
                if param.grad is not None:
                    accumulated_gradients[name] += param.grad.detach().clone()

            num_samples+=batch.size()[0]

            # Update the progress bar
            pbar.update(1)  # Increment the progress bar by 1
            pbar.set_postfix({'train_loss': loss.item()})  # Display current loss


    train_loss /= num_samples
    train_losses.append(train_loss)


    # Create current directory if doesn't exists
    current_dir=f'epochs/epoch_{epoch +1}/'
    os.makedirs(current_dir, exist_ok=True)


    # Validation loop
    compiled_vae.eval()
    val_loss = 0
    num_samples_val=0
    with torch.no_grad():
        with tqdm(total=len(val_loader), desc=f'Validation Epoch {epoch + 1}/{num_epochs}', unit='batch',position=1, leave=False) as pbar:
            for batch in val_loader:
                batch = batch.to(device)
                recon_batch, mu, logvar = compiled_vae(batch)
                loss = compiled_vae.vae_loss(recon_batch, batch, mu, logvar)
                val_loss += loss.item()
                num_samples_val+=batch.size()[0]

                # Update the progress bar
                pbar.update(1)  # Increment the progress bar by 1
                pbar.set_postfix({'val_loss': loss.item()})  # Display current loss

    val_loss /= num_samples_val
    val_losses.append(val_loss)


    # Update the learning rate scheduler based on validation loss
    scheduler.step(val_loss)
    # The learning rate in optimizer will be updated automatically if needed
    current_lr = scheduler.get_last_lr()[0]  # To check the current learning rate

    # At the end of each epoch, save the reconstructions
    compiled_vae.eval()
    with torch.no_grad():
        fixed_batch_epoch = fixed_batch.to(device)
        mu, _ = compiled_vae.encode(fixed_batch_epoch)
        recon_batch = compiled_vae.decode(mu)
        recon_batch=recon_batch.cpu()

        # Reverse normalization
        recon_batch = (recon_batch*0.5)+0.5 
        # Clamp the values to be in the range [0, 1]
        recon_batch = torch.clamp(recon_batch, 0, 1)
        for i in range(n):
            save_image(recon_batch[i],fp=os.path.join(current_dir,f'img_recon_{i}.png'))
    
    # Log reconstructed images to TensorBoard
    writer.add_images('Reconstructed Images', recon_batch, epoch)

    # Log the model parameters and accumulated gradients to TensorBoard after the epoch
    avg_gradients=copy.deepcopy(accumulated_gradients)
    for name, param in compiled_vae.named_parameters():
        writer.add_histogram(f'Weights/{name}', param, global_step=epoch)
        if param.grad is not None:
            # Log averaged gradients
            avg_gradient = accumulated_gradients[name] / num_samples
            avg_gradients[name]=avg_gradient
            writer.add_histogram(f'Gradients/{name}', avg_gradient, global_step=epoch)

    # Save both model and optimizer states
    torch.save({
        'epoch': epoch+1,
        'model_state_dict': compiled_vae.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'train_loss': train_loss,
        'val_loss':val_loss,
        'num_samples': num_samples,
        'avg_gradients':avg_gradients
    }, os.path.join(current_dir, f'checkpoint_epoch_{epoch+1}.pth'))

    # Save best model when validation loss improves
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
        'epoch': epoch+1,
        'model_state_dict': compiled_vae.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'train_loss': train_loss,
        'val_loss':val_loss,
        'num_samples': num_samples,
        'avg_gradients':avg_gradients
        }, os.path.join('epochs/best_model/', 'vae_best_model.pth'))

    # Log training and validation losses
    writer.add_scalar('Loss/Train', train_loss, epoch)
    writer.add_scalar('Loss/Validation', val_loss, epoch)

    tqdm.write(f'Epoch {epoch + 1}, Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}, Learning Rate: {current_lr}')


# After training loop

# Calculate test loss
compiled_vae.eval()
test_loss = 0
num_samples = 0
with tqdm(total=len(val_loader), desc=f'Validation', unit='batch') as pbar:
    with torch.no_grad():
        for batch in test_loader:
            batch = batch.to(device)
            recon_batch, mu, logvar = compiled_vae(batch)
            loss = compiled_vae.vae_loss(recon_batch, batch, mu, logvar)
            test_loss += loss.item()
            num_samples += batch.size()[0]

            # Update the progress bar
            pbar.update(1)  # Increment the progress bar by 1
            pbar.set_postfix({'val_loss': loss.item()})  # Display current loss
test_loss /= num_samples

# Log test losses
writer.add_scalar('Loss/Test', test_loss, epoch)

# Close the TensorBoard writer
writer.close()

# Print test loss
tqdm.write(f'Test Loss: {test_loss:.4f}')

# Plot training and validation losses
plt.figure(figsize=(10, 5))
plt.plot(range(1, num_epochs + 1), train_losses, label='Training Loss')
plt.plot(range(1, num_epochs + 1), val_losses, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid()
plt.savefig('loss_plot.png')  # Save the plot as an image file
plt.show()  # Display the plot

torch.cuda.empty_cache()
gc.collect()