In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import math
import matplotlib.pyplot as plt

#####################################
# GBM-Based Synthetic Dataset (Same as before)
#####################################
class GBMDataset(torch.utils.data.Dataset):
    def __init__(self, length=5000, seq_len=50, S0=100, mu=0.05, sigma=0.2, dt=1/252):
        self.seq_len = seq_len
        total_length = length + seq_len
        prices = torch.zeros(total_length)
        prices[0] = S0
        for t in range(1, total_length):
            Z = torch.randn(1)
            prices[t] = prices[t-1] * torch.exp((mu - 0.5 * sigma**2)*dt + sigma*math.sqrt(dt)*Z)
        self.data = prices
        
    def __len__(self):
        return len(self.data) - self.seq_len
    
    def __getitem__(self, idx):
        x = self.data[idx:idx+self.seq_len]
        y = self.data[idx+self.seq_len]
        return x.unsqueeze(-1), y.unsqueeze(-1)


#####################################
# Two-Layer LSTM Model
#####################################
class TwoLayerLSTM(nn.Module):
    def __init__(self, input_size=1, hidden_size=32, num_layers=2, output_size=1):
        super(TwoLayerLSTM, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers=num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
        
    def forward(self, x):
        out, _ = self.lstm(x)
        last_out = out[:, -1, :]
        preds = self.fc(last_out)
        return preds

#####################################
# Sharpe Ratio Function
#####################################
def sharpe_ratio(returns):
    mean_ret = returns.mean()
    std_ret = returns.std()
    if std_ret == 0:
        return torch.tensor(0.0, device=returns.device)
    return mean_ret / std_ret

#####################################
# Helper: Compute Gradient Norm
#####################################
def compute_grad_norm(model):
    total_norm = 0.0
    for p in model.parameters():
        if p.grad is not None:
            param_norm = p.grad.data.norm(2)
            total_norm += param_norm.item()**2
    return total_norm**0.5

#####################################
# Training Setup
#####################################
seq_len = 50
batch_size = 32
epochs = 5
learning_rate = 0.001

train_dataset = GBMDataset(length=5000, seq_len=seq_len, S0=100, mu=0.05, sigma=0.2, dt=1/252)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

model = TwoLayerLSTM(input_size=1, hidden_size=32, num_layers=2, output_size=1)
criterion_mse = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

lambda_sr = 0.1  # initial guess

model.train()
for epoch in range(1, epochs+1):
    for i, (features, targets) in enumerate(train_loader):
        optimizer.zero_grad()
        
        # Forward pass
        preds = model(features)
        mse_loss = criterion_mse(preds, targets)

        # Compute returns based on last price in the input sequence
        last_price = features[:, -1, 0]  # shape: (batch,)
        predicted_returns = (preds.squeeze() - last_price) / last_price
        sr = sharpe_ratio(predicted_returns)
        
        # We want to find gradient norms for MSE and SR separately
        # 1) Grad norm for MSE
        # Backprop MSE alone
        optimizer.zero_grad()
        mse_loss.backward(retain_graph=True)
        gradnorm_mse = compute_grad_norm(model)
        
        # 2) Grad norm for SR
        # Clear grads and backprop SR alone
        # SR appears in loss as (-lambda_sr * SR), so just backprop SR alone (as if loss = SR)
        optimizer.zero_grad()
        sr.backward(retain_graph=True)
        gradnorm_sr = compute_grad_norm(model)

        # 3) Adjust lambda_sr so gradnorm(MSE) = lambda_sr * gradnorm(SR)
        # Avoid division by zero
        if gradnorm_sr > 1e-12:
            new_lambda_sr = gradnorm_mse / gradnorm_sr
        else:
            new_lambda_sr = lambda_sr  # if sr grad is negligible, keep lambda_sr same

        lambda_sr = new_lambda_sr

        # 4) Now do final backward with combined loss = MSE - lambda_sr * SR
        optimizer.zero_grad()
        loss = mse_loss - lambda_sr * sr
        loss.backward()
        optimizer.step()

        if (i+1) % 100 == 0:
            print(f"Epoch [{epoch}], Batch [{i+1}] Loss: {loss.item():.4f}, lambda_sr: {lambda_sr:.4f}, gradnorm_mse: {gradnorm_mse:.4f}, gradnorm_sr: {gradnorm_sr:.4f}")

