# Java-fusion training

### Essential imports

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
import math
import numpy as np
from torch.utils.data import DataLoader
from torch import Tensor
from typing import Optional

## Model construction

### Resnet block definition (key part of U-Net)

In [2]:
class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, time_embed_dim):
        super().__init__()
        self.time_mlp = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_embed_dim, out_channels)
        )

        self.block1 = nn.Sequential(
            #nn.GroupNorm(8, in_channels),
            nn.SiLU(),
            nn.Conv2d(in_channels, out_channels, 3, padding=1)
        )

        self.block2 = nn.Sequential(
            #nn.GroupNorm(8, out_channels),
            nn.SiLU(),
            nn.Conv2d(out_channels, out_channels, 3, padding=1)
        )

        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels, out_channels, 1)
        else:
            self.shortcut = nn.Identity()

    def forward(self, x, t, debug=False):
        h = self.block1(x)
        h += self.time_mlp(t)[:, :, None, None]
        h = self.block2(h)
        if (debug):
            #print("Inp: \n"+str(x[0,0])) 
            print("Out: \n"+str(h[0,0]))
            print("Plus 'identity: \n" + str((h + self.shortcut(x))[0,0]))
        return h + self.shortcut(x)

### Timestep embedding (not related to Noise scheduler)

In [3]:
def timestep_embedding(timesteps, dim, max_period=10000):
    """Sinusoidal timestep embeddings"""
    half = dim // 2
    freqs = torch.exp(
        -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
    ).to(device=timesteps.device)
    args = timesteps[:, None].float() * freqs[None]
    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
    if dim % 2:
        embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
    return embedding

### Model architecture definition

In [4]:
# =================== SIMPLE U-NET ===================
class SimpleUNet(nn.Module):
    def __init__(self, in_channels=3, model_channels=64, out_channels=3, num_res_blocks=2):
        super().__init__()

        self.time_embed_dim = model_channels * 4

        # Time embedding
        self.time_embed = nn.Sequential(
            nn.Linear(model_channels, self.time_embed_dim),
            nn.SiLU(),
            nn.Linear(self.time_embed_dim, self.time_embed_dim),
        )

        # Encoder
        self.conv_in = nn.Conv2d(in_channels, model_channels, 3, padding=1)

        # Encoder - Fixed channel progression
        self.down1 = nn.ModuleList([
            ResBlock(model_channels, model_channels, self.time_embed_dim),      # 64->64
            ResBlock(model_channels, model_channels, self.time_embed_dim),      # 64->64
            nn.Conv2d(model_channels, model_channels * 2, 3, stride=2, padding=1)  # 64->128
        ])

        self.down2 = nn.ModuleList([
            ResBlock(model_channels * 2, model_channels * 2, self.time_embed_dim),  # 128->128
            ResBlock(model_channels * 2, model_channels * 4, self.time_embed_dim),  # 128->256
            nn.Conv2d(model_channels * 4, model_channels * 4, 3, stride=2, padding=1)  # 256->256
        ])

        # Middle - operates on 256 channels
        self.middle = nn.ModuleList([
            ResBlock(model_channels * 4, model_channels * 4, self.time_embed_dim),  # 256->256
            ResBlock(model_channels * 4, model_channels * 4, self.time_embed_dim),  # 256->256
        ])

        # Decoder - Fixed to handle concatenations properly
        self.up1 = nn.ModuleList([
            nn.ConvTranspose2d(model_channels * 4, model_channels * 4, 4, stride=2, padding=1),  # 256->256, upsample
            ResBlock(model_channels * 4 + model_channels * 4, model_channels * 2, self.time_embed_dim),  # 512->128 (concat+reduce)
            ResBlock(model_channels * 2 + model_channels * 2, model_channels * 2, self.time_embed_dim),  # 256->128 (concat+keep)
        ])

        self.up2 = nn.ModuleList([
            nn.ConvTranspose2d(model_channels * 2, model_channels * 2, 4, stride=2, padding=1),  # 128->128, upsample
            ResBlock(model_channels * 2 + model_channels, model_channels, self.time_embed_dim),      # 192->64 (concat+reduce)
            ResBlock(model_channels + model_channels, model_channels, self.time_embed_dim),          # 128->64 (concat+reduce)
        ])

        self.conv_out = nn.Sequential(
            #nn.GroupNorm(8, model_channels),
            nn.SiLU(),
            nn.Conv2d(model_channels, out_channels, 3, padding=1)
        )

    def forward(self, x, timesteps):
        # Time embedding
        t = self.time_embed(timestep_embedding(timesteps, dim=64))

        # Encoder
        h = self.conv_in(x)  # 3->64
        hs = []

        # Down1: 64->64->64, then downsample to 128
        h = self.down1[0](h, t, False)  # ResBlock: 64->64
        hs.append(h)
        h = self.down1[1](h, t)  # ResBlock: 64->64
        hs.append(h)
        h = self.down1[2](h)     # Downsample: 64->128

        # Down2: 128->128->256, then downsample to 256
        h = self.down2[0](h, t)  # ResBlock: 128->128
        hs.append(h)
        h = self.down2[1](h, t)  # ResBlock: 128->256
        hs.append(h)
        h = self.down2[2](h)     # Downsample: 256->256

        # Middle: 256->256->256
        h = self.middle[0](h, t)
        h = self.middle[1](h, t)

        # Decoder - carefully match the skip connections
        # Up1: 256 + skip connections
        h = self.up1[0](h)  # Upsample: 256->256
        h = torch.cat([h, hs.pop()], dim=1)  # 256+256=512
        h = self.up1[1](h, t)  # ResBlock: 512->128
        h = torch.cat([h, hs.pop()], dim=1)  # 128+128=256
        h = self.up1[2](h, t)  # ResBlock: 256->128

        # Up2: 128 + skip connections
        h = self.up2[0](h)  # Upsample: 128->128
        h = torch.cat([h, hs.pop()], dim=1)  # 128+64=192
        h = self.up2[1](h, t)  # ResBlock: 192->64
        #print(h[0,26])
        h = torch.cat([h, hs.pop()], dim=1)  # 64+64=128
        h = self.up2[2](h, t)  # ResBlock: 128->64

        return self.conv_out(h)

