# **Imports**

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt
from tqdm import tqdm
import glob
from torch.optim import Adam
from torchvision.models import inception_v3
from scipy.linalg import sqrtm
import math

torch.manual_seed(42)
np.random.seed(42)


# **Define constants**

In [None]:
# Define constants
IMAGE_SIZE = 150
BATCH_SIZE = 32
EPOCHS = 100
LEARNING_RATE = 2e-4
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
TIMESTEPS = 1000
BETA_START = 1e-4
BETA_END = 0.02


# **Extract timestep embeddings**

In [None]:
# Helper function to extract timestep embeddings
def extract(a, t, x_shape):
    batch_size = t.shape[0]
    # Make sure t is on the same device as a
    out = a.gather(-1, t)
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1)))




# **Normalizing image**

In [None]:
def normalize_image(img_tensor):
    """Normalize image to [-1, 1] range with better handling of edge cases"""
    min_val = img_tensor.min()
    max_val = img_tensor.max()

    # Check for division by zero or very small values
    if max_val - min_val < 1e-5:
        return torch.zeros_like(img_tensor)

    # Normalize to [0, 1] then to [-1, 1]
    normalized = (img_tensor - min_val) / (max_val - min_val)
    return 2.0 * normalized - 1.0


# **Dataset Handling**

In [None]:
# Dataset class for loading .npy lensing images
class LensingDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_paths = glob.glob(os.path.join(data_dir, "*.npy"))
        self.transform = transform

    def __len__(self):
        return len(self.data_paths)

    def __getitem__(self, idx):
        data_path = self.data_paths[idx]
        # Load and process the .npy file
        img = np.load(data_path)
        img = img[0]  # Get the first channel as shape is (1, 150, 150)

        # Convert to tensor
        img_tensor = torch.tensor(img, dtype=torch.float32)
        img_tensor = img_tensor.unsqueeze(0)  # Add channel dimension [1, 150, 150]

        # Apply transformations if any
        if self.transform:
            img_tensor = self.transform(img_tensor)

        # Normalize to [-1, 1]
        img_tensor = normalize_image(img_tensor)

        return img_tensor



# **U-Net Building Blocks**

In [None]:
# Define U-Net building blocks
class Block(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, act="relu", dropout=0.1):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1) if down
            else nn.ConvTranspose2d(in_channels, out_channels, 3, 1, 1),
            nn.GroupNorm(1, out_channels),
            nn.ReLU() if act == "relu" else nn.SiLU(),
            nn.Dropout(dropout) if dropout > 0 else nn.Identity(),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1),
            nn.GroupNorm(1, out_channels),
            nn.ReLU() if act == "relu" else nn.SiLU(),
        )

    def forward(self, x):
        return self.conv(x)

class ResBlock(nn.Module):
    def __init__(self, channels, time_emb_dim=None):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
        self.norm1 = nn.GroupNorm(1, channels)
        self.norm2 = nn.GroupNorm(1, channels)
        self.act = nn.SiLU()

        if time_emb_dim is not None:
            self.time_mlp = nn.Sequential(
                nn.SiLU(),
                nn.Linear(time_emb_dim, channels)
            )
        else:
            self.time_mlp = None

    def forward(self, x, t=None):
        h = self.act(self.norm1(self.conv1(x)))

        if self.time_mlp is not None and t is not None:
            time_emb = self.time_mlp(t)
            h = h + time_emb.reshape(-1, time_emb.shape[1], 1, 1)

        h = self.act(self.norm2(self.conv2(h)))
        return h + x

class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

class Attention(nn.Module):
    def __init__(self, channels, size):
        super(Attention, self).__init__()
        self.channels = channels
        self.size = size
        self.mha = nn.MultiheadAttention(channels, 4, batch_first=True)
        self.ln = nn.LayerNorm([channels])
        self.ff_self = nn.Sequential(
            nn.LayerNorm([channels]),
            nn.Linear(channels, channels),
            nn.GELU(),
            nn.Linear(channels, channels),
        )

    def forward(self, x):
        x = x.view(-1, self.channels, self.size * self.size).swapaxes(1, 2)
        x_ln = self.ln(x)
        attention_value, _ = self.mha(x_ln, x_ln, x_ln)
        attention_value = attention_value + x
        attention_value = self.ff_self(attention_value) + attention_value
        return attention_value.swapaxes(2, 1).view(-1, self.channels, self.size, self.size)



# **U-Net model**

