In [2]:
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
import numpy as np
import matplotlib.pyplot as plt
import torch
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import torchvision
import torchvision.transforms as transforms
print(torch.cuda.is_available())  # Should return True if CUDA is enabled
print(torch.__version__)  # Check the installed PyTorch version
print(torch.version.cuda)  # Check CUDA version used by PyTorch

True
2.4.1
12.4


In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

# ====================================
# Fourier Time Embeddings (NCSN++-Style)
# ====================================
class FourierTimeEmbedding(nn.Module):
    def __init__(self, embedding_size=64, scale=1.0):
        super().__init__()
        self.W = nn.Parameter(torch.randn(embedding_size // 2) * scale, requires_grad=False)
        self.mlp = nn.Sequential(
            nn.Linear(embedding_size, embedding_size),
            nn.SiLU(),
            nn.Linear(embedding_size, embedding_size)
        )

    def forward(self, t):
        t_proj = t[:, None] * self.W[None, :] * 2 * np.pi
        emb = torch.cat([torch.sin(t_proj), torch.cos(t_proj)], dim=-1)
        return self.mlp(emb)

class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, timesteps):
        device = timesteps.device
        half_dim = self.dim // 2
        emb = torch.log(torch.tensor(10000.0, device=device)) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=device) * -emb)
        emb = timesteps[:, None] * emb[None, :]
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
        return emb

class SelfAttention2D(nn.Module):
    def __init__(self, channels, num_heads=4):
        super().__init__()
        self.norm = nn.GroupNorm(8, channels)
        self.attn = nn.MultiheadAttention(embed_dim=channels, num_heads=num_heads, batch_first=True)

    def forward(self, x):
        B, C, H, W = x.shape
        x_norm = self.norm(x)
        x_reshaped = x_norm.view(B, C, -1).permute(0, 2, 1)
        attn_out, _ = self.attn(x_reshaped, x_reshaped, x_reshaped)
        attn_out = attn_out.permute(0, 2, 1).view(B, C, H, W)
        return x + attn_out

