In [2]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"


In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.utils import save_image
import torchvision.models as models
from torchvision.transforms.functional import adjust_sharpness
import os
import numpy as np
import matplotlib.pyplot as plt
import kagglehub
from PIL import Image
from torchvision.models import VGG16_Weights

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyperparameters
latent_dim = 64
num_classes = 5
batch_size = 16
epochs = 200  # Reduced to 200 for faster iteration
learning_rate = 1e-3
image_size = 128
channels = 3
output_dir = "FL_CVAE"
beta_max = 50.0  # Increased for stronger KL regularization
annealing_epochs = 20  # Faster annealing
perceptual_weight = 10.0  # Increased for sharper images
recon_weight = 0.5  # Balance reconstruction loss

# Download Vehicle Type Image Dataset from Kaggle
try:
    path = kagglehub.dataset_download("sujaykapadnis/vehicle-type-image-dataset")
    print("Path to dataset files:", path)
    dataset_path = path
except Exception as e:
    print(f"Failed to download dataset: {e}")
    raise

# Define the VehicleTypeDataset class
class VehicleTypeDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.images = []
        self.labels = []
        self.class_names = []
        self.class_to_idx = {}

        print(f"Searching for images in {root_dir}")
        for root, dirs, files in os.walk(root_dir):
            image_files = [f for f in files if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
            if image_files:
                class_name = os.path.basename(root)
                if class_name not in self.class_to_idx:
                    self.class_names.append(class_name)
                    self.class_to_idx[class_name] = len(self.class_names) - 1
                for img_file in image_files:
                    img_path = os.path.join(root, img_file)
                    self.images.append(img_path)
                    self.labels.append(self.class_to_idx[class_name])

        if not self.images:
            raise ValueError(
                f"No images found in {root_dir}. "
                "Expected class folders containing .jpg, .png, or .jpeg images."
            )

        print(f"Found {len(self.images)} images across {len(self.class_names)} classes.")
        print(f"Classes: {self.class_names}")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, label

# Define transforms
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
])

# Load the dataset
dataset = VehicleTypeDataset(root_dir=dataset_path, transform=transform)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Define the Encoder network
class Encoder(nn.Module):
    def __init__(self, latent_dim, num_classes):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(channels + num_classes, 32, kernel_size=4, stride=2, padding=1)  # 32x64x64
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)  # 64x32x32
        self.conv3 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)  # 128x16x16
        self.conv4 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)  # 256x8x8
        self.conv5 = nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1)  # 512x4x4
        self.fc_mean = nn.Linear(512 * 4 * 4, latent_dim)
        self.fc_logvar = nn.Linear(512 * 4 * 4, latent_dim)

    def forward(self, x, y):
        y = F.one_hot(y, num_classes=num_classes).float()
        y = y.unsqueeze(-1).unsqueeze(-1)
        y = y.expand(-1, -1, x.size(2), x.size(3))
        x_with_y = torch.cat([x, y], dim=1)
        
        h1 = F.relu(self.conv1(x_with_y))
        h2 = F.relu(self.conv2(h1))
        h3 = F.relu(self.conv3(h2))
        h4 = F.relu(self.conv4(h3))
        h5 = F.relu(self.conv5(h4))
        h = h5.view(h5.size(0), -1)
        z_mean = self.fc_mean(h)
        z_logvar = self.fc_logvar(h)
        return z_mean, z_logvar, (h1, h2, h3, h4, h5)

# Define the Decoder network
class Decoder(nn.Module):
    def __init__(self, latent_dim, num_classes):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(latent_dim + num_classes, 512 * 4 * 4)
        self.deconv1 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1)  # 256x8x8
        self.deconv2 = nn.ConvTranspose2d(512, 128, kernel_size=4, stride=2, padding=1)  # 128x16x16
        self.deconv3 = nn.ConvTranspose2d(256, 64, kernel_size=4, stride=2, padding=1)  # 64x32x32
        self.deconv4 = nn.ConvTranspose2d(128, 32, kernel_size=4, stride=2, padding=1)  # 32x64x64
        self.deconv5 = nn.ConvTranspose2d(64, channels, kernel_size=4, stride=2, padding=1)  # 3x128x128

    def forward(self, z, y, skip_connections):
        h1, h2, h3, h4, h5 = skip_connections
        y = F.one_hot(y, num_classes=num_classes).float()
        z_with_y = torch.cat([z, y], dim=-1)
        h = F.relu(self.fc(z_with_y))
        h = h.view(h.size(0), 512, 4, 4)
        
        h = F.relu(self.deconv1(h))
        h = torch.cat([h, h4], dim=1)
        h = F.relu(self.deconv2(h))
        h = torch.cat([h, h3], dim=1)
        h = F.relu(self.deconv3(h))
        h = torch.cat([h, h2], dim=1)
        h = F.relu(self.deconv4(h))
        h = torch.cat([h, h1], dim=1)
        h = self.deconv5(h)
        # Remove sigmoid, clamp to [0, 1]
        x_reconstructed = torch.clamp(h, 0, 1)
        return x_reconstructed