In [None]:
# U-Net model for diffusion
class UNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, time_dim=256, features=64):
        super().__init__()

        # Time embedding
        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(time_dim),
            nn.Linear(time_dim, time_dim),
            nn.ReLU()
        )

        # Initial conv
        self.conv0 = nn.Conv2d(in_channels, features, kernel_size=3, padding=1)

        # Down sampling
        self.downs = nn.ModuleList([
            Block(features, features, act="silu"),
            Block(features, features * 2, act="silu"),
            Block(features * 2, features * 2, act="silu"),
            Block(features * 2, features * 4, act="silu"),
            Block(features * 4, features * 4, act="silu"),
        ])

        # Middle blocks with attention
        self.mid_block1 = ResBlock(features * 4, time_dim)
        self.mid_attn = Attention(features * 4, IMAGE_SIZE // 16)
        self.mid_block2 = ResBlock(features * 4, time_dim)

        # Up sampling with matching dimensions
        self.ups = nn.ModuleList([
            Block(features * 8, features * 4, down=False, act="silu"),
            Block(features * 6, features * 2, down=False, act="silu"),
            Block(features * 4, features * 2, down=False, act="silu"),
            Block(features * 3, features, down=False, act="silu"),
        ])

        # Max pooling for downsampling
        self.pool = nn.MaxPool2d(kernel_size=2)

        # Upsampling
        self.upsamples = nn.ModuleList([
            nn.Upsample(scale_factor=2, mode="nearest"),
            nn.Upsample(scale_factor=2, mode="nearest"),
            nn.Upsample(scale_factor=2, mode="nearest"),
        ])

        # Final conv
        self.final = nn.Sequential(
            nn.Conv2d(features * 2, features, kernel_size=3, padding=1),
            nn.GroupNorm(1, features),
            nn.ReLU(),
            nn.Conv2d(features, out_channels, kernel_size=1),
        )

    def forward(self, x, t):
        # Time embedding
        t = self.time_mlp(t)

        # Initial feature extraction
        x = self.conv0(x)

        # Store residuals for skip connections
        residuals = []
        residuals.append(x)

        # Downsample and store intermediate outputs
        downs_outputs = []
        current = x
        for i, down in enumerate(self.downs):
            current = down(current)
            if i < len(self.downs) - 1:  # Don't pool the last layer
                residuals.append(current)
                current = self.pool(current)
            else:
                downs_outputs.append(current)

        # Middle with attention
        x = self.mid_block1(downs_outputs[-1], t)
        x = self.mid_attn(x)
        x = self.mid_block2(x, t)

        # Upsample with skip connections - Fixed dimensions
        for i, up in enumerate(self.ups):
            residual = residuals.pop() if residuals else None
            # Make sure dimensions match before concatenation
            if residual is not None:
                if x.shape[2:] != residual.shape[2:]:
                    x = F.interpolate(x, size=residual.shape[2:], mode="nearest")
                x = torch.cat((x, residual), dim=1)
            x = up(x)
            if i < len(self.ups) - 1 and i < len(self.upsamples):
                x = self.upsamples[i](x)

        # Final residual connection and output
        last_residual = residuals.pop() if residuals else None
        if last_residual is not None:
            # Make sure dimensions match before concatenation
            if x.shape[2:] != last_residual.shape[2:]:
                x = F.interpolate(x, size=last_residual.shape[2:], mode="nearest")
            x = torch.cat((x, last_residual), dim=1)

        return self.final(x)



# **Diffusion model**

In [None]:
# Diffusion model
class DiffusionModel:
    def __init__(self, noise_steps=TIMESTEPS, beta_start=BETA_START, beta_end=BETA_END, img_size=IMAGE_SIZE, device=DEVICE):
        self.noise_steps = noise_steps
        self.beta_start = beta_start
        self.beta_end = beta_end
        self.img_size = img_size
        self.device = device

        # Linear noise schedule
        self.betas = torch.linspace(beta_start, beta_end, noise_steps).to(device)
        self.alphas = 1. - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)

        # Diffusion process parameters
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)
        self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas)
        self.posterior_variance = self.betas * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)

    def q_sample(self, x_0, t, noise=None):
        """Forward diffusion process: add noise to the image"""
        if noise is None:
            noise = torch.randn_like(x_0)

        sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, x_0.shape)
        sqrt_one_minus_alphas_cumprod_t = extract(self.sqrt_one_minus_alphas_cumprod, t, x_0.shape)

        return sqrt_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noise, noise

    def p_sample(self, model, x, t, t_index):
        """Sample from the learned model at step t"""
        betas_t = extract(self.betas, t, x.shape)
        sqrt_one_minus_alphas_cumprod_t = extract(self.sqrt_one_minus_alphas_cumprod, t, x.shape)
        sqrt_recip_alphas_t = extract(self.sqrt_recip_alphas, t, x.shape)

        # Predicted noise
        model_output = model(x, t)

        # No noise for t=0 step
        model_mean = sqrt_recip_alphas_t * (x - betas_t * model_output / sqrt_one_minus_alphas_cumprod_t)

        if t_index == 0:
            return model_mean
        else:
            posterior_variance_t = extract(self.posterior_variance, t, x.shape)
            noise = torch.randn_like(x)
            return model_mean + torch.sqrt(posterior_variance_t) * noise

    @torch.no_grad()
    def p_sample_loop(self, model, shape):
        """Generate image by iteratively denoising"""
        device = next(model.parameters()).device
        b = shape[0]
        # Use improved initial noise distribution
        img = torch.randn(shape, device=device)
        img = torch.clamp(img, -2, 2)  # Truncate extreme values

        for i in tqdm(reversed(range(0, self.noise_steps)), desc='Sampling', total=self.noise_steps):
            t = torch.full((b,), i, device=device, dtype=torch.long)
            img = self.p_sample(model, img, t, i)

        return img

    @torch.no_grad()
    def sample(self, model, n_samples, channels=1):
        """Sample new images"""
        return self.p_sample_loop(model, (n_samples, channels, self.img_size, self.img_size))



