In [None]:
##RGB--->GREY

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.utils import save_image
import torchvision.models as models
from torchvision.transforms.functional import adjust_sharpness
import os
import numpy as np
import matplotlib.pyplot as plt
import kagglehub
from PIL import Image
from torchvision.models import VGG16_Weights

# Clear GPU memory
torch.cuda.empty_cache()

# Set CUDA_LAUNCH_BLOCKING for debugging
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

# Device configuration
device = torch.device('cuda:1' if torch.cuda.device_count() > 1 else 'cuda')

# Hyperparameters
latent_dim = 128
num_classes = 5
batch_size = 16
epochs = 1000
learning_rate = 5e-4
image_size = 128
channels = 1
output_dir = "FL_CVAE"
beta_max = 5.0
annealing_epochs = 50
perceptual_weight = 1.0
recon_weight = 0.5

# Download Vehicle Type Image Dataset from Kaggle
try:
    path = kagglehub.dataset_download("sujaykapadnis/vehicle-type-image-dataset")
    print("Path to dataset files:", path)
    dataset_path = path
except Exception as e:
    print(f"Failed to download dataset: {e}")
    raise

# Define the VehicleTypeDataset class with a limit
class VehicleTypeDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, transform=None, max_images=1000):
        self.root_dir = root_dir
        self.transform = transform
        self.images = []
        self.labels = []
        self.class_names = []
        self.class_to_idx = {}

        print(f"Searching for images in {root_dir}")
        for root, dirs, files in os.walk(root_dir):
            image_files = [f for f in files if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
            if image_files:
                class_name = os.path.basename(root)
                if class_name not in self.class_to_idx:
                    self.class_names.append(class_name)
                    self.class_to_idx[class_name] = len(self.class_names) - 1
                for img_file in image_files:
                    img_path = os.path.join(root, img_file)
                    self.images.append(img_path)
                    self.labels.append(self.class_to_idx[class_name])
                    if len(self.images) >= max_images:
                        break
                if len(self.images) >= max_images:
                    break

        if not self.images:
            raise ValueError(
                f"No images found in {root_dir}. "
                "Expected class folders containing .jpg, .png, or .jpeg images."
            )

        print(f"Found {len(self.images)} images across {len(self.class_names)} classes.")
        print(f"Classes: {self.class_names}")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert("L")
        if self.transform:
            image = self.transform(image)
        return image, label

# Define transforms
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
])

# Load the dataset with a limit
dataset = VehicleTypeDataset(root_dir=dataset_path, transform=transform, max_images=1000)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Define the Encoder network
class Encoder(nn.Module):
    def __init__(self, latent_dim, num_classes):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(channels + num_classes, 32, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)
        self.conv5 = nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1)
        self.fc_mean = nn.Linear(512 * 4 * 4, latent_dim)
        self.fc_logvar = nn.Linear(512 * 4 * 4, latent_dim)

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x, y):
        y = F.one_hot(y, num_classes=num_classes).float()
        y = y.unsqueeze(-1).unsqueeze(-1)
        y = y.expand(-1, -1, x.size(2), x.size(3))
        x_with_y = torch.cat([x, y], dim=1)
        
        h1 = F.relu(self.conv1(x_with_y))
        if torch.isnan(h1).any() or torch.isinf(h1).any():
            print("NaN or Inf in h1")
        h2 = F.relu(self.conv2(h1))
        if torch.isnan(h2).any() or torch.isinf(h2).any():
            print("NaN or Inf in h2")
        h3 = F.relu(self.conv3(h2))
        if torch.isnan(h3).any() or torch.isinf(h3).any():
            print("NaN or Inf in h3")
        h4 = F.relu(self.conv4(h3))
        if torch.isnan(h4).any() or torch.isinf(h4).any():
            print("NaN or Inf in h4")
        h5 = F.relu(self.conv5(h4))
        if torch.isnan(h5).any() or torch.isinf(h5).any():
            print("NaN or Inf in h5")
        h = h5.view(h5.size(0), -1)
        z_mean = self.fc_mean(h)
        z_logvar = self.fc_logvar(h)
        return z_mean, z_logvar, (h1, h2, h3, h4, h5)

