In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torchvision import transforms, models
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.optim import Adam
from scipy import linalg
from huggingface_hub import hf_hub_download

# Configuration
class Config:
    def __init__(self):
        # Data parameters
        self.data_path = "extracted_data/Samples"
        self.image_size = 128
        self.channels = 1  # Grayscale images

        # Training parameters
        self.batch_size = 32
        self.epochs = 10
        self.lr = 2e-4
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        # DDPM parameters
        self.timesteps = 1000
        self.beta_start = 1e-4
        self.beta_end = 0.02

        # Model parameters
        self.hidden_dims = [64, 128, 256, 512]
        self.num_res_blocks = 2
        
        # Inference settings (for later use)
        self.model_path = None  # Will be set when loading from HuggingFace
        self.fid_n_samples = 16

config = Config()

# Dataset class
class LensingDataset(Dataset):
    def __init__(self, data_path, image_size, transform=None):
        self.data_path = data_path
        self.file_list = [f for f in os.listdir(data_path) if f.endswith('.npy')]
        self.image_size = image_size

        if transform is None:
            self.transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Resize((self.image_size, self.image_size), antialias=True),
                transforms.Normalize((0.5,), (0.5,))
            ])
        else:
            self.transform = transform

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        file_path = os.path.join(self.data_path, self.file_list[idx])
        image = np.load(file_path).astype(np.float32)

        # Normalize to [0, 1]
        min_val = np.min(image)
        max_val = np.max(image)
        if max_val > min_val:
            image = (image - min_val) / (max_val - min_val)
        else:
            image = np.zeros_like(image)

        # Handle different shapes
        if len(image.shape) == 1:
            size = int(np.sqrt(image.shape[0]))
            if size * size == image.shape[0]:
                image = image.reshape(size, size)
            else:
                closest_factor = int(np.sqrt(image.shape[0]))
                while image.shape[0] % closest_factor != 0 and closest_factor > 1:
                    closest_factor -= 1
                if closest_factor > 1:
                    other_dim = image.shape[0] // closest_factor
                    image = image.reshape(closest_factor, other_dim)
                else:
                    image = np.resize(image, (self.image_size, self.image_size))
        elif len(image.shape) == 3:
            if image.shape[0] == 150:
                image = image[0]
            else:
                image = np.mean(image, axis=0)

        if len(image.shape) != 2:
            image = np.resize(image, (self.image_size, self.image_size))

        if self.transform:
            image = self.transform(image)

        return image.float()

# Diffusion schedule helper
def get_noise_schedule(config):
    """Linear beta schedule for diffusion model."""
    beta = np.linspace(config.beta_start, config.beta_end, config.timesteps)
    sqrt_beta = np.sqrt(beta)
    alpha = 1 - beta
    alpha_bar = np.cumprod(alpha)
    sqrt_alpha_bar = np.sqrt(alpha_bar)
    sqrt_one_minus_alpha_bar = np.sqrt(1 - alpha_bar)

    # Convert to torch tensors
    beta = torch.tensor(beta, dtype=torch.float32).to(config.device)
    sqrt_beta = torch.tensor(sqrt_beta, dtype=torch.float32).to(config.device)
    alpha = torch.tensor(alpha, dtype=torch.float32).to(config.device)
    alpha_bar = torch.tensor(alpha_bar, dtype=torch.float32).to(config.device)
    sqrt_alpha_bar = torch.tensor(sqrt_alpha_bar, dtype=torch.float32).to(config.device)
    sqrt_one_minus_alpha_bar = torch.tensor(sqrt_one_minus_alpha_bar, dtype=torch.float32).to(config.device)

    return {
        'beta': beta,
        'sqrt_beta': sqrt_beta,
        'alpha': alpha,
        'alpha_bar': alpha_bar,
        'sqrt_alpha_bar': sqrt_alpha_bar,
        'sqrt_one_minus_alpha_bar': sqrt_one_minus_alpha_bar,
    }

# Model components
class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = np.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device, dtype=torch.float32) * -embeddings)
        embeddings = time[:, None].float() * embeddings[None, :]
        embeddings = torch.cat((torch.sin(embeddings), torch.cos(embeddings)), dim=-1)
        if self.dim % 2 == 1:
             embeddings = F.pad(embeddings, (0,1))
        return embeddings

