In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.utils import save_image
import torchvision.models as models
from torchvision.models import VGG16_Weights
import os
import numpy as np
import matplotlib.pyplot as plt
import kagglehub
from PIL import Image

# Device configuration
device = torch.device('cuda:1' if torch.cuda.device_count() > 1 else 'cuda')

# Hyperparameters
latent_dim = 256
intermediate_dim = 1024
num_classes = 5
batch_size = 64
epochs = 3000
learning_rate = 3e-4
image_size = 128
channels = 3
output_dir = "FL_CVAE"
beta_max = 1.0
annealing_epochs = 100
recon_weight = 0.5
perceptual_weight = 5.0
save_freq = 50
weight_decay = 1e-5  # Add weight decay for regularization
patience = 50  # Early stopping patience
min_delta = 0.001  # Minimum improvement for early stopping

# Create output directory
os.makedirs(output_dir, exist_ok=True)

# Download Vehicle Type Image Dataset from Kaggle
try:
    path = kagglehub.dataset_download("sujaykapadnis/vehicle-type-image-dataset")
    print("Path to dataset files:", path)
    dataset_path = path
except Exception as e:
    print(f"Failed to download dataset: {e}")
    raise

# Define the VehicleTypeDataset class
class VehicleTypeDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.images = []
        self.labels = []
        self.class_names = []
        self.class_to_idx = {}

        print(f"Searching for images in {root_dir}")
        for root, dirs, files in os.walk(root_dir):
            image_files = [f for f in files if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
            if image_files:
                class_name = os.path.basename(root)
                if class_name not in self.class_to_idx:
                    self.class_names.append(class_name)
                    self.class_to_idx[class_name] = len(self.class_names) - 1
                for img_file in image_files:
                    img_path = os.path.join(root, img_file)
                    self.images.append(img_path)
                    self.labels.append(self.class_to_idx[class_name])

        if not self.images:
            raise ValueError(
                f"No images found in {root_dir}. "
                "Expected class folders containing .jpg, .png, or .jpeg images."
            )

        print(f"Found {len(self.images)} images across {len(self.class_names)} classes.")
        print(f"Classes: {self.class_names}")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, label

# Define transforms
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
])

# Load the dataset
dataset = VehicleTypeDataset(root_dir=dataset_path, transform=transform)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Define the Enhanced Encoder network
class EnhancedEncoder(nn.Module):
    def __init__(self, latent_dim, intermediate_dim, num_classes):
        super(EnhancedEncoder, self).__init__()
        self.conv1 = nn.Conv2d(channels + num_classes, 64, kernel_size=4, stride=2, padding=1)  # 64x64x64
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)  # 128x32x32
        self.bn2 = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)  # 256x16x16
        self.bn3 = nn.BatchNorm2d(256)
        self.conv4 = nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1)  # 512x8x8
        self.bn4 = nn.BatchNorm2d(512)
        self.conv5 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)  # 512x4x4
        self.bn5 = nn.BatchNorm2d(512)
        self.conv6 = nn.Conv2d(512, 1024, kernel_size=4, stride=2, padding=1)  # 1024x2x2
        self.bn6 = nn.BatchNorm2d(1024)
        
        self.fc_intermediate = nn.Linear(1024 * 2 * 2, intermediate_dim)
        self.bn_intermediate = nn.BatchNorm1d(intermediate_dim)
        
        self.fc_mean = nn.Linear(intermediate_dim, latent_dim)
        self.fc_logvar = nn.Linear(intermediate_dim, latent_dim)

    def forward(self, x, y):
        y = F.one_hot(y, num_classes=num_classes).float()
        y = y.unsqueeze(-1).unsqueeze(-1)
        y = y.expand(-1, -1, x.size(2), x.size(3))
        x_with_y = torch.cat([x, y], dim=1)
        
        h1 = F.leaky_relu(self.bn1(self.conv1(x_with_y)), negative_slope=0.2)
        h2 = F.leaky_relu(self.bn2(self.conv2(h1)), negative_slope=0.2)
        h3 = F.leaky_relu(self.bn3(self.conv3(h2)), negative_slope=0.2)
        h4 = F.leaky_relu(self.bn4(self.conv4(h3)), negative_slope=0.2)
        h5 = F.leaky_relu(self.bn5(self.conv5(h4)), negative_slope=0.2)
        h6 = F.leaky_relu(self.bn6(self.conv6(h5)), negative_slope=0.2)
        
        h = h6.view(h6.size(0), -1)  # Fixed: Use h6.size(0)
        h = F.leaky_relu(self.bn_intermediate(self.fc_intermediate(h)), negative_slope=0.2)
        
        z_mean = self.fc_mean(h)
        z_logvar = self.fc_logvar(h)
        return z_mean, z_logvar, (h1, h2, h3, h4, h5, h6)

