In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose, ToTensor
from PIL import Image
import matplotlib.pyplot as plt


In [2]:
# Hyperparameters
epochs = 200
batch_size = 50
latent_dim = 25
h_image, w_image = 96,96 # Target image size for training


In [3]:
#vae model
class VAE(nn.Module):
    def __init__(self, latent_dim):
        super(VAE, self).__init__()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, stride=2, padding=2),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU()
        )
        self.flatten = nn.Flatten()
        self.fc_mean = nn.Linear(128 * 6 * 6, latent_dim)  # Update input size here
        self.fc_var = nn.Linear(128 * 6 * 6, latent_dim)
        self.fc_decoder = nn.Linear(latent_dim, 128 * 6 * 6)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 1, kernel_size=3, stride=1, padding=1),
            nn.Sigmoid()
        )

    def reparameterize(self, mean, var):
        std = torch.exp(0.5 * var)
        eps = torch.randn_like(std)
        return mean + eps * std

    def forward(self, x):
        x = self.encoder(x)
        x = self.flatten(x)
        mean = self.fc_mean(x)
        var = self.fc_var(x)
        z = self.reparameterize(mean, var)
        x = self.fc_decoder(z)
        x = x.view(-1, 128, 6, 6)
        x = self.decoder(x)
        return x, mean, var

In [4]:
# Loss Function
def vae_loss(recon_x, x, mean, var):
    # recon_loss = nn.functional.mse_loss(recon_x, x, reduction='sum')
    recon_loss = nn.functional.binary_cross_entropy(recon_x, x, reduction='sum')

    kl_loss = -0.5 * torch.sum(1 + var - mean.pow(2) - var.exp())
    return recon_loss + kl_loss


In [5]:
# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Custom Dataset
class CityscapesDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.image_files = [os.path.join(root_dir, f) for f in os.listdir(root_dir) if f.endswith('.jpg') or f.endswith('.png')]
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        image = Image.open(img_path).convert('L')  
        if self.transform:
            image = self.transform(image)
        return image

Using device: cpu


In [6]:
# Dataset Paths
dataset_dir = '/kaggle/input/cityscapes/train/img'  

In [7]:
# Transformations for left part
class LeftCrop:
    def __init__(self, size):
        self.size = size

    def __call__(self, image):
        return image.crop((0, 0, self.size[0], self.size[1]))

transform_left = Compose([
    LeftCrop((96, 96)),
    ToTensor(),
])
# Load datasets for left and right images
dataset_left = CityscapesDataset(root_dir=dataset_dir, transform=transform_left)
# Dataloaders
dataloader_left = DataLoader(dataset_left, batch_size=batch_size, shuffle=True)

In [8]:
#  Transformations for right part
class RightCrop:
    def __init__(self, size):
        self.size = size

    def __call__(self, image):
        image_width = image.width
        return image.crop((image_width - self.size[0], 0, image_width, self.size[1]))

transform_right = Compose([
    RightCrop((96, 96)),
    ToTensor(),
])
dataset_right = CityscapesDataset(root_dir=dataset_dir, transform=transform_right)
dataloader_right = DataLoader(dataset_right, batch_size=batch_size, shuffle=True)


In [None]:
# Initialize Model and Optimizer
vae = VAE(latent_dim).to(device)
optimizer = optim.Adam(vae.parameters(), lr=0.001)
# Training loop for left part images
for epoch in range(epochs):
    vae.train()
    train_loss = 0
    for batch in dataloader_left:
        left_image = batch.to(device)
        optimizer.zero_grad()
        recon_x, mean, var = vae(left_image)
        loss = vae_loss(recon_x, left_image, mean, var)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

    print(f"Epoch {epoch + 1}, Loss: {train_loss / len(dataloader_left.dataset):.4f}")

# Save the model after training
torch.save(vae.state_dict(), "vae_left_model.pth")
print("Model saved for left part")

Epoch 1, Loss: 5622.0710
Epoch 2, Loss: 5418.6849
Epoch 3, Loss: 5329.4353
Epoch 4, Loss: 5237.2745
Epoch 5, Loss: 5202.8994
Epoch 6, Loss: 5191.6126
Epoch 7, Loss: 5178.1017
Epoch 8, Loss: 5173.0608
Epoch 9, Loss: 5166.6909
Epoch 10, Loss: 5162.6443
Epoch 11, Loss: 5158.0510
Epoch 12, Loss: 5154.1058
Epoch 13, Loss: 5151.7131
Epoch 14, Loss: 5150.8892
Epoch 15, Loss: 5145.8149
Epoch 16, Loss: 5143.5682
Epoch 17, Loss: 5140.8395
Epoch 18, Loss: 5138.4157
Epoch 19, Loss: 5136.5334
Epoch 20, Loss: 5134.0772
Epoch 21, Loss: 5132.7660
Epoch 22, Loss: 5130.9791
Epoch 23, Loss: 5128.6579
Epoch 24, Loss: 5128.7781
Epoch 25, Loss: 5126.1630
Epoch 26, Loss: 5128.7244
Epoch 27, Loss: 5124.7843
Epoch 28, Loss: 5121.9531
Epoch 29, Loss: 5119.6676
Epoch 30, Loss: 5118.5167
Epoch 31, Loss: 5117.1983
Epoch 32, Loss: 5117.8383
Epoch 33, Loss: 5115.9203
Epoch 34, Loss: 5114.3748
Epoch 35, Loss: 5113.1187
Epoch 36, Loss: 5111.2699
Epoch 37, Loss: 5110.0986
Epoch 38, Loss: 5111.8304
Epoch 39, Loss: 5109.

