In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torchvision.models import vgg16, VGG16_Weights
import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from PIL import Image
import kagglehub

# Configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Hyperparameters (optimized for 128x128 RGB vehicle images)
latent_dim = 256
num_classes = 5
batch_size = 64
epochs = 3000
learning_rate = 5e-4
image_size = 128
channels = 3
beta_max = 1.0
annealing_epochs = 100
perceptual_weight = 10.0
recon_weight = 1.0

# Output directories
output_dir = "FL_CVAE"
os.makedirs(output_dir, exist_ok=True)
os.makedirs(os.path.join(output_dir, "checkpoints"), exist_ok=True)
os.makedirs(os.path.join(output_dir, "samples"), exist_ok=True)

# Dataset Class (replace with your actual dataset)

## Dataset Class
class VehicleDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []
        self.class_to_idx = {}

        print(f"Scanning dataset directory: {root_dir}")
        if not os.path.isdir(root_dir):
            raise ValueError(f"Root directory {root_dir} does not exist or is not a directory.")

        for idx, class_dir in enumerate(sorted(os.listdir(root_dir))):
            class_path = os.path.join(root_dir, class_dir)
            if os.path.isdir(class_path):
                self.class_to_idx[class_dir] = idx
                image_files = [f for f in os.listdir(class_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
                if not image_files:
                    print(f"No images found in {class_path}")
                    continue
                for img_file in image_files:
                    img_path = os.path.join(class_path, img_file)
                    try:
                        with Image.open(img_path) as img:
                            img.verify()
                        self.image_paths.append(img_path)
                        self.labels.append(idx)
                    except Exception as e:
                        print(f"Skipping corrupted image {img_path}: {e}")

        print(f"Found {len(self.image_paths)} images across {len(self.class_to_idx)} classes")
        print(f"Class mapping: {self.class_to_idx}")
        if len(self.image_paths) == 0:
            raise ValueError("No valid images found in the dataset. Check the directory structure and file formats.")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        try:
            image = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            raise
        label = self.labels[idx]
        
        if self.transform:
            image = self.transform(image)
            
        return image, label

# Initialize dataset and dataloader
path = kagglehub.dataset_download("sujaykapadnis/vehicle-type-image-dataset")
print(f"Dataset path: {path}")
print(f"Directory contents: {os.listdir(path)}")
for subdir in os.listdir(path):
    subdir_path = os.path.join(path, subdir)
    if os.path.isdir(subdir_path):
        print(f"Contents of {subdir}: {os.listdir(subdir_path)}")

dataset = VehicleDataset(root_dir=path, transform=transform)
print(f"Dataset size: {len(dataset)}")
num_classes = len(dataset.class_to_idx)  # Update num_classes dynamically
print(f"Number of classes: {num_classes}")

train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
# Initialize dataset and dataloader
path = kagglehub.dataset_download("sujaykapadnis/vehicle-type-image-dataset")
print(f"Dataset path: {path}")
print(f"Directory contents: {os.listdir(path)}")

dataset = VehicleDataset(root_dir=path, transform=transform)
if len(dataset) == 0:
    raise ValueError("No images found in the dataset. Check the directory structure and file formats.")
print(f"Dataset size: {len(dataset)}")

train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)  # Set num_workers=0 for debugging
# Transformations
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Initialize dataset and dataloader
path = kagglehub.dataset_download("sujaykapadnis/vehicle-type-image-dataset")
dataset = VehicleDataset(root_dir= path, transform=transform)  # Replace with your dataset path
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

# Enhanced Encoder
class EnhancedEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.initial_conv = nn.Sequential(
            nn.Conv2d(channels + num_classes, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2)
        )
        self.down1 = nn.Sequential(
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2)
        )
        self.down2 = nn.Sequential(
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2)
        )
        self.down3 = nn.Sequential(
            nn.Conv2d(256, 512, 4, 2, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2)
        )
        self.down4 = nn.Sequential(
            nn.Conv2d(512, 512, 4, 2, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2)
        )
        self.fc_mu = nn.Linear(512 * 4 * 4, latent_dim)
        self.fc_var = nn.Linear(512 * 4 * 4, latent_dim)

    def forward(self, x, y):
        y = F.one_hot(y, num_classes).float()
        y = y.view(-1, num_classes, 1, 1).expand(-1, -1, x.size(2), x.size(3))
        x = torch.cat([x, y], dim=1)
        h0 = self.initial_conv(x)
        h1 = self.down1(h0)
        h2 = self.down2(h1)
        h3 = self.down3(h2)
        h4 = self.down4(h3)
        h4 = h4.view(h4.size(0), -1)
        return self.fc_mu(h4), self.fc_var(h4), (h0, h1, h2, h3)

