In [None]:
import torch
import torch.utils.data
from torch import nn, optim
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# CANM Model Definition
class CANM(nn.Module):
    def __init__(self, N=1):
        super(CANM, self).__init__()
        self.N = N
        self.fc1 = nn.Linear(2, 20)
        self.fc21 = nn.Linear(20, 12)
        self.fc22 = nn.Linear(12, 7)
        self.fc23 = nn.Linear(7, N)
        self.fc31 = nn.Linear(20, 12)
        self.fc32 = nn.Linear(12, 7)
        self.fc33 = nn.Linear(7, N)
        self.fc4 = nn.Linear(1 + N, 10)
        self.fc5 = nn.Linear(10, 7)
        self.fc6 = nn.Linear(7, 5)
        self.fc7 = nn.Linear(5, 1)
        self.relu = nn.ReLU()

    def encode(self, xy):
        h1 = self.relu(self.fc1(xy.view(-1, 2)))
        mu = self.fc23(self.relu(self.fc22(self.relu(self.fc21(h1)))))
        logvar = self.fc33(self.relu(self.fc32(self.relu(self.fc31(h1)))))
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu)

    def decode(self, x, z):
        xz = torch.cat((x.view(-1, 1), z.view(-1, self.N)), dim=1)
        return self.fc7(self.relu(self.fc6(self.relu(self.fc5(self.relu(self.fc4(xz)))))))

    def forward(self, data):
        x, y = data[:, 0], data[:, 1]
        mu, logvar = self.encode(data)
        z = self.reparameterize(mu, logvar)
        yhat = self.decode(x, z)
        return yhat, mu, logvar

# TransformerVAE Model Definition
class TransformerVAE(nn.Module):
    def __init__(self, N=1, num_layers=2, d_model=32, nhead=2, dim_feedforward=64, dropout=0.1):
        super(TransformerVAE, self).__init__()
        self.N = N
        self.d_model = d_model
        self.input_projection = nn.Sequential(nn.Linear(2, d_model), nn.ReLU())
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.fc_mu = nn.Linear(d_model, N)
        self.fc_logvar = nn.Linear(d_model, N)
        self.decoder = nn.Sequential(nn.Linear(N + 1, d_model), nn.ReLU(), nn.Linear(d_model, 1))

    def encode(self, xy):
        encoded = self.transformer_encoder(self.input_projection(xy).unsqueeze(1)).squeeze(1)
        return self.fc_mu(encoded), self.fc_logvar(encoded)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu)

    def decode(self, x, z):
        xz = torch.cat((x.view(-1, 1), z.view(-1, self.N)), dim=-1)
        return self.decoder(xz)

    def forward(self, data):
        mu, logvar = self.encode(data.view(-1, 2))
        z = self.reparameterize(mu, logvar)
        yhat = self.decode(data[:, 0].unsqueeze(-1), z)
        return yhat, mu, logvar

# Define the loss function
def loss_function(y, yhat, mu, logvar, sdy, beta=0.1):
    if sdy.item() <= 0:
        sdy = torch.clamp(sdy, min=0.70)
    n = torch.distributions.Normal(0, sdy)
    reconstruction_loss = -torch.sum(n.log_prob(y - yhat))
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return reconstruction_loss + beta * kl_loss