# Define the Enhanced Decoder network
class EnhancedDecoder(nn.Module):
    def __init__(self, latent_dim, intermediate_dim, num_classes):
        super(EnhancedDecoder, self).__init__()
        self.fc = nn.Linear(latent_dim + num_classes, intermediate_dim)
        self.bn_fc = nn.BatchNorm1d(intermediate_dim)
        self.fc_to_features = nn.Linear(intermediate_dim, 1024 * 2 * 2)
        
        self.deconv1 = nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1)  # 512x4x4
        self.bn1 = nn.BatchNorm2d(512)
        self.deconv2 = nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1)  # 512x8x8
        self.bn2 = nn.BatchNorm2d(512)
        self.deconv3 = nn.ConvTranspose2d(1024, 256, kernel_size=4, stride=2, padding=1)  # 256x16x16
        self.bn3 = nn.BatchNorm2d(256)
        self.deconv4 = nn.ConvTranspose2d(512, 128, kernel_size=4, stride=2, padding=1)  # 128x32x32
        self.bn4 = nn.BatchNorm2d(128)
        self.deconv5 = nn.ConvTranspose2d(256, 64, kernel_size=4, stride=2, padding=1)  # 64x64x64
        self.bn5 = nn.BatchNorm2d(64)
        self.deconv6 = nn.ConvTranspose2d(128, channels, kernel_size=4, stride=2, padding=1)  # 3x128x128

    def forward(self, z, y, skip_connections):
        h1, h2, h3, h4, h5, h6 = skip_connections
        y = F.one_hot(y, num_classes=num_classes).float()
        z_with_y = torch.cat([z, y], dim=-1)
        
        h = F.leaky_relu(self.bn_fc(self.fc(z_with_y)), negative_slope=0.2)
        h = self.fc_to_features(h)
        h = h.view(h.size(0), 1024, 2, 2)
        
        h = F.leaky_relu(self.bn1(self.deconv1(h)), negative_slope=0.2)
        h = torch.cat([h, h5], dim=1)
        h = F.leaky_relu(self.bn2(self.deconv2(h)), negative_slope=0.2)
        h = torch.cat([h, h4], dim=1)
        h = F.leaky_relu(self.bn3(self.deconv3(h)), negative_slope=0.2)
        h = torch.cat([h, h3], dim=1)
        h = F.leaky_relu(self.bn4(self.deconv4(h)), negative_slope=0.2)
        h = torch.cat([h, h2], dim=1)
        h = F.leaky_relu(self.bn5(self.deconv5(h)), negative_slope=0.2)
        h = torch.cat([h, h1], dim=1)
        x_reconstructed = self.deconv6(h)
        return torch.clamp(x_reconstructed, 0, 1)

# Conditional VAE model
class ConditionalVAE(nn.Module):
    def __init__(self, encoder, decoder):
        super(ConditionalVAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def reparameterize(self, z_mean, z_logvar):
        std = torch.exp(0.5 * z_logvar)
        eps = torch.randn_like(std)
        return z_mean + eps * std

    def forward(self, x, y):
        z_mean, z_logvar, skip_connections = self.encoder(x, y)
        z = self.reparameterize(z_mean, z_logvar)
        x_reconstructed = self.decoder(z, y, skip_connections)
        return x_reconstructed, z_mean, z_logvar

# Load pretrained VGG16 for perceptual loss
vgg = models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1).features.to(device).eval()
for param in vgg.parameters():
    param.requires_grad = False

def perceptual_loss(x, x_reconstructed):
    mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(x.device)
    std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(x.device)
    x_normalized = (x - mean) / std
    x_reconstructed_normalized = (x_reconstructed - mean) / std
    x_features = vgg(x_normalized)
    x_recon_features = vgg(x_reconstructed_normalized)
    return F.mse_loss(x_features, x_recon_features)

