In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import torch
from torch.utils.data import DataLoader

from normalizing_flows.src.one_dimensional import data, realnvp
from normalizing_flows.src.callbacks import EarlyStopping


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

# Create dataset

In [None]:
features = data.create_dataset('make_moons', n_samples=10000, noise=0.05, random_state=1)
# features = data.create_dataset('make_circles', n_samples=10000, noise=0.05, random_state=1, factor=0.5)

data.scatter_plot(features)

In [None]:
# Create pytorch dataset and dataloader
dataset = data.OneDimensionalDataset(features)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# Print some information about the dataloader
print(f"Dataset size: {len(dataset)}")
print(f"Number of batches: {len(dataloader)}")
print(f"Batch shape: {next(iter(dataloader)).shape}")

# Create model

In [None]:
cf = realnvp.CouplingFlow1D(2)
cf = cf.to(device)
cf

# Train couping flow

In [None]:
opt = torch.optim.Adam(cf.parameters(), lr=2e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, factor=0.5, patience=5, threshold=0.001, threshold_mode='abs')
early_stopping = EarlyStopping(patience=10, threshold=0.001)
n_epochs = 100

cf = cf.train()
cf = cf.to(device)
for ep in range(n_epochs):
    loss_sum = 0
    for i, x in enumerate(dataloader):
        x = x.to(device)
        opt.zero_grad()
        loss = -cf.log_prob(x).mean()
        loss.backward()
        opt.step()
        loss_sum += loss.detach().cpu().item()
    
    loss_avg = loss_sum / len(dataloader)
    lr = scheduler.get_last_lr()[0]
    scheduler.step(loss_avg)
    
    print(f"Epoch {ep+1}/{n_epochs}, loss: {loss_avg:.4f}, lr: {lr}")
    
    if early_stopping(loss_avg):
        print(f'EarlyStopping activated. Ending training now.')
        break

In [None]:
sampled_features = cf.sample(10000).cpu().numpy()

data.scatter_plot(sampled_features)

In [None]:
data.density_heatmap_plot(cf, torch.linspace(-1.2, 2.2, 501), torch.linspace(-0.7, 1.2, 301))