In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset, random_split
from PIL import Image
import os
import numpy as np
import matplotlib.pyplot as plt
import kagglehub

# Hyperparameters
RESIZE = 128
original_dim = RESIZE * RESIZE * 3
intermediate_dim = 512
latent_dim = 256
num_classes = 5
batch_size = 8  # Further reduced to 4 if needed (see below)
epochs = 3000  # Temporarily reduced for debugging; set back to 3000 once confirmed working
learning_rate = 1e-4
beta_start = 1
beta_end = 10
device = torch.device('cuda:1' if torch.cuda.device_count() > 1 else 'cuda')

# Print GPU information
print(f"Using device: {device}")
if device.type == 'cuda':
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory Allocated: {torch.cuda.memory_allocated(0) / 1024**2:.2f} MB")

# Custom Dataset for Kaggle Vehicle Type Image Dataset with Preloading
class VehicleTypeDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = []
        self.class_to_idx = {}
        self.images = []
        self.labels = []
        self.preloaded_images = []

        print(f"Searching for images in {root_dir}")
        for root, dirs, files in os.walk(root_dir):
            image_files = [f for f in files if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
            if image_files:
                class_name = os.path.basename(root)
                if class_name not in self.class_to_idx:
                    self.classes.append(class_name)
                    self.class_to_idx[class_name] = len(self.classes) - 1
                for img_file in image_files:
                    img_path = os.path.join(root, img_file)
                    try:
                        Image.open(img_path).verify()
                        self.images.append(img_path)
                        self.labels.append(self.class_to_idx[class_name])
                    except:
                        print(f"Skipping corrupted image: {img_path}")

        if len(self.classes) != num_classes:
            raise ValueError(f"Expected {num_classes} classes, found {len(self.classes)}")
        print(f"Found {len(self.images)} images across {len(self.classes)} classes.")
        print(f"Classes: {self.classes}")

        # Preload images into memory
        print("Preloading images into memory...")
        for img_path in self.images:
            try:
                image = Image.open(img_path).convert('RGB')
                self.preloaded_images.append(image)
            except:
                print(f"Failed to preload image: {img_path}")
                self.preloaded_images.append(None)  # Placeholder for failed images

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.preloaded_images[idx]
        if image is None:
            raise ValueError(f"Image at index {idx} failed to load during preloading")
        label = self.labels[idx]
        sample = {'image': image, 'label': label}
        if self.transform:
            sample['image'] = self.transform(sample['image'])
        return sample

# Enhanced Image Grid for Plotting
def image_grid(imgs, rows, cols, class_label, height=RESIZE, width=RESIZE):
    assert len(imgs) <= rows * cols
    grid = Image.new('RGB', size=(cols * width, rows * height), color=(255, 255, 255))
    for i, img in enumerate(imgs):
        x = (i % cols) * width + 2
        y = (i // cols) * height + 2
        grid.paste(img, box=(x, y))
    return grid

# Weight initialization function
def init_weights(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.Linear):
        nn.init.xavier_normal_(m.weight)
        nn.init.constant_(m.bias, 0)

# Encoder with kernel_size=5
class Encoder(nn.Module):
    def __init__(self, input_size, hidden_dim, latent_dim, num_classes):
        super(Encoder, self).__init__()
        self.encoder_cnn = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=5, stride=2, padding=2),    # Output: 64 x 64 x 64
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2),  # Output: 128 x 32 x 32
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2), # Output: 256 x 16 x 16
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 512, kernel_size=5, stride=2, padding=2), # Output: 512 x 8 x 8
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=5, stride=2, padding=2), # Output: 512 x 4 x 4
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512, 1024, kernel_size=5, stride=2, padding=2),# Output: 1024 x 2 x 2
            nn.BatchNorm2d(1024),
            nn.ReLU(),
            nn.Conv2d(1024, 1024, kernel_size=5, stride=1, padding=2),# Output: 1024 x 2 x 2
            nn.BatchNorm2d(1024),
            nn.ReLU()
        )
        self.fc_hidden = nn.Linear(1024 * 2 * 2 + num_classes, hidden_dim)
        self.dropout = nn.Dropout(0.3)
        self.fc_mean = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
        self.apply(init_weights)

    def forward(self, x, y):
        x = self.encoder_cnn(x)
        x = x.view(x.size(0), -1)
        x_with_y = torch.cat([x, y], dim=-1)
        h = F.relu(self.fc_hidden(x_with_y))
        h = self.dropout(h)
        z_mean = self.fc_mean(h)
        z_logvar = self.fc_logvar(h)
        z_logvar = torch.clamp(z_logvar, min=-10, max=10)
        return z_mean, z_logvar

