CVAE with 50 epoches


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.utils import save_image
import os
import numpy as np
import matplotlib.pyplot as plt
import kagglehub
from PIL import Image

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyperparameters
latent_dim = 16
num_classes = 5
batch_size = 32
epochs = 50
learning_rate = 1e-3
image_size = 128
channels = 3

# Download Vehicle Type Image Dataset from Kaggle
try:
    path = kagglehub.dataset_download("sujaykapadnis/vehicle-type-image-dataset")
    print("Path to dataset files:", path)
    dataset_path = path
except Exception as e:
    print(f"Failed to download dataset: {e}")
    raise

# Define the VehicleTypeDataset class (from your previous code)
class VehicleTypeDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.images = []
        self.labels = []
        self.class_names = []
        self.class_to_idx = {}

        print(f"Searching for images in {root_dir}")
        for root, dirs, files in os.walk(root_dir):
            image_files = [f for f in files if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
            if image_files:
                class_name = os.path.basename(root)
                if class_name not in self.class_to_idx:
                    self.class_names.append(class_name)
                    self.class_to_idx[class_name] = len(self.class_names) - 1
                for img_file in image_files:
                    img_path = os.path.join(root, img_file)
                    self.images.append(img_path)
                    self.labels.append(self.class_to_idx[class_name])

        if not self.images:
            raise ValueError(
                f"No images found in {root_dir}. "
                "Expected class folders containing .jpg, .png, or .jpeg images."
            )

        print(f"Found {len(self.images)} images across {len(self.class_names)} classes.")
        print(f"Classes: {self.class_names}")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, label

# Define transforms (normalize to [0, 1] for CVAE training)
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),  # Converts to [0, 1]
])

# Load the dataset
dataset = VehicleTypeDataset(root_dir=dataset_path, transform=transform)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Define the Encoder network (convolutional)
class Encoder(nn.Module):
    def __init__(self, latent_dim, num_classes):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(channels + num_classes, 32, kernel_size=4, stride=2, padding=1)  # Output: 32x64x64
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)  # Output: 64x32x32
        self.conv3 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)  # Output: 128x16x16
        self.conv4 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)  # Output: 256x8x8
        self.fc_mean = nn.Linear(256 * 8 * 8, latent_dim)
        self.fc_logvar = nn.Linear(256 * 8 * 8, latent_dim)

    def forward(self, x, y):
        # Expand one-hot labels to match image dimensions and concatenate
        y = F.one_hot(y, num_classes=num_classes).float()  # Shape: (batch_size, num_classes)
        y = y.unsqueeze(-1).unsqueeze(-1)  # Shape: (batch_size, num_classes, 1, 1)
        y = y.expand(-1, -1, x.size(2), x.size(3))  # Shape: (batch_size, num_classes, 128, 128)
        x_with_y = torch.cat([x, y], dim=1)  # Shape: (batch_size, channels + num_classes, 128, 128)
        
        h = F.relu(self.conv1(x_with_y))
        h = F.relu(self.conv2(h))
        h = F.relu(self.conv3(h))
        h = F.relu(self.conv4(h))
        h = h.view(h.size(0), -1)  # Flatten: (batch_size, 256*8*8)
        z_mean = self.fc_mean(h)
        z_logvar = self.fc_logvar(h)
        return z_mean, z_logvar

# Define the Decoder network (transposed convolutional)
class Decoder(nn.Module):
    def __init__(self, latent_dim, num_classes):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(latent_dim + num_classes, 256 * 8 * 8)
        self.deconv1 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1)  # Output: 128x16x16
        self.deconv2 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1)  # Output: 64x32x32
        self.deconv3 = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1)  # Output: 32x64x64
        self.deconv4 = nn.ConvTranspose2d(32, channels, kernel_size=4, stride=2, padding=1)  # Output: 3x128x128

    def forward(self, z, y):
        y = F.one_hot(y, num_classes=num_classes).float()  # Shape: (batch_size, num_classes)
        z_with_y = torch.cat([z, y], dim=-1)  # Shape: (batch_size, latent_dim + num_classes)
        h = F.relu(self.fc(z_with_y))
        h = h.view(h.size(0), 256, 8, 8)  # Reshape: (batch_size, 256, 8, 8)
        h = F.relu(self.deconv1(h))
        h = F.relu(self.deconv2(h))
        h = F.relu(self.deconv3(h))
        x_reconstructed = torch.sigmoid(self.deconv4(h))  # Output: (batch_size, 3, 128, 128)
        return x_reconstructed

# Conditional VAE model
class ConditionalVAE(nn.Module):
    def __init__(self, encoder, decoder):
        super(ConditionalVAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def reparameterize(self, z_mean, z_logvar):
        std = torch.exp(0.5 * z_logvar)
        eps = torch.randn_like(std)
        return z_mean + eps * std

    def forward(self, x, y):
        z_mean, z_logvar = self.encoder(x, y)
        z = self.reparameterize(z_mean, z_logvar)
        x_reconstructed = self.decoder(z, y)
        return x_reconstructed, z_mean, z_logvar

# Instantiate Encoder, Decoder, and CVAE
encoder = Encoder(latent_dim, num_classes).to(device)
decoder = Decoder(latent_dim, num_classes).to(device)
cvae = ConditionalVAE(encoder, decoder).to(device)

# Define optimizer and loss function
optimizer = optim.Adam(cvae.parameters(), lr=learning_rate)

def cvae_loss(x, x_reconstructed, z_mean, z_logvar):
    recon_loss = F.binary_cross_entropy(x_reconstructed, x, reduction='sum')
    kl_loss = -0.5 * torch.sum(1 + z_logvar - z_mean.pow(2) - z_logvar.exp())
    return recon_loss + kl_loss

# Training loop
cvae.train()
for epoch in range(epochs):
    train_loss = 0
    for batch_idx, (data, labels) in enumerate(train_loader):
        data, labels = data.to(device), labels.to(device)

        optimizer.zero_grad()
        x_reconstructed, z_mean, z_logvar = cvae(data, labels)
        loss = cvae_loss(data, x_reconstructed, z_mean, z_logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

    print(f'Epoch {epoch + 1}/{epochs}, Loss: {train_loss / len(train_loader.dataset):.4f}')

# Save the trained CVAE model
output_dir = "FL_VEHICLE_NON_IID"
os.makedirs(output_dir, exist_ok=True)
model_path = os.path.join(output_dir, "cvae_vehicle.pth")
torch.save(cvae.state_dict(), model_path)
print(f"Saved CVAE model to {model_path}")

# Generate samples for Classes 3 and 4
def generate_samples_labelwise(cvae, num_samples, classes_to_generate, base_dir, latent_dim, device):
    cvae.eval()
    os.makedirs(base_dir, exist_ok=True)
    with torch.no_grad():
        for class_label in classes_to_generate:
            label_tensor = torch.tensor([class_label]).repeat(num_samples).to(device)
            z = torch.randn(num_samples, latent_dim).to(device)
            generated_samples = cvae.decoder(z, label_tensor)  # Shape: (num_samples, 3, 128, 128)
            
            class_dir = os.path.join(base_dir, str(class_label))
            os.makedirs(class_dir, exist_ok=True)
            for idx, sample in enumerate(generated_samples):
                save_image(sample, os.path.join(class_dir, f"sample_{idx}.png"))
            print(f"Generated {num_samples} samples for Class {class_label} ({dataset.class_names[class_label]}).")

# Generate 500 samples each for Classes 3 and 4
base_dir = os.path.join(output_dir, "generated_samples")
classes_to_generate = [3, 4]  # Classes 3 and 4
generate_samples_labelwise(cvae, num_samples=500, classes_to_generate=classes_to_generate, base_dir=base_dir, latent_dim=latent_dim, device=device)

# Plot random samples for Classes 3 and 4
def plot_random_samples(base_dir, classes_to_generate, num_images_per_class=10):
    fig, axs = plt.subplots(len(classes_to_generate), num_images_per_class, figsize=(20, 4))
    for row, class_label in enumerate(classes_to_generate):
        class_dir = os.path.join(base_dir, str(class_label))
        sample_files = os.listdir(class_dir)
        random_samples = np.random.choice(sample_files, num_images_per_class, replace=False)
        
        for col, sample_file in enumerate(random_samples):
            sample_path = os.path.join(class_dir, sample_file)
            sample_image = plt.imread(sample_path)  # RGB image
            if len(classes_to_generate) == 1:
                ax = axs[col]
            else:
                ax = axs[row, col]
            ax.imshow(sample_image)
            ax.axis('off')
            if col == 0:
                ax.set_ylabel(dataset.class_names[class_label], rotation=90, labelpad=10)
    
    plt.tight_layout()
    plot_path = os.path.join(output_dir, "synthetic_samples_classes_3_4.png")
    plt.savefig(plot_path)
    plt.show()
    print(f"Saved synthetic samples plot to {plot_path}")

# Plot 10 samples each for Classes 3 and 4
plot_random_samples(base_dir=base_dir, classes_to_generate=classes_to_generate, num_images_per_class=10)

MODIDFED CVAE 

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.utils import save_image
import torchvision.models as models
from torchvision.transforms.functional import adjust_sharpness
import os
import numpy as np
import matplotlib.pyplot as plt
import kagglehub
from PIL import Image
from torchvision.models import VGG16_Weights

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyperparameters
latent_dim = 64
num_classes = 5
batch_size = 16
epochs = 500
learning_rate = 1e-3
image_size = 128
channels = 3
output_dir = "FL_VEHICLE_NON_IID"

# Download Vehicle Type Image Dataset from Kaggle
try:
    path = kagglehub.dataset_download("sujaykapadnis/vehicle-type-image-dataset")
    print("Path to dataset files:", path)
    dataset_path = path
except Exception as e:
    print(f"Failed to download dataset: {e}")
    raise

# Define the VehicleTypeDataset class
class VehicleTypeDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.images = []
        self.labels = []
        self.class_names = []
        self.class_to_idx = {}

        print(f"Searching for images in {root_dir}")
        for root, dirs, files in os.walk(root_dir):
            image_files = [f for f in files if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
            if image_files:
                class_name = os.path.basename(root)
                if class_name not in self.class_to_idx:
                    self.class_names.append(class_name)
                    self.class_to_idx[class_name] = len(self.class_names) - 1
                for img_file in image_files:
                    img_path = os.path.join(root, img_file)
                    self.images.append(img_path)
                    self.labels.append(self.class_to_idx[class_name])

        if not self.images:
            raise ValueError(
                f"No images found in {root_dir}. "
                "Expected class folders containing .jpg, .png, or .jpeg images."
            )

        print(f"Found {len(self.images)} images across {len(self.class_names)} classes.")
        print(f"Classes: {self.class_names}")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, label

# Define transforms
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),  # Converts to [0, 1]
])

