In [None]:
%load_ext autoreload
%autoreload 2

DEVICE = "cuda" # cpu/cuda
RETRAIN = False

if DEVICE == "cpu":
    import os
    os.environ["CUDA_VISIBLE_DEVICES"] = ""

import sys
sys.path.append("../")

import torch
from torch.utils.data import Subset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import numpy as np

from image_gen import GenerativeModel
from image_gen.diffusion import VarianceExploding, VariancePreserving, SubVariancePreserving
from image_gen.noise import LinearNoiseSchedule, CosineNoiseSchedule
from image_gen.samplers import EulerMaruyama, ExponentialIntegrator, ODEProbabilityFlow, PredictorCorrector


n_threads = torch.get_num_threads()
print('Number of threads: {:d}'.format(n_threads))

device = torch.device(DEVICE)

In [None]:
# Dataset MINST
data = datasets.MNIST(
    root='data',
    train=True,
    download=True,
    transform=ToTensor()
)
print(type(data))

In [None]:
diffusion_model = "ve"
noise_schedule = "l"
sampler = "euler"

schedule_map = {
    "l": LinearNoiseSchedule(beta_min=0.0001, beta_max=20),
    "c": CosineNoiseSchedule(beta_max=0.9999)
}
diffusion_map = {
    "ve": VarianceExploding,
    "vp": VariancePreserving,
    "svp": SubVariancePreserving
}
sampler_map = {
    "euler": EulerMaruyama,
    "exp": ExponentialIntegrator,
    "ode": ODEProbabilityFlow,
    "pc": PredictorCorrector
}

model = GenerativeModel(
    diffusion=diffusion_map.get(diffusion_model),
    sampler=sampler_map.get(sampler),
    noise_schedule=schedule_map.get(noise_schedule)
)

if diffusion_model != 've':
    diffusion_model = f"{diffusion_model}_{noise_schedule}"

In [None]:
epochs = 20
model.train(data, epochs=epochs)
model.save(f'mnist_{epochs}e_{diffusion_model}_{sampler}.pth')
# model.load(f'mnist_{epochs}e_{diffusion_model}_{sampler}.pth')

In [None]:
model.set_labels([f"Number {n}" for n in range(10)])
model.save(f'mnist_{epochs}e_{diffusion_model}_{sampler}_LabelNameTest.pth')

In [None]:
n_images = 16
samples = model.generate(n_images)

In [None]:
def show_images(images, n_images=4):
    images = images[:n_images]  # Select only the first n_images
    images = images.permute(0, 2, 3, 1).cpu().detach().numpy()
    images = (images + 1) / 2  # Scale from [-1,1] to [0,1]
    
    grid_size = int(np.sqrt(n_images))  # Ensure a square-like grid
    
    fig, axes = plt.subplots(grid_size, grid_size, figsize=(4, 4))
    axes = axes.flatten()
    
    for idx, img in enumerate(images):
        axes[idx].imshow(img, cmap="gray")
        axes[idx].axis('off')
    
    plt.tight_layout()
    plt.show()

In [None]:
def show_mnist_images(dataset, n_images=4):
    fig, axes = plt.subplots(int(np.sqrt(n_images)), int(np.sqrt(n_images)), figsize=(4, 4))
    axes = axes.flatten()
    
    for i in range(n_images):
        img, label = dataset[i]  # Get image and label
        img = img.squeeze().numpy()  # Convert to 2D array
        
        axes[i].imshow(img, cmap='gray')  # Display in grayscale
        axes[i].axis('off')

    plt.tight_layout()
    plt.show()

In [None]:
show_mnist_images(data, n_images=4)

In [None]:
show_images(samples, n_images=n_images)

In [None]:
print(model.model.module.num_classes)

In [None]:
n_images = 16
samples = model.generate(n_images, class_labels=torch.Tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5]).long().to(device))
show_images(samples, n_images=n_images)