# Instantiate the model
encoder = EnhancedEncoder(latent_dim, intermediate_dim, num_classes).to(device)
decoder = EnhancedDecoder(latent_dim, intermediate_dim, num_classes).to(device)
cvae = ConditionalVAE(encoder, decoder).to(device)

# Define optimizer and scheduler
optimizer = optim.Adam(cvae.parameters(), lr=learning_rate, weight_decay=weight_decay)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=20, factor=0.5)

# Define loss function
def cvae_loss(x, x_reconstructed, z_mean, z_logvar, beta=1.0, recon_weight=1.0, perceptual_weight=1.0):
    recon_loss = F.mse_loss(x_reconstructed, x, reduction='sum')
    kl_loss = -0.5 * torch.sum(1 + z_logvar - z_mean.pow(2) - z_logvar.exp()) + 1e-6  # Add small constant
    percep_loss = perceptual_loss(x, x_reconstructed) * perceptual_weight
    total_loss = recon_weight * recon_loss + beta * kl_loss + percep_loss
    return total_loss, recon_loss, kl_loss, percep_loss

# Generate samples for all classes
def generate_samples_labelwise(cvae, num_samples, classes_to_generate, base_dir, latent_dim, device, epoch):
    cvae.eval()
    os.makedirs(base_dir, exist_ok=True)
    with torch.no_grad():
        for class_label in classes_to_generate:
            label_tensor = torch.tensor([class_label]).repeat(num_samples).to(device)
            z = torch.randn(num_samples, latent_dim).to(device)
            dummy_skips = [
                torch.randn(num_samples, 64, 64, 64).to(device) * 0.1,
                torch.randn(num_samples, 128, 32, 32).to(device) * 0.1,
                torch.randn(num_samples, 256, 16, 16).to(device) * 0.1,
                torch.randn(num_samples, 512, 8, 8).to(device) * 0.1,
                torch.randn(num_samples, 512, 4, 4).to(device) * 0.1,
                torch.randn(num_samples, 1024, 2, 2).to(device) * 0.1
            ]
            generated_samples = cvae.decoder(z, label_tensor, dummy_skips)
            
            class_dir = os.path.join(base_dir, f"class_{class_label}")
            os.makedirs(class_dir, exist_ok=True)
            for idx, sample in enumerate(generated_samples):
                save_image(sample, os.path.join(class_dir, f"sample_{idx}_epoch_{epoch}.png"))
            print(f"Epoch {epoch}: Generated {num_samples} samples for Class {class_label} ({dataset.class_names[class_label]}).")

# Plot random samples for all classes
def plot_random_samples(base_dir, classes_to_generate, num_images_per_class, epoch):
    fig, axs = plt.subplots(len(classes_to_generate), num_images_per_class, figsize=(20, len(classes_to_generate) * 2))
    for row, class_label in enumerate(classes_to_generate):
        class_dir = os.path.join(base_dir, f"class_{class_label}")
        sample_files = [f for f in os.listdir(class_dir) if f.endswith(f"epoch_{epoch}.png")]
        if len(sample_files) < num_images_per_class:
            continue
        random_samples = np.random.choice(sample_files, num_images_per_class, replace=False)
        
        for col, sample_file in enumerate(random_samples):
            sample_path = os.path.join(class_dir, sample_file)
            sample_image = plt.imread(sample_path)
            ax = axs[row, col] if len(classes_to_generate) > 1 else axs[col]
            ax.imshow(sample_image)
            ax.axis('off')
            if col == 0:
                ax.set_ylabel(dataset.class_names[class_label], rotation=90, labelpad=10)
    
    plt.tight_layout()
    plot_path = os.path.join(output_dir, f"synthetic_samples_epoch_{epoch}.png")
    plt.savefig(plot_path)
    plt.close()
    print(f"Saved synthetic samples plot for epoch {epoch} to {plot_path}")