# Load the dataset
dataset = VehicleTypeDataset(root_dir=dataset_path, transform=transform)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Define the Encoder network with deeper layers and skip connections
class Encoder(nn.Module):
    def __init__(self, latent_dim, num_classes):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(channels + num_classes, 32, kernel_size=4, stride=2, padding=1)  # 32x64x64
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)  # 64x32x32
        self.conv3 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)  # 128x16x16
        self.conv4 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)  # 256x8x8
        self.conv5 = nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1)  # 512x4x4
        self.fc_mean = nn.Linear(512 * 4 * 4, latent_dim)
        self.fc_logvar = nn.Linear(512 * 4 * 4, latent_dim)

    def forward(self, x, y):
        y = F.one_hot(y, num_classes=num_classes).float()
        y = y.unsqueeze(-1).unsqueeze(-1)
        y = y.expand(-1, -1, x.size(2), x.size(3))
        x_with_y = torch.cat([x, y], dim=1)
        
        h1 = F.relu(self.conv1(x_with_y))
        h2 = F.relu(self.conv2(h1))
        h3 = F.relu(self.conv3(h2))
        h4 = F.relu(self.conv4(h3))
        h5 = F.relu(self.conv5(h4))
        h = h5.view(h5.size(0), -1)  # Fixed: Use h5.size(0) instead of h.size(0)
        z_mean = self.fc_mean(h)
        z_logvar = self.fc_logvar(h)
        return z_mean, z_logvar, (h1, h2, h3, h4, h5)

# Define the Decoder network with skip connections
class Decoder(nn.Module):
    def __init__(self, latent_dim, num_classes):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(latent_dim + num_classes, 512 * 4 * 4)
        self.deconv1 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1)  # 256x8x8
        self.deconv2 = nn.ConvTranspose2d(512, 128, kernel_size=4, stride=2, padding=1)  # 128x16x16
        self.deconv3 = nn.ConvTranspose2d(256, 64, kernel_size=4, stride=2, padding=1)  # 64x32x32
        self.deconv4 = nn.ConvTranspose2d(128, 32, kernel_size=4, stride=2, padding=1)  # 32x64x64
        self.deconv5 = nn.ConvTranspose2d(64, channels, kernel_size=4, stride=2, padding=1)  # 3x128x128

    def forward(self, z, y, skip_connections):
        h1, h2, h3, h4, h5 = skip_connections
        y = F.one_hot(y, num_classes=num_classes).float()
        z_with_y = torch.cat([z, y], dim=-1)
        h = F.relu(self.fc(z_with_y))
        h = h.view(h.size(0), 512, 4, 4)
        
        h = F.relu(self.deconv1(h))  # 256x8x8
        h = torch.cat([h, h4], dim=1)  # Concatenate skip connection: 256 + 256 = 512
        h = F.relu(self.deconv2(h))  # 128x16x16
        h = torch.cat([h, h3], dim=1)  # Concatenate skip connection: 128 + 128 = 256
        h = F.relu(self.deconv3(h))  # 64x32x32
        h = torch.cat([h, h2], dim=1)  # Concatenate skip connection: 64 + 64 = 128
        h = F.relu(self.deconv4(h))  # 32x64x64
        h = torch.cat([h, h1], dim=1)  # Concatenate skip connection: 32 + 32 = 64
        x_reconstructed = torch.sigmoid(self.deconv5(h))
        return x_reconstructed

# Conditional VAE model
class ConditionalVAE(nn.Module):
    def __init__(self, encoder, decoder):
        super(ConditionalVAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def reparameterize(self, z_mean, z_logvar):
        std = torch.exp(0.5 * z_logvar)
        eps = torch.randn_like(std)
        return z_mean + eps * std

    def forward(self, x, y):
        z_mean, z_logvar, skip_connections = self.encoder(x, y)
        z = self.reparameterize(z_mean, z_logvar)
        x_reconstructed = self.decoder(z, y, skip_connections)
        return x_reconstructed, z_mean, z_logvar

# Load pretrained VGG16 for perceptual loss (updated API)
vgg = models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1).features.to(device).eval()
for param in vgg.parameters():
    param.requires_grad = False

def perceptual_loss(x, x_reconstructed):
    x_features = vgg(x)
    x_recon_features = vgg(x_reconstructed)
    return F.mse_loss(x_features, x_recon_features)

# Instantiate Encoder, Decoder, and CVAE
encoder = Encoder(latent_dim, num_classes).to(device)
decoder = Decoder(latent_dim, num_classes).to(device)
cvae = ConditionalVAE(encoder, decoder).to(device)

# Define optimizer and learning rate scheduler
optimizer = optim.Adam(cvae.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5)

# Define loss function with separate logging
def cvae_loss(x, x_reconstructed, z_mean, z_logvar, perceptual_weight=0.1):
    recon_loss = F.binary_cross_entropy(x_reconstructed, x, reduction='sum')
    kl_loss = -0.5 * torch.sum(1 + z_logvar - z_mean.pow(2) - z_logvar.exp())
    percep_loss = perceptual_loss(x, x_reconstructed) * perceptual_weight
    total_loss = recon_loss + kl_loss + percep_loss
    return total_loss, recon_loss, kl_loss, percep_loss

# Training loop with checkpointing
cvae.train()
for epoch in range(epochs):
    train_loss = 0
    train_recon_loss = 0
    train_kl_loss = 0
    train_percep_loss = 0
    for batch_idx, (data, labels) in enumerate(train_loader):
        data, labels = data.to(device), labels.to(device)

        optimizer.zero_grad()
        x_reconstructed, z_mean, z_logvar = cvae(data, labels)
        total_loss, recon_loss, kl_loss, percep_loss = cvae_loss(data, x_reconstructed, z_mean, z_logvar)
        total_loss.backward()
        train_loss += total_loss.item()
        train_recon_loss += recon_loss.item()
        train_kl_loss += kl_loss.item()
        train_percep_loss += percep_loss.item()
        optimizer.step()

    scheduler.step()
    avg_loss = train_loss / len(train_loader.dataset)
    avg_recon_loss = train_recon_loss / len(train_loader.dataset)
    avg_kl_loss = train_kl_loss / len(train_loader.dataset)
    avg_percep_loss = train_percep_loss / len(train_loader.dataset)
    print(f'Epoch {epoch + 1}/{epochs}, Total Loss: {avg_loss:.4f}, '
          f'Recon Loss: {avg_recon_loss:.4f}, KL Loss: {avg_kl_loss:.4f}, '
          f'Percep Loss: {avg_percep_loss:.4f}')

    # Save checkpoint every 50 epochs
    if (epoch + 1) % 50 == 0:
        checkpoint_path = os.path.join(output_dir, f"cvae_vehicle_epoch_{epoch + 1}.pth")
        torch.save(cvae.state_dict(), checkpoint_path)
        print(f"Saved checkpoint at epoch {epoch + 1} to {checkpoint_path}")

        # Generate and visualize samples at this checkpoint
        cvae.eval()
        base_dir = os.path.join(output_dir, f"generated_samples_epoch_{epoch + 1}")
        classes_to_generate = [3, 4]
        with torch.no_grad():
            for class_label in classes_to_generate:
                label_tensor = torch.tensor([class_label]).repeat(500).to(device)
                z = torch.randn(500, latent_dim).to(device)
                # Dummy skip connections for generation (since we're not passing through encoder)
                dummy_skips = [
                    torch.zeros(batch_size, 32, 64, 64).to(device),
                    torch.zeros(batch_size, 64, 32, 32).to(device),
                    torch.zeros(batch_size, 128, 16, 16).to(device),
                    torch.zeros(batch_size, 256, 8, 8).to(device),
                    torch.zeros(batch_size, 512, 4, 4).to(device)
                ]
                generated_samples = cvae.decoder(z, label_tensor, dummy_skips)
                class_dir = os.path.join(base_dir, str(class_label))
                os.makedirs(class_dir, exist_ok=True)
                for idx, sample in enumerate(generated_samples):
                    sample = adjust_sharpness(sample, sharpness_factor=2.0)
                    save_image(sample, os.path.join(class_dir, f"sample_{idx}.png"))
                print(f"Generated 500 samples for Class {class_label} ({dataset.class_names[class_label]}) at epoch {epoch + 1}.")

        # Plot samples
        fig, axs = plt.subplots(len(classes_to_generate), 10, figsize=(20, 4))
        for row, class_label in enumerate(classes_to_generate):
            class_dir = os.path.join(base_dir, str(class_label))
            sample_files = os.listdir(class_dir)
            random_samples = np.random.choice(sample_files, 10, replace=False)
            for col, sample_file in enumerate(random_samples):
                sample_path = os.path.join(class_dir, sample_file)
                sample_image = Image.open(sample_path).convert("RGB")
                sample_image = sample_image.resize((128, 128), Image.LANCZOS)
                sample_image = np.array(sample_image) / 255.0
                ax = axs[row, col] if len(classes_to_generate) > 1 else axs[col]
                ax.imshow(sample_image)
                ax.axis('off')
                if col == 0:
                    ax.set_ylabel(dataset.class_names[class_label], rotation=90, labelpad=10)
        plt.tight_layout()
        plot_path = os.path.join(output_dir, f"synthetic_samples_classes_3_4_epoch_{epoch + 1}.png")
        plt.savefig(plot_path)
        plt.close()
        print(f"Saved synthetic samples plot at epoch {epoch + 1} to {plot_path}")
        cvae.train()

