In [None]:
# Python packages

#!pip install import-ipynb
#!pip install git+https://github.com/patrick-kidger/torchcubicspline.git # this pip install shouldn't be strictly needed but kept just in case.
import time
import import_ipynb
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import torch 
import torch.nn as nn 
import torch.optim as optim 
from torchvision import datasets
from torchvision import transforms

from torchcubicspline import(natural_cubic_spline_coeffs, 
                             NaturalCubicSpline)
from datetime import datetime, timedelta
from scipy.interpolate import CubicSpline
from scipy.optimize import newton,least_squares
from torch.func import hessian
import math
from torch.func import jacfwd
from torch.func import vmap, vjp

In [None]:

# set default torch values to float64
torch.set_default_dtype(torch.float64)

# set seed, to ensure same weight init across models
def set_seed(inte):
    torch.manual_seed(inte)


# device setting for if Cuda is available, i.e. allows running on GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Functions used across all networks

# Learing rate scheduler, decays learning rate every 50 epochs

class CustomLRScheduler(torch.optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, last_epoch = -1):
        super().__init__(optimizer, last_epoch)
    def get_lr(self):
        lr_factor = 0.9**(self.last_epoch//50)
        return [base_lr*lr_factor for base_lr in self.base_lrs]
    
# class for centered soft max

class nn_centered_softmax(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, input):
        return 1/(1+torch.exp(-input))- 0.5

In [None]:
#Torch functions for calculating swaps, zero coupon prices and interpolation


# function for discounting zero coupon rates for tensors

def discounting_torch(curve, maturity): 
    return torch.exp(-maturity * curve[int(maturity)-1]) 


# function for calculating swap rates for tensors, calls discounting_torch().

def swap_rate_torch_paper(curve, maturities):
    swap_rates = torch.empty(0) 
    for maturity in maturities:
        float_leg = 1 - discounting_torch(curve, maturity)
        fixed_leg = torch.sum(torch.stack([discounting_torch(curve, t) for t in range(1,int(maturity)+1)]))
        swap_rate = float_leg / fixed_leg
        swap_rates = torch.cat(( swap_rates,swap_rate.unsqueeze(0)))
    return swap_rates


# curve class for doing interpolation for tensors, currently not needed. 
class Curve_torch():
    def __init__(self, x_values, y_values):

        cs = natural_cubic_spline_coeffs(x_values, y_values.reshape(6,1))
        spline = NaturalCubicSpline(cs)
        self.x_values = np.linspace(0,30,30* 365, endpoint=True)
        torch_x_values = torch.empty(0)
        torch_y_values = torch.empty(0)
        for i in self.x_values: 
            torch_y_values = torch.cat((torch_y_values,(spline.evaluate(torch.tensor(i)))))

            torch_x_values = torch.cat((torch_x_values,(torch.tensor([i]))))
        
        self.y_values = torch_y_values
        self.x_values = torch_x_values


    def set_value(self, x, y):
        idx = torch.abs(self.x_values - x).argmin()
        self.y_values[idx] = y

    def get_value(self, x):
        # Find the index of the nearest x value
        idx = torch.abs(self.x_values - x).argmin()
        return self.y_values[idx]


In [None]:
# Functions used to calculate the arbitrage condition and needed compontents for this, ie sigma, mu and derivatives. Both 2-factor and 3-factor versions
# exists for each function (if needed)


# construction and calculation of gradients used in arbitrage condition
def construct_grad_z2(taus, latents, model):
    taus = torch.full((latents.size()[0], 1), taus)      

    #compute jacobian
    compute_batch_jacobian = vmap(jacfwd(model.decoder, argnums=0), in_dims=(0)) 
    latent_grad = compute_batch_jacobian(torch.cat((latents,taus),dim=1))

    # Reshaping gradients
    grad_z = latent_grad[:, :, :2]  
    grad_tau = latent_grad[:, :, 2:]  
    grad_tau = torch.squeeze(grad_tau, dim=2)
    grad_z = torch.squeeze(grad_z, dim=2)
    
    return grad_tau, grad_z


# construction and calculation of gradients used in arbitrage condition
def construct_grad_z3(taus, latents, model):
    taus = torch.full((latents.size()[0], 1), taus)      

    #compute jacobian
    compute_batch_jacobian = vmap(jacfwd(model.decoder, argnums=0), in_dims=(0)) 
    latent_grad = compute_batch_jacobian(torch.cat((latents,taus),dim=1))

    # Reshaping gradients
    grad_z = latent_grad[:, :, :3]  # Selects first 2 elements along the last dimension
    grad_tau = latent_grad[:, :, 3:]  # Selects the last element along the last dimension
    grad_tau = torch.squeeze(grad_tau, dim=2)
    grad_z = torch.squeeze(grad_z, dim=2)
    
    return grad_tau, grad_z

# construction and calculation of hessian used in arbitrage condition
def construct_hessian(taus, latents, model):    
    
    taus = torch.full((latents.size()[0], 1), taus)          
    # compute hessian
    compute_batch_hessian = vmap(hessian(model.decoder, argnums=0), in_dims=(0))
    hessian_z = compute_batch_hessian(torch.cat((latents,taus), dim=1)) 

    # reshaping hessian
    removed_row = hessian_z[:, :, :-1, :]
    hessian_z = removed_row[:, :, :, :-1]
    hessian_z = torch.squeeze(hessian_z,dim=1)

    return hessian_z


# Function for constructing the sigma matrix given output of network
def construct_sigma2(sigma1, sigma2, rho):
    final_sigma = torch.rand(len(sigma1), 2, 2)
    for i in range(len(sigma1)):
        final_sigma[i, 0, 0] = sigma1[i]
        final_sigma[i, 0, 1] = 0
        final_sigma[i, 1, 0] = rho[i] * sigma2[i]
        final_sigma[i, 1, 1] = torch.sqrt(1 - torch.pow(rho[i], 2)) * sigma2[i]
    return final_sigma

# Function for constructing the sigma matrix given output of network
def construct_sigma3(sigma1, sigma2, sigma3, rho12, rho13, rho23):
    final_sigma = torch.empty(len(sigma1), 3, 3)
    for i in range(len(sigma1)):
        final_sigma[i, 0, 0] = sigma1[i]
        final_sigma[i, 0, 1] = 0
        final_sigma[i, 0, 2] = 0
        final_sigma[i, 1, 0] = rho12[i] * sigma2[i]
        final_sigma[i, 1, 1] = torch.sqrt(1 - torch.pow(rho12[i], 2)) * sigma2[i]
        final_sigma[i, 1, 2] = 0
        final_sigma[i, 2, 0] = sigma3[i]*rho13[i]
        final_sigma[i, 2, 1] = sigma3[i]*((rho23[i]-rho13[i]*rho12[i])/(torch.sqrt(1 - torch.pow(rho12[i], 2))))
        final_sigma[i, 2, 2] = sigma3[i]*(torch.sqrt(1-(torch.pow(rho13[i], 2)-(torch.pow((rho23[i]-rho13[i]*rho12[i])/(torch.sqrt(1 - torch.pow(rho12[i], 2))), 2)))))
    return final_sigma

# Function for calculation of the arbitrage condition 
def arbitrage_condition2(pi, r, partial_tau, grad_z, mu, sigma_matrix, hessian_z): 
    L = torch.rand(len(pi),1)

    for i in range(len(pi)):
        first_part = (-r[i]*pi[i]) - partial_tau[i] + (grad_z[i] @ mu[i].unsqueeze(0).reshape(2,1))
        second_part = (0.5*torch.trace(torch.transpose(sigma_matrix[i], 0,1) * hessian_z[i] * sigma_matrix[i]))

        L[i] = first_part + second_part
    return L


# Function for calculation of the arbitrage condition 
def arbitrage_condition3(pi, r, partial_tau, grad_z, mu, sigma_matrix, hessian_z): 
    L = torch.rand(len(pi),1)

    for i in range(len(pi)):
        first_part = (-r[i]*pi[i]) - partial_tau[i] + (grad_z[i] @ mu[i].unsqueeze(0).reshape(3,1))
        second_part = (0.5*torch.trace(torch.transpose(sigma_matrix[i], 0,1) * hessian_z[i] * sigma_matrix[i]))

        L[i] = first_part + second_part
    return L

In [None]:
# Custom loss functions for networks


# Custom loss function for adding the arbitrage condition with MSE
class CustomLoss(nn.Module):
    def __init__(self): 
        super().__init__()
    
    def forward(self, yhat, y, L, w):  
        mse = torch.mean(torch.sum(torch.pow(yhat - y, 2),1)/8 )
        arbitrage_loss = torch.mean(torch.sum(torch.pow(L, 2), 1)/30)
       
        return mse + (w*arbitrage_loss)
    
# Custom loss function, with added lipshitz regulization (not used, but was used during development)
class CustomLipshitzLoss(nn.Module):
    def __init__(self): 
            super().__init__()
    def forward(self, yhat, y, L, w, lipshitz, beta):
        mse = torch.mean(torch.sum(torch.pow((yhat-y), 2),1))
        arbitrage_loss = torch.mean(torch.sum(torch.pow(L, 2),1)*w)     
        return mse + arbitrage_loss + beta * lipshitz
        
# Custom loss function for adding the arbitrage condition with MSE and Kullback Leibler Divergence
class CustomklLoss(nn.Module):
    def __init__(self): 
        super().__init__()
    
    def forward(self, yhat, y, L, w, mean, var): 

        mse = torch.mean(torch.sum(torch.pow(yhat - y, 2),1)/8 )
        arbitrage_loss = torch.mean(torch.sum(torch.pow(L, 2), 1)/30 * w)
        kl = - torch.mean(torch.sum(1 + torch.log(var.pow(2)) - mean.pow(2) - var.pow(2) ,1))
       
        return mse  + 1e-7*kl + arbitrage_loss

In [None]:
# Class for the 2-factor Financed Informed Auto-Encoder model. It should be noted that if one wishes to evaluate the model without the financed informed part,
# one should remove the arbitrage_loss from the CustomLoss class(). While this isnt the most efficient way of doing this, it still works :-) and leaves less
# code clutter.

class FIRAutoEncoder(nn.Module):
     def __init__(self, in_features, latent_dim):
        super().__init__()
        self.name = "FIRAE"
        self.in_features = in_features
        self.latent_dim = latent_dim

        self.encoder = nn.Sequential(
            nn.Linear(self.in_features,self.latent_dim, bias = False)
            
                            )
        self.volatility = nn.Sequential(
            nn_centered_softmax(),
            nn.Linear(self.latent_dim,3, bias = False),
            nn.Linear(3,3, bias = False)
        )
        
        self.drift = nn.Sequential(
            nn_centered_softmax(),
            nn.Linear(self.latent_dim,2, bias = False),
            nn.Linear(2,2,bias =False)
        )
        
        self.decoder = nn.Sequential(
            nn.Linear(3,10, bias = False),
            nn_centered_softmax(),
            nn.Linear(10,1, bias = False)
        )

        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                

     def forward(self,x):
        # AE part of network
        x, maturities = x[:-1][0], x[-1]
        encoded = self.encoder(x)
        decoded = []
        for tau in maturities: 
            taus = torch.full((x.size()[0], 1), tau)
            decoded.append( self.decoder(torch.cat((encoded.requires_grad_(True),taus.requires_grad_(True)), dim=1)) )
        
        
        # financial information part
        mu = self.drift(encoded)
        sigma_1, sigma_2, rho = torch.split(self.volatility(encoded),1,dim=1)
        decode = torch.stack(decoded, dim=1)
        
        return decode.reshape(len(x), 31), encoded, mu, torch.exp(sigma_1), torch.exp(sigma_2), torch.tanh(rho) 

In [None]:
# Training function for the 2-factor Auto-Encoder models, includes both at train step and validation step

def train_ae(model, data, epochs, optimizer, criterion,maturities,swap_maturities, scheduler, valloader):
    losses = []
    outputs = []
    acc_loss = 0
    result_zcs = []
    val_losses = []
    val_acc_loss = 0
    for epoch in range(epochs):
        print("epoch number:",epoch)
        for swaps in data:
          if model.name == "FIRAE":
            regulizer_curve = []
            zc_rates, latents, mu, sigma1, sigma2, rho = model([swaps,maturities])
            short_rate = zc_rates[:,0] 
            
            # constructing sigma
            sigma_matrix = construct_sigma2(sigma1, sigma2, rho) 

            # construct gradients and hessian from forward pass of decoder
            maturities_1 = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30]
            for tau in maturities_1:
              
              # constructing gradients and hessian
              grad_tau, grad_z = construct_grad_z2(tau, latents, model)
              hessian_z = construct_hessian(tau, latents, model)

              # arbitrage condition
              regulizer = arbitrage_condition2(zc_rates[:,tau], short_rate, grad_tau, grad_z, mu, sigma_matrix, hessian_z)             
              regulizer_curve.append(regulizer)
              result_regulizer = torch.cat(regulizer_curve, dim=1)

           

            swaps_calc = []            
            # calculate the swaps:
            for i in range(len(zc_rates)):
               swaps_calc.append(swap_rate_torch_paper(zc_rates[i], swap_maturities))
              
            result_swaps = torch.stack(swaps_calc, dim = 0)
              
            # loss part of the network
            loss = criterion(result_swaps, swaps, result_regulizer, 5) 
            print(loss)
            acc_loss += loss

            optimizer.zero_grad()
            loss.backward(retain_graph=True)
            optimizer.step() 



            # validation step, ie evaluate current model on validation set.
        for val_swaps in valloader:
            if model.name == "FIRAE":
              model.eval()

              
              val_regulizer_curve = []

              val_zc_rates, val_latents, val_mu, val_sigma1, val_sigma2, val_rho = model([val_swaps,maturities])
              val_short_rate = val_zc_rates[:,0]

              val_sigma_matrix = construct_sigma2(val_sigma1, val_sigma2, val_rho)

              for val_tau in maturities_1: 
                 val_grad_tau, val_grad_z = construct_grad_z2(val_tau, val_latents, model)
                 val_hessian_z = construct_hessian(val_tau, val_latents, model)

                 val_regulizer = arbitrage_condition2(val_zc_rates[:,val_tau], val_short_rate, val_grad_tau, val_grad_z, val_mu, val_sigma_matrix, val_hessian_z)
                 val_regulizer_curve.append(val_regulizer)
                
                 val_result_regulizer = torch.cat(val_regulizer_curve, dim=1)
            

              val_swaps_calc = []
              for i in range(len(val_zc_rates)):
                  val_swaps_calc.append(swap_rate_torch_paper(val_zc_rates[i], swap_maturities))
                  
              val_result_swaps = torch.stack(val_swaps_calc, dim = 0)
                
                
              val_loss = criterion(val_result_swaps, val_swaps, val_result_regulizer, 5) 
              val_acc_loss += val_loss
              print("val loss:" , val_acc_loss)
        scheduler.step()


        print("acc loss:",acc_loss/40)
        acc_loss = 0
        val_acc_loss =0 
        losses.append(loss.detach().numpy())
        val_losses.append(val_loss.detach().numpy())
        outputs.append((epochs, swaps.detach().numpy(), result_swaps.detach().numpy()))
        result_zcs.append(zc_rates.detach().numpy())
        
 
    return model, losses, outputs,val_losses