class ResidualBlock2D(nn.Module):
    def __init__(self, in_channels, out_channels, time_emb_dim, use_attention=False):
        super().__init__()
        num_groups = min(32, in_channels // 4)
        self.norm1 = nn.GroupNorm(num_groups, in_channels)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)

        num_groups = min(32, out_channels // 4)
        self.norm2 = nn.GroupNorm(num_groups, out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)

        self.residual = nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity()
        self.time_mlp = nn.Linear(time_emb_dim, out_channels)
        self.use_attention = use_attention
        if use_attention:
            self.attention = SelfAttention2D(out_channels)

    def forward(self, x, temb):
        h = self.norm1(x)
        h = F.silu(h)
        h = self.conv1(h)
        
        t_emb = self.time_mlp(F.silu(temb)).view(temb.shape[0], -1, 1, 1)
        h = h + t_emb
        
        h = self.norm2(h)
        h = F.silu(h)
        h = self.conv2(h)
        
        if self.use_attention:
            h = self.attention(h)
        
        return h + self.residual(x)

class UNet2D(nn.Module):
    def __init__(self, in_channels=3, base_channels=64, out_channels=1, time_emb_dim=64):
        super().__init__()
        self.time_emb = nn.Sequential(
            SinusoidalPosEmb(time_emb_dim),
            nn.Linear(time_emb_dim, time_emb_dim * 4),
            nn.SiLU(),
            nn.Linear(time_emb_dim * 4, time_emb_dim),
        )
        self.time_embedding = FourierTimeEmbedding(time_emb_dim)
        self.init_conv = nn.Conv2d(in_channels, base_channels, kernel_size=3, padding=1)
        
        self.down1 = ResidualBlock2D(base_channels, base_channels * 2, time_emb_dim, use_attention=False)
        self.down2 = ResidualBlock2D(base_channels * 2, base_channels * 4, time_emb_dim, use_attention=True)
        
        self.bottleneck = ResidualBlock2D(base_channels * 4, base_channels * 4, time_emb_dim, use_attention=True)
        
        self.up1 = ResidualBlock2D(base_channels * 4, base_channels * 2, time_emb_dim, use_attention=True)
        self.up2 = ResidualBlock2D(base_channels * 2, base_channels, time_emb_dim, use_attention=False)
        
        self.final_conv = nn.Conv2d(base_channels, out_channels, kernel_size=3, padding=1)

    def forward(self, x, t):
        temb = self.time_embedding(t)
        x = self.init_conv(x)
        x1 = self.down1(x, temb)
        x2 = self.down2(x1, temb)

        x = self.bottleneck(x2, temb)
        x = self.up1(x + x2, temb)
        x = self.up2(x + x1, temb)

        return self.final_conv(x)


In [4]:

#print(model_pas)

In [5]:
# ====================================
# 2. Function to Sample Noisy Data at Time t
# ====================================
def sample_data_at_time_t(x0, T, t):
    t = t.view(-1, 1, 1)  # Reshape for broadcasting
    mean_t = x0 * torch.exp(-t)  # Mean at time t
    var_t = T * (1 - torch.exp(-2 * t)) + 1e-8  # 🔥 Prevents division by zero
    # Sample from Normal distribution
    noise = torch.randn_like(x0) * torch.sqrt(var_t)
    x_t = mean_t + noise
    return x_t

# 3. Def Loss Function
def diffusion_loss(model, x0, T, t):
    device = "cuda"  # Ensure everything is on the same device
    x0 = x0.to(device)
    # Reshape t to broadcast correctly (B, 1, 1, 1) for MNIST images (B, 1, 28, 28)
    t = t.view(-1, 1, 1, 1).to(device)

    # Compute mean and variance at time t
    mean_t = x0 * torch.exp(-t)  # Mean at time t
    var_t = T * (1 - torch.exp(-2 * t)) + 1e-8  # Avoid division by zero
    var_t = var_t.view(-1, 1, 1, 1)  # Ensure proper shape for broadcasting

    # Sample noise from a standard normal distribution
    batch_randn = torch.randn_like(x0).to(device)
    noise = batch_randn * torch.sqrt(var_t)

    # Sample noisy image at time t
    xt = mean_t + noise

    # Compute correction term
    l = batch_randn  # Since noise is added, the true noise should match this

    # Compute loss
    with torch.amp.autocast("cuda"):  # Use mixed precision for efficiency
        pred_noise = model(xt, t)
        loss = (-torch.sqrt(var_t) * pred_noise - l) ** 2
        loss = torch.mean(torch.sum(loss.reshape(loss.shape[0], -1), dim=-1))

    return loss

In [6]:
# ====================================
# 1. Generate Correlated Gaussian Noise
# ====================================
def generate_correlated_noise(covariance_matrix, image_shape):
    """
    Generate correlated Gaussian noise for MNIST-shaped data.
    Args:
        covariance_matrix (torch.Tensor): A [batch_size, 2, 2] covariance matrix.
        image_shape (tuple): Target image shape, e.g., (1, 28, 28).
    Returns:
        torch.Tensor: Correlated noise of shape [batch_size, 2, 28, 28].
    """
    batch_size = covariance_matrix.shape[0]
    height, width = image_shape[-2], image_shape[-1]  # Extract MNIST dimensions
    covariance_matrix = covariance_matrix.view(batch_size, 2, 2)  # Explicitly reshape it
    # Perform Cholesky decomposition
    L = torch.linalg.cholesky(covariance_matrix)  # Shape: [batch_size, 2, 2]

    # Generate uncorrelated standard normal noise with MNIST shape
    uncorrelated_noise = torch.randn(batch_size, 2, height, width).to(covariance_matrix.device)  # [B, 2, 28, 28]

    # Reshape for batch matrix multiplication (bmm requires 3D tensors)
    uncorrelated_noise = uncorrelated_noise.view(batch_size, 2, -1)  # Reshape to [B, 2, 784]

    # Apply the Cholesky factor to obtain correlated noise
    correlated_noise = torch.bmm(L, uncorrelated_noise)  # Shape: [B, 2, 784]

    # Reshape back to image format
    correlated_noise = correlated_noise.view(batch_size, 2, height, width)  # [B, 2, 28, 28]

    return correlated_noise


# ====================================
# 2. Compute Covariance Matrix Components
# ====================================
def compute_covariance(Tp, Ta, tau, k, t):
    a = torch.exp(-k * t)
    b = torch.exp(-t / tau)
    Tx = Tp
    Ty = Ta / (tau * tau)
    w = 1 / tau
    M11 = (1/k)*Tx*(1-a*a) + (1/k)*Ty*( 1/(w*(k+w)) + 4*a*b*k/((k+w)*(k-w)**2) - (k*b*b + w*a*a)/(w*(k-w)**2) )
    M12 = (Ty/(w*(k*k - w*w))) * ( k*(1-b*b) - w*(1 + b*b - 2*a*b) )
    M22 = (Ty/w)*(1-b*b)
    # Stack elements correctly to match batch dimension
    cov_matrix = torch.stack([
        torch.cat([M11, M12], dim=1),  # First row: [M11, M12]
        torch.cat([M12, M22], dim=1)   # Second row: [M12, M22]
    ], dim=1)  # Stack along second dimension to form [batch, 2, 2]
    return cov_matrix  # Shape: (batch_size, 2, 2)

# ====================================
# 3. Sample Noisy Data at Time t
# ====================================
def sample_data_active(x0, eta0, t, Tp, Ta, tau, k):
    device = x0.device  # Ensure everything stays on the correct device

    # Compute covariance matrix
    cov_mat = compute_covariance(Tp, Ta, tau, k, t)  # Shape: [batch_size, 2, 2]

    # Compute mean values at time t
    a = torch.exp(-k * t)
    b = (torch.exp(-t / tau) - torch.exp(-k * t)) / (k - (1 / tau))
    c = torch.exp(-t / tau)

    mean_x = a * x0 + b * eta0  # Shape: [batch_size, 1, 28, 28]
    mean_eta = c * eta0  # Shape: [batch_size, 1, 28, 28]

    # Generate correlated Gaussian noise
    noise = generate_correlated_noise(cov_mat, x0.shape[-3:])  # Shape: [batch_size, 2, 28, 28]

    # Extract noise for x and eta
    noise_x = noise[:, 0, :, :].unsqueeze(1)  # Shape: [batch_size, 1, 28, 28]
    noise_eta = noise[:, 1, :, :].unsqueeze(1)  # Shape: [batch_size, 1, 28, 28]

    return mean_x + noise_x, mean_eta + noise_eta, noise


# ====================================
# 4. Define Loss Function
# ====================================

def diffusion_loss_active(model_eta, x0, eta0, Tp, Ta, tau, k, t):
    device = "cuda"  # Ensure everything is on the same device
    x0, eta0 = x0.to(device), eta0.to(device) 
    t = t.view(-1, 1, 1, 1).to(device)  # Reshape t properly
    Tp, Ta, tau, k = Tp.to(device), Ta.to(device), tau.to(device), k.to(device)
    
    # Sample noisy data
    xt, etat, batch_randn = sample_data_active(x0, eta0, t, Tp, Ta, tau, k)
    
    # Concatenate along the channel dimension
    xin = torch.cat((xt, etat), dim=1).float()  # Assuming input shape is [B, C, H, W]
    
    # Compute mean adjustments
    a = torch.exp(-k * t)
    b = (torch.exp(-t / tau) - torch.exp(-k * t)) / (k - (1 / tau))
    c = torch.exp(-t / tau)
    
    # Compute covariance matrix components
    M = compute_covariance(Tp, Ta, tau, k, t)
    M11 = M[:, 0, 0].view(-1, 1, 1, 1)
    M12 = M[:, 0, 1].view(-1, 1, 1, 1)
    M22 = M[:, 1, 1].view(-1, 1, 1, 1)
    det = M11 * M22 - M12 * M12
    
    a, b, c = a.view(-1, 1, 1, 1), b.view(-1, 1, 1, 1), c.view(-1, 1, 1, 1)
    
    # Compute loss terms
    Feta = torch.sqrt(1 / det) * (-M11 * (etat - c * eta0) + M12 * (xt - a * x0 - b * eta0))
    scr_eta = torch.sqrt(det) * model_eta(xin, t)
    
    # Compute loss
    loss_eta = torch.mean((scr_eta - Feta) ** 2)
    
    return loss_eta

In [14]:
restart_epoch = 0
T_var = 1.0
Tp_var = 1e-3
Ta_var = 1.0
tau_var = 0.5


device = "cuda"
model_pas = UNet2D(in_channels=1, base_channels=16, out_channels=1)
model_pas = model_pas.to(device)
model_act = UNet2D(in_channels=2, base_channels=16, out_channels=1)
model_act = model_act.to(device)


fname_pas = "Passive_models_T_{}".format(T_var)
fname_act = "Active_models_Ta_{}_tau_{}".format(Ta_var, tau_var)
if not os.path.exists(fname_pas):
    os.mkdir(fname_pas)
else:
    if os.path.exists('{}/model_at_epoch_{}.pth'.format(fname_pas, restart_epoch)):
        checkpoint_pas = torch.load('{}/model_at_epoch_{}.pth'.format(fname_pas, restart_epoch), weights_only=True)
        model_pas.load_state_dict(checkpoint_pas['model_state_dict'])
        print("Passive model loaded")
        
if not os.path.exists(fname_act):
    os.mkdir(fname_act)
else:
    if os.path.exists('{}/model_at_epoch_{}.pth'.format(fname_act, restart_epoch)):
        checkpoint_act = torch.load('{}/model_at_epoch_{}.pth'.format(fname_act, restart_epoch), weights_only=True)
        model_act.load_state_dict(checkpoint_act['model_state_dict'])
        print("Active model loaded")

'\nfname_pas = "Passive_models_T_{}".format(T_var)\nfname_act = "Active_models_Ta_{}_tau_{}".format(Ta_var, tau_var)\nif not os.path.exists(fname_pas):\n    os.mkdir(fname_pas)\nelse:\n    if os.path.exists(\'{}/model_at_epoch_{}.pth\'.format(fname_pas, restart_epoch)):\n        checkpoint_pas = torch.load(\'{}/model_at_epoch_{}.pth\'.format(fname_pas, restart_epoch), weights_only=True)\n        model_pas.load_state_dict(checkpoint_pas[\'model_state_dict\'])\n        print("Passive model loaded")\n        \nif not os.path.exists(fname_act):\n    os.mkdir(fname_act)\nelse:\n    if os.path.exists(\'{}/model_at_epoch_{}.pth\'.format(fname_act, restart_epoch)):\n        checkpoint_act = torch.load(\'{}/model_at_epoch_{}.pth\'.format(fname_act, restart_epoch), weights_only=True)\n        model_act.load_state_dict(checkpoint_act[\'model_state_dict\'])\n        print("Active model loaded")\n'

In [15]:
restart_epoch = 0
T_var = 1.0
Tp_var = 1e-3
Ta_var = 1.0
tau_var = 0.5


device = "cuda"
model_pas = UNet2D(in_channels=1, base_channels=16, out_channels=1)
model_pas = model_pas.to(device)
model_act = UNet2D(in_channels=2, base_channels=16, out_channels=1)
model_act = model_act.to(device)

In [16]:
# Define transforms (convert images to tensor and normalize)
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert to [0,1] tensor
    transforms.Normalize((0.5,), (0.5,))  # Normalize to [-1,1]
])

