<a href="https://colab.research.google.com/github/ashnvael/FProject_Team26/blob/main/OTC_combined.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Sobolev Training for NN solver for both PDE in General Equilibrium

In [None]:
from google.colab import drive
import os
drive.mount(mountpoint = "/content/drive", force_remount = True)
os.chdir("drive/MyDrive/OTC_risk")

#### code on local machine starts here

In [None]:
# Define model name
import os
model_name = "OTC_combined"

# Define model paths
model_load_path = f"./output/{model_name}/pretrained_model.pt"
model_V_save_path = f"./output/{model_name}/trained_model_V.pt"
model_y_save_path = f"./output/{model_name}/trained_model_y.pt"

loading_saved_model = False

# Define figure save path
fig_path = f"./fig/{model_name}/"
os.makedirs(os.path.dirname(model_V_save_path), exist_ok=True)
os.makedirs(os.path.dirname(model_y_save_path), exist_ok=True)
os.makedirs(os.path.dirname(fig_path), exist_ok=True)


In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from scipy.interpolate import RegularGridInterpolator
from util import alpha_fn
import matplotlib.pyplot as plt
from typing import Union, Optional

import numpy as np
from PDE_nn import initialize_grids, initial_guess
from PDE_nn import solve_y_pde, get_agg # functions to get aggregates and y on grid space
from config import CompetitiveSearch
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
param = CompetitiveSearch()
W_grid, S_grid, x_grid, s_grid = initialize_grids(param)

In [None]:
param.sigma = 0.01

In [None]:
class Phase:
    def __init__(self, lr: float,
                 epochs: int,
                 sigma: float = param.sigma,
                 mode: str = "both"):
        self.lr = lr
        self.epochs = epochs
        self.sigma = sigma
        self.mode = mode

PHASES_LIST = [
               Phase(lr = 1e-5, epochs = 100, sigma = param.sigma, mode = "y"),
               Phase(lr = 1e-5, epochs = 100, sigma = param.sigma, mode = "V"),
               Phase(lr = 1e-5, epochs = 100, sigma = param.sigma, mode = "y"),
               Phase(lr = 1e-5, epochs = 100, sigma = param.sigma, mode = "V"),
               Phase(lr = 1e-5, epochs = 100, sigma = param.sigma, mode = "y"),
               Phase(lr = 1e-5, epochs = 100, sigma = param.sigma, mode = "V"),
               Phase(lr = 1e-5, epochs = 100, sigma = param.sigma, mode = "y"),
               Phase(lr = 1e-5, epochs = 100, sigma = param.sigma, mode = "V"),
               Phase(lr = 1e-5, epochs = 100, sigma = param.sigma, mode = "y"),
               Phase(lr = 1e-5, epochs = 100, sigma = param.sigma, mode = "V"),
               Phase(lr = 1e-5, epochs = 100, sigma = param.sigma, mode = "y"),
               Phase(lr = 1e-5, epochs = 100, sigma = param.sigma, mode = "V"),
               Phase(lr = 1e-5, epochs = 100, sigma = param.sigma, mode = "y"),
               Phase(lr = 1e-5, epochs = 100, sigma = param.sigma, mode = "V"),
               Phase(lr = 1e-5, epochs = 100, sigma = param.sigma, mode = "y"),
               Phase(lr = 1e-5, epochs = 100, sigma = param.sigma, mode = "V"),
               Phase(lr = 1e-5, epochs = 100, sigma = param.sigma, mode = "y"),
               Phase(lr = 1e-5, epochs = 100, sigma = param.sigma, mode = "V"),
               Phase(lr = 1e-5, epochs = 100, sigma = param.sigma, mode = "y"),
               Phase(lr = 1e-5, epochs = 100, sigma = param.sigma, mode = "V"),
               Phase(lr = 1e-5, epochs = 100, sigma = param.sigma, mode = "y"),
               Phase(lr = 1e-5, epochs = 100, sigma = param.sigma, mode = "V"),
               Phase(lr = 1e-5, epochs = 100, sigma = param.sigma, mode = "y"),
               Phase(lr = 1e-5, epochs = 100, sigma = param.sigma, mode = "V"),
               Phase(lr = 1e-5, epochs = 100, sigma = param.sigma, mode = "y"),
               Phase(lr = 1e-5, epochs = 100, sigma = param.sigma, mode = "V"),
               Phase(lr = 1e-5, epochs = 100, sigma = param.sigma, mode = "y"),
               Phase(lr = 1e-5, epochs = 100, sigma = param.sigma, mode = "V"),
               Phase(lr = 1e-5, epochs = 100, sigma = param.sigma, mode = "y"),
               Phase(lr = 1e-5, epochs = 100, sigma = param.sigma, mode = "V"),
               Phase(lr = 1e-5, epochs = 100, sigma = param.sigma, mode = "y"),
               Phase(lr = 1e-5, epochs = 100, sigma = param.sigma, mode = "V"),
               Phase(lr = 1e-5, epochs = 100, sigma = param.sigma, mode = "y"),
               Phase(lr = 1e-5, epochs = 100, sigma = param.sigma, mode = "V"),
               Phase(lr = 1e-5, epochs = 100, sigma = param.sigma, mode = "y"),
               Phase(lr = 1e-5, epochs = 100, sigma = param.sigma, mode = "V"),
               ]