# Conditional VAE model
class ConditionalVAE(nn.Module):
    def __init__(self, encoder, decoder):
        super(ConditionalVAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def reparameterize(self, z_mean, z_logvar):
        std = torch.exp(0.5 * z_logvar)
        eps = torch.randn_like(std)
        return z_mean + eps * std

    def forward(self, x, y):
        z_mean, z_logvar, skip_connections = self.encoder(x, y)
        z = self.reparameterize(z_mean, z_logvar)
        x_reconstructed = self.decoder(z, y, skip_connections)
        return x_reconstructed, z_mean, z_logvar

# Load pretrained VGG16 for perceptual loss
vgg = models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1).features.to(device).eval()
for param in vgg.parameters():
    param.requires_grad = False

def perceptual_loss(x, x_reconstructed):
    x_features = vgg(x)
    x_recon_features = vgg(x_reconstructed)
    return F.mse_loss(x_features, x_recon_features)

# Instantiate Encoder, Decoder, and CVAE
encoder = Encoder(latent_dim, num_classes).to(device)
decoder = Decoder(latent_dim, num_classes).to(device)
cvae = ConditionalVAE(encoder, decoder).to(device)

# Define optimizer and learning rate scheduler
optimizer = optim.Adam(cvae.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)

# Define loss function with weights
def cvae_loss(x, x_reconstructed, z_mean, z_logvar, beta=1.0, recon_weight=1.0, perceptual_weight=1.0):
    recon_loss = F.binary_cross_entropy(x_reconstructed, x, reduction='sum')
    kl_loss = -0.5 * torch.sum(1 + z_logvar - z_mean.pow(2) - z_logvar.exp())
    percep_loss = perceptual_loss(x, x_reconstructed) * perceptual_weight
    total_loss = recon_weight * recon_loss + beta * kl_loss + percep_loss
    return total_loss, recon_loss, kl_loss, percep_loss

# Training loop with checkpointing
cvae.train()
for epoch in range(epochs):
    # Compute beta for KL annealing
    if epoch < annealing_epochs:
        beta = beta_max * (epoch / annealing_epochs)
    else:
        beta = beta_max

    train_loss = 0
    train_recon_loss = 0
    train_kl_loss = 0
    train_percep_loss = 0
    for batch_idx, (data, labels) in enumerate(train_loader):
        data, labels = data.to(device), labels.to(device)

        optimizer.zero_grad()
        x_reconstructed, z_mean, z_logvar = cvae(data, labels)
        total_loss, recon_loss, kl_loss, percep_loss = cvae_loss(
            data, x_reconstructed, z_mean, z_logvar, 
            beta=beta, recon_weight=recon_weight, perceptual_weight=perceptual_weight
        )
        total_loss.backward()
        train_loss += total_loss.item()
        train_recon_loss += recon_loss.item()
        train_kl_loss += kl_loss.item()
        train_percep_loss += percep_loss.item()
        optimizer.step()

    scheduler.step()
    avg_loss = train_loss / len(train_loader.dataset)
    avg_recon_loss = train_recon_loss / len(train_loader.dataset)
    avg_kl_loss = train_kl_loss / len(train_loader.dataset)
    avg_percep_loss = train_percep_loss / len(train_loader.dataset)
    print(f'Epoch {epoch + 1}/{epochs}, Beta: {beta:.2f}, Total Loss: {avg_loss:.4f}, '
          f'Recon Loss: {avg_recon_loss:.4f}, KL Loss: {avg_kl_loss:.4f}, '
          f'Percep Loss: {avg_percep_loss:.4f}')

    # Save checkpoint every 50 epochs
    if (epoch + 1) % 50 == 0:
        checkpoint_path = os.path.join(output_dir, f"cvae_vehicle_epoch_{epoch + 1}_v3.pth")
        torch.save(cvae.state_dict(), checkpoint_path)
        print(f"Saved checkpoint at epoch {epoch + 1} to {checkpoint_path}")

        # Generate and visualize samples at this checkpoint
        cvae.eval()
        base_dir = os.path.join(output_dir, f"generated_samples_epoch_{epoch + 1}_v3")
        classes_to_generate = [3, 4]
        num_samples = 500
        with torch.no_grad():
            for class_label in classes_to_generate:
                label_tensor = torch.tensor([class_label]).repeat(num_samples).to(device)
                z = torch.randn(num_samples, latent_dim).to(device)
                # Create dummy skip connections with small random noise
                dummy_skips = [
                    torch.randn(num_samples, 32, 64, 64).to(device) * 0.1,
                    torch.randn(num_samples, 64, 32, 32).to(device) * 0.1,
                    torch.randn(num_samples, 128, 16, 16).to(device) * 0.1,
                    torch.randn(num_samples, 256, 8, 8).to(device) * 0.1,
                    torch.randn(num_samples, 512, 4, 4).to(device) * 0.1
                ]
                generated_samples = cvae.decoder(z, label_tensor, dummy_skips)
                class_dir = os.path.join(base_dir, str(class_label))
                os.makedirs(class_dir, exist_ok=True)
                for idx, sample in enumerate(generated_samples):
                    sample = adjust_sharpness(sample, sharpness_factor=2.0)
                    save_image(sample, os.path.join(class_dir, f"sample_{idx}.png"))
                print(f"Generated {num_samples} samples for Class {class_label} ({dataset.class_names[class_label]}) at epoch {epoch + 1}.")

        # Plot samples
        fig, axs = plt.subplots(len(classes_to_generate), 10, figsize=(20, 4))
        for row, class_label in enumerate(classes_to_generate):
            class_dir = os.path.join(base_dir, str(class_label))
            sample_files = os.listdir(class_dir)
            random_samples = np.random.choice(sample_files, 10, replace=False)
            for col, sample_file in enumerate(random_samples):
                sample_path = os.path.join(class_dir, sample_file)
                sample_image = Image.open(sample_path).convert("RGB")
                sample_image = sample_image.resize((128, 128), Image.LANCZOS)
                sample_image = np.array(sample_image) / 255.0
                ax = axs[row, col] if len(classes_to_generate) > 1 else axs[col]
                ax.imshow(sample_image)
                ax.axis('off')
                if col == 0:
                    ax.set_ylabel(dataset.class_names[class_label], rotation=90, labelpad=10)
        plt.tight_layout()
        plot_path = os.path.join(output_dir, f"synthetic_samples_classes_3_4_epoch_{epoch + 1}_v3.png")
        plt.savefig(plot_path)
        plt.close()
        print(f"Saved synthetic samples plot at epoch {epoch + 1} to {plot_path}")
        cvae.train()