class SelfAttention(nn.Module):
    def __init__(self, channels, num_heads=4):
        super(SelfAttention, self).__init__()
        self.channels = channels
        self.num_heads = num_heads
        assert channels % num_heads == 0
        self.mha = nn.MultiheadAttention(channels, num_heads, batch_first=True)
        self.ln = nn.LayerNorm([channels])
        self.ff_self = nn.Sequential(
            nn.LayerNorm([channels]),
            nn.Linear(channels, channels),
            nn.GELU(),
            nn.Linear(channels, channels),
        )

    def forward(self, x):
        size = x.shape[-2:]
        x = x.view(-1, self.channels, size[0] * size[1]).swapaxes(1, 2)
        x_ln = self.ln(x)
        attention_value, _ = self.mha(x_ln, x_ln, x_ln)
        attention_value = attention_value + x
        attention_value = self.ff_self(attention_value) + attention_value
        return attention_value.swapaxes(2, 1).view(-1, self.channels, size[0], size[1])

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, time_emb_dim=None, groups=8):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.time_mlp = nn.Linear(time_emb_dim, out_channels) if time_emb_dim else None

        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.norm1 = nn.GroupNorm(min(groups, out_channels), out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.norm2 = nn.GroupNorm(min(groups, out_channels), out_channels)

        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels, out_channels, 1)
        else:
            self.shortcut = nn.Identity()

        self.act = nn.SiLU()

    def forward(self, x, time_emb=None):
        h = self.act(self.norm1(self.conv1(x)))

        if time_emb is not None and self.time_mlp is not None:
            time_emb_proj = self.act(self.time_mlp(time_emb))
            h = h + time_emb_proj.view(-1, time_emb_proj.shape[1], 1, 1)

        h = self.act(self.norm2(self.conv2(h)))
        return h + self.shortcut(x)

# UNet model definition
class UNet(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.channels = config.channels
        self.time_dim = config.hidden_dims[0]

        # Time embedding
        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(self.time_dim),
            nn.Linear(self.time_dim, self.time_dim * 4),
            nn.SiLU(),
            nn.Linear(self.time_dim * 4, self.time_dim * 4)
        )
        time_emb_dim_unet = self.time_dim * 4

        # Initial projection
        init_dim = config.hidden_dims[0]
        self.init_conv = nn.Conv2d(config.channels, init_dim, 3, padding=1)

        # Encoder
        self.downs = nn.ModuleList()
        dims = config.hidden_dims
        current_dim = init_dim

        for i in range(len(dims)):
            down_block_layers = nn.ModuleList()

            for _ in range(config.num_res_blocks):
                down_block_layers.append(ResidualBlock(current_dim, current_dim, time_emb_dim=time_emb_dim_unet))

            if i >= (len(dims) - 2):
                down_block_layers.append(SelfAttention(current_dim))

            if i < len(dims) - 1:
                down_block_layers.append(nn.Conv2d(current_dim, dims[i+1], kernel_size=4, stride=2, padding=1))
                current_dim = dims[i+1]

            self.downs.append(down_block_layers)

        # Middle block
        mid_dim = dims[-1]
        self.mid = nn.ModuleList([
            ResidualBlock(mid_dim, mid_dim, time_emb_dim=time_emb_dim_unet),
            SelfAttention(mid_dim),
            ResidualBlock(mid_dim, mid_dim, time_emb_dim=time_emb_dim_unet)
        ])

        # Decoder
        self.ups = nn.ModuleList()
        for i in reversed(range(len(dims))):
            up_block_layers = nn.ModuleList()

            if i == len(dims) - 1:
                in_dim = dims[i]
            else:
                in_dim = dims[i] * 2

            up_block_layers.append(ResidualBlock(in_dim, dims[i], time_emb_dim=time_emb_dim_unet))

            for _ in range(config.num_res_blocks - 1):
                up_block_layers.append(ResidualBlock(dims[i], dims[i], time_emb_dim=time_emb_dim_unet))

            if i >= (len(dims) - 2):
                up_block_layers.append(SelfAttention(dims[i]))

            if i > 0:
                up_block_layers.append(nn.ConvTranspose2d(dims[i], dims[i-1], kernel_size=4, stride=2, padding=1))

            self.ups.append(up_block_layers)

        # Final layers
        final_dim = dims[0]
        self.final_conv = nn.Sequential(
            nn.GroupNorm(min(8, final_dim), final_dim),
            nn.SiLU(),
            nn.Conv2d(final_dim, config.channels, 1)
        )

    def forward(self, x, t):
        # Time embedding
        t = self.time_mlp(t)

        # Initial convolution
        x = self.init_conv(x)
        skip_connections = []

        # Encoder
        for down_block in self.downs:
            for layer in down_block:
                if isinstance(layer, ResidualBlock) or isinstance(layer, SelfAttention):
                    x = layer(x, t) if isinstance(layer, ResidualBlock) else layer(x)
                else:
                    skip_connections.append(x)
                    x = layer(x)

        # Middle
        for layer in self.mid:
            x = layer(x, t) if isinstance(layer, ResidualBlock) else layer(x)

        # Decoder
        for i, up_block in enumerate(self.ups):
            if i > 0:
                skip = skip_connections.pop()
                if x.shape[2:] != skip.shape[2:]:
                    skip = F.interpolate(skip, size=x.shape[2:], mode='bilinear', align_corners=False)
                x = torch.cat([x, skip], dim=1)

            for layer in up_block:
                if isinstance(layer, ResidualBlock) or isinstance(layer, SelfAttention):
                    x = layer(x, t) if isinstance(layer, ResidualBlock) else layer(x)
                else:
                    x = layer(x)

        # Final convolution
        return self.final_conv(x)