In [None]:
class y_net(nn.Module):
    def __init__(self, param, nn_width=50, nn_num_layers=3):
        super(y_net, self).__init__()

        self.input_dim = 2
        self.output_dim = 1
        self.param = param

        # Define input layer
        self.input_layer = nn.Linear(self.input_dim, nn_width)
        self.activation = nn.GELU()

        # Define hidden layers
        self.hidden_layers = nn.ModuleList(
            [nn.Linear(nn_width, nn_width) for _ in range(nn_num_layers)]
        )

        # Define output layer
        self.output_layer = nn.Linear(nn_width, self.output_dim)

    def compute_pertrubation_soln(self, input):
        x, s = input[:, 0], input[:, 1]

        r1 = -0.5 * (1 + 1 / self.param.psi) * self.param.sigma**2 * (
            self.param.gamma1 * x * (s / x)**2 + self.param.gamma2 * (1 - x) * ((1 - s) / (1 - x))**2
        )
        pi1 = (self.param.nu_tilde * self.param.gamma1 * (s / x) +
               (1 - self.param.nu_tilde) * self.param.gamma2 * (1 - s) / (1 - x)) * self.param.sigma**2
        soln = 1 / self.param.q_star + r1 + pi1
        return soln.view(-1, 1)

    def forward(self, input):
        perturbation_soln = self.compute_pertrubation_soln(input)

        # Forward pass with implicit residual connections
        x = self.activation(self.input_layer(input))
        for layer in self.hidden_layers:
            x = x + self.activation(layer(x))

        out = self.output_layer(x)

        return out #+ perturbation_soln # Residual learning

class V_net(nn.Module):
    def __init__(self, param, nn_width=50, nn_num_layers=3):
        super(V_net, self).__init__()

        self.input_dim = 4
        self.output_dim = 1
        self.param = param

        # Define input layer
        self.input_layer = nn.Linear(self.input_dim, nn_width)
        self.activation = nn.GELU()

        # Define hidden layers (without using nn.Sequential)
        self.hidden_layers = nn.ModuleList(
            [nn.Linear(nn_width, nn_width) for _ in range(nn_num_layers)]
        )

        # Define output layer
        self.output_layer = nn.Linear(nn_width, self.output_dim)
        self.epsilon = 0.05

        # Scaling parameters
        W_scale = 1 / self.param.W_min
        S_scale = self.param.S_max - self.param.S_min
        x_scale = self.param.x_max - self.param.x_min
        s_scale = self.param.s_max - self.param.s_min
        scale_tensor = torch.tensor([W_scale, S_scale, x_scale, s_scale])
        self.register_buffer('scale_tensor', scale_tensor)

        subtr_tensor = torch.tensor([1/self.param.W_max, self.param.S_min, self.param.x_min, self.param.s_min])
        self.register_buffer('subtr_tensor', subtr_tensor)

    def compute_pertrubation_soln(self, input):
        W, S, x, s = input[:, 0], input[:, 1], input[:, 2], input[:, 3]
        V0 = self.param.A * W ** (1 - self.param.gamma) / (1 - self.param.gamma)
        r1 = -0.5 * (1 + 1 / self.param.psi) * self.param.sigma**2 * (
            self.param.gamma1 * x * (s / x)**2 + self.param.gamma2 * (1 - x) * ((1 - s) / (1 - x))**2
        )
        pi1 = (self.param.nu_tilde * self.param.gamma1 * (s / x) +
               (1 - self.param.nu_tilde) * self.param.gamma2 * (1 - s) / (1 - x)) * self.param.sigma**2

        V1 = (self.param.A * W ** (1 - self.param.gamma) / self.param.rho_star) * (
            r1 + pi1 * S / (self.param.rho_star * W) - 0.5 * self.param.gamma * self.param.sigma**2 * (S / (self.param.rho_star * W))**2
        )
        soln = V0 + self.epsilon * V1

        return soln.view(-1, 1)

    def forward(self, input):
        input_transformed = input.clone()
        input_transformed[:, 0] = torch.pow(input[:, 0], 1 - self.param.gamma)
        input_transformed = (input_transformed - self.subtr_tensor) / self.scale_tensor

        perturbation_soln = self.compute_pertrubation_soln(input)

        # Forward pass with implicit residual connections
        x = self.activation(self.input_layer(input_transformed))  # Initial transformation
        for layer in self.hidden_layers:
            x = x + self.activation(layer(x))  # Residual connection

        out = self.output_layer(x)

        return out #+ perturbation_soln  # Residual learning

In [None]:
from aggregate_fast import *
from itertools import chain