# Decoder with kernel_size=5
class Decoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, output_size, num_classes):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(latent_dim + num_classes, hidden_dim)
        self.dropout = nn.Dropout(0.3)
        self.fc_to_cnn = nn.Linear(hidden_dim, 1024 * 2 * 2)
        self.decoder_cnn = nn.Sequential(
            nn.ConvTranspose2d(1024, 1024, kernel_size=5, stride=1, padding=2), # Output: 1024 x 2 x 2
            nn.BatchNorm2d(1024),
            nn.ReLU(),
            nn.ConvTranspose2d(1024, 512, kernel_size=5, stride=2, padding=2, output_padding=1), # Output: 512 x 4 x 4
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 512, kernel_size=5, stride=2, padding=2, output_padding=1),  # Output: 512 x 8 x 8
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 256, kernel_size=5, stride=2, padding=2, output_padding=1),  # Output: 256 x 16 x 16
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=5, stride=2, padding=2, output_padding=1),  # Output: 128 x 32 x 32
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2, padding=2, output_padding=1),   # Output: 64 x 64 x 64
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, kernel_size=5, stride=2, padding=2, output_padding=1),     # Output: 3 x 128 x 128
            nn.Tanh()
        )
        self.apply(init_weights)

    def forward(self, z, y):
        z_with_y = torch.cat([z, y], dim=-1)
        h = F.relu(self.fc(z_with_y))
        h = self.dropout(h)
        h = F.relu(self.fc_to_cnn(h))
        h = h.view(-1, 1024, 2, 2)
        x_reconstructed = self.decoder_cnn(h)
        x_reconstructed = torch.clamp(x_reconstructed, min=-1, max=1)
        return x_reconstructed

