In [1]:
#CAVE WITH NEW DATA DISTRIBUTION

In [5]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset, Dataset, ConcatDataset
import torchvision.datasets as datasets
import os
import numpy as np
from PIL import Image
import random
import kagglehub
import time

# Ensure output directory exists
output_dir = "FL_VEHICLE_NON_IID_2"
os.makedirs(output_dir, exist_ok=True)

RESIZE = 128
original_dim = RESIZE * RESIZE * 3
intermediate_dim = 512
latent_dim = 256
num_classes = 5
batch_size = 8
epochs = 3000  # Set to 3000 as requested
learning_rate = 1e-4
beta_start = 1
beta_end = 10
device = torch.device('cuda:1' if torch.cuda.device_count() > 1 else 'cuda')

# Download Vehicle Type Image Dataset from Kaggle
try:
    path = kagglehub.dataset_download("sujaykapadnis/vehicle-type-image-dataset")
    print("Path to dataset files:", path)
    dataset_path = path
except Exception as e:
    print(f"Failed to download dataset: {e}")
    raise

# Base transform for RGB Vehicle Type Dataset
base_transform = transforms.Compose([
    transforms.Resize((128, 128)),  # Resize to 128x128
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize RGB channels
])

# Debug dataset directory structure
print("Inspecting dataset path:", dataset_path)
for root, dirs, files in os.walk(dataset_path):
    print(f"Root: {root}")
    print(f"Dirs: {dirs}")
    print(f"Files (first 5): {files[:5]}")
    print("-" * 50)

class VehicleTypeDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.images = []
        self.labels = []
        self.class_names = []
        self.class_to_idx = {}

        print(f"Searching for images in {root_dir}")
        for root, dirs, files in os.walk(root_dir):
            image_files = [f for f in files if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
            if image_files:
                class_name = os.path.basename(root)
                if class_name not in self.class_to_idx:
                    self.class_names.append(class_name)
                    self.class_to_idx[class_name] = len(self.class_names) - 1
                for img_file in image_files:
                    img_path = os.path.join(root, img_file)
                    self.images.append(img_path)
                    self.labels.append(self.class_to_idx[class_name])

        if not self.images:
            raise ValueError(
                f"No images found in {root_dir}. "
                "Expected class folders containing .jpg, .png, or .jpeg images."
            )

        print(f"Found {len(self.images)} images across {len(self.class_names)} classes.")
        print(f"Classes: {self.class_names}")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, label

# Load the Vehicle Type Dataset without transformations for splitting
try:
    dataset_no_transform = VehicleTypeDataset(root_dir=dataset_path, transform=None)
except Exception as e:
    print(f"Failed to load dataset: {e}")
    raise

# Update label_dim
label_dim = len(dataset_no_transform.class_names)
print(f"Number of classes (label_dim): {label_dim}")

# Step 1: Split dataset into train, validation, and test sets per class
validation_ratio = 0.1
test_ratio = 0.1
train_ratio = 0.8

# Separate the dataset by class using labels directly
class_datasets = [[] for _ in range(label_dim)]
for idx in range(len(dataset_no_transform)):
    label = dataset_no_transform.labels[idx]
    class_datasets[label].append(idx)

# Split each class into train, validation, and test sets
train_indices_per_class = []
val_indices_per_class = []
test_indices_per_class = []

for class_idx in range(label_dim):
    indices = class_datasets[class_idx]
    total_samples = len(indices)
    num_train = int(total_samples * train_ratio)
    num_val = int(total_samples * validation_ratio)
    num_test = total_samples - num_train - num_val

    random.shuffle(indices)

    train_indices = indices[:num_train]
    val_indices = indices[num_train:num_train + num_val]
    test_indices = indices[num_train + num_val:]

    train_indices_per_class.append(train_indices)
    val_indices_per_class.append(val_indices)
    test_indices_per_class.append(test_indices)

    print(f"Class {class_idx}: Train={len(train_indices)}, Val={len(val_indices)}, Test={len(test_indices)}")

# Verify no overlap between train, val, and test
for class_idx in range(label_dim):
    train_set = set(train_indices_per_class[class_idx])
    val_set = set(val_indices_per_class[class_idx])
    test_set = set(test_indices_per_class[class_idx])

    assert len(train_set.intersection(val_set)) == 0, f"Overlap between train and val for class {class_idx}"
    assert len(train_set.intersection(test_set)) == 0, f"Overlap between train and test for class {class_idx}"
    assert len(val_set.intersection(test_set)) == 0, f"Overlap between val and test for class {class_idx}"

# Create a new dataset instance with transformations
dataset = VehicleTypeDataset(root_dir=dataset_path, transform=base_transform)

# Create train, val, and test datasets
train_dataset = Subset(dataset, [idx for class_indices in train_indices_per_class for idx in class_indices])
val_dataset = Subset(dataset, [idx for class_indices in val_indices_per_class for idx in class_indices])
test_dataset = Subset(dataset, [idx for class_indices in test_indices_per_class for idx in class_indices])

# Create separate datasets for each class using the original dataset
distinct_class_datasets = []
num_classes = label_dim  # Use label_dim to ensure consistency
for class_idx in range(num_classes):
    distinct_class_dataset = Subset(dataset, train_indices_per_class[class_idx])
    distinct_class_datasets.append(distinct_class_dataset)

# Verify the size of each class dataset
for i, distinct_class_dataset in enumerate(distinct_class_datasets):
    print(f"Class {i} dataset size: {len(distinct_class_dataset)}")

# Function to split a dataset into two parts
def split_dataset(dataset, split_ratio):
    train_size = int(np.round(split_ratio * len(dataset)))
    remaining_size = len(dataset) - train_size
    train_dataset, remaining_dataset = torch.utils.data.random_split(dataset, [train_size, remaining_size])
    return train_dataset, remaining_dataset

# Split each class dataset into two halves (50/50)
split_ratio = 0.5
split_datasets = []
train_class_datasets1 = []
train_class_datasets2 = []

for distinct_class_dataset in distinct_class_datasets:
    train_class_dataset1, train_class_dataset2 = split_dataset(distinct_class_dataset, split_ratio)
    split_datasets.append((train_class_dataset1, train_class_dataset2))
    train_class_datasets1.append(train_class_dataset1)
    train_class_datasets2.append(train_class_dataset2)

for i, (train_class_dataset1, train_class_dataset2) in enumerate(split_datasets):
    print(f"Class {i}:")
    print(f"  Number of samples in train_class_datasets1: {len(train_class_dataset1)}")
    print(f"  Number of samples in train_class_datasets2: {len(train_class_dataset2)}")

# Further split train_class_datasets2 into 70% and 30% parts
split_ratio = 0.7
split_datasets2 = []
train_class_datasets2_part1 = []
train_class_datasets2_part2 = []

for class_dataset in train_class_datasets2:
    part1_dataset, part2_dataset = split_dataset(class_dataset, split_ratio)
    split_datasets2.append((part1_dataset, part2_dataset))
    train_class_datasets2_part1.append(part1_dataset)
    train_class_datasets2_part2.append(part2_dataset)

for i, (part1_dataset, part2_dataset) in enumerate(split_datasets2):
    print(f"Class {i}:")
    print(f"  Number of samples in train_class_datasets2_part1 (70%): {len(part1_dataset)}")
    print(f"  Number of samples in train_class_datasets2_part2 (30%): {len(part2_dataset)}")

# Define augmentation transforms
augmentation_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=10),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
])

# Function to apply augmentation only to PIL images
def augment_image_if_needed(image):
    if isinstance(image, torch.Tensor):
        image = transforms.ToPILImage()(image)
    image = augmentation_transform(image)
    image = base_transform(image)
    return image

# Function to augment the dataset to a target length
def augment_dataset(dataset, target_length):
    augmented_samples = []
    current_length = len(dataset)
    num_samples_to_augment = target_length - current_length
    
    if num_samples_to_augment <= 0:
        return dataset
    
    for _ in range(num_samples_to_augment):
        index = random.randint(0, current_length - 1)
        image, label = dataset[index]
        augmented_image = augment_image_if_needed(image)
        augmented_samples.append((augmented_image, label))
    
    augmented_dataset = ConcatDataset([dataset, augmented_samples])
    return augmented_dataset

# Target length for all datasets
lengthiest_length = 500

# Augment each dataset to have exactly 500 samples
train_class_datasets1 = [augment_dataset(dataset, lengthiest_length) for dataset in train_class_datasets1]
train_class_datasets2_part1 = [augment_dataset(dataset, lengthiest_length) for dataset in train_class_datasets2_part1]
train_class_datasets2_part2 = [augment_dataset(dataset, lengthiest_length) for dataset in train_class_datasets2_part2]

