In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os

import torch
from torch.utils.data import DataLoader

from normalizing_flows.src.one_dimensional import data, coupling_flow
from normalizing_flows.src.realnvp.callbacks import EarlyStopping, ModelCheckpoint


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

# Create dataset

In [None]:
n_samples = 50000
features = data.create_dataset('make_moons', n_samples=n_samples, noise=0.05, random_state=1)
# features = data.create_dataset('make_circles', n_samples=n_samples, noise=0.05, random_state=1, factor=0.5)
# features = data.create_dataset('make_blobs', n_samples=n_samples, centers=3, random_state=1)

data.scatter_plot(features)

In [None]:
# Create pytorch dataset and dataloader
dataset = data.OneDimensionalDataset(features)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# Print some information about the dataloader
print(f"Dataset size: {len(dataset)}")
print(f"Number of batches: {len(dataloader)}")
print(f"Batch shape: {next(iter(dataloader)).shape}")

# Create model

In [None]:
cf = coupling_flow.CouplingFlow1D(2, hidden_dim=128, n_coupling_layers=5)
cf = cf.to(device)
cf

# Train couping flow

In [None]:
opt = torch.optim.Adam(cf.parameters(), lr=2e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, factor=0.5, patience=4, threshold=0.001, threshold_mode='abs')
early_stopping = EarlyStopping(patience=8, threshold=0.001)
save_dir = 'checkpoints'
model_checkpoint = ModelCheckpoint(save_dir=save_dir, filename='model_{epoch:03d}_{score:.3f}.pt', save_best_only=True)

n_epochs = 100

cf = cf.train()
cf = cf.to(device)
for ep in range(n_epochs):
    loss_sum = 0
    for i, x in enumerate(dataloader):
        x = x.to(device)
        opt.zero_grad()
        loss = -cf.log_prob(x).mean()
        loss.backward()
        opt.step()
        loss_sum += loss.detach().cpu().item()
    
    loss_avg = loss_sum / len(dataloader)
    lr = scheduler.get_last_lr()[0]
    scheduler.step(loss_avg)
    model_checkpoint.save(cf, score=loss_avg, epoch=ep)
    
    print(f"Epoch {ep+1}/{n_epochs}, loss: {loss_avg:.4f}, lr: {lr}")
    
    if early_stopping(loss_avg):
        print(f'EarlyStopping activated. Ending training now.')
        break

best_path = os.path.join(save_dir, os.listdir(save_dir)[-1])
print(f"Loading best model from checkpoint: {best_path}.")
model_checkpoint.load(cf, best_path)

In [None]:
sampled_features = cf.sample(25000).cpu().numpy()

data.scatter_plot(sampled_features)

In [None]:
data.density_heatmap_plot(
    cf, 
    torch.linspace(features[:,0].min() * 1.1, features[:,0].max() * 1.1, 401), 
    torch.linspace(features[:,1].min() * 1.1, features[:,1].max() * 1.1, 401)
)