# **Learning Rate Scheduler**

In [None]:
# Create a learning rate scheduler
def create_lr_scheduler(optimizer, total_epochs):
    """Create a learning rate scheduler that decays over time"""
    return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_epochs)



# **FID score calculation**

In [None]:
# Implement FID score calculation
def calculate_fid(real_images, fake_images, device=DEVICE):
    """Calculate FID score between two sets of images"""
    # Load the pretrained model but don't require grad
    model = inception_v3(pretrained=True, transform_input=False).to(device)
    model.eval()

    # Handle smaller batch sizes for memory constraints
    batch_size = 8

    # Resize images to inception input size
    resize = transforms.Resize((299, 299))

    # Function to get features
    def get_features(images):
        features = []
        # Process in batches to avoid OOM
        for i in range(0, len(images), batch_size):
            batch = images[i:i+batch_size]
            batch_features = []

            for img in batch:
                img = resize(img)
                # Expand to 3 channels if grayscale
                if img.shape[0] == 1:
                    img = img.repeat(3, 1, 1)
                # Move to range [0, 1] then normalize for Inception
                img = (img + 1) / 2.0  # Convert from [-1, 1] to [0, 1]
                img = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img)
                img = img.unsqueeze(0).to(device)

                try:
                    with torch.no_grad():
                        # Use the correct output from InceptionV3
                        feat = model(img)
                        if isinstance(feat, tuple):
                            feat = feat[0]  # Get the logits, not the aux_logits
                        feat = feat.squeeze()
                    batch_features.append(feat.cpu().numpy())
                except Exception as e:
                    print(f"Error processing image: {e}")
                    continue

            if batch_features:
                features.extend(batch_features)

        if not features:
            raise ValueError("No features were successfully extracted")

        return np.stack(features, axis=0)

    try:
        # Get activations
        real_features = get_features(real_images)
        fake_features = get_features(fake_images)

        # Check feature dimensions
        print(f"Feature shapes - Real: {real_features.shape}, Fake: {fake_features.shape}")

        # Calculate mean and covariance
        mu_real = np.mean(real_features, axis=0)
        sigma_real = np.cov(real_features, rowvar=False)

        mu_fake = np.mean(fake_features, axis=0)
        sigma_fake = np.cov(fake_features, rowvar=False)

        # Ensure matrices are proper shape
        print(f"Covariance shapes - Real: {sigma_real.shape}, Fake: {sigma_fake.shape}")

        # Calculate FID score
        ssdiff = np.sum((mu_real - mu_fake) ** 2.0)

        # Handle numerical stability in sqrtm
        try:
            # Ensure matrices are proper format for sqrtm
            if sigma_real.shape[0] > 0 and sigma_fake.shape[0] > 0:
                # Check if matrices are positive semi-definite
                # Add a small epsilon to diagonal for numerical stability
                epsilon = 1e-6
                sigma_real += np.eye(sigma_real.shape[0]) * epsilon
                sigma_fake += np.eye(sigma_fake.shape[0]) * epsilon

                covmean = sqrtm(sigma_real.dot(sigma_fake))

                # Check if covmean contains complex numbers
                if np.iscomplexobj(covmean):
                    covmean = covmean.real

                fid = ssdiff + np.trace(sigma_real + sigma_fake - 2.0 * covmean)
            else:
                print("Error: Empty covariance matrices")
                fid = ssdiff
        except Exception as e:
            print(f"Error in FID sqrtm calculation: {e}")
            # Fallback to just the mean squared difference
            fid = ssdiff

        return fid
    except Exception as e:
        print(f"Error in FID calculation: {e}")
        return float('inf')  # Return infinity if calculation fails



# **Training**