# Final model save
final_model_path = os.path.join(output_dir, "cvae_vehicle_final.pth")
torch.save(cvae.state_dict(), final_model_path)
print(f"Saved final CVAE model to {final_model_path}")

Path to dataset files: C:\Users\hp\.cache\kagglehub\datasets\sujaykapadnis\vehicle-type-image-dataset\versions\1
Searching for images in C:\Users\hp\.cache\kagglehub\datasets\sujaykapadnis\vehicle-type-image-dataset\versions\1
Found 4793 images across 5 classes.
Classes: ['Hatchback', 'Other', 'Pickup', 'Seden', 'SUV']
Epoch 1/500, Total Loss: 26322.5639, Recon Loss: 26321.3656, KL Loss: 1.1939, Percep Loss: 0.0042
Epoch 2/500, Total Loss: 25217.1438, Recon Loss: 25217.1412, KL Loss: 0.0008, Percep Loss: 0.0021
Epoch 3/500, Total Loss: 25122.8553, Recon Loss: 25122.8534, KL Loss: 0.0001, Percep Loss: 0.0013
Epoch 4/500, Total Loss: 25086.5595, Recon Loss: 25086.5589, KL Loss: 0.0000, Percep Loss: 0.0009
Epoch 5/500, Total Loss: 25070.5351, Recon Loss: 25070.5351, KL Loss: 0.0000, Percep Loss: 0.0007
Epoch 6/500, Total Loss: 25054.9869, Recon Loss: 25054.9869, KL Loss: 0.0000, Percep Loss: 0.0006
Epoch 7/500, Total Loss: 25052.4975, Recon Loss: 25052.4975, KL Loss: 0.0000, Percep Loss: 

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 500 but got size 16 for tensor number 1 in the list.

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.utils import save_image
import torchvision.models as models
from torchvision.transforms.functional import adjust_sharpness
import os
import numpy as np
import matplotlib.pyplot as plt
import kagglehub
from PIL import Image
from torchvision.models import VGG16_Weights

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyperparameters
latent_dim = 64
num_classes = 5
batch_size = 16
epochs = 500
learning_rate = 1e-3
image_size = 128
channels = 3
output_dir = "FL_VEHICLE_NON_IID"
beta_max = 10.0  # Maximum beta for KL loss
annealing_epochs = 50  # Epochs over which to anneal beta

# Download Vehicle Type Image Dataset from Kaggle
try:
    path = kagglehub.dataset_download("sujaykapadnis/vehicle-type-image-dataset")
    print("Path to dataset files:", path)
    dataset_path = path
except Exception as e:
    print(f"Failed to download dataset: {e}")
    raise

# Define the VehicleTypeDataset class
class VehicleTypeDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.images = []
        self.labels = []
        self.class_names = []
        self.class_to_idx = {}

        print(f"Searching for images in {root_dir}")
        for root, dirs, files in os.walk(root_dir):
            image_files = [f for f in files if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
            if image_files:
                class_name = os.path.basename(root)
                if class_name not in self.class_to_idx:
                    self.class_names.append(class_name)
                    self.class_to_idx[class_name] = len(self.class_names) - 1
                for img_file in image_files:
                    img_path = os.path.join(root, img_file)
                    self.images.append(img_path)
                    self.labels.append(self.class_to_idx[class_name])

        if not self.images:
            raise ValueError(
                f"No images found in {root_dir}. "
                "Expected class folders containing .jpg, .png, or .jpeg images."
            )

        print(f"Found {len(self.images)} images across {len(self.class_names)} classes.")
        print(f"Classes: {self.class_names}")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, label

# Define transforms
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
])

# Load the dataset
dataset = VehicleTypeDataset(root_dir=dataset_path, transform=transform)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Define the Encoder network
class Encoder(nn.Module):
    def __init__(self, latent_dim, num_classes):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(channels + num_classes, 32, kernel_size=4, stride=2, padding=1)  # 32x64x64
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)  # 64x32x32
        self.conv3 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)  # 128x16x16
        self.conv4 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)  # 256x8x8
        self.conv5 = nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1)  # 512x4x4
        self.fc_mean = nn.Linear(512 * 4 * 4, latent_dim)
        self.fc_logvar = nn.Linear(512 * 4 * 4, latent_dim)

    def forward(self, x, y):
        y = F.one_hot(y, num_classes=num_classes).float()
        y = y.unsqueeze(-1).unsqueeze(-1)
        y = y.expand(-1, -1, x.size(2), x.size(3))
        x_with_y = torch.cat([x, y], dim=1)
        
        h1 = F.relu(self.conv1(x_with_y))
        h2 = F.relu(self.conv2(h1))
        h3 = F.relu(self.conv3(h2))
        h4 = F.relu(self.conv4(h3))
        h5 = F.relu(self.conv5(h4))
        h = h5.view(h5.size(0), -1)
        z_mean = self.fc_mean(h)
        z_logvar = self.fc_logvar(h)
        return z_mean, z_logvar, (h1, h2, h3, h4, h5)

# Define the Decoder network
class Decoder(nn.Module):
    def __init__(self, latent_dim, num_classes):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(latent_dim + num_classes, 512 * 4 * 4)
        self.deconv1 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1)  # 256x8x8
        self.deconv2 = nn.ConvTranspose2d(512, 128, kernel_size=4, stride=2, padding=1)  # 128x16x16
        self.deconv3 = nn.ConvTranspose2d(256, 64, kernel_size=4, stride=2, padding=1)  # 64x32x32
        self.deconv4 = nn.ConvTranspose2d(128, 32, kernel_size=4, stride=2, padding=1)  # 32x64x64
        self.deconv5 = nn.ConvTranspose2d(64, channels, kernel_size=4, stride=2, padding=1)  # 3x128x128

    def forward(self, z, y, skip_connections):
        h1, h2, h3, h4, h5 = skip_connections
        y = F.one_hot(y, num_classes=num_classes).float()
        z_with_y = torch.cat([z, y], dim=-1)
        h = F.relu(self.fc(z_with_y))
        h = h.view(h.size(0), 512, 4, 4)
        
        h = F.relu(self.deconv1(h))
        h = torch.cat([h, h4], dim=1)
        h = F.relu(self.deconv2(h))
        h = torch.cat([h, h3], dim=1)
        h = F.relu(self.deconv3(h))
        h = torch.cat([h, h2], dim=1)
        h = F.relu(self.deconv4(h))
        h = torch.cat([h, h1], dim=1)
        x_reconstructed = torch.sigmoid(self.deconv5(h))
        return x_reconstructed

# Conditional VAE model
class ConditionalVAE(nn.Module):
    def __init__(self, encoder, decoder):
        super(ConditionalVAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def reparameterize(self, z_mean, z_logvar):
        std = torch.exp(0.5 * z_logvar)
        eps = torch.randn_like(std)
        return z_mean + eps * std

    def forward(self, x, y):
        z_mean, z_logvar, skip_connections = self.encoder(x, y)
        z = self.reparameterize(z_mean, z_logvar)
        x_reconstructed = self.decoder(z, y, skip_connections)
        return x_reconstructed, z_mean, z_logvar

# Load pretrained VGG16 for perceptual loss
vgg = models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1).features.to(device).eval()
for param in vgg.parameters():
    param.requires_grad = False

def perceptual_loss(x, x_reconstructed):
    x_features = vgg(x)
    x_recon_features = vgg(x_reconstructed)
    return F.mse_loss(x_features, x_recon_features)

# Instantiate Encoder, Decoder, and CVAE
encoder = Encoder(latent_dim, num_classes).to(device)
decoder = Decoder(latent_dim, num_classes).to(device)
cvae = ConditionalVAE(encoder, decoder).to(device)

# Define optimizer and learning rate scheduler
optimizer = optim.Adam(cvae.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5)

# Define loss function with beta scaling and annealing
def cvae_loss(x, x_reconstructed, z_mean, z_logvar, beta=1.0, perceptual_weight=1.0):
    recon_loss = F.binary_cross_entropy(x_reconstructed, x, reduction='sum')
    kl_loss = -0.5 * torch.sum(1 + z_logvar - z_mean.pow(2) - z_logvar.exp())
    percep_loss = perceptual_loss(x, x_reconstructed) * perceptual_weight
    total_loss = recon_loss + beta * kl_loss + percep_loss
    return total_loss, recon_loss, kl_loss, percep_loss