## DDPM framework

### Noise scheduler

In [5]:
# =================== NOISE SCHEDULE ===================

def linear_beta_schedule(timesteps:int, beta_start:float=0.0001, beta_end:float=0.02) -> Tensor:
    """Simple linear noise schedule - most reliable"""
    return torch.linspace(beta_start, beta_end, timesteps)

### Main DDPM class (where all comes together)

In [6]:
# =================== DDPM CLASS ===================
class SimpleDDPM:
    def __init__(self, model:SimpleUNet, timesteps:int=2500, device:str="cuda") -> None:
        self.model = model.to(device)
        self.timesteps = timesteps
        self.device = device

        # Noise schedule
        self.betas = linear_beta_schedule(timesteps).to(device)
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)

        # Pre-compute values for sampling
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)

    def q_sample(self, x_start:Tensor, t:Tensor, noise:Optional[Tensor]=None) -> Tensor:
        """Add noise to images (forward process)"""
        if noise is None:
            noise = torch.randn_like(x_start)

        sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod[t][:, None, None, None]
        sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t][:, None, None, None]

        return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

    def train_step(self, batch:Tensor) -> Tensor:
        """Single training step"""
        x_start = batch.to(self.device)
        batch_size = x_start.shape[0]

        # Sample random timesteps
        t = torch.randint(0, self.timesteps, (batch_size,), device=self.device).long()

        # Add noise
        noise = torch.randn_like(x_start)
        x_noisy = self.q_sample(x_start, t, noise)

        # Predict noise
        predicted_noise = self.model(x_noisy, t)

        # Calculate loss
        loss = F.mse_loss(predicted_noise, noise)
        return loss

    @torch.no_grad()
    def sample(self, shape:tuple, debug:bool=False, \
        init:Tensor=Tensor([]), alphas:Tensor=Tensor([]), betas:Tensor=Tensor([]), alphas_cumprod:Tensor=Tensor([]),  noise:Tensor=Tensor([])) \
        -> Tensor | tuple[np.ndarray, Tensor, Tensor, Tensor, Tensor, Tensor]:
        """Generate samples"""
        device = self.device
        b = shape[0]
        img = torch.randn(shape, device=device) if not init.numel() else init
        ini = img.numpy()
        noiseHist = []

        for i in tqdm(reversed(range(0, self.timesteps)), desc='Sampling'):
            t = torch.full((b,), i, device=device, dtype=torch.long)

            # Predict noise
            predicted_noise = self.model(img, t)

            # Compute coefficients
            alpha = self.alphas[i] if not alphas.numel() else alphas[i]
            alpha_cumprod = self.alphas_cumprod[i] if not alphas_cumprod.numel() else alphas_cumprod[i]
            beta = self.betas[i] if not betas.numel() else betas[i]

            # Update image
            img = (1 / torch.sqrt(alpha)) * (img - ((1 - alpha) / torch.sqrt(1 - alpha_cumprod)) * predicted_noise)

            if i > 0:
                noise = torch.randn_like(img) if not noise.numel() else noise[i]
                if debug:
                    noiseHist.append(noise)
                img += torch.sqrt(beta) * noise

        return img if not debug else (ini, img, self.alphas, self.alphas_cumprod, self.betas, torch.stack(noiseHist, 0))