In [None]:
# Class for the 3-factor Financed Informed Auto-Encoder model. It should be noted that if one wishes to evaluate the model without the financed informed part,
# one should remove the arbitrage_loss from the CustomLoss class(). While this isnt the most efficient way of doing this, it still works :-) and leaves less
# code clutter.

class FI3AutoEncoder(nn.Module):
     def __init__(self, in_features, latent_dim):
        super().__init__()
        self.name = "FI3AE"
        self.in_features = in_features
        self.latent_dim = latent_dim
        

        self.encoder = nn.Sequential(
            nn.Linear(self.in_features,self.latent_dim, bias = False)
                            )


        self.volatility = nn.Sequential(
            nn_centered_softmax(),
            nn.Linear(self.latent_dim,6, bias = False),
            nn.Linear(6,6, bias = False)
        )
        
        self.drift = nn.Sequential(
            nn_centered_softmax(),
            nn.Linear(self.latent_dim,3, bias = False),
            nn.Linear(3,3, bias=False)
        )
        
        self.decoder = nn.Sequential(            
            nn.Linear(4,10, bias=False),
            nn_centered_softmax(),
            nn.Linear(10,1,bias=False)
        )

        for m in self.modules():
            if isinstance(m, nn.Linear):

                nn.init.xavier_uniform_(m.weight)
                

     def forward(self,x):
        # AE part of network
        x, maturities = x[:-1][0], x[-1]
        encoded = self.encoder(x)
        decoded = []

        for tau in maturities:
            taus = torch.full((x.size()[0], 1), tau)
            decoded.append(self.decoder(torch.cat((encoded.requires_grad_(True),taus.requires_grad_(True)), dim=1))) 
        
        # financial information part
        mu = self.drift(encoded)
        sigma_1, sigma_2, sigma_3, rho12, rho13, rho23 = torch.split(self.volatility(encoded),1,dim=1)
       
        decode = torch.stack(decoded,dim=1)

        return decode.reshape(len(x),31), encoded, mu, torch.exp(sigma_1), torch.exp(sigma_2), torch.exp(sigma_3), torch.tanh(rho12), torch.tanh(rho13), torch.tanh(rho23)