class Train_NN:
    def __init__(self, param, input_dim, device=DEVICE):
        self.param = param
        self.device = device
        self.model_V = V_net(param, nn_width = 64, nn_num_layers = 5).to(self.device)
        self.model_y = y_net(param = param, nn_width = 64, nn_num_layers = 5).to(self.device)
        self.phases: list[Phase] = PHASES_LIST
        self.pretrain_losses_V = []
        self.pretrain_losses_y = []
        self.train_losses = []
        self.optimizer_y = optim.Adam(self.model_y.parameters(), lr=1e-4) ## ONLY FOR PRETRAINING
        self.optimizer_V = optim.Adam(self.model_V.parameters(), lr=1e-4) ## ONLY FOR PRETRAINING
        self.agg = Aggregates(V_model = self.model_V, y_model = self.model_y,
                              param = self.param, device = self.device)

        ## TODO: make it more explicit and general - pretraining should align with phase 0
        self.param.sigma = self.phases[0].sigma

        # Initialize grids and precomputed aggregates
        W_grid, S_grid, x_grid, s_grid = initialize_grids(param)
        self.x_grid, self.s_grid = x_grid, s_grid
        _, self.V_init, self.y = initial_guess(W_grid, S_grid, x_grid, s_grid, param)
        self.y = torch.tensor(self.y, device = DEVICE)


    def update_training_mode(self, mode: str) -> None:

        if mode == "y" or mode == "train_y":
            self.model_V.eval()
            self.model_y.train()
            for param in self.model_V.parameters():
                param.requires_grad = False
            for param in self.model_y.parameters():
                param.requires_grad = True
        elif mode == "V" or mode == "train_V":
            self.model_V.train()
            self.model_y.eval()
            for param in self.model_V.parameters():
                param.requires_grad = True
            for param in self.model_y.parameters():
                param.requires_grad = False

        elif mode == "both" or mode is None:
            self.model_V.eval()
            self.model_y.eval()
            for param in self.model_V.parameters():
                param.requires_grad = True
            for param in self.model_y.parameters():
                param.requires_grad = True
        else:
            raise ValueError(f"Unknown training mode {mode} (expected 'V', 'y', 'both')")


    def compute_initial_guess_y(self, x, s):
        """
        Compute the initial guess for \( y(x, s) \) using the given formula.
        """
        # Parameters
        q_star = self.param.q_star
        sigma = self.param.sigma
        psi = self.param.psi
        gamma1 = self.param.gamma1
        gamma2 = self.param.gamma2
        nu_tilde = self.param.nu_tilde

        # Compute r1 and pi1
        r1 = -0.5 * (1 + 1 / psi) * sigma**2 * (
            gamma1 * x * (s / x)**2 + gamma2 * (1 - x) * ((1 - s) / (1 - x))**2
        )
        pi1 = (nu_tilde * gamma1 * (s / x) + (1 - nu_tilde) * gamma2 * (1 - s) / (1 - x)) * sigma**2

        # Compute the initial guess for y(x, s)
        y_guess = 1 / q_star + r1 + pi1
        return y_guess


    def compute_initial_guess_V(self, W, S, x, s):
        """
        Compute the initial guess for \( V(W, S, x, s) \) using the given formula.
        """
        A = self.param.A
        gamma = self.param.gamma
        rho_star = self.param.rho_star
        sigma = self.param.sigma
        psi = self.param.psi
        gamma1 = self.param.gamma1
        gamma2 = self.param.gamma2
        nu_tilde = self.param.nu_tilde
        epsilon = 0.05  # Perturbation factor

        # Components of \( V \)
        V0 = A * W ** (1 - gamma) / (1 - gamma)
        r1 = -0.5 * (1 + 1 / psi) * sigma**2 * (gamma1 * x * (s / x)**2 + gamma2 * (1 - x) * ((1 - s) / (1 - x))**2)
        pi1 = (nu_tilde * gamma1 * (s / x) + (1 - nu_tilde) * gamma2 * (1 - s) / (1 - x)) * sigma**2

        V1 = (A * W ** (1 - gamma) / rho_star) * (r1 + pi1 * S / (rho_star * W) -
            0.5 * gamma * sigma**2 * (S / (rho_star * W))**2)

        return V0 + epsilon * V1


    def sample_data_y(self, num_samples, x_s_distance_min = 0.01):
        """
        Sample data points for \( W, S, x, s \) ensuring that \( |x - s| \geq x_s_distance_min \).
        """

        # Sample x
        x_mean = (self.param.x_min + self.param.x_max) / 2
        x_std = (self.param.x_max - self.param.x_min) / 6
        x = torch.normal(x_mean, x_std, size=(num_samples, 1), requires_grad=True).to(self.device)
        x = torch.clamp(x, min=self.param.x_min, max=self.param.x_max)

        # Sample s ensuring |x - s| >= x_s_distance_min
        s_mean = (self.param.s_min + self.param.s_max) / 2
        s_std = (self.param.s_max - self.param.s_min) / 6
        s = torch.normal(s_mean, s_std, size=(num_samples, 1), requires_grad=True).to(self.device)
        s = torch.clamp(s, min=self.param.s_min, max=self.param.s_max)

        # Return the concatenated tensor
        return torch.cat([x, s], dim=1)


    def sample_data_V(self, num_samples, x_s_distance_min = 0.01):
        """
        Sample data points for \( W, S, x, s \) ensuring that \( |x - s| \geq x_s_distance_min \).
        """
        # Sample W
        W_mean = (self.param.W_min + self.param.W_max) / 2
        W_std = (self.param.W_max - self.param.W_min) / 24
        W = torch.normal(W_mean, W_std, size=(num_samples, 1), requires_grad=True).to(self.device)
        W = torch.clamp(W, min=self.param.W_min, max=self.param.W_max)

        # Sample S
        S_mean = (self.param.S_min + self.param.S_max) / 2
        S_std = (self.param.S_max - self.param.S_min) / 6
        S = torch.normal(S_mean, S_std, size=(num_samples, 1), requires_grad=True).to(self.device)
        S = torch.clamp(S, min=self.param.S_min, max=self.param.S_max)

        # Sample x
        x_mean = (self.param.x_min + self.param.x_max) / 2
        x_std = (self.param.x_max - self.param.x_min) / 6
        x = torch.normal(x_mean, x_std, size=(num_samples, 1), requires_grad=True).to(self.device)
        x = torch.clamp(x, min=self.param.x_min, max=self.param.x_max)

        # Sample s ensuring |x - s| >= x_s_distance_min
        s_mean = (self.param.s_min + self.param.s_max) / 2
        s_std = (self.param.s_max - self.param.s_min) / 6
        s = torch.normal(s_mean, s_std, size=(num_samples, 1), requires_grad=True).to(self.device)
        s = torch.clamp(s, min=self.param.s_min, max=self.param.s_max)

        # Enforce |x - s| >= x_s_distance_min
        diff = torch.abs(x - s)
        mask = diff < x_s_distance_min  # Find violations
        while mask.any():
            # Resample only the violated s values
            resample_s = torch.normal(s_mean, s_std, size=(mask.sum(), 1), requires_grad=True).to(self.device)
            resample_s = torch.clamp(resample_s, min=self.param.s_min, max=self.param.s_max)
            s[mask] = resample_s.squeeze()
            diff = torch.abs(x - s)
            mask = diff < x_s_distance_min  # Recheck violations

        # Return the concatenated tensor
        return torch.cat([W, S, x, s], dim=1)


    def sobolev_loss(self, outputs, y_target, X, sobolev_weight=1.0):
        """
        Computes the Sobolev loss (already combined with relative error)

        Parameters:
            outputs: The model predictions of shape (B, 1).
            y_target: The target values (initial guess) of shape (B, 1).
            X: The input tensor of shape (B, 2)
            sobolev_weight: The weighting factor for the derivative (Sobolev) term.

        Returns:
            Tensor:  Sobolev loss.
        """

        value_loss = torch.mean(torch.square((outputs - y_target) / y_target))

        grad_outputs = torch.ones_like(outputs)

        grad_pred = torch.autograd.grad(
            outputs, X, grad_outputs=grad_outputs,
            create_graph=True, retain_graph=True
        )[0]

        grad_target = torch.autograd.grad(
            y_target, X, grad_outputs=grad_outputs,
            create_graph=True, retain_graph=True
        )[0]

        derivative_loss = torch.mean(torch.square(grad_pred - grad_target))

        return value_loss + sobolev_weight * derivative_loss


    def pretrain_y(self, epochs=1000, batch_size=1000, print_freq = 500):
        """
        Pretrain the model using the initial guess.
        """
        self.model_y.train(True)
        for epoch in range(1, epochs + 1):
            # Sample training data
            X_pretrain = self.sample_data_y(batch_size)
            x, s = X_pretrain[:, 0], X_pretrain[:, 1]

            # Compute the initial guess
            # y_pretrain = self.compute_initial_guess(W, S, x, s).view(-1, 1)
            y_pretrain = self.compute_initial_guess_y(x, s).view(-1, 1)
            # print(V_pretrain)

            # Zero gradients
            self.optimizer_y.zero_grad()

            # Forward pass through the model
            outputs = self.model_y(X_pretrain)

            loss = self.sobolev_loss(outputs, y_pretrain, X_pretrain)


            # Backward pass and optimization
            loss.backward()
            self.pretrain_losses_y.append(loss.detach().item())
            self.optimizer_y.step()

            # Monitor training progress
            if epoch % print_freq == 0 or epoch == 1:
                print(f"Pretrain Epoch {epoch}/{epochs}, Loss: {loss.item():.4e}")

            # Early stopping
            if loss.item() < 1e-7:
                print(f"Pretraining for y converged at epoch {epoch} with loss {loss.item():.4e}")
                break

    def pretrain_V(self, epochs=1000, batch_size=1000, print_freq = 500):
        """
        Pretrain the model using the initial guess.
        """
        self.model_V.train(True)
        for epoch in range(1, epochs + 1):
            # Sample training data
            X_pretrain = self.sample_data_V(batch_size)
            W, S, x, s = X_pretrain[:, 0], X_pretrain[:, 1], X_pretrain[:, 2], X_pretrain[:, 3]

            # Compute the initial guess
            # y_pretrain = self.compute_initial_guess(W, S, x, s).view(-1, 1)
            V_pretrain = self.compute_initial_guess_V(W, S, x, s).view(-1, 1)
            # print(V_pretrain)

            # Zero gradients
            self.optimizer_V.zero_grad()

            # Forward pass through the model
            outputs = self.model_V(X_pretrain)

            loss = self.sobolev_loss(outputs, V_pretrain, X_pretrain)


            # Backward pass and optimization
            loss.backward()
            self.pretrain_losses_V.append(loss.detach().item())
            self.optimizer_V.step()

            # Monitor training progress
            if epoch % print_freq == 0 or epoch == 1:
                print(f"Pretrain Epoch {epoch}/{epochs}, Loss: {loss.item():.4e}")

            # Early stopping
            if loss.item() < 1e-7:
                print(f"Pretraining for V converged at epoch {epoch} with loss {loss.item():.4e}")
                break

    class PDELoss(nn.Module):
        def __init__(self, param, y, agg, x_grid, s_grid):
            super().__init__()
            self.param = param
            self.y = torch.tensor(y, dtype=torch.float32, device = DEVICE)
            self.agg = agg

        def forward(self, model_y, model_V, inputs):
            """
            inputs - (B, 4) tensor of (W, S, x, s)
            """

            inputs_V = inputs
            inputs_y = inputs[:, 2:]

            # y and V
            y = model_y(inputs_y).view(-1, )
            V = model_V(inputs_V).view(-1, )

            # derivatives for y and V
            y_grad = torch.autograd.grad(y, inputs_y, grad_outputs=torch.ones_like(y),
                                         create_graph=True, retain_graph=True)[0]
            y_x = y_grad[:, 0]
            y_s = y_grad[:, 1]
            y_xx = torch.autograd.grad(y_x, inputs_y, grad_outputs=torch.ones_like(y_x),
                                       create_graph=True, retain_graph=True)[0][:, 0]

            V_grad = torch.autograd.grad(V, inputs_V, grad_outputs=torch.ones_like(V),
                                         create_graph=True, retain_graph=True)[0]
            V_W = V_grad[:, 0]
            V_S = V_grad[:, 1]
            V_x = V_grad[:, 2]
            V_s = V_grad[:, 3]

            V_WW = torch.autograd.grad(V_W, inputs_V, grad_outputs=torch.ones_like(V_W),
                                    create_graph=True, retain_graph=True)[0][:, 0]
            V_xx = torch.autograd.grad(V_x, inputs_V, grad_outputs=torch.ones_like(V_x),
                                    create_graph=True, retain_graph=True)[0][:, 2]
            V_Wx = torch.autograd.grad(V_W, inputs_V, grad_outputs=torch.ones_like(V_W),
                                    create_graph=True, retain_graph=True)[0][:, 2]

            W, S, x, s = inputs[:, 0], inputs[:, 1], inputs[:, 2], inputs[:, 3]

            # Extract aggregates for y
            r_X = self.agg.compute_r(x, s)
            pi_X = self.agg.compute_pi(x, s)
            mu_x = self.agg.compute_mu_x(x, s)
            mu_s = self.agg.compute_mu_s(x, s)
            sigma_x = self.agg.compute_sigma_x(x, s)
            vd = self.agg.compute_vd(x, s)
            sigma_R = self.agg.compute_sigma_R(x, s)
            hat_r = r_X + self.param.gamma * self.param.sigma**2 - self.param.mu
            hat_pi = pi_X - self.param.gamma * self.param.sigma * sigma_R

            # Policy functions for V
            Omega = torch.where(V_W != 0, V_S / V_W, torch.zeros_like(V_S))
            n = (1 / self.param.chi) * ((1 - self.param.eta) / (2 - self.param.eta)) * Omega * y
            theta = torch.where(Omega != 0,
                                ((self.param.alpha_bar / vd) * torch.abs(Omega) / ((2 - self.param.eta) / y)) ** (1 / (1 - self.param.eta)),
                                torch.zeros_like(Omega))

            # Compute residual for y
            y_term1 = r_X + pi_X
            y_term2 = y + self.param.mu
            y_term3 = (y_x * mu_x + y_s * mu_s) / y + 0.5 * y_xx * sigma_x ** 2 / y
            y_term4 = (y_x * sigma_x / y) ** 2
            y_term5 = self.param.sigma * y_x * sigma_x / y

            y_residual = y_term1 - y_term2 + y_term3 - y_term4 + y_term5
            y_residual_sq = (y_residual / y) ** 2

            # Compute residual for V

            V_term1 = self.param.rho_hat * V
            # WARNING: I removed torch.abs(V_W), which we used so far
            V_term2 = (self.param.gamma / (1 - self.param.gamma)) * (V_W ** ((self.param.gamma - 1) / self.param.gamma))
            V_term3 = V_W * (hat_r * W + (hat_pi * S / y) + (vd * self.param.d_bar / y) -
                           (theta * vd / y * torch.abs(n)) - (self.param.chi * n ** 2 / y * alpha_fn(theta, self.param)))
            V_term4 = V_x * (mu_x + (1 - self.param.gamma) * self.param.sigma * sigma_x)
            V_term5 = V_s * mu_s
            V_term6 = V_S * n * alpha_fn(theta, self.param)
            V_term7 = 0.5 * V_WW * (sigma_R * S / y - self.param.sigma * W) ** 2
            V_term8 = 0.5 * V_xx * sigma_x ** 2
            V_term9 = V_Wx * (sigma_R * S / y - self.param.sigma * W) * sigma_x

            V_residual = V_term1 - V_term2 - (V_term3 + V_term4 + V_term5 + V_term6 + V_term7 + V_term8 + V_term9)
            V_residual_sq = (V_residual / V) ** 2

            # return torch.mean(y_residual_sq) + torch.mean(V_residual_sq)
            return (V_residual_sq, y_residual_sq)


    def train(self, num_samples=128, print_freq=100, loss_thres=1e-10, loss_equiv_const: float = 1e4):
        """
        Train the model to solve the PDE after pretraining.

        Parameters:
            num_samples (int): Number of samples to generate for training at each epoch.
            print_freq (int): Frequency at which to print the training loss.
            loss_thres (float): The threshold value for stopping the training phase when the loss is smaller than this value.
        """
        loss_fn = self.PDELoss(param = self.param, y = self.y, agg = self.agg,
                               x_grid = self.x_grid, s_grid = self.s_grid)

        for phase_indx, phase in enumerate(self.phases):
            self.param.sigma = phase.sigma
            self.optimizer = optim.Adam(chain(self.model_V.parameters(),
                                              self.model_y.parameters()),
                                        lr=phase.lr)
            self.update_training_mode(phase.mode)

            for epoch in range(1, phase.epochs + 1):
                # Sample data points
                data = self.sample_data_V(num_samples).to(self.device)

                # Zero gradients
                self.optimizer.zero_grad()

                # Compute loss
                (V_residual_sq, y_residual_sq) = loss_fn(model_y = self.model_y,
                                                         model_V = self.model_V,
                                                         inputs = data)
                loss_V = torch.mean(V_residual_sq)
                loss_y = torch.mean(y_residual_sq)
                # loss = loss_V + loss_equiv_const * loss_y
                loss = loss_V + loss_y

                print(loss)

                if torch.isnan(loss):
                    print(f"Warning: NaN detected in loss at Phase {phase_indx}, Epoch {epoch}. Skipping update.")

                    torch.save(data, f"problematic_input_phase{phase_indx}_epoch{epoch}.pt")
                    continue

                # Backward pass and optimization
                loss.backward()
                self.optimizer.step()
                self.agg.flush_cache()
                self.train_losses.append([loss_V.detach().item(), loss_y.detach().item()])

                # Monitor training progress
                if epoch % print_freq == 0 or epoch == 1:
                    print(f"Phase {phase_indx}, Epoch {epoch}/{phase.epochs}, Loss: {loss.item():.4e}")
                    # self.visualize_inputs(save_path=fig_path, phase = phase, epoch=epoch)

                # Check if the loss is below the threshold and stop the phase if it is
                if loss.item() < loss_thres:
                    print(f"Phase {phase_indx} stopped early at Epoch {epoch} due to loss threshold {loss_thres}.")
                    break