# Training loop with checkpointing
cvae.train()
for epoch in range(epochs):
    # Compute beta for KL annealing
    if epoch < annealing_epochs:
        beta = beta_max * (epoch / annealing_epochs)
    else:
        beta = beta_max

    train_loss = 0
    train_recon_loss = 0
    train_kl_loss = 0
    train_percep_loss = 0
    for batch_idx, (data, labels) in enumerate(train_loader):
        data, labels = data.to(device), labels.to(device)

        optimizer.zero_grad()
        x_reconstructed, z_mean, z_logvar = cvae(data, labels)
        total_loss, recon_loss, kl_loss, percep_loss = cvae_loss(data, x_reconstructed, z_mean, z_logvar, beta=beta)
        total_loss.backward()
        train_loss += total_loss.item()
        train_recon_loss += recon_loss.item()
        train_kl_loss += kl_loss.item()
        train_percep_loss += percep_loss.item()
        optimizer.step()

    scheduler.step()
    avg_loss = train_loss / len(train_loader.dataset)
    avg_recon_loss = train_recon_loss / len(train_loader.dataset)
    avg_kl_loss = train_kl_loss / len(train_loader.dataset)
    avg_percep_loss = train_percep_loss / len(train_loader.dataset)
    print(f'Epoch {epoch + 1}/{epochs}, Beta: {beta:.2f}, Total Loss: {avg_loss:.4f}, '
          f'Recon Loss: {avg_recon_loss:.4f}, KL Loss: {avg_kl_loss:.4f}, '
          f'Percep Loss: {avg_percep_loss:.4f}')

    # Save checkpoint every 50 epochs
    if (epoch + 1) % 50 == 0:
        checkpoint_path = os.path.join(output_dir, f"cvae_vehicle_epoch_{epoch + 1}.pth")
        torch.save(cvae.state_dict(), checkpoint_path)
        print(f"Saved checkpoint at epoch {epoch + 1} to {checkpoint_path}")

        # Generate and visualize samples at this checkpoint
        cvae.eval()
        base_dir = os.path.join(output_dir, f"generated_samples_epoch_{epoch + 1}")
        classes_to_generate = [3, 4]
        num_samples = 500
        with torch.no_grad():
            for class_label in classes_to_generate:
                label_tensor = torch.tensor([class_label]).repeat(num_samples).to(device)
                z = torch.randn(num_samples, latent_dim).to(device)
                # Create dummy skip connections with correct batch size
                dummy_skips = [
                    torch.zeros(num_samples, 32, 64, 64).to(device),
                    torch.zeros(num_samples, 64, 32, 32).to(device),
                    torch.zeros(num_samples, 128, 16, 16).to(device),
                    torch.zeros(num_samples, 256, 8, 8).to(device),
                    torch.zeros(num_samples, 512, 4, 4).to(device)
                ]
                generated_samples = cvae.decoder(z, label_tensor, dummy_skips)
                class_dir = os.path.join(base_dir, str(class_label))
                os.makedirs(class_dir, exist_ok=True)
                for idx, sample in enumerate(generated_samples):
                    sample = adjust_sharpness(sample, sharpness_factor=2.0)
                    save_image(sample, os.path.join(class_dir, f"sample_{idx}.png"))
                print(f"Generated {num_samples} samples for Class {class_label} ({dataset.class_names[class_label]}) at epoch {epoch + 1}.")

        # Plot samples
        fig, axs = plt.subplots(len(classes_to_generate), 10, figsize=(20, 4))
        for row, class_label in enumerate(classes_to_generate):
            class_dir = os.path.join(base_dir, str(class_label))
            sample_files = os.listdir(class_dir)
            random_samples = np.random.choice(sample_files, 10, replace=False)
            for col, sample_file in enumerate(random_samples):
                sample_path = os.path.join(class_dir, sample_file)
                sample_image = Image.open(sample_path).convert("RGB")
                sample_image = sample_image.resize((128, 128), Image.LANCZOS)
                sample_image = np.array(sample_image) / 255.0
                ax = axs[row, col] if len(classes_to_generate) > 1 else axs[col]
                ax.imshow(sample_image)
                ax.axis('off')
                if col == 0:
                    ax.set_ylabel(dataset.class_names[class_label], rotation=90, labelpad=10)
        plt.tight_layout()
        plot_path = os.path.join(output_dir, f"synthetic_samples_classes_3_4_epoch_{epoch + 1}.png")
        plt.savefig(plot_path)
        plt.close()
        print(f"Saved synthetic samples plot at epoch {epoch + 1} to {plot_path}")
        cvae.train()

# Final model save
final_model_path = os.path.join(output_dir, "cvae_vehicle_final.pth")
torch.save(cvae.state_dict(), final_model_path)
print(f"Saved final CVAE model to {final_model_path}")

Path to dataset files: C:\Users\hp\.cache\kagglehub\datasets\sujaykapadnis\vehicle-type-image-dataset\versions\1
Searching for images in C:\Users\hp\.cache\kagglehub\datasets\sujaykapadnis\vehicle-type-image-dataset\versions\1
Found 4793 images across 5 classes.
Classes: ['Hatchback', 'Other', 'Pickup', 'Seden', 'SUV']


KeyboardInterrupt: 

In [1]:
import torch
torch.cuda.empty_cache()


In [3]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"



RGB MODIFED

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.utils import save_image
import torchvision.models as models
from torchvision.transforms.functional import adjust_sharpness
import os
import numpy as np
import matplotlib.pyplot as plt
import kagglehub
from PIL import Image
from torchvision.models import VGG16_Weights

# Set CUDA_LAUNCH_BLOCKING for debugging
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyperparameters
latent_dim = 128  # Increased for better detail capture
num_classes = 5
batch_size = 32
epochs = 1000  # Extended for better convergence
learning_rate = 5e-4  # Reduced for stability
image_size = 128
channels = 3
output_dir = "FL_VEHICLE_NON_IID"
beta_max = 10.0  # Reduced to balance KL loss
annealing_epochs = 50  # Gradual increase of KL weight
perceptual_weight = 1.0  # Increased to emphasize perceptual quality
recon_weight = 0.5  # Slightly reduced to balance losses

# Download Vehicle Type Image Dataset from Kaggle
try:
    path = kagglehub.dataset_download("sujaykapadnis/vehicle-type-image-dataset")
    print("Path to dataset files:", path)
    dataset_path = path
except Exception as e:
    print(f"Failed to download dataset: {e}")
    raise

# Define the VehicleTypeDataset class
class VehicleTypeDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.images = []
        self.labels = []
        self.class_names = []
        self.class_to_idx = {}

        print(f"Searching for images in {root_dir}")
        for root, dirs, files in os.walk(root_dir):
            image_files = [f for f in files if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
            if image_files:
                class_name = os.path.basename(root)
                if class_name not in self.class_to_idx:
                    self.class_names.append(class_name)
                    self.class_to_idx[class_name] = len(self.class_names) - 1
                for img_file in image_files:
                    img_path = os.path.join(root, img_file)
                    self.images.append(img_path)
                    self.labels.append(self.class_to_idx[class_name])

        if not self.images:
            raise ValueError(
                f"No images found in {root_dir}. "
                "Expected class folders containing .jpg, .png, or .jpeg images."
            )

        print(f"Found {len(self.images)} images across {len(self.class_names)} classes.")
        print(f"Classes: {self.class_names}")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, label

# Define transforms
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),  # Converts to [0, 1]
])

# Load the dataset
dataset = VehicleTypeDataset(root_dir=dataset_path, transform=transform)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Define the Encoder network with deeper layers and skip connections
class Encoder(nn.Module):
    def __init__(self, latent_dim, num_classes):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(channels + num_classes, 32, kernel_size=4, stride=2, padding=1)  # 32x64x64
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)  # 64x32x32
        self.conv3 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)  # 128x16x16
        self.conv4 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)  # 256x8x8
        self.conv5 = nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1)  # 512x4x4
        self.fc_mean = nn.Linear(512 * 4 * 4, latent_dim)
        self.fc_logvar = nn.Linear(512 * 4 * 4, latent_dim)

        # Initialize weights
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x, y):
        y = F.one_hot(y, num_classes=num_classes).float()
        y = y.unsqueeze(-1).unsqueeze(-1)
        y = y.expand(-1, -1, x.size(2), x.size(3))
        x_with_y = torch.cat([x, y], dim=1)
        
        h1 = F.relu(self.conv1(x_with_y))
        if torch.isnan(h1).any() or torch.isinf(h1).any():
            print("NaN or Inf in h1")
        h2 = F.relu(self.conv2(h1))
        if torch.isnan(h2).any() or torch.isinf(h2).any():
            print("NaN or Inf in h2")
        h3 = F.relu(self.conv3(h2))
        if torch.isnan(h3).any() or torch.isinf(h3).any():
            print("NaN or Inf in h3")
        h4 = F.relu(self.conv4(h3))
        if torch.isnan(h4).any() or torch.isinf(h4).any():
            print("NaN or Inf in h4")
        h5 = F.relu(self.conv5(h4))
        if torch.isnan(h5).any() or torch.isinf(h5).any():
            print("NaN or Inf in h5")
        h = h5.view(h5.size(0), -1)
        z_mean = self.fc_mean(h)
        z_logvar = self.fc_logvar(h)
        return z_mean, z_logvar, (h1, h2, h3, h4, h5)

# Define the Decoder network with skip connections
class Decoder(nn.Module):
    def __init__(self, latent_dim, num_classes):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(latent_dim + num_classes, 512 * 4 * 4)
        self.deconv1 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1)  # 256x8x8
        self.deconv2 = nn.ConvTranspose2d(512, 128, kernel_size=4, stride=2, padding=1)  # 128x16x16
        self.deconv3 = nn.ConvTranspose2d(256, 64, kernel_size=4, stride=2, padding=1)  # 64x32x32
        self.deconv4 = nn.ConvTranspose2d(128, 32, kernel_size=4, stride=2, padding=1)  # 32x64x64
        self.deconv5 = nn.ConvTranspose2d(64, channels, kernel_size=4, stride=2, padding=1)  # 3x128x128

        # Initialize weights
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, z, y, skip_connections):
        h1, h2, h3, h4, h5 = skip_connections
        y = F.one_hot(y, num_classes=num_classes).float()
        z_with_y = torch.cat([z, y], dim=-1)
        h = F.relu(self.fc(z_with_y))
        h = h.view(h.size(0), 512, 4, 4)
        
        h = F.relu(self.deconv1(h))
        h = torch.cat([h, h4], dim=1)
        h = F.relu(self.deconv2(h))
        h = torch.cat([h, h3], dim=1)
        h = F.relu(self.deconv3(h))
        h = torch.cat([h, h2], dim=1)
        h = F.relu(self.deconv4(h))
        h = torch.cat([h, h1], dim=1)
        x_reconstructed = torch.sigmoid(self.deconv5(h))
        if torch.isnan(x_reconstructed).any() or torch.isinf(x_reconstructed).any():
            print("NaN or Inf in x_reconstructed")
        return x_reconstructed