# Final model save
final_model_path = os.path.join(output_dir, "cvae_vehicle_final_v3.pth")
torch.save(cvae.state_dict(), final_model_path)
print(f"Saved final CVAE model to {final_model_path}")

Path to dataset files: C:\Users\hp\.cache\kagglehub\datasets\sujaykapadnis\vehicle-type-image-dataset\versions\1
Searching for images in C:\Users\hp\.cache\kagglehub\datasets\sujaykapadnis\vehicle-type-image-dataset\versions\1
Found 4793 images across 5 classes.
Classes: ['Hatchback', 'Other', 'Pickup', 'Seden', 'SUV']


RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
# ✅ Updated CVAE Training Code for Multi-GPU Awareness & Error Debugging

import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.utils import save_image
from torchvision.transforms.functional import adjust_sharpness
from torchvision.models import VGG16_Weights
from PIL import Image
import kagglehub
import matplotlib.pyplot as plt

# Set CUDA_LAUNCH_BLOCKING for debug clarity
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

# 🖥️ Device Configuration
print(f"CUDA available: {torch.cuda.is_available()}")
device = torch.device("cuda:1" if torch.cuda.device_count() > 1 else "cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device} ({torch.cuda.get_device_name(device) if torch.cuda.is_available() else 'CPU'})")

# 🔽 Your existing code here (imports, class definitions, model architecture)...
# ⏩ Keep all the definitions as is from your current code
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.utils import save_image
import torchvision.models as models
from torchvision.transforms.functional import adjust_sharpness
import os
import numpy as np
import matplotlib.pyplot as plt
import kagglehub
from PIL import Image
from torchvision.models import VGG16_Weights

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyperparameters
latent_dim = 64
num_classes = 5
batch_size = 16
epochs = 200  # Reduced to 200 for faster iteration
learning_rate = 1e-3
image_size = 128
channels = 3
output_dir = "FL_CVAE"
beta_max = 50.0  # Increased for stronger KL regularization
annealing_epochs = 20  # Faster annealing
perceptual_weight = 10.0  # Increased for sharper images
recon_weight = 0.5  # Balance reconstruction loss

# Download Vehicle Type Image Dataset from Kaggle
try:
    path = kagglehub.dataset_download("sujaykapadnis/vehicle-type-image-dataset")
    print("Path to dataset files:", path)
    dataset_path = path
except Exception as e:
    print(f"Failed to download dataset: {e}")
    raise