# Diffusion model
class DiffusionModel:
    def __init__(self, config):
        self.config = config
        self.model = UNet(config).to(config.device)
        self.noise_schedule = get_noise_schedule(config)
        self.optimizer = Adam(self.model.parameters(), lr=config.lr)
        self.loss_fn = nn.MSELoss()

    def forward_diffusion(self, x_0, t):
        """Add noise to the input image according to the timestep t."""
        x_0 = x_0.float().to(self.config.device)
        noise = torch.randn_like(x_0, dtype=torch.float32, device=self.config.device)
        sqrt_alpha_bar = self.noise_schedule['sqrt_alpha_bar'][t].view(-1, 1, 1, 1)
        sqrt_one_minus_alpha_bar = self.noise_schedule['sqrt_one_minus_alpha_bar'][t].view(-1, 1, 1, 1)
        x_t = sqrt_alpha_bar * x_0 + sqrt_one_minus_alpha_bar * noise
        return x_t, noise

    def train_step(self, x_0):
        """Single training step."""
        self.optimizer.zero_grad()
        x_0 = x_0.float().to(self.config.device)
        batch_size = x_0.shape[0]
        t = torch.randint(0, self.config.timesteps, (batch_size,), device=self.config.device).long()
        x_t, noise = self.forward_diffusion(x_0, t)
        predicted_noise = self.model(x_t, t)
        loss = self.loss_fn(predicted_noise, noise)
        loss.backward()
        self.optimizer.step()
        return loss.item()

    @torch.no_grad()
    def sample(self, n_samples, size=None):
        """Sample new images using the reverse diffusion process."""
        self.model.eval()

        if size is None:
            size = (self.config.channels, self.config.image_size, self.config.image_size)

        x = torch.randn((n_samples, *size), device=self.config.device, dtype=torch.float32)

        for t in tqdm(reversed(range(self.config.timesteps)), desc="Sampling", total=self.config.timesteps):
            t_tensor = torch.full((n_samples,), t, device=self.config.device, dtype=torch.long)
            predicted_noise = self.model(x, t_tensor)
            alpha_t = self.noise_schedule['alpha'][t]
            alpha_t_bar = self.noise_schedule['alpha_bar'][t]
            beta_t = self.noise_schedule['beta'][t]
            sqrt_one_minus_alpha_t_bar = self.noise_schedule['sqrt_one_minus_alpha_bar'][t]
            sqrt_recip_alpha_t = torch.sqrt(1.0 / alpha_t)
            mean = sqrt_recip_alpha_t * (x - (beta_t / sqrt_one_minus_alpha_t_bar) * predicted_noise)

            if t > 0:
                std_dev = self.noise_schedule['sqrt_beta'][t]
                noise = torch.randn_like(x, dtype=torch.float32)
                x = mean + std_dev * noise
            else:
                x = mean

        self.model.train()
        return x