# Training loop with early stopping
cvae.train()
best_loss = float('inf')
patience_counter = 0
for epoch in range(epochs):
    if epoch < annealing_epochs:
        beta = beta_max * (epoch / annealing_epochs)
    else:
        beta = beta_max

    train_loss = 0
    train_recon_loss = 0
    train_kl_loss = 0
    train_percep_loss = 0
    for batch_idx, (data, labels) in enumerate(train_loader):
        data, labels = data.to(device), labels.to(device)

        optimizer.zero_grad()
        x_reconstructed, z_mean, z_logvar = cvae(data, labels)
        total_loss, recon_loss, kl_loss, percep_loss = cvae_loss(
            data, x_reconstructed, z_mean, z_logvar, 
            beta=beta, recon_weight=recon_weight, perceptual_weight=perceptual_weight
        )
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(cvae.parameters(), max_norm=1.0)
        train_loss += total_loss.item()
        train_recon_loss += recon_loss.item()
        train_kl_loss += kl_loss.item()
        train_percep_loss += percep_loss.item()
        optimizer.step()

    avg_loss = train_loss / len(train_loader.dataset)
    scheduler.step(avg_loss)
    avg_recon_loss = train_recon_loss / len(train_loader.dataset)
    avg_kl_loss = train_kl_loss / len(train_loader.dataset)
    avg_percep_loss = train_percep_loss / len(train_loader.dataset)
    print(f'Epoch {epoch + 1}/{epochs}, Beta: {beta:.2f}, Total Loss: {avg_loss:.4f}, '
          f'Recon Loss: {avg_recon_loss:.4f}, KL Loss: {avg_kl_loss:.4f}, Percep Loss: {avg_percep_loss:.4f}')

    # Early stopping
    if avg_loss < best_loss - min_delta:
        best_loss = avg_loss
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch + 1} due to no improvement in loss.")
            break

    # Save checkpoint and generate samples every 50 epochs
    if (epoch + 1) % save_freq == 0:
        checkpoint_path = os.path.join(output_dir, f"cvae_vehicle_epoch_{epoch + 1}.pth")
        torch.save(cvae.state_dict(), checkpoint_path)
        print(f"Saved checkpoint at epoch {epoch + 1} to {checkpoint_path}")

        base_dir = os.path.join(output_dir, f"generated_samples_epoch_{epoch + 1}")
        classes_to_generate = list(range(num_classes))
        generate_samples_labelwise(cvae, num_samples=500, classes_to_generate=classes_to_generate, 
                                 base_dir=base_dir, latent_dim=latent_dim, device=device, epoch=epoch + 1)
        
        plot_random_samples(base_dir=base_dir, classes_to_generate=classes_to_generate, 
                           num_images_per_class=10, epoch=epoch + 1)

# Save the final model
model_path = os.path.join(output_dir, "cvae_vehicle_final.pth")
torch.save(cvae.state_dict(), model_path)
print(f"Saved final CVAE model to {model_path}")

# Final generation and plotting
base_dir = os.path.join(output_dir, "generated_samples_final")
generate_samples_labelwise(cvae, num_samples=500, classes_to_generate=classes_to_generate, 
                         base_dir=base_dir, latent_dim=latent_dim, device=device, epoch=epochs)
plot_random_samples(base_dir=base_dir, classes_to_generate=classes_to_generate, 
                   num_images_per_class=10, epoch=epochs)

Path to dataset files: C:\Users\hp\.cache\kagglehub\datasets\sujaykapadnis\vehicle-type-image-dataset\versions\1
Searching for images in C:\Users\hp\.cache\kagglehub\datasets\sujaykapadnis\vehicle-type-image-dataset\versions\1
Found 4793 images across 5 classes.
Classes: ['Hatchback', 'Other', 'Pickup', 'Seden', 'SUV']
Epoch 1/3000, Beta: 0.00, Total Loss: 1234.7925, Recon Loss: 2469.1533, KL Loss: 51425368689866216.0000, Percep Loss: 0.2159
Epoch 2/3000, Beta: 0.01, Total Loss: 248366.8210, Recon Loss: 678.2645, KL Loss: 24802751.1481, Percep Loss: 0.1792
Epoch 3/3000, Beta: 0.02, Total Loss: 199.4490, Recon Loss: 396.2839, KL Loss: 57.2738, Percep Loss: 0.1616
Epoch 4/3000, Beta: 0.03, Total Loss: 138.1492, Recon Loss: 273.6968, KL Loss: 38.4055, Percep Loss: 0.1487
Epoch 5/3000, Beta: 0.04, Total Loss: 115.7033, Recon Loss: 228.4818, KL Loss: 33.0613, Percep Loss: 0.1399
Epoch 6/3000, Beta: 0.05, Total Loss: 89.1474, Recon Loss: 175.5803, KL Loss: 24.6396, Percep Loss: 0.1253
Epoch 

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.utils import save_image
import torchvision.models as models
from torchvision.models import VGG16_Weights
import os
import numpy as np
import matplotlib.pyplot as plt
import kagglehub
from PIL import Image

