In [65]:
import torch
import torch.nn as nn

import numpy as np

from src.data.datasets import ModelParamsDataset

from src.data.helpers import get_moons_dataset, rotate, get_accuracy
from src.model.models import DBModel, Autoencoder

import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
from src.visualization.visualize import plot_interpolation, plot_decision_boundary


In [66]:
def sample_2d_normal(mean, cov = [[1, 0], [0, 1]], num_samples = 1000):
    """
    Sample from a 2D normal distribution with a given mean and covariance matrix.

    Parameters:
    mean (list or tensor): Mean of the distribution (2 elements).
    cov (list or tensor): Covariance matrix of the distribution (2x2).
    num_samples (int): Number of samples to draw.

    Returns:
    tensor: Samples drawn from the distribution.
    """
    mean = torch.tensor(mean, dtype=torch.float32)
    cov = torch.tensor(cov, dtype=torch.float32)
    distribution = torch.distributions.MultivariateNormal(mean, cov)
    samples = distribution.sample((num_samples,))
    return samples

In [67]:
# Hyperparameters
epochs = 100
batch_size = 100
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
random_seed = 0

torch.manual_seed(random_seed)

# Model
autoencoder = Autoencoder(2)
model = DBModel()
model.requires_grad_(False)

autoencoder.to(device)
model.to(device)

# Optimizer
optimizer = torch.optim.Adam(autoencoder.parameters(), lr=1e-3)
loss_fn = nn.BCELoss()
loss_fn_latent = nn.MSELoss()

# Dataset
dataset = ModelParamsDataset("../data/eight_angles_small.csv")
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)

# Grid of points 100x100
xx, yy = np.meshgrid(np.linspace(-2, 2, 100), np.linspace(-2, 2, 100))
input = np.c_[xx.ravel(), yy.ravel()]
input = torch.tensor(input, dtype=torch.float32)
input = input.unsqueeze(0).repeat(batch_size, 1, 1)
input = input.to(device)

# Locations
path = "../models/autoencoders/matching"
model_path = "{}/model_final.pth".format(path)

In [68]:
autoencoder.train()
for epoch in tqdm(range(epochs)):
    total_loss = 0
    for (parameters_batch, angles_batch) in dataloader:
        # Model prediction
        parameters_batch = parameters_batch.to(device)
        latent = autoencoder.encoder(parameters_batch)

        c = torch.cos(angles_batch[0])*15
        s = torch.sin(angles_batch[0])*15
        mean = torch.Tensor([c, s])
        samples = sample_2d_normal(mean, num_samples=batch_size)
        samples = samples.to(device)
        reconstructed = autoencoder.decoder(latent)
        output = model(reconstructed, input)

        # Ground truth
        goal = model(parameters_batch, input)

        optimizer.zero_grad()
        loss = loss_fn(output, goal)
        loss_latent = loss_fn_latent(latent, samples)
        loss += loss_latent
        total_loss += loss.item()
        loss.backward()
        optimizer.step()
        
    print(f"Epoch {epoch} - Loss: {total_loss}")
    if(epoch % 5 == 0):
        torch.save(autoencoder.state_dict(), f"{path}/checkpoint_{epoch}.pth")

torch.save(autoencoder.state_dict(), model_path)

  0%|          | 0/100 [00:00<?, ?it/s]

  mean = torch.tensor(mean, dtype=torch.float32)


Epoch 0 - Loss: 68583.71393585205
Epoch 1 - Loss: 78006.68474578857
Epoch 2 - Loss: 82021.20597839355
Epoch 3 - Loss: 80858.83161735535
Epoch 4 - Loss: 82182.57918739319
Epoch 5 - Loss: 81671.71192932129
Epoch 6 - Loss: 81503.73538208008
Epoch 7 - Loss: 80567.34192085266
Epoch 8 - Loss: 80576.95140838623
Epoch 9 - Loss: 81219.01641464233
Epoch 10 - Loss: 80906.15961837769
Epoch 11 - Loss: 80930.2382068634


KeyboardInterrupt: 