# FID calculation helper
class FID:
    def __init__(self, config):
        self.device = config.device
        try:
            self.inception_model = models.inception_v3(weights=models.Inception_V3_Weights.DEFAULT, transform_input=False)
        except TypeError:
            self.inception_model = models.inception_v3(pretrained=True, transform_input=False)

        self.inception_model.fc = nn.Identity()
        self.inception_model.AuxLogits.fc = nn.Identity()
        self.inception_model = self.inception_model.to(self.device)
        self.inception_model.eval()

    @torch.no_grad()
    def extract_features(self, loader):
        features = []

        for batch in tqdm(loader, desc="Extracting features"):
            if isinstance(batch, tuple) and len(batch) == 1:
                imgs = batch[0]
            elif isinstance(batch, list) and len(batch) == 1:
                imgs = batch[0]
            else:
                imgs = batch

            imgs = imgs.to(self.device)
            if imgs.shape[1] == 1:
                imgs = imgs.repeat(1, 3, 1, 1)
            imgs = F.interpolate(imgs, size=(299, 299), mode='bilinear', align_corners=False, antialias=True)
            imgs = (imgs + 1) / 2.0
            inception_normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            imgs = inception_normalize(imgs)
            feat = self.inception_model(imgs)
            if isinstance(feat, models.inception.InceptionOutputs):
                feat = feat.logits
            features.append(feat.cpu().numpy())

        return np.concatenate(features, axis=0)

    def calculate_fid(self, real_features, generated_features):
        mu1, sigma1 = np.mean(real_features, axis=0), np.cov(real_features, rowvar=False)
        mu2, sigma2 = np.mean(generated_features, axis=0), np.cov(generated_features, rowvar=False)
        sum_sq_diff = np.sum((mu1 - mu2)**2)
        eps = 1e-6
        covmean, _ = linalg.sqrtm(sigma1.dot(sigma2) + eps * np.eye(sigma1.shape[0]), disp=False)
        if np.iscomplexobj(covmean):
            covmean = covmean.real
        fid = sum_sq_diff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
        return fid

# Helper functions
def load_dataset(config):
    try:
        dataset = LensingDataset(config.data_path, config.image_size)
        print(f"Successfully loaded dataset with {len(dataset)} samples")
        if len(dataset) > 0:
            sample = dataset[0]
            print(f"Sample shape: {sample.shape}, dtype: {sample.dtype}")
        else:
            print("Dataset is empty!")
            return None
        return dataset
    except Exception as e:
        print(f"Error loading dataset: {e}")
        if not os.path.exists(config.data_path):
            print(f"Data directory '{config.data_path}' not found.")
            print("Please ensure the data is extracted correctly.")
        elif os.path.isdir(config.data_path):
            try:
                files = os.listdir(config.data_path)
                print(f"Files in {config.data_path}: {files[:10]}...")
                if not any(f.endswith('.npy') for f in files):
                    print("No .npy files found in the directory.")
            except Exception as list_e:
                print(f"Could not list files in {config.data_path}: {list_e}")
        import traceback
        traceback.print_exc()
        return None