# Conditional VAE model
class ConditionalVAE(nn.Module):
    def __init__(self, encoder, decoder):
        super(ConditionalVAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def reparameterize(self, z_mean, z_logvar):
        std = torch.exp(0.5 * z_logvar)
        eps = torch.randn_like(std)
        return z_mean + eps * std

    def forward(self, x, y):
        z_mean, z_logvar, skip_connections = self.encoder(x, y)
        z = self.reparameterize(z_mean, z_logvar)
        x_reconstructed = self.decoder(z, y, skip_connections)
        return x_reconstructed, z_mean, z_logvar

# Load pretrained VGG16 for perceptual loss
vgg = models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1).features.to(device).eval()
for param in vgg.parameters():
    param.requires_grad = False

def perceptual_loss(x, x_reconstructed):
    # Normalize inputs to ImageNet mean and std
    mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(x.device)
    std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(x.device)
    x_normalized = (x - mean) / std
    x_reconstructed_normalized = (x_reconstructed - mean) / std
    x_features = vgg(x_normalized)
    x_recon_features = vgg(x_reconstructed_normalized)
    return F.mse_loss(x_features, x_recon_features)

# Instantiate Encoder, Decoder, and CVAE
encoder = Encoder(latent_dim, num_classes).to(device)
decoder = Decoder(latent_dim, num_classes).to(device)
cvae = ConditionalVAE(encoder, decoder).to(device)

# Define optimizer and learning rate scheduler
optimizer = optim.Adam(cvae.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.5)

# Define loss function with beta annealing
def cvae_loss(x, x_reconstructed, z_mean, z_logvar, beta=1.0, recon_weight=1.0, perceptual_weight=1.0):
    if torch.isnan(x).any() or torch.isinf(x).any():
        print("NaN or Inf detected in x")
    if torch.isnan(x_reconstructed).any() or torch.isinf(x_reconstructed).any():
        print("NaN or Inf detected in x_reconstructed")
    if torch.isnan(z_mean).any() or torch.isinf(z_mean).any():
        print("NaN or Inf detected in z_mean")
    if torch.isnan(z_logvar).any() or torch.isinf(z_logvar).any():
        print("NaN or Inf detected in z_logvar")

    recon_loss = F.binary_cross_entropy(x_reconstructed, x, reduction='sum')
    kl_loss = -0.5 * torch.sum(1 + torch.clamp(z_logvar, -10, 10) - z_mean.pow(2) - torch.clamp(z_logvar, -10, 10).exp())
    percep_loss = perceptual_loss(x, x_reconstructed) * perceptual_weight
    total_loss = recon_weight * recon_loss + beta * kl_loss + percep_loss
    return total_loss, recon_loss, kl_loss, percep_loss

# Training loop with checkpointing
cvae.train()
for epoch in range(epochs):
    if epoch < annealing_epochs:
        beta = beta_max * (epoch / annealing_epochs)
    else:
        beta = beta_max

    train_loss = 0
    train_recon_loss = 0
    train_kl_loss = 0
    train_percep_loss = 0
    for batch_idx, (data, labels) in enumerate(train_loader):
        data, labels = data.to(device), labels.to(device)
       

        optimizer.zero_grad()
        x_reconstructed, z_mean, z_logvar = cvae(data, labels)
        total_loss, recon_loss, kl_loss, percep_loss = cvae_loss(
            data, x_reconstructed, z_mean, z_logvar, 
            beta=beta, recon_weight=recon_weight, perceptual_weight=perceptual_weight
        )
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(cvae.parameters(), max_norm=1.0)  # Gradient clipping
        train_loss += total_loss.item()
        train_recon_loss += recon_loss.item()
        train_kl_loss += kl_loss.item()
        train_percep_loss += percep_loss.item()
        optimizer.step()

    scheduler.step()
    avg_loss = train_loss / len(train_loader.dataset)
    avg_recon_loss = train_recon_loss / len(train_loader.dataset)
    avg_kl_loss = train_kl_loss / len(train_loader.dataset)
    avg_percep_loss = train_percep_loss / len(train_loader.dataset)
    print(f'Epoch {epoch + 1}/{epochs}, Beta: {beta:.2f}, Total Loss: {avg_loss:.4f}, '
          f'Recon Loss: {avg_recon_loss:.4f}, KL Loss: {avg_kl_loss:.4f}, '
          f'Percep Loss: {avg_percep_loss:.4f}')

    # Save checkpoint every 50 epochs
    if (epoch + 1) % 50 == 0:
        checkpoint_path = os.path.join(output_dir, f"cvae_vehicle_epoch_{epoch + 1}.pth")
        torch.save(cvae.state_dict(), checkpoint_path)
        print(f"Saved checkpoint at epoch {epoch + 1} to {checkpoint_path}")

        # Generate and visualize samples at this checkpoint
        cvae.eval()
        base_dir = os.path.join(output_dir, f"generated_samples_epoch_{epoch + 1}")
        classes_to_generate = [3, 4]
        num_samples = 500
        with torch.no_grad():
            for class_label in classes_to_generate:
                label_tensor = torch.tensor([class_label]).repeat(num_samples).to(device)
                z = torch.randn(num_samples, latent_dim).to(device)
                # Use realistic dummy skip connections
                dummy_skips = [
                    torch.randn(num_samples, 32, 64, 64).to(device) * 0.1,
                    torch.randn(num_samples, 64, 32, 32).to(device) * 0.1,
                    torch.randn(num_samples, 128, 16, 16).to(device) * 0.1,
                    torch.randn(num_samples, 256, 8, 8).to(device) * 0.1,
                    torch.randn(num_samples, 512, 4, 4).to(device) * 0.1
                ]
                generated_samples = cvae.decoder(z, label_tensor, dummy_skips)
                class_dir = os.path.join(base_dir, str(class_label))
                os.makedirs(class_dir, exist_ok=True)
                for idx, sample in enumerate(generated_samples):
                    sample = adjust_sharpness(sample, sharpness_factor=2.0)
                    save_image(sample, os.path.join(class_dir, f"sample_{idx}.png"))
                print(f"Generated {num_samples} samples for Class {class_label} ({dataset.class_names[class_label]}) at epoch {epoch + 1}.")

        # Plot samples
        fig, axs = plt.subplots(len(classes_to_generate), 10, figsize=(20, 4))
        for row, class_label in enumerate(classes_to_generate):
            class_dir = os.path.join(base_dir, str(class_label))
            sample_files = os.listdir(class_dir)
            random_samples = np.random.choice(sample_files, 10, replace=False)
            for col, sample_file in enumerate(random_samples):
                sample_path = os.path.join(class_dir, sample_file)
                sample_image = Image.open(sample_path).convert("RGB")
                sample_image = sample_image.resize((128, 128), Image.LANCZOS)
                sample_image = np.array(sample_image) / 255.0
                ax = axs[row, col] if len(classes_to_generate) > 1 else axs[col]
                ax.imshow(sample_image)
                ax.axis('off')
                if col == 0:
                    ax.set_ylabel(dataset.class_names[class_label], rotation=90, labelpad=10)
        plt.tight_layout()
        plot_path = os.path.join(output_dir, f"synthetic_samples_classes_3_4_epoch_{epoch + 1}.png")
        plt.savefig(plot_path)
        plt.close()
        print(f"Saved synthetic samples plot at epoch {epoch + 1} to {plot_path}")
        cvae.train()

# Final model save
final_model_path = os.path.join(output_dir, "cvae_vehicle_final.pth")
torch.save(cvae.state_dict(), final_model_path)
print(f"Saved final CVAE model to {final_model_path}")

Path to dataset files: C:\Users\hp\.cache\kagglehub\datasets\sujaykapadnis\vehicle-type-image-dataset\versions\1
Searching for images in C:\Users\hp\.cache\kagglehub\datasets\sujaykapadnis\vehicle-type-image-dataset\versions\1
Found 4793 images across 5 classes.
Classes: ['Hatchback', 'Other', 'Pickup', 'Seden', 'SUV']
Epoch 1/1000, Beta: 0.00, Total Loss: 16987.8285, Recon Loss: 33975.5287, KL Loss: 2253.9204, Percep Loss: 0.0641
Epoch 2/1000, Beta: 0.20, Total Loss: 12782.9845, Recon Loss: 25562.7434, KL Loss: 7.8139, Percep Loss: 0.0500
Epoch 3/1000, Beta: 0.40, Total Loss: 12708.2272, Recon Loss: 25416.2942, KL Loss: 0.0983, Percep Loss: 0.0407
Epoch 4/1000, Beta: 0.60, Total Loss: 12653.7986, Recon Loss: 25307.5122, KL Loss: 0.0118, Percep Loss: 0.0354
Epoch 5/1000, Beta: 0.80, Total Loss: 12615.7262, Recon Loss: 25231.3893, KL Loss: 0.0021, Percep Loss: 0.0298
Epoch 6/1000, Beta: 1.00, Total Loss: 12595.2531, Recon Loss: 25190.4501, KL Loss: 0.0017, Percep Loss: 0.0265
Epoch 7/10

