# Models Training

In [None]:
# importing dependencies 
from utils import *
from ddpms import *

In [None]:
# Defining general parameters 
# DDPM's Parameters
T = 1000
learning_rate = 1e-3
epochs = 10
batch_size = 256

# Importance sampling specific parameters
history_length = 10  # Number of recent values to store

In [None]:
# Loading training data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x + torch.rand(x.shape) / 255),  # Dequantize pixel values
    transforms.Lambda(lambda x: (x - 0.5) * 2.0),  # Map from [0,1] -> [-1, -1]
    transforms.Lambda(lambda x: x.flatten())
])

# Download and transform train dataset
dataloader_train = torch.utils.data.DataLoader(
    datasets.MNIST('./mnist_data', download=True, train=True, transform=transform),
    batch_size=batch_size,
    shuffle=True)

In [None]:
# Select device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

In [None]:
model_folder = './model_checkpoints'

## Provided DDPM

In [None]:
%%time

mnist_unet = ScoreNet((lambda t: torch.ones(1).to(device)))

# Construct model
model = DDPM_classic(mnist_unet, T=T).to(device)

# Construct optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Setup simple scheduler
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.9999)

# Call training loop
train(model, optimizer, scheduler, dataloader_train,
      epochs=epochs, device=device, ema=True, per_epoch_callback=None)

# Save the model 
#torch.save(model.state_dict(), model_folder+"/model_classic.pth")

## Low-discrepancy sampling (VDM)

In [None]:
%%time

mnist_unet = ScoreNet((lambda t: torch.ones(1).to(device)))

# Construct model
model = DDPM_low_discrepancy(mnist_unet, T=T, sampler="simple").to(device)

# Construct optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Setup simple scheduler
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.9999)

# Call training loop
train(model, optimizer, scheduler, dataloader_train,
      epochs=epochs, device=device, ema=True, per_epoch_callback=None)

# Save the model 
#torch.save(model.state_dict(), model_folder+"/model_lds_simple.pth")

## Low-discrepancy sampling (Sobol)

In [None]:
%%time

mnist_unet = ScoreNet((lambda t: torch.ones(1).to(device)))

# Construct model
model = DDPM_low_discrepancy(mnist_unet, T=T, sampler="sobol").to(device)

# Construct optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Setup simple scheduler
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.9999)

# Call training loop
train(model, optimizer, scheduler, dataloader_train,
      epochs=epochs, device=device, ema=True, per_epoch_callback=None)

# Save the model 
#torch.save(model.state_dict(), model_folder+"/model_lds_sobol.pth")

## Importance sampling 

In [None]:

mnist_unet = ScoreNet((lambda t: torch.ones(1).to(device)))

# Construct model
model = DDPM_importance(mnist_unet, T=T).to(device)

# Construct optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Setup simple scheduler
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.9999)

# Call training loop
train(model, optimizer, scheduler, dataloader_train,
      epochs=epochs, device=device, ema=True, per_epoch_callback=reporter)

# Save the model 
#torch.save(model.state_dict(), model_folder+"/model_is.pt")

## Predicting $x_0$

In [None]:
%%time

mnist_unet = ScoreNet((lambda t: torch.ones(1).to(device)))

# Construct model
model = DDPM_x0(mnist_unet, T=T).to(device)

# Construct optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Setup simple scheduler
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.9999)

# Call training loop
train(model, optimizer, scheduler, dataloader_train,
      epochs=epochs, device=device, ema=True, per_epoch_callback=None)

# Save model
#torch.save(model.state_dict(), model_folder+"/model_x0.pt")

## Predicting $\mu$

In [None]:
%%time


mnist_unet = ScoreNet2()

# Construct model
model = DDPM_mu(mnist_unet, T=T).to(device)

# Construct optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Setup simple scheduler
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.9999)

# Call training loop
train(model, optimizer, scheduler, dataloader_train,
      epochs=epochs, device=device, ema=False, per_epoch_callback=reporter)

# Save model
#torch.save(model.state_dict(), model_folder+"/model_mu.pt")

## Classifier Guided

In [None]:
%%time

# Classifier specific parameters
beta_1 = 1e-4
beta_T = 2e-2

# Instantiating the classifier
model_classifier = RobustMNISTClassifier().to(device)

# instantiating the classifier-wrapper
wrapper = ClassifierWrapper(model_classifier, T=T, beta_1=beta_1, beta_T=beta_T).to(device)

# train the classifier
classifier = train_classifier(model_classifier, wrapper)

# Saving the classifier
torch.save(model_classifier.state_dict(), model_folder + "/classifier.pt")

mnist_unet = ScoreNet((lambda t: torch.ones(1).to(device)))

# Construct model
model = DDPM_class(mnist_unet, T=T).to(device)

# Construct optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Setup simple scheduler
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.9999)

# Call training loop
train(model, optimizer, scheduler, dataloader_train,
      epochs=epochs, device=device, ema=True, per_epoch_callback=None)

# No saving of the DDPM model since is the same used in the provided implementation we will load "/model_classic.pth" afterwards

## Classifier-free Guidance

In [None]:
%%time

# Construct model
mnist_unet = ScoreNet_class((lambda t: torch.ones(1).to(device)))

# Construct model
model = DDPM_class_free(mnist_unet, T=T).to(device)

# Construct optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Setup simple scheduler
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.9999)

# Call training loop
train_class_free(model, optimizer, scheduler, dataloader_train,
                 epochs=epochs, device=device, ema=True, per_epoch_callback=None)

# Save model
#torch.save(model.state_dict(), model_folder+"/model_classifier_free.pt")