# Device configuration
device = torch.device('cuda:1' if torch.cuda.device_count() > 1 else 'cuda')

# Hyperparameters
latent_dim = 256
intermediate_dim = 1024
num_classes = 5
batch_size = 64
epochs = 3000
learning_rate = 3e-4
image_size = 128
channels = 3
output_dir = "FL_CVAE"
beta_max = 1.0
annealing_epochs = 100
recon_weight = 0.5
perceptual_weight = 5.0
weight_decay = 1e-5
patience = 50
min_delta = 0.001

# Create output directory
os.makedirs(output_dir, exist_ok=True)

# Download Vehicle Type Image Dataset from Kaggle
try:
    path = kagglehub.dataset_download("sujaykapadnis/vehicle-type-image-dataset")
    print("Path to dataset files:", path)
    dataset_path = path
except Exception as e:
    print(f"Failed to download dataset: {e}")
    raise

# Define the VehicleTypeDataset class
class VehicleTypeDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.images = []
        self.labels = []
        self.class_names = []
        self.class_to_idx = {}

        print(f"Searching for images in {root_dir}")
        for root, dirs, files in os.walk(root_dir):
            image_files = [f for f in files if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
            if image_files:
                class_name = os.path.basename(root)
                if class_name not in self.class_to_idx:
                    self.class_names.append(class_name)
                    self.class_to_idx[class_name] = len(self.class_names) - 1
                for img_file in image_files:
                    img_path = os.path.join(root, img_file)
                    self.images.append(img_path)
                    self.labels.append(self.class_to_idx[class_name])

        if not self.images:
            raise ValueError(
                f"No images found in {root_dir}. "
                "Expected class folders containing .jpg, .png, or .jpeg images."
            )

        print(f"Found {len(self.images)} images across {len(self.class_names)} classes.")
        print(f"Classes: {self.class_names}")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, label

# Define transforms
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
])

# Load the dataset
dataset = VehicleTypeDataset(root_dir=dataset_path, transform=transform)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Define the Enhanced Encoder network
class EnhancedEncoder(nn.Module):
    def __init__(self, latent_dim, intermediate_dim, num_classes):
        super(EnhancedEncoder, self).__init__()
        self.conv1 = nn.Conv2d(channels + num_classes, 64, kernel_size=4, stride=2, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        self.conv4 = nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1)
        self.bn4 = nn.BatchNorm2d(512)
        self.conv5 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
        self.bn5 = nn.BatchNorm2d(512)
        self.conv6 = nn.Conv2d(512, 1024, kernel_size=4, stride=2, padding=1)
        self.bn6 = nn.BatchNorm2d(1024)
        
        self.fc_intermediate = nn.Linear(1024 * 2 * 2, intermediate_dim)
        self.bn_intermediate = nn.BatchNorm1d(intermediate_dim)
        
        self.fc_mean = nn.Linear(intermediate_dim, latent_dim)
        self.fc_logvar = nn.Linear(intermediate_dim, latent_dim)

    def forward(self, x, y):
        y = F.one_hot(y, num_classes=num_classes).float()
        y = y.unsqueeze(-1).unsqueeze(-1)
        y = y.expand(-1, -1, x.size(2), x.size(3))
        x_with_y = torch.cat([x, y], dim=1)
        
        h1 = F.leaky_relu(self.bn1(self.conv1(x_with_y)), negative_slope=0.2)
        h2 = F.leaky_relu(self.bn2(self.conv2(h1)), negative_slope=0.2)
        h3 = F.leaky_relu(self.bn3(self.conv3(h2)), negative_slope=0.2)
        h4 = F.leaky_relu(self.bn4(self.conv4(h3)), negative_slope=0.2)
        h5 = F.leaky_relu(self.bn5(self.conv5(h4)), negative_slope=0.2)
        h6 = F.leaky_relu(self.bn6(self.conv6(h5)), negative_slope=0.2)
        
        h = h6.view(h6.size(0), -1)
        h = F.leaky_relu(self.bn_intermediate(self.fc_intermediate(h)), negative_slope=0.2)
        
        z_mean = self.fc_mean(h)
        z_logvar = self.fc_logvar(h)
        return z_mean, z_logvar, (h1, h2, h3, h4, h5, h6)