# Conditional VAE
class ConditionalVAE(nn.Module):
    def __init__(self, encoder, decoder):
        super(ConditionalVAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def reparameterize(self, z_mean, z_logvar):
        std = torch.exp(0.5 * z_logvar)
        eps = torch.randn_like(std)
        return z_mean + eps * std

    def forward(self, data, y):
        z_mean, z_logvar = self.encoder(data, y)
        z = self.reparameterize(z_mean, z_logvar)
        x_reconstructed = self.decoder(z, y)
        return x_reconstructed, z_mean, z_logvar

# Download and prepare dataset
path = kagglehub.dataset_download("sujaykapadnis/vehicle-type-image-dataset")
print("Path to dataset files:", path)

# Data preparation with train/validation split
train_transform = transforms.Compose([
    transforms.Resize((RESIZE, RESIZE), interpolation=Image.BILINEAR),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

dataset = VehicleTypeDataset(root_dir=path, transform=train_transform)
train_size = int(0.8 * len(dataset))  # ~3834 images
val_size = len(dataset) - train_size  # ~959 images
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True, drop_last=True)

# Instantiate model
input_size = (batch_size, 3, RESIZE, RESIZE)
encoder = Encoder(input_size, intermediate_dim, latent_dim, num_classes).to(device)
decoder = Decoder(latent_dim, intermediate_dim, input_size, num_classes).to(device)
cvae = ConditionalVAE(encoder, decoder).to(device)

# Optimizer with learning rate scheduler
optimizer = optim.Adam(cvae.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=500, gamma=0.5)

# Loss function with normalization
def cvae_loss(data, x_reconstructed, z_mean, z_logvar, beta=1.0):
    batch_size = data.size(0)
    mse_loss = F.mse_loss(x_reconstructed, data, reduction='sum') / batch_size
    kl_loss = -0.5 * torch.sum(1 + z_logvar - z_mean.pow(2) - z_logvar.exp()) / batch_size
    return mse_loss + beta * kl_loss

# Training loop with gradient clipping and progress monitoring
cvae.train()
train_losses, val_losses = [], []

for epoch in range(epochs):
    print(f"Starting epoch {epoch + 1}/{epochs}", flush=True)
    beta = beta_start + (beta_end - beta_start) * (epoch / (epochs - 1)) if epochs > 1 else beta_end
    train_loss = 0
    batches_processed = 0

    for batch_idx, sample in enumerate(train_loader):
        try:
            data = sample['image'].to(device)
            labels = sample['label'].to(device)
            one_hot_labels = F.one_hot(labels, num_classes=num_classes).float()

            optimizer.zero_grad()
            x_reconstructed, z_mean, z_logvar = cvae(data, one_hot_labels)
            loss = cvae_loss(data, x_reconstructed, z_mean, z_logvar, beta)

            if torch.isnan(loss) or torch.isinf(loss):
                print(f"NaN/Inf loss detected at epoch {epoch+1}, batch {batch_idx+1}", flush=True)
                continue

            loss.backward()
            torch.nn.utils.clip_grad_norm_(cvae.parameters(), max_norm=1.0)
            optimizer.step()
            train_loss += loss.item() * data.size(0)
            batches_processed += 1

            # Print progress every 100 batches
            if (batch_idx + 1) % 100 == 0:
                print(f"Epoch {epoch + 1}, Batch {batch_idx + 1}/{len(train_loader)}, Loss: {loss.item():.4f}", flush=True)

        except RuntimeError as e:
            print(f"Error at epoch {epoch + 1}, batch {batch_idx + 1}: {str(e)}", flush=True)
            if "out of memory" in str(e).lower():
                print("Out of memory error detected. Clearing cache and continuing...", flush=True)
                if device.type == 'cuda':
                    torch.cuda.empty_cache()
            continue

    avg_train_loss = train_loss / len(train_loader.dataset)
    train_losses.append(avg_train_loss)

    # Validation loss
    cvae.eval()
    val_loss = 0
    with torch.no_grad():
        for sample in val_loader:
            data = sample['image'].to(device)
            labels = sample['label'].to(device)
            one_hot_labels = F.one_hot(labels, num_classes=num_classes).float()
            x_reconstructed, z_mean, z_logvar = cvae(data, one_hot_labels)
            loss = cvae_loss(data, x_reconstructed, z_mean, z_logvar, beta)
            val_loss += loss.item() * data.size(0)
    avg_val_loss = val_loss / len(val_loader.dataset) if len(val_loader.dataset) > 0 else 0
    val_losses.append(avg_val_loss)

    print(f'Epoch {epoch + 1}/{epochs}, Batches Processed: {batches_processed}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}', flush=True)
    scheduler.step()

    # Clear GPU memory
    if device.type == 'cuda':
        torch.cuda.empty_cache()

    # Save model and generate samples only at the last epoch
    if epoch == epochs - 1:
        try:
            # Save the model
            model_path = "cvae_vehicle_final.pth"
            torch.save(cvae.state_dict(), model_path)
            print(f"Saved final CVAE model to {model_path}", flush=True)

            # Generate 100 samples per class
            def generate_samples_labelwise(cvae, num_samples_per_class, base_dir, latent_dim, device):
                cvae.eval()
                os.makedirs(base_dir, exist_ok=True)
                class_names = dataset.classes
                with torch.no_grad():
                    for class_label in range(num_classes):
                        label_tensor = torch.tensor([class_label]).repeat(num_samples_per_class).to(device)
                        one_hot_labels = F.one_hot(label_tensor, num_classes=num_classes).float().to(device)
                        z = torch.randn(num_samples_per_class, latent_dim).to(device)
                        generated_samples = cvae.decoder(z, one_hot_labels)

                        class_dir = os.path.join(base_dir, str(class_label))
                        os.makedirs(class_dir, exist_ok=True)
                        images = []
                        for idx, sample in enumerate(generated_samples):
                            sample = sample.cpu().detach().numpy()
                            sample = sample * 0.5 + 0.5
                            sample = np.nan_to_num(sample, nan=0.0, posinf=0.0, neginf=0.0)
                            sample = (255 * sample).astype(np.uint8)
                            sample = np.transpose(sample, (1, 2, 0))
                            pil_image = Image.fromarray(sample).convert('RGB')
                            pil_image.save(os.path.join(class_dir, f"sample_{idx}.png"))
                            if idx < 32:
                                images.append(pil_image)
                        if images:
                            grid = image_grid(images, rows=4, cols=8, class_label=class_names[class_label])
                            grid.save(os.path.join(class_dir, f"grid_{class_names[class_label]}.png"))

            base_dir = "generated_samples-arch-A5"
            generate_samples_labelwise(cvae, num_samples_per_class=100, base_dir=base_dir, latent_dim=latent_dim, device=device)
            print("Finished generating samples", flush=True)

        except Exception as e:
            print(f"Error during saving/generation at last epoch: {str(e)}", flush=True)

# Plot training and validation losses
try:
    plt.figure(figsize=(8, 6))
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Val Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig("training_validation_losses-A5.png")
    plt.close()
    print("Saved loss plot to training_validation_losses-A5.png", flush=True)
except Exception as e:
    print(f"Error plotting losses: {str(e)}", flush=True)

Using device: cuda:1
GPU Name: NVIDIA RTX A5000
GPU Memory Allocated: 1632.76 MB
Path to dataset files: C:\Users\hp\.cache\kagglehub\datasets\sujaykapadnis\vehicle-type-image-dataset\versions\1
Searching for images in C:\Users\hp\.cache\kagglehub\datasets\sujaykapadnis\vehicle-type-image-dataset\versions\1
Found 4793 images across 5 classes.
Classes: ['Hatchback', 'Other', 'Pickup', 'Seden', 'SUV']
Preloading images into memory...
Starting epoch 1/3000
Epoch 1, Batch 100/479, Loss: 8330.7422
Epoch 1, Batch 200/479, Loss: 8530.8848
Epoch 1, Batch 300/479, Loss: 7411.3530
Epoch 1, Batch 400/479, Loss: 7409.0112
Epoch 1/3000, Batches Processed: 479, Train Loss: 9118.5234, Val Loss: 7856.2977
Starting epoch 2/3000
Epoch 2, Batch 100/479, Loss: 7354.0107
Epoch 2, Batch 200/479, Loss: 5924.6113
Epoch 2, Batch 300/479, Loss: 7629.7148
Epoch 2, Batch 400/479, Loss: 8198.4824
Epoch 2/3000, Batches Processed: 479, Train Loss: 7342.5527, Val Loss: 7084.9774
Starting epoch 3/3000
Epoch 3, Batch 10

In [1]:
##increasing epoches

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset, random_split
from PIL import Image
import os
import numpy as np
import matplotlib.pyplot as plt
import kagglehub

# Hyperparameters
RESIZE = 128
original_dim = RESIZE * RESIZE * 3
intermediate_dim = 512
latent_dim = 256
num_classes = 5
batch_size = 8
previous_epochs = 3000  # Number of epochs already trained
additional_epochs = 2000  # Additional epochs to train
total_epochs = previous_epochs + additional_epochs  # 5000 epochs total
learning_rate = 1e-4
beta_start = 1
beta_end = 10
device = torch.device('cuda:1' if torch.cuda.device_count() > 1 else 'cuda')

# Print GPU information
print(f"Using device: {device}")
if device.type == 'cuda':
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory Allocated: {torch.cuda.memory_allocated(0) / 1024**2:.2f} MB")

# Custom Dataset for Kaggle Vehicle Type Image Dataset with Preloading
class VehicleTypeDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = []
        self.class_to_idx = {}
        self.images = []
        self.labels = []
        self.preloaded_images = []

        print(f"Searching for images in {root_dir}")
        for root, dirs, files in os.walk(root_dir):
            image_files = [f for f in files if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
            if image_files:
                class_name = os.path.basename(root)
                if class_name not in self.class_to_idx:
                    self.classes.append(class_name)
                    self.class_to_idx[class_name] = len(self.classes) - 1
                for img_file in image_files:
                    img_path = os.path.join(root, img_file)
                    try:
                        Image.open(img_path).verify()
                        self.images.append(img_path)
                        self.labels.append(self.class_to_idx[class_name])
                    except:
                        print(f"Skipping corrupted image: {img_path}")

        if len(self.classes) != num_classes:
            raise ValueError(f"Expected {num_classes} classes, found {len(self.classes)}")
        print(f"Found {len(self.images)} images across {len(self.classes)} classes.")
        print(f"Classes: {self.classes}")

        # Preload images into memory
        print("Preloading images into memory...")
        for img_path in self.images:
            try:
                image = Image.open(img_path).convert('RGB')
                self.preloaded_images.append(image)
            except:
                print(f"Failed to preload image: {img_path}")
                self.preloaded_images.append(None)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.preloaded_images[idx]
        if image is None:
            raise ValueError(f"Image at index {idx} failed to load during preloading")
        label = self.labels[idx]
        sample = {'image': image, 'label': label}
        if self.transform:
            sample['image'] = self.transform(sample['image'])
        return sample

# Enhanced Image Grid for Plotting
def image_grid(imgs, rows, cols, class_label, height=RESIZE, width=RESIZE):
    assert len(imgs) <= rows * cols
    grid = Image.new('RGB', size=(cols * width, rows * height), color=(255, 255, 255))
    for i, img in enumerate(imgs):
        x = (i % cols) * width + 2
        y = (i // cols) * height + 2
        grid.paste(img, box=(x, y))
    return grid

# Weight initialization function
def init_weights(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.Linear):
        nn.init.xavier_normal_(m.weight)
        nn.init.constant_(m.bias, 0)

# Encoder with kernel_size=5
class Encoder(nn.Module):
    def __init__(self, input_size, hidden_dim, latent_dim, num_classes):
        super(Encoder, self).__init__()
        self.encoder_cnn = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 512, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512, 1024, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(1024),
            nn.ReLU(),
            nn.Conv2d(1024, 1024, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(1024),
            nn.ReLU()
        )
        self.fc_hidden = nn.Linear(1024 * 2 * 2 + num_classes, hidden_dim)
        self.dropout = nn.Dropout(0.3)
        self.fc_mean = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
        self.apply(init_weights)

    def forward(self, x, y):
        x = self.encoder_cnn(x)
        x = x.view(x.size(0), -1)
        x_with_y = torch.cat([x, y], dim=-1)
        h = F.relu(self.fc_hidden(x_with_y))
        h = self.dropout(h)
        z_mean = self.fc_mean(h)
        z_logvar = self.fc_logvar(h)
        z_logvar = torch.clamp(z_logvar, min=-10, max=10)
        return z_mean, z_logvar

# Decoder with kernel_size=5
class Decoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, output_size, num_classes):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(latent_dim + num_classes, hidden_dim)
        self.dropout = nn.Dropout(0.3)
        self.fc_to_cnn = nn.Linear(hidden_dim, 1024 * 2 * 2)
        self.decoder_cnn = nn.Sequential(
            nn.ConvTranspose2d(1024, 1024, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(1024),
            nn.ReLU(),
            nn.ConvTranspose2d(1024, 512, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 512, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 256, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.Tanh()
        )
        self.apply(init_weights)

    def forward(self, z, y):
        z_with_y = torch.cat([z, y], dim=-1)
        h = F.relu(self.fc(z_with_y))
        h = self.dropout(h)
        h = F.relu(self.fc_to_cnn(h))
        h = h.view(-1, 1024, 2, 2)
        x_reconstructed = self.decoder_cnn(h)
        x_reconstructed = torch.clamp(x_reconstructed, min=-1, max=1)
        return x_reconstructed

# Conditional VAE
class ConditionalVAE(nn.Module):
    def __init__(self, encoder, decoder):
        super(ConditionalVAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def reparameterize(self, z_mean, z_logvar):
        std = torch.exp(0.5 * z_logvar)
        eps = torch.randn_like(std)
        return z_mean + eps * std

    def forward(self, data, y):
        z_mean, z_logvar = self.encoder(data, y)
        z = self.reparameterize(z_mean, z_logvar)
        x_reconstructed = self.decoder(z, y)
        return x_reconstructed, z_mean, z_logvar

# Download and prepare dataset
path = kagglehub.dataset_download("sujaykapadnis/vehicle-type-image-dataset")
print("Path to dataset files:", path)

# Data preparation with train/validation split
train_transform = transforms.Compose([
    transforms.Resize((RESIZE, RESIZE), interpolation=Image.BILINEAR),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

dataset = VehicleTypeDataset(root_dir=path, transform=train_transform)
train_size = int(0.8 * len(dataset))  # ~3834 images
val_size = len(dataset) - train_size  # ~959 images
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True, drop_last=True)

# Instantiate model
input_size = (batch_size, 3, RESIZE, RESIZE)
encoder = Encoder(input_size, intermediate_dim, latent_dim, num_classes).to(device)
decoder = Decoder(latent_dim, intermediate_dim, input_size, num_classes).to(device)
cvae = ConditionalVAE(encoder, decoder).to(device)

# Load the saved model checkpoint
model_path = "cvae_vehicle_final.pth"
if os.path.exists(model_path):
    cvae.load_state_dict(torch.load(model_path, map_location=device))
    print(f"Loaded saved model from {model_path}", flush=True)
else:
    raise FileNotFoundError(f"Model checkpoint {model_path} not found. Please ensure the file exists.")

# Optimizer with learning rate scheduler
optimizer = optim.Adam(cvae.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=500, gamma=0.5)

# Simulate scheduler steps for the previous 3000 epochs to match the learning rate
for _ in range(previous_epochs):
    scheduler.step()

# Loss function with normalization
def cvae_loss(data, x_reconstructed, z_mean, z_logvar, beta=1.0):
    batch_size = data.size(0)
    mse_loss = F.mse_loss(x_reconstructed, data, reduction='sum') / batch_size
    kl_loss = -0.5 * torch.sum(1 + z_logvar - z_mean.pow(2) - z_logvar.exp()) / batch_size
    return mse_loss + beta * kl_loss

# Initialize loss lists with placeholders for the first 3000 epochs
# Note: If you have the actual train_losses and val_losses from the previous run, load them instead
train_losses = [0] * previous_epochs  # Placeholder
val_losses = [0] * previous_epochs    # Placeholder
# If you saved the losses previously, you can load them here:
# train_losses = np.load("train_losses.npy").tolist()
# val_losses = np.load("val_losses.npy").tolist()
# For now, we'll append the last known values for continuity in plotting
train_losses[-1] = 357.2150  # Last known train loss
val_losses[-1] = 1270.8426   # Last known val loss

# Training loop for additional 2000 epochs (3001 to 5000)
cvae.train()
for epoch in range(previous_epochs, total_epochs):
    epoch_display = epoch + 1  # Display epoch as 3001, 3002, ..., 5000
    print(f"Starting epoch {epoch_display}/{total_epochs}", flush=True)
    
    # Beta annealing (continuing from epoch 3000 where beta was already at beta_end)
    beta = beta_end  # Beta was already at 10 by epoch 3000, so we keep it there
    
    train_loss = 0
    batches_processed = 0

    for batch_idx, sample in enumerate(train_loader):
        try:
            data = sample['image'].to(device)
            labels = sample['label'].to(device)
            one_hot_labels = F.one_hot(labels, num_classes=num_classes).float()

            optimizer.zero_grad()
            x_reconstructed, z_mean, z_logvar = cvae(data, one_hot_labels)
            loss = cvae_loss(data, x_reconstructed, z_mean, z_logvar, beta)

            if torch.isnan(loss) or torch.isinf(loss):
                print(f"NaN/Inf loss detected at epoch {epoch_display}, batch {batch_idx+1}", flush=True)
                continue

            loss.backward()
            torch.nn.utils.clip_grad_norm_(cvae.parameters(), max_norm=1.0)
            optimizer.step()
            train_loss += loss.item() * data.size(0)
            batches_processed += 1

            # Print progress every 100 batches
            if (batch_idx + 1) % 100 == 0:
                print(f"Epoch {epoch_display}, Batch {batch_idx + 1}/{len(train_loader)}, Loss: {loss.item():.4f}", flush=True)

        except RuntimeError as e:
            print(f"Error at epoch {epoch_display}, batch {batch_idx + 1}: {str(e)}", flush=True)
            if "out of memory" in str(e).lower():
                print("Out of memory error detected. Clearing cache and continuing...", flush=True)
                if device.type == 'cuda':
                    torch.cuda.empty_cache()
            continue

    avg_train_loss = train_loss / len(train_loader.dataset)
    train_losses.append(avg_train_loss)

    # Validation loss
    cvae.eval()
    val_loss = 0
    with torch.no_grad():
        for sample in val_loader:
            data = sample['image'].to(device)
            labels = sample['label'].to(device)
            one_hot_labels = F.one_hot(labels, num_classes=num_classes).float()
            x_reconstructed, z_mean, z_logvar = cvae(data, one_hot_labels)
            loss = cvae_loss(data, x_reconstructed, z_mean, z_logvar, beta)
            val_loss += loss.item() * data.size(0)
    avg_val_loss = val_loss / len(val_loader.dataset) if len(val_loader.dataset) > 0 else 0
    val_losses.append(avg_val_loss)

    print(f'Epoch {epoch_display}/{total_epochs}, Batches Processed: {batches_processed}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}', flush=True)
    scheduler.step()

    # Clear GPU memory
    if device.type == 'cuda':
        torch.cuda.empty_cache()

    # Save model and generate samples only at the last epoch (5000)
    if epoch == total_epochs - 1:
        try:
            # Save the model
            model_path = "cvae_vehicle_final_epoch5000.pth"
            torch.save(cvae.state_dict(), model_path)
            print(f"Saved final CVAE model to {model_path}", flush=True)

            # Generate 100 samples per class
            def generate_samples_labelwise(cvae, num_samples_per_class, base_dir, latent_dim, device):
                cvae.eval()
                os.makedirs(base_dir, exist_ok=True)
                class_names = dataset.classes
                with torch.no_grad():
                    for class_label in range(num_classes):
                        label_tensor = torch.tensor([class_label]).repeat(num_samples_per_class).to(device)
                        one_hot_labels = F.one_hot(label_tensor, num_classes=num_classes).float().to(device)
                        z = torch.randn(num_samples_per_class, latent_dim).to(device)
                        generated_samples = cvae.decoder(z, one_hot_labels)

                        class_dir = os.path.join(base_dir, str(class_label))
                        os.makedirs(class_dir, exist_ok=True)
                        images = []
                        for idx, sample in enumerate(generated_samples):
                            sample = sample.cpu().detach().numpy()
                            sample = sample * 0.5 + 0.5
                            sample = np.nan_to_num(sample, nan=0.0, posinf=0.0, neginf=0.0)
                            sample = (255 * sample).astype(np.uint8)
                            sample = np.transpose(sample, (1, 2, 0))
                            pil_image = Image.fromarray(sample).convert('RGB')
                            pil_image.save(os.path.join(class_dir, f"sample_{idx}.png"))
                            if idx < 32:
                                images.append(pil_image)
                        if images:
                            grid = image_grid(images, rows=4, cols=8, class_label=class_names[class_label])
                            grid.save(os.path.join(class_dir, f"grid_{class_names[class_label]}.png"))

            base_dir = "generated_samples-arch-A5-epoch5000"
            generate_samples_labelwise(cvae, num_samples_per_class=100, base_dir=base_dir, latent_dim=latent_dim, device=device)
            print("Finished generating samples", flush=True)

        except Exception as e:
            print(f"Error during saving/generation at last epoch: {str(e)}", flush=True)

# Plot training and validation losses
try:
    plt.figure(figsize=(8, 6))
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Val Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig("training_validation_losses-A5-epoch5000.png")
    plt.close()
    print("Saved loss plot to training_validation_losses-A5-epoch5000.png", flush=True)
except Exception as e:
    print(f"Error plotting losses: {str(e)}", flush=True)

Using device: cuda:1
GPU Name: NVIDIA RTX A5000
GPU Memory Allocated: 0.00 MB
Path to dataset files: C:\Users\hp\.cache\kagglehub\datasets\sujaykapadnis\vehicle-type-image-dataset\versions\1
Searching for images in C:\Users\hp\.cache\kagglehub\datasets\sujaykapadnis\vehicle-type-image-dataset\versions\1
Found 4793 images across 5 classes.
Classes: ['Hatchback', 'Other', 'Pickup', 'Seden', 'SUV']
Preloading images into memory...


  cvae.load_state_dict(torch.load(model_path, map_location=device))


Loaded saved model from cvae_vehicle_final.pth
Starting epoch 3001/5000




Epoch 3001, Batch 100/479, Loss: 15672.8477
Epoch 3001, Batch 200/479, Loss: 15397.9023
Epoch 3001, Batch 300/479, Loss: 16449.3066
Epoch 3001, Batch 400/479, Loss: 16850.7617
Epoch 3001/5000, Batches Processed: 479, Train Loss: 18745.6700, Val Loss: 16731.6559
Starting epoch 3002/5000
Epoch 3002, Batch 100/479, Loss: 18578.0859
Epoch 3002, Batch 200/479, Loss: 10942.2852
Epoch 3002, Batch 300/479, Loss: 11249.3809
Epoch 3002, Batch 400/479, Loss: 14757.4131
Epoch 3002/5000, Batches Processed: 479, Train Loss: 13883.8601, Val Loss: 12317.9610
Starting epoch 3003/5000
Epoch 3003, Batch 100/479, Loss: 10846.0088
Epoch 3003, Batch 200/479, Loss: 11713.5244
Epoch 3003, Batch 300/479, Loss: 13184.1279
Epoch 3003, Batch 400/479, Loss: 12124.3184
Epoch 3003/5000, Batches Processed: 479, Train Loss: 11577.0236, Val Loss: 10782.1390
Starting epoch 3004/5000
Epoch 3004, Batch 100/479, Loss: 11554.3125
Epoch 3004, Batch 200/479, Loss: 10980.0859
Epoch 3004, Batch 300/479, Loss: 15027.5361
Epoch 3