# Print the sizes of the augmented datasets
print("Sizes of augmented datasets:")
for i, dataset in enumerate(train_class_datasets1):
    print(f"  Length of augmented train_class_datasets1[{i}]: {len(dataset)}")
for i, dataset in enumerate(train_class_datasets2_part1):
    print(f"  Length of augmented train_class_datasets2_part1[{i}]: {len(dataset)}")
for i, dataset in enumerate(train_class_datasets2_part2):
    print(f"  Length of augmented train_class_datasets2_part2[{i}]: {len(dataset)}")

train_class_datasets2= [augment_dataset(dataset, lengthiest_length) for dataset in train_class_datasets2]
print("Sizes of augmented datasets:")
for i, dataset in enumerate(train_class_datasets2):
    print(f"  Length of augmented train_class_datasets2[{i}]: {len(dataset)}")


Num_users=5
    
user_data = [] 

user_data.append(torch.utils.data.ConcatDataset([train_class_datasets1[0], train_class_datasets2[1],train_class_datasets2_part2[2],train_class_datasets2_part2[3],train_class_datasets2_part2[4]]))
user_data.append(torch.utils.data.ConcatDataset([train_class_datasets1[1],train_class_datasets2_part1[2]]))
user_data.append(torch.utils.data.ConcatDataset([train_class_datasets1[2],train_class_datasets2_part1[3]]))
user_data.append(torch.utils.data.ConcatDataset([train_class_datasets1[3],train_class_datasets2_part1[4]]))
user_data.append(torch.utils.data.ConcatDataset([train_class_datasets1[4]]))
for i, user_dataset in enumerate(user_data):
    print(f"User {i + 1}:")
    print("Number of samples in the user dataset:", len(user_dataset))

cvae_users = {}
train_losses_users = {}
val_losses_users = {}