def train(config):
    dataset = load_dataset(config)
    if dataset is None:
        print("Dataset loading failed. Exiting training.")
        return None

    num_workers = 2
    dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True, 
                           num_workers=num_workers, pin_memory=True if config.device=="cuda" else False)

    diffusion = DiffusionModel(config)
    losses = []
    
    for epoch in range(config.epochs):
        epoch_losses = []
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{config.epochs}", leave=True)

        for batch in progress_bar:
            loss = diffusion.train_step(batch)
            epoch_losses.append(loss)
            progress_bar.set_postfix({"loss": np.mean(epoch_losses[-20:])})

        avg_loss = np.mean(epoch_losses)
        losses.append(avg_loss)
        print(f"Epoch {epoch+1}/{config.epochs}, Average Loss: {avg_loss:.4f}")

        # Generate samples periodically
        if (epoch + 1) % 10 == 0 or epoch == config.epochs - 1:
            print(f"Generating samples for epoch {epoch+1}...")
            samples = diffusion.sample(16)
            samples = (samples + 1) / 2
            samples = samples.clamp(0.0, 1.0)

            plt.figure(figsize=(10, 10))
            for i in range(16):
                plt.subplot(4, 4, i+1)
                plt.imshow(samples[i, 0].cpu().numpy(), cmap='viridis')
                plt.axis('off')
            plt.suptitle(f"Samples Epoch {epoch+1}", fontsize=16)
            plt.tight_layout(rect=[0, 0.03, 1, 0.95])
            save_path = f"samples_epoch_{epoch+1}.png"
            plt.savefig(save_path)
            print(f"Saved samples to {save_path}")
            plt.close()
            
            # Calculate FID
            if True:  # Only calculate FID for the final epoch to save time
                print(f"Calculating FID for epoch {epoch+1}...")
                fid_calculator = FID(config)
                
                subset_indices = np.random.choice(len(dataset), min(256, len(dataset)), replace=False)
                real_subset = torch.utils.data.Subset(dataset, subset_indices)
                fid_batch_size = 32
                real_loader = DataLoader(real_subset, batch_size=fid_batch_size, shuffle=False, num_workers=num_workers)
                real_features = fid_calculator.extract_features(real_loader)
                
                n_fid_samples = len(real_subset)
                generated_samples = diffusion.sample(n_fid_samples)
                generated_dataset = torch.utils.data.TensorDataset(generated_samples)
                generated_loader = DataLoader(generated_dataset, batch_size=fid_batch_size, shuffle=False)
                generated_features = fid_calculator.extract_features(generated_loader)
                
                fid_score = fid_calculator.calculate_fid(real_features, generated_features)
                print(f"Epoch {epoch+1}, FID Score: {fid_score:.2f}")

    # Save the trained model
    model_save_path = "gravitational_lensing_diffusion.pth"
    torch.save(diffusion.model.state_dict(), model_save_path)
    print(f"Saved trained model to {model_save_path}")
    
    # Plot loss curve
    plt.figure(figsize=(10, 5))
    plt.plot(losses)
    plt.xlabel('Epoch')
    plt.ylabel('Average Loss')
    plt.title('Training Loss Curve')
    plt.grid(True)
    loss_plot_path = "training_loss.png"
    plt.savefig(loss_plot_path)
    print(f"Saved training loss plot to {loss_plot_path}")
    plt.close()
    
    return diffusion

# Inference with downloaded model
class DiffusionInference:
    def __init__(self, config):
        self.config = config
        self.model = UNet(config).to(config.device)
        self.noise_schedule = get_noise_schedule(config)

        print(f"Loading model from {config.model_path}...")
        try:
            self.model.load_state_dict(torch.load(config.model_path, map_location=config.device))
            print("Model loaded successfully!")
        except Exception as e:
            print(f"Error loading model: {e}")
            raise

        self.model.eval()

    @torch.no_grad()
    def sample(self, n_samples, size=None):
        if size is None:
            size = (self.config.channels, self.config.image_size, self.config.image_size)

        x = torch.randn((n_samples, *size), device=self.config.device, dtype=torch.float32)

        for t in tqdm(reversed(range(self.config.timesteps)), desc="Sampling", total=self.config.timesteps):
            t_tensor = torch.full((n_samples,), t, device=self.config.device, dtype=torch.long)
            predicted_noise = self.model(x, t_tensor)
            alpha_t = self.noise_schedule['alpha'][t]
            alpha_t_bar = self.noise_schedule['alpha_bar'][t]
            beta_t = self.noise_schedule['beta'][t]
            sqrt_one_minus_alpha_t_bar = self.noise_schedule['sqrt_one_minus_alpha_bar'][t]
            sqrt_recip_alpha_t = torch.sqrt(1.0 / alpha_t)
            mean = sqrt_recip_alpha_t * (x - (beta_t / sqrt_one_minus_alpha_t_bar) * predicted_noise)

            if t > 0:
                std_dev = self.noise_schedule['sqrt_beta'][t]
                noise = torch.randn_like(x, dtype=torch.float32)
                x = mean + std_dev * noise
            else:
                x = mean

        return x