In [None]:
##RGB

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.utils import save_image
import torchvision.models as models
from torchvision.transforms.functional import adjust_sharpness
import os
import numpy as np
import matplotlib.pyplot as plt
import kagglehub
from PIL import Image
from torchvision.models import VGG16_Weights

# Clear GPU memory
torch.cuda.empty_cache()

# Set CUDA_LAUNCH_BLOCKING for debugging
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyperparameters
latent_dim = 128
batch_size = 16
epochs = 1000
learning_rate = 5e-4
image_size = 128
channels = 3
output_dir = "FL_CVAE"
beta_max = 5.0
annealing_epochs = 50
perceptual_weight = 1.0
recon_weight = 0.5

# Download Vehicle Type Image Dataset from Kaggle
try:
    path = kagglehub.dataset_download("sujaykapadnis/vehicle-type-image-dataset")
    print("Path to dataset files:", path)
    dataset_path = path
except Exception as e:
    print(f"Failed to download dataset: {e}")
    raise

# Define the VehicleTypeDataset class
class VehicleTypeDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.images = []
        self.labels = []
        self.class_names = []
        self.class_to_idx = {}

        print(f"Searching for images in {root_dir}")
        for root, dirs, files in os.walk(root_dir):
            image_files = [f for f in files if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
            if image_files:
                class_name = os.path.basename(root)
                if class_name not in self.class_to_idx:
                    self.class_names.append(class_name)
                    self.class_to_idx[class_name] = len(self.class_names) - 1
                for img_file in image_files:
                    img_path = os.path.join(root, img_file)
                    self.images.append(img_path)
                    self.labels.append(self.class_to_idx[class_name])

        if not self.images:
            raise ValueError(
                f"No images found in {root_dir}. "
                "Expected class folders containing .jpg, .png, or .jpeg images."
            )

        print(f"Found {len(self.images)} images across {len(self.class_names)} classes.")
        print(f"Classes: {self.class_names}")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, label

# Define transforms
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
])

# Load the dataset
dataset = VehicleTypeDataset(root_dir=dataset_path, transform=transform)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Dynamically set num_classes based on the dataset
num_classes = len(dataset.class_names)
print(f"Number of classes in dataset: {num_classes}")

# Define classes to generate (choose the last two classes if possible)
if num_classes >= 2:
    classes_to_generate = [num_classes - 2, num_classes - 1]  # Last two classes
else:
    classes_to_generate = [0]  # Fallback to the first class if fewer than 2 classes
print(f"Classes to generate: {classes_to_generate}")

# Define the Encoder network
class Encoder(nn.Module):
    def __init__(self, latent_dim, num_classes):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(channels + num_classes, 32, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)
        self.conv5 = nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1)
        self.fc_mean = nn.Linear(512 * 4 * 4, latent_dim)
        self.fc_logvar = nn.Linear(512 * 4 * 4, latent_dim)

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x, y):
        y = F.one_hot(y, num_classes=num_classes).float()
        y = y.unsqueeze(-1).unsqueeze(-1)
        y = y.expand(-1, -1, x.size(2), x.size(3))
        x_with_y = torch.cat([x, y], dim=1)
        
        h1 = F.relu(self.conv1(x_with_y))
        if torch.isnan(h1).any() or torch.isinf(h1).any():
            print("NaN or Inf in h1")
        h2 = F.relu(self.conv2(h1))
        if torch.isnan(h2).any() or torch.isinf(h2).any():
            print("NaN or Inf in h2")
        h3 = F.relu(self.conv3(h2))
        if torch.isnan(h3).any() or torch.isinf(h3).any():
            print("NaN or Inf in h3")
        h4 = F.relu(self.conv4(h3))
        if torch.isnan(h4).any() or torch.isinf(h4).any():
            print("NaN or Inf in h4")
        h5 = F.relu(self.conv5(h4))
        if torch.isnan(h5).any() or torch.isinf(h5).any():
            print("NaN or Inf in h5")
        h = h5.view(h5.size(0), -1)
        z_mean = self.fc_mean(h)
        z_logvar = self.fc_logvar(h)
        return z_mean, z_logvar, (h1, h2, h3, h4, h5)

# Define the Decoder network
class Decoder(nn.Module):
    def __init__(self, latent_dim, num_classes):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(latent_dim + num_classes, 512 * 4 * 4)
        self.deconv1 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1)
        self.deconv2 = nn.ConvTranspose2d(512, 128, kernel_size=4, stride=2, padding=1)
        self.deconv3 = nn.ConvTranspose2d(256, 64, kernel_size=4, stride=2, padding=1)
        self.deconv4 = nn.ConvTranspose2d(128, 32, kernel_size=4, stride=2, padding=1)
        self.deconv5 = nn.ConvTranspose2d(64, channels, kernel_size=4, stride=2, padding=1)

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, z, y, skip_connections):
        h1, h2, h3, h4, h5 = skip_connections
        y = F.one_hot(y, num_classes=num_classes).float()
        z_with_y = torch.cat([z, y], dim=-1)
        h = F.relu(self.fc(z_with_y))
        h = h.view(h.size(0), 512, 4, 4)
        
        h = F.relu(self.deconv1(h))
        h = torch.cat([h, h4], dim=1)
        h = F.relu(self.deconv2(h))
        h = torch.cat([h, h3], dim=1)
        h = F.relu(self.deconv3(h))
        h = torch.cat([h, h2], dim=1)
        h = F.relu(self.deconv4(h))
        h = torch.cat([h, h1], dim=1)
        x_reconstructed = torch.sigmoid(self.deconv5(h))
        if torch.isnan(x_reconstructed).any() or torch.isinf(x_reconstructed).any():
            print("NaN or Inf in x_reconstructed")
        return x_reconstructed

# Conditional VAE model
class ConditionalVAE(nn.Module):
    def __init__(self, encoder, decoder):
        super(ConditionalVAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def reparameterize(self, z_mean, z_logvar):
        std = torch.exp(0.5 * z_logvar)
        eps = torch.randn_like(std)
        return z_mean + eps * std

    def forward(self, x, y):
        z_mean, z_logvar, skip_connections = self.encoder(x, y)
        z = self.reparameterize(z_mean, z_logvar)
        x_reconstructed = self.decoder(z, y, skip_connections)
        return x_reconstructed, z_mean, z_logvar

# Load pretrained VGG16 for perceptual loss
vgg = models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1).features.to(device).eval()
for param in vgg.parameters():
    param.requires_grad = False

def perceptual_loss(x, x_reconstructed):
    x_subset = x[:8]
    x_reconstructed_subset = x_reconstructed[:8]
    mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(x.device)
    std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(x.device)
    x_normalized = (x_subset - mean) / std
    x_reconstructed_normalized = (x_reconstructed_subset - mean) / std
    x_features = vgg(x_normalized)
    x_recon_features = vgg(x_reconstructed_normalized)
    return F.mse_loss(x_features, x_recon_features)

# Instantiate Encoder, Decoder, and CVAE
encoder = Encoder(latent_dim, num_classes).to(device)
decoder = Decoder(latent_dim, num_classes).to(device)
cvae = ConditionalVAE(encoder, decoder).to(device)

# Define optimizer and learning rate scheduler
optimizer = optim.Adam(cvae.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.5)

# Define loss function with beta annealing
def cvae_loss(x, x_reconstructed, z_mean, z_logvar, beta=1.0, recon_weight=1.0, perceptual_weight=1.0):
    if torch.isnan(x).any() or torch.isinf(x).any():
        print("NaN or Inf detected in x")
    if torch.isnan(x_reconstructed).any() or torch.isinf(x_reconstructed).any():
        print("NaN or Inf detected in x_reconstructed")
    if torch.isnan(z_mean).any() or torch.isinf(z_mean).any():
        print("NaN or Inf detected in z_mean")
    if torch.isnan(z_logvar).any() or torch.isinf(z_logvar).any():
        print("NaN or Inf detected in z_logvar")

    recon_loss = F.binary_cross_entropy(x_reconstructed, x, reduction='sum')
    kl_loss = -0.5 * torch.sum(1 + torch.clamp(z_logvar, -5, 5) - z_mean.pow(2) - torch.clamp(z_logvar, -5, 5).exp())
    percep_loss = perceptual_loss(x, x_reconstructed) * perceptual_weight
    total_loss = recon_weight * recon_loss + beta * kl_loss + percep_loss
    return total_loss, recon_loss, kl_loss, percep_loss

