### Config for loading data

In [None]:
### This is configs/config.py ###

config_CELEBA = {
    "dataset_name": "CELEBA",
    "image_shape": (3, 64, 64),
    "batch_size": 64,
    "epochs": 500,
    "learning_rate": 3e-4,
    "project_name": "CELEBA_wandb",
    "log_sample_interval": 5,
    "wandb_name": "Test",
    "data_dir": "cropped_celeba_bin",  # Path to your .bin files

}

config_CIFAR10 = {
    "dataset_name": "CIFAR10",
    "image_shape": (3, 32, 32),  
    "batch_size": 128,
    "epochs": 500,
    "learning_rate": 3e-4,
    "project_name": "CIFAR10_wandb",
    "log_sample_interval": 5,
    "wandb_name": "Test",
}

config_MNIST = {
    "dataset_name": "MNIST",
    "image_shape": (1, 32, 32),
    "batch_size": 128,
    "epochs": 500,
    "learning_rate": 3e-4,
    "project_name": "MNIST_wandb",
    "log_sample_interval": 5,
    "wandb_name": "Test",
}

### Load the data

In [None]:
### The is utils/dataset_loader.py ###


import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import Subset, DataLoader
import os

def load_dataset(config, small_sample=False, validation=False):
    validationSize = 0.1
    if config["dataset_name"] == "MNIST":
        transform = transforms.Compose([
            transforms.Pad((2, 2, 2, 2), fill=0),  # Adds 2 pixels to each side with black padding
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,)),  # Normalize to [-1, 1]
        ])
        dataset = datasets.MNIST("mnist", download=True, transform=transform)

    elif config["dataset_name"] == "CIFAR10":
        transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # RGB Normalize to [-1, 1]
        ])
        dataset = datasets.CIFAR10("cifar10", download=True, transform=transform)

    elif config["dataset_name"] == "CELEBA":
        # Define transformation
        transform = transforms.Compose([
            transforms.Resize((config["image_shape"][1], config["image_shape"][2])),  
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),  # Convert PIL image to tensor
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # Normalize to [-1, 1]
        ])

        # Get all data batch files
        data_batch_dir = "cropped_celeba_bin"
        batch_files = [os.path.join(data_batch_dir, f) for f in os.listdir(data_batch_dir) if f.startswith("data_batch_")]

        # Load all binary files
        dataset = CombinedBinaryDataset(bin_files=batch_files, img_size=(128, 128), num_channels=3, transform=transform)

    else:
        raise ValueError(f"Dataset {config['dataset_name']} is not supported!")
    
    if small_sample:
        # Use only the first 100 samples for quick testing
        dataset = Subset(dataset, range(10))
    
    if validation:
        dataset_size = len(dataset)
        val_size = int(validationSize * dataset_size)
        dataset = Subset(dataset, range(dataset_size - val_size, dataset_size))
        print("Loading validation set")
    else:
        dataset_size = len(dataset)
        train_size = int((1-validationSize) * dataset_size)
        dataset = Subset(dataset, range(0, train_size))
        print("Loading training set")

    return dataset

class CombinedBinaryDataset(torch.utils.data.Dataset):
    def __init__(self, bin_files, img_size, num_channels=3, transform=None):
        self.bin_files = bin_files  # List of binary file paths
        self.img_size = img_size
        self.num_channels = num_channels
        self.samples = []
        self.transform = transform

        # Read all batches into memory
        for bin_file in self.bin_files:
            file_size = os.path.getsize(bin_file)
            sample_size = 1 + num_channels * img_size[0] * img_size[1]  # 1 byte label + pixel data
            num_samples = file_size // sample_size

            #print(f"Loading {num_samples} samples from {bin_file}")

            with open(bin_file, "rb") as f:
                for _ in range(num_samples):
                    raw = f.read(sample_size)
                    self.samples.append(raw)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        raw = self.samples[idx]
        label = raw[0]  # Dummy label
        pixels = torch.tensor(
            list(raw[1:]), dtype=torch.float32
        ).reshape(self.num_channels, *self.img_size) / 255.0  # Normalize to [0, 1]

        if self.transform:
            # Convert to PIL Image for compatibility with transforms
            pixels = transforms.ToPILImage()(pixels)
            pixels = self.transform(pixels)

        return pixels, label


### Define the blocks used in our model

In [None]:
### The is UNetResBlock/blocks.py ###


import torch.nn as nn
import torch
import torch.nn.functional as F


# Taken from https://github.com/dome272/Diffusion-Models-pytorch
class SelfAttention(nn.Module):
    def __init__(self, channels, size):
        super(SelfAttention, self).__init__()
        self.channels = channels
        self.size = size
        self.mha = nn.MultiheadAttention(channels, 4, batch_first=True)
        self.ln = nn.LayerNorm([channels])
        self.ff_self = nn.Sequential(
            nn.LayerNorm([channels]),
            nn.Linear(channels, channels),
            nn.GELU(),
            nn.Linear(channels, channels),
        )

    def forward(self, x):
        x = x.view(-1, self.channels, self.size * self.size).swapaxes(1, 2)
        x_ln = self.ln(x)
        attention_value, _ = self.mha(x_ln, x_ln, x_ln)
        attention_value = attention_value + x
        attention_value = self.ff_self(attention_value) + attention_value
        return attention_value.swapaxes(2, 1).view(-1, self.channels, self.size, self.size)