## Training and inference

### Dataset loading

In [13]:
# =================== TRAINING SETUP ===================
def get_cifar10_dataloader(batch_size=32, class_idx=6):  # 6 = frogs
    """Load CIFAR-10 dataset filtered to single class"""
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Scale to [-1, 1]
    ])

    dataset = torchvision.datasets.CIFAR10(
        root='./data', train=True, download=True, transform=transform
    )

    # Filter to only the desired class
    indices = [i for i, (_, label) in enumerate(dataset) if label == class_idx]
    subset = torch.utils.data.Subset(dataset, indices)

    print(f"Training on {len(indices)} samples from class {class_idx} (ships)")

    return DataLoader(subset, batch_size=batch_size, shuffle=True, num_workers=2)

def get_mnist_dataloader(batch_size=32, digit:int=None):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))  # Single channel normalization
    ])

    dataset = torchvision.datasets.MNIST(
        root='./data', train=True, download=True, transform=transform
    )

    if digit:
      # Filter to only the desired digit
      indices = [i for i, (_, label) in enumerate(dataset) if label == digit]
      subset = torch.utils.data.Subset(dataset, indices)

      print(f"Training on {len(indices)} samples of digit {digit}")

    if digit:
      return DataLoader(subset, batch_size=batch_size, shuffle=True, num_workers=2)

    return DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)

def show_samples(samples):
    """Display generated samples"""
    samples = (samples + 1) / 2  # Convert from [-1, 1] to [0, 1]
    samples = torch.clamp(samples, 0, 1)

    fig, axes = plt.subplots(1, 4, figsize=(12, 3))
    for i in range(4):
        img = samples[i].cpu().permute(1, 2, 0).numpy()
        axes[i].imshow(img)
        axes[i].axis('off')
    plt.show()

### Training loop