# Define the VehicleTypeDataset class
class VehicleTypeDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.images = []
        self.labels = []
        self.class_names = []
        self.class_to_idx = {}

        print(f"Searching for images in {root_dir}")
        for root, dirs, files in os.walk(root_dir):
            image_files = [f for f in files if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
            if image_files:
                class_name = os.path.basename(root)
                if class_name not in self.class_to_idx:
                    self.class_names.append(class_name)
                    self.class_to_idx[class_name] = len(self.class_names) - 1
                for img_file in image_files:
                    img_path = os.path.join(root, img_file)
                    self.images.append(img_path)
                    self.labels.append(self.class_to_idx[class_name])

        if not self.images:
            raise ValueError(
                f"No images found in {root_dir}. "
                "Expected class folders containing .jpg, .png, or .jpeg images."
            )

        print(f"Found {len(self.images)} images across {len(self.class_names)} classes.")
        print(f"Classes: {self.class_names}")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, label

# Define transforms
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
])

# Load the dataset
dataset = VehicleTypeDataset(root_dir=dataset_path, transform=transform)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Define the Encoder network
class Encoder(nn.Module):
    def __init__(self, latent_dim, num_classes):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(channels + num_classes, 32, kernel_size=4, stride=2, padding=1)  # 32x64x64
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)  # 64x32x32
        self.conv3 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)  # 128x16x16
        self.conv4 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)  # 256x8x8
        self.conv5 = nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1)  # 512x4x4
        self.fc_mean = nn.Linear(512 * 4 * 4, latent_dim)
        self.fc_logvar = nn.Linear(512 * 4 * 4, latent_dim)

    def forward(self, x, y):
        y = F.one_hot(y, num_classes=num_classes).float()
        y = y.unsqueeze(-1).unsqueeze(-1)
        y = y.expand(-1, -1, x.size(2), x.size(3))
        x_with_y = torch.cat([x, y], dim=1)
        
        h1 = F.relu(self.conv1(x_with_y))
        h2 = F.relu(self.conv2(h1))
        h3 = F.relu(self.conv3(h2))
        h4 = F.relu(self.conv4(h3))
        h5 = F.relu(self.conv5(h4))
        h = h5.view(h5.size(0), -1)
        z_mean = self.fc_mean(h)
        z_logvar = self.fc_logvar(h)
        return z_mean, z_logvar, (h1, h2, h3, h4, h5)

# Define the Decoder network
class Decoder(nn.Module):
    def __init__(self, latent_dim, num_classes):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(latent_dim + num_classes, 512 * 4 * 4)
        self.deconv1 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1)  # 256x8x8
        self.deconv2 = nn.ConvTranspose2d(512, 128, kernel_size=4, stride=2, padding=1)  # 128x16x16
        self.deconv3 = nn.ConvTranspose2d(256, 64, kernel_size=4, stride=2, padding=1)  # 64x32x32
        self.deconv4 = nn.ConvTranspose2d(128, 32, kernel_size=4, stride=2, padding=1)  # 32x64x64
        self.deconv5 = nn.ConvTranspose2d(64, channels, kernel_size=4, stride=2, padding=1)  # 3x128x128

    def forward(self, z, y, skip_connections):
        h1, h2, h3, h4, h5 = skip_connections
        y = F.one_hot(y, num_classes=num_classes).float()
        z_with_y = torch.cat([z, y], dim=-1)
        h = F.relu(self.fc(z_with_y))
        h = h.view(h.size(0), 512, 4, 4)
        
        h = F.relu(self.deconv1(h))
        h = torch.cat([h, h4], dim=1)
        h = F.relu(self.deconv2(h))
        h = torch.cat([h, h3], dim=1)
        h = F.relu(self.deconv3(h))
        h = torch.cat([h, h2], dim=1)
        h = F.relu(self.deconv4(h))
        h = torch.cat([h, h1], dim=1)
        h = self.deconv5(h)
        # Remove sigmoid, clamp to [0, 1]
        x_reconstructed = torch.clamp(h, 0, 1)
        return x_reconstructed

# Conditional VAE model
class ConditionalVAE(nn.Module):
    def __init__(self, encoder, decoder):
        super(ConditionalVAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def reparameterize(self, z_mean, z_logvar):
        std = torch.exp(0.5 * z_logvar)
        eps = torch.randn_like(std)
        return z_mean + eps * std

    def forward(self, x, y):
        z_mean, z_logvar, skip_connections = self.encoder(x, y)
        z = self.reparameterize(z_mean, z_logvar)
        x_reconstructed = self.decoder(z, y, skip_connections)
        return x_reconstructed, z_mean, z_logvar

# Load pretrained VGG16 for perceptual loss
vgg = models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1).features.to(device).eval()
for param in vgg.parameters():
    param.requires_grad = False

def perceptual_loss(x, x_reconstructed):
    x_features = vgg(x)
    x_recon_features = vgg(x_reconstructed)
    return F.mse_loss(x_features, x_recon_features)

