In [None]:
import pandas as pd
import torch
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import importlib

from load_and_sample import *

torch.set_float32_matmul_precision("high")

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using: {device}")



Using: cuda


In [None]:
# --- Model Parameters --- #

latent_dim = 128
unet_dim = 128
train_batch_size = 16
sample_batch_size = 4
num_timesteps = 1000

In [None]:
# Read from the latent data file and put it into a dataloader

In [None]:
# Load the VAE
vae = load_vae_selfies("./saved_models/epoch=447-step=139328.ckpt")


In [None]:
# Initialize the diffusion model

In [None]:
# Diffusion model training

def train_diffusion(diffusion_model, dataloader, batch_size=train_batch_size, epochs=10, lr=1e-4):
    model = diffusion_model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    for epoch in range(epochs):
        model.train()
        epoch_loss = 0.0

        progress_bar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Epoch {epoch}")

        for batch_idx, batch in progress_bar:
            latent = batch


In [None]:
# Train the model



def train_diffusion(diffusion_model, dataloader=dataloader, batch_size=train_batch_size, epochs=10, lr=1e-4, device=device):
    model = diffusion_model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    for epoch in range(epochs):
        model.train()
        epoch_loss = 0.0

        progress_bar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Epoch {epoch}")

        for batch_idx, batch in progress_bar:
            latent = batch['latent']
            latent = latent.to(device)

            # IMPORTANT: the dataloader stores objects of shape (b, n), but the
            # UNET / diffusion want (b, 1, n)
            latent = latent.reshape(batch_size, 1, latent_dim)

            # When we sample, we will unshape this
            optimizer.zero_grad()
            loss = model(latent)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

            if batch_idx % 100 == 0:
                print(f"Epoch {epoch}, Batch: {batch_idx}: Batch Loss: {loss.item()}")

        torch.save({'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict()}, 'diff_checkpoints/best_model.pt')


        print(f"Epoch {epoch}, Average Loss: {epoch_loss / len(dataloader):.6f}")


In [None]:
# make the data loader
dataloader = data_module.full_dataloader
for batch in dataloader:
    latent = batch['latent']
    print(f"latent shape: {latent.shape}")
    break

latent shape: torch.Size([16, 256])


In [None]:
def sample_diffusion(diffusion_model):
    diffusion_model.eval()
    with torch.no_grad():
        latents = diffusion_model.sample(batch_size=sample_batch_size)
        # latents are (b, 1, n), need to reshape
        latents = latents.reshape(sample_batch_size, latent_dim)
        return latents
    


In [None]:
# Create diffusion model
importlib.reload(guided_diffusion.guided_diffusion_1d)
torch.cuda.empty_cache()

epochs = 1
print(f"Using device: {device}")


unet_model = guided_diffusion.guided_diffusion_1d.Unet1D(
    dim = unet_dim,
    channels=1,
    dim_mults=(1, 2, 4, 8)
).to(device)

diffusion_model = guided_diffusion.guided_diffusion_1d.GaussianDiffusion1D(
    unet_model,
    seq_length=latent_dim,
    timesteps=num_timesteps,
    objective='pred_v'
).to(device)


Using device: cuda