# Define the Decoder network
class Decoder(nn.Module):
    def __init__(self, latent_dim, num_classes):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(latent_dim + num_classes, 512 * 4 * 4)
        self.deconv1 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1)
        self.deconv2 = nn.ConvTranspose2d(512, 128, kernel_size=4, stride=2, padding=1)
        self.deconv3 = nn.ConvTranspose2d(256, 64, kernel_size=4, stride=2, padding=1)
        self.deconv4 = nn.ConvTranspose2d(128, 32, kernel_size=4, stride=2, padding=1)
        self.deconv5 = nn.ConvTranspose2d(64, channels, kernel_size=4, stride=2, padding=1)

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, z, y, skip_connections):
        h1, h2, h3, h4, h5 = skip_connections
        y = F.one_hot(y, num_classes=num_classes).float()
        z_with_y = torch.cat([z, y], dim=-1)
        h = F.relu(self.fc(z_with_y))
        h = h.view(h.size(0), 512, 4, 4)
        
        h = F.relu(self.deconv1(h))
        h = torch.cat([h, h4], dim=1)
        h = F.relu(self.deconv2(h))
        h = torch.cat([h, h3], dim=1)
        h = F.relu(self.deconv3(h))
        h = torch.cat([h, h2], dim=1)
        h = F.relu(self.deconv4(h))
        h = torch.cat([h, h1], dim=1)
        x_reconstructed = torch.sigmoid(self.deconv5(h))
        if torch.isnan(x_reconstructed).any() or torch.isinf(x_reconstructed).any():
            print("NaN or Inf in x_reconstructed")
        return x_reconstructed

# Convert grayscale to RGB by duplicating channels
def grayscale_to_rgb(tensor):
    return tensor.repeat(1, 3, 1, 1)

# Conditional VAE model
class ConditionalVAE(nn.Module):
    def __init__(self, encoder, decoder):
        super(ConditionalVAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def reparameterize(self, z_mean, z_logvar):
        std = torch.exp(0.5 * z_logvar)
        eps = torch.randn_like(std)
        return z_mean + eps * std

    def forward(self, x, y):
        z_mean, z_logvar, skip_connections = self.encoder(x, y)
        z = self.reparameterize(z_mean, z_logvar)
        x_reconstructed = self.decoder(z, y, skip_connections)
        return x_reconstructed, z_mean, z_logvar

# Load pretrained VGG16 for perceptual loss
vgg = models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1).features.to(device).eval()
for param in vgg.parameters():
    param.requires_grad = False

def perceptual_loss(x, x_reconstructed):
    x_rgb = grayscale_to_rgb(x)
    x_reconstructed_rgb = grayscale_to_rgb(x_reconstructed)
    x_rgb_subset = x_rgb[:8]
    x_reconstructed_rgb_subset = x_reconstructed_rgb[:8]
    mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(x.device)
    std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(x.device)
    x_normalized = (x_rgb_subset - mean) / std
    x_reconstructed_normalized = (x_reconstructed_rgb_subset - mean) / std
    x_features = vgg(x_normalized)
    x_recon_features = vgg(x_reconstructed_normalized)
    return F.mse_loss(x_features, x_recon_features)

# Instantiate Encoder, Decoder, and CVAE
encoder = Encoder(latent_dim, num_classes).to(device)
decoder = Decoder(latent_dim, num_classes).to(device)
cvae = ConditionalVAE(encoder, decoder).to(device)

# Define optimizer and learning rate scheduler
optimizer = optim.Adam(cvae.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.5)

# Define loss function with beta annealing
def cvae_loss(x, x_reconstructed, z_mean, z_logvar, beta=1.0, recon_weight=1.0, perceptual_weight=1.0):
    if torch.isnan(x).any() or torch.isinf(x).any():
        print("NaN or Inf detected in x")
    if torch.isnan(x_reconstructed).any() or torch.isinf(x_reconstructed).any():
        print("NaN or Inf detected in x_reconstructed")
    if torch.isnan(z_mean).any() or torch.isinf(z_mean).any():
        print("NaN or Inf detected in z_mean")
    if torch.isnan(z_logvar).any() or torch.isinf(z_logvar).any():
        print("NaN or Inf detected in z_logvar")

    recon_loss = F.binary_cross_entropy(x_reconstructed, x, reduction='sum')
    kl_loss = -0.5 * torch.sum(1 + torch.clamp(z_logvar, -5, 5) - z_mean.pow(2) - torch.clamp(z_logvar, -5, 5).exp())
    percep_loss = perceptual_loss(x, x_reconstructed) * perceptual_weight
    total_loss = recon_weight * recon_loss + beta * kl_loss + percep_loss
    return total_loss, recon_loss, kl_loss, percep_loss