for user_idx in range(Num_users):
    # Start timer for this user's CVAE training
    user_start_time = time.time()

    # Create user dataset
    user_dataset = ConcatDataset([train_class_datasets[i] for i in user_classes[user_idx]])
    print(f"User {user_idx + 1} dataset length: {len(user_dataset)}")

    # Validate indices
    try:
        for i in range(min(5, len(user_dataset))):
            sample, label = user_dataset[i]
            print(f"User {user_idx + 1}, Sample {i}: Label={label}, Data shape={sample.shape}")
    except Exception as e:
        print(f"Error accessing samples for User {user_idx + 1}: {e}")
        raise

    user_loader = DataLoader(user_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True, drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True, drop_last=True)

    # Instantiate CVAE
    encoder = Encoder(intermediate_dim, latent_dim, num_classes).to(device)
    decoder = Decoder(latent_dim, intermediate_dim, num_classes).to(device)
    cvae = ConditionalVAE(encoder, decoder).to(device)

    # Optimizer and scheduler
    optimizer = optim.Adam(cvae.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=500, gamma=0.5)

    # Training loop
    checkpoint_dir = os.path.join(output_dir, f'checkpoints_cvae_user_{user_idx + 1}')
    os.makedirs(checkpoint_dir, exist_ok=True)

    train_losses = []
    val_losses = []

    cvae.train()
    for epoch in range(1, epochs + 1):
        epoch_start_time = time.time()

        beta = beta_start + (beta_end - beta_start) * ((epoch - 1) / (epochs - 1)) if epochs > 1 else beta_end

        train_loss = 0
        batches_processed = 0

        for batch_idx, (data, labels) in enumerate(user_loader):
            try:
                data = data.to(device)
                y = F.one_hot(labels, num_classes=num_classes).float().to(device)

                optimizer.zero_grad()
                x_recon, z_mean, z_logvar = cvae(data, y)
                loss = cvae_loss(data, x_recon, z_mean, z_logvar, beta)

                if torch.isnan(loss) or torch.isinf(loss):
                    print(f"NaN/Inf loss detected at epoch {epoch}, batch {batch_idx + 1} for User {user_idx + 1}")
                    continue

                loss.backward()
                torch.nn.utils.clip_grad_norm_(cvae.parameters(), max_norm=1.0)
                optimizer.step()
                train_loss += loss.item()
                batches_processed += 1

            except Exception as e:
                print(f"Error at epoch {epoch}, batch {batch_idx + 1} for User {user_idx + 1}: {e}")
                if "out of memory" in str(e).lower():
                    print("Out of memory error detected. Clearing cache...")
                    if device.type == 'cuda':
                        torch.cuda.empty_cache()
                continue

        avg_train_loss = train_loss / batches_processed if batches_processed > 0 else float('inf')
        train_losses.append(avg_train_loss)

        # Validation
        cvae.eval()
        val_loss = 0
        val_batches = 0
        with torch.no_grad():
            for data, labels in val_loader:
                data = data.to(device)
                y = F.one_hot(labels, num_classes=num_classes).float().to(device)
                x_recon, z_mean, z_logvar = cvae(data, y)
                loss = cvae_loss(data, x_recon, z_mean, z_logvar, beta=1.0)
                val_loss += loss.item()
                val_batches += 1

        avg_val_loss = val_loss / val_batches if val_batches > 0 else 0
        val_losses.append(avg_val_loss)

        epoch_time = time.time() - epoch_start_time
        print(f"User {user_idx + 1}, Epoch {epoch}/{epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Time: {format_time(epoch_time)}")

        scheduler.step()

        # Clear GPU memory
        if device.type == 'cuda':
            torch.cuda.empty_cache()

        # Save checkpoints, latent vectors, and decoder parameters every 500 epochs or at the end
        if epoch % 500 == 0 or epoch == epochs:
            checkpoint_path = os.path.join(checkpoint_dir, f'cvae_epoch_{epoch}.pth')
            torch.save({
                'epoch': epoch,
                'model_state_dict': cvae.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_losses': train_losses,
                'val_losses': val_losses
            }, checkpoint_path)
            print(f"Checkpoint saved for User {user_idx + 1} at epoch {epoch} to {checkpoint_path}")

            # Save decoder parameters
            decoder_dir = os.path.join(checkpoint_dir, 'decoder')
            os.makedirs(decoder_dir, exist_ok=True)
            decoder_path = os.path.join(decoder_dir, f'decoder_epoch_{epoch}.pth')
            torch.save(cvae.decoder.state_dict(), decoder_path)
            print(f"Decoder saved for User {user_idx + 1} at epoch {epoch} to {decoder_path}")

            # Save latent vectors with labels
            latent_dir = os.path.join(checkpoint_dir, f'latent_vectors_epoch_{epoch}')
            os.makedirs(latent_dir, exist_ok=True)

            cvae.eval()
            with torch.no_grad():
                latent_vectors = {cls: {'z_mean': [], 'z_logvar': [], 'labels': []} for cls in user_classes[user_idx]}
                for data, labels in user_loader:
                    data = data.to(device)
                    y = F.one_hot(labels, num_classes=num_classes).float().to(device)
                    z_mean, z_logvar = cvae.encoder(data, y)
                    for i, label in enumerate(labels):
                        latent_vectors[label.item()]['z_mean'].append(z_mean[i].cpu())
                        latent_vectors[label.item()]['z_logvar'].append(z_logvar[i].cpu())
                        latent_vectors[label.item()]['labels'].append(label.item())

                for cls in user_classes[user_idx]:
                    if latent_vectors[cls]['z_mean']:
                        z_mean = torch.stack(latent_vectors[cls]['z_mean'])
                        z_logvar = torch.stack(latent_vectors[cls]['z_logvar'])
                        labels = torch.tensor(latent_vectors[cls]['labels'])
                        save_path = os.path.join(latent_dir, f'class_{cls}.pth')
                        torch.save({
                            'z_mean': z_mean,
                            'z_logvar': z_logvar,
                            'labels': labels
                        }, save_path)
                        print(f"Saved latent vectors for User {user_idx + 1}, Class {cls} at epoch {epoch} to {save_path}")

    # Store losses for plotting
    train_losses_users[user_idx] = train_losses
    val_losses_users[user_idx] = val_losses

    # Plot losses
    plt.figure(figsize=(6, 4))
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Val Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(f'User {user_idx + 1} CVAE Loss')
    plt.legend()
    plt.grid(True)
    loss_plot_path = os.path.join(checkpoint_dir, 'loss_plot.png')
    plt.savefig(loss_plot_path)
    plt.close()
    print(f"Loss plot saved for User {user_idx + 1} to {loss_plot_path}")

    cvae_users[user_idx] = cvae

    user_time = time.time() - user_start_time
    print(f"Total time for User {user_idx + 1} CVAE training: {format_time(user_time)}\n")