# Instantiate Encoder, Decoder, and CVAE
encoder = Encoder(latent_dim, num_classes).to(device)
decoder = Decoder(latent_dim, num_classes).to(device)
cvae = ConditionalVAE(encoder, decoder).to(device)

# Define optimizer and learning rate scheduler
optimizer = optim.Adam(cvae.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)

# Define loss function with weights
def cvae_loss(x, x_reconstructed, z_mean, z_logvar, beta=1.0, recon_weight=1.0, perceptual_weight=1.0):
    recon_loss = F.binary_cross_entropy(x_reconstructed, x, reduction='sum')
    kl_loss = -0.5 * torch.sum(1 + z_logvar - z_mean.pow(2) - z_logvar.exp())
    percep_loss = perceptual_loss(x, x_reconstructed) * perceptual_weight
    total_loss = recon_weight * recon_loss + beta * kl_loss + percep_loss
    return total_loss, recon_loss, kl_loss, percep_loss

# Training loop with checkpointing
cvae.train()
for epoch in range(epochs):
    # Compute beta for KL annealing
    if epoch < annealing_epochs:
        beta = beta_max * (epoch / annealing_epochs)
    else:
        beta = beta_max

    train_loss = 0
    train_recon_loss = 0
    train_kl_loss = 0
    train_percep_loss = 0
    for batch_idx, (data, labels) in enumerate(train_loader):
        data, labels = data.to(device), labels.to(device)

        optimizer.zero_grad()
        x_reconstructed, z_mean, z_logvar = cvae(data, labels)
        total_loss, recon_loss, kl_loss, percep_loss = cvae_loss(
            data, x_reconstructed, z_mean, z_logvar, 
            beta=beta, recon_weight=recon_weight, perceptual_weight=perceptual_weight
        )
        total_loss.backward()
        train_loss += total_loss.item()
        train_recon_loss += recon_loss.item()
        train_kl_loss += kl_loss.item()
        train_percep_loss += percep_loss.item()
        optimizer.step()

    scheduler.step()
    avg_loss = train_loss / len(train_loader.dataset)
    avg_recon_loss = train_recon_loss / len(train_loader.dataset)
    avg_kl_loss = train_kl_loss / len(train_loader.dataset)
    avg_percep_loss = train_percep_loss / len(train_loader.dataset)
    print(f'Epoch {epoch + 1}/{epochs}, Beta: {beta:.2f}, Total Loss: {avg_loss:.4f}, '
          f'Recon Loss: {avg_recon_loss:.4f}, KL Loss: {avg_kl_loss:.4f}, '
          f'Percep Loss: {avg_percep_loss:.4f}')

    # Save checkpoint every 50 epochs
    if (epoch + 1) % 50 == 0:
        checkpoint_path = os.path.join(output_dir, f"cvae_vehicle_epoch_{epoch + 1}_v3.pth")
        torch.save(cvae.state_dict(), checkpoint_path)
        print(f"Saved checkpoint at epoch {epoch + 1} to {checkpoint_path}")

        # Generate and visualize samples at this checkpoint
        cvae.eval()
        base_dir = os.path.join(output_dir, f"generated_samples_epoch_{epoch + 1}_v3")
        classes_to_generate = [3, 4]
        num_samples = 500
        with torch.no_grad():
            for class_label in classes_to_generate:
                label_tensor = torch.tensor([class_label]).repeat(num_samples).to(device)
                z = torch.randn(num_samples, latent_dim).to(device)
                # Create dummy skip connections with small random noise
                dummy_skips = [
                    torch.randn(num_samples, 32, 64, 64).to(device) * 0.1,
                    torch.randn(num_samples, 64, 32, 32).to(device) * 0.1,
                    torch.randn(num_samples, 128, 16, 16).to(device) * 0.1,
                    torch.randn(num_samples, 256, 8, 8).to(device) * 0.1,
                    torch.randn(num_samples, 512, 4, 4).to(device) * 0.1
                ]
                generated_samples = cvae.decoder(z, label_tensor, dummy_skips)
                class_dir = os.path.join(base_dir, str(class_label))
                os.makedirs(class_dir, exist_ok=True)
                for idx, sample in enumerate(generated_samples):
                    sample = adjust_sharpness(sample, sharpness_factor=2.0)
                    save_image(sample, os.path.join(class_dir, f"sample_{idx}.png"))
                print(f"Generated {num_samples} samples for Class {class_label} ({dataset.class_names[class_label]}) at epoch {epoch + 1}.")

        # Plot samples
        fig, axs = plt.subplots(len(classes_to_generate), 10, figsize=(20, 4))
        for row, class_label in enumerate(classes_to_generate):
            class_dir = os.path.join(base_dir, str(class_label))
            sample_files = os.listdir(class_dir)
            random_samples = np.random.choice(sample_files, 10, replace=False)
            for col, sample_file in enumerate(random_samples):
                sample_path = os.path.join(class_dir, sample_file)
                sample_image = Image.open(sample_path).convert("RGB")
                sample_image = sample_image.resize((128, 128), Image.LANCZOS)
                sample_image = np.array(sample_image) / 255.0
                ax = axs[row, col] if len(classes_to_generate) > 1 else axs[col]
                ax.imshow(sample_image)
                ax.axis('off')
                if col == 0:
                    ax.set_ylabel(dataset.class_names[class_label], rotation=90, labelpad=10)
        plt.tight_layout()
        plot_path = os.path.join(output_dir, f"synthetic_samples_classes_3_4_epoch_{epoch + 1}_v3.png")
        plt.savefig(plot_path)
        plt.close()
        print(f"Saved synthetic samples plot at epoch {epoch + 1} to {plot_path}")
        cvae.train()