# Training loop with checkpointing
cvae.train()
for epoch in range(epochs):
    if epoch < annealing_epochs:
        beta = beta_max * (epoch / annealing_epochs)
    else:
        beta = beta_max

    train_loss = 0
    train_recon_loss = 0
    train_kl_loss = 0
    train_percep_loss = 0
    for batch_idx, (data, labels) in enumerate(train_loader):
        data, labels = data.to(device), labels.to(device)
       

        optimizer.zero_grad()
        x_reconstructed, z_mean, z_logvar = cvae(data, labels)
        total_loss, recon_loss, kl_loss, percep_loss = cvae_loss(
            data, x_reconstructed, z_mean, z_logvar, 
            beta=beta, recon_weight=recon_weight, perceptual_weight=perceptual_weight
        )
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(cvae.parameters(), max_norm=0.5)
        train_loss += total_loss.item()
        train_recon_loss += recon_loss.item()
        train_kl_loss += kl_loss.item()
        train_percep_loss += percep_loss.item()
        optimizer.step()

    scheduler.step()
    avg_loss = train_loss / len(train_loader.dataset)
    avg_recon_loss = train_recon_loss / len(train_loader.dataset)
    avg_kl_loss = train_kl_loss / len(train_loader.dataset)
    avg_percep_loss = train_percep_loss / len(train_loader.dataset)
    print(f'Epoch {epoch + 1}/{epochs}, Beta: {beta:.2f}, Total Loss: {avg_loss:.4f}, '
          f'Recon Loss: {avg_recon_loss:.4f}, KL Loss: {avg_kl_loss:.4f}, '
          f'Percep Loss: {avg_percep_loss:.4f}')

    if (epoch + 1) % 100 == 0:
        checkpoint_path = os.path.join(output_dir, f"cvae_vehicle_epoch_{epoch + 1}.pth")
        torch.save(cvae.state_dict(), checkpoint_path)
        print(f"Saved checkpoint at epoch {epoch + 1} to {checkpoint_path}")

        cvae.eval()
        base_dir = os.path.join(output_dir, f"generated_samples_epoch_{epoch + 1}")
        classes_to_generate = [3, 4]
        num_samples = 100
        with torch.no_grad():
            for class_label in classes_to_generate:
                label_tensor = torch.tensor([class_label]).repeat(num_samples).to(device)
                z = torch.randn(num_samples, latent_dim).to(device)
                dummy_skips = [
                    torch.randn(num_samples, 32, 64, 64).to(device) * 0.1,
                    torch.randn(num_samples, 64, 32, 32).to(device) * 0.1,
                    torch.randn(num_samples, 128, 16, 16).to(device) * 0.1,
                    torch.randn(num_samples, 256, 8, 8).to(device) * 0.1,
                    torch.randn(num_samples, 512, 4, 4).to(device) * 0.1
                ]
                generated_samples = cvae.decoder(z, label_tensor, dummy_skips)
                generated_samples_rgb = grayscale_to_rgb(generated_samples)
                class_dir = os.path.join(base_dir, str(class_label))
                os.makedirs(class_dir, exist_ok=True)
                for idx, sample in enumerate(generated_samples_rgb):
                    sample = adjust_sharpness(sample, sharpness_factor=2.0)
                    save_image(sample, os.path.join(class_dir, f"sample_{idx}.png"))
                print(f"Generated {num_samples} samples for Class {class_label} ({dataset.class_names[class_label]}) at epoch {epoch + 1}.")

        fig, axs = plt.subplots(len(classes_to_generate), 10, figsize=(20, 4))
        for row, class_label in enumerate(classes_to_generate):
            class_dir = os.path.join(base_dir, str(class_label))
            sample_files = os.listdir(class_dir)
            random_samples = np.random.choice(sample_files, 10, replace=False)
            for col, sample_file in enumerate(random_samples):
                sample_path = os.path.join(class_dir, sample_file)
                sample_image = Image.open(sample_path).convert("RGB")
                sample_image = sample_image.resize((128, 128), Image.LANCZOS)
                sample_image = np.array(sample_image) / 255.0
                ax = axs[row, col] if len(classes_to_generate) > 1 else axs[col]
                ax.imshow(sample_image)
                ax.axis('off')
                if col == 0:
                    ax.set_ylabel(dataset.class_names[class_label], rotation=90, labelpad=10)
        plt.tight_layout()
        plot_path = os.path.join(output_dir, f"synthetic_samples_classes_3_4_epoch_{epoch + 1}.png")
        plt.savefig(plot_path)
        plt.close()
        print(f"Saved synthetic samples plot at epoch {epoch + 1} to {plot_path}")
        cvae.train()