# Enhanced Decoder
class EnhancedDecoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(latent_dim + num_classes, 512 * 4 * 4),
            nn.BatchNorm1d(512 * 4 * 4),
            nn.LeakyReLU(0.2)
        )
        self.up1 = nn.Sequential(
            nn.ConvTranspose2d(512, 512, 4, 2, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2)
        )
        self.up2 = nn.Sequential(
            nn.ConvTranspose2d(1024, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2)
        )
        self.up3 = nn.Sequential(
            nn.ConvTranspose2d(512, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2)
        )
        self.up4 = nn.Sequential(
            nn.ConvTranspose2d(256, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2)
        )
        self.final = nn.Sequential(
            nn.ConvTranspose2d(128, channels, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self, z, y, skip_connections):
        h0, h1, h2, h3 = skip_connections
        y = F.one_hot(y, num_classes).float()
        z = torch.cat([z, y], dim=1)
        h = self.fc(z)
        h = h.view(-1, 512, 4, 4)
        h = self.up1(h)
        h = torch.cat([h, h3], dim=1)
        h = self.up2(h)
        h = torch.cat([h, h2], dim=1)
        h = self.up3(h)
        h = torch.cat([h, h1], dim=1)
        h = self.up4(h)
        h = torch.cat([h, h0], dim=1)
        return self.final(h)

# Perceptual Loss
class PerceptualLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.vgg = vgg16(weights=VGG16_Weights.IMAGENET1K_V1).features[:16].to(device).eval()
        for param in self.vgg.parameters():
            param.requires_grad = False
        self.criterion = nn.L1Loss()
        self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))

    def forward(self, input, target):
        input = (input - self.mean) / self.std
        target = (target - self.mean) / self.std
        input_features = self.vgg(input)
        target_features = self.vgg(target)
        return self.criterion(input_features, target_features)

# CVAE Model
class ConditionalVAE(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = EnhancedEncoder().to(device)
        self.decoder = EnhancedDecoder().to(device)
        self.perceptual_loss = PerceptualLoss().to(device)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x, y):
        mu, logvar, skips = self.encoder(x, y)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decoder(z, y, skips)
        return x_recon, mu, logvar

    def loss_function(self, x, x_recon, mu, logvar, beta=1.0):
        recon_loss = F.l1_loss(x_recon, x, reduction='sum')
        kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        percep_loss = self.perceptual_loss(x_recon, x) * perceptual_weight
        total_loss = recon_weight * recon_loss + beta * kl_loss + percep_loss
        return total_loss, recon_loss, kl_loss, percep_loss

# Training function
def train():
    cvae = ConditionalVAE().to(device)
    optimizer = optim.Adam(cvae.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=50, factor=0.5)
    
    for epoch in range(epochs):
        beta = min(beta_max, beta_max * (epoch / annealing_epochs))
        total_loss = 0
        recon_loss = 0
        kl_loss = 0
        percep_loss = 0
        
        pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs}')
        for batch_idx, (data, labels) in enumerate(pbar):
            data, labels = data.to(device), labels.to(device)
            
            optimizer.zero_grad()
            x_recon, mu, logvar = cvae(data, labels)
            loss, r_loss, k_loss, p_loss = cvae.loss_function(data, x_recon, mu, logvar, beta)
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(cvae.parameters(), 1.0)
            optimizer.step()
            
            total_loss += loss.item()
            recon_loss += r_loss.item()
            kl_loss += k_loss.item()
            percep_loss += p_loss.item()
            
            pbar.set_postfix({
                'Loss': total_loss/(batch_idx+1),
                'Recon': recon_loss/(batch_idx+1),
                'KL': kl_loss/(batch_idx+1),
                'Percep': percep_loss/(batch_idx+1),
                'Beta': beta
            })
        
        avg_loss = total_loss / len(train_loader)
        scheduler.step(avg_loss)
        
        if (epoch + 1) % 50 == 0:
            checkpoint_path = os.path.join(output_dir, "checkpoints", f"cvae_epoch_{epoch+1}.pth")
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': cvae.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': avg_loss,
            }, checkpoint_path)
            generate_samples(cvae, epoch + 1)
    
    final_path = os.path.join(output_dir, "cvae_final.pth")
    torch.save(cvae.state_dict(), final_path)

# Sample generation function
def generate_samples(model, epoch, num_samples=100):
    model.eval()
    os.makedirs(os.path.join(output_dir, "samples", f"epoch_{epoch}"), exist_ok=True)
    
    with torch.no_grad():
        for class_idx in range(num_classes):
            z = torch.randn(num_samples, latent_dim).to(device)
            labels = torch.full((num_samples,), class_idx, dtype=torch.long).to(device)
            dummy_skips = (
                torch.zeros(num_samples, 64, 64, 64).to(device),
                torch.zeros(num_samples, 128, 32, 32).to(device),
                torch.zeros(num_samples, 256, 16, 16).to(device),
                torch.zeros(num_samples, 512, 8, 8).to(device)
            )
            samples = model.decoder(z, labels, dummy_skips)
            samples = (samples + 1) / 2
            
            class_dir = os.path.join(output_dir, "samples", f"epoch_{epoch}", f"class_{class_idx}")
            os.makedirs(class_dir, exist_ok=True)
            
            for i in range(num_samples):
                save_image(samples[i], os.path.join(class_dir, f"sample_{i}.png"))
    
    print(f"Generated samples for epoch {epoch}")

if __name__ == "__main__":
    print("Starting training...")
    train()
    print("Training complete!")

Using device: cuda
Dataset path: C:\Users\hp\.cache\kagglehub\datasets\sujaykapadnis\vehicle-type-image-dataset\versions\1
Directory contents: ['Vehicle Type Image Dataset (Version 2) VTID2']
Contents of Vehicle Type Image Dataset (Version 2) VTID2: ['Hatchback', 'Other', 'Pickup', 'Seden', 'SUV']
Scanning dataset directory: C:\Users\hp\.cache\kagglehub\datasets\sujaykapadnis\vehicle-type-image-dataset\versions\1
No images found in C:\Users\hp\.cache\kagglehub\datasets\sujaykapadnis\vehicle-type-image-dataset\versions\1\Vehicle Type Image Dataset (Version 2) VTID2
Found 0 images across 1 classes
Class mapping: {'Vehicle Type Image Dataset (Version 2) VTID2': 0}


ValueError: No valid images found in the dataset. Check the directory structure and file formats.