In [None]:
%%time
# Initialize the Train_NN instance
train_nn = Train_NN(param, input_dim=2, device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))

In [None]:
%%time
# Check whether to pretrain or load the existing model
if loading_saved_model == False:
    print("Starting pretraining...")
    train_nn.pretrain_V(epochs=1_000, batch_size=1024)
    train_nn.pretrain_y(epochs=1_000, batch_size=1024)

    # Save the pretrained model
    torch.save({
        'model_state_dict': train_nn.model_V.state_dict(),
        'optimizer_state_dict': train_nn.optimizer_V.state_dict(),
    }, model_load_path)
    print(f"Pretrained model for V saved at {model_load_path}")

    torch.save({
        'model_state_dict': train_nn.model_y.state_dict(),
        'optimizer_state_dict': train_nn.optimizer_y.state_dict(),
    }, model_load_path)
    print(f"Pretrained model for y saved at {model_load_path}")

    # Plot and save the pretraining loss
    plt.figure()
    plt.plot(train_nn.pretrain_losses_y)
    plt.yscale('log')
    plt.xlabel('Epoch')
    plt.ylabel('Loss during pretrain')
    plt.title('Pretraining Loss for y')
    plt.savefig(f"{fig_path}/pretrain_loss.png")  # Save the plot
    # plt.close()
    print(f"Pretraining loss plot saved at {fig_path}/pretrain_loss_y.png")

    plt.figure()
    plt.plot(train_nn.pretrain_losses_V)
    plt.yscale('log')
    plt.xlabel('Epoch')
    plt.ylabel('Loss during pretrain')
    plt.title('Pretraining Loss for V')
    plt.savefig(f"{fig_path}/pretrain_loss.png")  # Save the plot
    # plt.close()
    print(f"Pretraining loss plot saved at {fig_path}/pretrain_loss_V.png")