def download_and_infer():
    # Download model from HuggingFace
    model_path = hf_hub_download(
        repo_id="oussamaor/gravlens",
        filename="model.pth"
    )
    print("Model downloaded to:", model_path)
    
    # Set up config for inference
    config.model_path = model_path
    
    # Generate samples from the downloaded model
    diffusion = DiffusionInference(config)
    n_samples = 16
    samples = diffusion.sample(n_samples)
    samples = (samples + 1) / 2
    samples = samples.clamp(0.0, 1.0)

    # Visualize samples
    plt.figure(figsize=(12, 12))
    for i in range(n_samples):
        plt.subplot(4, 4, i+1)
        plt.imshow(samples[i, 0].cpu().numpy(), cmap='viridis')
        plt.axis('off')
    plt.suptitle("Generated Samples from Pretrained Model", fontsize=16)
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.savefig("pretrained_samples.png")
    plt.show()

def calculate_fid_with_pretrained():
    # Ensure the model is downloaded
    if config.model_path is None:
        model_path = hf_hub_download(
            repo_id="oussamaor/gravlens",
            filename="model.pth"
        )
        config.model_path = model_path
    
    # Load dataset and check if it exists
    dataset = load_dataset(config)
    if dataset is None:
        print("Cannot calculate FID: dataset loading failed")
        return
    
    # Initialize diffusion and FID calculator
    diffusion = DiffusionInference(config)
    fid_calculator = FID(config)
    
    # Sample subset of real data
    subset_size = min(config.fid_n_samples, len(dataset))
    subset_indices = np.random.choice(len(dataset), subset_size, replace=False)
    real_subset = torch.utils.data.Subset(dataset, subset_indices)
    real_loader = DataLoader(real_subset, batch_size=config.batch_size, shuffle=False, num_workers=2)
    
    # Extract features
    real_features = fid_calculator.extract_features(real_loader)
    
    # Generate samples and extract features
    generated_samples = diffusion.sample(subset_size)
    generated_dataset = TensorDataset(generated_samples)
    generated_loader = DataLoader(generated_dataset, batch_size=config.batch_size, shuffle=False)
    generated_features = fid_calculator.extract_features(generated_loader)
    
    # Calculate FID
    fid_score = fid_calculator.calculate_fid(real_features, generated_features)
    print(f"FID Score with pretrained model: {fid_score:.2f}")
    return fid_score

# Main function
def main():
    print(f"Using device: {config.device}")
    
    # You can choose to either train or infer
    train_mode = False  # Set to True to train, False to use pretrained
    if train_mode:
        print("Training diffusion model for gravitational lensing...")
        diffusion_model = train(config)
        
        if diffusion_model is not None:
            print("Generating final samples...")
            final_samples = diffusion_model.sample(16)
            final_samples = (final_samples + 1) / 2
            final_samples = final_samples.clamp(0.0, 1.0)
            
            plt.figure(figsize=(12, 12))
            for i in range(16):
                plt.subplot(4, 4, i+1)
                plt.imshow(final_samples[i, 0].cpu().numpy(), cmap='viridis')
                plt.axis('off')
            plt.suptitle("Final Generated Samples", fontsize=16)
            plt.tight_layout(rect=[0, 0.03, 1, 0.95])
            plt.savefig("final_samples.png")
            print("Saved final generated samples to final_samples.png")
            plt.show()
    else:
        # Use pretrained model from HuggingFace
        print("Using pretrained model from HuggingFace")
        download_and_infer()
        
        # Optionally calculate FID score
        calculate_fid = True  # Set to False to skip FID calculation
        if calculate_fid:
            fid_score = calculate_fid_with_pretrained()
            print(f"FID Score: {fid_score:.2f}")

if __name__ == "__main__":
    main()