class ConvBlock(nn.Module): 
    def __init__(self, in_channels, out_channels, time_dim):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),  # Convolution
            nn.GroupNorm(32, out_channels),  # Group normalization
            nn.SiLU(),  # SiLU activation
            nn.Dropout(p=0.1),
        )
        
        # Optional 1x1 convolution for residual connection if dimensions mismatch
        self.residual_conv = (
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
            if in_channels != out_channels else nn.Identity()
        )

    def forward(self, x): 
        return self.conv(x) + self.residual_conv(x)

class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, time_dim):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(32, out_channels),
            nn.SiLU(),
            nn.Dropout(p=0.1),
        )
        
        self.residual_conv = (
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
            if in_channels != out_channels else nn.Identity()
        )

        self.activation = nn.SiLU()
        self.time_proj = nn.Linear(time_dim, out_channels)  # Project time embedding

    def forward(self, x, temb):
        x_conv = self.conv(x)  # Standard conv operation
        # Project time embedding to spatial dimensions
        temb_proj = self.time_proj(temb)
        temb_proj = temb_proj[:, :, None, None].repeat(1, 1, x_conv.shape[-2], x_conv.shape[-1])
        x_output = x_conv + temb_proj
        residual = self.residual_conv(x)
        return self.activation(x_output + residual)


class DownBlock(nn.Module):
    def __init__(self, in_channels, out_channels, time_dim):
        super().__init__()
        self.res_block_1 = ResBlock(in_channels, out_channels, time_dim)
        self.downSample = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1)
        self.res_block_2 = ResBlock(out_channels, out_channels, time_dim)

    def forward(self, x, temb):
        x = self.res_block_1(x, temb)
        x_down = self.downSample(x)
        x_down = self.res_block_2(x_down, temb)
        return x_down


class UpBlock(nn.Module):
    def __init__(self, in_channels, out_channels, time_dim):
        super().__init__()
        self.upSample = nn.ConvTranspose2d(in_channels, in_channels, kernel_size=2, stride=2)
        self.res_block_1 = ResBlock(in_channels * 2, in_channels, time_dim)
        self.res_block_2 = ResBlock(in_channels, out_channels, time_dim)

    def forward(self, x, x_skip, temb):
        x_up = self.upSample(x)
        x_cat = torch.cat([x_skip, x_up], dim=1)
        x_up = self.res_block_1(x_cat, temb)
        x_up = self.res_block_2(x_up, temb)
        return x_up


class TimeEmbedding(nn.Module):
    def __init__(self, embedding_dim, projection_dim):
        super().__init__()
        self.dense1 = nn.Linear(embedding_dim, projection_dim)
        self.dense2 = nn.Linear(projection_dim, projection_dim)
        self.activation = nn.SiLU()

    def forward(self, t):
        temb = self.dense1(t)
        temb = self.activation(temb)
        temb = self.dense2(temb)
        return temb

### Functions for alpha and beta

In [None]:
### The is UNetResBlock/alphabeta.py ###


import torch.nn as nn
import torch
import torch.nn.functional as F

def compute_linear_beta_schedule(timesteps, beta_start=1e-4, beta_end=0.02):
    """
    Linear schedule for betas from beta_start to beta_end over timesteps.
    """
    betas = torch.linspace(beta_start, beta_end, timesteps)
    return betas

def compute_alpha_schedule(betas):
    """
    Computes alpha values for each timestep using the betas.
    alpha_t = 1 - beta_t
    """
    return 1.0 - betas

def compute_alpha_cumulative_product(alpha_t):
    """
    Computes the cumulative product of alpha_t over all timesteps.
    This gives us the alpha_cumprod needed for the reverse process.
    """
    return torch.cumprod(alpha_t, dim=0)


### EMA

In [None]:
### The is UNetResBlock/EMA.py ###


# Written by chatgpt, modified by us
class EMA:
    def __init__(self, model, decay=0.999):
        """
        Exponential Moving Average (EMA) for model parameters.

        Args:
            model (torch.nn.Module): The model to track.
            decay (float): EMA decay factor, usually close to 1 (e.g., 0.999).
        """
        self.model = model
        self.decay = decay
        self.shadow = {name: param.clone().detach() for name, param in model.named_parameters()}

    def update(self, model):
        """
        Update the EMA parameters with the current model parameters.

        Args:
            model (torch.nn.Module): The current model with updated weights.
        """
        for name, param in model.named_parameters():
            if param.requires_grad:
                self.shadow[name].data.mul_(self.decay).add_((1.0 - self.decay) * param.data)

    def apply_shadow(self):
        """Apply the EMA weights to the model."""
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                param.data.copy_(self.shadow[name].data)

    def store(self):
        """Store the current model parameters for restoration."""
        self.backup = {name: param.clone() for name, param in self.model.named_parameters()}

    def restore(self):
        """Restore the original model parameters."""
        for name, param in self.model.named_parameters():
            param.data.copy_(self.backup[name])