else:
    # Load the pretrained model
    checkpoint = torch.load(model_load_path, map_location=train_nn.device)
    train_nn.model.load_state_dict(checkpoint['model_state_dict'])
    train_nn.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    print(f"Model loaded from {model_load_path}")

In [None]:
# # Load the pretrained model
# checkpoint = torch.load(model_save_path, map_location=train_nn.device)
# train_nn.model.load_state_dict(checkpoint['model_state_dict'])
# train_nn.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# print(f"Model loaded from {model_save_path}")

In [None]:
%%time
# After pretraining, proceed with training
train_nn.train(num_samples=1024, print_freq=500)

In [None]:
# Save the trained model and optimizer state
# TODO: not rewritten for a combined trainer
torch.save({
    'model_state_dict': train_nn.model_V.state_dict(),
    'optimizer_state_dict': train_nn.optimizer.state_dict(),
    'training_losses': train_nn.train_losses,
}, model_V_save_path)
print(f"Training solution saved to {model_V_save_path}")
torch.save({
    'model_state_dict': train_nn.model_y.state_dict(),
    'optimizer_state_dict': train_nn.optimizer.state_dict(),
    'training_losses': train_nn.train_losses,
}, model_y_save_path)
print(f"Training solution saved to {model_y_save_path}")

In [None]:
len(train_nn.train_losses)