In [None]:
# Training function
def train(model, diffusion, dataloader, optimizer, device=DEVICE, epochs=EPOCHS):
    model.to(device)

    # Create directories for saving
    os.makedirs("model_checkpoints", exist_ok=True)
    os.makedirs("generated_samples", exist_ok=True)
    os.makedirs("loss_plots", exist_ok=True)

    # Initialize learning rate scheduler
    scheduler = create_lr_scheduler(optimizer, epochs)

    # Keep track of losses for plotting
    epoch_losses = []

    for epoch in range(epochs):
        model.train()
        epoch_loss = 0.0
        batch_losses = []
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")

        for i, images in enumerate(progress_bar):
            images = images.to(device)
            batch_size = images.shape[0]

            # Sample random timesteps
            t = torch.randint(0, diffusion.noise_steps, (batch_size,), device=device).long()

            # Add noise to images
            x_t, noise = diffusion.q_sample(images, t)

            # Predict noise
            predicted_noise = model(x_t, t)

            # Calculate loss
            loss = F.mse_loss(predicted_noise, noise)

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()

            # Gradient clipping for stability
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            optimizer.step()

            batch_loss = loss.item()
            epoch_loss += batch_loss
            batch_losses.append(batch_loss)

            progress_bar.set_postfix({"loss": batch_loss})

        # Step the learning rate scheduler
        scheduler.step()

        # Calculate average loss for the epoch
        avg_loss = epoch_loss / len(dataloader)
        epoch_losses.append(avg_loss)

        # Print epoch results
        print(f"Epoch {epoch+1}/{epochs} - Average Loss: {avg_loss:.4f} - LR: {scheduler.get_last_lr()[0]:.6f}")

        # Plot and save loss curve
        plt.figure(figsize=(10, 5))
        plt.plot(epoch_losses, label='Training Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title('Training Loss Over Time')
        plt.legend()
        plt.grid(True)
        plt.savefig(f"loss_plots/loss_curve_epoch_{epoch+1}.png")
        plt.close()

        # Save model checkpoint and generate samples
        if (epoch + 1) % 10 == 0 or epoch == epochs - 1:
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'loss': avg_loss,
            }, f"model_checkpoints/diffusion_epoch_{epoch+1}.pt")

            # Generate and save samples
            model.eval()
            with torch.no_grad():
                samples = diffusion.sample(model, n_samples=4)

                # Save the generated samples
                fig, axes = plt.subplots(2, 2, figsize=(10, 10))
                axes = axes.flatten()

                for i, sample in enumerate(samples):
                    # Convert from [-1, 1] to [0, 1] for display
                    sample_img = (sample.cpu().squeeze() + 1) / 2
                    axes[i].imshow(sample_img, cmap='viridis')
                    axes[i].axis('off')

                plt.tight_layout()
                plt.savefig(f"generated_samples/samples_epoch_{epoch+1}.png")
                plt.close()

    # Calculate FID score after training
    model.eval()

    # Get a batch of real images
    real_images = []
    try:
        for batch in dataloader:
            real_images.append(batch)
            if len(real_images) * batch.shape[0] >= 50:  # Reduced to 50 images for speed
                break
        real_images = torch.cat(real_images, dim=0)[:50].to(device)

        # Generate fake images
        fake_images = diffusion.sample(model, n_samples=min(50, len(real_images)))

        # Calculate FID score
        fid_score = calculate_fid(real_images, fake_images)
        print(f"Final FID Score: {fid_score:.4f}")
    except Exception as e:
        print(f"Error in final evaluation: {e}")
        fid_score = float('inf')

    return model, fid_score



# **Main Function**

In [None]:
# Main function
def main():
    print(f"Using device: {DEVICE}")

    # Define data directory
    data_dir = "/kaggle/input/diffusion/Samples"

    # Create dataset and dataloader
    dataset = LensingDataset(data_dir)

    # Check if dataset is empty
    if len(dataset) == 0:
        print(f"No .npy files found in {data_dir}")
        return
    else:
        print(f"Found {len(dataset)} samples in the dataset")

    # Create dataloader with fewer workers for Kaggle
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

    # Create model, diffusion process, and optimizer
    model = UNet(in_channels=1, out_channels=1, time_dim=256)
    diffusion = DiffusionModel()
    optimizer = Adam(model.parameters(), lr=LEARNING_RATE)

    # Train the model
    trained_model, fid_score = train(model, diffusion, dataloader, optimizer)

    print(f"Training complete. Final FID Score: {fid_score:.4f}")

    # Generate final samples
    os.makedirs("final_samples", exist_ok=True)
    trained_model.eval()

    with torch.no_grad():
        samples = diffusion.sample(trained_model, n_samples=16)

        # Save the generated samples
        fig, axes = plt.subplots(4, 4, figsize=(15, 15))
        axes = axes.flatten()

        for i, sample in enumerate(samples):
            # Convert from [-1, 1] to [0, 1] for display
            sample_img = (sample.cpu().squeeze() + 1) / 2
            axes[i].imshow(sample_img, cmap='viridis')
            axes[i].axis('off')

        plt.tight_layout()
        plt.savefig("final_samples/final_samples.png")
        plt.close()

    print("Generated final samples saved to 'final_samples/final_samples.png'")


# Entry point
if __name__ == "__main__":
    main()

Using device: cuda
Found 10000 samples in the dataset


Epoch 1/100: 100%|██████████| 313/313 [01:32<00:00,  3.40it/s, loss=0.0442] 


Epoch 1/100 - Average Loss: 0.0347 - LR: 0.000200


Epoch 2/100: 100%|██████████| 313/313 [01:31<00:00,  3.40it/s, loss=0.0384] 


Epoch 2/100 - Average Loss: 0.0110 - LR: 0.000200


Epoch 3/100: 100%|██████████| 313/313 [01:31<00:00,  3.40it/s, loss=0.00211]


Epoch 3/100 - Average Loss: 0.0087 - LR: 0.000200


Epoch 4/100: 100%|██████████| 313/313 [01:32<00:00,  3.40it/s, loss=0.00338]


Epoch 4/100 - Average Loss: 0.0072 - LR: 0.000199


Epoch 5/100: 100%|██████████| 313/313 [01:31<00:00,  3.40it/s, loss=0.0174] 


Epoch 5/100 - Average Loss: 0.0065 - LR: 0.000199


Epoch 6/100: 100%|██████████| 313/313 [01:31<00:00,  3.40it/s, loss=0.0128] 


