In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.utils.data
from torch import optim
from scipy.stats import gaussian_kde
import random
import matplotlib.pyplot as plt

import torch
import torch.nn as nn

class CANM(nn.Module):
    def __init__(self, N):
        super(CANM, self).__init__()
        self.N = N
        # Fully connected layers for main encoding
        self.fc1 = nn.Linear(3, 20)
        self.fc21 = nn.Linear(20, 12)
        self.fc22 = nn.Linear(12, 7)
        self.fc23 = nn.Linear(7, N)

        self.fc31 = nn.Linear(20, 12)
        self.fc32 = nn.Linear(12, 7)
        self.fc33 = nn.Linear(7, N)

        # Fully connected layers for confounding variable encoding
        self.fc_conf_mu1 = nn.Linear(20, 12)
        self.fc_conf_mu2 = nn.Linear(12, 1)

        self.fc_conf_logvar1 = nn.Linear(20, 12)
        self.fc_conf_logvar2 = nn.Linear(12, 1)

        # Decoder
        self.fc4 = nn.Linear(1 + N + 1, 10)
        self.fc5 = nn.Linear(10, 7)
        self.fc6 = nn.Linear(7, 5)
        self.fc7 = nn.Linear(5, 1)

        self.relu = nn.ReLU()

    def encode(self, data):
        data = data.view(-1, 3)
        h1 = self.relu(self.fc1(data))

        # Main latent variable
        h21 = self.relu(self.fc21(h1))
        h22 = self.relu(self.fc22(h21))
        mu = self.fc23(h22)

        h31 = self.relu(self.fc31(h1))
        h32 = self.relu(self.fc32(h31))
        logvar = self.fc33(h32)

        # Confounding variable latent space
        conf_mu_hidden = self.relu(self.fc_conf_mu1(h1))
        conf_mu = self.fc_conf_mu2(conf_mu_hidden)

        conf_logvar_hidden = self.relu(self.fc_conf_logvar1(h1))
        conf_logvar = self.fc_conf_logvar2(conf_logvar_hidden)

        return mu, logvar, conf_mu, conf_logvar

    def reparameterize(self, mu, logvar):
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return eps * std + mu
        else:
            return mu

    def decode(self, x, z, conf_z):
        x = x.view(-1, 1)
        z = z.view(-1, self.N)
        conf_z = conf_z.view(-1, 1)
        h4 = self.relu(self.fc4(torch.cat((x, z, conf_z), 1)))
        h5 = self.relu(self.fc5(h4))
        h6 = self.relu(self.fc6(h5))
        yhat = self.fc7(h6)
        return yhat

    def forward(self, data):
        data = data.view(-1, 3)
        x, y, conf = data[:, 0], data[:, 1], data[:, 2]

        # Encoding
        mu, logvar, conf_mu, conf_logvar = self.encode(data)

        # Reparameterization
        z = self.reparameterize(mu, logvar)
        conf_z = self.reparameterize(conf_mu, conf_logvar)

        # Decoding
        yhat = self.decode(x, z, conf_z)
        return yhat, mu, logvar, conf_mu, conf_logvar

class TransformerVAE(nn.Module):
    def __init__(self, input_dim=3, latent_dim=10, confounding_dim=1, d_model=64, nhead=4, num_encoder_layers=4, num_decoder_layers=4):
        super(TransformerVAE, self).__init__()

        self.input_embedding = nn.Linear(input_dim, d_model)
        self.positional_encoding = nn.Parameter(torch.randn(1, d_model))

        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)

        self.fc_mu = nn.Sequential(
            nn.Linear(d_model, 32),
            nn.ReLU(),
            nn.Linear(32, latent_dim)
        )

        self.fc_logvar = nn.Sequential(
            nn.Linear(d_model, 32),
            nn.ReLU(),
            nn.Linear(32, latent_dim)
        )

        self.fc_conf_mu = nn.Sequential(
            nn.Linear(d_model, 32),
            nn.ReLU(),
            nn.Linear(32, confounding_dim)
        )

        self.fc_conf_logvar = nn.Sequential(
            nn.Linear(d_model, 32),
            nn.ReLU(),
            nn.Linear(32, confounding_dim)
        )

        decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)

        self.decoder = nn.Sequential(
            nn.Linear(1 + latent_dim + confounding_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 1)
        )

    def encode(self, x):
        x = self.input_embedding(x)
        x = x + self.positional_encoding

        encoded = self.transformer_encoder(x.unsqueeze(1)).squeeze(1)

        mu = self.fc_mu(encoded)
        logvar = self.fc_logvar(encoded)

        conf_mu = self.fc_conf_mu(encoded)
        conf_logvar = self.fc_conf_logvar(encoded)

        return mu, logvar, conf_mu, conf_logvar

    def reparameterize(self, mu, logvar):
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return eps.mul(std).add_(mu)
        else:
            return mu

    def decode(self, x, z, conf_z):
        yhat = self.decoder(torch.cat([x.view(-1, 1), z, conf_z], dim=1))
        return yhat

    def forward(self, data):
        data = data.view(-1, 3)
        x, y, u = data[:, 0], data[:, 1], data[:, 2]

        mu, logvar, conf_mu, conf_logvar = self.encode(data)

        z = self.reparameterize(mu, logvar)
        conf_z = self.reparameterize(conf_mu, conf_logvar)

        yhat = self.decode(x, z, conf_z)

        return yhat, mu, logvar, conf_mu, conf_logvar

def canm_loss_function(y, yhat, mu, logvar, sdy, beta, conf_mu, conf_logvar):
    # Reconstruction loss (BCE)
    N = y - yhat

    if sdy.item() <= 0:
        sdy = -sdy + 0.05

    n = torch.distributions.Normal(0, sdy)
    BCE = -torch.sum(n.log_prob(N))

    # Regularization for the confounding variable (Confounder KL Divergence)
    conf_reg = -0.5 * torch.sum(1 + conf_logvar - conf_mu.pow(2) - conf_logvar.exp()) * beta

    # Regularization for the main latent space (KL Divergence)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) * beta

    # Total loss
    return BCE + KLD + conf_reg


def transformer_loss_function(y, yhat, mu, logvar, conf_mu, conf_logvar, sdy, beta):
    N = y - yhat

    if sdy.item() <= 0:
        sdy = -sdy + 0.05

    n = torch.distributions.Normal(0, sdy)
    BCE = -torch.sum(n.log_prob(N))
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) * beta
    CONF_KLD = -0.5 * torch.sum(1 + conf_logvar - conf_mu.pow(2) - conf_logvar.exp()) * beta
    return BCE + KLD + CONF_KLD