In [None]:
plt.figure(figsize=(6, 4))
loss_V = [loss[0] for loss in train_nn.train_losses]
loss_y = [loss[1] for loss in train_nn.train_losses]
plt.plot(loss_V, label='Loss_V')
plt.plot(loss_y, label='Loss_y')
plt.yscale('log')
plt.xlabel('Epoch', fontsize=20)
plt.ylabel('Training Loss', fontsize=20)
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.legend()
plt.savefig(fig_path + f"/train_loss.png")
plt.show()

In [None]:
# Define fixed values for W, S, x, s
x_value = 0.9   # Fixed x
s_value = 0.5   # Varying s

# Generate s values for the plot
s_values = torch.linspace(param.s_min, param.s_max, 500).unsqueeze(1).to(train_nn.device)

# Prepare inputs for the neural network
# W_tensor = torch.full_like(s_values, W_value)
x_tensor = torch.full_like(s_values, x_value)
s_tensor = s_values
inputs = torch.cat([x_tensor, s_tensor], dim=1)

# Get the outputs from the neural network (y(x, s))
nn_outputs = train_nn.model_y(inputs).detach().cpu().numpy()

# Compute the initial guess for y(x, s)
initial_guess = train_nn.compute_initial_guess_y(
    x=x_tensor,
    s=s_tensor
).cpu().numpy()