# Final model save
final_model_path = os.path.join(output_dir, "cvae_vehicle_final_v3.pth")
torch.save(cvae.state_dict(), final_model_path)
print(f"Saved final CVAE model to {final_model_path}")

# ✅ Check for valid labels before training
print("Validating labels...")
labels = [label for _, label in dataset]
invalid_labels = [label for label in labels if label >= num_classes or label < 0]
assert len(invalid_labels) == 0, f"Found invalid labels: {invalid_labels}"

# ✅ Wrap training loop with try/except
try:
    for epoch in range(epochs):
        if epoch < annealing_epochs:
            beta = beta_max * (epoch / annealing_epochs)
        else:
            beta = beta_max

        cvae.train()
        train_loss = 0
        train_recon_loss = 0
        train_kl_loss = 0
        train_percep_loss = 0

        for batch_idx, (data, labels) in enumerate(train_loader):
            data, labels = data.to(device), labels.to(device)

            optimizer.zero_grad()
            x_reconstructed, z_mean, z_logvar = cvae(data, labels)
            total_loss, recon_loss, kl_loss, percep_loss = cvae_loss(
                data, x_reconstructed, z_mean, z_logvar,
                beta=beta, recon_weight=recon_weight, perceptual_weight=perceptual_weight
            )
            total_loss.backward()
            optimizer.step()

            train_loss += total_loss.item()
            train_recon_loss += recon_loss.item()
            train_kl_loss += kl_loss.item()
            train_percep_loss += percep_loss.item()

        scheduler.step()

        avg_loss = train_loss / len(train_loader.dataset)
        print(f"Epoch {epoch + 1}/{epochs}, Loss: {avg_loss:.4f}, Beta: {beta:.2f}")

except RuntimeError as e:
    print("\nCUDA RuntimeError triggered:")
    print(e)
    print("\n🔎 Try checking the labels, device mismatch, or tensor shapes.")
    torch.cuda.empty_cache()
    raise


CUDA available: True
Using device: cuda:1 (NVIDIA RTX A5000)
Path to dataset files: C:\Users\hp\.cache\kagglehub\datasets\sujaykapadnis\vehicle-type-image-dataset\versions\1
Searching for images in C:\Users\hp\.cache\kagglehub\datasets\sujaykapadnis\vehicle-type-image-dataset\versions\1
Found 4793 images across 5 classes.
Classes: ['Hatchback', 'Other', 'Pickup', 'Seden', 'SUV']
Epoch 1/200, Beta: 0.00, Total Loss: 26174.1840, Recon Loss: 52347.1708, KL Loss: 597.5877, Percep Loss: 0.5986
Epoch 2/200, Beta: 2.50, Total Loss: 25397.4838, Recon Loss: 50621.8584, KL Loss: 34.3796, Percep Loss: 0.6055
Epoch 3/200, Beta: 5.00, Total Loss: 22136.9983, Recon Loss: 44267.0337, KL Loss: 0.5784, Percep Loss: 0.5895
Epoch 4/200, Beta: 7.50, Total Loss: 29973.4969, Recon Loss: 59845.7363, KL Loss: 6.6682, Percep Loss: 0.6170
Epoch 5/200, Beta: 10.00, Total Loss: 32555.9683, Recon Loss: 64536.4847, KL Loss: 28.7111, Percep Loss: 0.6144
Epoch 6/200, Beta: 12.50, Total Loss: 26916.0161, Recon Loss: 5

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.utils import save_image
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

# Configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Hyperparameters (optimized for stability and quality)
latent_dim = 256  # Increased for better representation
num_classes = 5
batch_size = 32
epochs = 300  # Set to 300 epochs
learning_rate = 3e-4
image_size = 128
channels = 3  # RGB
output_dir = "CVAE_GENERATED_SAMPLES"
os.makedirs(output_dir, exist_ok=True)

# Memory management
torch.cuda.empty_cache()

# Define transforms
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize to [-1, 1]
])

