In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader, Subset, Dataset, ConcatDataset
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import kagglehub
import time

# Start total script timer
total_start_time = time.time()

# Ensure output directory exists
output_dir = "FL_VEHICLE_CVAE_latent_test"
os.makedirs(output_dir, exist_ok=True)

# Hyperparameters
RESIZE = 128
original_dim = RESIZE * RESIZE * 3
intermediate_dim = 512
latent_dim = 256
num_classes = 5
batch_size = 8
epochs = 2  # Changed to 2 for testing
learning_rate = 1e-4
beta_start = 1
beta_end = 10
device = torch.device('cuda:1' if torch.cuda.device_count() > 1 else 'cuda')

# Print GPU information
print(f"Using device: {device}")
if device.type == 'cuda':
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory Allocated: {torch.cuda.memory_allocated(0) / 1024**2:.2f} MB")

# Download Vehicle Type Image Dataset from Kaggle
try:
    path = kagglehub.dataset_download("sujaykapadnis/vehicle-type-image-dataset")
    print("Path to dataset files:", path)
    dataset_path = path
except Exception as e:
    print(f"Failed to download dataset: {e}")
    raise

# Transform for CVAE (normalize to [-1, 1])
cvae_input_transform = transforms.Compose([
    transforms.Resize((RESIZE, RESIZE)),
    transforms.ToTensor(),  # [0, 1]
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # [-1, 1]
])

# Debug dataset directory structure
print("Inspecting dataset path:", dataset_path)
for root, dirs, files in os.walk(dataset_path):
    print(f"Root: {root}")
    print(f"Dirs: {dirs}")
    print(f"Files (first 5): {files[:5]}")
    print("-" * 50)

# Custom Dataset for the Vehicle Type Dataset
class VehicleTypeDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.images = []
        self.labels = []
        self.class_names = []
        self.class_to_idx = {}

        print(f"Searching for images in {root_dir}")
        for root, dirs, files in os.walk(root_dir):
            image_files = [f for f in files if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
            if image_files:
                class_name = os.path.basename(root)
                if class_name not in self.class_to_idx:
                    self.class_names.append(class_name)
                    self.class_to_idx[class_name] = len(self.class_names) - 1
                for img_file in image_files:
                    img_path = os.path.join(root, img_file)
                    try:
                        Image.open(img_path).verify()
                        self.images.append(img_path)
                        self.labels.append(self.class_to_idx[class_name])
                    except:
                        print(f"Skipping corrupted image: {img_path}")

        if len(self.class_names) != num_classes:
            raise ValueError(f"Expected {num_classes} classes, found {len(self.class_names)}")
        print(f"Found {len(self.images)} images across {len(self.class_names)} classes.")
        print(f"Classes: {self.class_names}")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, label

# Load the dataset
dataset = VehicleTypeDataset(root_dir=dataset_path, transform=cvae_input_transform)

# Update label_dim
label_dim = len(dataset.class_names)
print(f"Number of classes (label_dim): {label_dim}")

# Step 1: Split dataset into train, validation, and test sets per class
validation_ratio = 0.1
test_ratio = 0.1
train_ratio = 0.8

class_datasets = [[] for _ in range(label_dim)]
for idx in range(len(dataset)):
    label = dataset.labels[idx]
    class_datasets[label].append(idx)

train_indices_per_class = []
val_indices_per_class = []
test_indices_per_class = []

for class_idx in range(label_dim):
    indices = class_datasets[class_idx]
    total_samples = len(indices)
    num_train = int(total_samples * train_ratio)
    num_val = int(total_samples * validation_ratio)
    num_test = total_samples - num_train - num_val

    np.random.shuffle(indices)

    train_indices = indices[:num_train]
    val_indices = indices[num_train:num_train + num_val]
    test_indices = indices[num_train + num_val:]

    train_indices_per_class.append(train_indices)
    val_indices_per_class.append(val_indices)
    test_indices_per_class.append(test_indices)

    print(f"Class {class_idx}: Train={len(train_indices)}, Val={len(val_indices)}, Test={len(test_indices)}")

# Verify no overlap
for class_idx in range(label_dim):
    train_set = set(train_indices_per_class[class_idx])
    val_set = set(val_indices_per_class[class_idx])
    test_set = set(test_indices_per_class[class_idx])
    assert len(train_set.intersection(val_set)) == 0, f"Overlap between train and val for class {class_idx}"
    assert len(train_set.intersection(test_set)) == 0, f"Overlap between train and test for class {class_idx}"
    assert len(val_set.intersection(test_set)) == 0, f"Overlap between val and test for class {class_idx}"

# Create training dataset
train_dataset = Subset(dataset, [idx for class_indices in train_indices_per_class for idx in class_indices])
val_dataset = Subset(dataset, [idx for class_indices in val_indices_per_class for idx in class_indices])

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

# Subsample to 500 samples per class
target_samples_per_class = 500
train_class_datasets = []

for class_idx in range(label_dim):
    class_indices = train_indices_per_class[class_idx]
    if len(class_indices) > target_samples_per_class:
        class_indices = np.random.choice(class_indices, target_samples_per_class, replace=False).tolist()
    class_dataset = Subset(dataset, class_indices)
    train_class_datasets.append(class_dataset)
    print(f"Class {class_idx} subsampled dataset length: {len(class_dataset)}")

# Weight initialization
def init_weights(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.Linear):
        nn.init.xavier_normal_(m.weight)
        nn.init.constant_(m.bias, 0)

# Encoder (Convolutional)
class Encoder(nn.Module):
    def __init__(self, hidden_dim, latent_dim, num_classes):
        super(Encoder, self).__init__()
        self.encoder_cnn = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=5, stride=2, padding=2),    # Output: 64 x 64 x 64
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2),  # Output: 128 x 32 x 32
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2), # Output: 256 x 16 x 16
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 512, kernel_size=5, stride=2, padding=2), # Output: 512 x 8 x 8
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=5, stride=2, padding=2), # Output: 512 x 4 x 4
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512, 1024, kernel_size=5, stride=2, padding=2),# Output: 1024 x 2 x 2
            nn.BatchNorm2d(1024),
            nn.ReLU(),
            nn.Conv2d(1024, 1024, kernel_size=5, stride=1, padding=2),# Output: 1024 x 2 x 2
            nn.BatchNorm2d(1024),
            nn.ReLU()
        )
        self.fc_hidden = nn.Linear(1024 * 2 * 2 + num_classes, hidden_dim)
        self.dropout = nn.Dropout(0.3)
        self.fc_mean = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
        self.apply(init_weights)

    def forward(self, x, y):
        x = self.encoder_cnn(x)
        x = x.view(x.size(0), -1)
        x_with_y = torch.cat([x, y], dim=-1)
        h = F.relu(self.fc_hidden(x_with_y))
        h = self.dropout(h)
        z_mean = self.fc_mean(h)
        z_logvar = self.fc_logvar(h)
        z_logvar = torch.clamp(z_logvar, min=-10, max=10)
        return z_mean, z_logvar

# Decoder (Convolutional)
class Decoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, num_classes):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(latent_dim + num_classes, hidden_dim)
        self.dropout = nn.Dropout(0.3)
        self.fc_to_cnn = nn.Linear(hidden_dim, 1024 * 2 * 2)
        self.decoder_cnn = nn.Sequential(
            nn.ConvTranspose2d(1024, 1024, kernel_size=5, stride=1, padding=2), # Output: 1024 x 2 x 2
            nn.BatchNorm2d(1024),
            nn.ReLU(),
            nn.ConvTranspose2d(1024, 512, kernel_size=5, stride=2, padding=2, output_padding=1), # Output: 512 x 4 x 4
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 512, kernel_size=5, stride=2, padding=2, output_padding=1),  # Output: 512 x 8 x 8
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 256, kernel_size=5, stride=2, padding=2, output_padding=1),  # Output: 256 x 16 x 16
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=5, stride=2, padding=2, output_padding=1),  # Output: 128 x 32 x 32
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2, padding=2, output_padding=1),   # Output: 64 x 64 x 64
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, kernel_size=5, stride=2, padding=2, output_padding=1),     # Output: 3 x 128 x 128
            nn.Tanh()
        )
        self.apply(init_weights)

    def forward(self, z, y):
        z_with_y = torch.cat([z, y], dim=-1)
        h = F.relu(self.fc(z_with_y))
        h = self.dropout(h)
        h = F.relu(self.fc_to_cnn(h))
        h = h.view(-1, 1024, 2, 2)
        x_reconstructed = self.decoder_cnn(h)
        x_reconstructed = torch.clamp(x_reconstructed, min=-1, max=1)
        return x_reconstructed

