In [11]:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from data import load_dataset_and_make_dataloaders
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

In [12]:
class Model(nn.Module):
    def __init__(
        self,
        image_channels: int,
        nb_channels: int,
        num_blocks: int,
        cond_channels: int,
    ) -> None:
        super().__init__()
        self.noise_emb = NoiseEmbedding(cond_channels)
        self.conv_in = nn.Conv2d(image_channels, nb_channels, kernel_size=3, padding=1)
        self.blocks = nn.ModuleList([
            ResidualBlock(nb_channels, cond_channels) for _ in range(num_blocks)
        ])
        self.conv_out = nn.Conv2d(nb_channels, image_channels, kernel_size=3, padding=1)
    
    def forward(self, noisy_input: torch.Tensor, c_noise: torch.Tensor) -> torch.Tensor:
        # Generate noise embedding
        cond = self.noise_emb(c_noise)  # [batch_size, cond_channels]

        # Pass through input convolution
        x = self.conv_in(noisy_input)

        # Pass through residual blocks with noise conditioning
        for block in self.blocks:
            x = block(x, cond)

        # Output convolution to return the denoised image
        return self.conv_out(x)


class NoiseEmbedding(nn.Module):
    """
    The NoiseEmbedding module generates a sinusoidal embedding for a given noise level.
    It takes a 1D tensor representing the noise level and produces a 2D tensor with
    concatenated cosine and sine values of the noise level scaled by a learned weight.
    
    This embedding can be used to condition the model on different noise levels, which
    is useful in tasks such as denoising or generative modeling where the noise level
    plays a significant role.
    """
    def __init__(self, cond_channels: int) -> None:
        super().__init__()
        assert cond_channels % 2 == 0
        self.register_buffer('weight', torch.randn(1, cond_channels // 2))
    
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        assert input.ndim == 1
        f = 2 * torch.pi * input.unsqueeze(1) @ self.weight
        return torch.cat([f.cos(), f.sin()], dim=-1)


class ResidualBlock(nn.Module):
    def __init__(self, nb_channels: int, cond_channels: int) -> None:
        super().__init__()
        self.norm1 = nn.BatchNorm2d(nb_channels)
        self.conv1 = nn.Conv2d(nb_channels, nb_channels, kernel_size=3, stride=1, padding=1)
        self.norm2 = nn.BatchNorm2d(nb_channels)
        self.conv2 = nn.Conv2d(nb_channels, nb_channels, kernel_size=3, stride=1, padding=1)

        # Add a learnable linear layer to project the noise embedding
        self.noise_projection = nn.Linear(cond_channels, nb_channels)

    def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
        # Project noise embedding to match the channel dimension of x
        noise_emb = self.noise_projection(cond).unsqueeze(-1).unsqueeze(-1)
        
        # Add noise embedding to feature map before convolutions
        y = self.conv1(F.relu(self.norm1(x + noise_emb)))
        y = self.conv2(F.relu(self.norm2(y)))
        return x + y



In [31]:

#Step 1.0: Sample an image y from the dataset:
gpu = torch.cuda.is_available()
device = torch.device('cuda:0' if gpu else 'cpu')
 
dl, info = load_dataset_and_make_dataloaders(
    dataset_name='FashionMNIST',
    root_dir='data', # choose the directory to store the data 
    batch_size=32,
    num_workers=0,   # you can use more workers if you see the GPU is waiting for the batches
    pin_memory=gpu,  # use pin memory if you're planning to move the data to GPU
)
sigma_data = info.sigma_data.to(device)  # Move sigma_data to GPU
train_dl = dl.train
test_dl = dl.valid
#Step 1.1: Sample sigma from pnoise
def sample_sigma(n, loc=-1.2, scale=1.2, sigma_min=2e-3, sigma_max=80):
    return (torch.randn(n, device=device) * scale + loc).exp().clip(sigma_min, sigma_max)


#Step 1.2: Add noise
def add_noise_to_image(image, sigma, p_noise_mean=0, p_noise_std=1):
    epsilon = torch.randn(image.shape, device=image.device) * p_noise_std + p_noise_mean
    noisy_image = image + sigma * epsilon
    noisy_image = torch.clamp(noisy_image, 0, 1)  # Ensure pixel values are within [0, 1]
    return noisy_image

#Step 1.3: Compute coefficients
def c_in(sigma):
    return 1/(torch.sqrt(sigma_data**2 + sigma**2))

def c_out(sigma):
    return sigma*sigma_data/(torch.sqrt(sigma_data**2 + sigma**2))

def c_skip(sigma):
    return sigma_data**2/(sigma_data**2 + sigma**2)

def c_noise(sigma):
    return torch.log(sigma)/4
train_target = []
sigma_batch = []
noisy_train_input = []
noisy_test_input = []
s = True
for images, _ in train_dl:
    images = images.to(device)  # Move images to GPU
    if s:
        batch_siz = images.shape[0]
        s = False
    #For each batch, a different sigma
    sigma = sample_sigma(1).to(device)  # Ensure sigma is on GPU
    sigma_batch.append(sigma)
    for sub_image in images:
        noised_image = add_noise_to_image(sub_image, sigma)
        noisy_train_input.append(noised_image * c_in(sigma))
        train_target.append((sub_image - c_skip(sigma) * noised_image) / c_out(sigma))

for images, _ in test_dl:
    images = images.to(device)  # Move images to GPU
    for sub_image in images:
        #For each image, a different sigma, not storing the sigma this time
        sigma = sample_sigma(1).to(device)  # Ensure sigma is on GPU
        noised_image = add_noise_to_image(sub_image, sigma)
        noisy_test_input.append(noised_image)

noisy_train_input = torch.stack(noisy_train_input).float().to(device)  # Move to GPU
train_target = torch.stack(train_target).float().to(device)  # Move to GPU
noisy_test_input = torch.stack(noisy_test_input).float().to(device)  # Move to GPU
sigma_batch = torch.stack(sigma_batch).float().to(device)  # Move to GPU
c_noise_list = c_noise(sigma_batch).to(device)  # Move to GPU

train_dataset = TensorDataset(noisy_train_input, train_target)
train_loader = DataLoader(train_dataset, batch_size=batch_siz, shuffle=False)


In [None]:
eta = 0.1
image_channels = 1  # FashionMnist images have 1 channel
nb_channels = 16
num_blocks = 4
cond_channels = 8
model = Model(image_channels, nb_channels, num_blocks, cond_channels).to(device)
# Define the loss function and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=eta)

# Training loop
nb_epochs = 100
for epoch in tqdm(range(nb_epochs), desc="Training"):
    model.train()
    total_loss = 0
    for batch_idx, (x_batch, t_batch) in enumerate(train_loader):
        x_batch = x_batch.to(device)  # Move to GPU
        t_batch = t_batch.to(device)  # Move to GPU
        # Zero the gradients
        optimizer.zero_grad()

        # Forward pass
        output = model(x_batch, c_noise_list[batch_idx])

        # Compute loss over the batch, returns average loss per sample in the batch
        loss = criterion(output, t_batch)

        # Backward pass
        loss.backward()

        # Update weights
        optimizer.step()

        # Accumulate loss
        total_loss += loss.item()


In [None]:
# Testing loop
model.eval()
def build_sigma_schedule(steps, rho=7, sigma_min=2e-3, sigma_max=80):
    min_inv_rho = sigma_min ** (1 / rho)
    max_inv_rho = sigma_max ** (1 / rho)
    sigmas = (max_inv_rho + torch.linspace(0, 1, steps) * (min_inv_rho - max_inv_rho)) ** rho
    return sigmas

def D(x, sigma):
    return c_skip(sigma) * x + c_out(sigma) * model(c_in(sigma) * x, c_noise(sigma))

sigmas = build_sigma_schedule(50).to(device)  # Ensure sigmas are on GPU
def euler_step(x):
    batch_size = x.size(0)
    for i, sigma in enumerate(sigmas):
        sigma = sigma.to(device)  # Ensure sigma is on GPU
        sigma = sigma.expand(batch_size)  # Expand sigma to match batch size
        with torch.no_grad():
            x_denoised = D(x, sigma)  
            # Where D(x, sigma) = cskip(sigma) * x + cout(sigma) * F(cin(sigma) * x, cnoise(sigma)) 
            # and F(.,.) is your neural network
        
        sigma_next = sigmas[i + 1] if i < len(sigmas) - 1 else 0
        d = (x - x_denoised) / sigma
        
        x = x + d * (sigma_next - sigma)  # Perform one step of Euler's method
    return x

# Visualize the first 5 noisy test images and their denoised versions
fig, axes = plt.subplots(5, 2, figsize=(10, 20))
for i in range(5):
    noisy_image = noisy_test_input[i].unsqueeze(0)  # Add batch dimension
    denoised_image = euler_step(noisy_image).squeeze(0)  # Remove batch dimension

    # Plot noisy image
    axes[i, 0].imshow(noisy_image.squeeze().cpu().numpy(), cmap='gray')
    axes[i, 0].set_title("Noisy Image")
    axes[i, 0].axis('off')

    # Plot denoised image
    axes[i, 1].imshow(denoised_image.squeeze().cpu().numpy(), cmap='gray')
    axes[i, 1].set_title("Denoised Image")
    axes[i, 1].axis('off')

plt.show()