# Encoder Network
class Encoder(nn.Module):
    def __init__(self, latent_dim, num_classes):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(channels + num_classes, 64, 4, 2, 1)
        self.conv2 = nn.Conv2d(64, 128, 4, 2, 1)
        self.conv3 = nn.Conv2d(128, 256, 4, 2, 1)
        self.conv4 = nn.Conv2d(256, 512, 4, 2, 1)
        self.conv5 = nn.Conv2d(512, 512, 4, 2, 1)
        self.fc_mu = nn.Linear(512 * 4 * 4, latent_dim)
        self.fc_var = nn.Linear(512 * 4 * 4, latent_dim)
        
        # Initialize weights
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0)

    def forward(self, x, y):
        # One-hot encode and reshape label
        y = F.one_hot(y, num_classes=num_classes).float()
        y = y.view(-1, num_classes, 1, 1).expand(-1, -1, x.size(2), x.size(3))
        
        # Concatenate image and label
        x = torch.cat([x, y], dim=1)
        
        # Encoder path
        h1 = F.leaky_relu(self.conv1(x), 0.2)
        h2 = F.leaky_relu(self.conv2(h1), 0.2)
        h3 = F.leaky_relu(self.conv3(h2), 0.2)
        h4 = F.leaky_relu(self.conv4(h3), 0.2)
        h5 = F.leaky_relu(self.conv5(h4), 0.2)
        
        h5 = h5.view(h5.size(0), -1)
        mu = self.fc_mu(h5)
        log_var = self.fc_var(h5)
        
        return mu, log_var, (h1, h2, h3, h4)

# Decoder Network
class Decoder(nn.Module):
    def __init__(self, latent_dim, num_classes):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(latent_dim + num_classes, 512 * 4 * 4)
        
        self.deconv1 = nn.ConvTranspose2d(512, 512, 4, 2, 1)
        self.deconv2 = nn.ConvTranspose2d(512 + 512, 256, 4, 2, 1)  # + skip connection
        self.deconv3 = nn.ConvTranspose2d(256 + 256, 128, 4, 2, 1)  # + skip connection
        self.deconv4 = nn.ConvTranspose2d(128 + 128, 64, 4, 2, 1)   # + skip connection
        self.deconv5 = nn.ConvTranspose2d(64 + 64, channels, 4, 2, 1)  # + skip connection
        
        # Initialize weights
        for m in self.modules():
            if isinstance(m, nn.ConvTranspose2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0)

    def forward(self, z, y, skip_connections):
        h1, h2, h3, h4 = skip_connections
        
        # One-hot encode label
        y = F.one_hot(y, num_classes=num_classes).float()
        z = torch.cat([z, y], dim=1)
        
        # Project and reshape
        h = F.leaky_relu(self.fc(z), 0.2)
        h = h.view(-1, 512, 4, 4)
        
        # Decoder path with skip connections
        h = F.leaky_relu(self.deconv1(h), 0.2)
        h = torch.cat([h, h4], dim=1)
        
        h = F.leaky_relu(self.deconv2(h), 0.2)
        h = torch.cat([h, h3], dim=1)
        
        h = F.leaky_relu(self.deconv3(h), 0.2)
        h = torch.cat([h, h2], dim=1)
        
        h = F.leaky_relu(self.deconv4(h), 0.2)
        h = torch.cat([h, h1], dim=1)
        
        # Final layer with tanh activation for [-1, 1] range
        x_recon = torch.tanh(self.deconv5(h))
        
        return x_recon

# CVAE Model
class CVAE(nn.Module):
    def __init__(self, encoder, decoder):
        super(CVAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x, y):
        mu, log_var, skips = self.encoder(x, y)
        z = self.reparameterize(mu, log_var)
        x_recon = self.decoder(z, y, skips)
        return x_recon, mu, log_var

# Loss function
def cvae_loss(x_recon, x, mu, log_var):
    # Reconstruction loss (MSE for stability)
    recon_loss = F.mse_loss(x_recon, x, reduction='sum')
    
    # KL divergence
    kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    
    return recon_loss + kl_loss, recon_loss, kl_loss