# Define the Enhanced Decoder network
class EnhancedDecoder(nn.Module):
    def __init__(self, latent_dim, intermediate_dim, num_classes):
        super(EnhancedDecoder, self).__init__()
        self.fc = nn.Linear(latent_dim + num_classes, intermediate_dim)
        self.bn_fc = nn.BatchNorm1d(intermediate_dim)
        self.fc_to_features = nn.Linear(intermediate_dim, 1024 * 2 * 2)
        
        self.deconv1 = nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1)
        self.bn1 = nn.BatchNorm2d(512)
        self.deconv2 = nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(512)
        self.deconv3 = nn.ConvTranspose2d(1024, 256, kernel_size=4, stride=2, padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        self.deconv4 = nn.ConvTranspose2d(512, 128, kernel_size=4, stride=2, padding=1)
        self.bn4 = nn.BatchNorm2d(128)
        self.deconv5 = nn.ConvTranspose2d(256, 64, kernel_size=4, stride=2, padding=1)
        self.bn5 = nn.BatchNorm2d(64)
        self.deconv6 = nn.ConvTranspose2d(128, channels, kernel_size=4, stride=2, padding=1)

    def forward(self, z, y, skip_connections):
        h1, h2, h3, h4, h5, h6 = skip_connections
        y = F.one_hot(y, num_classes=num_classes).float()
        z_with_y = torch.cat([z, y], dim=-1)
        
        h = F.leaky_relu(self.bn_fc(self.fc(z_with_y)), negative_slope=0.2)
        h = self.fc_to_features(h)
        h = h.view(h.size(0), 1024, 2, 2)
        
        h = F.leaky_relu(self.bn1(self.deconv1(h)), negative_slope=0.2)
        h = torch.cat([h, h5], dim=1)
        h = F.leaky_relu(self.bn2(self.deconv2(h)), negative_slope=0.2)
        h = torch.cat([h, h4], dim=1)
        h = F.leaky_relu(self.bn3(self.deconv3(h)), negative_slope=0.2)
        h = torch.cat([h, h3], dim=1)
        h = F.leaky_relu(self.bn4(self.deconv4(h)), negative_slope=0.2)
        h = torch.cat([h, h2], dim=1)
        h = F.leaky_relu(self.bn5(self.deconv5(h)), negative_slope=0.2)
        h = torch.cat([h, h1], dim=1)
        x_reconstructed = self.deconv6(h)
        return torch.clamp(x_reconstructed, 0, 1)

# Conditional VAE model
class ConditionalVAE(nn.Module):
    def __init__(self, encoder, decoder):
        super(ConditionalVAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def reparameterize(self, z_mean, z_logvar):
        std = torch.exp(0.5 * z_logvar)
        eps = torch.randn_like(std)
        return z_mean + eps * std

    def forward(self, x, y):
        z_mean, z_logvar, skip_connections = self.encoder(x, y)
        z = self.reparameterize(z_mean, z_logvar)
        x_reconstructed = self.decoder(z, y, skip_connections)
        return x_reconstructed, z_mean, z_logvar

# Load pretrained VGG16 for perceptual loss
vgg = models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1).features.to(device).eval()
for param in vgg.parameters():
    param.requires_grad = False

def perceptual_loss(x, x_reconstructed):
    mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(x.device)
    std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(x.device)
    x_normalized = (x - mean) / std
    x_reconstructed_normalized = (x_reconstructed - mean) / std
    x_features = vgg(x_normalized)
    x_recon_features = vgg(x_reconstructed_normalized)
    return F.mse_loss(x_features, x_recon_features)

# Instantiate the model
encoder = EnhancedEncoder(latent_dim, intermediate_dim, num_classes).to(device)
decoder = EnhancedDecoder(latent_dim, intermediate_dim, num_classes).to(device)
cvae = ConditionalVAE(encoder, decoder).to(device)

# Load checkpoint
checkpoint_path = os.path.join(output_dir, "cvae_vehicle_epoch_50.pth")
if os.path.exists(checkpoint_path):
    cvae.load_state_dict(torch.load(checkpoint_path, map_location=device))
    print(f"Loaded checkpoint from {checkpoint_path}")
else:
    raise FileNotFoundError(f"Checkpoint file {checkpoint_path} not found")

# Define optimizer and scheduler
optimizer = optim.Adam(cvae.parameters(), lr=learning_rate, weight_decay=weight_decay)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=20, factor=0.5)