# Download and load the dataset
train_dataset = torchvision.datasets.MNIST(root="./data", train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.MNIST(root="./data", train=False, transform=transform, download=True)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [17]:
len(train_loader)

938

In [18]:
# ====================================
# 5. Initialize Model & Train with Mini-batch SGD
# ====================================
# Define the optimizer (SGD instead of Adam)
#optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)  # 🔥 SGD with Momentum
device = "cuda"
optimizer_pas = optim.Adam(model_pas.parameters(), lr=1e-3)  # 🔥 SGD with Momentum
optimizer_act = optim.Adam(model_act.parameters(), lr=1e-3)

#scaler = torch.cuda.amp.GradScaler()  # 🔥 Mixed precision training
#scaler = torch.amp.GradScaler("cuda")
scaler = torch.amp.GradScaler(device)
T = torch.tensor(1.0, device=device)
Tp = torch.tensor(1e-3, device=device)
Ta = torch.tensor(1.0, device=device)
tau = torch.tensor(0.5, device=device)
k = torch.tensor(1.0, device=device)

# Create Mini-batches using DataLoader
nsamples = len(train_dataset)
batch_size = 64 # Define mini-batch size
num_epochs = 10
import torch.amp as amp
dataloader = train_loader

# Use the new API, specifying the device type explicitly.
#scaler = amp.GradScaler("cuda")
scaler = amp.GradScaler(device)
for epoch in range(num_epochs):
    total_loss_pas = 0.0
    total_loss_act = 0.0
    num_batches = len(dataloader)
    # Generate Gaussian noise with shape (nsamples, 28, 28)
    gaussian_data_eta = np.random.normal(
        loc=0,  # Mean
        scale=np.sqrt(Ta.cpu().numpy() / tau.cpu().numpy()),  # Standard deviation
        size=(nsamples, 28, 28)  # Shape: (nsamples, 28, 28)
    )

    g_tensor_eta = torch.tensor(gaussian_data_eta, dtype=torch.float32).view(-1, 1, 28, 28)
    # Wrap it back into a TensorDataset
    dataset_eta = TensorDataset(g_tensor_eta)
    dataloader_eta = DataLoader(dataset_eta, batch_size=batch_size, shuffle=True)

    for (x0_batch, _), (eta0_batch), dat_idx in zip(dataloader, dataloader_eta, range(len(train_loader))):
        x0 = x0_batch.to(device=device)
        eta0 = eta0_batch[0].to(device=device)
        t = 1e-3 + (1.0 - 1e-3) * torch.rand(x0.shape[0], device=device)
        t = t.view(-1,1)
        #t = torch.rand(x0.shape[0], device=device)
        # Use the new autocast API with the device specified.
        with torch.amp.autocast(device):
            loss_pas = diffusion_loss(model_pas, x0, T, t)
            loss_act = diffusion_loss_active(model_act, x0, eta0, Tp, Ta, tau, k, t)

        if torch.isnan(loss_pas) or torch.isnan(loss_act):
            print(f"⚠️ Loss became NaN at epoch {epoch}, stopping training!")
            break

        optimizer_pas.zero_grad()
        optimizer_act.zero_grad()
        loss_pas.backward()
        loss_act.backward()
        optimizer_pas.step()
        optimizer_act.step()
        #scaler.scale(loss_pas).backward()
        #scaler.unscale_(optimizer_pas)
        #torch.nn.utils.clip_grad_norm_(model_pas.parameters(), max_norm=10.0)
        #scaler.step(optimizer_pas)
        #scaler.scale(loss_act).backward()
        #scaler.unscale_(optimizer_act)
        #torch.nn.utils.clip_grad_norm_(model_act.parameters(), max_norm=10.0)
        #scaler.step(optimizer_act)
        #scaler.update()
        print(f"Epoch [{epoch}/{num_epochs}, {dat_idx}/{len(train_loader)}], Intermidiate Loss, Passive: {loss_pas.item()}, Active: {loss_act.item()}")
        total_loss_pas += loss_pas.item()
        total_loss_act += loss_act.item()

        if (dat_idx+1)%10 == 0:
            torch.save({
            'model_state_dict': model_pas.state_dict(),  # Model weights
            'epoch': epoch,  # Epoch
        }, 'Passive_models_T_{}/model_at_epoch_{}_datidx_{}.pth'.format(T, epoch, dat_idx))
        if (dat_idx+1)%10 == 0:
            torch.save({
            'model_state_dict': model_act.state_dict(),  # Model weights
            'epoch': epoch,  # Epoch
        }, 'Active_models_Ta_{}_tau_{}/model_at_epoch_{}_datidx_{}.pth'.format(Ta, tau, epoch, dat_idx))

    avg_loss_pas = total_loss_pas / num_batches
    avg_loss_act = total_loss_act / num_batches
    avg_loss_act = 0.0
    print(f"Epoch [{epoch}/{num_epochs}], Average Loss, Passive: {avg_loss_pas:.6f}, Active: {avg_loss_act:.6f}")


Epoch [0/10, 0/938], Intermidiate Loss, Passive: 862.3075561523438, Active: 0.19194954633712769
Epoch [0/10, 1/938], Intermidiate Loss, Passive: 866.6494750976562, Active: 0.3731131851673126
Epoch [0/10, 2/938], Intermidiate Loss, Passive: 320.67596435546875, Active: 0.12253031134605408
Epoch [0/10, 3/938], Intermidiate Loss, Passive: 352.5382385253906, Active: 0.11905665695667267
Epoch [0/10, 4/938], Intermidiate Loss, Passive: 293.0059814453125, Active: 0.12519028782844543
Epoch [0/10, 5/938], Intermidiate Loss, Passive: 258.12744140625, Active: 0.08196710795164108
Epoch [0/10, 6/938], Intermidiate Loss, Passive: 253.77542114257812, Active: 0.08357680588960648
Epoch [0/10, 7/938], Intermidiate Loss, Passive: 228.73153686523438, Active: 0.10557541996240616
Epoch [0/10, 8/938], Intermidiate Loss, Passive: 188.74935913085938, Active: 0.09194540232419968
Epoch [0/10, 9/938], Intermidiate Loss, Passive: 229.35169982910156, Active: 0.06588846445083618
Epoch [0/10, 10/938], Intermidiate Los

KeyboardInterrupt: 