# Final model save
final_model_path = os.path.join(output_dir, "cvae_vehicle_final.pth")
torch.save(cvae.state_dict(), final_model_path)
print(f"Saved final CVAE model to {final_model_path}")

Path to dataset files: C:\Users\hp\.cache\kagglehub\datasets\sujaykapadnis\vehicle-type-image-dataset\versions\1
Searching for images in C:\Users\hp\.cache\kagglehub\datasets\sujaykapadnis\vehicle-type-image-dataset\versions\1
Found 1000 images across 2 classes.
Classes: ['Hatchback', 'Other']
Epoch 1/1000, Beta: 0.00, Total Loss: 5096.8752, Recon Loss: 10193.5189, KL Loss: 538.2337, Percep Loss: 0.1158
Epoch 2/1000, Beta: 0.10, Total Loss: 4382.3090, Recon Loss: 8762.3615, KL Loss: 10.2121, Percep Loss: 0.1071
Epoch 3/1000, Beta: 0.20, Total Loss: 4330.1876, Recon Loss: 8659.9946, KL Loss: 0.4808, Percep Loss: 0.0941
Epoch 4/1000, Beta: 0.30, Total Loss: 4303.3290, Recon Loss: 8606.4371, KL Loss: 0.0896, Percep Loss: 0.0837
Epoch 5/1000, Beta: 0.40, Total Loss: 4295.4675, Recon Loss: 8590.7705, KL Loss: 0.0263, Percep Loss: 0.0717
Epoch 6/1000, Beta: 0.50, Total Loss: 4289.2458, Recon Loss: 8578.3385, KL Loss: 0.0169, Percep Loss: 0.0682
Epoch 7/1000, Beta: 0.60, Total Loss: 4281.7255

IndexError: list index out of range

In [5]:
###SAVES AT LAST EPOCH

In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.utils import save_image
import torchvision.models as models
from torchvision.transforms.functional import adjust_sharpness
import os
import numpy as np
import matplotlib.pyplot as plt
import kagglehub
from PIL import Image
from torchvision.models import VGG16_Weights

# Clear GPU memory
torch.cuda.empty_cache()

# Set CUDA_LAUNCH_BLOCKING for debugging
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyperparameters
latent_dim = 128
num_classes = 5
batch_size = 16
epochs = 1000
learning_rate = 5e-4
image_size = 128
channels = 1
output_dir = "FL_CVAE"
beta_max = 5.0
annealing_epochs = 50
perceptual_weight = 1.0
recon_weight = 0.5

# Download Vehicle Type Image Dataset from Kaggle
try:
    path = kagglehub.dataset_download("sujaykapadnis/vehicle-type-image-dataset")
    print("Path to dataset files:", path)
    dataset_path = path
except Exception as e:
    print(f"Failed to download dataset: {e}")
    raise

# Define the VehicleTypeDataset class with a limit
class VehicleTypeDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, transform=None, max_images=1000):
        self.root_dir = root_dir
        self.transform = transform
        self.images = []
        self.labels = []
        self.class_names = []
        self.class_to_idx = {}

        print(f"Searching for images in {root_dir}")
        for root, dirs, files in os.walk(root_dir):
            image_files = [f for f in files if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
            if image_files:
                class_name = os.path.basename(root)
                if class_name not in self.class_to_idx:
                    self.class_names.append(class_name)
                    self.class_to_idx[class_name] = len(self.class_names) - 1
                for img_file in image_files:
                    img_path = os.path.join(root, img_file)
                    self.images.append(img_path)
                    self.labels.append(self.class_to_idx[class_name])
                    if len(self.images) >= max_images:
                        break
                if len(self.images) >= max_images:
                    break

        if not self.images:
            raise ValueError(
                f"No images found in {root_dir}. "
                "Expected class folders containing .jpg, .png, or .jpeg images."
            )

        print(f"Found {len(self.images)} images across {len(self.class_names)} classes.")
        print(f"Classes: {self.class_names}")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert("L")
        if self.transform:
            image = self.transform(image)
        return image, label

# Define transforms
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
])