# Training loop with checkpointing
cvae.train()
for epoch in range(epochs):
    if epoch < annealing_epochs:
        beta = beta_max * (epoch / annealing_epochs)
    else:
        beta = beta_max

    train_loss = 0
    train_recon_loss = 0
    train_kl_loss = 0
    train_percep_loss = 0
    for batch_idx, (data, labels) in enumerate(train_loader):
        data, labels = data.to(device), labels.to(device)

        optimizer.zero_grad()
        x_reconstructed, z_mean, z_logvar = cvae(data, labels)
        total_loss, recon_loss, kl_loss, percep_loss = cvae_loss(
            data, x_reconstructed, z_mean, z_logvar, 
            beta=beta, recon_weight=recon_weight, perceptual_weight=perceptual_weight
        )
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(cvae.parameters(), max_norm=0.5)
        train_loss += total_loss.item()
        train_recon_loss += recon_loss.item()
        train_kl_loss += kl_loss.item()
        train_percep_loss += percep_loss.item()
        optimizer.step()

    scheduler.step()
    avg_loss = train_loss / len(train_loader.dataset)
    avg_recon_loss = train_recon_loss / len(train_loader.dataset)
    avg_kl_loss = train_kl_loss / len(train_loader.dataset)
    avg_percep_loss = train_percep_loss / len(train_loader.dataset)
    print(f'Epoch {epoch + 1}/{epochs}, Beta: {beta:.2f}, Total Loss: {avg_loss:.4f}, '
          f'Recon Loss: {avg_recon_loss:.4f}, KL Loss: {avg_kl_loss:.4f}, '
          f'Percep Loss: {avg_percep_loss:.4f}')

    # Save checkpoint every 100 epochs
    if (epoch + 1) % 100 == 0:
        checkpoint_path = os.path.join(output_dir, f"cvae_vehicle_epoch_{epoch + 1}.pth")
        torch.save(cvae.state_dict(), checkpoint_path)
        print(f"Saved checkpoint at epoch {epoch + 1} to {checkpoint_path}")

        # Generate and visualize samples
        cvae.eval()
        base_dir = os.path.join(output_dir, f"generated_samples_epoch_{epoch + 1}")
        num_samples = 100
        with torch.no_grad():
            for class_label in classes_to_generate:
                label_tensor = torch.tensor([class_label]).repeat(num_samples).to(device)
                z = torch.randn(num_samples, latent_dim).to(device)
                dummy_skips = [
                    torch.randn(num_samples, 32, 64, 64).to(device) * 0.1,
                    torch.randn(num_samples, 64, 32, 32).to(device) * 0.1,
                    torch.randn(num_samples, 128, 16, 16).to(device) * 0.1,
                    torch.randn(num_samples, 256, 8, 8).to(device) * 0.1,
                    torch.randn(num_samples, 512, 4, 4).to(device) * 0.1
                ]
                generated_samples = cvae.decoder(z, label_tensor, dummy_skips)
                class_dir = os.path.join(base_dir, str(class_label))
                os.makedirs(class_dir, exist_ok=True)
                for idx, sample in enumerate(generated_samples):
                    sample = adjust_sharpness(sample, sharpness_factor=2.0)
                    save_image(sample, os.path.join(class_dir, f"sample_{idx}.png"))
                # Safely access class name
                class_name = dataset.class_names[class_label] if class_label < len(dataset.class_names) else f"Class_{class_label}"
                print(f"Generated {num_samples} samples for Class {class_label} ({class_name}) at epoch {epoch + 1}.")

        # Plot samples
        fig, axs = plt.subplots(len(classes_to_generate), 10, figsize=(20, 4))
        for row, class_label in enumerate(classes_to_generate):
            class_dir = os.path.join(base_dir, str(class_label))
            sample_files = os.listdir(class_dir)
            random_samples = np.random.choice(sample_files, 10, replace=False)
            for col, sample_file in enumerate(random_samples):
                sample_path = os.path.join(class_dir, sample_file)
                sample_image = Image.open(sample_path).convert("RGB")
                sample_image = sample_image.resize((128, 128), Image.LANCZOS)
                sample_image = np.array(sample_image) / 255.0
                ax = axs[row, col] if len(classes_to_generate) > 1 else axs[col]
                ax.imshow(sample_image)
                ax.axis('off')
                if col == 0:
                    # Safely access class name for plotting
                    class_name = dataset.class_names[class_label] if class_label < len(dataset.class_names) else f"Class_{class_label}"
                    ax.set_ylabel(class_name, rotation=90, labelpad=10)
        plt.tight_layout()
        plot_path = os.path.join(output_dir, f"synthetic_samples_classes_{'_'.join(map(str, classes_to_generate))}_epoch_{epoch + 1}.png")
        plt.savefig(plot_path)
        plt.close()
        print(f"Saved synthetic samples plot at epoch {epoch + 1} to {plot_path}")
        cvae.train()

# Final model save
final_model_path = os.path.join(output_dir, "cvae_vehicle_final.pth")
torch.save(cvae.state_dict(), final_model_path)
print(f"Saved final CVAE model to {final_model_path}")

Path to dataset files: C:\Users\hp\.cache\kagglehub\datasets\sujaykapadnis\vehicle-type-image-dataset\versions\1
Searching for images in C:\Users\hp\.cache\kagglehub\datasets\sujaykapadnis\vehicle-type-image-dataset\versions\1
Found 4793 images across 5 classes.
Classes: ['Hatchback', 'Other', 'Pickup', 'Seden', 'SUV']
Number of classes in dataset: 5
Classes to generate: [3, 4]
Epoch 1/1000, Beta: 0.00, Total Loss: 14687.9958, Recon Loss: 29375.7637, KL Loss: 399.6378, Percep Loss: 0.1140
Epoch 2/1000, Beta: 0.10, Total Loss: 12661.4874, Recon Loss: 25322.2727, KL Loss: 2.7964, Percep Loss: 0.0714
Epoch 3/1000, Beta: 0.20, Total Loss: 12599.1264, Recon Loss: 25198.1397, KL Loss: 0.0146, Percep Loss: 0.0536
Epoch 4/1000, Beta: 0.30, Total Loss: 12569.7557, Recon Loss: 25139.4285, KL Loss: 0.0043, Percep Loss: 0.0402
Epoch 5/1000, Beta: 0.40, Total Loss: 12550.0420, Recon Loss: 25100.0218, KL Loss: 0.0011, Percep Loss: 0.0308
Epoch 6/1000, Beta: 0.50, Total Loss: 12543.0925, Recon Loss: 

In [2]:
##NOT SAVING AT EVERY 100 RGB

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.utils import save_image
import torchvision.models as models
from torchvision.transforms.functional import adjust_sharpness
import os
import numpy as np
import matplotlib.pyplot as plt
import kagglehub
from PIL import Image
from torchvision.models import VGG16_Weights

# Clear GPU memory
torch.cuda.empty_cache()

# Set CUDA_LAUNCH_BLOCKING for debugging
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

# Device configuration
device = torch.device('cuda:1' if torch.cuda.device_count() > 1 else 'cuda')

# Hyperparameters
latent_dim = 128
batch_size = 16
epochs = 1000
learning_rate = 5e-4
image_size = 128
channels = 3
output_dir = "FL_CVAE"
beta_max = 5.0
annealing_epochs = 50
perceptual_weight = 1.0
recon_weight = 0.5

# Download Vehicle Type Image Dataset from Kaggle
try:
    path = kagglehub.dataset_download("sujaykapadnis/vehicle-type-image-dataset")
    print("Path to dataset files:", path)
    dataset_path = path
except Exception as e:
    print(f"Failed to download dataset: {e}")
    raise

# Define the VehicleTypeDataset class
class VehicleTypeDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.images = []
        self.labels = []
        self.class_names = []
        self.class_to_idx = {}

        print(f"Searching for images in {root_dir}")
        for root, dirs, files in os.walk(root_dir):
            image_files = [f for f in files if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
            if image_files:
                class_name = os.path.basename(root)
                if class_name not in self.class_to_idx:
                    self.class_names.append(class_name)
                    self.class_to_idx[class_name] = len(self.class_names) - 1
                for img_file in image_files:
                    img_path = os.path.join(root, img_file)
                    self.images.append(img_path)
                    self.labels.append(self.class_to_idx[class_name])

        if not self.images:
            raise ValueError(
                f"No images found in {root_dir}. "
                "Expected class folders containing .jpg, .png, or .jpeg images."
            )

        print(f"Found {len(self.images)} images across {len(self.class_names)} classes.")
        print(f"Classes: {self.class_names}")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, label

# Define transforms
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
])

# Load the dataset
dataset = VehicleTypeDataset(root_dir=dataset_path, transform=transform)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Dynamically set num_classes based on the dataset
num_classes = len(dataset.class_names)
print(f"Number of classes in dataset: {num_classes}")

# Define classes to generate (choose the last two classes if possible)
if num_classes >= 2:
    classes_to_generate = [num_classes - 2, num_classes - 1]  # Last two classes
else:
    classes_to_generate = [0]  # Fallback to the first class if fewer than 2 classes
print(f"Classes to generate: {classes_to_generate}")

# Define the Encoder network
class Encoder(nn.Module):
    def __init__(self, latent_dim, num_classes):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(channels + num_classes, 32, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)
        self.conv5 = nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1)
        self.fc_mean = nn.Linear(512 * 4 * 4, latent_dim)
        self.fc_logvar = nn.Linear(512 * 4 * 4, latent_dim)

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x, y):
        y = F.one_hot(y, num_classes=num_classes).float()
        y = y.unsqueeze(-1).unsqueeze(-1)
        y = y.expand(-1, -1, x.size(2), x.size(3))
        x_with_y = torch.cat([x, y], dim=1)
        
        h1 = F.relu(self.conv1(x_with_y))
        if torch.isnan(h1).any() or torch.isinf(h1).any():
            print("NaN or Inf in h1")
        h2 = F.relu(self.conv2(h1))
        if torch.isnan(h2).any() or torch.isinf(h2).any():
            print("NaN or Inf in h2")
        h3 = F.relu(self.conv3(h2))
        if torch.isnan(h3).any() or torch.isinf(h3).any():
            print("NaN or Inf in h3")
        h4 = F.relu(self.conv4(h3))
        if torch.isnan(h4).any() or torch.isinf(h4).any():
            print("NaN or Inf in h4")
        h5 = F.relu(self.conv5(h4))
        if torch.isnan(h5).any() or torch.isinf(h5).any():
            print("NaN or Inf in h5")
        h = h5.view(h5.size(0), -1)
        z_mean = self.fc_mean(h)
        z_logvar = self.fc_logvar(h)
        return z_mean, z_logvar, (h1, h2, h3, h4, h5)

