In [30]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
from transformers import ViTModel, ViTConfig
import os
from tqdm import tqdm
from PIL import Image
import torch.nn.functional as F

class UNetGenerator(nn.Module):
    def __init__(self, input_channels, output_channels):
        super(UNetGenerator, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(input_channels, 64, kernel_size=4, stride=2, padding=1),  # 224x224 -> 112x112
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),             # 112x112 -> 56x56
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=4, stride=2, padding=1),            # 56x56 -> 28x28
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=4, stride=2, padding=1),            # 28x28 -> 14x14
            nn.BatchNorm2d(128),
            nn.ReLU(),
        )

        self.middle = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=4, stride=1, padding=1),  # Keeps size at 14x14
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=1, padding=1),  # Keeps size at 14x14
            nn.BatchNorm2d(128),
            nn.ReLU(),
        )

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),  # 14x14 -> 28x28
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),   # 28x28 -> 56x56
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1),   # 56x56 -> 112x112
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.ConvTranspose2d(16, output_channels, kernel_size=4, stride=2, padding=1),  # 112x112 -> 224x224
            nn.Tanh(),
        )

    def forward(self, x):
        # Encoder
        enc = self.encoder(x)
        #print(f"Shape after encoder: {enc.shape}")  # Check encoder output shape
        
        # Middle
        mid = self.middle(enc)
        #print(f"Shape after middle: {mid.shape}")  # Check shape after middle layer
        
        # Decoder
        dec = self.decoder(mid)
        #print(f"Shape after decoder: {dec.shape}")  # Final generator output shape
        return dec


class Discriminator(nn.Module):
    def __init__(self, input_channels):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(input_channels, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 1, kernel_size=4, stride=2, padding=0),  # Ensures output is [B, 1, 1, 1]
            nn.Sigmoid(),
        )
    
    def forward(self, x):
        return self.model(x)