# Conditional VAE
class ConditionalVAE(nn.Module):
    def __init__(self, encoder, decoder):
        super(ConditionalVAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def reparameterize(self, z_mean, z_logvar):
        std = torch.exp(0.5 * z_logvar)
        eps = torch.randn_like(std)
        return z_mean + eps * std

    def forward(self, data, y):
        z_mean, z_logvar = self.encoder(data, y)
        z = self.reparameterize(z_mean, z_logvar)
        x_reconstructed = self.decoder(z, y)
        return x_reconstructed, z_mean, z_logvar

# Loss Function
def cvae_loss(data, x_reconstructed, z_mean, z_logvar, beta=1.0):
    batch_size = data.size(0)
    mse_loss = F.mse_loss(x_reconstructed, data, reduction='sum') / batch_size
    kl_loss = -0.5 * torch.sum(1 + z_logvar - z_mean.pow(2) - z_logvar.exp()) / batch_size
    return mse_loss + beta * kl_loss

# Function to format time
def format_time(seconds):
    hours = int(seconds // 3600)
    minutes = int((seconds % 3600) // 60)
    secs = seconds % 60
    if hours > 0:
        return f"{hours}h {minutes}m {secs:.2f}s"
    elif minutes > 0:
        return f"{minutes}m {secs:.2f}s"
    else:
        return f"{secs:.2f}s"

# Step 2: Define user classes and train CVAEs
Num_users = 5
user_classes = {
    0: [0, 1, 3],  # User 1
    1: [1, 2, 4],  # User 2
    2: [0, 2, 3],  # User 3
    3: [1, 3, 4],  # User 4
    4: [0, 2, 4]   # User 5
}

cvae_users = {}
train_losses_users = {}
val_losses_users = {}

for user_idx in range(Num_users):
    # Start timer for this user's CVAE training
    user_start_time = time.time()

    # Create user dataset
    user_dataset = ConcatDataset([train_class_datasets[i] for i in user_classes[user_idx]])
    print(f"User {user_idx + 1} dataset length: {len(user_dataset)}")

    # Validate indices
    try:
        for i in range(min(5, len(user_dataset))):
            sample, label = user_dataset[i]
            print(f"User {user_idx + 1}, Sample {i}: Label={label}, Data shape={sample.shape}")
    except Exception as e:
        print(f"Error accessing samples for User {user_idx + 1}: {e}")
        raise

    user_loader = DataLoader(user_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True, drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True, drop_last=True)

    # Instantiate CVAE
    encoder = Encoder(intermediate_dim, latent_dim, num_classes).to(device)
    decoder = Decoder(latent_dim, intermediate_dim, num_classes).to(device)
    cvae = ConditionalVAE(encoder, decoder).to(device)

    # Optimizer and scheduler
    optimizer = optim.Adam(cvae.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=500, gamma=0.5)

    # Training loop
    checkpoint_dir = os.path.join(output_dir, f'checkpoints_cvae_user_{user_idx + 1}')
    os.makedirs(checkpoint_dir, exist_ok=True)

    train_losses = []
    val_losses = []

    cvae.train()
    for epoch in range(1, epochs + 1):
        epoch_start_time = time.time()

        beta = beta_start + (beta_end - beta_start) * ((epoch - 1) / (epochs - 1)) if epochs > 1 else beta_end

        train_loss = 0
        batches_processed = 0

        for batch_idx, (data, labels) in enumerate(user_loader):
            try:
                data = data.to(device)
                y = F.one_hot(labels, num_classes=num_classes).float().to(device)

                optimizer.zero_grad()
                x_recon, z_mean, z_logvar = cvae(data, y)
                loss = cvae_loss(data, x_recon, z_mean, z_logvar, beta)

                if torch.isnan(loss) or torch.isinf(loss):
                    print(f"NaN/Inf loss detected at epoch {epoch}, batch {batch_idx + 1} for User {user_idx + 1}")
                    continue

                loss.backward()
                torch.nn.utils.clip_grad_norm_(cvae.parameters(), max_norm=1.0)
                optimizer.step()
                train_loss += loss.item()
                batches_processed += 1

            except Exception as e:
                print(f"Error at epoch {epoch}, batch {batch_idx + 1} for User {user_idx + 1}: {e}")
                if "out of memory" in str(e).lower():
                    print("Out of memory error detected. Clearing cache...")
                    if device.type == 'cuda':
                        torch.cuda.empty_cache()
                continue

        avg_train_loss = train_loss / batches_processed if batches_processed > 0 else float('inf')
        train_losses.append(avg_train_loss)

        # Validation
        cvae.eval()
        val_loss = 0
        val_batches = 0
        with torch.no_grad():
            for data, labels in val_loader:
                data = data.to(device)
                y = F.one_hot(labels, num_classes=num_classes).float().to(device)
                x_recon, z_mean, z_logvar = cvae(data, y)
                loss = cvae_loss(data, x_recon, z_mean, z_logvar, beta=1.0)
                val_loss += loss.item()
                val_batches += 1

        avg_val_loss = val_loss / val_batches if val_batches > 0 else 0
        val_losses.append(avg_val_loss)

        epoch_time = time.time() - epoch_start_time
        print(f"User {user_idx + 1}, Epoch {epoch}/{epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Time: {format_time(epoch_time)}")

        scheduler.step()

        # Clear GPU memory
        if device.type == 'cuda':
            torch.cuda.empty_cache()

        # Save checkpoints, latent vectors, and decoder parameters every 500 epochs or at the end
        if epoch % 500 == 0 or epoch == epochs:
            checkpoint_path = os.path.join(checkpoint_dir, f'cvae_epoch_{epoch}.pth')
            torch.save({
                'epoch': epoch,
                'model_state_dict': cvae.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_losses': train_losses,
                'val_losses': val_losses
            }, checkpoint_path)
            print(f"Checkpoint saved for User {user_idx + 1} at epoch {epoch} to {checkpoint_path}")

            # Save decoder parameters
            decoder_dir = os.path.join(checkpoint_dir, 'decoder')
            os.makedirs(decoder_dir, exist_ok=True)
            decoder_path = os.path.join(decoder_dir, f'decoder_epoch_{epoch}.pth')
            torch.save(cvae.decoder.state_dict(), decoder_path)
            print(f"Decoder saved for User {user_idx + 1} at epoch {epoch} to {decoder_path}")

            # Save latent vectors with labels
            latent_dir = os.path.join(checkpoint_dir, f'latent_vectors_epoch_{epoch}')
            os.makedirs(latent_dir, exist_ok=True)

            cvae.eval()
            with torch.no_grad():
                latent_vectors = {cls: {'z_mean': [], 'z_logvar': [], 'labels': []} for cls in user_classes[user_idx]}
                for data, labels in user_loader:
                    data = data.to(device)
                    y = F.one_hot(labels, num_classes=num_classes).float().to(device)
                    z_mean, z_logvar = cvae.encoder(data, y)
                    for i, label in enumerate(labels):
                        latent_vectors[label.item()]['z_mean'].append(z_mean[i].cpu())
                        latent_vectors[label.item()]['z_logvar'].append(z_logvar[i].cpu())
                        latent_vectors[label.item()]['labels'].append(label.item())

                for cls in user_classes[user_idx]:
                    if latent_vectors[cls]['z_mean']:
                        z_mean = torch.stack(latent_vectors[cls]['z_mean'])
                        z_logvar = torch.stack(latent_vectors[cls]['z_logvar'])
                        labels = torch.tensor(latent_vectors[cls]['labels'])
                        save_path = os.path.join(latent_dir, f'class_{cls}.pth')
                        torch.save({
                            'z_mean': z_mean,
                            'z_logvar': z_logvar,
                            'labels': labels
                        }, save_path)
                        print(f"Saved latent vectors for User {user_idx + 1}, Class {cls} at epoch {epoch} to {save_path}")

    # Store losses for plotting
    train_losses_users[user_idx] = train_losses
    val_losses_users[user_idx] = val_losses

    # Plot losses
    plt.figure(figsize=(6, 4))
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Val Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(f'User {user_idx + 1} CVAE Loss')
    plt.legend()
    plt.grid(True)
    loss_plot_path = os.path.join(checkpoint_dir, 'loss_plot.png')
    plt.savefig(loss_plot_path)
    plt.close()
    print(f"Loss plot saved for User {user_idx + 1} to {loss_plot_path}")

    cvae_users[user_idx] = cvae

    user_time = time.time() - user_start_time
    print(f"Total time for User {user_idx + 1} CVAE training: {format_time(user_time)}\n")

# Step 3: Share latent vectors and decoder parameters to generate synthetic data
class_to_users = {cls: [] for cls in range(label_dim)}
for user_idx, classes in user_classes.items():
    for cls in classes:
        class_to_users[cls].append(user_idx)

# Define sharing scheme
sharing_scheme = {}
for cls in range(label_dim):
    target_users = [user_idx for user_idx in range(Num_users) if cls not in user_classes[user_idx]]
    if target_users and class_to_users[cls]:
        source_user = class_to_users[cls][0]  # First user with this class
        sharing_scheme[f'class_{cls}'] = {
            'source_user': source_user,
            'target_users': target_users,
            'share_decoder': True
        }

synthetic_datasets = [[] for _ in range(Num_users)]
num_synthetic_per_class_generate = 1000
num_synthetic_per_class_select = 500

# Use the latest latent vectors and decoder (from epoch 2)
for class_key, scheme in sharing_scheme.items():
    class_id = int(class_key.split('_')[1])
    source_user = scheme['source_user']
    target_users = scheme['target_users']

    # Load the latest latent vectors (epoch 2)
    latent_dir = os.path.join(output_dir, f'checkpoints_cvae_user_{source_user+1}', 'latent_vectors_epoch_2')
    latent_path = os.path.join(latent_dir, f'class_{class_id}.pth')
    latent_data = torch.load(latent_path, weights_only=False)
    print(f"Loaded latent data for User {source_user+1}, Class {class_id}: z_mean shape={latent_data['z_mean'].shape}")
    z_mean_all = latent_data['z_mean'].to(device)
    z_logvar_all = latent_data['z_logvar'].to(device)

    # Load the latest decoder (epoch 2)
    decoder_dir = os.path.join(output_dir, f'checkpoints_cvae_user_{source_user+1}', 'decoder')
    decoder_path = os.path.join(decoder_dir, 'decoder_epoch_2.pth')
    print(f"Loading decoder for User {source_user + 1}, Class {class_id}")

    # Create shared CVAE instance
    shared_cvae = ConditionalVAE(Encoder(intermediate_dim, latent_dim, num_classes), Decoder(latent_dim, intermediate_dim, num_classes)).to(device)

    if scheme['share_decoder']:
        decoder_params = torch.load(decoder_path, weights_only=False)
        shared_cvae.decoder.load_state_dict(decoder_params)
        print(f"Loaded decoder parameters: {decoder_path}")
    else:
        print(f"Warning: No decoder shared for user {source_user + 1}, Class {class_id}. Using random decoder.")

    # Generate synthetic data for all target users
    for user_idx in target_users:
        synthetic_dir = os.path.join(output_dir, f'synthetic_user_{user_idx + 1}', f'class_{class_id}')
        os.makedirs(synthetic_dir, exist_ok=True)

        print(f"Generating {num_synthetic_per_class_generate} synthetic images for User {user_idx + 1}, Class {class_id}")
        synthetic_images = []
        mean_intensities = []

        shared_cvae.eval()
        with torch.no_grad():
            for i in range(num_synthetic_per_class_generate):
                z = shared_cvae.reparameterize(z_mean_all[i % len(z_mean_all)].unsqueeze(0), 
                                               z_logvar_all[i % len(z_mean_all)].unsqueeze(0))
                y = F.one_hot(torch.tensor([class_id]), num_classes=label_dim).float().to(device)
                synthetic_img = shared_cvae.decoder(z, y).cpu()
                mean_intensity = synthetic_img.mean().item()
                synthetic_images.append(synthetic_img)
                mean_intensities.append(mean_intensity)

                if (i + 1) % 200 == 0:
                    print(f"Generated {i + 1} images for User {user_idx + 1}, Class {class_id}")

        # Select top 500 images based on mean pixel intensity
        print(f"Selecting top {num_synthetic_per_class_select} images for User {user_idx + 1}, Class {class_id}")
        sorted_indices = np.argsort(mean_intensities)[::-1]
        selected_indices = sorted_indices[:num_synthetic_per_class_select]

        # Save selected images
        for idx, img_idx in enumerate(selected_indices):
            img_path = os.path.join(synthetic_dir, f'image_{idx + 1}.png')
            try:
                img = synthetic_images[img_idx].view(3, RESIZE, RESIZE)
                img = img * 0.5 + 0.5  # Denormalize to [0, 1]
                img = img.clamp(0, 1)
                img = transforms.ToPILImage()(img)
                img.save(img_path)
                if (idx + 1) % 100 == 0 or idx == 0:
                    print(f"Saved {idx + 1} selected images for User {user_idx + 1}, Class {class_id}")
            except Exception as e:
                print(f"Error saving image {img_path}: {e}")
                continue

        print(f"Completed generating and selecting {num_synthetic_per_class_select} images for User {user_idx + 1}, Class {class_id}")

        class SyntheticDataset(Dataset):
            def __init__(self, class_label, root_dir, transform=None):
                self.class_label = class_label
                self.root_dir = root_dir
                self.transform = transform
                self.image_files = sorted([f for f in os.listdir(root_dir) if f.endswith('.png')])
                if len(self.image_files) == 0:
                    raise ValueError(f"No images found in {root_dir}")

            def __len__(self):
                return len(self.image_files)

            def __getitem__(self, idx):
                img_path = os.path.join(self.root_dir, self.image_files[idx])
                image = Image.open(img_path).convert("RGB")
                if self.transform:
                    image = self.transform(image)
                return image, self.class_label

        synthetic_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        synthetic_dataset = SyntheticDataset(class_id, synthetic_dir, transform=synthetic_transform)
        synthetic_datasets[user_idx].append(synthetic_dataset)

# Step 4: Verify the converted IID distribution
user_data = []
for user_idx in range(Num_users):
    real_data = ConcatDataset([train_class_datasets[i] for i in user_classes[user_idx]])
    if synthetic_datasets[user_idx]:
        user_data.append(ConcatDataset([real_data] + synthetic_datasets[user_idx]))
    else:
        user_data.append(real_data)

print("\n=== Verifying Converted IID Data Distribution Across Users ===")
class_counts_per_user = []
for user_idx in range(Num_users):
    user_dataset = user_data[user_idx]
    class_counts = [0] * label_dim
    for idx in range(len(user_dataset)):
        _, label = user_dataset[idx]
        class_counts[label] += 1
    class_counts_per_user.append(class_counts)
    print(f"User {user_idx + 1} (CVAE IID) Class Distribution: {class_counts}")
    total_samples = len(user_dataset)
    class_percentages = [count / total_samples * 100 if total_samples > 0 else 0 for count in class_counts]
    print(f"User {user_idx + 1} (CVAE IID) Class Percentages: {[f'{p:.2f}%' for p in class_percentages]}")

# Calculate and print total script time
total_time = time.time() - total_start_time
print(f"\nTotal time for the entire script: {format_time(total_time)}")

In [1]:
##3000epoches

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader, Subset, Dataset, ConcatDataset
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import kagglehub
import time
import random

# Start total script timer
total_start_time = time.time()

# Ensure output directory exists
output_dir = "FL_VEHICLE_CVAE_latent"
os.makedirs(output_dir, exist_ok=True)

# Hyperparameters
RESIZE = 128
original_dim = RESIZE * RESIZE * 3
intermediate_dim = 512
latent_dim = 256
num_classes = 5
batch_size = 8
epochs = 3000
learning_rate = 1e-4
beta_start = 1
beta_end = 10
device = torch.device('cuda:1' if torch.cuda.device_count() > 1 else 'cuda')

# Print GPU information
print(f"Using device: {device}")
if device.type == 'cuda':
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory Allocated: {torch.cuda.memory_allocated(0) / 1024**2:.2f} MB")

# Download Vehicle Type Image Dataset from Kaggle
try:
    path = kagglehub.dataset_download("sujaykapadnis/vehicle-type-image-dataset")
    print("Path to dataset files:", path)
    dataset_path = path
except Exception as e:
    print(f"Failed to download dataset: {e}")
    raise

# Transform for CVAE (normalize to [-1, 1])
cvae_input_transform = transforms.Compose([
    transforms.Resize((RESIZE, RESIZE)),
    transforms.ToTensor(),  # [0, 1]
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # [-1, 1]
])

# Debug dataset directory structure
print("Inspecting dataset path:", dataset_path)
for root, dirs, files in os.walk(dataset_path):
    print(f"Root: {root}")
    print(f"Dirs: {dirs}")
    print(f"Files (first 5): {files[:5]}")
    print("-" * 50)

# Custom Dataset for the Vehicle Type Dataset
class VehicleTypeDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.images = []
        self.labels = []
        self.class_names = []
        self.class_to_idx = {}

        print(f"Searching for images in {root_dir}")
        for root, dirs, files in os.walk(root_dir):
            image_files = [f for f in files if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
            if image_files:
                class_name = os.path.basename(root)
                if class_name not in self.class_to_idx:
                    self.class_names.append(class_name)
                    self.class_to_idx[class_name] = len(self.class_names) - 1
                for img_file in image_files:
                    img_path = os.path.join(root, img_file)
                    try:
                        Image.open(img_path).verify()
                        self.images.append(img_path)
                        self.labels.append(self.class_to_idx[class_name])
                    except:
                        print(f"Skipping corrupted image: {img_path}")

        if len(self.class_names) != num_classes:
            raise ValueError(f"Expected {num_classes} classes, found {len(self.class_names)}")
        print(f"Found {len(self.images)} images across {len(self.class_names)} classes.")
        print(f"Classes: {self.class_names}")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, label

# Load the dataset
dataset = VehicleTypeDataset(root_dir=dataset_path, transform=cvae_input_transform)

# Update label_dim
label_dim = len(dataset.class_names)
print(f"Number of classes (label_dim): {label_dim}")

# Step 1: Split dataset into train, validation, and test sets per class
validation_ratio = 0.1
test_ratio = 0.1
train_ratio = 0.8

class_datasets = [[] for _ in range(label_dim)]
for idx in range(len(dataset)):
    label = dataset.labels[idx]
    class_datasets[label].append(idx)

train_indices_per_class = []
val_indices_per_class = []
test_indices_per_class = []

for class_idx in range(label_dim):
    indices = class_datasets[class_idx]
    total_samples = len(indices)
    num_train = int(total_samples * train_ratio)
    num_val = int(total_samples * validation_ratio)
    num_test = total_samples - num_train - num_val

    np.random.shuffle(indices)

    train_indices = indices[:num_train]
    val_indices = indices[num_train:num_train + num_val]
    test_indices = indices[num_train + num_val:]

    train_indices_per_class.append(train_indices)
    val_indices_per_class.append(val_indices)
    test_indices_per_class.append(test_indices)

    print(f"Class {class_idx}: Train={len(train_indices)}, Val={len(val_indices)}, Test={len(test_indices)}")

# Verify no overlap
for class_idx in range(label_dim):
    train_set = set(train_indices_per_class[class_idx])
    val_set = set(val_indices_per_class[class_idx])
    test_set = set(test_indices_per_class[class_idx])
    assert len(train_set.intersection(val_set)) == 0, f"Overlap between train and val for class {class_idx}"
    assert len(train_set.intersection(test_set)) == 0, f"Overlap between train and test for class {class_idx}"
    assert len(val_set.intersection(test_set)) == 0, f"Overlap between val and test for class {class_idx}"

# Create training and validation datasets
train_dataset = Subset(dataset, [idx for class_indices in train_indices_per_class for idx in class_indices])
val_dataset = Subset(dataset, [idx for class_indices in val_indices_per_class for idx in class_indices])

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

# Step 1.5: Ensure 500 samples per class by generating synthetic samples for underrepresented classes
target_samples_per_class = 500
train_class_datasets = []
cvae_per_class = {}

# Weight initialization
def init_weights(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.Linear):
        nn.init.xavier_normal_(m.weight)
        nn.init.constant_(m.bias, 0)

# Encoder (Convolutional)
class Encoder(nn.Module):
    def __init__(self, hidden_dim, latent_dim, num_classes):
        super(Encoder, self).__init__()
        self.encoder_cnn = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 512, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512, 1024, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(1024),
            nn.ReLU(),
            nn.Conv2d(1024, 1024, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(1024),
            nn.ReLU()
        )
        self.fc_hidden = nn.Linear(1024 * 2 * 2 + num_classes, hidden_dim)
        self.dropout = nn.Dropout(0.3)
        self.fc_mean = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
        self.apply(init_weights)

    def forward(self, x, y):
        x = self.encoder_cnn(x)
        x = x.view(x.size(0), -1)
        x_with_y = torch.cat([x, y], dim=-1)
        h = F.relu(self.fc_hidden(x_with_y))
        h = self.dropout(h)
        z_mean = self.fc_mean(h)
        z_logvar = self.fc_logvar(h)
        z_logvar = torch.clamp(z_logvar, min=-10, max=10)
        return z_mean, z_logvar

# Decoder (Convolutional)
class Decoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, num_classes):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(latent_dim + num_classes, hidden_dim)
        self.dropout = nn.Dropout(0.3)
        self.fc_to_cnn = nn.Linear(hidden_dim, 1024 * 2 * 2)
        self.decoder_cnn = nn.Sequential(
            nn.ConvTranspose2d(1024, 1024, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(1024),
            nn.ReLU(),
            nn.ConvTranspose2d(1024, 512, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 512, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 256, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.Tanh()
        )
        self.apply(init_weights)

    def forward(self, z, y):
        z_with_y = torch.cat([z, y], dim=-1)
        h = F.relu(self.fc(z_with_y))
        h = self.dropout(h)
        h = F.relu(self.fc_to_cnn(h))
        h = h.view(-1, 1024, 2, 2)
        x_reconstructed = self.decoder_cnn(h)
        x_reconstructed = torch.clamp(x_reconstructed, min=-1, max=1)
        return x_reconstructed

# Conditional VAE
class ConditionalVAE(nn.Module):
    def __init__(self, encoder, decoder):
        super(ConditionalVAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def reparameterize(self, z_mean, z_logvar):
        std = torch.exp(0.5 * z_logvar)
        eps = torch.randn_like(std)
        return z_mean + eps * std

    def forward(self, data, y):
        z_mean, z_logvar = self.encoder(data, y)
        z = self.reparameterize(z_mean, z_logvar)
        x_reconstructed = self.decoder(z, y)
        return x_reconstructed, z_mean, z_logvar

# Loss Function
def cvae_loss(data, x_reconstructed, z_mean, z_logvar, beta=1.0):
    batch_size = data.size(0)
    mse_loss = F.mse_loss(x_reconstructed, data, reduction='sum') / batch_size
    kl_loss = -0.5 * torch.sum(1 + z_logvar - z_mean.pow(2) - z_logvar.exp()) / batch_size
    return mse_loss + beta * kl_loss

# Function to format time
def format_time(seconds):
    hours = int(seconds // 3600)
    minutes = int((seconds % 3600) // 60)
    secs = seconds % 60
    if hours > 0:
        return f"{hours}h {minutes}m {secs:.2f}s"
    elif minutes > 0:
        return f"{minutes}m {secs:.2f}s"
    else:
        return f"{secs:.2f}s"

# Generate synthetic samples for classes with fewer than 500 samples
for class_idx in range(label_dim):
    class_indices = train_indices_per_class[class_idx]
    num_real_samples = len(class_indices)
    print(f"Class {class_idx} has {num_real_samples} real samples before augmentation.")

    if num_real_samples >= target_samples_per_class:
        class_indices = np.random.choice(class_indices, target_samples_per_class, replace=False).tolist()
        class_dataset = Subset(dataset, class_indices)
        train_class_datasets.append(class_dataset)
        print(f"Class {class_idx} subsampled to {len(class_dataset)} samples.")
        continue

    num_synthetic_needed = target_samples_per_class - num_real_samples
    print(f"Generating {num_synthetic_needed} synthetic samples for Class {class_idx} to reach {target_samples_per_class} samples.")

    class_dataset = Subset(dataset, class_indices)
    class_loader = DataLoader(class_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True, drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True, drop_last=True)

    encoder = Encoder(intermediate_dim, latent_dim, num_classes).to(device)
    decoder = Decoder(latent_dim, intermediate_dim, num_classes).to(device)
    cvae = ConditionalVAE(encoder, decoder).to(device)

    optimizer = optim.Adam(cvae.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=500, gamma=0.5)

    class_checkpoint_dir = os.path.join(output_dir, f'checkpoints_cvae_class_{class_idx}')
    os.makedirs(class_checkpoint_dir, exist_ok=True)

    train_losses = []
    val_losses = []

    cvae.train()
    for epoch in range(1, epochs + 1):
        epoch_start_time = time.time()

        beta = beta_start + (beta_end - beta_start) * ((epoch - 1) / (epochs - 1)) if epochs > 1 else beta_end

        train_loss = 0
        batches_processed = 0

        for batch_idx, (data, labels) in enumerate(class_loader):
            try:
                data = data.to(device)
                y = F.one_hot(labels, num_classes=num_classes).float().to(device)

                optimizer.zero_grad()
                x_recon, z_mean, z_logvar = cvae(data, y)
                loss = cvae_loss(data, x_recon, z_mean, z_logvar, beta)

                if torch.isnan(loss) or torch.isinf(loss):
                    print(f"NaN/Inf loss detected at epoch {epoch}, batch {batch_idx + 1} for Class {class_idx}")
                    continue

                loss.backward()
                torch.nn.utils.clip_grad_norm_(cvae.parameters(), max_norm=1.0)
                optimizer.step()
                train_loss += loss.item()
                batches_processed += 1

            except Exception as e:
                print(f"Error at epoch {epoch}, batch {batch_idx + 1} for Class {class_idx}: {e}")
                if "out of memory" in str(e).lower():
                    print("Out of memory error detected. Clearing cache...")
                    if device.type == 'cuda':
                        torch.cuda.empty_cache()
                continue

        avg_train_loss = train_loss / batches_processed if batches_processed > 0 else float('inf')
        train_losses.append(avg_train_loss)

        cvae.eval()
        val_loss = 0
        val_batches = 0
        with torch.no_grad():
            for data, labels in val_loader:
                data = data.to(device)
                y = F.one_hot(labels, num_classes=num_classes).float().to(device)
                x_recon, z_mean, z_logvar = cvae(data, y)
                loss = cvae_loss(data, x_recon, z_mean, z_logvar, beta=1.0)
                val_loss += loss.item()
                val_batches += 1

        avg_val_loss = val_loss / val_batches if val_batches > 0 else 0
        val_losses.append(avg_val_loss)

        epoch_time = time.time() - epoch_start_time
        print(f"Class {class_idx}, Epoch {epoch}/{epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Time: {format_time(epoch_time)}")

        scheduler.step()

        if device.type == 'cuda':
            torch.cuda.empty_cache()

        if epoch % 500 == 0 or epoch == epochs:
            checkpoint_path = os.path.join(class_checkpoint_dir, f'cvae_epoch_{epoch}.pth')
            torch.save({
                'epoch': epoch,
                'model_state_dict': cvae.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_losses': train_losses,
                'val_losses': val_losses
            }, checkpoint_path)
            print(f"Checkpoint saved for Class {class_idx} at epoch {epoch} to {checkpoint_path}")

            latent_dir = os.path.join(class_checkpoint_dir, f'latent_vectors_epoch_{epoch}')
            os.makedirs(latent_dir, exist_ok=True)

            cvae.eval()
            with torch.no_grad():
                latent_vectors = {'z_mean': [], 'z_logvar': [], 'labels': []}
                for data, labels in class_loader:
                    data = data.to(device)
                    y = F.one_hot(labels, num_classes=num_classes).float().to(device)
                    z_mean, z_logvar = cvae.encoder(data, y)
                    for i in range(len(labels)):
                        latent_vectors['z_mean'].append(z_mean[i].cpu())
                        latent_vectors['z_logvar'].append(z_logvar[i].cpu())
                        latent_vectors['labels'].append(labels[i].item())

                if latent_vectors['z_mean']:
                    z_mean = torch.stack(latent_vectors['z_mean'])
                    z_logvar = torch.stack(latent_vectors['z_logvar'])
                    labels = torch.tensor(latent_vectors['labels'])
                    save_path = os.path.join(latent_dir, f'class_{class_idx}.pth')
                    torch.save({
                        'z_mean': z_mean,
                        'z_logvar': z_logvar,
                        'labels': labels
                    }, save_path)
                    print(f"Saved latent vectors for Class {class_idx} at epoch {epoch} to {save_path}")

    cvae_per_class[class_idx] = cvae

    synthetic_dir = os.path.join(output_dir, f'synthetic_class_{class_idx}')
    os.makedirs(synthetic_dir, exist_ok=True)

    print(f"Generating {num_synthetic_needed} synthetic images for Class {class_idx}")
    synthetic_images = []
    latent_path = os.path.join(class_checkpoint_dir, 'latent_vectors_epoch_3000', f'class_{class_idx}.pth')
    latent_data = torch.load(latent_path, weights_only=False)
    z_mean_all = latent_data['z_mean'].to(device)
    z_logvar_all = latent_data['z_logvar'].to(device)

    cvae.eval()
    with torch.no_grad():
        for i in range(num_synthetic_needed):
            z = cvae.reparameterize(z_mean_all[i % len(z_mean_all)].unsqueeze(0), 
                                    z_logvar_all[i % len(z_mean_all)].unsqueeze(0))
            y = F.one_hot(torch.tensor([class_idx]), num_classes=label_dim).float().to(device)
            synthetic_img = cvae.decoder(z, y).cpu()
            synthetic_images.append(synthetic_img)

    for idx, img in enumerate(synthetic_images):
        img_path = os.path.join(synthetic_dir, f'image_{idx + 1}.png')
        try:
            img = img.view(3, RESIZE, RESIZE)
            img = img * 0.5 + 0.5
            img = img.clamp(0, 1)
            img = transforms.ToPILImage()(img)
            img.save(img_path)
            if (idx + 1) % 100 == 0 or idx == 0:
                print(f"Saved {idx + 1} synthetic images for Class {class_idx}")
        except Exception as e:
            print(f"Error saving image {img_path}: {e}")
            continue

    class SyntheticDataset(Dataset):
        def __init__(self, class_label, root_dir, transform=None):
            self.class_label = class_label
            self.root_dir = root_dir
            self.transform = transform
            self.image_files = sorted([f for f in os.listdir(root_dir) if f.endswith('.png')])
            if len(self.image_files) == 0:
                raise ValueError(f"No images found in {root_dir}")

        def __len__(self):
            return len(self.image_files)

        def __getitem__(self, idx):
            img_path = os.path.join(self.root_dir, self.image_files[idx])
            image = Image.open(img_path).convert("RGB")
            if self.transform:
                image = self.transform(image)
            return image, self.class_label

    synthetic_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    synthetic_dataset = SyntheticDataset(class_idx, synthetic_dir, transform=synthetic_transform)

    combined_dataset = ConcatDataset([class_dataset, synthetic_dataset])
    combined_indices = list(range(len(combined_dataset)))
    if len(combined_indices) > target_samples_per_class:
        combined_indices = np.random.choice(combined_indices, target_samples_per_class, replace=False).tolist()
    final_class_dataset = Subset(combined_dataset, combined_indices)
    train_class_datasets.append(final_class_dataset)
    print(f"Class {class_idx} final dataset length: {len(final_class_dataset)}")

# Step 1.6: Split train_class_datasets like the first code
# Function to split a dataset into two parts
def split_dataset(dataset, split_ratio):
    train_size = int(np.round(split_ratio * len(dataset)))
    remaining_size = len(dataset) - train_size
    train_dataset, remaining_dataset = torch.utils.data.random_split(dataset, [train_size, remaining_size])
    return train_dataset, remaining_dataset

# Split each class dataset into two halves (50/50)
split_ratio = 0.5
split_datasets = []
train_class_datasets1 = []
train_class_datasets2 = []

for class_dataset in train_class_datasets:
    train_class_dataset1, train_class_dataset2 = split_dataset(class_dataset, split_ratio)
    split_datasets.append((train_class_dataset1, train_class_dataset2))
    train_class_datasets1.append(train_class_dataset1)
    train_class_datasets2.append(train_class_dataset2)

for i, (train_class_dataset1, train_class_dataset2) in enumerate(split_datasets):
    print(f"Class {i}:")
    print(f"  Number of samples in train_class_datasets1: {len(train_class_dataset1)}")
    print(f"  Number of samples in train_class_datasets2: {len(train_class_dataset2)}")

# Further split train_class_datasets2 into 70% and 30% parts
split_ratio = 0.7
split_datasets2 = []
train_class_datasets2_part1 = []
train_class_datasets2_part2 = []

for class_dataset in train_class_datasets2:
    part1_dataset, part2_dataset = split_dataset(class_dataset, split_ratio)
    split_datasets2.append((part1_dataset, part2_dataset))
    train_class_datasets2_part1.append(part1_dataset)
    train_class_datasets2_part2.append(part2_dataset)

for i, (part1_dataset, part2_dataset) in enumerate(split_datasets2):
    print(f"Class {i}:")
    print(f"  Number of samples in train_class_datasets2_part1 (70%): {len(part1_dataset)}")
    print(f"  Number of samples in train_class_datasets2_part2 (30%): {len(part2_dataset)}")

# Step 2: Define user datasets as per the first code
Num_users = 5
user_data = []
user_classes = {}  # For synthetic data generation

# Assign datasets to users
user_data.append(ConcatDataset([train_class_datasets1[0], train_class_datasets2[1], train_class_datasets2_part2[2], train_class_datasets2_part2[3], train_class_datasets2_part2[4]]))
user_classes[0] = [0, 1, 2, 3, 4]

user_data.append(ConcatDataset([train_class_datasets1[1], train_class_datasets2_part1[2]]))
user_classes[1] = [1, 2]

user_data.append(ConcatDataset([train_class_datasets1[2], train_class_datasets2_part1[3]]))
user_classes[2] = [2, 3]

user_data.append(ConcatDataset([train_class_datasets1[3], train_class_datasets2_part1[4]]))
user_classes[3] = [3, 4]

user_data.append(ConcatDataset([train_class_datasets1[4]]))
user_classes[4] = [4]

for i, user_dataset in enumerate(user_data):
    print(f"User {i + 1}:")
    print(f"Number of samples in the user dataset: {len(user_dataset)}")

cvae_users = {}
train_losses_users = {}
val_losses_users = {}

for user_idx in range(Num_users):
    user_start_time = time.time()

    user_dataset = user_data[user_idx]
    print(f"User {user_idx + 1} dataset length: {len(user_dataset)}")

    try:
        for i in range(min(5, len(user_dataset))):
            sample, label = user_dataset[i]
            print(f"User {user_idx + 1}, Sample {i}: Label={label}, Data shape={sample.shape}")
    except Exception as e:
        print(f"Error accessing samples for User {user_idx + 1}: {e}")
        raise

    user_loader = DataLoader(user_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True, drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True, drop_last=True)

    encoder = Encoder(intermediate_dim, latent_dim, num_classes).to(device)
    decoder = Decoder(latent_dim, intermediate_dim, num_classes).to(device)
    cvae = ConditionalVAE(encoder, decoder).to(device)

    optimizer = optim.Adam(cvae.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=500, gamma=0.5)

    checkpoint_dir = os.path.join(output_dir, f'checkpoints_cvae_user_{user_idx + 1}')
    os.makedirs(checkpoint_dir, exist_ok=True)

    train_losses = []
    val_losses = []

    cvae.train()
    for epoch in range(1, epochs + 1):
        epoch_start_time = time.time()

        beta = beta_start + (beta_end - beta_start) * ((epoch - 1) / (epochs - 1)) if epochs > 1 else beta_end

        train_loss = 0
        batches_processed = 0

        for batch_idx, (data, labels) in enumerate(user_loader):
            try:
                data = data.to(device)
                y = F.one_hot(labels, num_classes=num_classes).float().to(device)

                optimizer.zero_grad()
                x_recon, z_mean, z_logvar = cvae(data, y)
                loss = cvae_loss(data, x_recon, z_mean, z_logvar, beta)

                if torch.isnan(loss) or torch.isinf(loss):
                    print(f"NaN/Inf loss detected at epoch {epoch}, batch {batch_idx + 1} for User {user_idx + 1}")
                    continue

                loss.backward()
                torch.nn.utils.clip_grad_norm_(cvae.parameters(), max_norm=1.0)
                optimizer.step()
                train_loss += loss.item()
                batches_processed += 1

            except Exception as e:
                print(f"Error at epoch {epoch}, batch {batch_idx + 1} for User {user_idx + 1}: {e}")
                if "out of memory" in str(e).lower():
                    print("Out of memory error detected. Clearing cache...")
                    if device.type == 'cuda':
                        torch.cuda.empty_cache()
                continue

        avg_train_loss = train_loss / batches_processed if batches_processed > 0 else float('inf')
        train_losses.append(avg_train_loss)

        cvae.eval()
        val_loss = 0
        val_batches = 0
        with torch.no_grad():
            for data, labels in val_loader:
                data = data.to(device)
                y = F.one_hot(labels, num_classes=num_classes).float().to(device)
                x_recon, z_mean, z_logvar = cvae(data, y)
                loss = cvae_loss(data, x_recon, z_mean, z_logvar, beta=1.0)
                val_loss += loss.item()
                val_batches += 1

        avg_val_loss = val_loss / val_batches if val_batches > 0 else 0
        val_losses.append(avg_val_loss)

        epoch_time = time.time() - epoch_start_time
        print(f"User {user_idx + 1}, Epoch {epoch}/{epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Time: {format_time(epoch_time)}")

        scheduler.step()

        if device.type == 'cuda':
            torch.cuda.empty_cache()

        if epoch % 500 == 0 or epoch == epochs:
            checkpoint_path = os.path.join(checkpoint_dir, f'cvae_epoch_{epoch}.pth')
            torch.save({
                'epoch': epoch,
                'model_state_dict': cvae.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_losses': train_losses,
                'val_losses': val_losses
            }, checkpoint_path)
            print(f"Checkpoint saved for User {user_idx + 1} at epoch {epoch} to {checkpoint_path}")

            decoder_dir = os.path.join(checkpoint_dir, 'decoder')
            os.makedirs(decoder_dir, exist_ok=True)
            decoder_path = os.path.join(decoder_dir, f'decoder_epoch_{epoch}.pth')
            torch.save(cvae.decoder.state_dict(), decoder_path)
            print(f"Decoder saved for User {user_idx + 1} at epoch {epoch} to {decoder_path}")

            latent_dir = os.path.join(checkpoint_dir, f'latent_vectors_epoch_{epoch}')
            os.makedirs(latent_dir, exist_ok=True)

            cvae.eval()
            with torch.no_grad():
                latent_vectors = {cls: {'z_mean': [], 'z_logvar': [], 'labels': []} for cls in user_classes[user_idx]}
                for data, labels in user_loader:
                    data = data.to(device)
                    y = F.one_hot(labels, num_classes=num_classes).float().to(device)
                    z_mean, z_logvar = cvae.encoder(data, y)
                    for i, label in enumerate(labels):
                        latent_vectors[label.item()]['z_mean'].append(z_mean[i].cpu())
                        latent_vectors[label.item()]['z_logvar'].append(z_logvar[i].cpu())
                        latent_vectors[label.item()]['labels'].append(label.item())

                for cls in user_classes[user_idx]:
                    if latent_vectors[cls]['z_mean']:
                        z_mean = torch.stack(latent_vectors[cls]['z_mean'])
                        z_logvar = torch.stack(latent_vectors[cls]['z_logvar'])
                        labels = torch.tensor(latent_vectors[cls]['labels'])
                        save_path = os.path.join(latent_dir, f'class_{cls}.pth')
                        torch.save({
                            'z_mean': z_mean,
                            'z_logvar': z_logvar,
                            'labels': labels
                        }, save_path)
                        print(f"Saved latent vectors for User {user_idx + 1}, Class {cls} at epoch {epoch} to {save_path}")

    train_losses_users[user_idx] = train_losses
    val_losses_users[user_idx] = val_losses

    plt.figure(figsize=(6, 4))
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Val Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(f'User {user_idx + 1} CVAE Loss')
    plt.legend()
    plt.grid(True)
    loss_plot_path = os.path.join(checkpoint_dir, 'loss_plot.png')
    plt.savefig(loss_plot_path)
    plt.close()
    print(f"Loss plot saved for User {user_idx + 1} to {loss_plot_path}")

    cvae_users[user_idx] = cvae

    user_time = time.time() - user_start_time
    print(f"Total time for User {user_idx + 1} CVAE training: {format_time(user_time)}\n")

# Step 3: Share latent vectors and decoder parameters to generate synthetic data
class_to_users = {cls: [] for cls in range(label_dim)}
for user_idx, classes in user_classes.items():
    for cls in classes:
        class_to_users[cls].append(user_idx)

sharing_scheme = {}
for cls in range(label_dim):
    target_users = [user_idx for user_idx in range(Num_users) if cls not in user_classes[user_idx]]
    if target_users and class_to_users[cls]:
        source_user = class_to_users[cls][0]
        sharing_scheme[f'class_{cls}'] = {
            'source_user': source_user,
            'target_users': target_users,
            'share_decoder': True
        }

synthetic_datasets = [[] for _ in range(Num_users)]
num_synthetic_per_class_generate = 1000
num_synthetic_per_class_select = 500

for class_key, scheme in sharing_scheme.items():
    class_id = int(class_key.split('_')[1])
    source_user = scheme['source_user']
    target_users = scheme['target_users']

    latent_dir = os.path.join(output_dir, f'checkpoints_cvae_user_{source_user+1}', 'latent_vectors_epoch_3000')
    latent_path = os.path.join(latent_dir, f'class_{class_id}.pth')
    latent_data = torch.load(latent_path, weights_only=False)
    print(f"Loaded latent data for User {source_user+1}, Class {class_id}: z_mean shape={latent_data['z_mean'].shape}")
    z_mean_all = latent_data['z_mean'].to(device)
    z_logvar_all = latent_data['z_logvar'].to(device)

    decoder_dir = os.path.join(output_dir, f'checkpoints_cvae_user_{source_user+1}', 'decoder')
    decoder_path = os.path.join(decoder_dir, 'decoder_epoch_3000.pth')
    print(f"Loading decoder for User {source_user + 1}, Class {class_id}")

    shared_cvae = ConditionalVAE(Encoder(intermediate_dim, latent_dim, num_classes), Decoder(latent_dim, intermediate_dim, num_classes)).to(device)

    if scheme['share_decoder']:
        decoder_params = torch.load(decoder_path, weights_only=False)
        shared_cvae.decoder.load_state_dict(decoder_params)
        print(f"Loaded decoder parameters: {decoder_path}")
    else:
        print(f"Warning: No decoder shared for user {source_user + 1}, Class {class_id}. Using random decoder.")

    for user_idx in target_users:
        synthetic_dir = os.path.join(output_dir, f'synthetic_user_{user_idx + 1}', f'class_{class_id}')
        os.makedirs(synthetic_dir, exist_ok=True)

        print(f"Generating {num_synthetic_per_class_generate} synthetic images for User {user_idx + 1}, Class {class_id}")
        synthetic_images = []
        mean_intensities = []

        shared_cvae.eval()
        with torch.no_grad():
            for i in range(num_synthetic_per_class_generate):
                z = shared_cvae.reparameterize(z_mean_all[i % len(z_mean_all)].unsqueeze(0), 
                                               z_logvar_all[i % len(z_mean_all)].unsqueeze(0))
                y = F.one_hot(torch.tensor([class_id]), num_classes=label_dim).float().to(device)
                synthetic_img = shared_cvae.decoder(z, y).cpu()
                mean_intensity = synthetic_img.mean().item()
                synthetic_images.append(synthetic_img)
                mean_intensities.append(mean_intensity)

                if (i + 1) % 200 == 0:
                    print(f"Generated {i + 1} images for User {user_idx + 1}, Class {class_id}")

        print(f"Selecting top {num_synthetic_per_class_select} images for User {user_idx + 1}, Class {class_id}")
        sorted_indices = np.argsort(mean_intensities)[::-1]
        selected_indices = sorted_indices[:num_synthetic_per_class_select]

        for idx, img_idx in enumerate(selected_indices):
            img_path = os.path.join(synthetic_dir, f'image_{idx + 1}.png')
            try:
                img = synthetic_images[img_idx].view(3, RESIZE, RESIZE)
                img = img * 0.5 + 0.5
                img = img.clamp(0, 1)
                img = transforms.ToPILImage()(img)
                img.save(img_path)
                if (idx + 1) % 100 == 0 or idx == 0:
                    print(f"Saved {idx + 1} selected images for User {user_idx + 1}, Class {class_id}")
            except Exception as e:
                print(f"Error saving image {img_path}: {e}")
                continue

        print(f"Completed generating and selecting {num_synthetic_per_class_select} images for User {user_idx + 1}, Class {class_id}")

        class SyntheticDataset(Dataset):
            def __init__(self, class_label, root_dir, transform=None):
                self.class_label = class_label
                self.root_dir = root_dir
                self.transform = transform
                self.image_files = sorted([f for f in os.listdir(root_dir) if f.endswith('.png')])
                if len(self.image_files) == 0:
                    raise ValueError(f"No images found in {root_dir}")

            def __len__(self):
                return len(self.image_files)

            def __getitem__(self, idx):
                img_path = os.path.join(self.root_dir, self.image_files[idx])
                image = Image.open(img_path).convert("RGB")
                if self.transform:
                    image = self.transform(image)
                return image, self.class_label

        synthetic_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        synthetic_dataset = SyntheticDataset(class_id, synthetic_dir, transform=synthetic_transform)
        synthetic_datasets[user_idx].append(synthetic_dataset)

# Step 4: Verify the final non-IID distribution
final_user_data = []
for user_idx in range(Num_users):
    real_data = user_data[user_idx]
    if synthetic_datasets[user_idx]:
        final_user_data.append(ConcatDataset([real_data] + synthetic_datasets[user_idx]))
    else:
        final_user_data.append(real_data)

print("\n=== Verifying Final Non-IID Data Distribution Across Users ===")
class_counts_per_user = []
for user_idx in range(Num_users):
    user_dataset = final_user_data[user_idx]
    class_counts = [0] * label_dim
    for idx in range(len(user_dataset)):
        _, label = user_dataset[idx]
        class_counts[label] += 1
    class_counts_per_user.append(class_counts)
    print(f"User {user_idx + 1} (Non-IID) Class Distribution: {class_counts}")
    total_samples = len(user_dataset)
    class_percentages = [count / total_samples * 100 if total_samples > 0 else 0 for count in class_counts]
    print(f"User {user_idx + 1} (Non-IID) Class Percentages: {[f'{p:.2f}%' for p in class_percentages]}")

total_time = time.time() - total_start_time
print(f"\nTotal time for the entire script: {format_time(total_time)}")

In [2]:
##testing 2epoches

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader, Subset, Dataset, ConcatDataset
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import kagglehub
import time
import random

# Start total script timer
total_start_time = time.time()

# Ensure output directory exists
output_dir = "FL_VEHICLE_CVAE_latent_test2_non_iid_2"
os.makedirs(output_dir, exist_ok=True)

# Hyperparameters
RESIZE = 128
original_dim = RESIZE * RESIZE * 3
intermediate_dim = 512
latent_dim = 256
num_classes = 5
batch_size = 8
epochs = 2  # Changed from 3000 to 2
learning_rate = 1e-4
beta_start = 1
beta_end = 10
device = torch.device('cuda:1' if torch.cuda.device_count() > 1 else 'cuda')

# Print GPU information
print(f"Using device: {device}")
if device.type == 'cuda':
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory Allocated: {torch.cuda.memory_allocated(0) / 1024**2:.2f} MB")

# Download Vehicle Type Image Dataset from Kaggle
try:
    path = kagglehub.dataset_download("sujaykapadnis/vehicle-type-image-dataset")
    print("Path to dataset files:", path)
    dataset_path = path
except Exception as e:
    print(f"Failed to download dataset: {e}")
    raise

# Transform for CVAE (normalize to [-1, 1])
cvae_input_transform = transforms.Compose([
    transforms.Resize((RESIZE, RESIZE)),
    transforms.ToTensor(),  # [0, 1]
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # [-1, 1]
])

# Debug dataset directory structure
print("Inspecting dataset path:", dataset_path)
for root, dirs, files in os.walk(dataset_path):
    print(f"Root: {root}")
    print(f"Dirs: {dirs}")
    print(f"Files (first 5): {files[:5]}")
    print("-" * 50)

# Custom Dataset for the Vehicle Type Dataset
class VehicleTypeDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.images = []
        self.labels = []
        self.class_names = []
        self.class_to_idx = {}

        print(f"Searching for images in {root_dir}")
        for root, dirs, files in os.walk(root_dir):
            image_files = [f for f in files if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
            if image_files:
                class_name = os.path.basename(root)
                if class_name not in self.class_to_idx:
                    self.class_names.append(class_name)
                    self.class_to_idx[class_name] = len(self.class_names) - 1
                for img_file in image_files:
                    img_path = os.path.join(root, img_file)
                    try:
                        Image.open(img_path).verify()
                        self.images.append(img_path)
                        self.labels.append(self.class_to_idx[class_name])
                    except:
                        print(f"Skipping corrupted image: {img_path}")

        if len(self.class_names) != num_classes:
            raise ValueError(f"Expected {num_classes} classes, found {len(self.class_names)}")
        print(f"Found {len(self.images)} images across {len(self.class_names)} classes.")
        print(f"Classes: {self.class_names}")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, label

# Load the dataset
dataset = VehicleTypeDataset(root_dir=dataset_path, transform=cvae_input_transform)

# Update label_dim
label_dim = len(dataset.class_names)
print(f"Number of classes (label_dim): {label_dim}")

# Step 1: Split dataset into train, validation, and test sets per class
validation_ratio = 0.1
test_ratio = 0.1
train_ratio = 0.8

class_datasets = [[] for _ in range(label_dim)]
for idx in range(len(dataset)):
    label = dataset.labels[idx]
    class_datasets[label].append(idx)

train_indices_per_class = []
val_indices_per_class = []
test_indices_per_class = []

for class_idx in range(label_dim):
    indices = class_datasets[class_idx]
    total_samples = len(indices)
    num_train = int(total_samples * train_ratio)
    num_val = int(total_samples * validation_ratio)
    num_test = total_samples - num_train - num_val

    np.random.shuffle(indices)

    train_indices = indices[:num_train]
    val_indices = indices[num_train:num_train + num_val]
    test_indices = indices[num_train + num_val:]

    train_indices_per_class.append(train_indices)
    val_indices_per_class.append(val_indices)
    test_indices_per_class.append(test_indices)

    print(f"Class {class_idx}: Train={len(train_indices)}, Val={len(val_indices)}, Test={len(test_indices)}")

# Verify no overlap
for class_idx in range(label_dim):
    train_set = set(train_indices_per_class[class_idx])
    val_set = set(val_indices_per_class[class_idx])
    test_set = set(test_indices_per_class[class_idx])
    assert len(train_set.intersection(val_set)) == 0, f"Overlap between train and val for class {class_idx}"
    assert len(train_set.intersection(test_set)) == 0, f"Overlap between train and test for class {class_idx}"
    assert len(val_set.intersection(test_set)) == 0, f"Overlap between val and test for class {class_idx}"

# Create training and validation datasets
train_dataset = Subset(dataset, [idx for class_indices in train_indices_per_class for idx in class_indices])
val_dataset = Subset(dataset, [idx for class_indices in val_indices_per_class for idx in class_indices])

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

# Step 1.5: Ensure 500 samples per class by generating synthetic samples for underrepresented classes
target_samples_per_class = 500
train_class_datasets = []
cvae_per_class = {}

# Weight initialization
def init_weights(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.Linear):
        nn.init.xavier_normal_(m.weight)
        nn.init.constant_(m.bias, 0)

# Encoder (Convolutional)
class Encoder(nn.Module):
    def __init__(self, hidden_dim, latent_dim, num_classes):
        super(Encoder, self).__init__()
        self.encoder_cnn = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 512, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512, 1024, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(1024),
            nn.ReLU(),
            nn.Conv2d(1024, 1024, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(1024),
            nn.ReLU()
        )
        self.fc_hidden = nn.Linear(1024 * 2 * 2 + num_classes, hidden_dim)
        self.dropout = nn.Dropout(0.3)
        self.fc_mean = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
        self.apply(init_weights)

    def forward(self, x, y):
        x = self.encoder_cnn(x)
        x = x.view(x.size(0), -1)
        x_with_y = torch.cat([x, y], dim=-1)
        h = F.relu(self.fc_hidden(x_with_y))
        h = self.dropout(h)
        z_mean = self.fc_mean(h)
        z_logvar = self.fc_logvar(h)
        z_logvar = torch.clamp(z_logvar, min=-10, max=10)
        return z_mean, z_logvar

# Decoder (Convolutional)
class Decoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, num_classes):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(latent_dim + num_classes, hidden_dim)
        self.dropout = nn.Dropout(0.3)
        self.fc_to_cnn = nn.Linear(hidden_dim, 1024 * 2 * 2)
        self.decoder_cnn = nn.Sequential(
            nn.ConvTranspose2d(1024, 1024, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(1024),
            nn.ReLU(),
            nn.ConvTranspose2d(1024, 512, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 512, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 256, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.Tanh()
        )
        self.apply(init_weights)

    def forward(self, z, y):
        z_with_y = torch.cat([z, y], dim=-1)
        h = F.relu(self.fc(z_with_y))
        h = self.dropout(h)
        h = F.relu(self.fc_to_cnn(h))
        h = h.view(-1, 1024, 2, 2)
        x_reconstructed = self.decoder_cnn(h)
        x_reconstructed = torch.clamp(x_reconstructed, min=-1, max=1)
        return x_reconstructed

# Conditional VAE
class ConditionalVAE(nn.Module):
    def __init__(self, encoder, decoder):
        super(ConditionalVAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def reparameterize(self, z_mean, z_logvar):
        std = torch.exp(0.5 * z_logvar)
        eps = torch.randn_like(std)
        return z_mean + eps * std

    def forward(self, data, y):
        z_mean, z_logvar = self.encoder(data, y)
        z = self.reparameterize(z_mean, z_logvar)
        x_reconstructed = self.decoder(z, y)
        return x_reconstructed, z_mean, z_logvar

# Loss Function
def cvae_loss(data, x_reconstructed, z_mean, z_logvar, beta=1.0):
    batch_size = data.size(0)
    mse_loss = F.mse_loss(x_reconstructed, data, reduction='sum') / batch_size
    kl_loss = -0.5 * torch.sum(1 + z_logvar - z_mean.pow(2) - z_logvar.exp()) / batch_size
    return mse_loss + beta * kl_loss

# Function to format time
def format_time(seconds):
    hours = int(seconds // 3600)
    minutes = int((seconds % 3600) // 60)
    secs = seconds % 60
    if hours > 0:
        return f"{hours}h {minutes}m {secs:.2f}s"
    elif minutes > 0:
        return f"{minutes}m {secs:.2f}s"
    else:
        return f"{secs:.2f}s"

# Generate synthetic samples for classes with fewer than 500 samples
for class_idx in range(label_dim):
    class_indices = train_indices_per_class[class_idx]
    num_real_samples = len(class_indices)
    print(f"Class {class_idx} has {num_real_samples} real samples before augmentation.")

    if num_real_samples >= target_samples_per_class:
        class_indices = np.random.choice(class_indices, target_samples_per_class, replace=False).tolist()
        class_dataset = Subset(dataset, class_indices)
        train_class_datasets.append(class_dataset)
        print(f"Class {class_idx} subsampled to {len(class_dataset)} samples.")
        continue

    num_synthetic_needed = target_samples_per_class - num_real_samples
    print(f"Generating {num_synthetic_needed} synthetic samples for Class {class_idx} to reach {target_samples_per_class} samples.")

    class_dataset = Subset(dataset, class_indices)
    class_loader = DataLoader(class_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True, drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True, drop_last=True)

    encoder = Encoder(intermediate_dim, latent_dim, num_classes).to(device)
    decoder = Decoder(latent_dim, intermediate_dim, num_classes).to(device)
    cvae = ConditionalVAE(encoder, decoder).to(device)

    optimizer = optim.Adam(cvae.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=500, gamma=0.5)

    class_checkpoint_dir = os.path.join(output_dir, f'checkpoints_cvae_class_{class_idx}')
    os.makedirs(class_checkpoint_dir, exist_ok=True)

    train_losses = []
    val_losses = []

    cvae.train()
    for epoch in range(1, epochs + 1):
        epoch_start_time = time.time()

        beta = beta_start + (beta_end - beta_start) * ((epoch - 1) / (epochs - 1)) if epochs > 1 else beta_end

        train_loss = 0
        batches_processed = 0

        for batch_idx, (data, labels) in enumerate(class_loader):
            try:
                data = data.to(device)
                y = F.one_hot(labels, num_classes=num_classes).float().to(device)

                optimizer.zero_grad()
                x_recon, z_mean, z_logvar = cvae(data, y)
                loss = cvae_loss(data, x_recon, z_mean, z_logvar, beta)

                if torch.isnan(loss) or torch.isinf(loss):
                    print(f"NaN/Inf loss detected at epoch {epoch}, batch {batch_idx + 1} for Class {class_idx}")
                    continue

                loss.backward()
                torch.nn.utils.clip_grad_norm_(cvae.parameters(), max_norm=1.0)
                optimizer.step()
                train_loss += loss.item()
                batches_processed += 1

            except Exception as e:
                print(f"Error at epoch {epoch}, batch {batch_idx + 1} for Class {class_idx}: {e}")
                if "out of memory" in str(e).lower():
                    print("Out of memory error detected. Clearing cache...")
                    if device.type == 'cuda':
                        torch.cuda.empty_cache()
                continue

        avg_train_loss = train_loss / batches_processed if batches_processed > 0 else float('inf')
        train_losses.append(avg_train_loss)

        cvae.eval()
        val_loss = 0
        val_batches = 0
        with torch.no_grad():
            for data, labels in val_loader:
                data = data.to(device)
                y = F.one_hot(labels, num_classes=num_classes).float().to(device)
                x_recon, z_mean, z_logvar = cvae(data, y)
                loss = cvae_loss(data, x_recon, z_mean, z_logvar, beta=1.0)
                val_loss += loss.item()
                val_batches += 1

        avg_val_loss = val_loss / val_batches if val_batches > 0 else 0
        val_losses.append(avg_val_loss)

        epoch_time = time.time() - epoch_start_time
        print(f"Class {class_idx}, Epoch {epoch}/{epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Time: {format_time(epoch_time)}")

        scheduler.step()

        if device.type == 'cuda':
            torch.cuda.empty_cache()

        # Save checkpoint after epoch 1 and at the final epoch
        if epoch == 1 or epoch == epochs:
            checkpoint_path = os.path.join(class_checkpoint_dir, f'cvae_epoch_{epoch}.pth')
            torch.save({
                'epoch': epoch,
                'model_state_dict': cvae.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_losses': train_losses,
                'val_losses': val_losses
            }, checkpoint_path)
            print(f"Checkpoint saved for Class {class_idx} at epoch {epoch} to {checkpoint_path}")

            latent_dir = os.path.join(class_checkpoint_dir, f'latent_vectors_epoch_{epoch}')
            os.makedirs(latent_dir, exist_ok=True)

            cvae.eval()
            with torch.no_grad():
                latent_vectors = {'z_mean': [], 'z_logvar': [], 'labels': []}
                for data, labels in class_loader:
                    data = data.to(device)
                    y = F.one_hot(labels, num_classes=num_classes).float().to(device)
                    z_mean, z_logvar = cvae.encoder(data, y)
                    for i in range(len(labels)):
                        latent_vectors['z_mean'].append(z_mean[i].cpu())
                        latent_vectors['z_logvar'].append(z_logvar[i].cpu())
                        latent_vectors['labels'].append(labels[i].item())

                if latent_vectors['z_mean']:
                    z_mean = torch.stack(latent_vectors['z_mean'])
                    z_logvar = torch.stack(latent_vectors['z_logvar'])
                    labels = torch.tensor(latent_vectors['labels'])
                    save_path = os.path.join(latent_dir, f'class_{class_idx}.pth')
                    torch.save({
                        'z_mean': z_mean,
                        'z_logvar': z_logvar,
                        'labels': labels
                    }, save_path)
                    print(f"Saved latent vectors for Class {class_idx} at epoch {epoch} to {save_path}")

    cvae_per_class[class_idx] = cvae

    synthetic_dir = os.path.join(output_dir, f'synthetic_class_{class_idx}')
    os.makedirs(synthetic_dir, exist_ok=True)

    print(f"Generating {num_synthetic_needed} synthetic images for Class {class_idx}")
    synthetic_images = []
    latent_path = os.path.join(class_checkpoint_dir, f'latent_vectors_epoch_{epochs}', f'class_{class_idx}.pth')
    latent_data = torch.load(latent_path, weights_only=False)
    z_mean_all = latent_data['z_mean'].to(device)
    z_logvar_all = latent_data['z_logvar'].to(device)

    cvae.eval()
    with torch.no_grad():
        for i in range(num_synthetic_needed):
            z = cvae.reparameterize(z_mean_all[i % len(z_mean_all)].unsqueeze(0), 
                                    z_logvar_all[i % len(z_mean_all)].unsqueeze(0))
            y = F.one_hot(torch.tensor([class_idx]), num_classes=label_dim).float().to(device)
            synthetic_img = cvae.decoder(z, y).cpu()
            synthetic_images.append(synthetic_img)

    for idx, img in enumerate(synthetic_images):
        img_path = os.path.join(synthetic_dir, f'image_{idx + 1}.png')
        try:
            img = img.view(3, RESIZE, RESIZE)
            img = img * 0.5 + 0.5
            img = img.clamp(0, 1)
            img = transforms.ToPILImage()(img)
            img.save(img_path)
            if (idx + 1) % 100 == 0 or idx == 0:
                print(f"Saved {idx + 1} synthetic images for Class {class_idx}")
        except Exception as e:
            print(f"Error saving image {img_path}: {e}")
            continue

    class SyntheticDataset(Dataset):
        def __init__(self, class_label, root_dir, transform=None):
            self.class_label = class_label
            self.root_dir = root_dir
            self.transform = transform
            self.image_files = sorted([f for f in os.listdir(root_dir) if f.endswith('.png')])
            if len(self.image_files) == 0:
                raise ValueError(f"No images found in {root_dir}")

        def __len__(self):
            return len(self.image_files)

        def __getitem__(self, idx):
            img_path = os.path.join(self.root_dir, self.image_files[idx])
            image = Image.open(img_path).convert("RGB")
            if self.transform:
                image = self.transform(image)
            return image, self.class_label

    synthetic_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    synthetic_dataset = SyntheticDataset(class_idx, synthetic_dir, transform=synthetic_transform)

    combined_dataset = ConcatDataset([class_dataset, synthetic_dataset])
    combined_indices = list(range(len(combined_dataset)))
    if len(combined_indices) > target_samples_per_class:
        combined_indices = np.random.choice(combined_indices, target_samples_per_class, replace=False).tolist()
    final_class_dataset = Subset(combined_dataset, combined_indices)
    train_class_datasets.append(final_class_dataset)
    print(f"Class {class_idx} final dataset length: {len(final_class_dataset)}")

# Step 1.6: Split train_class_datasets like the first code
# Function to split a dataset into two parts
def split_dataset(dataset, split_ratio):
    train_size = int(np.round(split_ratio * len(dataset)))
    remaining_size = len(dataset) - train_size
    train_dataset, remaining_dataset = torch.utils.data.random_split(dataset, [train_size, remaining_size])
    return train_dataset, remaining_dataset

# Split each class dataset into two halves (50/50)
split_ratio = 0.5
split_datasets = []
train_class_datasets1 = []
train_class_datasets2 = []

for class_dataset in train_class_datasets:
    train_class_dataset1, train_class_dataset2 = split_dataset(class_dataset, split_ratio)
    split_datasets.append((train_class_dataset1, train_class_dataset2))
    train_class_datasets1.append(train_class_dataset1)
    train_class_datasets2.append(train_class_dataset2)

for i, (train_class_dataset1, train_class_dataset2) in enumerate(split_datasets):
    print(f"Class {i}:")
    print(f"  Number of samples in train_class_datasets1: {len(train_class_dataset1)}")
    print(f"  Number of samples in train_class_datasets2: {len(train_class_dataset2)}")

# Further split train_class_datasets2 into 70% and 30% parts
split_ratio = 0.7
split_datasets2 = []
train_class_datasets2_part1 = []
train_class_datasets2_part2 = []

for class_dataset in train_class_datasets2:
    part1_dataset, part2_dataset = split_dataset(class_dataset, split_ratio)
    split_datasets2.append((part1_dataset, part2_dataset))
    train_class_datasets2_part1.append(part1_dataset)
    train_class_datasets2_part2.append(part2_dataset)

for i, (part1_dataset, part2_dataset) in enumerate(split_datasets2):
    print(f"Class {i}:")
    print(f"  Number of samples in train_class_datasets2_part1 (70%): {len(part1_dataset)}")
    print(f"  Number of samples in train_class_datasets2_part2 (30%): {len(part2_dataset)}")

# Step 2: Define user datasets as per the first code
Num_users = 5
user_data = []
user_classes = {}  # For synthetic data generation

# Assign datasets to users
user_data.append(ConcatDataset([train_class_datasets1[0], train_class_datasets2[1], train_class_datasets2_part2[2], train_class_datasets2_part2[3], train_class_datasets2_part2[4]]))
user_classes[0] = [0, 1, 2, 3, 4]

user_data.append(ConcatDataset([train_class_datasets1[1], train_class_datasets2_part1[2]]))
user_classes[1] = [1, 2]

user_data.append(ConcatDataset([train_class_datasets1[2], train_class_datasets2_part1[3]]))
user_classes[2] = [2, 3]

user_data.append(ConcatDataset([train_class_datasets1[3], train_class_datasets2_part1[4]]))
user_classes[3] = [3, 4]

user_data.append(ConcatDataset([train_class_datasets1[4]]))
user_classes[4] = [4]

for i, user_dataset in enumerate(user_data):
    print(f"User {i + 1}:")
    print(f"Number of samples in the user dataset: {len(user_dataset)}")

cvae_users = {}
train_losses_users = {}
val_losses_users = {}

for user_idx in range(Num_users):
    user_start_time = time.time()

    user_dataset = user_data[user_idx]
    print(f"User {user_idx + 1} dataset length: {len(user_dataset)}")

    try:
        for i in range(min(5, len(user_dataset))):
            sample, label = user_dataset[i]
            print(f"User {user_idx + 1}, Sample {i}: Label={label}, Data shape={sample.shape}")
    except Exception as e:
        print(f"Error accessing samples for User {user_idx + 1}: {e}")
        raise

    user_loader = DataLoader(user_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True, drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True, drop_last=True)

    encoder = Encoder(intermediate_dim, latent_dim, num_classes).to(device)
    decoder = Decoder(latent_dim, intermediate_dim, num_classes).to(device)
    cvae = ConditionalVAE(encoder, decoder).to(device)

    optimizer = optim.Adam(cvae.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=500, gamma=0.5)

    checkpoint_dir = os.path.join(output_dir, f'checkpoints_cvae_user_{user_idx + 1}')
    os.makedirs(checkpoint_dir, exist_ok=True)

    train_losses = []
    val_losses = []

    cvae.train()
    for epoch in range(1, epochs + 1):
        epoch_start_time = time.time()

        beta = beta_start + (beta_end - beta_start) * ((epoch - 1) / (epochs - 1)) if epochs > 1 else beta_end

        train_loss = 0
        batches_processed = 0

        for batch_idx, (data, labels) in enumerate(user_loader):
            try:
                data = data.to(device)
                y = F.one_hot(labels, num_classes=num_classes).float().to(device)

                optimizer.zero_grad()
                x_recon, z_mean, z_logvar = cvae(data, y)
                loss = cvae_loss(data, x_recon, z_mean, z_logvar, beta)

                if torch.isnan(loss) or torch.isinf(loss):
                    print(f"NaN/Inf loss detected at epoch {epoch}, batch {batch_idx + 1} for User {user_idx + 1}")
                    continue

                loss.backward()
                torch.nn.utils.clip_grad_norm_(cvae.parameters(), max_norm=1.0)
                optimizer.step()
                train_loss += loss.item()
                batches_processed += 1

            except Exception as e:
                print(f"Error at epoch {epoch}, batch {batch_idx + 1} for User {user_idx + 1}: {e}")
                if "out of memory" in str(e).lower():
                    print("Out of memory error detected. Clearing cache...")
                    if device.type == 'cuda':
                        torch.cuda.empty_cache()
                continue

        avg_train_loss = train_loss / batches_processed if batches_processed > 0 else float('inf')
        train_losses.append(avg_train_loss)

        cvae.eval()
        val_loss = 0
        val_batches = 0
        with torch.no_grad():
            for data, labels in val_loader:
                data = data.to(device)
                y = F.one_hot(labels, num_classes=num_classes).float().to(device)
                x_recon, z_mean, z_logvar = cvae(data, y)
                loss = cvae_loss(data, x_recon, z_mean, z_logvar, beta=1.0)
                val_loss += loss.item()
                val_batches += 1

        avg_val_loss = val_loss / val_batches if val_batches > 0 else 0
        val_losses.append(avg_val_loss)

        epoch_time = time.time() - epoch_start_time
        print(f"User {user_idx + 1}, Epoch {epoch}/{epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Time: {format_time(epoch_time)}")

        scheduler.step()

        if device.type == 'cuda':
            torch.cuda.empty_cache()

        # Save checkpoint after epoch 1 and at the final epoch
        if epoch == 1 or epoch == epochs:
            checkpoint_path = os.path.join(checkpoint_dir, f'cvae_epoch_{epoch}.pth')
            torch.save({
                'epoch': epoch,
                'model_state_dict': cvae.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_losses': train_losses,
                'val_losses': val_losses
            }, checkpoint_path)
            print(f"Checkpoint saved for User {user_idx + 1} at epoch {epoch} to {checkpoint_path}")

            decoder_dir = os.path.join(checkpoint_dir, 'decoder')
            os.makedirs(decoder_dir, exist_ok=True)
            decoder_path = os.path.join(decoder_dir, f'decoder_epoch_{epoch}.pth')
            torch.save(cvae.decoder.state_dict(), decoder_path)
            print(f"Decoder saved for User {user_idx + 1} at epoch {epoch} to {decoder_path}")

            latent_dir = os.path.join(checkpoint_dir, f'latent_vectors_epoch_{epoch}')
            os.makedirs(latent_dir, exist_ok=True)

            cvae.eval()
            with torch.no_grad():
                latent_vectors = {cls: {'z_mean': [], 'z_logvar': [], 'labels': []} for cls in user_classes[user_idx]}
                for data, labels in user_loader:
                    data = data.to(device)
                    y = F.one_hot(labels, num_classes=num_classes).float().to(device)
                    z_mean, z_logvar = cvae.encoder(data, y)
                    for i, label in enumerate(labels):
                        latent_vectors[label.item()]['z_mean'].append(z_mean[i].cpu())
                        latent_vectors[label.item()]['z_logvar'].append(z_logvar[i].cpu())
                        latent_vectors[label.item()]['labels'].append(label.item())

                for cls in user_classes[user_idx]:
                    if latent_vectors[cls]['z_mean']:
                        z_mean = torch.stack(latent_vectors[cls]['z_mean'])
                        z_logvar = torch.stack(latent_vectors[cls]['z_logvar'])
                        labels = torch.tensor(latent_vectors[cls]['labels'])
                        save_path = os.path.join(latent_dir, f'class_{cls}.pth')
                        torch.save({
                            'z_mean': z_mean,
                            'z_logvar': z_logvar,
                            'labels': labels
                        }, save_path)
                        print(f"Saved latent vectors for User {user_idx + 1}, Class {cls} at epoch {epoch} to {save_path}")

    train_losses_users[user_idx] = train_losses
    val_losses_users[user_idx] = val_losses

    plt.figure(figsize=(6, 4))
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Val Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(f'User {user_idx + 1} CVAE Loss')
    plt.legend()
    plt.grid(True)
    loss_plot_path = os.path.join(checkpoint_dir, 'loss_plot.png')
    plt.savefig(loss_plot_path)
    plt.close()
    print(f"Loss plot saved for User {user_idx + 1} to {loss_plot_path}")

    cvae_users[user_idx] = cvae

    user_time = time.time() - user_start_time
    print(f"Total time for User {user_idx + 1} CVAE training: {format_time(user_time)}\n")

# Step 3: Share latent vectors and decoder parameters to generate synthetic data
class_to_users = {cls: [] for cls in range(label_dim)}
for user_idx, classes in user_classes.items():
    for cls in classes:
        class_to_users[cls].append(user_idx)

sharing_scheme = {}
for cls in range(label_dim):
    target_users = [user_idx for user_idx in range(Num_users) if cls not in user_classes[user_idx]]
    if target_users and class_to_users[cls]:
        source_user = class_to_users[cls][0]
        sharing_scheme[f'class_{cls}'] = {
            'source_user': source_user,
            'target_users': target_users,
            'share_decoder': True
        }

synthetic_datasets = [[] for _ in range(Num_users)]
num_synthetic_per_class_generate = 1000
num_synthetic_per_class_select = 500

for class_key, scheme in sharing_scheme.items():
    class_id = int(class_key.split('_')[1])
    source_user = scheme['source_user']
    target_users = scheme['target_users']

    latent_dir = os.path.join(output_dir, f'checkpoints_cvae_user_{source_user+1}', f'latent_vectors_epoch_{epochs}')
    latent_path = os.path.join(latent_dir, f'class_{class_id}.pth')
    latent_data = torch.load(latent_path, weights_only=False)
    print(f"Loaded latent data for User {source_user+1}, Class {class_id}: z_mean shape={latent_data['z_mean'].shape}")
    z_mean_all = latent_data['z_mean'].to(device)
    z_logvar_all = latent_data['z_logvar'].to(device)

    decoder_dir = os.path.join(output_dir, f'checkpoints_cvae_user_{source_user+1}', 'decoder')
    decoder_path = os.path.join(decoder_dir, f'decoder_epoch_{epochs}.pth')
    print(f"Loading decoder for User {source_user + 1}, Class {class_id}")

    shared_cvae = ConditionalVAE(Encoder(intermediate_dim, latent_dim, num_classes), Decoder(latent_dim, intermediate_dim, num_classes)).to(device)

    if scheme['share_decoder']:
        decoder_params = torch.load(decoder_path, weights_only=False)
        shared_cvae.decoder.load_state_dict(decoder_params)
        print(f"Loaded decoder parameters: {decoder_path}")
    else:
        print(f"Warning: No decoder shared for user {source_user + 1}, Class {class_id}. Using random decoder.")

    for user_idx in target_users:
        synthetic_dir = os.path.join(output_dir, f'synthetic_user_{user_idx + 1}', f'class_{class_id}')
        os.makedirs(synthetic_dir, exist_ok=True)

        print(f"Generating {num_synthetic_per_class_generate} synthetic images for User {user_idx + 1}, Class {class_id}")
        synthetic_images = []
        mean_intensities = []

        shared_cvae.eval()
        with torch.no_grad():
            for i in range(num_synthetic_per_class_generate):
                z = shared_cvae.reparameterize(z_mean_all[i % len(z_mean_all)].unsqueeze(0), 
                                               z_logvar_all[i % len(z_mean_all)].unsqueeze(0))
                y = F.one_hot(torch.tensor([class_id]), num_classes=label_dim).float().to(device)
                synthetic_img = shared_cvae.decoder(z, y).cpu()
                mean_intensity = synthetic_img.mean().item()
                synthetic_images.append(synthetic_img)
                mean_intensities.append(mean_intensity)

                if (i + 1) % 200 == 0:
                    print(f"Generated {i + 1} images for User {user_idx + 1}, Class {class_id}")

        print(f"Selecting top {num_synthetic_per_class_select} images for User {user_idx + 1}, Class {class_id}")
        sorted_indices = np.argsort(mean_intensities)[::-1]
        selected_indices = sorted_indices[:num_synthetic_per_class_select]

        for idx, img_idx in enumerate(selected_indices):
            img_path = os.path.join(synthetic_dir, f'image_{idx + 1}.png')
            try:
                img = synthetic_images[img_idx].view(3, RESIZE, RESIZE)
                img = img * 0.5 + 0.5
                img = img.clamp(0, 1)
                img = transforms.ToPILImage()(img)
                img.save(img_path)
                if (idx + 1) % 100 == 0 or idx == 0:
                    print(f"Saved {idx + 1} selected images for User {user_idx + 1}, Class {class_id}")
            except Exception as e:
                print(f"Error saving image {img_path}: {e}")
                continue

        print(f"Completed generating and selecting {num_synthetic_per_class_select} images for User {user_idx + 1}, Class {class_id}")

        class SyntheticDataset(Dataset):
            def __init__(self, class_label, root_dir, transform=None):
                self.class_label = class_label
                self.root_dir = root_dir
                self.transform = transform
                self.image_files = sorted([f for f in os.listdir(root_dir) if f.endswith('.png')])
                if len(self.image_files) == 0:
                    raise ValueError(f"No images found in {root_dir}")

            def __len__(self):
                return len(self.image_files)

            def __getitem__(self, idx):
                img_path = os.path.join(self.root_dir, self.image_files[idx])
                image = Image.open(img_path).convert("RGB")
                if self.transform:
                    image = self.transform(image)
                return image, self.class_label

        synthetic_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        synthetic_dataset = SyntheticDataset(class_id, synthetic_dir, transform=synthetic_transform)
        synthetic_datasets[user_idx].append(synthetic_dataset)

# Step 4: Verify the final non-IID distribution
final_user_data = []
for user_idx in range(Num_users):
    real_data = user_data[user_idx]
    if synthetic_datasets[user_idx]:
        final_user_data.append(ConcatDataset([real_data] + synthetic_datasets[user_idx]))
    else:
        final_user_data.append(real_data)

print("\n=== Verifying Final Non-IID Data Distribution Across Users ===")
class_counts_per_user = []
for user_idx in range(Num_users):
    user_dataset = final_user_data[user_idx]
    class_counts = [0] * label_dim
    for idx in range(len(user_dataset)):
        _, label = user_dataset[idx]
        class_counts[label] += 1
    class_counts_per_user.append(class_counts)
    print(f"User {user_idx + 1} (Non-IID) Class Distribution: {class_counts}")
    total_samples = len(user_dataset)
    class_percentages = [count / total_samples * 100 if total_samples > 0 else 0 for count in class_counts]
    print(f"User {user_idx + 1} (Non-IID) Class Percentages: {[f'{p:.2f}%' for p in class_percentages]}")

total_time = time.time() - total_start_time
print(f"\nTotal time for the entire script: {format_time(total_time)}")

Using device: cuda:1
GPU Name: NVIDIA RTX A5000
GPU Memory Allocated: 0.00 MB
Path to dataset files: C:\Users\hp\.cache\kagglehub\datasets\sujaykapadnis\vehicle-type-image-dataset\versions\1
Inspecting dataset path: C:\Users\hp\.cache\kagglehub\datasets\sujaykapadnis\vehicle-type-image-dataset\versions\1
Root: C:\Users\hp\.cache\kagglehub\datasets\sujaykapadnis\vehicle-type-image-dataset\versions\1
Dirs: ['Vehicle Type Image Dataset (Version 2) VTID2']
Files (first 5): []
--------------------------------------------------
Root: C:\Users\hp\.cache\kagglehub\datasets\sujaykapadnis\vehicle-type-image-dataset\versions\1\Vehicle Type Image Dataset (Version 2) VTID2
Dirs: ['Hatchback', 'Other', 'Pickup', 'Seden', 'SUV']
Files (first 5): []
--------------------------------------------------
Root: C:\Users\hp\.cache\kagglehub\datasets\sujaykapadnis\vehicle-type-image-dataset\versions\1\Vehicle Type Image Dataset (Version 2) VTID2\Hatchback
Dirs: []
Files (first 5): ['PHOTO_0.jpg', 'PHOTO_1.jpg

In [1]:
##generating 1000samples

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader, Subset, Dataset, ConcatDataset
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import kagglehub
import time
import random

# Start total script timer
total_start_time = time.time()

# Ensure output directory exists
output_dir = "FL_VEHICLE_CVAE_latent_test2_noniid_2"
os.makedirs(output_dir, exist_ok=True)

# Hyperparameters
RESIZE = 128
original_dim = RESIZE * RESIZE * 3
intermediate_dim = 512
latent_dim = 256
num_classes = 5
batch_size = 8
epochs = 2  # Already set to 2 for testing
learning_rate = 1e-4
beta_start = 1
beta_end = 10
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Print GPU information
print(f"Using device: {device}")
if device.type == 'cuda':
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory Allocated: {torch.cuda.memory_allocated(0) / 1024**2:.2f} MB")

# Download Vehicle Type Image Dataset from Kaggle
try:
    path = kagglehub.dataset_download("sujaykapadnis/vehicle-type-image-dataset")
    print("Path to dataset files:", path)
    dataset_path = path
except Exception as e:
    print(f"Failed to download dataset: {e}")
    raise

# Transform for CVAE (normalize to [-1, 1])
cvae_input_transform = transforms.Compose([
    transforms.Resize((RESIZE, RESIZE)),
    transforms.ToTensor(),  # [0, 1]
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # [-1, 1]
])

# Debug dataset directory structure
print("Inspecting dataset path:", dataset_path)
for root, dirs, files in os.walk(dataset_path):
    print(f"Root: {root}")
    print(f"Dirs: {dirs}")
    print(f"Files (first 5): {files[:5]}")
    print("-" * 50)

# Custom Dataset for the Vehicle Type Dataset
class VehicleTypeDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.images = []
        self.labels = []
        self.class_names = []
        self.class_to_idx = {}

        print(f"Searching for images in {root_dir}")
        for root, dirs, files in os.walk(root_dir):
            image_files = [f for f in files if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
            if image_files:
                class_name = os.path.basename(root)
                if class_name not in self.class_to_idx:
                    self.class_names.append(class_name)
                    self.class_to_idx[class_name] = len(self.class_names) - 1
                for img_file in image_files:
                    img_path = os.path.join(root, img_file)
                    try:
                        Image.open(img_path).verify()
                        self.images.append(img_path)
                        self.labels.append(self.class_to_idx[class_name])
                    except:
                        print(f"Skipping corrupted image: {img_path}")

        if len(self.class_names) != num_classes:
            raise ValueError(f"Expected {num_classes} classes, found {len(self.class_names)}")
        print(f"Found {len(self.images)} images across {len(self.class_names)} classes.")
        print(f"Classes: {self.class_names}")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, label

# Load the dataset
dataset = VehicleTypeDataset(root_dir=dataset_path, transform=cvae_input_transform)

# Update label_dim
label_dim = len(dataset.class_names)
print(f"Number of classes (label_dim): {label_dim}")

# Step 1: Split dataset into train, validation, and test sets per class
validation_ratio = 0.1
test_ratio = 0.1
train_ratio = 0.8

class_datasets = [[] for _ in range(label_dim)]
for idx in range(len(dataset)):
    label = dataset.labels[idx]
    class_datasets[label].append(idx)

train_indices_per_class = []
val_indices_per_class = []
test_indices_per_class = []

for class_idx in range(label_dim):
    indices = class_datasets[class_idx]
    total_samples = len(indices)
    num_train = int(total_samples * train_ratio)
    num_val = int(total_samples * validation_ratio)
    num_test = total_samples - num_train - num_val

    np.random.shuffle(indices)

    train_indices = indices[:num_train]
    val_indices = indices[num_train:num_train + num_val]
    test_indices = indices[num_train + num_val:]

    train_indices_per_class.append(train_indices)
    val_indices_per_class.append(val_indices)
    test_indices_per_class.append(test_indices)

    print(f"Class {class_idx}: Train={len(train_indices)}, Val={len(val_indices)}, Test={len(test_indices)}")

# Verify no overlap
for class_idx in range(label_dim):
    train_set = set(train_indices_per_class[class_idx])
    val_set = set(val_indices_per_class[class_idx])
    test_set = set(test_indices_per_class[class_idx])
    assert len(train_set.intersection(val_set)) == 0, f"Overlap between train and val for class {class_idx}"
    assert len(train_set.intersection(test_set)) == 0, f"Overlap between train and test for class {class_idx}"
    assert len(val_set.intersection(test_set)) == 0, f"Overlap between val and test for class {class_idx}"

# Create training and validation datasets
train_dataset = Subset(dataset, [idx for class_indices in train_indices_per_class for idx in class_indices])
val_dataset = Subset(dataset, [idx for class_indices in val_indices_per_class for idx in class_indices])

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

# Create separate datasets for each class using the training indices
distinct_class_datasets = []
for class_idx in range(label_dim):
    distinct_class_dataset = Subset(dataset, train_indices_per_class[class_idx])
    distinct_class_datasets.append(distinct_class_dataset)

# Verify the size of each class dataset
for i, distinct_class_dataset in enumerate(distinct_class_datasets):
    print(f"Class {i} dataset size: {len(distinct_class_dataset)}")

# Function to split a dataset into two parts
def split_dataset(dataset, split_ratio):
    train_size = int(np.round(split_ratio * len(dataset)))
    remaining_size = len(dataset) - train_size
    train_dataset, remaining_dataset = torch.utils.data.random_split(dataset, [train_size, remaining_size])
    return train_dataset, remaining_dataset

# Split each class dataset into two halves (50/50)
split_ratio = 0.5
split_datasets = []
train_class_datasets1 = []
train_class_datasets2 = []

for distinct_class_dataset in distinct_class_datasets:
    train_class_dataset1, train_class_dataset2 = split_dataset(distinct_class_dataset, split_ratio)
    split_datasets.append((train_class_dataset1, train_class_dataset2))
    train_class_datasets1.append(train_class_dataset1)
    train_class_datasets2.append(train_class_dataset2)

for i, (train_class_dataset1, train_class_dataset2) in enumerate(split_datasets):
    print(f"Class {i}:")
    print(f"  Number of samples in train_class_datasets1: {len(train_class_dataset1)}")
    print(f"  Number of samples in train_class_datasets2: {len(train_class_dataset2)}")

# Further split train_class_datasets2 into 70% and 30% parts
split_ratio = 0.7
split_datasets2 = []
train_class_datasets2_part1 = []
train_class_datasets2_part2 = []

for class_dataset in train_class_datasets2:
    part1_dataset, part2_dataset = split_dataset(class_dataset, split_ratio)
    split_datasets2.append((part1_dataset, part2_dataset))
    train_class_datasets2_part1.append(part1_dataset)
    train_class_datasets2_part2.append(part2_dataset)

for i, (part1_dataset, part2_dataset) in enumerate(split_datasets2):
    print(f"Class {i}:")
    print(f"  Number of samples in train_class_datasets2_part1 (70%): {len(part1_dataset)}")
    print(f"  Number of samples in train_class_datasets2_part2 (30%): {len(part2_dataset)}")

# Define augmentation transforms
augmentation_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=10),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
])

# Function to apply augmentation only to PIL images
def augment_image_if_needed(image):
    if isinstance(image, torch.Tensor):
        # Denormalize from [-1, 1] to [0, 1] before converting to PIL
        image = image * 0.5 + 0.5
        image = transforms.ToPILImage()(image)
    image = augmentation_transform(image)
    # Reapply the CVAE transform to normalize back to [-1, 1]
    image = cvae_input_transform(image)
    return image

# Function to augment the dataset to a target length
def augment_dataset(dataset, target_length):
    augmented_samples = []
    current_length = len(dataset)
    num_samples_to_augment = target_length - current_length
    
    if num_samples_to_augment <= 0:
        return dataset
    
    for _ in range(num_samples_to_augment):
        index = random.randint(0, current_length - 1)
        image, label = dataset[index]
        augmented_image = augment_image_if_needed(image)
        augmented_samples.append((augmented_image, label))
    
    augmented_dataset = ConcatDataset([dataset, augmented_samples])
    return augmented_dataset

# Target length for all datasets
lengthiest_length = 500

# Augment each dataset to have exactly 500 samples
train_class_datasets1 = [augment_dataset(dataset, lengthiest_length) for dataset in train_class_datasets1]
train_class_datasets2_part1 = [augment_dataset(dataset, lengthiest_length) for dataset in train_class_datasets2_part1]
train_class_datasets2_part2 = [augment_dataset(dataset, lengthiest_length) for dataset in train_class_datasets2_part2]
train_class_datasets2 = [augment_dataset(dataset, lengthiest_length) for dataset in train_class_datasets2]

# Print the sizes of the augmented datasets
print("Sizes of augmented datasets:")
for i, dataset in enumerate(train_class_datasets1):
    print(f"  Length of augmented train_class_datasets1[{i}]: {len(dataset)}")
for i, dataset in enumerate(train_class_datasets2_part1):
    print(f"  Length of augmented train_class_datasets2_part1[{i}]: {len(dataset)}")
for i, dataset in enumerate(train_class_datasets2_part2):
    print(f"  Length of augmented train_class_datasets2_part2[{i}]: {len(dataset)}")
for i, dataset in enumerate(train_class_datasets2):
    print(f"  Length of augmented train_class_datasets2[{i}]: {len(dataset)}")

# Define number of users
Num_users = 5

# Assign datasets to users in a non-IID manner
user_data = []
user_classes = {}

user_data.append(ConcatDataset([train_class_datasets1[0], train_class_datasets2[1], train_class_datasets2_part2[2], train_class_datasets2_part2[3], train_class_datasets2_part2[4]]))
user_classes[0] = [0, 1, 2, 3, 4]

user_data.append(ConcatDataset([train_class_datasets1[1], train_class_datasets2_part1[2]]))
user_classes[1] = [1, 2]

user_data.append(ConcatDataset([train_class_datasets1[2], train_class_datasets2_part1[3]]))
user_classes[2] = [2, 3]

user_data.append(ConcatDataset([train_class_datasets1[3], train_class_datasets2_part1[4]]))
user_classes[3] = [3, 4]

user_data.append(ConcatDataset([train_class_datasets1[4]]))
user_classes[4] = [4]

for i, user_dataset in enumerate(user_data):
    print(f"User {i + 1}:")
    print(f"Number of samples in the user dataset: {len(user_dataset)}")

# Weight initialization
def init_weights(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.Linear):
        nn.init.xavier_normal_(m.weight)
        nn.init.constant_(m.bias, 0)

# Encoder (Convolutional)
class Encoder(nn.Module):
    def __init__(self, hidden_dim, latent_dim, num_classes):
        super(Encoder, self).__init__()
        self.encoder_cnn = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 512, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512, 1024, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(1024),
            nn.ReLU(),
            nn.Conv2d(1024, 1024, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(1024),
            nn.ReLU()
        )
        self.fc_hidden = nn.Linear(1024 * 2 * 2 + num_classes, hidden_dim)
        self.dropout = nn.Dropout(0.3)
        self.fc_mean = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
        self.apply(init_weights)

    def forward(self, x, y):
        x = self.encoder_cnn(x)
        x = x.view(x.size(0), -1)
        x_with_y = torch.cat([x, y], dim=-1)
        h = F.relu(self.fc_hidden(x_with_y))
        h = self.dropout(h)
        z_mean = self.fc_mean(h)
        z_logvar = self.fc_logvar(h)
        z_logvar = torch.clamp(z_logvar, min=-10, max=10)
        return z_mean, z_logvar

# Decoder (Convolutional)
class Decoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, num_classes):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(latent_dim + num_classes, hidden_dim)
        self.dropout = nn.Dropout(0.3)
        self.fc_to_cnn = nn.Linear(hidden_dim, 1024 * 2 * 2)
        self.decoder_cnn = nn.Sequential(
            nn.ConvTranspose2d(1024, 1024, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(1024),
            nn.ReLU(),
            nn.ConvTranspose2d(1024, 512, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 512, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 256, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.Tanh()
        )
        self.apply(init_weights)

    def forward(self, z, y):
        z_with_y = torch.cat([z, y], dim=-1)
        h = F.relu(self.fc(z_with_y))
        h = self.dropout(h)
        h = F.relu(self.fc_to_cnn(h))
        h = h.view(-1, 1024, 2, 2)
        x_reconstructed = self.decoder_cnn(h)
        x_reconstructed = torch.clamp(x_reconstructed, min=-1, max=1)
        return x_reconstructed

# Conditional VAE
class ConditionalVAE(nn.Module):
    def __init__(self, encoder, decoder):
        super(ConditionalVAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def reparameterize(self, z_mean, z_logvar):
        std = torch.exp(0.5 * z_logvar)
        eps = torch.randn_like(std)
        return z_mean + eps * std

    def forward(self, data, y):
        z_mean, z_logvar = self.encoder(data, y)
        z = self.reparameterize(z_mean, z_logvar)
        x_reconstructed = self.decoder(z, y)
        return x_reconstructed, z_mean, z_logvar

# Loss Function
def cvae_loss(data, x_reconstructed, z_mean, z_logvar, beta=1.0):
    batch_size = data.size(0)
    mse_loss = F.mse_loss(x_reconstructed, data, reduction='sum') / batch_size
    kl_loss = -0.5 * torch.sum(1 + z_logvar - z_mean.pow(2) - z_logvar.exp()) / batch_size
    return mse_loss + beta * kl_loss

# Function to format time
def format_time(seconds):
    hours = int(seconds // 3600)
    minutes = int((seconds % 3600) // 60)
    secs = seconds % 60
    if hours > 0:
        return f"{hours}h {minutes}m {secs:.2f}s"
    elif minutes > 0:
        return f"{minutes}m {secs:.2f}s"
    else:
        return f"{secs:.2f}s"

# Step 2: Train CVAEs for each user
cvae_users = {}
train_losses_users = {}
val_losses_users = {}

for user_idx in range(Num_users):
    user_start_time = time.time()

    user_dataset = user_data[user_idx]
    print(f"User {user_idx + 1} dataset length: {len(user_dataset)}")

    try:
        for i in range(min(5, len(user_dataset))):
            sample, label = user_dataset[i]
            print(f"User {user_idx + 1}, Sample {i}: Label={label}, Data shape={sample.shape}")
    except Exception as e:
        print(f"Error accessing samples for User {user_idx + 1}: {e}")
        raise

    user_loader = DataLoader(user_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True, drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True, drop_last=True)

    encoder = Encoder(intermediate_dim, latent_dim, num_classes).to(device)
    decoder = Decoder(latent_dim, intermediate_dim, num_classes).to(device)
    cvae = ConditionalVAE(encoder, decoder).to(device)

    optimizer = optim.Adam(cvae.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=500, gamma=0.5)

    checkpoint_dir = os.path.join(output_dir, f'checkpoints_cvae_user_{user_idx + 1}')
    os.makedirs(checkpoint_dir, exist_ok=True)

    train_losses = []
    val_losses = []

    cvae.train()
    for epoch in range(1, epochs + 1):
        epoch_start_time = time.time()

        beta = beta_start + (beta_end - beta_start) * ((epoch - 1) / (epochs - 1)) if epochs > 1 else beta_end

        train_loss = 0
        batches_processed = 0

        for batch_idx, (data, labels) in enumerate(user_loader):
            try:
                data = data.to(device)
                y = F.one_hot(labels, num_classes=num_classes).float().to(device)

                optimizer.zero_grad()
                x_recon, z_mean, z_logvar = cvae(data, y)
                loss = cvae_loss(data, x_recon, z_mean, z_logvar, beta)

                if torch.isnan(loss) or torch.isinf(loss):
                    print(f"NaN/Inf loss detected at epoch {epoch}, batch {batch_idx + 1} for User {user_idx + 1}")
                    continue

                loss.backward()
                torch.nn.utils.clip_grad_norm_(cvae.parameters(), max_norm=1.0)
                optimizer.step()
                train_loss += loss.item()
                batches_processed += 1

            except Exception as e:
                print(f"Error at epoch {epoch}, batch {batch_idx + 1} for User {user_idx + 1}: {e}")
                if "out of memory" in str(e).lower():
                    print("Out of memory error detected. Clearing cache...")
                    if device.type == 'cuda':
                        torch.cuda.empty_cache()
                continue

        avg_train_loss = train_loss / batches_processed if batches_processed > 0 else float('inf')
        train_losses.append(avg_train_loss)

        cvae.eval()
        val_loss = 0
        val_batches = 0
        with torch.no_grad():
            for data, labels in val_loader:
                data = data.to(device)
                y = F.one_hot(labels, num_classes=num_classes).float().to(device)
                x_recon, z_mean, z_logvar = cvae(data, y)
                loss = cvae_loss(data, x_recon, z_mean, z_logvar, beta=1.0)
                val_loss += loss.item()
                val_batches += 1

        avg_val_loss = val_loss / val_batches if val_batches > 0 else 0
        val_losses.append(avg_val_loss)

        epoch_time = time.time() - epoch_start_time
        print(f"User {user_idx + 1}, Epoch {epoch}/{epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Time: {format_time(epoch_time)}")

        scheduler.step()

        if device.type == 'cuda':
            torch.cuda.empty_cache()

        if epoch == 1 or epoch == epochs:  # Save after epoch 1 as requested
            checkpoint_path = os.path.join(checkpoint_dir, f'cvae_epoch_{epoch}.pth')
            torch.save({
                'epoch': epoch,
                'model_state_dict': cvae.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_losses': train_losses,
                'val_losses': val_losses
            }, checkpoint_path)
            print(f"Checkpoint saved for User {user_idx + 1} at epoch {epoch} to {checkpoint_path}")

            decoder_dir = os.path.join(checkpoint_dir, 'decoder')
            os.makedirs(decoder_dir, exist_ok=True)
            decoder_path = os.path.join(decoder_dir, f'decoder_epoch_{epoch}.pth')
            torch.save(cvae.decoder.state_dict(), decoder_path)
            print(f"Decoder saved for User {user_idx + 1} at epoch {epoch} to {decoder_path}")

            latent_dir = os.path.join(checkpoint_dir, f'latent_vectors_epoch_{epoch}')
            os.makedirs(latent_dir, exist_ok=True)

            cvae.eval()
            with torch.no_grad():
                latent_vectors = {cls: {'z_mean': [], 'z_logvar': [], 'labels': []} for cls in user_classes[user_idx]}
                for data, labels in user_loader:
                    data = data.to(device)
                    y = F.one_hot(labels, num_classes=num_classes).float().to(device)
                    z_mean, z_logvar = cvae.encoder(data, y)
                    for i, label in enumerate(labels):
                        latent_vectors[label.item()]['z_mean'].append(z_mean[i].cpu())
                        latent_vectors[label.item()]['z_logvar'].append(z_logvar[i].cpu())
                        latent_vectors[label.item()]['labels'].append(label.item())

                for cls in user_classes[user_idx]:
                    if latent_vectors[cls]['z_mean']:
                        z_mean = torch.stack(latent_vectors[cls]['z_mean'])
                        z_logvar = torch.stack(latent_vectors[cls]['z_logvar'])
                        labels = torch.tensor(latent_vectors[cls]['labels'])
                        save_path = os.path.join(latent_dir, f'class_{cls}.pth')
                        torch.save({
                            'z_mean': z_mean,
                            'z_logvar': z_logvar,
                            'labels': labels
                        }, save_path)
                        print(f"Saved latent vectors for User {user_idx + 1}, Class {cls} at epoch {epoch} to {save_path}")

    train_losses_users[user_idx] = train_losses
    val_losses_users[user_idx] = val_losses

    plt.figure(figsize=(6, 4))
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Val Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(f'User {user_idx + 1} CVAE Loss')
    plt.legend()
    plt.grid(True)
    loss_plot_path = os.path.join(checkpoint_dir, 'loss_plot.png')
    plt.savefig(loss_plot_path)
    plt.close()
    print(f"Loss plot saved for User {user_idx + 1} to {loss_plot_path}")

    cvae_users[user_idx] = cvae

    user_time = time.time() - user_start_time
    print(f"Total time for User {user_idx + 1} CVAE training: {format_time(user_time)}\n")

# Step 3: Share latent vectors and decoder parameters to generate 1000 synthetic samples per class
class_to_users = {cls: [] for cls in range(label_dim)}
for user_idx, classes in user_classes.items():
    for cls in classes:
        class_to_users[cls].append(user_idx)

sharing_scheme = {}
for cls in range(label_dim):
    target_users = [user_idx for user_idx in range(Num_users) if cls not in user_classes[user_idx]]
    if target_users and class_to_users[cls]:
        source_user = class_to_users[cls][0]
        sharing_scheme[f'class_{cls}'] = {
            'source_user': source_user,
            'target_users': target_users,
            'share_decoder': True
        }

synthetic_datasets = [[] for _ in range(Num_users)]
num_synthetic_per_class_generate = 1000  # Generate exactly 1000 samples per class

for class_key, scheme in sharing_scheme.items():
    class_id = int(class_key.split('_')[1])
    source_user = scheme['source_user']
    target_users = scheme['target_users']

    latent_dir = os.path.join(output_dir, f'checkpoints_cvae_user_{source_user+1}', 'latent_vectors_epoch_2')
    latent_path = os.path.join(latent_dir, f'class_{class_id}.pth')
    latent_data = torch.load(latent_path, weights_only=False)
    print(f"Loaded latent data for User {source_user+1}, Class {class_id}: z_mean shape={latent_data['z_mean'].shape}")
    z_mean_all = latent_data['z_mean'].to(device)
    z_logvar_all = latent_data['z_logvar'].to(device)

    decoder_dir = os.path.join(output_dir, f'checkpoints_cvae_user_{source_user+1}', 'decoder')
    decoder_path = os.path.join(decoder_dir, 'decoder_epoch_2.pth')
    print(f"Loading decoder for User {source_user + 1}, Class {class_id}")

    shared_cvae = ConditionalVAE(Encoder(intermediate_dim, latent_dim, num_classes), Decoder(latent_dim, intermediate_dim, num_classes)).to(device)

    if scheme['share_decoder']:
        decoder_params = torch.load(decoder_path, weights_only=False)
        shared_cvae.decoder.load_state_dict(decoder_params)
        print(f"Loaded decoder parameters: {decoder_path}")
    else:
        print(f"Warning: No decoder shared for user {source_user + 1}, Class {class_id}. Using random decoder.")

    for user_idx in target_users:
        synthetic_dir = os.path.join(output_dir, f'synthetic_user_{user_idx + 1}', f'class_{class_id}')
        os.makedirs(synthetic_dir, exist_ok=True)

        print(f"Generating {num_synthetic_per_class_generate} synthetic images for User {user_idx + 1}, Class {class_id}")
        synthetic_images = []

        shared_cvae.eval()
        with torch.no_grad():
            for i in range(num_synthetic_per_class_generate):
                z = shared_cvae.reparameterize(z_mean_all[i % len(z_mean_all)].unsqueeze(0), 
                                               z_logvar_all[i % len(z_mean_all)].unsqueeze(0))
                y = F.one_hot(torch.tensor([class_id]), num_classes=label_dim).float().to(device)
                synthetic_img = shared_cvae.decoder(z, y).cpu()
                synthetic_images.append(synthetic_img)

                if (i + 1) % 200 == 0:
                    print(f"Generated {i + 1} images for User {user_idx + 1}, Class {class_id}")

        # Save all 1000 generated images
        for idx, img in enumerate(synthetic_images):
            img_path = os.path.join(synthetic_dir, f'image_{idx + 1}.png')
            try:
                img = img.view(3, RESIZE, RESIZE)
                img = img * 0.5 + 0.5  # Denormalize to [0, 1]
                img = img.clamp(0, 1)
                img = transforms.ToPILImage()(img)
                img.save(img_path)
                if (idx + 1) % 200 == 0 or idx == 0:
                    print(f"Saved {idx + 1} images for User {user_idx + 1}, Class {class_id}")
            except Exception as e:
                print(f"Error saving image {img_path}: {e}")
                continue

        print(f"Completed generating {num_synthetic_per_class_generate} images for User {user_idx + 1}, Class {class_id}")

        class SyntheticDataset(Dataset):
            def __init__(self, class_label, root_dir, transform=None):
                self.class_label = class_label
                self.root_dir = root_dir
                self.transform = transform
                self.image_files = sorted([f for f in os.listdir(root_dir) if f.endswith('.png')])
                if len(self.image_files) == 0:
                    raise ValueError(f"No images found in {root_dir}")

            def __len__(self):
                return len(self.image_files)

            def __getitem__(self, idx):
                img_path = os.path.join(self.root_dir, self.image_files[idx])
                image = Image.open(img_path).convert("RGB")
                if self.transform:
                    image = self.transform(image)
                return image, self.class_label

        synthetic_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        synthetic_dataset = SyntheticDataset(class_id, synthetic_dir, transform=synthetic_transform)
        synthetic_datasets[user_idx].append(synthetic_dataset)

# Step 4: Verify the final non-IID distribution
final_user_data = []
for user_idx in range(Num_users):
    real_data = user_data[user_idx]
    if synthetic_datasets[user_idx]:
        final_user_data.append(ConcatDataset([real_data] + synthetic_datasets[user_idx]))
    else:
        final_user_data.append(real_data)

print("\n=== Verifying Final Non-IID Data Distribution Across Users ===")
class_counts_per_user = []
for user_idx in range(Num_users):
    user_dataset = final_user_data[user_idx]
    class_counts = [0] * label_dim
    for idx in range(len(user_dataset)):
        _, label = user_dataset[idx]
        class_counts[label] += 1
    class_counts_per_user.append(class_counts)
    print(f"User {user_idx + 1} (Non-IID) Class Distribution: {class_counts}")
    total_samples = len(user_dataset)
    class_percentages = [count / total_samples * 100 if total_samples > 0 else 0 for count in class_counts]
    print(f"User {user_idx + 1} (Non-IID) Class Percentages: {[f'{p:.2f}%' for p in class_percentages]}")

total_time = time.time() - total_start_time
print(f"\nTotal time for the entire script: {format_time(total_time)}")

Using device: cuda
GPU Name: NVIDIA RTX A5000
GPU Memory Allocated: 1632.76 MB
Path to dataset files: C:\Users\hp\.cache\kagglehub\datasets\sujaykapadnis\vehicle-type-image-dataset\versions\1
Inspecting dataset path: C:\Users\hp\.cache\kagglehub\datasets\sujaykapadnis\vehicle-type-image-dataset\versions\1
Root: C:\Users\hp\.cache\kagglehub\datasets\sujaykapadnis\vehicle-type-image-dataset\versions\1
Dirs: ['Vehicle Type Image Dataset (Version 2) VTID2']
Files (first 5): []
--------------------------------------------------
Root: C:\Users\hp\.cache\kagglehub\datasets\sujaykapadnis\vehicle-type-image-dataset\versions\1\Vehicle Type Image Dataset (Version 2) VTID2
Dirs: ['Hatchback', 'Other', 'Pickup', 'Seden', 'SUV']
Files (first 5): []
--------------------------------------------------
Root: C:\Users\hp\.cache\kagglehub\datasets\sujaykapadnis\vehicle-type-image-dataset\versions\1\Vehicle Type Image Dataset (Version 2) VTID2\Hatchback
Dirs: []
Files (first 5): ['PHOTO_0.jpg', 'PHOTO_1.jp

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader, Subset, Dataset, ConcatDataset
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import kagglehub
import time
from torch.cuda.amp import GradScaler, autocast  # For mixed precision training

# Start total script timer
total_start_time = time.time()

# Ensure output directory exists
output_dir = "FL_VEHICLE_CVAE_latent_test3_noniid_2"
os.makedirs(output_dir, exist_ok=True)

# Hyperparameters
RESIZE = 128
original_dim = RESIZE * RESIZE * 3
intermediate_dim = 512
latent_dim = 256
num_classes = 5
batch_size = 8  # Reduced to prevent memory issues
epochs = 2  # Reduced for faster testing; set to 3000 for full training
learning_rate = 1e-4
beta_start = 1
beta_end = 10
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Print GPU information and check available memory
print(f"Using device: {device}")
if device.type == 'cuda':
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    total_memory = torch.cuda.get_device_properties(0).total_memory / 1024**2  # in MB
    print(f"Total GPU Memory: {total_memory:.2f} MB")
    print(f"GPU Memory Allocated: {torch.cuda.memory_allocated(0) / 1024**2:.2f} MB")

# Download Vehicle Type Image Dataset from Kaggle
try:
    path = kagglehub.dataset_download("sujaykapadnis/vehicle-type-image-dataset")
    print("Path to dataset files:", path)
    dataset_path = path
except Exception as e:
    print(f"Failed to download dataset: {e}")
    raise

# Transform for CVAE (normalize to [-1, 1])
cvae_input_transform = transforms.Compose([
    transforms.Resize((RESIZE, RESIZE)),
    transforms.ToTensor(),  # [0, 1]
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # [-1, 1]
])

# Debug dataset directory structure
print("Inspecting dataset path:", dataset_path)
for root, dirs, files in os.walk(dataset_path):
    print(f"Root: {root}")
    print(f"Dirs: {dirs}")
    print(f"Files (first 5): {files[:5]}")
    print("-" * 50)

# Custom Dataset for the Vehicle Type Dataset
class VehicleTypeDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.images = []
        self.labels = []
        self.class_names = []
        self.class_to_idx = {}

        print(f"Searching for images in {root_dir}")
        for root, dirs, files in os.walk(root_dir):
            image_files = [f for f in files if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
            if image_files:
                class_name = os.path.basename(root)
                if class_name not in self.class_to_idx:
                    self.class_names.append(class_name)
                    self.class_to_idx[class_name] = len(self.class_names) - 1
                for img_file in image_files:
                    img_path = os.path.join(root, img_file)
                    try:
                        Image.open(img_path).verify()
                        self.images.append(img_path)
                        self.labels.append(self.class_to_idx[class_name])
                    except:
                        print(f"Skipping corrupted image: {img_path}")

        if len(self.class_names) != num_classes:
            raise ValueError(f"Expected {num_classes} classes, found {len(self.class_names)}")
        print(f"Found {len(self.images)} images across {len(self.class_names)} classes.")
        print(f"Classes: {self.class_names}")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, label

# Load the dataset
dataset = VehicleTypeDataset(root_dir=dataset_path, transform=cvae_input_transform)

# Update label_dim
label_dim = len(dataset.class_names)
print(f"Number of classes (label_dim): {label_dim}")

# Step 1: Split dataset into train, validation, and test sets per class
validation_ratio = 0.1
test_ratio = 0.1
train_ratio = 0.8

class_datasets = [[] for _ in range(label_dim)]
for idx in range(len(dataset)):
    label = dataset.labels[idx]
    class_datasets[label].append(idx)

train_indices_per_class = []
val_indices_per_class = []
test_indices_per_class = []

for class_idx in range(label_dim):
    indices = class_datasets[class_idx]
    total_samples = len(indices)
    num_train = int(total_samples * train_ratio)
    num_val = int(total_samples * validation_ratio)
    num_test = total_samples - num_train - num_val

    np.random.shuffle(indices)

    train_indices = indices[:num_train]
    val_indices = indices[num_train:num_train + num_val]
    test_indices = indices[num_train + num_val:]

    train_indices_per_class.append(train_indices)
    val_indices_per_class.append(val_indices)
    test_indices_per_class.append(test_indices)

    print(f"Class {class_idx}: Train={len(train_indices)}, Val={len(val_indices)}, Test={len(test_indices)}")

# Verify no overlap
for class_idx in range(label_dim):
    train_set = set(train_indices_per_class[class_idx])
    val_set = set(val_indices_per_class[class_idx])
    test_set = set(test_indices_per_class[class_idx])
    assert len(train_set.intersection(val_set)) == 0, f"Overlap between train and val for class {class_idx}"
    assert len(train_set.intersection(test_set)) == 0, f"Overlap between train and test for class {class_idx}"
    assert len(val_set.intersection(test_set)) == 0, f"Overlap between val and test for class {class_idx}"

# Create training and validation datasets
train_dataset = Subset(dataset, [idx for class_indices in train_indices_per_class for idx in class_indices])
val_dataset = Subset(dataset, [idx for class_indices in val_indices_per_class for idx in class_indices])

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

# Step 1.5: Ensure 500 samples per class by generating synthetic samples for underrepresented classes
target_samples_per_class = 500
train_class_datasets = []
cvae_per_class = {}

# Weight initialization
def init_weights(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.Linear):
        nn.init.xavier_normal_(m.weight)
        nn.init.constant_(m.bias, 0)

# Encoder (Convolutional)
class Encoder(nn.Module):
    def __init__(self, hidden_dim, latent_dim, num_classes):
        super(Encoder, self).__init__()
        self.encoder_cnn = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 512, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512, 1024, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(1024),
            nn.ReLU(),
            nn.Conv2d(1024, 1024, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(1024),
            nn.ReLU()
        )
        self.fc_hidden = nn.Linear(1024 * 2 * 2 + num_classes, hidden_dim)
        self.dropout = nn.Dropout(0.3)
        self.fc_mean = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
        self.apply(init_weights)

    def forward(self, x, y):
        x = self.encoder_cnn(x)
        x = x.view(x.size(0), -1)
        x_with_y = torch.cat([x, y], dim=-1)
        h = F.relu(self.fc_hidden(x_with_y))
        h = self.dropout(h)
        z_mean = self.fc_mean(h)
        z_logvar = self.fc_logvar(h)
        z_logvar = torch.clamp(z_logvar, min=-10, max=10)
        return z_mean, z_logvar

# Decoder (Convolutional)
class Decoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, num_classes):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(latent_dim + num_classes, hidden_dim)
        self.dropout = nn.Dropout(0.3)
        self.fc_to_cnn = nn.Linear(hidden_dim, 1024 * 2 * 2)
        self.decoder_cnn = nn.Sequential(
            nn.ConvTranspose2d(1024, 1024, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(1024),
            nn.ReLU(),
            nn.ConvTranspose2d(1024, 512, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 512, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 256, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.Tanh()
        )
        self.apply(init_weights)

    def forward(self, z, y):
        z_with_y = torch.cat([z, y], dim=-1)
        h = F.relu(self.fc(z_with_y))
        h = self.dropout(h)
        h = F.relu(self.fc_to_cnn(h))
        h = h.view(-1, 1024, 2, 2)
        x_reconstructed = self.decoder_cnn(h)
        x_reconstructed = torch.clamp(x_reconstructed, min=-1, max=1)
        return x_reconstructed

# Conditional VAE
class ConditionalVAE(nn.Module):
    def __init__(self, encoder, decoder):
        super(ConditionalVAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def reparameterize(self, z_mean, z_logvar):
        std = torch.exp(0.5 * z_logvar)
        eps = torch.randn_like(std)
        return z_mean + eps * std

    def forward(self, data, y):
        z_mean, z_logvar = self.encoder(data, y)
        z = self.reparameterize(z_mean, z_logvar)
        x_reconstructed = self.decoder(z, y)
        return x_reconstructed, z_mean, z_logvar

# Loss Function
def cvae_loss(data, x_reconstructed, z_mean, z_logvar, beta=1.0):
    batch_size = data.size(0)
    mse_loss = F.mse_loss(x_reconstructed, data, reduction='sum') / batch_size
    kl_loss = -0.5 * torch.sum(1 + z_logvar - z_mean.pow(2) - z_logvar.exp()) / batch_size
    return mse_loss + beta * kl_loss

# Function to format time
def format_time(seconds):
    hours = int(seconds // 3600)
    minutes = int((seconds % 3600) // 60)
    secs = seconds % 60
    if hours > 0:
        return f"{hours}h {minutes}m {secs:.2f}s"
    elif minutes > 0:
        return f"{minutes}m {secs:.2f}s"
    else:
        return f"{secs:.2f}s"

# Generate synthetic samples for classes with fewer than 500 samples
for class_idx in range(label_dim):
    class_indices = train_indices_per_class[class_idx]
    num_real_samples = len(class_indices)
    print(f"Class {class_idx} has {num_real_samples} real samples before augmentation.")

    if num_real_samples >= target_samples_per_class:
        # Subsample to exactly 500 if more than 500
        class_indices = np.random.choice(class_indices, target_samples_per_class, replace=False).tolist()
        class_dataset = Subset(dataset, class_indices)
        train_class_datasets.append(class_dataset)
        print(f"Class {class_idx} subsampled to {len(class_dataset)} samples.")
        continue

    # If fewer than 500, train a CVAE to generate additional samples
    num_synthetic_needed = target_samples_per_class - num_real_samples
    print(f"Generating {num_synthetic_needed} synthetic samples for Class {class_idx} to reach {target_samples_per_class} samples.")

    # Create dataset for this class
    class_dataset = Subset(dataset, class_indices)
    class_loader = DataLoader(class_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True, drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True, drop_last=True)

    # Instantiate CVAE for this class
    encoder = Encoder(intermediate_dim, latent_dim, num_classes).to(device)
    decoder = Decoder(latent_dim, intermediate_dim, num_classes).to(device)
    cvae = ConditionalVAE(encoder, decoder).to(device)

    # Optimizer and scheduler
    optimizer = optim.Adam(cvae.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=500, gamma=0.5)
    scaler = GradScaler()  # For mixed precision training

    # Training loop for this class
    class_checkpoint_dir = os.path.join(output_dir, f'checkpoints_cvae_class_{class_idx}')
    os.makedirs(class_checkpoint_dir, exist_ok=True)

    train_losses = []
    val_losses = []

    cvae.train()
    for epoch in range(1, epochs + 1):
        epoch_start_time = time.time()

        beta = beta_start + (beta_end - beta_start) * ((epoch - 1) / (epochs - 1)) if epochs > 1 else beta_end

        train_loss = 0
        batches_processed = 0

        for batch_idx, (data, labels) in enumerate(class_loader):
            try:
                data = data.to(device)
                y = F.one_hot(labels, num_classes=num_classes).float().to(device)

                optimizer.zero_grad()
                with autocast():
                    x_recon, z_mean, z_logvar = cvae(data, y)
                    loss = cvae_loss(data, x_recon, z_mean, z_logvar, beta)

                if torch.isnan(loss) or torch.isinf(loss):
                    print(f"NaN/Inf loss detected at epoch {epoch}, batch {batch_idx + 1} for Class {class_idx}")
                    continue

                scaler.scale(loss).backward()
                torch.nn.utils.clip_grad_norm_(cvae.parameters(), max_norm=1.0)
                scaler.step(optimizer)
                scaler.update()
                train_loss += loss.item()
                batches_processed += 1

                # Monitor GPU memory
                if device.type == 'cuda':
                    mem_allocated = torch.cuda.memory_allocated(0) / 1024**2
                    mem_reserved = torch.cuda.memory_reserved(0) / 1024**2
                    

            except Exception as e:
                print(f"Error at epoch {epoch}, batch {batch_idx + 1} for Class {class_idx}: {e}")
                if "out of memory" in str(e).lower():
                    print("Out of memory error detected. Clearing cache...")
                    if device.type == 'cuda':
                        torch.cuda.empty_cache()
                continue

        avg_train_loss = train_loss / batches_processed if batches_processed > 0 else float('inf')
        train_losses.append(avg_train_loss)

        # Validation
        cvae.eval()
        val_loss = 0
        val_batches = 0
        with torch.no_grad():
            for data, labels in val_loader:
                data = data.to(device)
                y = F.one_hot(labels, num_classes=num_classes).float().to(device)
                with autocast():
                    x_recon, z_mean, z_logvar = cvae(data, y)
                    loss = cvae_loss(data, x_recon, z_mean, z_logvar, beta=1.0)
                val_loss += loss.item()
                val_batches += 1

        avg_val_loss = val_loss / val_batches if val_batches > 0 else 0
        val_losses.append(avg_val_loss)

        epoch_time = time.time() - epoch_start_time
        print(f"Class {class_idx}, Epoch {epoch}/{epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Time: {format_time(epoch_time)}")

        scheduler.step()

        # Clear GPU memory
        if device.type == 'cuda':
            torch.cuda.empty_cache()

        # Save checkpoints and latent vectors at the final epoch
        if epoch == epochs:
            checkpoint_path = os.path.join(class_checkpoint_dir, f'cvae_epoch_{epoch}.pth')
            torch.save({
                'epoch': epoch,
                'model_state_dict': cvae.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_losses': train_losses,
                'val_losses': val_losses
            }, checkpoint_path)
            print(f"Checkpoint saved for Class {class_idx} at epoch {epoch} to {checkpoint_path}")

            # Save latent vectors with labels
            latent_dir = os.path.join(class_checkpoint_dir, f'latent_vectors_epoch_{epoch}')
            os.makedirs(latent_dir, exist_ok=True)

            cvae.eval()
            with torch.no_grad():
                latent_vectors = {'z_mean': [], 'z_logvar': [], 'labels': []}
                for data, labels in class_loader:
                    data = data.to(device)
                    y = F.one_hot(labels, num_classes=num_classes).float().to(device)
                    with autocast():
                        z_mean, z_logvar = cvae.encoder(data, y)
                    for i in range(len(labels)):
                        latent_vectors['z_mean'].append(z_mean[i].cpu())
                        latent_vectors['z_logvar'].append(z_logvar[i].cpu())
                        latent_vectors['labels'].append(labels[i].item())

                if latent_vectors['z_mean']:
                    z_mean = torch.stack(latent_vectors['z_mean'])
                    z_logvar = torch.stack(latent_vectors['z_logvar'])
                    labels = torch.tensor(latent_vectors['labels'])
                    save_path = os.path.join(latent_dir, f'class_{class_idx}.pth')
                    torch.save({
                        'z_mean': z_mean,
                        'z_logvar': z_logvar,
                        'labels': labels
                    }, save_path)
                    print(f"Saved latent vectors for Class {class_idx} at epoch {epoch} to {save_path}")

    # Store the trained CVAE for this class
    cvae_per_class[class_idx] = cvae

    # Generate synthetic samples to reach 500
    synthetic_dir = os.path.join(output_dir, f'synthetic_class_{class_idx}')
    os.makedirs(synthetic_dir, exist_ok=True)

    print(f"Generating {num_synthetic_needed} synthetic images for Class {class_idx}")
    synthetic_images = []
    latent_path = os.path.join(class_checkpoint_dir, f'latent_vectors_epoch_{epochs}', f'class_{class_idx}.pth')
    latent_data = torch.load(latent_path, weights_only=False)
    z_mean_all = latent_data['z_mean'].to(device)
    z_logvar_all = latent_data['z_logvar'].to(device)

    cvae.eval()
    with torch.no_grad():
        for i in range(num_synthetic_needed):
            z = cvae.reparameterize(z_mean_all[i % len(z_mean_all)].unsqueeze(0), 
                                    z_logvar_all[i % len(z_mean_all)].unsqueeze(0))
            y = F.one_hot(torch.tensor([class_idx]), num_classes=label_dim).float().to(device)
            with autocast():
                synthetic_img = cvae.decoder(z, y).cpu()
            synthetic_images.append(synthetic_img)

    # Save synthetic images
    for idx, img in enumerate(synthetic_images):
        img_path = os.path.join(synthetic_dir, f'image_{idx + 1}.png')
        try:
            img = img.view(3, RESIZE, RESIZE)
            img = img * 0.5 + 0.5  # Denormalize to [0, 1]
            img = img.clamp(0, 1)
            img = transforms.ToPILImage()(img)
            img.save(img_path)
            if (idx + 1) % 100 == 0 or idx == 0:
                print(f"Saved {idx + 1} synthetic images for Class {class_idx}")
        except Exception as e:
            print(f"Error saving image {img_path}: {e}")
            continue

    # Create synthetic dataset
    class SyntheticDataset(Dataset):
        def __init__(self, class_label, root_dir, transform=None):
            self.class_label = class_label
            self.root_dir = root_dir
            self.transform = transform
            self.image_files = sorted([f for f in os.listdir(root_dir) if f.endswith('.png')])
            if len(self.image_files) == 0:
                raise ValueError(f"No images found in {root_dir}")

        def __len__(self):
            return len(self.image_files)

        def __getitem__(self, idx):
            img_path = os.path.join(self.root_dir, self.image_files[idx])
            image = Image.open(img_path).convert("RGB")
            if self.transform:
                image = self.transform(image)
            return image, self.class_label

    synthetic_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    synthetic_dataset = SyntheticDataset(class_idx, synthetic_dir, transform=synthetic_transform)

    # Combine real and synthetic datasets to reach exactly 500 samples
    combined_dataset = ConcatDataset([class_dataset, synthetic_dataset])
    combined_indices = list(range(len(combined_dataset)))
    if len(combined_indices) > target_samples_per_class:
        combined_indices = np.random.choice(combined_indices, target_samples_per_class, replace=False).tolist()
    final_class_dataset = Subset(combined_dataset, combined_indices)
    train_class_datasets.append(final_class_dataset)
    print(f"Class {class_idx} final dataset length: {len(final_class_dataset)}")

# Step 2: Define user classes and train CVAEs
Num_users = 5
user_classes = {
    0: [0, 1, 2, 3, 4],  # User 1
    1: [1, 2],           # User 2
    2: [3, 4, 0],        # User 3
    3: [1],              # User 4
    4: [4, 0]            # User 5
}

cvae_users = {}
train_losses_users = {}
val_losses_users = {}

for user_idx in range(Num_users):
    # Start timer for this user's CVAE training
    user_start_time = time.time()

    # Create user dataset
    user_dataset = ConcatDataset([train_class_datasets[i] for i in user_classes[user_idx]])
    print(f"User {user_idx + 1} dataset length: {len(user_dataset)}")

    # Validate indices
    try:
        for i in range(min(5, len(user_dataset))):
            sample, label = user_dataset[i]
            print(f"User {user_idx + 1}, Sample {i}: Label={label}, Data shape={sample.shape}")
    except Exception as e:
        print(f"Error accessing samples for User {user_idx + 1}: {e}")
        raise

    user_loader = DataLoader(user_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True, drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True, drop_last=True)

    # Instantiate CVAE
    encoder = Encoder(intermediate_dim, latent_dim, num_classes).to(device)
    decoder = Decoder(latent_dim, intermediate_dim, num_classes).to(device)
    cvae = ConditionalVAE(encoder, decoder).to(device)

    # Optimizer and scheduler
    optimizer = optim.Adam(cvae.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=500, gamma=0.5)
    scaler = GradScaler()

    # Training loop
    checkpoint_dir = os.path.join(output_dir, f'checkpoints_cvae_user_{user_idx + 1}')
    os.makedirs(checkpoint_dir, exist_ok=True)

    train_losses = []
    val_losses = []

    cvae.train()
    for epoch in range(1, epochs + 1):
        epoch_start_time = time.time()

        beta = beta_start + (beta_end - beta_start) * ((epoch - 1) / (epochs - 1)) if epochs > 1 else beta_end

        train_loss = 0
        batches_processed = 0

        for batch_idx, (data, labels) in enumerate(user_loader):
            try:
                data = data.to(device)
                y = F.one_hot(labels, num_classes=num_classes).float().to(device)

                optimizer.zero_grad()
                with autocast():
                    x_recon, z_mean, z_logvar = cvae(data, y)
                    loss = cvae_loss(data, x_recon, z_mean, z_logvar, beta)

                if torch.isnan(loss) or torch.isinf(loss):
                    print(f"NaN/Inf loss detected at epoch {epoch}, batch {batch_idx + 1} for User {user_idx + 1}")
                    continue

                scaler.scale(loss).backward()
                torch.nn.utils.clip_grad_norm_(cvae.parameters(), max_norm=1.0)
                scaler.step(optimizer)
                scaler.update()
                train_loss += loss.item()
                batches_processed += 1

                # Monitor GPU memory
                if device.type == 'cuda':
                    mem_allocated = torch.cuda.memory_allocated(0) / 1024**2
                    mem_reserved = torch.cuda.memory_reserved(0) / 1024**2
                    print(f"Batch {batch_idx + 1}/{len(user_loader)} - GPU Memory Allocated: {mem_allocated:.2f} MB, Reserved: {mem_reserved:.2f} MB")

            except Exception as e:
                print(f"Error at epoch {epoch}, batch {batch_idx + 1} for User {user_idx + 1}: {e}")
                if "out of memory" in str(e).lower():
                    print("Out of memory error detected. Clearing cache...")
                    if device.type == 'cuda':
                        torch.cuda.empty_cache()
                continue

        avg_train_loss = train_loss / batches_processed if batches_processed > 0 else float('inf')
        train_losses.append(avg_train_loss)

        # Validation
        cvae.eval()
        val_loss = 0
        val_batches = 0
        with torch.no_grad():
            for data, labels in val_loader:
                data = data.to(device)
                y = F.one_hot(labels, num_classes=num_classes).float().to(device)
                with autocast():
                    x_recon, z_mean, z_logvar = cvae(data, y)
                    loss = cvae_loss(data, x_recon, z_mean, z_logvar, beta=1.0)
                val_loss += loss.item()
                val_batches += 1

        avg_val_loss = val_loss / val_batches if val_batches > 0 else 0
        val_losses.append(avg_val_loss)

        epoch_time = time.time() - epoch_start_time
        print(f"User {user_idx + 1}, Epoch {epoch}/{epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Time: {format_time(epoch_time)}")

        scheduler.step()

        # Clear GPU memory
        if device.type == 'cuda':
            torch.cuda.empty_cache()

        # Save checkpoints, latent vectors, and decoder parameters at the final epoch
        if epoch == epochs:
            checkpoint_path = os.path.join(checkpoint_dir, f'cvae_epoch_{epoch}.pth')
            torch.save({
                'epoch': epoch,
                'model_state_dict': cvae.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_losses': train_losses,
                'val_losses': val_losses
            }, checkpoint_path)
            print(f"Checkpoint saved for User {user_idx + 1} at epoch {epoch} to {checkpoint_path}")

            # Save decoder parameters
            decoder_dir = os.path.join(checkpoint_dir, 'decoder')
            os.makedirs(decoder_dir, exist_ok=True)
            decoder_path = os.path.join(decoder_dir, f'decoder_epoch_{epoch}.pth')
            torch.save(cvae.decoder.state_dict(), decoder_path)
            print(f"Decoder saved for User {user_idx + 1} at epoch {epoch} to {decoder_path}")

            # Save latent vectors with labels
            latent_dir = os.path.join(checkpoint_dir, f'latent_vectors_epoch_{epoch}')
            os.makedirs(latent_dir, exist_ok=True)

            cvae.eval()
            with torch.no_grad():
                latent_vectors = {cls: {'z_mean': [], 'z_logvar': [], 'labels': []} for cls in user_classes[user_idx]}
                for data, labels in user_loader:
                    data = data.to(device)
                    y = F.one_hot(labels, num_classes=num_classes).float().to(device)
                    with autocast():
                        z_mean, z_logvar = cvae.encoder(data, y)
                    for i, label in enumerate(labels):
                        latent_vectors[label.item()]['z_mean'].append(z_mean[i].cpu())
                        latent_vectors[label.item()]['z_logvar'].append(z_logvar[i].cpu())
                        latent_vectors[label.item()]['labels'].append(label.item())

                for cls in user_classes[user_idx]:
                    if latent_vectors[cls]['z_mean']:
                        z_mean = torch.stack(latent_vectors[cls]['z_mean'])
                        z_logvar = torch.stack(latent_vectors[cls]['z_logvar'])
                        labels = torch.tensor(latent_vectors[cls]['labels'])
                        save_path = os.path.join(latent_dir, f'class_{cls}.pth')
                        torch.save({
                            'z_mean': z_mean,
                            'z_logvar': z_logvar,
                            'labels': labels
                        }, save_path)
                        print(f"Saved latent vectors for User {user_idx + 1}, Class {cls} at epoch {epoch} to {save_path}")

    # Store losses for plotting
    train_losses_users[user_idx] = train_losses
    val_losses_users[user_idx] = val_losses

    # Plot losses
    plt.figure(figsize=(6, 4))
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Val Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(f'User {user_idx + 1} CVAE Loss')
    plt.legend()
    plt.grid(True)
    loss_plot_path = os.path.join(checkpoint_dir, 'loss_plot.png')
    plt.savefig(loss_plot_path)
    plt.close()
    print(f"Loss plot saved for User {user_idx + 1} to {loss_plot_path}")

    cvae_users[user_idx] = cvae

    user_time = time.time() - user_start_time
    print(f"Total time for User {user_idx + 1} CVAE training: {format_time(user_time)}\n")

# Step 3: Generate 1000 new synthetic samples for each class per user and save user-wise
num_synthetic_per_class_generate = 1000
batch_size_synthetic = 200  # Generate in batches to manage memory

for user_idx in range(Num_users):
    user_cvae = cvae_users[user_idx]
    classes = user_classes[user_idx]

    # Load the latent vectors for the user's classes
    latent_dir = os.path.join(output_dir, f'checkpoints_cvae_user_{user_idx + 1}', f'latent_vectors_epoch_{epochs}')
    latent_vectors = {}
    for cls in classes:
        latent_path = os.path.join(latent_dir, f'class_{cls}.pth')
        latent_data = torch.load(latent_path, weights_only=False)
        latent_vectors[cls] = {
            'z_mean': latent_data['z_mean'].to(device),
            'z_logvar': latent_data['z_logvar'].to(device)
        }
        print(f"Loaded latent vectors for User {user_idx + 1}, Class {cls}: z_mean shape={latent_data['z_mean'].shape}")

    # Load the decoder
    decoder_dir = os.path.join(output_dir, f'checkpoints_cvae_user_{user_idx + 1}', 'decoder')
    decoder_path = os.path.join(decoder_dir, f'decoder_epoch_{epochs}.pth')
    user_cvae.decoder.load_state_dict(torch.load(decoder_path, weights_only=False))
    print(f"Loaded decoder for User {user_idx + 1} from {decoder_path}")

    # Generate synthetic samples for each class
    user_cvae.eval()
    for cls in classes:
        class_synthetic_dir = os.path.join(output_dir, f'synthetic_user_{user_idx + 1}', f'class_{cls}')
        latent_subdir = os.path.join(class_synthetic_dir, 'latent_vectors')
        os.makedirs(class_synthetic_dir, exist_ok=True)
        os.makedirs(latent_subdir, exist_ok=True)

        print(f"Generating {num_synthetic_per_class_generate} synthetic samples for User {user_idx + 1}, Class {cls}")
        total_generated = 0
        z_mean_all = latent_vectors[cls]['z_mean']
        z_logvar_all = latent_vectors[cls]['z_logvar']

        with torch.no_grad():
            while total_generated < num_synthetic_per_class_generate:
                remaining = min(batch_size_synthetic, num_synthetic_per_class_generate - total_generated)
                synthetic_images = []
                synthetic_z_means = []
                synthetic_z_logvars = []

                try:
                    for i in range(remaining):
                        idx = (total_generated + i) % len(z_mean_all)
                        z_mean = z_mean_all[idx].unsqueeze(0)
                        z_logvar = z_logvar_all[idx].unsqueeze(0)
                        z = user_cvae.reparameterize(z_mean, z_logvar)
                        y = F.one_hot(torch.tensor([cls]), num_classes=label_dim).float().to(device)
                        with autocast():
                            synthetic_img = user_cvae.decoder(z, y).cpu()
                        synthetic_images.append(synthetic_img)
                        synthetic_z_means.append(z_mean.cpu())
                        synthetic_z_logvars.append(z_logvar.cpu())

                    # Save the batch of synthetic images and their latent vectors
                    for idx, (img, z_mean, z_logvar) in enumerate(zip(synthetic_images, synthetic_z_means, synthetic_z_logvars)):
                        img_path = os.path.join(class_synthetic_dir, f'image_{total_generated + idx + 1}.png')
                        try:
                            img = img.view(3, RESIZE, RESIZE)
                            img = img * 0.5 + 0.5  # Denormalize to [0, 1]
                            img = img.clamp(0, 1)
                            img = transforms.ToPILImage()(img)
                            img.save(img_path)

                            # Save latent vectors for this image
                            latent_path = os.path.join(latent_subdir, f'image_{total_generated + idx + 1}_latent.pth')
                            torch.save({
                                'z_mean': z_mean.squeeze(0),
                                'z_logvar': z_logvar.squeeze(0),
                                'label': cls
                            }, latent_path)

                            if (total_generated + idx + 1) % 200 == 0 or (total_generated + idx) == 0:
                                print(f"Saved {total_generated + idx + 1} images for User {user_idx + 1}, Class {cls}")
                        except Exception as e:
                            print(f"Error saving image {img_path}: {e}")
                            continue

                    total_generated += remaining
                    print(f"Generated batch of {remaining} images. Total generated: {total_generated}/{num_synthetic_per_class_generate}")

                    # Clear memory after each batch
                    del synthetic_images, synthetic_z_means, synthetic_z_logvars
                    if device.type == 'cuda':
                        torch.cuda.empty_cache()

                except Exception as e:
                    print(f"Error generating synthetic images for User {user_idx + 1}, Class {cls}: {e}")
                    if "out of memory" in str(e).lower():
                        print("Out of memory error during synthetic generation. Clearing cache...")
                        if device.type == 'cuda':
                            torch.cuda.empty_cache()
                    continue

        # Save the decoder parameters for this class
        decoder_class_path = os.path.join(class_synthetic_dir, 'decoder.pth')
        torch.save(user_cvae.decoder.state_dict(), decoder_class_path)
        print(f"Saved decoder for User {user_idx + 1}, Class {cls} to {decoder_class_path}")

        print(f"Completed generating {total_generated} synthetic samples for User {user_idx + 1}, Class {cls}")

# Step 4: Verify the non-IID distribution
user_data = []
for user_idx in range(Num_users):
    user_data.append(ConcatDataset([train_class_datasets[i] for i in user_classes[user_idx]]))

print("\n=== Verifying Non-IID Data Distribution Across Users ===")
class_counts_per_user = []
for user_idx in range(Num_users):
    user_dataset = user_data[user_idx]
    class_counts = [0] * label_dim
    for idx in range(len(user_dataset)):
        _, label = user_dataset[idx]
        class_counts[label] += 1
    class_counts_per_user.append(class_counts)
    print(f"User {user_idx + 1} (Non-IID) Class Distribution: {class_counts}")
    total_samples = len(user_dataset)
    class_percentages = [count / total_samples * 100 if total_samples > 0 else 0 for count in class_counts]
    print(f"User {user_idx + 1} (Non-IID) Class Percentages: {[f'{p:.2f}%' for p in class_percentages]}")

# Calculate and print total script time
total_time = time.time() - total_start_time
print(f"\nTotal time for the entire script: {format_time(total_time)}")

Using device: cuda
GPU Name: NVIDIA RTX A5000
Total GPU Memory: 24563.50 MB
GPU Memory Allocated: 1627.25 MB
Path to dataset files: C:\Users\hp\.cache\kagglehub\datasets\sujaykapadnis\vehicle-type-image-dataset\versions\1
Inspecting dataset path: C:\Users\hp\.cache\kagglehub\datasets\sujaykapadnis\vehicle-type-image-dataset\versions\1
Root: C:\Users\hp\.cache\kagglehub\datasets\sujaykapadnis\vehicle-type-image-dataset\versions\1
Dirs: ['Vehicle Type Image Dataset (Version 2) VTID2']
Files (first 5): []
--------------------------------------------------
Root: C:\Users\hp\.cache\kagglehub\datasets\sujaykapadnis\vehicle-type-image-dataset\versions\1\Vehicle Type Image Dataset (Version 2) VTID2
Dirs: ['Hatchback', 'Other', 'Pickup', 'Seden', 'SUV']
Files (first 5): []
--------------------------------------------------
Root: C:\Users\hp\.cache\kagglehub\datasets\sujaykapadnis\vehicle-type-image-dataset\versions\1\Vehicle Type Image Dataset (Version 2) VTID2\Hatchback
Dirs: []
Files (first 5

  scaler = GradScaler()  # For mixed precision training
  with autocast():
  with autocast():


Class 0, Epoch 1/2, Train Loss: 15891.5903, Val Loss: 10343.9302, Time: 6.05s
Class 0, Epoch 2/2, Train Loss: 9625.8331, Val Loss: 8858.1745, Time: 5.98s
Checkpoint saved for Class 0 at epoch 2 to FL_VEHICLE_CVAE_latent_test3_noniid_2\checkpoints_cvae_class_0\cvae_epoch_2.pth


  with autocast():


Saved latent vectors for Class 0 at epoch 2 to FL_VEHICLE_CVAE_latent_test3_noniid_2\checkpoints_cvae_class_0\latent_vectors_epoch_2\class_0.pth
Generating 19 synthetic images for Class 0


  with autocast():


Saved 1 synthetic images for Class 0
Class 0 final dataset length: 500
Class 1 has 480 real samples before augmentation.
Generating 20 synthetic samples for Class 1 to reach 500 samples.
Class 1, Epoch 1/2, Train Loss: 12518.4573, Val Loss: 11204.9463, Time: 6.67s
Class 1, Epoch 2/2, Train Loss: 6667.1490, Val Loss: 11129.3255, Time: 6.62s
Checkpoint saved for Class 1 at epoch 2 to FL_VEHICLE_CVAE_latent_test3_noniid_2\checkpoints_cvae_class_1\cvae_epoch_2.pth
Saved latent vectors for Class 1 at epoch 2 to FL_VEHICLE_CVAE_latent_test3_noniid_2\checkpoints_cvae_class_1\latent_vectors_epoch_2\class_1.pth
Generating 20 synthetic images for Class 1
Saved 1 synthetic images for Class 1
Class 1 final dataset length: 500
Class 2 has 1351 real samples before augmentation.
Class 2 subsampled to 500 samples.
Class 3 has 977 real samples before augmentation.
Class 3 subsampled to 500 samples.
Class 4 has 544 real samples before augmentation.
Class 4 subsampled to 500 samples.
User 1 dataset lengt

  scaler = GradScaler()
  with autocast():


Batch 1/312 - GPU Memory Allocated: 2432.26 MB, Reserved: 2904.00 MB
Batch 2/312 - GPU Memory Allocated: 2432.26 MB, Reserved: 2904.00 MB
Batch 3/312 - GPU Memory Allocated: 2432.26 MB, Reserved: 2904.00 MB
Batch 4/312 - GPU Memory Allocated: 2432.26 MB, Reserved: 2904.00 MB
Batch 5/312 - GPU Memory Allocated: 2432.26 MB, Reserved: 2904.00 MB
Batch 6/312 - GPU Memory Allocated: 2432.26 MB, Reserved: 2904.00 MB
Batch 7/312 - GPU Memory Allocated: 2432.26 MB, Reserved: 2904.00 MB
Batch 8/312 - GPU Memory Allocated: 2432.26 MB, Reserved: 2904.00 MB
Batch 9/312 - GPU Memory Allocated: 2432.26 MB, Reserved: 2904.00 MB
Batch 10/312 - GPU Memory Allocated: 2432.26 MB, Reserved: 2904.00 MB
Batch 11/312 - GPU Memory Allocated: 2432.26 MB, Reserved: 2904.00 MB
Batch 12/312 - GPU Memory Allocated: 2432.26 MB, Reserved: 2904.00 MB
Batch 13/312 - GPU Memory Allocated: 2432.26 MB, Reserved: 2904.00 MB
Batch 14/312 - GPU Memory Allocated: 2432.26 MB, Reserved: 2904.00 MB
Batch 15/312 - GPU Memory All

  with autocast():


User 1, Epoch 1/2, Train Loss: 9863.5816, Val Loss: 7821.8646, Time: 22.99s
Batch 1/312 - GPU Memory Allocated: 3236.08 MB, Reserved: 3726.00 MB
Batch 2/312 - GPU Memory Allocated: 3236.08 MB, Reserved: 3726.00 MB
Batch 3/312 - GPU Memory Allocated: 3236.08 MB, Reserved: 3726.00 MB
Batch 4/312 - GPU Memory Allocated: 3236.08 MB, Reserved: 3726.00 MB
Batch 5/312 - GPU Memory Allocated: 3236.08 MB, Reserved: 3726.00 MB
Batch 6/312 - GPU Memory Allocated: 3236.08 MB, Reserved: 3726.00 MB
Batch 7/312 - GPU Memory Allocated: 3236.08 MB, Reserved: 3726.00 MB
Batch 8/312 - GPU Memory Allocated: 3236.08 MB, Reserved: 3726.00 MB
Batch 9/312 - GPU Memory Allocated: 3236.08 MB, Reserved: 3726.00 MB
Batch 10/312 - GPU Memory Allocated: 3236.08 MB, Reserved: 3726.00 MB
Batch 11/312 - GPU Memory Allocated: 3236.08 MB, Reserved: 3726.00 MB
Batch 12/312 - GPU Memory Allocated: 3236.08 MB, Reserved: 3726.00 MB
Batch 13/312 - GPU Memory Allocated: 3236.08 MB, Reserved: 3726.00 MB
Batch 14/312 - GPU Memo

  with autocast():


Saved latent vectors for User 1, Class 0 at epoch 2 to FL_VEHICLE_CVAE_latent_test3_noniid_2\checkpoints_cvae_user_1\latent_vectors_epoch_2\class_0.pth
Saved latent vectors for User 1, Class 1 at epoch 2 to FL_VEHICLE_CVAE_latent_test3_noniid_2\checkpoints_cvae_user_1\latent_vectors_epoch_2\class_1.pth
Saved latent vectors for User 1, Class 2 at epoch 2 to FL_VEHICLE_CVAE_latent_test3_noniid_2\checkpoints_cvae_user_1\latent_vectors_epoch_2\class_2.pth
Saved latent vectors for User 1, Class 3 at epoch 2 to FL_VEHICLE_CVAE_latent_test3_noniid_2\checkpoints_cvae_user_1\latent_vectors_epoch_2\class_3.pth
Saved latent vectors for User 1, Class 4 at epoch 2 to FL_VEHICLE_CVAE_latent_test3_noniid_2\checkpoints_cvae_user_1\latent_vectors_epoch_2\class_4.pth


In [2]:
import torch
torch.cuda.empty_cache()