In [None]:
# Training function for the 3-factor Auto-Encoder models, includes both at train step and validation step

def train_ae3(model, data, epochs, optimizer, criterion,maturities,swap_maturities, scheduler, valloader):
    losses = []
    outputs = []
    acc_loss = 0
    result_zcs = []
    val_losses = []
    val_acc_loss = 0
    for epoch in range(epochs):
        print("epoch number:",epoch)
        for swaps in data:
          if model.name == "FI3AE":
          
            regulizer_curve = []

            zc_rates, latents, mu, sigma1, sigma2, sigma3, rho12, rho13, rho23 = model([swaps,maturities])  
            
            short_rate = zc_rates[:,0]

            # constructing sigma
            sigma_matrix = construct_sigma3(sigma1, sigma2, sigma3, rho12, rho13, rho23)
            maturities_1 = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30]
            for tau in maturities_1:
              # construct gradients and hessian from forward pass of decoder
              grad_tau, grad_z = construct_grad_z3(tau, latents, model)
              hessian_z = construct_hessian(tau, latents, model)

              # arbitrage condition
              regulizer = arbitrage_condition3(zc_rates[:,tau], short_rate, grad_tau, grad_z, mu, sigma_matrix, hessian_z)
              regulizer_curve.append(regulizer)

              result_regulizer = torch.cat(regulizer_curve, dim=1)

          
            swaps_calc = []
            # calculate the swaps:
            for i in range(len(zc_rates)):
               swaps_calc.append(swap_rate_torch_paper(zc_rates[i], swap_maturities))
              
            result_swaps = torch.stack(swaps_calc, dim = 0).requires_grad_(True)
              
            # loss 
            loss = criterion(result_swaps, swaps, result_regulizer, 5) 
            print(loss)
            acc_loss += loss

            optimizer.zero_grad()
            loss.backward(retain_graph=True)
            optimizer.step() 
            
        # Validation step    
        for val_swaps in valloader:
          if model.name == "FI3AE":
            model.eval()
   
            val_regulizer_curve = []

            val_zc_rates, val_latents, val_mu, val_sigma1, val_sigma2, val_sigma3,val_rho12,val_rho13,val_rho23 = model([val_swaps,maturities])
            val_short_rate = val_zc_rates[:,0]

            val_sigma_matrix = construct_sigma3(val_sigma1, val_sigma2, val_sigma3, val_rho12,val_rho13,val_rho23)

            for val_tau in maturities_1:
              val_grad_tau, val_grad_z = construct_grad_z3(val_tau, val_latents, model)
              val_hessian_z = construct_hessian(val_tau, val_latents, model)

              val_regulizer = arbitrage_condition3(val_zc_rates[:,val_tau], val_short_rate, val_grad_tau, val_grad_z, val_mu, val_sigma_matrix, val_hessian_z)
              val_regulizer_curve.append(val_regulizer)
        
              val_result_regulizer = torch.cat(val_regulizer_curve, dim=1)

            val_swaps_calc = []
            # calculate the swaps:
            for i in range(len(val_zc_rates)):
              val_swaps_calc.append(swap_rate_torch_paper(val_zc_rates[i], swap_maturities))
                
            val_result_swaps = torch.stack(val_swaps_calc, dim = 0)
                            
            val_loss = criterion(val_result_swaps, val_swaps, val_result_regulizer, 5) 
            val_acc_loss += val_loss
            print("val loss:" , val_acc_loss)
  
        scheduler.step()


        print("acc loss:",acc_loss/40)
        acc_loss = 0
        val_acc_loss =0 
        losses.append(loss.detach().numpy())
        val_losses.append(val_loss.detach().numpy())
        outputs.append((epochs, swaps.detach().numpy(), result_swaps.detach().numpy()))
        result_zcs.append(zc_rates.detach().numpy())
        
 
    return model, losses, outputs,val_losses