Epoch 6/100 - Average Loss: 0.0066 - LR: 0.000198


Epoch 7/100: 100%|██████████| 313/313 [01:32<00:00,  3.40it/s, loss=0.00135] 


Epoch 7/100 - Average Loss: 0.0067 - LR: 0.000198


Epoch 8/100: 100%|██████████| 313/313 [01:31<00:00,  3.41it/s, loss=0.0101]  


Epoch 8/100 - Average Loss: 0.0058 - LR: 0.000197


Epoch 9/100: 100%|██████████| 313/313 [01:31<00:00,  3.40it/s, loss=0.0159]  


Epoch 9/100 - Average Loss: 0.0063 - LR: 0.000196


Epoch 10/100: 100%|██████████| 313/313 [01:31<00:00,  3.40it/s, loss=0.00135] 


Epoch 10/100 - Average Loss: 0.0059 - LR: 0.000195


Sampling: 100%|██████████| 1000/1000 [00:23<00:00, 42.13it/s]
Epoch 11/100: 100%|██████████| 313/313 [01:31<00:00,  3.41it/s, loss=0.00116] 


Epoch 11/100 - Average Loss: 0.0062 - LR: 0.000194


Epoch 12/100: 100%|██████████| 313/313 [01:31<00:00,  3.41it/s, loss=0.0025]  


Epoch 12/100 - Average Loss: 0.0053 - LR: 0.000193


Epoch 13/100: 100%|██████████| 313/313 [01:31<00:00,  3.40it/s, loss=0.00206] 


Epoch 13/100 - Average Loss: 0.0061 - LR: 0.000192


Epoch 14/100: 100%|██████████| 313/313 [01:32<00:00,  3.40it/s, loss=0.0097]  


Epoch 14/100 - Average Loss: 0.0057 - LR: 0.000190


Epoch 15/100: 100%|██████████| 313/313 [01:31<00:00,  3.40it/s, loss=0.00564] 


Epoch 15/100 - Average Loss: 0.0054 - LR: 0.000189


Epoch 16/100: 100%|██████████| 313/313 [01:31<00:00,  3.41it/s, loss=0.0395]  


Epoch 16/100 - Average Loss: 0.0053 - LR: 0.000188


Epoch 17/100: 100%|██████████| 313/313 [01:31<00:00,  3.41it/s, loss=0.00105] 


Epoch 17/100 - Average Loss: 0.0060 - LR: 0.000186


Epoch 18/100: 100%|██████████| 313/313 [01:31<00:00,  3.41it/s, loss=0.00232] 


Epoch 18/100 - Average Loss: 0.0053 - LR: 0.000184


Epoch 19/100: 100%|██████████| 313/313 [01:31<00:00,  3.41it/s, loss=0.00303] 


Epoch 19/100 - Average Loss: 0.0051 - LR: 0.000183


Epoch 20/100: 100%|██████████| 313/313 [01:32<00:00,  3.40it/s, loss=0.00202] 


Epoch 20/100 - Average Loss: 0.0048 - LR: 0.000181


Sampling: 100%|██████████| 1000/1000 [00:23<00:00, 42.11it/s]
Epoch 21/100: 100%|██████████| 313/313 [01:31<00:00,  3.41it/s, loss=0.00158] 


Epoch 21/100 - Average Loss: 0.0056 - LR: 0.000179


Epoch 22/100: 100%|██████████| 313/313 [01:31<00:00,  3.41it/s, loss=0.000649]


Epoch 22/100 - Average Loss: 0.0055 - LR: 0.000177


Epoch 23/100: 100%|██████████| 313/313 [01:31<00:00,  3.40it/s, loss=0.000481]


Epoch 23/100 - Average Loss: 0.0053 - LR: 0.000175


Epoch 24/100: 100%|██████████| 313/313 [01:31<00:00,  3.40it/s, loss=0.000871]


Epoch 24/100 - Average Loss: 0.0053 - LR: 0.000173


Epoch 25/100: 100%|██████████| 313/313 [01:31<00:00,  3.41it/s, loss=0.00257] 


Epoch 25/100 - Average Loss: 0.0052 - LR: 0.000171


Epoch 26/100: 100%|██████████| 313/313 [01:31<00:00,  3.40it/s, loss=0.0322]  


Epoch 26/100 - Average Loss: 0.0051 - LR: 0.000168


Epoch 27/100: 100%|██████████| 313/313 [01:31<00:00,  3.40it/s, loss=0.00109] 


Epoch 27/100 - Average Loss: 0.0053 - LR: 0.000166


Epoch 28/100: 100%|██████████| 313/313 [01:31<00:00,  3.40it/s, loss=0.00396] 


Epoch 28/100 - Average Loss: 0.0050 - LR: 0.000164


Epoch 29/100: 100%|██████████| 313/313 [01:31<00:00,  3.40it/s, loss=0.00115] 


Epoch 29/100 - Average Loss: 0.0058 - LR: 0.000161


Epoch 30/100: 100%|██████████| 313/313 [01:32<00:00,  3.40it/s, loss=0.0017]  


Epoch 30/100 - Average Loss: 0.0048 - LR: 0.000159


