# nflows package

In [110]:
# Classes implemented in the previous task

from nflows.transforms import AffineCouplingTransform
from nflows.transforms.base import CompositeTransform
from nflows.flows.base import Flow
from nflows.distributions.normal import StandardNormal

In [114]:
# Let us try more expressive transforms!

from nflows.transforms.autoregressive import MaskedAffineAutoregressiveTransform   # MAF (from the lecture)
from nflows.transforms import PiecewiseRationalQuadraticCouplingTransform  # neural spline flows https://arxiv.org/abs/1906.04032


from nflows.transforms import RandomPermutation
from nflows.transforms import BatchNorm  # introduced along with the Real-NVP

# MNIST dataset

### Define dataset

In [84]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.optim import Adam, AdamW
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
import numpy as np


# Define the MNIST dataset with preprocessing
class NoisyMNIST(datasets.MNIST):
    def __init__(self, *args, noise_std=0.1, subset_digits=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.noise_std = noise_std
        self.subset_digits = subset_digits if subset_digits is not None else list(range(10))
        self.num_classes = len(self.subset_digits)
        self.mapping = {digit: idx for idx, digit in enumerate(self.subset_digits)}
    
    def __getitem__(self, index):
        img, target = super().__getitem__(index)
        noise = torch.randn_like(img) * self.noise_std
        img = img + noise
        img = (img - img.mean()) / img.std()  # Standardize the image
        mapped_target = torch.tensor(self.mapping.get(target, -1))

        if mapped_target.item() < 0:
            return img, target, None

        one_hot_target = F.one_hot(mapped_target, self.num_classes).float()
        return img, target, one_hot_target

### Pick your two favorite digits

In [85]:
# Define the subset of digits to use
subset_digits = [2, 5]  # Modify this list to select the desired digits


# Load the MNIST dataset
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = NoisyMNIST(root='./data', train=True, download=True, transform=transform, subset_digits=subset_digits)
test_dataset = NoisyMNIST(root='./data', train=False, download=True, transform=transform, subset_digits=subset_digits)

In [86]:
# Filter the dataset to include only the desired digits
def filter_digits(dataset, digits):
    indices = [i for i, (_, target, _) in enumerate(dataset) if target in digits]
    return Subset(dataset, indices)


train_dataset = filter_digits(train_dataset, subset_digits)
test_dataset = filter_digits(test_dataset, subset_digits)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Define your model

In [None]:
D = ...
context_features = ... # how do we provide context in this case? look at the data (iterate over train_loader)
num_transforms = ...

base_dist = StandardNormal(shape=[D])

transforms = []

for i in range(num_transforms):
    transforms.append(RandomPermutation(D))

    transforms.append(
        MaskedAffineAutoregressiveTransform(
            D, hidden_features=...,
            context_features=...,
            ...
        ),
    )

transform = CompositeTransform(transforms)

flow = Flow(transform, base_dist)

print('Number of parameters:', sum(p.numel() for p in flow.parameters()))

# Train

In [116]:
from tqdm.notebook import trange, tqdm

import matplotlib.pyplot as plt
%matplotlib inline


In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

flow.to(device)

optim = AdamW(flow.parameters(), lr=1e-3)

# Training loop
num_epochs = 10

pbar = trange(num_epochs)

losses = []

for epoch in pbar:
    flow.train()
    total_loss = 0

    epoch_pbar = tqdm(train_loader)

    for i, batch in enumerate(epoch_pbar):
        optim.zero_grad()
        imgs, _, context = batch
        
        # TODO: calculate the loss 
        ...
        loss = ...

        loss.backward()
        optim.step()
        total_loss += loss.item() 

        epoch_pbar.set_description(f'current loss: {total_loss/  (i + 1):.2e}')
    
    avg_loss = total_loss / len(train_loader)
    pbar.set_description(f'Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss:.2e}')

    losses.append(avg_loss)

    flow.eval();

    # Sample one image from the flow with some context ([1, 0] in this case)
    test_context = ...

    with torch.no_grad():
        sample = flow.sample(1, context=test_context.float().to(device))
        img = sample.view(28, 28).cpu().numpy()
        plt.imshow(img)
        plt.show()

# Tests

TODO: Sample several images from the flows with different digits (context).