In [None]:
# Class for the 2-factor Financed Informed Variational Auto-Encoder model. It should be noted that if one wishes to evaluate the model without the financed informed
# part, one should remove the arbitrage_loss from the CustomLoss class(). While this isnt the most efficient way of doing this, it still works :-) and leaves less
# code clutter.

class FirVariationalAutoEncoder(nn.Module): 
    def __init__(self, in_features, latent_dim): 
        super().__init__()
        self.name = "FIRVAE"

        self.in_features = in_features
        self.latent_dim = latent_dim
        self.encoder = nn.Sequential(
            nn.Linear(self.in_features,5, bias = False),
            nn.Sigmoid(),
            nn.Linear(5,5, bias = False),
        )

        self.encoder1 = nn.Sequential(
            nn.Linear(5,self.latent_dim, bias = False)            
                            )
        
        self.encoder2 = nn.Sequential(
            nn.Linear(5,self.latent_dim, bias = False)            

        )

        self.decoder = nn.Sequential(
            nn.Linear(3,10, bias = False),
            nn_centered_softmax(),
            nn.Linear(10,1, bias = False)
                            )
        
        self.volatility = nn.Sequential(
            nn_centered_softmax(),
            nn.Linear(self.latent_dim,3, bias = False),
            nn.Linear(3,3, bias = False)
        )
        
        self.drift = nn.Sequential(
            nn_centered_softmax(),
            nn.Linear(self.latent_dim,2, bias = False),
            nn.Linear(2,2, bias = False)
        )
        
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
        
        

    def forward(self,x):
        # VAE part of the network        
        x, maturities = x[:-1][0], x[-1]
        x = self.encoder(x)

        mean1, var1 = torch.split(self.encoder1(x), 1, dim=1)
        mean2, var2 = torch.split(self.encoder2(x),1, dim=1)
        
        N = torch.distributions.Normal(0, 1)

        var1 = torch.exp(var1)
        var2 = torch.exp(var2)
        
        z1 = mean1 + var1*N.sample(mean1.shape)
        z2 = mean2 + var2*N.sample(mean2.shape)

        z = torch.cat((z1,z2), dim=1)

        mean = torch.cat((mean1, mean2), dim = 1)

        var = torch.cat((var1, var2), dim = 1)
        decoded = []
        for tau in maturities:
            taus = torch.full((x.size()[0], 1), tau)
            decoded.append(self.decoder(torch.cat((z.requires_grad_(True),taus), dim=1)))
        

        # finacial information part
        mu = self.drift(z)
        sigma_1, sigma_2, rho = torch.split(self.volatility(z),1,dim=1)
        decode = torch.stack(decoded, dim=1)
        return decode.reshape(len(x),31), z , mu, torch.exp(sigma_1), torch.exp(sigma_2),torch.tanh(rho), var, mean