# Load the dataset with a limit
dataset = VehicleTypeDataset(root_dir=dataset_path, transform=transform, max_images=1000)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Define the Encoder network
class Encoder(nn.Module):
    def __init__(self, latent_dim, num_classes):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(channels + num_classes, 32, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)
        self.conv5 = nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1)
        self.fc_mean = nn.Linear(512 * 4 * 4, latent_dim)
        self.fc_logvar = nn.Linear(512 * 4 * 4, latent_dim)

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x, y):
        y = F.one_hot(y, num_classes=num_classes).float()
        y = y.unsqueeze(-1).unsqueeze(-1)
        y = y.expand(-1, -1, x.size(2), x.size(3))
        x_with_y = torch.cat([x, y], dim=1)
        
        h1 = F.relu(self.conv1(x_with_y))
        if torch.isnan(h1).any() or torch.isinf(h1).any():
            print("NaN or Inf in h1")
        h2 = F.relu(self.conv2(h1))
        if torch.isnan(h2).any() or torch.isinf(h2).any():
            print("NaN or Inf in h2")
        h3 = F.relu(self.conv3(h2))
        if torch.isnan(h3).any() or torch.isinf(h3).any():
            print("NaN or Inf in h3")
        h4 = F.relu(self.conv4(h3))
        if torch.isnan(h4).any() or torch.isinf(h4).any():
            print("NaN or Inf in h4")
        h5 = F.relu(self.conv5(h4))
        if torch.isnan(h5).any() or torch.isinf(h5).any():
            print("NaN or Inf in h5")
        h = h5.view(h5.size(0), -1)
        z_mean = self.fc_mean(h)
        z_logvar = self.fc_logvar(h)
        return z_mean, z_logvar, (h1, h2, h3, h4, h5)

# Define the Decoder network
class Decoder(nn.Module):
    def __init__(self, latent_dim, num_classes):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(latent_dim + num_classes, 512 * 4 * 4)
        self.deconv1 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1)
        self.deconv2 = nn.ConvTranspose2d(512, 128, kernel_size=4, stride=2, padding=1)
        self.deconv3 = nn.ConvTranspose2d(256, 64, kernel_size=4, stride=2, padding=1)
        self.deconv4 = nn.ConvTranspose2d(128, 32, kernel_size=4, stride=2, padding=1)
        self.deconv5 = nn.ConvTranspose2d(64, channels, kernel_size=4, stride=2, padding=1)

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, z, y, skip_connections):
        h1, h2, h3, h4, h5 = skip_connections
        y = F.one_hot(y, num_classes=num_classes).float()
        z_with_y = torch.cat([z, y], dim=-1)
        h = F.relu(self.fc(z_with_y))
        h = h.view(h.size(0), 512, 4, 4)
        
        h = F.relu(self.deconv1(h))
        h = torch.cat([h, h4], dim=1)
        h = F.relu(self.deconv2(h))
        h = torch.cat([h, h3], dim=1)
        h = F.relu(self.deconv3(h))
        h = torch.cat([h, h2], dim=1)
        h = F.relu(self.deconv4(h))
        h = torch.cat([h, h1], dim=1)
        x_reconstructed = torch.sigmoid(self.deconv5(h))
        if torch.isnan(x_reconstructed).any() or torch.isinf(x_reconstructed).any():
            print("NaN or Inf in x_reconstructed")
        return x_reconstructed

# Convert grayscale to RGB by duplicating channels
def grayscale_to_rgb(tensor):
    return tensor.repeat(1, 3, 1, 1)

# Conditional VAE model
class ConditionalVAE(nn.Module):
    def __init__(self, encoder, decoder):
        super(ConditionalVAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def reparameterize(self, z_mean, z_logvar):
        std = torch.exp(0.5 * z_logvar)
        eps = torch.randn_like(std)
        return z_mean + eps * std

    def forward(self, x, y):
        z_mean, z_logvar, skip_connections = self.encoder(x, y)
        z = self.reparameterize(z_mean, z_logvar)
        x_reconstructed = self.decoder(z, y, skip_connections)
        return x_reconstructed, z_mean, z_logvar

# Load pretrained VGG16 for perceptual loss
vgg = models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1).features.to(device).eval()
for param in vgg.parameters():
    param.requires_grad = False