# Step 3: Share latent vectors and decoder parameters to generate synthetic data
class_to_users = {cls: [] for cls in range(label_dim)}
for user_idx, classes in user_classes.items():
    for cls in classes:
        class_to_users[cls].append(user_idx)

# Define sharing scheme
sharing_scheme = {}
for cls in range(label_dim):
    target_users = [user_idx for user_idx in range(Num_users) if cls not in user_classes[user_idx]]
    if target_users and class_to_users[cls]:
        source_user = class_to_users[cls][0]  # First user with this class
        sharing_scheme[f'class_{cls}'] = {
            'source_user': source_user,
            'target_users': target_users,
            'share_decoder': True
        }

synthetic_datasets = [[] for _ in range(Num_users)]
num_synthetic_per_class_generate = 1000
num_synthetic_per_class_select = 500

# Use the latest latent vectors and decoder (from epoch 2)
for class_key, scheme in sharing_scheme.items():
    class_id = int(class_key.split('_')[1])
    source_user = scheme['source_user']
    target_users = scheme['target_users']

    # Load the latest latent vectors (epoch 2)
    latent_dir = os.path.join(output_dir, f'checkpoints_cvae_user_{source_user+1}', 'latent_vectors_epoch_2')
    latent_path = os.path.join(latent_dir, f'class_{class_id}.pth')
    latent_data = torch.load(latent_path, weights_only=False)
    print(f"Loaded latent data for User {source_user+1}, Class {class_id}: z_mean shape={latent_data['z_mean'].shape}")
    z_mean_all = latent_data['z_mean'].to(device)
    z_logvar_all = latent_data['z_logvar'].to(device)

    # Load the latest decoder (epoch 2)
    decoder_dir = os.path.join(output_dir, f'checkpoints_cvae_user_{source_user+1}', 'decoder')
    decoder_path = os.path.join(decoder_dir, 'decoder_epoch_2.pth')
    print(f"Loading decoder for User {source_user + 1}, Class {class_id}")

    # Create shared CVAE instance
    shared_cvae = ConditionalVAE(Encoder(intermediate_dim, latent_dim, num_classes), Decoder(latent_dim, intermediate_dim, num_classes)).to(device)

    if scheme['share_decoder']:
        decoder_params = torch.load(decoder_path, weights_only=False)
        shared_cvae.decoder.load_state_dict(decoder_params)
        print(f"Loaded decoder parameters: {decoder_path}")
    else:
        print(f"Warning: No decoder shared for user {source_user + 1}, Class {class_id}. Using random decoder.")

    # Generate synthetic data for all target users
    for user_idx in target_users:
        synthetic_dir = os.path.join(output_dir, f'synthetic_user_{user_idx + 1}', f'class_{class_id}')
        os.makedirs(synthetic_dir, exist_ok=True)

        print(f"Generating {num_synthetic_per_class_generate} synthetic images for User {user_idx + 1}, Class {class_id}")
        synthetic_images = []
        mean_intensities = []

        shared_cvae.eval()
        with torch.no_grad():
            for i in range(num_synthetic_per_class_generate):
                z = shared_cvae.reparameterize(z_mean_all[i % len(z_mean_all)].unsqueeze(0), 
                                               z_logvar_all[i % len(z_mean_all)].unsqueeze(0))
                y = F.one_hot(torch.tensor([class_id]), num_classes=label_dim).float().to(device)
                synthetic_img = shared_cvae.decoder(z, y).cpu()
                mean_intensity = synthetic_img.mean().item()
                synthetic_images.append(synthetic_img)
                mean_intensities.append(mean_intensity)

                if (i + 1) % 200 == 0:
                    print(f"Generated {i + 1} images for User {user_idx + 1}, Class {class_id}")

        # Select top 500 images based on mean pixel intensity
        print(f"Selecting top {num_synthetic_per_class_select} images for User {user_idx + 1}, Class {class_id}")
        sorted_indices = np.argsort(mean_intensities)[::-1]
        selected_indices = sorted_indices[:num_synthetic_per_class_select]

        # Save selected images
        for idx, img_idx in enumerate(selected_indices):
            img_path = os.path.join(synthetic_dir, f'image_{idx + 1}.png')
            try:
                img = synthetic_images[img_idx].view(3, RESIZE, RESIZE)
                img = img * 0.5 + 0.5  # Denormalize to [0, 1]
                img = img.clamp(0, 1)
                img = transforms.ToPILImage()(img)
                img.save(img_path)
                if (idx + 1) % 100 == 0 or idx == 0:
                    print(f"Saved {idx + 1} selected images for User {user_idx + 1}, Class {class_id}")
            except Exception as e:
                print(f"Error saving image {img_path}: {e}")
                continue

        print(f"Completed generating and selecting {num_synthetic_per_class_select} images for User {user_idx + 1}, Class {class_id}")

        class SyntheticDataset(Dataset):
            def __init__(self, class_label, root_dir, transform=None):
                self.class_label = class_label
                self.root_dir = root_dir
                self.transform = transform
                self.image_files = sorted([f for f in os.listdir(root_dir) if f.endswith('.png')])
                if len(self.image_files) == 0:
                    raise ValueError(f"No images found in {root_dir}")

            def __len__(self):
                return len(self.image_files)

            def __getitem__(self, idx):
                img_path = os.path.join(self.root_dir, self.image_files[idx])
                image = Image.open(img_path).convert("RGB")
                if self.transform:
                    image = self.transform(image)
                return image, self.class_label

        synthetic_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        synthetic_dataset = SyntheticDataset(class_id, synthetic_dir, transform=synthetic_transform)
        synthetic_datasets[user_idx].append(synthetic_dataset)