Sampling: 100%|██████████| 1000/1000 [00:23<00:00, 42.11it/s]
Epoch 31/100: 100%|██████████| 313/313 [01:31<00:00,  3.41it/s, loss=0.00135] 


Epoch 31/100 - Average Loss: 0.0047 - LR: 0.000156


Epoch 32/100: 100%|██████████| 313/313 [01:31<00:00,  3.41it/s, loss=0.000695]


Epoch 32/100 - Average Loss: 0.0045 - LR: 0.000154


Epoch 33/100: 100%|██████████| 313/313 [01:31<00:00,  3.41it/s, loss=0.00108] 


Epoch 33/100 - Average Loss: 0.0049 - LR: 0.000151


Epoch 34/100: 100%|██████████| 313/313 [01:31<00:00,  3.40it/s, loss=0.0105]  


Epoch 34/100 - Average Loss: 0.0047 - LR: 0.000148


Epoch 35/100: 100%|██████████| 313/313 [01:31<00:00,  3.41it/s, loss=0.00256] 


Epoch 35/100 - Average Loss: 0.0049 - LR: 0.000145


Epoch 36/100: 100%|██████████| 313/313 [01:32<00:00,  3.40it/s, loss=0.00226] 


Epoch 36/100 - Average Loss: 0.0054 - LR: 0.000143


Epoch 37/100: 100%|██████████| 313/313 [01:31<00:00,  3.40it/s, loss=0.0191]  


Epoch 37/100 - Average Loss: 0.0054 - LR: 0.000140


Epoch 38/100: 100%|██████████| 313/313 [01:31<00:00,  3.41it/s, loss=0.000703]


Epoch 38/100 - Average Loss: 0.0053 - LR: 0.000137


Epoch 39/100: 100%|██████████| 313/313 [01:31<00:00,  3.41it/s, loss=0.00476] 


Epoch 39/100 - Average Loss: 0.0048 - LR: 0.000134


Epoch 40/100: 100%|██████████| 313/313 [01:31<00:00,  3.41it/s, loss=0.000552]


Epoch 40/100 - Average Loss: 0.0055 - LR: 0.000131


Sampling: 100%|██████████| 1000/1000 [00:23<00:00, 42.13it/s]
Epoch 41/100: 100%|██████████| 313/313 [01:31<00:00,  3.40it/s, loss=0.000931]


Epoch 41/100 - Average Loss: 0.0054 - LR: 0.000128


Epoch 42/100: 100%|██████████| 313/313 [01:31<00:00,  3.40it/s, loss=0.000562]


Epoch 42/100 - Average Loss: 0.0050 - LR: 0.000125


Epoch 43/100: 100%|██████████| 313/313 [01:32<00:00,  3.40it/s, loss=0.000934]


Epoch 43/100 - Average Loss: 0.0046 - LR: 0.000122


Epoch 44/100: 100%|██████████| 313/313 [01:31<00:00,  3.41it/s, loss=0.00147] 


Epoch 44/100 - Average Loss: 0.0045 - LR: 0.000119


Epoch 45/100: 100%|██████████| 313/313 [01:31<00:00,  3.41it/s, loss=0.000891]


Epoch 45/100 - Average Loss: 0.0048 - LR: 0.000116


Epoch 46/100: 100%|██████████| 313/313 [01:31<00:00,  3.40it/s, loss=0.00417] 


Epoch 46/100 - Average Loss: 0.0046 - LR: 0.000113


Epoch 47/100: 100%|██████████| 313/313 [01:32<00:00,  3.40it/s, loss=0.00162] 


Epoch 47/100 - Average Loss: 0.0045 - LR: 0.000109


Epoch 48/100: 100%|██████████| 313/313 [01:31<00:00,  3.40it/s, loss=0.0201]  


Epoch 48/100 - Average Loss: 0.0048 - LR: 0.000106


Epoch 49/100: 100%|██████████| 313/313 [01:31<00:00,  3.41it/s, loss=0.00274] 


Epoch 49/100 - Average Loss: 0.0051 - LR: 0.000103


Epoch 50/100: 100%|██████████| 313/313 [01:31<00:00,  3.41it/s, loss=0.000877]


Epoch 50/100 - Average Loss: 0.0041 - LR: 0.000100


Sampling: 100%|██████████| 1000/1000 [00:23<00:00, 42.13it/s]
Epoch 51/100: 100%|██████████| 313/313 [01:31<00:00,  3.41it/s, loss=0.00201] 


Epoch 51/100 - Average Loss: 0.0049 - LR: 0.000097


Epoch 52/100: 100%|██████████| 313/313 [01:31<00:00,  3.41it/s, loss=0.00373] 


Epoch 52/100 - Average Loss: 0.0046 - LR: 0.000094


Epoch 53/100: 100%|██████████| 313/313 [01:31<00:00,  3.41it/s, loss=0.00944] 


Epoch 53/100 - Average Loss: 0.0047 - LR: 0.000091


Epoch 54/100: 100%|██████████| 313/313 [01:31<00:00,  3.41it/s, loss=0.00212] 


Epoch 54/100 - Average Loss: 0.0049 - LR: 0.000087