In [14]:
def train_ddpm(epochs=10, dataset="cifar-10", digit=None):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

    # Setup
    if dataset == "cifar-10":
      model = SimpleUNet()
      dataloader = get_cifar10_dataloader(batch_size=32, class_idx=8)  # ships
    elif dataset == "mnist":
      model = SimpleUNet(in_channels=1, out_channels=1)
      dataloader = get_mnist_dataloader(batch_size=32, digit=None)

    ddpm = SimpleDDPM(model, timesteps=2500, device=device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)  # Lower initial LR

    # Add learning rate scheduler
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)


    # Training loop
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        progress_bar = tqdm(dataloader, desc=f'Epoch {epoch+1}/{epochs}')

        for batch, _ in progress_bar:
            optimizer.zero_grad()
            loss = ddpm.train_step(batch)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            progress_bar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'lr': f'{scheduler.get_last_lr()[0]:.6f}'
            })

        # Step the scheduler after each epoch
        scheduler.step()

        avg_loss = total_loss / len(dataloader)
        print(f'Epoch {epoch+1} - Average Loss: {avg_loss:.4f} - LR: {scheduler.get_last_lr()[0]:.6f}')

        # Generate samples every 5 epochs
        if (epoch + 1) % 5 == 0:
            model.eval()
            if dataset=="cifar-10":
              samples = ddpm.sample((4, 3, 32, 32))
            elif dataset=="mnist":
              samples = ddpm.sample((4, 1, 28, 28))
            show_samples(samples)
            model.train()

    return ddpm

In [15]:
# Train the model
ddpm_mnist = train_ddpm(epochs=50, dataset="mnist")

Using device: cpu


NameError: name 'transforms' is not defined

### Inference

In [16]:
# Generate final samples
ddpm_mnist.model.eval()
samples = ddpm_mnist.sample((1, 1, 28, 28))
show_samples(samples)

NameError: name 'ddpm_mnist' is not defined

In [None]:
sum(p.numel() for p in ddpm_mnist.model.parameters() if p.requires_grad)

7706561

## Saving and loading!

### Save model

In [None]:
# Save the current trained model
torch.save({
    'model_state_dict': ddpm_mnist.model.state_dict(),
    'timesteps': ddpm_mnist.timesteps,
    'epoch': 50,  # or however many epochs you trained
}, 'ddpm_mnist_model.pth')

print("Model saved!")

Model saved!


### Load model

In [7]:
# Sample model loading
def load_trained_model(filepath:str='ddpm_fast_mnist_model.pth') -> SimpleDDPM:
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Create model architecture
    model = SimpleUNet(in_channels=1, out_channels=1)
    ddpm = SimpleDDPM(model, timesteps=5, device=device)
    
    # Load saved weights
    checkpoint = torch.load(filepath, map_location=device)
    ddpm.model.load_state_dict(checkpoint['model_state_dict'])
    
    print(f"Model loaded from {filepath}")
    return ddpm

# Usage:
loaded_ddpm = load_trained_model()
#ini, samples, alpha, alpha_cum, beta, noise = loaded_ddpm.sample((1, 1, 28, 28), True)
#show_samples(samples)

Model loaded from ddpm_fast_mnist_model.pth


In [25]:
loaded_ddpm.sample((1, 1, 28, 28), False, torch.from_numpy(ini),alpha, beta, alpha_cum,\
  noise)

Sampling: 0it [00:00, ?it/s]


RuntimeError: output with shape [1, 1, 28, 28] doesn't match the broadcast shape [4, 1, 1, 28, 28]

In [12]:
np.save("noise.npy", noise.numpy())
np.save("init.npy", ini)
np.save("out.npy", samples.numpy())
np.save("alpha.npy", alpha.numpy())
np.save("alpha_cum.npy", alpha_cum.numpy())
np.save("beta.npy", beta.numpy())

In [14]:
np.load('model_out.bin', allow_pickle=True)

UnpicklingError: Failed to interpret file 'model_out.bin' as a pickle

In [None]:
# Load your model checkpoint
checkpoint = torch.load('ddpm_fast_mnist_model.pth', map_location='cpu')
state_dict = checkpoint.get('state_dict', checkpoint)

with open('layer_mapping.txt', 'w') as f:
    for key, tensor in state_dict['model_state_dict'].items():
        filename = 'weights/' + key.replace('.', '_') + '.npy'
        npy_array = tensor.cpu().numpy()
        np.save(filename, npy_array)

        shape_str = str(list(npy_array.shape))
        f.write(f"{key} -> {filename} (shape: {shape_str})\n")