# Define loss function
def cvae_loss(x, x_reconstructed, z_mean, z_logvar, beta=1.0, recon_weight=1.0, perceptual_weight=1.0):
    recon_loss = F.mse_loss(x_reconstructed, x, reduction='sum')
    kl_loss = -0.5 * torch.sum(1 + z_logvar - z_mean.pow(2) - z_logvar.exp()) + 1e-6
    percep_loss = perceptual_loss(x, x_reconstructed) * perceptual_weight
    total_loss = recon_weight * recon_loss + beta * kl_loss + percep_loss
    return total_loss, recon_loss, kl_loss, percep_loss

# Generate samples for all classes
def generate_samples_labelwise(cvae, num_samples, classes_to_generate, base_dir, latent_dim, device, epoch):
    cvae.eval()
    os.makedirs(base_dir, exist_ok=True)
    with torch.no_grad():
        for class_label in classes_to_generate:
            label_tensor = torch.tensor([class_label]).repeat(num_samples).to(device)
            z = torch.randn(num_samples, latent_dim).to(device)
            dummy_skips = [
                torch.randn(num_samples, 64, 64, 64).to(device) * 0.1,
                torch.randn(num_samples, 128, 32, 32).to(device) * 0.1,
                torch.randn(num_samples, 256, 16, 16).to(device) * 0.1,
                torch.randn(num_samples, 512, 8, 8).to(device) * 0.1,
                torch.randn(num_samples, 512, 4, 4).to(device) * 0.1,
                torch.randn(num_samples, 1024, 2, 2).to(device) * 0.1
            ]
            generated_samples = cvae.decoder(z, label_tensor, dummy_skips)
            
            class_dir = os.path.join(base_dir, f"class_{class_label}")
            os.makedirs(class_dir, exist_ok=True)
            for idx, sample in enumerate(generated_samples):
                save_image(sample, os.path.join(class_dir, f"sample_{idx}_epoch_{epoch}.png"))
            print(f"Epoch {epoch}: Generated {num_samples} samples for Class {class_label} ({dataset.class_names[class_label]}).")

# Plot random samples for all classes
def plot_random_samples(base_dir, classes_to_generate, num_images_per_class, epoch):
    fig, axs = plt.subplots(len(classes_to_generate), num_images_per_class, figsize=(20, len(classes_to_generate) * 2))
    for row, class_label in enumerate(classes_to_generate):
        class_dir = os.path.join(base_dir, f"class_{class_label}")
        sample_files = [f for f in os.listdir(class_dir) if f.endswith(f"epoch_{epoch}.png")]
        if len(sample_files) < num_images_per_class:
            continue
        random_samples = np.random.choice(sample_files, num_images_per_class, replace=False)
        
        for col, sample_file in enumerate(random_samples):
            sample_path = os.path.join(class_dir, sample_file)
            sample_image = plt.imread(sample_path)
            ax = axs[row, col] if len(classes_to_generate) > 1 else axs[col]
            ax.imshow(sample_image)
            ax.axis('off')
            if col == 0:
                ax.set_ylabel(dataset.class_names[class_label], rotation=90, labelpad=10)
    
    plt.tight_layout()
    plot_path = os.path.join(output_dir, f"synthetic_samples_epoch_{epoch}.png")
    plt.savefig(plot_path)
    plt.close()
    print(f"Saved synthetic samples plot for epoch {epoch} to {plot_path}")

# Training loop with early stopping
cvae.train()
best_loss = float('inf')
patience_counter = 0
start_epoch = 50  # Start from epoch 50
classes_to_generate = list(range(num_classes))