In [None]:
# Training function for the 2-factor Variational Auto-Encoder models, includes both at train step and validation step


def train_vae(model, data, epochs, optimizer, criterion,maturities,swap_maturities, scheduler,valloader):
    losses = []
    outputs = []
    val_losses = []
    acc_loss = 0
    val_acc_loss = 0
    result_zcs = []
    for epoch in range(epochs):
        print("epoch number:",epoch)
        for swaps in data:
          if model.name == "FIRVAE":
            regulizer_curve = []
            zc_rates, latents, mu, sigma1, sigma2, rho,var,mean =  model([swaps, maturities])
            short_rate = zc_rates[:,0]
       
            # constructing sigma matrix
            sigma_matrix = construct_sigma2(sigma1, sigma2, rho)

            maturities_1 = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30]
            for tau in maturities_1: 

                # construct gradients and hessian from forward pass of decoder
                grad_tau, grad_z = construct_grad_z2(tau, latents, model)
                hessian_z = construct_hessian(tau, latents, model)
            
                # arbitrage condition
                regulizer = arbitrage_condition2(zc_rates[:,tau], short_rate, grad_tau, grad_z, mu, sigma_matrix, hessian_z)
                regulizer_curve.append(regulizer)

                result_regulizer = torch.cat(regulizer_curve, dim=1)
  

            swaps_calc = []
            # calculate the swaps:
            for i in range(len(zc_rates)):
               swaps_calc.append(swap_rate_torch_paper(zc_rates[i], swap_maturities))
              
            result_swaps = torch.stack(swaps_calc, dim = 0)

            # Calculating the loss function     
            loss = criterion(result_swaps, swaps, result_regulizer, 1, mean, var)  
            print(loss)
            acc_loss += loss


            optimizer.zero_grad()
            loss.backward()
            optimizer.step()


        # Validation step
        for val_swaps in valloader: 
            if model.name == "FIRVAE":
              model.eval()
              val_regulizer_curve = []
              val_zc_rates, val_latents, val_mu, val_sigma1, val_sigma2, val_rho,val_var,val_mean = model([val_swaps,maturities])
              
              val_short_rate = val_zc_rates[:,0]
              val_sigma_matrix = construct_sigma2(val_sigma1, val_sigma2, val_rho)

              for val_tau in maturities_1: 
                 val_grad_tau, val_grad_z = construct_grad_z2(val_tau, val_latents, model)
                 val_hessian_z = construct_hessian(val_tau, val_latents, model)

                 val_regulizer = arbitrage_condition2(val_zc_rates[:,val_tau], val_short_rate, val_grad_tau, val_grad_z, val_mu, val_sigma_matrix, val_hessian_z)
                 val_regulizer_curve.append(val_regulizer)
        
                 val_result_regulizer = torch.cat(val_regulizer_curve, dim=1)
    
              val_swaps_calc = []
              # calculate the swaps:
              for i in range(len(val_zc_rates)):
                val_swaps_calc.append(swap_rate_torch_paper(val_zc_rates[i], swap_maturities))
                  
              val_result_swaps = torch.stack(val_swaps_calc, dim = 0)

              val_loss = criterion(val_result_swaps, val_swaps, val_result_regulizer, 1,val_mean,val_var) 
              val_acc_loss += val_loss
              print("val loss:" , val_acc_loss)
          
        scheduler.step()


        print("acc loss:",acc_loss/40)
        acc_loss = 0
        val_acc_loss = 0
        losses.append(loss.detach().numpy())
        val_losses.append(val_loss.detach().numpy())
        outputs.append((epochs, swaps.detach().numpy(), result_swaps.detach().numpy()))
        result_zcs.append(zc_rates.detach().numpy())
        
 
    return model, losses, outputs, val_losses