Epoch 55/100: 100%|██████████| 313/313 [01:31<00:00,  3.41it/s, loss=0.00406] 


Epoch 55/100 - Average Loss: 0.0045 - LR: 0.000084


Epoch 56/100: 100%|██████████| 313/313 [01:31<00:00,  3.41it/s, loss=0.00618] 


Epoch 56/100 - Average Loss: 0.0055 - LR: 0.000081


Epoch 57/100: 100%|██████████| 313/313 [01:31<00:00,  3.40it/s, loss=0.000439]


Epoch 57/100 - Average Loss: 0.0044 - LR: 0.000078


Epoch 58/100: 100%|██████████| 313/313 [01:31<00:00,  3.41it/s, loss=0.00345] 


Epoch 58/100 - Average Loss: 0.0045 - LR: 0.000075


Epoch 59/100: 100%|██████████| 313/313 [01:31<00:00,  3.41it/s, loss=0.000301]


Epoch 59/100 - Average Loss: 0.0048 - LR: 0.000072


Epoch 60/100: 100%|██████████| 313/313 [01:32<00:00,  3.40it/s, loss=0.000432]


Epoch 60/100 - Average Loss: 0.0047 - LR: 0.000069


Sampling: 100%|██████████| 1000/1000 [00:23<00:00, 42.12it/s]
Epoch 61/100: 100%|██████████| 313/313 [01:31<00:00,  3.41it/s, loss=0.000414]


Epoch 61/100 - Average Loss: 0.0052 - LR: 0.000066


Epoch 62/100: 100%|██████████| 313/313 [01:32<00:00,  3.40it/s, loss=0.000978]


Epoch 62/100 - Average Loss: 0.0049 - LR: 0.000063


Epoch 63/100: 100%|██████████| 313/313 [01:31<00:00,  3.41it/s, loss=0.0141]  


Epoch 63/100 - Average Loss: 0.0040 - LR: 0.000060


Epoch 64/100: 100%|██████████| 313/313 [01:31<00:00,  3.40it/s, loss=0.00326] 


Epoch 64/100 - Average Loss: 0.0047 - LR: 0.000057


Epoch 65/100: 100%|██████████| 313/313 [01:31<00:00,  3.40it/s, loss=0.00383] 


Epoch 65/100 - Average Loss: 0.0050 - LR: 0.000055


Epoch 66/100: 100%|██████████| 313/313 [01:31<00:00,  3.41it/s, loss=0.000468]


Epoch 66/100 - Average Loss: 0.0044 - LR: 0.000052


Epoch 67/100: 100%|██████████| 313/313 [01:31<00:00,  3.41it/s, loss=0.00201] 


Epoch 67/100 - Average Loss: 0.0045 - LR: 0.000049


Epoch 68/100: 100%|██████████| 313/313 [01:31<00:00,  3.41it/s, loss=0.000338]


Epoch 68/100 - Average Loss: 0.0047 - LR: 0.000046


Epoch 69/100: 100%|██████████| 313/313 [01:31<00:00,  3.40it/s, loss=0.00528] 


Epoch 69/100 - Average Loss: 0.0043 - LR: 0.000044


Epoch 70/100: 100%|██████████| 313/313 [01:31<00:00,  3.40it/s, loss=0.00868] 


Epoch 70/100 - Average Loss: 0.0046 - LR: 0.000041


Sampling: 100%|██████████| 1000/1000 [00:23<00:00, 42.12it/s]
Epoch 71/100: 100%|██████████| 313/313 [01:31<00:00,  3.41it/s, loss=0.000385]


Epoch 71/100 - Average Loss: 0.0042 - LR: 0.000039


Epoch 72/100: 100%|██████████| 313/313 [01:31<00:00,  3.40it/s, loss=0.001]   


Epoch 72/100 - Average Loss: 0.0040 - LR: 0.000036


Epoch 73/100: 100%|██████████| 313/313 [01:31<00:00,  3.41it/s, loss=0.00307] 


Epoch 73/100 - Average Loss: 0.0051 - LR: 0.000034


Epoch 74/100: 100%|██████████| 313/313 [01:31<00:00,  3.41it/s, loss=0.00339] 


Epoch 74/100 - Average Loss: 0.0046 - LR: 0.000032


Epoch 75/100: 100%|██████████| 313/313 [01:31<00:00,  3.40it/s, loss=0.000182]


Epoch 75/100 - Average Loss: 0.0047 - LR: 0.000029


Epoch 76/100: 100%|██████████| 313/313 [01:31<00:00,  3.40it/s, loss=0.00502] 


Epoch 76/100 - Average Loss: 0.0045 - LR: 0.000027


Epoch 77/100: 100%|██████████| 313/313 [01:31<00:00,  3.41it/s, loss=0.00616] 


Epoch 77/100 - Average Loss: 0.0046 - LR: 0.000025


Epoch 78/100: 100%|██████████| 313/313 [01:31<00:00,  3.41it/s, loss=0.00107] 


Epoch 78/100 - Average Loss: 0.0042 - LR: 0.000023


Epoch 79/100: 100%|██████████| 313/313 [01:31<00:00,  3.41it/s, loss=0.0168]  