for epoch in range(start_epoch, epochs):
    if epoch < annealing_epochs:
        beta = beta_max * (epoch / annealing_epochs)
    else:
        beta = beta_max

    train_loss = 0
    train_recon_loss = 0
    train_kl_loss = 0
    train_percep_loss = 0
    for batch_idx, (data, labels) in enumerate(train_loader):
        data, labels = data.to(device), labels.to(device)

        optimizer.zero_grad()
        x_reconstructed, z_mean, z_logvar = cvae(data, labels)
        total_loss, recon_loss, kl_loss, percep_loss = cvae_loss(
            data, x_reconstructed, z_mean, z_logvar, 
            beta=beta, recon_weight=recon_weight, perceptual_weight=perceptual_weight
        )
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(cvae.parameters(), max_norm=1.0)
        train_loss += total_loss.item()
        train_recon_loss += recon_loss.item()
        train_kl_loss += kl_loss.item()
        train_percep_loss += percep_loss.item()
        optimizer.step()

    avg_loss = train_loss / len(train_loader.dataset)
    scheduler.step(avg_loss)
    avg_recon_loss = train_recon_loss / len(train_loader.dataset)
    avg_kl_loss = train_kl_loss / len(train_loader.dataset)
    avg_percep_loss = train_percep_loss / len(train_loader.dataset)
    print(f'Epoch {epoch + 1}/{epochs}, Beta: {beta:.2f}, Total Loss: {avg_loss:.4f}, '
          f'Recon Loss: {avg_recon_loss:.4f}, KL Loss: {avg_kl_loss:.4f}, Percep Loss: {avg_percep_loss:.4f}')

    # Early stopping
    if avg_loss < best_loss - min_delta:
        best_loss = avg_loss
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch + 1} due to no improvement in loss.")
            # Generate and plot samples at early stopping
            base_dir = os.path.join(output_dir, f"generated_samples_epoch_{epoch + 1}")
            generate_samples_labelwise(cvae, num_samples=500, classes_to_generate=classes_to_generate, 
                                     base_dir=base_dir, latent_dim=latent_dim, device=device, epoch=epoch + 1)
            plot_random_samples(base_dir=base_dir, classes_to_generate=classes_to_generate, 
                               num_images_per_class=10, epoch=epoch + 1)
            break

# Save the final model
model_path = os.path.join(output_dir, "cvae_vehicle_final.pth")
torch.save(cvae.state_dict(), model_path)
print(f"Saved final CVAE model to {model_path}")

# Generate and plot samples at the final epoch (either 3000 or early stopping)
base_dir = os.path.join(output_dir, "generated_samples_final")
generate_samples_labelwise(cvae, num_samples=500, classes_to_generate=classes_to_generate, 
                         base_dir=base_dir, latent_dim=latent_dim, device=device, epoch=epoch + 1)
plot_random_samples(base_dir=base_dir, classes_to_generate=classes_to_generate, 
                   num_images_per_class=10, epoch=epoch + 1)

Path to dataset files: C:\Users\hp\.cache\kagglehub\datasets\sujaykapadnis\vehicle-type-image-dataset\versions\1
Searching for images in C:\Users\hp\.cache\kagglehub\datasets\sujaykapadnis\vehicle-type-image-dataset\versions\1
Found 4793 images across 5 classes.
Classes: ['Hatchback', 'Other', 'Pickup', 'Seden', 'SUV']


  cvae.load_state_dict(torch.load(checkpoint_path, map_location=device))


Loaded checkpoint from FL_CVAE\cvae_vehicle_epoch_50.pth
Epoch 51/3000, Beta: 0.50, Total Loss: 20.6018, Recon Loss: 39.8786, KL Loss: 1.2381, Percep Loss: 0.0434
Epoch 52/3000, Beta: 0.51, Total Loss: 19.7890, Recon Loss: 36.2003, KL Loss: 3.2417, Percep Loss: 0.0356
Epoch 53/3000, Beta: 0.52, Total Loss: 16.5824, Recon Loss: 31.5419, KL Loss: 1.4915, Percep Loss: 0.0358
Epoch 54/3000, Beta: 0.53, Total Loss: 14.6445, Recon Loss: 28.5190, KL Loss: 0.6714, Percep Loss: 0.0292
Epoch 55/3000, Beta: 0.54, Total Loss: 19.5544, Recon Loss: 36.0875, KL Loss: 2.7309, Percep Loss: 0.0360
Epoch 56/3000, Beta: 0.55, Total Loss: 16.9137, Recon Loss: 32.2648, KL Loss: 1.3578, Percep Loss: 0.0345
Epoch 57/3000, Beta: 0.56, Total Loss: 13.5545, Recon Loss: 26.3678, KL Loss: 0.6119, Percep Loss: 0.0279
Epoch 58/3000, Beta: 0.57, Total Loss: 17.3054, Recon Loss: 31.8745, KL Loss: 2.3420, Percep Loss: 0.0332
Epoch 59/3000, Beta: 0.58, Total Loss: 14.9173, Recon Loss: 29.1466, KL Loss: 0.5508, Percep Lo