def perceptual_loss(x, x_reconstructed):
    x_rgb = grayscale_to_rgb(x)
    x_reconstructed_rgb = grayscale_to_rgb(x_reconstructed)
    x_rgb_subset = x_rgb[:8]
    x_reconstructed_rgb_subset = x_reconstructed_rgb[:8]
    mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(x.device)
    std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(x.device)
    x_normalized = (x_rgb_subset - mean) / std
    x_reconstructed_normalized = (x_reconstructed_rgb_subset - mean) / std
    x_features = vgg(x_normalized)
    x_recon_features = vgg(x_reconstructed_normalized)
    return F.mse_loss(x_features, x_recon_features)

# Instantiate Encoder, Decoder, and CVAE
encoder = Encoder(latent_dim, num_classes).to(device)
decoder = Decoder(latent_dim, num_classes).to(device)
cvae = ConditionalVAE(encoder, decoder).to(device)

# Define optimizer and learning rate scheduler
optimizer = optim.Adam(cvae.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.5)

# Define loss function with beta annealing
def cvae_loss(x, x_reconstructed, z_mean, z_logvar, beta=1.0, recon_weight=1.0, perceptual_weight=1.0):
    if torch.isnan(x).any() or torch.isinf(x).any():
        print("NaN or Inf detected in x")
    if torch.isnan(x_reconstructed).any() or torch.isinf(x_reconstructed).any():
        print("NaN or Inf detected in x_reconstructed")
    if torch.isnan(z_mean).any() or torch.isinf(z_mean).any():
        print("NaN or Inf detected in z_mean")
    if torch.isnan(z_logvar).any() or torch.isinf(z_logvar).any():
        print("NaN or Inf detected in z_logvar")

    recon_loss = F.binary_cross_entropy(x_reconstructed, x, reduction='sum')
    kl_loss = -0.5 * torch.sum(1 + torch.clamp(z_logvar, -5, 5) - z_mean.pow(2) - torch.clamp(z_logvar, -5, 5).exp())
    percep_loss = perceptual_loss(x, x_reconstructed) * perceptual_weight
    total_loss = recon_weight * recon_loss + beta * kl_loss + percep_loss
    return total_loss, recon_loss, kl_loss, percep_loss