In [None]:
# Class for the 3-factor Financed Informed Variational Auto-Encoder model. It should be noted that if one wishes to evaluate the model without the financed informed
# part, one should remove the arbitrage_loss from the CustomLoss class(). While this isnt the most efficient way of doing this, it still works :-) and leaves less
# code clutter.


class Fi3VariationalAutoEncoder(nn.Module): 
    def __init__(self, in_features, latent_dim): 
        super().__init__()
        self.name = "FI3VAE"

        self.in_features = in_features
        self.latent_dim = latent_dim
        
        self.encoder = nn.Sequential(
            nn.Linear(self.in_features, 5),
            nn.Sigmoid(),
            nn.Linear(5,5)
        )

        self.encoder1 = nn.Sequential(
            nn.Linear(5,2, bias = False)            
                            )
        self.encoder2 = nn.Sequential(
            nn.Linear(5,2, bias = False)            
                            )
        
        self.encoder3 = nn.Sequential(
            nn.Linear(5,2, bias = False)            
                            )
    
        
        self.decoder = nn.Sequential(
            nn.Linear(4,10, bias = False),
            nn_centered_softmax(),
            nn.Linear(10,1, bias = False)
                            )
        
        self.volatility = nn.Sequential(
            nn_centered_softmax(),
            nn.Linear(self.latent_dim,6, bias = False),
            nn.Linear(6,6, bias = False)
        )
        
        self.drift = nn.Sequential(
            nn_centered_softmax(),
            nn.Linear(self.latent_dim,3, bias = False),
            nn.Linear(3,3)
        )
        
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
        
        

    def forward(self,x):
        # VAE part of the network        

        x, maturities = x[:-1][0], x[-1]
        x = self.encoder(x)

        mean1, var1 = torch.split(self.encoder1(x), 1, dim=1)
        mean2, var2 = torch.split(self.encoder2(x), 1, dim=1)
        mean3, var3 = torch.split(self.encoder3(x), 1, dim=1)
       

        var1 = torch.exp(var1)
        var2 = torch.exp(var2)
        var3 = torch.exp(var3)

        N = torch.distributions.Normal(0, 1)

        
        z1 = mean1 + var1*N.sample(mean1.shape)
        z2 = mean2 + var2*N.sample(mean2.shape)
        z3 = mean3 + var3*N.sample(mean3.shape)

        z = torch.cat((z1,z2, z3), dim=1)
        mean = torch.cat((mean1, mean2, mean3), dim = 1)
        var = torch.cat((var1, var2, var3), dim = 1)

        decoded = []
        for tau in maturities: 
            taus = torch.full((x.size()[0], 1), tau)
            decoded.append(self.decoder(torch.cat((z.requires_grad_(True),taus), dim=1)))
            

        # finacial information part
        mu = self.drift(z)
        sigma_1, sigma_2, sigma_3, rho12, rho13, rho23 = torch.split(self.volatility(z),1,dim=1)
        decode = torch.stack(decoded, dim=1)
        
        return decode.reshape(len(x), 31), z , mu, torch.exp(sigma_1), torch.exp(sigma_2), torch.exp(sigma_3), torch.tanh(rho12), torch.tanh(rho13), torch.tanh(rho23), var, mean