In [None]:
# Initialize Model and Optimizer
vae = VAE(latent_dim).to(device)
optimizer = optim.Adam(vae.parameters(), lr=0.001)
# Training loop for right part images
for epoch in range(epochs):
    vae.train()
    train_loss = 0
    for batch in dataloader_right:
        right_image = batch.to(device)
        optimizer.zero_grad()
        recon_x, mean, var = vae(right_image)
        loss = vae_loss(recon_x, right_image, mean, var)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

    print(f"Epoch {epoch + 1}, Loss: {train_loss / len(dataloader_right.dataset):.4f}")

# Save the model after training
torch.save(vae.state_dict(), "vae_right_model.pth")
print("Model saved for right part")

In [None]:
vae.eval()  # Set the model to evaluation mode
with torch.no_grad():
    # Fetch a small batch from the dataloader for the left-cropped images
    sample_left = next(iter(dataloader_left)).to(device)[:2]  # First 2 images of left-cropped part
    sample_right = next(iter(dataloader_right)).to(device)[:2]  # First 2 images of right-cropped part
    print(f"Input shape (Left Part): {sample_left.shape}")  # Debug input shape for left part
    print(f"Input shape (Right Part): {sample_right.shape}")  # Debug input shape for right part

    # Print the image indices being processed
    print("Processing images with indices: 0, 1")  # These are the first two images from the batch

    # Pass through VAE for left part
    recon_x_left, mean_left, var_left = vae(sample_left)
    print(f"Reconstructed shape (Left Part): {recon_x_left.shape}")  # Debug output shape for left part

    # Pass through VAE for right part
    recon_x_right, mean_right, var_right = vae(sample_right)
    print(f"Reconstructed shape (Right Part): {recon_x_right.shape}")  # Debug output shape for right part

    # Check latent space
    print("Latent Mean (Left Part):", mean_left)
    print("Latent Variance (Left Part):", var_left)
    print("Latent Mean (Right Part):", mean_right)
    print("Latent Variance (Right Part):", var_right)

    # Plot results
    plt.figure(figsize=(12, 10))
    for i in range(len(sample_left)):
        # Original left-cropped image
        plt.subplot(len(sample_left), 4, 4 * i + 1)
        plt.imshow(sample_left[i].cpu().squeeze(), cmap='gray')
        plt.title(f"Original Left Part (Image {i})")
        plt.axis('off')

        # Reconstructed left-cropped image
        plt.subplot(len(sample_left), 4, 4 * i + 2)
        plt.imshow(recon_x_left[i].cpu().squeeze(), cmap='gray')
        plt.title(f"Reconstructed Left Part (Image {i})")
        plt.axis('off')

        # Original right-cropped image
        plt.subplot(len(sample_right), 4, 4 * i + 3)
        plt.imshow(sample_right[i].cpu().squeeze(), cmap='gray')
        plt.title(f"Original Right Part (Image {i})")
        plt.axis('off')

        # Reconstructed right-cropped image
        plt.subplot(len(sample_right), 4, 4 * i + 4)
        plt.imshow(recon_x_right[i].cpu().squeeze(), cmap='gray')
        plt.title(f"Reconstructed Right Part (Image {i})")
        plt.axis('off')

    plt.show()


In [None]:
import torch
import torch
import matplotlib.pyplot as plt
from PIL import Image
import torchvision.transforms as transforms

# Load your input image (path: '/kaggle/input/byopproject/testset/middle/train1.png')
input_img_path = '/kaggle/input/byopproject/testset/middle/train1.png'
input_img = Image.open(input_img_path).convert('L') 

# Resize the input image to match the height and width of the parts
input_img_resized = input_img.resize((recon_x_left.shape[3], recon_x_left.shape[2]))

# Convert the input image to a tensor
transform = transforms.ToTensor()
input_tensor = transform(input_img_resized).unsqueeze(0).to(device)  

# Replicate the input tensor to match the batch size of recon_x_left and recon_x_right
input_tensor = input_tensor.repeat(recon_x_left.shape[0], 1, 1, 1)

vae.eval() 
with torch.no_grad():
    # Fetch a small batch from the dataloader for the left-cropped images
    sample_left = next(iter(dataloader_left)).to(device)[:2]  
    sample_right = next(iter(dataloader_right)).to(device)[:2]  
    print(f"Input shape (Left Part): {sample_left.shape}") 
    print(f"Input shape (Right Part): {sample_right.shape}") 

    # Pass through VAE for left part
    recon_x_left, mean_left, var_left = vae(sample_left)
    print(f"Reconstructed shape (Left Part): {recon_x_left.shape}") 

    # Pass through VAE for right part
    recon_x_right, mean_right, var_right = vae(sample_right)
    print(f"Reconstructed shape (Right Part): {recon_x_right.shape}")  
    # latent space
    print("Latent Mean (Left Part):", mean_left)
    print("Latent Variance (Left Part):", var_left)
    print("Latent Mean (Right Part):", mean_right)
    print("Latent Variance (Right Part):", var_right)

   
    stitched_image = torch.cat((recon_x_left, input_tensor, recon_x_right), dim=3) 
   
    plt.figure(figsize=(10, 5))
    for i in range(len(sample_left)):
        
        plt.subplot(len(sample_left), 2, 2 * i + 1)
        plt.imshow(stitched_image[i].cpu().squeeze(), cmap='gray')
        plt.title(f"Stitched Image (Image {i})")
        plt.axis('off')

    plt.show()