# Step 4: Verify the converted IID distribution
user_data = []
for user_idx in range(Num_users):
    real_data = ConcatDataset([train_class_datasets[i] for i in user_classes[user_idx]])
    if synthetic_datasets[user_idx]:
        user_data.append(ConcatDataset([real_data] + synthetic_datasets[user_idx]))
    else:
        user_data.append(real_data)

print("\n=== Verifying Converted IID Data Distribution Across Users ===")
class_counts_per_user = []
for user_idx in range(Num_users):
    user_dataset = user_data[user_idx]
    class_counts = [0] * label_dim
    for idx in range(len(user_dataset)):
        _, label = user_dataset[idx]
        class_counts[label] += 1
    class_counts_per_user.append(class_counts)
    print(f"User {user_idx + 1} (CVAE IID) Class Distribution: {class_counts}")
    total_samples = len(user_dataset)
    class_percentages = [count / total_samples * 100 if total_samples > 0 else 0 for count in class_counts]
    print(f"User {user_idx + 1} (CVAE IID) Class Percentages: {[f'{p:.2f}%' for p in class_percentages]}")

# Calculate and print total script time
total_time = time.time() - total_start_time
print(f"\nTotal time for the entire script: {format_time(total_time)}")

Path to dataset files: C:\Users\hp\.cache\kagglehub\datasets\sujaykapadnis\vehicle-type-image-dataset\versions\1
Inspecting dataset path: C:\Users\hp\.cache\kagglehub\datasets\sujaykapadnis\vehicle-type-image-dataset\versions\1
Root: C:\Users\hp\.cache\kagglehub\datasets\sujaykapadnis\vehicle-type-image-dataset\versions\1
Dirs: ['Vehicle Type Image Dataset (Version 2) VTID2']
Files (first 5): []
--------------------------------------------------
Root: C:\Users\hp\.cache\kagglehub\datasets\sujaykapadnis\vehicle-type-image-dataset\versions\1\Vehicle Type Image Dataset (Version 2) VTID2
Dirs: ['Hatchback', 'Other', 'Pickup', 'Seden', 'SUV']
Files (first 5): []
--------------------------------------------------
Root: C:\Users\hp\.cache\kagglehub\datasets\sujaykapadnis\vehicle-type-image-dataset\versions\1\Vehicle Type Image Dataset (Version 2) VTID2\Hatchback
Dirs: []
Files (first 5): ['PHOTO_0.jpg', 'PHOTO_1.jpg', 'PHOTO_10.jpg', 'PHOTO_100.jpg', 'PHOTO_101.jpg']
-------------------------

NameError: name 'user_classes' is not defined