# Training function with memory management
def train_cvae(model, train_loader, epochs, device):
    model.train()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10, factor=0.5)
    
    for epoch in range(epochs):
        total_loss = 0
        recon_loss = 0
        kl_loss = 0
        
        for batch_idx, (x, y) in enumerate(train_loader):
            x, y = x.to(device), y.to(device)
            
            optimizer.zero_grad()
            x_recon, mu, log_var = model(x, y)
            
            loss, r_loss, k_loss = cvae_loss(x_recon, x, mu, log_var)
            loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            
            optimizer.step()
            
            total_loss += loss.item()
            recon_loss += r_loss.item()
            kl_loss += k_loss.item()
            
            # Clear memory
            if batch_idx % 50 == 0:
                torch.cuda.empty_cache()
        
        avg_loss = total_loss / len(train_loader.dataset)
        avg_recon = recon_loss / len(train_loader.dataset)
        avg_kl = kl_loss / len(train_loader.dataset)
        
        scheduler.step(avg_loss)
        
        print(f'Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}, Recon: {avg_recon:.4f}, KL: {avg_kl:.4f}')
        
        # Save checkpoint only at the final epoch
        if epoch + 1 == epochs:
            torch.save(model.state_dict(), os.path.join(output_dir, f'cvae_checkpoint_epoch_{epoch+1}.pth'))
            print(f"Saved checkpoint at epoch {epoch+1}")
    
    return model

# Generation function
def generate_samples(model, num_samples_per_class, device):
    model.eval()
    os.makedirs(os.path.join(output_dir, "generated_samples"), exist_ok=True)
    
    for class_idx in range(num_classes):
        print(f"Generating samples for class {class_idx}...")
        class_dir = os.path.join(output_dir, "generated_samples", f"class_{class_idx}")
        os.makedirs(class_dir, exist_ok=True)
        
        # Create labels
        labels = torch.full((num_samples_per_class,), class_idx, dtype=torch.long).to(device)
        
        # Generate latent vectors
        z = torch.randn(num_samples_per_class, latent_dim).to(device)
        
        # Create dummy skip connections (zeros)
        dummy_skips = (
            torch.zeros(num_samples_per_class, 64, 64, 64).to(device),
            torch.zeros(num_samples_per_class, 128, 32, 32).to(device),
            torch.zeros(num_samples_per_class, 256, 16, 16).to(device),
            torch.zeros(num_samples_per_class, 512, 8, 8).to(device)
        )
        
        with torch.no_grad():
            generated = model.decoder(z, labels, dummy_skips)
            generated = (generated + 1) / 2  # Convert from [-1,1] to [0,1]
            
            # Save images
            for i in range(num_samples_per_class):
                img = generated[i].cpu()
                save_image(img, os.path.join(class_dir, f'sample_{i}.png'))
        
        print(f"Generated {num_samples_per_class} samples for class {class_idx}")

# Load your dataset (replace with your actual dataset loading)
# For demonstration, we'll create a dummy dataset
from torchvision.datasets import FakeData
from torchvision import transforms as T

dummy_dataset = FakeData(size=1000, transform=transform, num_classes=num_classes)
train_loader = DataLoader(dummy_dataset, batch_size=batch_size, shuffle=True)

# Initialize and train CVAE
encoder = Encoder(latent_dim, num_classes).to(device)
decoder = Decoder(latent_dim, num_classes).to(device)
cvae = CVAE(encoder, decoder).to(device)

print("Starting training...")
cvae = train_cvae(cvae, train_loader, epochs, device)

# Save final model
torch.save(cvae.state_dict(), os.path.join(output_dir, 'cvae_final.pth'))
print("Training complete. Model saved.")

# Generate 100 samples per class
print("Generating samples...")
generate_samples(cvae, num_samples_per_class=100, device=device)
print("Sample generation complete.")

Using device: cuda
Starting training...
Epoch 1/300, Loss: 2121.8276, Recon: 2121.7309, KL: 0.0966
Epoch 2/300, Loss: 1216.8069, Recon: 1216.6914, KL: 0.1155
Epoch 3/300, Loss: 731.6235, Recon: 731.5613, KL: 0.0622
Epoch 4/300, Loss: 469.7885, Recon: 469.7244, KL: 0.0641
Epoch 5/300, Loss: 315.5756, Recon: 315.5164, KL: 0.0591
Epoch 6/300, Loss: 229.1670, Recon: 229.1084, KL: 0.0587
Epoch 7/300, Loss: 172.1056, Recon: 172.0516, KL: 0.0540
Epoch 8/300, Loss: 135.4635, Recon: 135.4059, KL: 0.0576
Epoch 9/300, Loss: 116.1134, Recon: 116.0435, KL: 0.0699
Epoch 10/300, Loss: 94.6187, Recon: 94.5375, KL: 0.0812
Epoch 11/300, Loss: 85.5184, Recon: 85.4537, KL: 0.0646
Epoch 12/300, Loss: 74.3899, Recon: 74.3587, KL: 0.0312
Epoch 13/300, Loss: 66.7432, Recon: 66.7061, KL: 0.0372
Epoch 14/300, Loss: 62.2126, Recon: 62.1792, KL: 0.0335
Epoch 15/300, Loss: 57.4597, Recon: 57.4351, KL: 0.0246
Epoch 16/300, Loss: 53.8213, Recon: 53.7978, KL: 0.0236
Epoch 17/300, Loss: 50.9600, Recon: 50.9320, KL: 0.