# Dataset and DataLoader
transform = transforms.Compose([
    transforms.Resize((224, 224)),  
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

class KITTIDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = sorted([os.path.join(root_dir, img) for img in os.listdir(root_dir) if img.endswith('.png') or img.endswith('.jpg')])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, 0  # Dummy label as GANs don't require labels

dataset = KITTIDataset("/kaggle/input/kitti-dataset/data_object_image_2/training/image_2", transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)


# Initialize Models
generator = UNetGenerator(input_channels=128, output_channels=3).cuda()
discriminator = Discriminator(input_channels=3).cuda()
vit_config = ViTConfig(hidden_size=128, num_attention_heads=4, num_hidden_layers=6)
vit_encoder = ViTModel(vit_config).cuda()

# Optimizers
opt_gen = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
opt_disc = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Loss Functions
adversarial_loss = nn.BCELoss()
reconstruction_loss = nn.MSELoss()

# Training Loop
epochs = 100
for epoch in range(epochs):
    for real_images, _ in tqdm(dataloader):
        real_images = real_images.cuda()
        batch_size = real_images.size(0)
        valid = torch.ones((batch_size, 1)).cuda()
        fake = torch.zeros((batch_size, 1)).cuda()

        # Train Generator
        opt_gen.zero_grad()

        #print(f"Input shape: {real_images.shape}")  # Shape of input image batch

        # Get the output from ViT encoder
        output = vit_encoder(real_images).last_hidden_state
        #print(f"Output shape (ViT): {output.shape}")  # ViT encoder output
        
        # Exclude the CLS token
        patch_embeddings = output[:, 1:, :]  # Remove the CLS token
        #print(f"Patch embeddings shape: {patch_embeddings.shape}")  # After excluding CLS token
        
        # Reshape the output to feed into the generator
        num_patches = int(patch_embeddings.shape[1] ** 0.5)  # Compute square root of number of patches
        z = patch_embeddings.permute(0, 2, 1).view(batch_size, 128, num_patches, num_patches)
        #print(f"Final tensor shape (z): {z.shape}")  # After reshaping for generator

        # After shaping z to [batch, 128, 14, 14], upsample it:
        z_upsampled = F.interpolate(z, size=(224, 224), mode='bilinear', align_corners=False)
        gen_images = generator(z_upsampled)

        #print(f"Generated images shape: {gen_images.shape}")


        valid = torch.ones_like(discriminator(gen_images)).cuda()  # Match shape dynamically
        fake = torch.zeros_like(discriminator(gen_images)).cuda()  # Match shape dynamically


        #print(f"Discriminator output shape: {discriminator(gen_images).shape}")
        #print(f"Valid tensor shape: {valid.shape}")

        
        g_loss = adversarial_loss(discriminator(gen_images), valid) + reconstruction_loss(gen_images, real_images)
        g_loss.backward()
        opt_gen.step()

        # Train Discriminator
        opt_disc.zero_grad()
        real_loss = adversarial_loss(discriminator(real_images), valid)
        fake_loss = adversarial_loss(discriminator(gen_images.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        opt_disc.step()

    print(f"Epoch {epoch}/{epochs} | Generator Loss: {g_loss.item()} | Discriminator Loss: {d_loss.item()}")

    # Save model checkpoint
    if epoch % 10 == 0:
        torch.save(generator.state_dict(), f"generator_epoch_{epoch}.pth")

100%|██████████| 234/234 [03:01<00:00,  1.29it/s]


Epoch 0/100 | Generator Loss: 1.0204671621322632 | Discriminator Loss: 0.5757521390914917


100%|██████████| 234/234 [02:04<00:00,  1.88it/s]


Epoch 1/100 | Generator Loss: 0.8944224119186401 | Discriminator Loss: 0.6046009659767151


100%|██████████| 234/234 [02:09<00:00,  1.81it/s]


Epoch 2/100 | Generator Loss: 0.9148299098014832 | Discriminator Loss: 0.662074863910675


100%|██████████| 234/234 [02:08<00:00,  1.81it/s]


Epoch 3/100 | Generator Loss: 0.9174225926399231 | Discriminator Loss: 0.6196024417877197


100%|██████████| 234/234 [02:09<00:00,  1.81it/s]


Epoch 4/100 | Generator Loss: 0.8649089336395264 | Discriminator Loss: 0.6286202669143677


100%|██████████| 234/234 [02:11<00:00,  1.78it/s]


Epoch 5/100 | Generator Loss: 0.8477680683135986 | Discriminator Loss: 0.6478431224822998


100%|██████████| 234/234 [02:11<00:00,  1.78it/s]


Epoch 6/100 | Generator Loss: 0.8825801610946655 | Discriminator Loss: 0.6523526310920715


100%|██████████| 234/234 [02:08<00:00,  1.82it/s]


Epoch 7/100 | Generator Loss: 0.8609124422073364 | Discriminator Loss: 0.6547753810882568


100%|██████████| 234/234 [02:10<00:00,  1.79it/s]


Epoch 8/100 | Generator Loss: 0.8716676235198975 | Discriminator Loss: 0.676280677318573


100%|██████████| 234/234 [02:06<00:00,  1.84it/s]


Epoch 9/100 | Generator Loss: 0.7923703193664551 | Discriminator Loss: 0.6853745579719543


100%|██████████| 234/234 [02:06<00:00,  1.85it/s]


Epoch 10/100 | Generator Loss: 0.8484306335449219 | Discriminator Loss: 0.641687273979187


100%|██████████| 234/234 [02:05<00:00,  1.86it/s]


Epoch 11/100 | Generator Loss: 0.8051302433013916 | Discriminator Loss: 0.7096637487411499


100%|██████████| 234/234 [02:04<00:00,  1.88it/s]


Epoch 12/100 | Generator Loss: 0.8951003551483154 | Discriminator Loss: 0.6600557565689087


100%|██████████| 234/234 [02:04<00:00,  1.87it/s]


Epoch 13/100 | Generator Loss: 0.9105572700500488 | Discriminator Loss: 0.6478037238121033


100%|██████████| 234/234 [02:07<00:00,  1.84it/s]


Epoch 14/100 | Generator Loss: 0.9278598427772522 | Discriminator Loss: 0.6569019556045532


100%|██████████| 234/234 [02:06<00:00,  1.86it/s]


Epoch 15/100 | Generator Loss: 0.8418715596199036 | Discriminator Loss: 0.6609691977500916


100%|██████████| 234/234 [02:06<00:00,  1.85it/s]


Epoch 16/100 | Generator Loss: 0.8129839897155762 | Discriminator Loss: 0.673197865486145


100%|██████████| 234/234 [02:11<00:00,  1.78it/s]


Epoch 17/100 | Generator Loss: 0.8519625067710876 | Discriminator Loss: 0.6657014489173889


100%|██████████| 234/234 [02:07<00:00,  1.84it/s]


Epoch 18/100 | Generator Loss: 0.8247244954109192 | Discriminator Loss: 0.6730898022651672


100%|██████████| 234/234 [02:08<00:00,  1.82it/s]


Epoch 19/100 | Generator Loss: 0.8537623882293701 | Discriminator Loss: 0.6696504950523376


100%|██████████| 234/234 [02:07<00:00,  1.83it/s]


Epoch 20/100 | Generator Loss: 0.8032670021057129 | Discriminator Loss: 0.6897951364517212


100%|██████████| 234/234 [02:06<00:00,  1.85it/s]


Epoch 21/100 | Generator Loss: 0.8628685474395752 | Discriminator Loss: 0.6753075122833252


100%|██████████| 234/234 [02:08<00:00,  1.82it/s]


Epoch 22/100 | Generator Loss: 0.8021018505096436 | Discriminator Loss: 0.6760892868041992


100%|██████████| 234/234 [02:05<00:00,  1.86it/s]


Epoch 23/100 | Generator Loss: 0.8990809917449951 | Discriminator Loss: 0.6575722098350525


100%|██████████| 234/234 [02:04<00:00,  1.88it/s]


Epoch 24/100 | Generator Loss: 0.8075598478317261 | Discriminator Loss: 0.6773127913475037


100%|██████████| 234/234 [02:06<00:00,  1.85it/s]


Epoch 25/100 | Generator Loss: 0.8177750110626221 | Discriminator Loss: 0.6819217801094055


100%|██████████| 234/234 [02:06<00:00,  1.84it/s]


Epoch 26/100 | Generator Loss: 0.7951428890228271 | Discriminator Loss: 0.672096848487854


100%|██████████| 234/234 [02:07<00:00,  1.84it/s]


Epoch 27/100 | Generator Loss: 0.8677485585212708 | Discriminator Loss: 0.6698155403137207


100%|██████████| 234/234 [02:07<00:00,  1.83it/s]


Epoch 28/100 | Generator Loss: 0.8236097693443298 | Discriminator Loss: 0.702592134475708


100%|██████████| 234/234 [02:08<00:00,  1.82it/s]


Epoch 29/100 | Generator Loss: 0.8524274230003357 | Discriminator Loss: 0.6662145853042603


100%|██████████| 234/234 [02:02<00:00,  1.91it/s]


Epoch 30/100 | Generator Loss: 0.8164740800857544 | Discriminator Loss: 0.682445764541626


100%|██████████| 234/234 [02:03<00:00,  1.90it/s]


Epoch 31/100 | Generator Loss: 0.8253035545349121 | Discriminator Loss: 0.6872433423995972


100%|██████████| 234/234 [02:05<00:00,  1.87it/s]


Epoch 32/100 | Generator Loss: 0.8390141725540161 | Discriminator Loss: 0.6749070882797241


100%|██████████| 234/234 [02:05<00:00,  1.86it/s]


Epoch 33/100 | Generator Loss: 0.8341342806816101 | Discriminator Loss: 0.6916966438293457


100%|██████████| 234/234 [02:07<00:00,  1.83it/s]


Epoch 34/100 | Generator Loss: 0.8566270470619202 | Discriminator Loss: 0.6547859907150269


100%|██████████| 234/234 [02:07<00:00,  1.83it/s]


Epoch 35/100 | Generator Loss: 0.8684451580047607 | Discriminator Loss: 0.7392113208770752


100%|██████████| 234/234 [02:08<00:00,  1.82it/s]


Epoch 36/100 | Generator Loss: 0.8120260834693909 | Discriminator Loss: 0.663360595703125


100%|██████████| 234/234 [02:07<00:00,  1.83it/s]


Epoch 37/100 | Generator Loss: 0.781323254108429 | Discriminator Loss: 0.6838430762290955


100%|██████████| 234/234 [02:05<00:00,  1.86it/s]


Epoch 38/100 | Generator Loss: 0.7895152568817139 | Discriminator Loss: 0.6805707216262817


100%|██████████| 234/234 [02:03<00:00,  1.89it/s]


Epoch 39/100 | Generator Loss: 0.8016489148139954 | Discriminator Loss: 0.6808002591133118


100%|██████████| 234/234 [02:06<00:00,  1.86it/s]


Epoch 40/100 | Generator Loss: 0.8106327056884766 | Discriminator Loss: 0.688918948173523


100%|██████████| 234/234 [02:07<00:00,  1.84it/s]


Epoch 41/100 | Generator Loss: 0.8124960064888 | Discriminator Loss: 0.6835490465164185


100%|██████████| 234/234 [02:05<00:00,  1.86it/s]


Epoch 42/100 | Generator Loss: 0.787822961807251 | Discriminator Loss: 0.6889368295669556


100%|██████████| 234/234 [02:06<00:00,  1.85it/s]


Epoch 43/100 | Generator Loss: 0.8456916809082031 | Discriminator Loss: 0.6765224933624268


100%|██████████| 234/234 [02:07<00:00,  1.84it/s]


Epoch 44/100 | Generator Loss: 0.8331030011177063 | Discriminator Loss: 0.6788831949234009


100%|██████████| 234/234 [02:05<00:00,  1.87it/s]


Epoch 45/100 | Generator Loss: 0.7839365601539612 | Discriminator Loss: 0.6863541603088379


100%|██████████| 234/234 [02:03<00:00,  1.89it/s]


Epoch 46/100 | Generator Loss: 0.8360519409179688 | Discriminator Loss: 0.6839478611946106


100%|██████████| 234/234 [02:04<00:00,  1.89it/s]


Epoch 47/100 | Generator Loss: 0.8060393929481506 | Discriminator Loss: 0.6806306838989258


100%|██████████| 234/234 [02:03<00:00,  1.89it/s]


Epoch 48/100 | Generator Loss: 0.7386088371276855 | Discriminator Loss: 0.7018717527389526


100%|██████████| 234/234 [02:04<00:00,  1.88it/s]


Epoch 49/100 | Generator Loss: 0.8969568014144897 | Discriminator Loss: 0.6736968159675598


100%|██████████| 234/234 [02:03<00:00,  1.90it/s]


Epoch 50/100 | Generator Loss: 0.7967444658279419 | Discriminator Loss: 0.7000998258590698


100%|██████████| 234/234 [02:01<00:00,  1.92it/s]


Epoch 51/100 | Generator Loss: 0.796057939529419 | Discriminator Loss: 0.6890507340431213


100%|██████████| 234/234 [02:03<00:00,  1.89it/s]


Epoch 52/100 | Generator Loss: 0.8251014947891235 | Discriminator Loss: 0.666301965713501


100%|██████████| 234/234 [02:03<00:00,  1.90it/s]


Epoch 53/100 | Generator Loss: 0.8232868313789368 | Discriminator Loss: 0.6785391569137573


100%|██████████| 234/234 [02:04<00:00,  1.88it/s]


Epoch 54/100 | Generator Loss: 0.8287081718444824 | Discriminator Loss: 0.6807428598403931


100%|██████████| 234/234 [02:02<00:00,  1.90it/s]


Epoch 55/100 | Generator Loss: 0.8102254867553711 | Discriminator Loss: 0.6958191394805908


100%|██████████| 234/234 [02:01<00:00,  1.92it/s]


Epoch 56/100 | Generator Loss: 0.7907676696777344 | Discriminator Loss: 0.7077564001083374


100%|██████████| 234/234 [02:03<00:00,  1.90it/s]


Epoch 57/100 | Generator Loss: 0.7805426716804504 | Discriminator Loss: 0.680925190448761


100%|██████████| 234/234 [02:05<00:00,  1.86it/s]


Epoch 58/100 | Generator Loss: 0.7663694024085999 | Discriminator Loss: 0.6856697797775269


100%|██████████| 234/234 [02:02<00:00,  1.90it/s]


Epoch 59/100 | Generator Loss: 0.7818620800971985 | Discriminator Loss: 0.706287145614624


100%|██████████| 234/234 [02:03<00:00,  1.90it/s]


Epoch 60/100 | Generator Loss: 0.8768738508224487 | Discriminator Loss: 0.6662469506263733


100%|██████████| 234/234 [02:02<00:00,  1.90it/s]


Epoch 61/100 | Generator Loss: 0.7843970060348511 | Discriminator Loss: 0.6893675327301025


100%|██████████| 234/234 [02:03<00:00,  1.90it/s]


Epoch 62/100 | Generator Loss: 0.8284315466880798 | Discriminator Loss: 0.6790874004364014


100%|██████████| 234/234 [02:05<00:00,  1.86it/s]


Epoch 63/100 | Generator Loss: 0.785324215888977 | Discriminator Loss: 0.7036466598510742


100%|██████████| 234/234 [02:03<00:00,  1.89it/s]


Epoch 64/100 | Generator Loss: 0.8799338936805725 | Discriminator Loss: 0.6578869819641113


100%|██████████| 234/234 [02:03<00:00,  1.89it/s]


Epoch 65/100 | Generator Loss: 0.8422272801399231 | Discriminator Loss: 0.6843662261962891


100%|██████████| 234/234 [02:03<00:00,  1.90it/s]


Epoch 66/100 | Generator Loss: 0.8745805621147156 | Discriminator Loss: 0.6813532114028931


100%|██████████| 234/234 [02:02<00:00,  1.91it/s]


Epoch 67/100 | Generator Loss: 0.7499045133590698 | Discriminator Loss: 0.7044329643249512


100%|██████████| 234/234 [02:02<00:00,  1.92it/s]


Epoch 68/100 | Generator Loss: 0.8427927494049072 | Discriminator Loss: 0.6611875295639038


100%|██████████| 234/234 [02:04<00:00,  1.88it/s]


Epoch 69/100 | Generator Loss: 0.7960138320922852 | Discriminator Loss: 0.6697055101394653


100%|██████████| 234/234 [02:03<00:00,  1.89it/s]


Epoch 70/100 | Generator Loss: 0.8366439938545227 | Discriminator Loss: 0.6792103052139282


100%|██████████| 234/234 [02:06<00:00,  1.84it/s]


Epoch 71/100 | Generator Loss: 0.770671010017395 | Discriminator Loss: 0.6804149150848389


100%|██████████| 234/234 [02:06<00:00,  1.85it/s]


Epoch 72/100 | Generator Loss: 0.7552427053451538 | Discriminator Loss: 0.6939644813537598


100%|██████████| 234/234 [02:04<00:00,  1.87it/s]


Epoch 73/100 | Generator Loss: 0.7815706133842468 | Discriminator Loss: 0.6834418177604675


100%|██████████| 234/234 [02:05<00:00,  1.86it/s]


Epoch 74/100 | Generator Loss: 0.8751962184906006 | Discriminator Loss: 0.673210620880127


100%|██████████| 234/234 [02:05<00:00,  1.86it/s]


Epoch 75/100 | Generator Loss: 0.775391697883606 | Discriminator Loss: 0.7037074565887451


100%|██████████| 234/234 [02:06<00:00,  1.84it/s]


Epoch 76/100 | Generator Loss: 0.7334122657775879 | Discriminator Loss: 0.699561595916748


100%|██████████| 234/234 [02:07<00:00,  1.83it/s]


Epoch 77/100 | Generator Loss: 0.8079275488853455 | Discriminator Loss: 0.694121241569519


100%|██████████| 234/234 [02:05<00:00,  1.86it/s]


Epoch 78/100 | Generator Loss: 0.7455877065658569 | Discriminator Loss: 0.7043793797492981


100%|██████████| 234/234 [02:05<00:00,  1.87it/s]


Epoch 79/100 | Generator Loss: 0.817805826663971 | Discriminator Loss: 0.6662888526916504


100%|██████████| 234/234 [02:07<00:00,  1.84it/s]


Epoch 80/100 | Generator Loss: 0.7745766043663025 | Discriminator Loss: 0.6922572255134583


100%|██████████| 234/234 [02:05<00:00,  1.86it/s]


Epoch 81/100 | Generator Loss: 0.8213468790054321 | Discriminator Loss: 0.6743678450584412


100%|██████████| 234/234 [02:05<00:00,  1.86it/s]


Epoch 82/100 | Generator Loss: 0.7712714076042175 | Discriminator Loss: 0.6945867538452148


100%|██████████| 234/234 [02:05<00:00,  1.87it/s]


Epoch 83/100 | Generator Loss: 0.8062005639076233 | Discriminator Loss: 0.6746228933334351


100%|██████████| 234/234 [02:06<00:00,  1.86it/s]


Epoch 84/100 | Generator Loss: 0.7782719135284424 | Discriminator Loss: 0.6844481825828552


100%|██████████| 234/234 [02:07<00:00,  1.83it/s]


Epoch 85/100 | Generator Loss: 0.8368995189666748 | Discriminator Loss: 0.6862497329711914


100%|██████████| 234/234 [02:06<00:00,  1.86it/s]


Epoch 86/100 | Generator Loss: 0.8261236548423767 | Discriminator Loss: 0.6736732125282288


100%|██████████| 234/234 [02:07<00:00,  1.83it/s]


Epoch 87/100 | Generator Loss: 0.7959935069084167 | Discriminator Loss: 0.696147084236145


100%|██████████| 234/234 [02:04<00:00,  1.88it/s]


Epoch 88/100 | Generator Loss: 0.7963318824768066 | Discriminator Loss: 0.6735004186630249


100%|██████████| 234/234 [02:07<00:00,  1.83it/s]


Epoch 89/100 | Generator Loss: 0.787826418876648 | Discriminator Loss: 0.6843894720077515


100%|██████████| 234/234 [02:06<00:00,  1.85it/s]


Epoch 90/100 | Generator Loss: 0.8171634078025818 | Discriminator Loss: 0.6750611066818237


100%|██████████| 234/234 [02:06<00:00,  1.86it/s]


Epoch 91/100 | Generator Loss: 0.8015762567520142 | Discriminator Loss: 0.67087322473526


100%|██████████| 234/234 [02:04<00:00,  1.87it/s]


Epoch 92/100 | Generator Loss: 0.818140983581543 | Discriminator Loss: 0.6868385076522827


100%|██████████| 234/234 [02:04<00:00,  1.89it/s]


Epoch 93/100 | Generator Loss: 0.740729570388794 | Discriminator Loss: 0.6803363561630249


100%|██████████| 234/234 [02:05<00:00,  1.87it/s]


Epoch 94/100 | Generator Loss: 0.7625223994255066 | Discriminator Loss: 0.706882655620575


100%|██████████| 234/234 [02:04<00:00,  1.88it/s]


Epoch 95/100 | Generator Loss: 0.7480475902557373 | Discriminator Loss: 0.6965979337692261


100%|██████████| 234/234 [02:03<00:00,  1.90it/s]


Epoch 96/100 | Generator Loss: 0.80366450548172 | Discriminator Loss: 0.7076535224914551


100%|██████████| 234/234 [02:05<00:00,  1.86it/s]


Epoch 97/100 | Generator Loss: 0.8222793936729431 | Discriminator Loss: 0.685386061668396


100%|██████████| 234/234 [02:05<00:00,  1.87it/s]


Epoch 98/100 | Generator Loss: 0.7330008745193481 | Discriminator Loss: 0.6944208741188049


100%|██████████| 234/234 [02:03<00:00,  1.89it/s]


Epoch 99/100 | Generator Loss: 0.7911421656608582 | Discriminator Loss: 0.6760014295578003


### Load the model and test

In [31]:
import torch
from torchvision.utils import save_image

# Assume you are in the same environment where UNetGenerator is defined
# or have imported it from your module.

# Create a new instance of the generator and load the weights
generator = UNetGenerator(input_channels=128, output_channels=3).cuda()
checkpoint_path = "/kaggle/working/generator_epoch_90.pth"  
generator.load_state_dict(torch.load(checkpoint_path))
generator.eval()


  generator.load_state_dict(torch.load(checkpoint_path))


UNetGenerator(
  (encoder): Sequential(
    (0): Conv2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): ReLU()
    (5): Conv2d(128, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (6): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): ReLU()
    (8): Conv2d(128, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (9): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU()
  )
  (middle): Sequential(
    (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(1

#### Using ViT. Requires sample image

In [33]:
from PIL import Image
from torchvision import transforms
import torch.nn.functional as F

# Load ViT and set to eval mode
vit_encoder = ViTModel(ViTConfig(hidden_size=128, num_attention_heads=4, num_hidden_layers=6)).cuda()
vit_encoder.eval()

# Same transformation used during training
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

# Load a sample image
test_image = Image.open("/kaggle/input/kitti-dataset/data_object_image_2/testing/image_2/000000.png").convert("RGB")
test_image = transform(test_image).unsqueeze(0).cuda()  # shape: [1, 3, 224, 224]

# Get ViT embeddings
with torch.no_grad():
    output = vit_encoder(test_image).last_hidden_state  # shape: [1, 197, 128]

# Remove CLS token
patch_embeddings = output[:, 1:, :]  # shape: [1, 196, 128]

# Reshape into spatial dimensions
num_patches = int(patch_embeddings.shape[1] ** 0.5)  # Should be 14 for 14x14
z = patch_embeddings.permute(0, 2, 1).view(1, 128, num_patches, num_patches)  # [1,128,14,14]

# Upsample to 224x224 if that's what your generator expects
z_upsampled = F.interpolate(z, size=(224, 224), mode='bilinear', align_corners=False)  # [1,128,224,224]

# Generate image
with torch.no_grad():
    gen_images = generator(z_upsampled)  # [1, 3, 224, 224]

# Save the generated image
# If you normalized your data with mean=0.5 and std=0.5, convert back
gen_images = gen_images * 0.5 + 0.5
save_image(gen_images, "generated_sample.png")


#### Using noise

In [36]:
import torch
from torchvision.utils import save_image
import torch.nn.functional as F

# Create random input (matching the shape generator expects)
# The generator expects [N, 128, 224, 224] since we upsampled z
z_random = torch.randn(1, 128, 224, 224).cuda()

with torch.no_grad():
    gen_images = generator(z_random)  # [1, 3, 224, 224]

# Convert back to [0,1] range if you used the same normalization
gen_images = gen_images * 0.5 + 0.5

# Save the generated image
save_image(gen_images, "generated_sample_random.png")