### Define the model

In [None]:
### The is UNetResBlock/model.py ###


import torch.nn as nn
import torch
import torch.nn.functional as F
import math
import os
from PIL import Image
import torchvision.utils as vutils

class UNet(nn.Module):
    def __init__(self, dim=32, in_channels=3, out_channels=3, time_dim=256, device="cuda"):
        super(UNet, self).__init__()

        self.device = device

        self.time_embedding_dim = 128  # Or some other dimension
        self.time_projection_dim = 512
        self.time_embedder = TimeEmbedding(self.time_embedding_dim, self.time_projection_dim)

        # Encoder (Downsampling path)
        self.inc = ResBlock(in_channels, 64, self.time_projection_dim)  # (b, 64, 32, 32)
        self.down1 = DownBlock(64, 128, self.time_projection_dim)       # (b, 128, 16, 16)
        self.attn1 = SelfAttention(128, dim // 2)
        self.down2 = DownBlock(128, 256, self.time_projection_dim)      # (b, 256, 8, 8)
        self.attn2 = SelfAttention(256, dim // 4)
        self.down3 = DownBlock(256, 512, self.time_projection_dim)      # (b, 512, 4, 4)
        self.attn3 = SelfAttention(512, dim // 8)
        self.down4 = DownBlock(512, 1024, self.time_projection_dim)     # (b, 1024, 2, 2)

        # Bottleneck
        self.bot1 = ResBlock(1024, 1024, self.time_projection_dim)      # (b, 1024, 2, 2)
        self.bot2 = ResBlock(1024, 1024, self.time_projection_dim)      # (b, 1024, 2, 2)
        self.attn_bot = SelfAttention(1024, dim // 16)
        self.bot3 = ResBlock(1024, 512, self.time_projection_dim)       # (b, 512, 2, 2)

        # Decoder (Upsampling path)
        self.up1 = UpBlock(512, 256, self.time_projection_dim)          # (b, 256, 4, 4)
        self.attn4 = SelfAttention(256, dim // 8)
        self.up2 = UpBlock(256, 128, self.time_projection_dim)          # (b, 128, 8, 8)
        self.attn5 = SelfAttention(128, dim // 4)
        self.up3 = UpBlock(128, 64, self.time_projection_dim)           # (b, 64, 16, 16)
        self.attn6 = SelfAttention(64, dim // 2)
        self.up4 = UpBlock(64, 64, self.time_projection_dim)            # (b, 64, 32, 32)

        # Output layer
        self.outc = nn.Conv2d(64, out_channels, kernel_size=1)          # Final output (b, c_out, 32, 32)

    
    def get_timestep_embedding(self, timesteps, embedding_dim):
        """
        Sinusoidal embeddings for discrete timesteps.
        """
        assert len(timesteps.shape) == 1, "Timesteps should be a 1D tensor"
        half_dim = embedding_dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
        emb = timesteps[:, None].float() * emb[None, :]
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
        if embedding_dim % 2 == 1:
            emb = torch.nn.functional.pad(emb, (0, 1))  # Zero pad to match dimensions
        return emb
        
    def forward(self, x, t):
        t_emb = self.get_timestep_embedding(t, self.time_embedding_dim)
        t_emb = self.time_embedder(t_emb)  # Project embedding

        # Encoder
        x1 = self.inc(x, t_emb)            # (b, 64, 32, 32)
        x2 = self.down1(x1, t_emb)         # (b, 128, 16, 16)
        x2 = self.attn1(x2)
        x3 = self.down2(x2, t_emb)         # (b, 256, 8, 8)
        x3 = self.attn2(x3)
        x4 = self.down3(x3, t_emb)         # (b, 512, 4, 4)
        x4 = self.attn3(x4)
        x5 = self.down4(x4, t_emb)         # (b, 1024, 2, 2)
        
        # Bottleneck
        x = self.bot1(x5, t_emb)           # (b, 1024, 2, 2)
        x = self.bot2(x, t_emb)           # (b, 1024, 2, 2)
        x = self.attn_bot(x)
        x = self.bot3(x, t_emb)           # (b, 512, 2, 2)

        # Decoder path (Upsampling)
        x = self.up1(x, x4, t_emb)         # (b, 256, 4, 4)
        x = self.attn4(x)
        x = self.up2(x, x3, t_emb)         # (b, 128, 8, 8)
        x = self.attn5(x)
        x = self.up3(x, x2, t_emb)         # (b, 64, 16, 16)
        x = self.attn6(x)
        x = self.up4(x, x1, t_emb)         # (b, 64, 32, 32)

        # Final output
        output = self.outc(x)              # (b, c_out, 32, 32)
        return output



def linear_beta_schedule(timesteps):
    beta_start = 0.0001  # Small noise variance at start
    beta_end = 0.02  # Larger noise variance at end
    return torch.linspace(beta_start, beta_end, timesteps)

def compute_alpha_and_alpha_bar(betas):
    alphas = 1.0 - betas
    alpha_bar = torch.cumprod(alphas, dim=0)
    return alphas, alpha_bar

def calc_loss(u_net, x, timesteps=1000):
    # Define the linear beta schedule
    betas = linear_beta_schedule(timesteps)
    _, alpha_bar = compute_alpha_and_alpha_bar(betas)
    
    # Sample random time steps
    t = torch.randint(0, timesteps, (x.size(0),), device=x.device).long()
    
    # Get alpha_bar_t values for chosen time steps
    alpha_bar_t = alpha_bar.to(x.device)[t].view(-1, 1, 1, 1)    
    
    # Add noise according to the forward process
    noise = torch.randn_like(x)
    x_t = torch.sqrt(alpha_bar_t) * x + torch.sqrt(1 - alpha_bar_t) * noise
    
    # Predict the noise using the score network
    predicted_noise = u_net(x_t, t/timesteps)
    
    # Compute the loss as the difference between predicted and actual noise
    loss = torch.mean((predicted_noise - noise) ** 2)
    return loss

def generate_samples(u_net, nsamples, image_shape, timesteps=1000):
    # Define the linear beta schedule
    betas = linear_beta_schedule(timesteps)
    alphas, alpha_bar = compute_alpha_and_alpha_bar(betas)

    device = next(u_net.parameters()).device
    x_t = torch.randn((nsamples, *image_shape), device=device)  # Start from pure noise
    
    for t in reversed(range(timesteps)):
        t_tensor = torch.full((x_t.size(0),), t, device=device).long()  # Current time step
        
        # Compute the variance terms
        alpha_t = alphas[t]
        alpha_bar_t = alpha_bar[t]
        alpha_bar_prev = alpha_bar[t - 1] if t > 0 else 1.0
        
        # Predict the noise
        predicted_noise = u_net(x_t, t_tensor/timesteps).detach()
        
        # Reconstruct the mean (mu) of x_{t-1}
        mu = (1 / torch.sqrt(alpha_t)) * (
            x_t - (betas[t] / torch.sqrt(1 - alpha_bar_t)) * predicted_noise
        )
        
        # Add noise for all steps except the final one
        if t > 0:
            noise = torch.randn_like(x_t)
            sigma_t = torch.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar_t) * betas[t])
            x_t = mu + sigma_t * noise
        else:
            x_t = mu  # No noise added at the final step
    
    return x_t


def save_intermediate_images(images, timestep, save_dir = "UNet/GenInt"):
    """
    Save a batch of intermediate images at a specific timestep.
    """
    os.makedirs(save_dir, exist_ok=True)  # Ensure the directory exists
    grid = vutils.make_grid(images, nrow=4, normalize=True, scale_each=True)
    grid = (grid.permute(1, 2, 0).cpu().numpy() * 255).astype('uint8')  # Convert to uint8
    img = Image.fromarray(grid)
    img.save(os.path.join(save_dir, f"timestep_{timestep:04d}.png"))




def print_memory_usage(tag=""):
    print(f"[{tag}] Allocated Memory: {torch.cuda.memory_allocated() / 1e6} MB")
    print(f"[{tag}] Reserved Memory: {torch.cuda.memory_reserved() / 1e6} MB")



### Trainer for the model

In [None]:
### The is UNetResBlock/trainer.py ###


import torch
import wandb
import os
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm

def train_model(u_net, dataset, config, model_name, log=False, save_model=False):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    u_net = u_net.to(device)
    ema = EMA(u_net, decay=0.995)  # Initialize EMA tracker

    image_shape = config["image_shape"]
    save_dir = os.path.join("UNetResBlock/results", model_name)  # Create the full directory path
    
    # Ensure the directory exists
    os.makedirs(save_dir, exist_ok=True)
    
    # Optimizer
    opt = torch.optim.Adam(u_net.parameters(), lr=config["learning_rate"])

    # DataLoader
    dloader = torch.utils.data.DataLoader(dataset, batch_size=config["batch_size"], shuffle=True)


    p_bar = False

    for i_epoch in range(config["epochs"]):
        print(f"Epoch {i_epoch} started, EMA")
        total_loss = 0

        # Wrap DataLoader with or without tqdm for progress bar
        if p_bar:  # Default to True if not specified
            num_batches = len(dloader)
            progress_bar = tqdm(dloader, total=num_batches, desc=f"Epoch {i_epoch}", ncols=100)
        else:
            progress_bar = dloader  # Use the dataloader directly without tqdm

        # Training loop
        for batch_idx, (data, _) in enumerate(progress_bar):
            data = data.to(device)
            opt.zero_grad()
            loss = calc_loss(u_net, data)
            loss.backward()
            opt.step()
            
            # Update EMA after optimizer step
            ema.update(u_net)

            total_loss += loss.item() * data.shape[0]

            # Update the description of the progress bar with current loss
            if p_bar:  # Only update tqdm if it's enabled
                progress_bar.set_postfix(loss=loss.item(), total_loss=total_loss)

        avg_loss = total_loss / len(dataset)
        if log:
            wandb.log({"loss": avg_loss}, step=i_epoch)

        # Save the model with EMA weights
        # Save the model with EMA weights
        if save_model:
            ema.store()  # Save the current model parameters
            ema.apply_shadow()  # Apply EMA weights to the model
            torch.save(u_net.state_dict(), f"UNetResBlock/models/{model_name}_ema.pt")
            ema.restore()  # Restore original weights for further training

        # Generate samples using EMA weights
        if (i_epoch < 5 or i_epoch % 5 == 0):
            print("Generating samples with EMA weights...")
            ema.store()  # Save the current model parameters
            ema.apply_shadow()  # Apply EMA weights
            generated_samples = generate_samples(u_net, nsamples=16, image_shape=image_shape, timesteps=1000)
            ema.restore()  # Restore original weights for further training
            
            # Save the generated samples as a row in the specified folder
            save_samples(generated_samples, save_dir, f"epoch{i_epoch}", i_epoch, wandb_log=True)
            save_samples(generated_samples, "/zhome/1a/a/156609/public_html/ShowResults", "resultOwn")
            
            # Save the first 8 images from the dataset
            first_batch, _ = next(iter(torch.utils.data.DataLoader(dataset, batch_size=8, shuffle=False)))
            save_samples(first_batch.to(device), save_dir, f"dataset_samples")
            

        # Compute average loss for the epoch
        print(f"Epoch {i_epoch}, Loss: {avg_loss}")

    print("Training complete.")


def save_samples(samples, save_dir, filename, i_epoch=0, wandb_log=False):
    # Ensure the save directory exists
    os.makedirs(save_dir, exist_ok=True)

    # Move the samples to CPU if on GPU and scale from [-1, 1] to [0, 1]
    samples = samples.cpu()
    # Compute statistics
    min_val = samples.min().item()
    max_val = samples.max().item()
    avg_val = samples.mean().item()
    std_val = samples.std().item()

    # Print statistics
    print(f"file name: {filename}")
    print(f"with samples min: {min_val:.2f}, samples max: {max_val:.2f}, avg: {avg_val:.2f}, std: {std_val:.2f}")

    samples = (samples + 1) / 2  # Scale from [-1, 1] to [0, 1]
    samples = samples.clamp(0, 1)  # Ensure values are within [0, 1]

    # Save concatenated image as a single PNG file
    images = []
    for i in range(samples.shape[0]):
        img = samples[i].permute(1, 2, 0).clamp(0, 1).numpy() * 255
        if samples.shape[1] == 1:  # Grayscale
            img = img[:, :, 0]
            img = Image.fromarray(img.astype('uint8'), mode='L')
        elif samples.shape[1] == 3:  # RGB
            img = Image.fromarray(img.astype('uint8'))
        else:
            raise ValueError(f"Unexpected number of channels: {samples.shape[1]}")
        images.append(img)

    concatenated_image = Image.new(
        'RGB' if samples.shape[1] == 3 else 'L',
        (samples.shape[0] * images[0].width, images[0].height)
    )

    for i, img in enumerate(images):
        concatenated_image.paste(img, (i * img.width, 0))

    # Save the concatenated image
    save_path = os.path.join(save_dir, f"{filename}.png")
    concatenated_image.save(save_path)

    # Optionally log to wandb with matplotlib for grayscale
    if wandb_log:
        grid_size = 4  # 4x4 grid
        num_images = samples.shape[0]

        if samples.shape[1] == 1:  # Grayscale images
            fig, axes = plt.subplots(grid_size, grid_size, figsize=(8, 8))
            axes = axes.flatten()  # Flatten the 2D grid of axes for easy iteration

            for i, ax in enumerate(axes):
                if i < num_images:
                    img = samples[i][0].cpu().numpy()  # Take the grayscale channel
                    ax.imshow(1 - img, cmap="Greys")
                ax.axis("off")  # Turn off axes for all plots, including blanks

            wandb.log({f"GeneratedSamples:": wandb.Image(fig)}, step = i_epoch)
            plt.close(fig)
        else:  # RGB images
            fig, axes = plt.subplots(grid_size, grid_size, figsize=(8, 8))
            axes = axes.flatten()  # Flatten the 2D grid of axes for easy iteration

            for i, ax in enumerate(axes):
                if i < num_images:
                    img = samples[i].permute(1, 2, 0).cpu().numpy()  # Convert from (C, H, W) to (H, W, C)
                    ax.imshow((img * 255).astype("uint8"))  # Ensure RGB is properly scaled
                ax.axis("off")  # Turn off axes for all plots, including blanks

            wandb.log({f"GeneratedSamples:": wandb.Image(fig)}, step = i_epoch)
            plt.close(fig)



### Run the model

In [None]:
### The is mainUNetResBlock.py ###

import wandb
import torch
import os
from utils.dataset_loader import load_dataset
from configs.config import config_MNIST, config_CIFAR10, config_CELEBA
from UNetResBlock.model import UNet
from UNetResBlock.trainer import train_model
import random

config = config_CELEBA

configDSName = config["dataset_name"]
name = f"{configDSName}_{random.randint(100000, 1000000)}"
print (f"starting run: {name}")

# Initialize W&B
logwandb = True

# Use only small subset of data (for debugging)
debugDataSize = False
modelNameTest = "_smallDS" if debugDataSize else ""

save_model = True
model_name = "ResNet"+name+modelNameTest # File that the model is saved as. Only relevant if save_model = True


if logwandb: 
    wandb.init(project=config["project_name"], name=model_name, config=config)

# Load dataset
dataset = load_dataset(config,small_sample=debugDataSize)

channels = config["image_shape"][0]
dim = config["image_shape"][1]

# Initialize model
model = UNet(in_channels=channels, dim = config["image_shape"][1], out_channels=channels)

# Train model
train_model(model, dataset, config, model_name=model_name, log=logwandb, save_model = save_model)



: 

### Make binary files used for FID scoring

In [None]:
### This is generateAndRead/save_model_binary.py ###

import sys
import os

# Add the directory containing the module to the Python path
external_module_path = "/zhome/1a/a/156609/project/path"
sys.path.append(external_module_path)


import torch
import numpy as np

# Configurations
config = config_CELEBA
model_filename = "ResNetCELEBA_518406_ema.pt"
batch_size = 200  # Number of samples to generate per batch
num_samples = 10000  # Number of samples to generate

# idk
image_shape = config["image_shape"]  # Channels, Height, Width

# Define the output folder and ensure it exists
output_folder = "generateAndRead/binSamples"
os.makedirs(output_folder, exist_ok=True)
dataset_name = config["dataset_name"]
output_binary_file = os.path.join(output_folder, f"model_{dataset_name}_{num_samples}samples_ema.bin")

# Load the saved model
model_folder = "savedModels"
model_path = os.path.join(model_folder, model_filename)


channels = config["image_shape"][0]
dim = config["image_shape"][1]
# Initialize and load the model
u_net = UNet(in_channels=channels, dim = config["image_shape"][1], out_channels=channels)
state_dict = torch.load(model_path, weights_only=True)
u_net.load_state_dict(state_dict)
u_net.eval()

# Generate samples in parallel batches
print("Generating samples in parallel...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
u_net.to(device)

# Adjust the shape of `all_samples` to include channels
all_samples = np.zeros((num_samples, *image_shape), dtype=np.uint8)

for i in range(0, num_samples, batch_size):
    current_batch_size = min(batch_size, num_samples - i)
    print(f"Generating batch {i // batch_size + 1}: {current_batch_size} samples...")
    
    # Generate samples for the current batch
    samples = generate_samples(u_net, current_batch_size, image_shape).detach().cpu().numpy()
    
    # Normalize the samples
    samples = (samples - samples.min(axis=(1, 2, 3), keepdims=True)) / \
              (samples.max(axis=(1, 2, 3), keepdims=True) - samples.min(axis=(1, 2, 3), keepdims=True)) * 255
    samples = samples.astype(np.uint8)

    # Store samples directly without squeezing the channel dimension
    all_samples[i:i + current_batch_size] = samples

# Save all samples to a binary file
with open(output_binary_file, "wb") as f:
    f.write(all_samples.tobytes())

print(f"All samples saved to {output_binary_file}.")



### FID scoring

In [None]:
import tensorflow as tf
import tensorflow_hub as tfhub
import tensorflow_gan as tfgan
import numpy as np
from torchvision import datasets, transforms
import torch
from pathlib import Path
from tqdm import tqdm
from PIL import Image
import logging
import os


# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Load the pre-trained MNIST classifier
MNIST_MODULE = "https://tfhub.dev/tensorflow/tfgan/eval/mnist/logits/1"
mnist_classifier_fn = tfhub.load(MNIST_MODULE)

IMAGE_MODULE = "https://www.kaggle.com/models/tensorflow/inception/tensorFlow1/tfgan-eval-inception/1"
image_classifier_fn = tfhub.load(IMAGE_MODULE)

def wrapped_image_classifier_fn(input_tensor):
    # Ensure input_tensor has 3 channels (convert grayscale to RGB)
    if input_tensor.shape[-1] == 1:  # Check if it's grayscale
        input_tensor = tf.image.grayscale_to_rgb(input_tensor)

    # Resize images to 299x299 for Inception
    # input_tensor = tf.image.resize(input_tensor, [299, 299])

    # Pass the processed input to the classifier
    output = image_classifier_fn(input_tensor)
    
    return output['pool_3']  # Use 'pool_3' for FID, or adjust based on requirements




def compute_activations(tensors, num_batches, classifier_fn):
    """
    Given a tensor of of shape (batch_size, height, width, channels), computes
    the activiations given by classifier_fn.
    """
    tensors_list = tf.split(tensors, num_or_size_splits=num_batches)
    stack = tf.stack(tensors_list)
    activation = tf.nest.map_structure(
        tf.stop_gradient,
        tf.map_fn(classifier_fn, stack, parallel_iterations=1, swap_memory=True),
    )
    return tf.concat(tf.unstack(activation), 0)



def read_binary_file(file_path, image_shape, isValidation):
    """
    Reads a binary file containing labeled images.

    Args:
        file_path (str): Path to the binary file.
        image_shape (tuple): Shape of each image (channels, height, width).
        num_images (int): Number of images to read.

    Returns:
        np.ndarray: Array of images of shape (num_images, channels, height, width).
    """
    # Bytes per image: 1 byte for label + image data
    if isValidation:
        bytes_per_image = np.prod(image_shape) + 1
    else:
        bytes_per_image =  np.prod(image_shape)

    with open(file_path, "rb") as f:
        data = np.frombuffer(f.read(), dtype=np.uint8)

    total_bytes = os.path.getsize(file_path)

    num_images =  total_bytes // bytes_per_image - 1

    images = []
    for i in range(num_images):
        start = i * bytes_per_image + 1  # Skip the label
        end = start + np.prod(image_shape)
        image = data[start:end].reshape(image_shape)
        images.append(image)
    return np.array(images)

def dataset_to_numpy(config, validation=False):
    """
    Load a dataset and convert it directly to a NumPy array of unnormalized images.

    Args:
        config (dict): Configuration dictionary specifying dataset parameters.
        validation (bool): Whether to load the validation set.

    Returns:
        np.ndarray: Array of images of shape (num_images, channels, height, width).
    """
    from utils.dataset_loader import load_dataset
    # Load the dataset
    dataset = load_dataset(config, validation=validation)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False)

    # Function to unnormalize images from [-1, 1] to [0, 255]
    def unnormalize(images):
        return ((images * 0.5) + 0.5) * 255

    images = []
    # Iterate over the DataLoader to fetch images
    for image, _ in dataloader:
        # Unnormalize and convert to numpy
        image = unnormalize(image).squeeze(0).byte().numpy()  # Remove batch dimension, scale, and convert
        images.append(image)

    # Convert list of numpy arrays to a single numpy array
    return np.stack(images)

def compute_fid_for_CIFAR_or_CELEBA(dataset, image_shape, validation_load_data):
    """
    Computes FID for CIFAR10 or CELEBA dataset using TensorFlow GAN utilities.

    Args:
        dataset (str): Name of the dataset ('CIFAR10' or 'CELEBA').
        image_shape (tuple): Shape of each image (channels, height, width).
        batch_size (int): Number of images to process in each batch.

    Returns:
        float: Computed FID score.
    """
    # File paths
    if dataset == "CIFAR10":
        generated_path = "generateAndRead/binSamples/model_CIFAR10_10000samples.bin"
        validation_path = "generateAndRead/binSamples/CIFAR10_validation_samples.bin"
        config = config_CIFAR10
    elif dataset == "CELEBA":
        generated_path = "generateAndRead/binSamples/model_CELEBA_10000samples_int.bin"
        validation_path = "generateAndRead/binSamples/CELEBA_validation_samples.bin"
        config = config_CELEBA
    elif dataset == "MNIST":
        generated_path = "generateAndRead/binSamples/model_MNIST_10000samples.bin"
        validation_path = "generateAndRead/binSamples/MNIST_validation_samples.bin"
        config = config_MNIST
    else:
        raise ValueError(f"Unknown dataset: {dataset}")

    tf.debugging.set_log_device_placement(True)

    output_dir = "./output"

    os.makedirs(output_dir, exist_ok=True)


    # Load validation images
    logger.info(f"Reading validation images from binary file...")
    # validation_images = read_binary_file(validation_path, image_shape, isValidation=True)
    validation_images = dataset_to_numpy(config, validation=validation_load_data)
    validation_images = validation_images / 255.0  # Normalize to [0, 1]
    validation_images_tf = tf.convert_to_tensor(validation_images, dtype=tf.float32)
    validation_images_tf = tf.transpose(validation_images_tf, [0, 2, 3, 1])  # Convert to NHWC format

    # Load generated images
    logger.info(f"Reading generated images from binary file...")
    generated_images = read_binary_file(generated_path, image_shape, isValidation=False)
    generated_images = generated_images / 255.0  # Normalize to [0, 1]
    generated_images_tf = tf.convert_to_tensor(generated_images, dtype=tf.float32)
    generated_images_tf = tf.transpose(generated_images_tf, [0, 2, 3, 1])  # Convert to NHWC format

    # Compute activations for validation images
    logger.info("Computing activations for validation images...")
    activations_real = compute_activations(validation_images_tf, num_batches=1, classifier_fn=wrapped_image_classifier_fn)

    # Compute activations for generated images
    logger.info("Computing activations for generated images...")
    activations_fake = compute_activations(generated_images_tf, num_batches=1, classifier_fn=wrapped_image_classifier_fn)

    # Compute FID
    logger.info("Computing FID score...")

    # Reshape activations to rank-2 tensors
    activations_fake = tf.reshape(activations_fake, [activations_fake.shape[0], activations_fake.shape[-1]])
    activations_real = tf.reshape(activations_real, [activations_real.shape[0], activations_real.shape[-1]])

    # Compute FID
    fid_score = tfgan.eval.frechet_classifier_distance_from_activations(activations_real, activations_fake)

    logger.info(f"FID score: {fid_score}")

    # Save activations and results
    np.save(os.path.join(output_dir, "activations_fake.npy"), activations_fake.numpy())
    np.save(os.path.join(output_dir, "activations_real.npy"), activations_real.numpy())

    print(f"Computed FID for {dataset}, val comparison {validation_load_data}: {fid_score.numpy()}")
    return fid_score.numpy()



def compute_fid_for_CIFAR_or_CELEBA_batch(dataset, image_shape, validation_load_data, batch_size=10000):
    """
    Computes FID for CIFAR10 or CELEBA dataset using TensorFlow GAN utilities.

    Args:
        dataset (str): Name of the dataset ('CIFAR10' or 'CELEBA').
        image_shape (tuple): Shape of each image (channels, height, width).
        validation_load_data (bool): Indicates whether to load validation data.
        batch_size (int): Number of images to process in each batch.

    Returns:
        float: Computed FID score.
    """
    # Define file paths
    paths = {
        "CIFAR10": ("model_CIFAR10_10000samples.bin", "CIFAR10_validation_samples.bin", config_CIFAR10),
        "CELEBA": ("model_CELEBA_10000samples_ema.bin", "CELEBA_validation_samples.bin", config_CELEBA),
        "MNIST": ("model_MNIST_10000samples.bin", "MNIST_validation_samples.bin", config_MNIST)
    }
    
    if dataset not in paths:
        raise ValueError(f"Unknown dataset: {dataset}")
    
    generated_file, validation_file, config = paths[dataset]
    generated_path = os.path.join("generateAndRead/binSamples", generated_file)
    validation_path = os.path.join("generateAndRead/binSamples", validation_file)
    
    output_dir = "./output"
    os.makedirs(output_dir, exist_ok=True)

    # Setup logging
    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger(__name__)
    
    # Load validation images
    logger.info("Reading and processing validation images in batches...")
    validation_images = dataset_to_numpy(config, validation=validation_load_data)
    validation_images = validation_images.astype(np.float32) / 255.0  # Normalize and ensure dtype is float32
    validation_dataset = tf.data.Dataset.from_tensor_slices(validation_images).batch(batch_size)

    activations_real_list = []
    for batch_index, batch in enumerate(validation_dataset):
        logger.info(f"Processing batch {batch_index + 1}/{len(validation_dataset)}...")
        batch_tf = tf.transpose(batch, [0, 2, 3, 1])  # Convert to NHWC format
        activations_batch = compute_activations(batch_tf, num_batches=1, classifier_fn=wrapped_image_classifier_fn)
        activations_real_list.append(activations_batch)
        
    activations_real = tf.concat(activations_real_list, axis=0)
    activations_real = tf.reshape(activations_real, [activations_real.shape[0], -1])  # Reshape to rank-2 tensor

    # Load generated images
    logger.info("Reading generated images...")
    generated_images = read_binary_file(generated_path, image_shape, isValidation=False)
    generated_images = generated_images.astype(np.float32) / 255.0  # Normalize and ensure dtype is float32
    generated_images_tf = tf.convert_to_tensor(generated_images, dtype=tf.float32)
    generated_images_tf = tf.transpose(generated_images_tf, [0, 2, 3, 1])  # Convert to NHWC format

    # Compute activations for generated images
    logger.info("Computing activations for generated images...")
    activations_fake = compute_activations(generated_images_tf, num_batches=1, classifier_fn=wrapped_image_classifier_fn)
    activations_fake = tf.reshape(activations_fake, [activations_fake.shape[0], -1])  # Reshape to rank-2 tensor

    # Compute FID
    logger.info("Computing FID score...")
    fid_score = tfgan.eval.frechet_classifier_distance_from_activations(activations_real, activations_fake)

    logger.info(f"FID score: {fid_score}")

    # Save activations and results
    np.save(os.path.join(output_dir, "activations_fake.npy"), activations_fake.numpy())
    np.save(os.path.join(output_dir, "activations_real.npy"), activations_real.numpy())

    print(f"Computed FID for {dataset}, validation comparison {validation_load_data}: {fid_score.numpy()}")
    return fid_score.numpy()



# Example usage, can also be made with MNIST, though might not be as accurate.
fid = compute_fid_for_CIFAR_or_CELEBA_batch(
    dataset="CIFAR10",
    image_shape=(3, 32, 32),
    validation_load_data = False,
)

fid = compute_fid_for_CIFAR_or_CELEBA(
    dataset="CELEBA",
    image_shape=(3, 64, 64),
    validation_load_data = True,
)