Epoch 79/100 - Average Loss: 0.0048 - LR: 0.000021


Epoch 80/100: 100%|██████████| 313/313 [01:31<00:00,  3.41it/s, loss=0.00244] 


Epoch 80/100 - Average Loss: 0.0049 - LR: 0.000019


Sampling: 100%|██████████| 1000/1000 [00:23<00:00, 42.12it/s]
Epoch 81/100: 100%|██████████| 313/313 [01:31<00:00,  3.41it/s, loss=0.00988] 


Epoch 81/100 - Average Loss: 0.0046 - LR: 0.000017


Epoch 82/100: 100%|██████████| 313/313 [01:31<00:00,  3.41it/s, loss=0.00868] 


Epoch 82/100 - Average Loss: 0.0044 - LR: 0.000016


Epoch 83/100: 100%|██████████| 313/313 [01:31<00:00,  3.41it/s, loss=0.000615]


Epoch 83/100 - Average Loss: 0.0045 - LR: 0.000014


Epoch 84/100: 100%|██████████| 313/313 [01:31<00:00,  3.40it/s, loss=0.0101]  


Epoch 84/100 - Average Loss: 0.0048 - LR: 0.000012


Epoch 85/100: 100%|██████████| 313/313 [01:31<00:00,  3.41it/s, loss=0.0105]  


Epoch 85/100 - Average Loss: 0.0048 - LR: 0.000011


Epoch 86/100: 100%|██████████| 313/313 [01:31<00:00,  3.41it/s, loss=0.00254] 


Epoch 86/100 - Average Loss: 0.0042 - LR: 0.000010


Epoch 87/100: 100%|██████████| 313/313 [01:31<00:00,  3.41it/s, loss=0.00112] 


Epoch 87/100 - Average Loss: 0.0043 - LR: 0.000008


Epoch 88/100: 100%|██████████| 313/313 [01:31<00:00,  3.40it/s, loss=0.000485]


Epoch 88/100 - Average Loss: 0.0046 - LR: 0.000007


Epoch 89/100: 100%|██████████| 313/313 [01:31<00:00,  3.40it/s, loss=0.0036]  


Epoch 89/100 - Average Loss: 0.0044 - LR: 0.000006


Epoch 90/100: 100%|██████████| 313/313 [01:31<00:00,  3.41it/s, loss=0.0283]  


Epoch 90/100 - Average Loss: 0.0041 - LR: 0.000005


Sampling: 100%|██████████| 1000/1000 [00:23<00:00, 42.12it/s]
Epoch 91/100: 100%|██████████| 313/313 [01:31<00:00,  3.41it/s, loss=0.00219] 


Epoch 91/100 - Average Loss: 0.0044 - LR: 0.000004


Epoch 92/100: 100%|██████████| 313/313 [01:31<00:00,  3.40it/s, loss=0.0112]  


Epoch 92/100 - Average Loss: 0.0043 - LR: 0.000003


Epoch 93/100: 100%|██████████| 313/313 [01:31<00:00,  3.40it/s, loss=0.00225] 


Epoch 93/100 - Average Loss: 0.0048 - LR: 0.000002


Epoch 94/100: 100%|██████████| 313/313 [01:31<00:00,  3.41it/s, loss=0.00309] 


Epoch 94/100 - Average Loss: 0.0044 - LR: 0.000002


Epoch 95/100: 100%|██████████| 313/313 [01:31<00:00,  3.41it/s, loss=0.00164] 


Epoch 95/100 - Average Loss: 0.0046 - LR: 0.000001


Epoch 96/100: 100%|██████████| 313/313 [01:31<00:00,  3.41it/s, loss=0.000701]


Epoch 96/100 - Average Loss: 0.0044 - LR: 0.000001


Epoch 97/100: 100%|██████████| 313/313 [01:31<00:00,  3.40it/s, loss=0.00111] 


Epoch 97/100 - Average Loss: 0.0045 - LR: 0.000000


Epoch 98/100: 100%|██████████| 313/313 [01:31<00:00,  3.41it/s, loss=0.0159]  


Epoch 98/100 - Average Loss: 0.0045 - LR: 0.000000


Epoch 99/100: 100%|██████████| 313/313 [01:31<00:00,  3.41it/s, loss=0.00445] 


Epoch 99/100 - Average Loss: 0.0048 - LR: 0.000000


Epoch 100/100: 100%|██████████| 313/313 [01:31<00:00,  3.41it/s, loss=0.00587] 


Epoch 100/100 - Average Loss: 0.0048 - LR: 0.000000


Sampling: 100%|██████████| 1000/1000 [00:23<00:00, 42.12it/s]
Sampling: 100%|██████████| 1000/1000 [02:39<00:00,  6.25it/s]


Feature shapes - Real: (50, 1000), Fake: (50, 1000)
Covariance shapes - Real: (1000, 1000), Fake: (1000, 1000)
Final FID Score: 34.3615
Training complete. Final FID Score: 34.3615


Sampling: 100%|██████████| 1000/1000 [00:56<00:00, 17.73it/s]


Generated final samples saved to 'final_samples/final_samples.png'