# Define the Decoder network
class Decoder(nn.Module):
    def __init__(self, latent_dim, num_classes):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(latent_dim + num_classes, 512 * 4 * 4)
        self.deconv1 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1)
        self.deconv2 = nn.ConvTranspose2d(512, 128, kernel_size=4, stride=2, padding=1)
        self.deconv3 = nn.ConvTranspose2d(256, 64, kernel_size=4, stride=2, padding=1)
        self.deconv4 = nn.ConvTranspose2d(128, 32, kernel_size=4, stride=2, padding=1)
        self.deconv5 = nn.ConvTranspose2d(64, channels, kernel_size=4, stride=2, padding=1)

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, z, y, skip_connections):
        h1, h2, h3, h4, h5 = skip_connections
        y = F.one_hot(y, num_classes=num_classes).float()
        z_with_y = torch.cat([z, y], dim=-1)
        h = F.relu(self.fc(z_with_y))
        h = h.view(h.size(0), 512, 4, 4)
        
        h = F.relu(self.deconv1(h))
        h = torch.cat([h, h4], dim=1)
        h = F.relu(self.deconv2(h))
        h = torch.cat([h, h3], dim=1)
        h = F.relu(self.deconv3(h))
        h = torch.cat([h, h2], dim=1)
        h = F.relu(self.deconv4(h))
        h = torch.cat([h, h1], dim=1)
        x_reconstructed = torch.sigmoid(self.deconv5(h))
        if torch.isnan(x_reconstructed).any() or torch.isinf(x_reconstructed).any():
            print("NaN or Inf in x_reconstructed")
        return x_reconstructed

# Conditional VAE model
class ConditionalVAE(nn.Module):
    def __init__(self, encoder, decoder):
        super(ConditionalVAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def reparameterize(self, z_mean, z_logvar):
        std = torch.exp(0.5 * z_logvar)
        eps = torch.randn_like(std)
        return z_mean + eps * std

    def forward(self, x, y):
        z_mean, z_logvar, skip_connections = self.encoder(x, y)
        z = self.reparameterize(z_mean, z_logvar)
        x_reconstructed = self.decoder(z, y, skip_connections)
        return x_reconstructed, z_mean, z_logvar

# Load pretrained VGG16 for perceptual loss
vgg = models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1).features.to(device).eval()
for param in vgg.parameters():
    param.requires_grad = False

def perceptual_loss(x, x_reconstructed):
    x_subset = x[:8]
    x_reconstructed_subset = x_reconstructed[:8]
    mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(x.device)
    std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(x.device)
    x_normalized = (x_subset - mean) / std
    x_reconstructed_normalized = (x_reconstructed_subset - mean) / std
    x_features = vgg(x_normalized)
    x_recon_features = vgg(x_reconstructed_normalized)
    return F.mse_loss(x_features, x_recon_features)

# Instantiate Encoder, Decoder, and CVAE
encoder = Encoder(latent_dim, num_classes).to(device)
decoder = Decoder(latent_dim, num_classes).to(device)
cvae = ConditionalVAE(encoder, decoder).to(device)

# Define optimizer and learning rate scheduler
optimizer = optim.Adam(cvae.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.5)

# Define loss function with beta annealing
def cvae_loss(x, x_reconstructed, z_mean, z_logvar, beta=1.0, recon_weight=1.0, perceptual_weight=1.0):
    if torch.isnan(x).any() or torch.isinf(x).any():
        print("NaN or Inf detected in x")
    if torch.isnan(x_reconstructed).any() or torch.isinf(x_reconstructed).any():
        print("NaN or Inf detected in x_reconstructed")
    if torch.isnan(z_mean).any() or torch.isinf(z_mean).any():
        print("NaN or Inf detected in z_mean")
    if torch.isnan(z_logvar).any() or torch.isinf(z_logvar).any():
        print("NaN or Inf detected in z_logvar")

    recon_loss = F.binary_cross_entropy(x_reconstructed, x, reduction='sum')
    kl_loss = -0.5 * torch.sum(1 + torch.clamp(z_logvar, -5, 5) - z_mean.pow(2) - torch.clamp(z_logvar, -5, 5).exp())
    percep_loss = perceptual_loss(x, x_reconstructed) * perceptual_weight
    total_loss = recon_weight * recon_loss + beta * kl_loss + percep_loss
    return total_loss, recon_loss, kl_loss, percep_loss

# Training loop
cvae.train()
for epoch in range(epochs):
    if epoch < annealing_epochs:
        beta = beta_max * (epoch / annealing_epochs)
    else:
        beta = beta_max

    train_loss = 0
    train_recon_loss = 0
    train_kl_loss = 0
    train_percep_loss = 0
    for batch_idx, (data, labels) in enumerate(train_loader):
        data, labels = data.to(device), labels.to(device)

        optimizer.zero_grad()
        x_reconstructed, z_mean, z_logvar = cvae(data, labels)
        total_loss, recon_loss, kl_loss, percep_loss = cvae_loss(
            data, x_reconstructed, z_mean, z_logvar, 
            beta=beta, recon_weight=recon_weight, perceptual_weight=perceptual_weight
        )
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(cvae.parameters(), max_norm=0.5)
        train_loss += total_loss.item()
        train_recon_loss += recon_loss.item()
        train_kl_loss += kl_loss.item()
        train_percep_loss += percep_loss.item()
        optimizer.step()

    scheduler.step()
    avg_loss = train_loss / len(train_loader.dataset)
    avg_recon_loss = train_recon_loss / len(train_loader.dataset)
    avg_kl_loss = train_kl_loss / len(train_loader.dataset)
    avg_percep_loss = train_percep_loss / len(train_loader.dataset)
    print(f'Epoch {epoch + 1}/{epochs}, Beta: {beta:.2f}, Total Loss: {avg_loss:.4f}, '
          f'Recon Loss: {avg_recon_loss:.4f}, KL Loss: {avg_kl_loss:.4f}, '
          f'Percep Loss: {avg_percep_loss:.4f}')

# Save final model
final_model_path = os.path.join(output_dir, "cvae_vehicle_final.pth")
torch.save(cvae.state_dict(), final_model_path)
print(f"Saved final CVAE model to {final_model_path}")

# Generate and visualize samples at the end
cvae.eval()
base_dir = os.path.join(output_dir, "generated_samples_final")
num_samples = 100
with torch.no_grad():
    for class_label in classes_to_generate:
        label_tensor = torch.tensor([class_label]).repeat(num_samples).to(device)
        z = torch.randn(num_samples, latent_dim).to(device)
        dummy_skips = [
            torch.randn(num_samples, 32, 64, 64).to(device) * 0.1,
            torch.randn(num_samples, 64, 32, 32).to(device) * 0.1,
            torch.randn(num_samples, 128, 16, 16).to(device) * 0.1,
            torch.randn(num_samples, 256, 8, 8).to(device) * 0.1,
            torch.randn(num_samples, 512, 4, 4).to(device) * 0.1
        ]
        generated_samples = cvae.decoder(z, label_tensor, dummy_skips)
        class_dir = os.path.join(base_dir, str(class_label))
        os.makedirs(class_dir, exist_ok=True)
        for idx, sample in enumerate(generated_samples):
            sample = adjust_sharpness(sample, sharpness_factor=2.0)
            save_image(sample, os.path.join(class_dir, f"sample_{idx}.png"))
        # Safely access class name
        class_name = dataset.class_names[class_label] if class_label < len(dataset.class_names) else f"Class_{class_label}"
        print(f"Generated {num_samples} samples for Class {class_label} ({class_name}) at the end of training.")

# Plot samples
fig, axs = plt.subplots(len(classes_to_generate), 10, figsize=(20, 4))
for row, class_label in enumerate(classes_to_generate):
    class_dir = os.path.join(base_dir, str(class_label))
    sample_files = os.listdir(class_dir)
    random_samples = np.random.choice(sample_files, 10, replace=False)
    for col, sample_file in enumerate(random_samples):
        sample_path = os.path.join(class_dir, sample_file)
        sample_image = Image.open(sample_path).convert("RGB")
        sample_image = sample_image.resize((128, 128), Image.LANCZOS)
        sample_image = np.array(sample_image) / 255.0
        ax = axs[row, col] if len(classes_to_generate) > 1 else axs[col]
        ax.imshow(sample_image)
        ax.axis('off')
        if col == 0:
            # Safely access class name for plotting
            class_name = dataset.class_names[class_label] if class_label < len(dataset.class_names) else f"Class_{class_label}"
            ax.set_ylabel(class_name, rotation=90, labelpad=10)
plt.tight_layout()
plot_path = os.path.join(output_dir, f"synthetic_samples_classes_{'_'.join(map(str, classes_to_generate))}_final.png")
plt.savefig(plot_path)
plt.close()
print(f"Saved synthetic samples plot to {plot_path}")

Path to dataset files: C:\Users\hp\.cache\kagglehub\datasets\sujaykapadnis\vehicle-type-image-dataset\versions\1
Searching for images in C:\Users\hp\.cache\kagglehub\datasets\sujaykapadnis\vehicle-type-image-dataset\versions\1
Found 4793 images across 5 classes.
Classes: ['Hatchback', 'Other', 'Pickup', 'Seden', 'SUV']
Number of classes in dataset: 5
Classes to generate: [3, 4]
Epoch 1/1000, Beta: 0.00, Total Loss: 21212.9881, Recon Loss: 42425.7422, KL Loss: 568.2313, Percep Loss: 0.1169
Epoch 2/1000, Beta: 0.10, Total Loss: 12676.3793, Recon Loss: 25352.1531, KL Loss: 2.2939, Percep Loss: 0.0733
Epoch 3/1000, Beta: 0.20, Total Loss: 12612.4983, Recon Loss: 25224.8883, KL Loss: 0.0052, Percep Loss: 0.0532
Epoch 4/1000, Beta: 0.30, Total Loss: 12579.3610, Recon Loss: 25158.6395, KL Loss: 0.0021, Percep Loss: 0.0407
Epoch 5/1000, Beta: 0.40, Total Loss: 12554.8546, Recon Loss: 25109.6483, KL Loss: 0.0007, Percep Loss: 0.0303
Epoch 6/1000, Beta: 0.50, Total Loss: 12542.9959, Recon Loss: 