# Plot both the neural network outputs and the initial guess
plt.figure(figsize=(8, 6))
plt.plot(s_values.cpu().numpy(), nn_outputs, label="NN Solution", linewidth=2)
plt.plot(s_values.cpu().numpy(), initial_guess, label="Initial Guess", linestyle="--", linewidth=2)

# Customize the plot
plt.xlabel("s", fontsize=20)
plt.ylabel("y(s)", fontsize=20)
plt.title(f"NN Solution of y(x,s) vs Initial Guess at x={x_value}", fontsize=20)
plt.legend(fontsize=20)
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.savefig(f"{fig_path}/comparison_y_x_{x_value}.png", bbox_inches="tight")


In [None]:
# Define fixed values for W, S, x, s
x_value = 0.1   # Fixed x
s_value = 0.5   # Varying s

# Generate s values for the plot
s_values = torch.linspace(param.s_min, param.s_max, 500).unsqueeze(1).to(train_nn.device)

# Prepare inputs for the neural network
# W_tensor = torch.full_like(s_values, W_value)
x_tensor = torch.full_like(s_values, x_value)
s_tensor = s_values
inputs = torch.cat([x_tensor, s_tensor], dim=1)

# Get the outputs from the neural network (y(x, s))
nn_outputs = train_nn.model_y(inputs).detach().cpu().numpy()

# Compute the initial guess for y(x, s)
initial_guess = train_nn.compute_initial_guess_y(
    x=x_tensor,
    s=s_tensor
).cpu().numpy()

# Plot both the neural network outputs and the initial guess
plt.figure(figsize=(8, 6))
plt.plot(s_values.cpu().numpy(), nn_outputs, label="NN Solution", linewidth=2)
plt.plot(s_values.cpu().numpy(), initial_guess, label="Initial Guess", linestyle="--", linewidth=2)

# Customize the plot
plt.xlabel("s", fontsize=20)
plt.ylabel("y(s)", fontsize=20)
plt.title(f"NN Solution of y(x,s) vs Initial Guess at x={x_value}", fontsize=20)
plt.legend(fontsize=20)
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.savefig(f"{fig_path}/comparison_y_x_{x_value}.png", bbox_inches="tight")


In [None]:
# Define fixed values for W, S, x, s
x_value = 0.5   # Fixed x
s_value = 0.5   # Varying s

# Generate s values for the plot
s_values = torch.linspace(param.s_min, param.s_max, 500).unsqueeze(1).to(train_nn.device)

# Prepare inputs for the neural network
# W_tensor = torch.full_like(s_values, W_value)
x_tensor = torch.full_like(s_values, x_value)
s_tensor = s_values
inputs = torch.cat([x_tensor, s_tensor], dim=1)

# Get the outputs from the neural network (y(x, s))
nn_outputs = train_nn.model_y(inputs).detach().cpu().numpy()

# Compute the initial guess for y(x, s)
initial_guess = train_nn.compute_initial_guess_y(
    x=x_tensor,
    s=s_tensor
).cpu().numpy()