In [None]:
# Training function for the 3-factor Variational Auto-Encoder models, includes both at train step and validation step


def train_vae3(model,data,epochs, optimizer, criterion, maturities, swap_maturities, scheduler, valloader):
    losses = []
    outputs = []
    val_losses = []
    acc_loss = 0
    val_acc_loss = 0
    result_zcs = []
    for epoch in range(epochs):
        print("epoch number:",epoch)
        for swaps in data:
          if model.name == "FI3VAE":
            regulizer_curve = []
            zc_rates, latents, mu, sigma1, sigma2, sigma3, rho12, rho13, rho23,var,mean = model([swaps,maturities])  
            short_rate = zc_rates[:,0]

            # constructing sigma
            sigma_matrix = construct_sigma3(sigma1, sigma2, sigma3, rho12, rho13,rho23)
            
            maturities_1 = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30]
            for tau in maturities_1:
              
              # construct gradients and hessian from forward pass of decoder
              grad_tau, grad_z = construct_grad_z3(tau, latents, model)
              hessian_z = construct_hessian(tau, latents, model)

              # arbitrage condition
              regulizer = arbitrage_condition3(zc_rates[:,tau], short_rate, grad_tau, grad_z, mu, sigma_matrix, hessian_z)
              regulizer_curve.append(regulizer)

              result_regulizer = torch.cat(regulizer_curve, dim=1)

           
            swaps_calc = []
            # calculate the swaps:
            for i in range(len(zc_rates)):
               swaps_calc.append(swap_rate_torch_paper(zc_rates[i], swap_maturities))
              
            result_swaps = torch.stack(swaps_calc, dim = 0)
              
            # loss 
            loss = criterion(result_swaps, swaps, result_regulizer, 1, mean, var) 
            print(loss) 
            acc_loss += loss


            optimizer.zero_grad()
            loss.backward(retain_graph=True)
            optimizer.step()

        # Validation step
        for val_swaps in valloader:
          if model.name == "FI3VAE":
            model.eval()
            val_regulizer_curve = []
            val_zc_rates, val_latents, val_mu, val_sigma1, val_sigma2, val_sigma3,val_rho12,val_rho13,val_rho23, val_var, val_mean = model([val_swaps,maturities])
            
            val_short_rate = val_zc_rates[:,0]
            val_sigma_matrix = construct_sigma3(val_sigma1, val_sigma2, val_sigma3, val_rho12,val_rho13,val_rho23)
            
            for val_tau in maturities_1:
               
              val_grad_tau, val_grad_z = construct_grad_z3(val_tau, val_latents, model)
              val_hessian_z = construct_hessian(val_tau, val_latents, model)

              val_regulizer = arbitrage_condition3(val_zc_rates[:, val_tau], val_short_rate, val_grad_tau, val_grad_z, val_mu, val_sigma_matrix, val_hessian_z)
              val_regulizer_curve.append(val_regulizer)
               
              val_result_regulizer = torch.cat(val_regulizer_curve, dim=1)
    


            val_swaps_calc = []
            for i in range(len(val_zc_rates)):
              val_swaps_calc.append(swap_rate_torch_paper(val_zc_rates[i], swap_maturities))
                
            val_result_swaps = torch.stack(val_swaps_calc, dim = 0)
              
              
            val_loss = criterion(val_result_swaps, val_swaps, val_result_regulizer, 1,val_mean, val_var)
            val_acc_loss += val_loss
            print("val loss:" , val_acc_loss)
        
        scheduler.step()


        print("acc loss:",acc_loss/40)
        acc_loss = 0
        val_acc_loss =0 
        losses.append(loss.detach().numpy())
        val_losses.append(val_loss.detach().numpy())
        outputs.append((epochs, swaps.detach().numpy(), result_swaps.detach().numpy()))
        result_zcs.append(zc_rates.detach().numpy())
 
    return model, losses, outputs, val_losses