In [None]:
import os
from tqdm import tqdm

import numpy as np
import matplotlib.pyplot as plt

import torch
from torch import nn

from flow_model import Flow
from utils import gaussian_nll, gen_data

In [None]:
x_dist_std = 10
source_fn = lambda x: 0.004*x*x
target_fn = lambda x: 0
device = "cpu"
exp_name = "quad_line"
use_x_dist = False

x_dist_std = 10
source_fn = lambda x: torch.cos(x)
target_fn = lambda x: torch.sin(x)
device = "cpu"
exp_name = "cos_sin"
use_x_dist = True

In [None]:
def target_neg_log_prob(points):
    x = points[:, 0]
    y = points[:, 1]
    gt = target_fn(x)

    loss_d0 = gaussian_nll(0, np.log(x_dist_std**2), x)
    loss_d1 = gaussian_nll(gt, np.log(0.05), y)

    total_loss = loss_d1
    if use_x_dist:
        total_loss += loss_d0
        
    return total_loss
loaders = gen_data(x_dist_std,
                   source_fn,
                   target_fn,
                   source_std=0.05,
                   target_std=0.05,
                   num_points=1024,
                   batch_size=512)

points2 = next(iter(loaders['source'])).numpy()
plt.scatter(points2[:,0], points2[:,1], label="Source Dist")

points = next(iter(loaders['target'])).numpy()
plt.scatter(points[:,0], points[:,1], label="Target Dist")
plt.xlim(-30,30)
plt.ylim(-5,5)
plt.legend()
plt.grid()
plt.show()

In [None]:
model = Flow(128, 32).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for e in  tqdm(range(2000)):
    losses = []
    for data in loaders["source"]:
        data = data.float().to(device)

        target_samples, log_jacobian = model(data)

        mean = torch.zeros_like(target_samples)
        log_var = torch.zeros_like(target_samples)

        log_p = target_neg_log_prob(target_samples)

        neg_log_likelihood = log_p - log_jacobian
        neg_log_likelihood = neg_log_likelihood.mean()/2 # nats per dim

        optimizer.zero_grad()
        neg_log_likelihood.backward()
        optimizer.step()
        losses.append(neg_log_likelihood.item())
print("Final negative log likelihood: {:.2f}".format(np.mean(losses)))

os.makedirs('plots_flow', exist_ok=True)
os.makedirs(f'plots_flow/{exp_name}', exist_ok=True)

data = loaders["source"].dataset.float().to(device)

for i in range(10):
    plt.scatter(data[:,0].detach().cpu().numpy(), data[:,1].detach().cpu().numpy(), color="blue")
    plt.grid()
    plt.xlim(-10,10)
    plt.ylim(-1.5,1.5)

    plt.savefig(f'plots_flow/{exp_name}/frame_{i:03d}.png')
    plt.close()  # Close the plot to avoid display

    data = model.forward(data)[0]