# Plot both the neural network outputs and the initial guess
plt.figure(figsize=(8, 6))
plt.plot(s_values.cpu().numpy(), nn_outputs, label="NN Solution", linewidth=2)
plt.plot(s_values.cpu().numpy(), initial_guess, label="Initial Guess", linestyle="--", linewidth=2)

# Customize the plot
plt.xlabel("s", fontsize=20)
plt.ylabel("y(s)", fontsize=20)
plt.title(f"NN Solution of y(x,s) vs Initial Guess at x={x_value}", fontsize=20)
plt.legend(fontsize=20)
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.savefig(f"{fig_path}/comparison_y_x_{x_value}.png", bbox_inches="tight")


In [None]:
# Define fixed values for S, x, s
S_value = 1.0  # Fixed S
x_value = 0.5  # Fixed x
s_value = 0.5  # Fixed s

# Generate W values for the plot
W_values = torch.linspace(10, 50, 500).unsqueeze(1).to(train_nn.device)

# Prepare inputs for the neural network
S_tensor = torch.full_like(W_values, S_value)
x_tensor = torch.full_like(W_values, x_value)
s_tensor = torch.full_like(W_values, s_value)
inputs = torch.cat([W_values, S_tensor, x_tensor, s_tensor], dim=1)

# Get the outputs from the neural network
nn_outputs = train_nn.model_V(inputs).detach().cpu().numpy()

# Compute the initial guess
initial_guess = train_nn.compute_initial_guess_V(
    W=W_values,
    S=S_tensor,
    x=x_tensor,
    s=s_tensor
).cpu().numpy()

# Plot both the neural network outputs and the initial guess
plt.figure(figsize=(8, 6))
plt.plot(W_values.cpu().numpy(), nn_outputs, label="NN Solution", linewidth=2)
plt.plot(W_values.cpu().numpy(), initial_guess, label="Initial Guess", linestyle="--", linewidth=2)

# Customize the plot
plt.xlabel("W", fontsize=20)
plt.ylabel("V(W)", fontsize=20)
plt.title(f"NN Solution vs Initial Guess at S={S_value}, x={x_value}, s={s_value}", fontsize=20)
plt.legend(fontsize=20)
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)

# Save the plot
fig_path = "./fig"  # Adjust the path as needed
plt.savefig(f"{fig_path}/comparison.png", bbox_inches="tight")
# plt.close()

plt.savefig(fig_path + f"/comparison_V_W_S_{S_value}_x_{x_value}_s_{s_value}_0505.png")


In [None]:
import torch
import matplotlib.pyplot as plt

# Example s-values
s_vals = [0.25, 0.5, 0.75]

# Set up a figure with 1 row and 2 columns of subplots
fig, axes = plt.subplots(1, 2, figsize=(10, 4))

# Optional: define colors/linestyles to differentiate s-values
colors = ['blue', 'red', 'green']
linestyles = ['--', '-', ':']
pi_repr = train_nn.agg.compute_pi(torch.tensor(0.5, requires_grad=True, device = train_nn.device),
                         torch.tensor(0.5, requires_grad=True, device = train_nn.device)).item()
train_nn.agg.flush_cache()


# Left subplot: Risk Premium
for i, s in enumerate(s_vals):
    x = torch.linspace(0.25, 0.75, 100, requires_grad=True, device = train_nn.device)
    s_tensor = torch.full((100,), s, requires_grad=True, device = train_nn.device)

    # Call your function that computes the risk premium pi(x, s)
    pi_vals = train_nn.agg.compute_pi(x, s_tensor) / pi_repr
    train_nn.agg.flush_cache()
    axes[0].plot(
        x.detach().cpu().numpy(),
        pi_vals.detach().cpu().numpy(),
        label=fr"$s={s}$",
        color=colors[i],
        linestyle=linestyles[i],
        linewidth=2
    )

axes[0].set_title("Risk Premium", fontsize=14)
axes[0].set_xlabel("Wealth share $x$", fontsize=12)
axes[0].set_ylabel(r"$\pi/\pi_{repr}$", fontsize=12)
axes[0].grid(True)
axes[0].legend(fontsize=10)

# Right subplot: Interest Rate
for i, s in enumerate(s_vals):
    x = torch.linspace(0.25, 0.75, 100, requires_grad=True, device = train_nn.device)
    s_tensor = torch.full((100,), s, requires_grad=True, device = train_nn.device)

    # Call your function that computes the interest rate r(x, s)
    r_vals = train_nn.agg.compute_r(x, s_tensor)
    train_nn.agg.flush_cache()

    axes[1].plot(
        x.detach().cpu().numpy(),
        r_vals.detach().cpu().numpy(),
        label=fr"$s={s}$",
        color=colors[i],
        linestyle=linestyles[i],
        linewidth=2
    )

axes[1].set_title("Interest Rate", fontsize=14)
axes[1].set_xlabel("Wealth share $x$", fontsize=12)
axes[1].set_ylabel(r"$r(x,s)$", fontsize=12)
axes[1].grid(True)
axes[1].legend(fontsize=10)

# Adjust spacing and show
plt.tight_layout()
plt.savefig("r,pi_plot.png")
plt.show()