# Training loop with checkpointing
cvae.train()
for epoch in range(epochs):
    if epoch < annealing_epochs:
        beta = beta_max * (epoch / annealing_epochs)
    else:
        beta = beta_max

    train_loss = 0
    train_recon_loss = 0
    train_kl_loss = 0
    train_percep_loss = 0
    for batch_idx, (data, labels) in enumerate(train_loader):
        data, labels = data.to(device), labels.to(device)
        
        optimizer.zero_grad()
        x_reconstructed, z_mean, z_logvar = cvae(data, labels)
        total_loss, recon_loss, kl_loss, percep_loss = cvae_loss(
            data, x_reconstructed, z_mean, z_logvar, 
            beta=beta, recon_weight=recon_weight, perceptual_weight=perceptual_weight
        )
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(cvae.parameters(), max_norm=0.5)
        train_loss += total_loss.item()
        train_recon_loss += recon_loss.item()
        train_kl_loss += kl_loss.item()
        train_percep_loss += percep_loss.item()
        optimizer.step()

    scheduler.step()
    avg_loss = train_loss / len(train_loader.dataset)
    avg_recon_loss = train_recon_loss / len(train_loader.dataset)
    avg_kl_loss = train_kl_loss / len(train_loader.dataset)
    avg_percep_loss = train_percep_loss / len(train_loader.dataset)
    print(f'Epoch {epoch + 1}/{epochs}, Beta: {beta:.2f}, Total Loss: {avg_loss:.4f}, '
          f'Recon Loss: {avg_recon_loss:.4f}, KL Loss: {avg_kl_loss:.4f}, '
          f'Percep Loss: {avg_percep_loss:.4f}')

    # Save checkpoint, generate samples, and create plot only in the last epoch
    if epoch + 1 == epochs:
        # Save checkpoint
        checkpoint_path = os.path.join(output_dir, f"cvae_vehicle_final.pth")
        torch.save(cvae.state_dict(), checkpoint_path)
        print(f"Saved final checkpoint at epoch {epoch + 1} to {checkpoint_path}")

        # Generate and save samples
        cvae.eval()
        base_dir = os.path.join(output_dir, f"generated_samples_final")
        classes_to_generate = [3, 4]
        num_samples = 100
        with torch.no_grad():
            for class_label in classes_to_generate:
                label_tensor = torch.tensor([class_label]).repeat(num_samples).to(device)
                z = torch.randn(num_samples, latent_dim).to(device)
                dummy_skips = [
                    torch.randn(num_samples, 32, 64, 64).to(device) * 0.1,
                    torch.randn(num_samples, 64, 32, 32).to(device) * 0.1,
                    torch.randn(num_samples, 128, 16, 16).to(device) * 0.1,
                    torch.randn(num_samples, 256, 8, 8).to(device) * 0.1,
                    torch.randn(num_samples, 512, 4, 4).to(device) * 0.1
                ]
                generated_samples = cvae.decoder(z, label_tensor, dummy_skips)
                generated_samples_rgb = grayscale_to_rgb(generated_samples)
                class_dir = os.path.join(base_dir, str(class_label))
                os.makedirs(class_dir, exist_ok=True)
                for idx, sample in enumerate(generated_samples_rgb):
                    sample = adjust_sharpness(sample, sharpness_factor=2.0)
                    save_image(sample, os.path.join(class_dir, f"sample_{idx}.png"))
                print(f"Generated {num_samples} samples for Class {class_label} ({dataset.class_names[class_label]}) at final epoch.")

        # Create and save plot
        fig, axs = plt.subplots(len(classes_to_generate), 10, figsize=(20, 4))
        for row, class_label in enumerate(classes_to_generate):
            class_dir = os.path.join(base_dir, str(class_label))
            sample_files = os.listdir(class_dir)
            random_samples = np.random.choice(sample_files, 10, replace=False)
            for col, sample_file in enumerate(random_samples):
                sample_path = os.path.join(class_dir, sample_file)
                sample_image = Image.open(sample_path).convert("RGB")
                sample_image = sample_image.resize((128, 128), Image.LANCZOS)
                sample_image = np.array(sample_image) / 255.0
                ax = axs[row, col] if len(classes_to_generate) > 1 else axs[col]
                ax.imshow(sample_image)
                ax.axis('off')
                if col == 0:
                    ax.set_ylabel(dataset.class_names[class_label], rotation=90, labelpad=10)
        plt.tight_layout()
        plot_path = os.path.join(output_dir, f"synthetic_samples_classes_3_4_final.png")
        plt.savefig(plot_path)
        plt.close()
        print(f"Saved synthetic samples plot at final epoch to {plot_path}")
        cvae.train()

# Final model save (redundant since saved above, but kept for consistency)
final_model_path = os.path.join(output_dir, "cvae_vehicle_final.pth")
torch.save(cvae.state_dict(), final_model_path)
print(f"Saved final CVAE model to {final_model_path}")

Path to dataset files: C:\Users\hp\.cache\kagglehub\datasets\sujaykapadnis\vehicle-type-image-dataset\versions\1
Searching for images in C:\Users\hp\.cache\kagglehub\datasets\sujaykapadnis\vehicle-type-image-dataset\versions\1
Found 1000 images across 2 classes.
Classes: ['Hatchback', 'Other']
Epoch 1/1000, Beta: 0.00, Total Loss: 6630.5874, Recon Loss: 13260.9483, KL Loss: 2531.6387, Percep Loss: 0.1133
Epoch 2/1000, Beta: 0.10, Total Loss: 4385.7789, Recon Loss: 8740.7821, KL Loss: 152.8180, Percep Loss: 0.1060
Epoch 3/1000, Beta: 0.20, Total Loss: 4329.0559, Recon Loss: 8657.9318, KL Loss: 0.0045, Percep Loss: 0.0891
Epoch 4/1000, Beta: 0.30, Total Loss: 4315.0779, Recon Loss: 8629.9949, KL Loss: 0.0011, Percep Loss: 0.0802
Epoch 5/1000, Beta: 0.40, Total Loss: 4299.5715, Recon Loss: 8598.9928, KL Loss: 0.0006, Percep Loss: 0.0749
Epoch 6/1000, Beta: 0.50, Total Loss: 4285.4196, Recon Loss: 8570.7064, KL Loss: 0.0004, Percep Loss: 0.0662
Epoch 7/1000, Beta: 0.60, Total Loss: 4277.32

IndexError: list index out of range