# AutoInt with Neural Control Variates


This notebook extends the original AutoInt example to implement the Neural Control Variates (NCV) framework for 1D integration. 
We use a control variate \( g_	heta(x) \) to reduce the variance of Monte Carlo integration.

The steps include:
1. Defining the control variate network \( g_	heta(x) \).
2. Modifying the training loop to minimize the variance \( \mathbb{E}[(f(x) - g_	heta(x))^2] \).
3. Using AutoInt to compute the integral of \( g_	heta(x) \).
4. Evaluating the integral with reduced variance.


## Imports

In [None]:

import torch
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader

from autoint.session import Session
import autoint.autograd_modules as autoint


## Function Definitions

In [None]:

# Define the target function to integrate
def target_fn(x):
    return torch.cos(5 * x) + torch.sin(2 * x)

# Dataset class to sample coordinates and function values
class Implicit1DWrapper(torch.utils.data.Dataset):
    def __init__(self, range_, fn, sampling_density=1000):
        self.range = torch.linspace(range_[0], range_[1], sampling_density)
        self.fn = fn

    def __len__(self):
        return len(self.range)

    def __getitem__(self, idx):
        x = self.range[idx].unsqueeze(0)
        y = self.fn(x)
        return x, y


## Define Networks

In [None]:

# Define the control variate network (simple MLP)
class ControlVariateNet(torch.nn.Module):
    def __init__(self, input_dim=1, hidden_dim=32, output_dim=1):
        super().__init__()
        self.net = torch.nn.Sequential(
            torch.nn.Linear(input_dim, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        return self.net(x)

# Define the integral network (SIREN from AutoInt)
class SIREN(autoint.MetaModule):
    def __init__(self, session):
        super().__init__()
        self.input = autoint.Input(torch.Tensor(1, 1), id='x_coords')
        self.net = []
        self.net.append(autoint.Linear(1, 32))
        self.net.append(autoint.SinActivation())
        self.net.append(autoint.Linear(32, 1))
        session.add_modules(self, self.net)


## Initialize Networks and Session

In [None]:

# Create AutoInt session and networks
session = Session()
integral_net = SIREN(session)
control_variate_net = ControlVariateNet()

# Move networks to GPU if available
device = torch.device("mps" if  torch.backends.mps.is_available() else "cpu")
integral_net.to(device)
control_variate_net.to(device)


## Training the Networks

In [None]:

# Create dataset and dataloader
dataset = Implicit1DWrapper(range_=[-1, 10], fn=target_fn)
dataloader = DataLoader(dataset, shuffle=True, batch_size=32)

# Optimizers and loss function
optimizer = torch.optim.Adam(list(integral_net.parameters()) + list(control_variate_net.parameters()), lr=1e-4)
loss_fn = torch.nn.MSELoss()

# Training loop
epochs = 1000
losses = []

for epoch in range(epochs):
    epoch_loss = 0
    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        
        # Compute control variate
        g = control_variate_net(x)
        
        # Compute the loss
        loss = loss_fn(y, g)
        epoch_loss += loss.item()
        
        # Backpropagation and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    losses.append(epoch_loss / len(dataloader))
    if epoch % 50 == 0:
        print(f"Epoch {epoch}/{epochs}, Loss: {epoch_loss:.4f}")

# Plot the training loss
plt.plot(losses)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training Loss for NCV")
plt.show()


## Evaluate the Integral

In [None]:

# Evaluate the integral using the control variate
x_coords = torch.linspace(-1, 10, 100).unsqueeze(1).to(device)
true_vals = target_fn(x_coords).cpu().detach().numpy()
control_vals = control_variate_net(x_coords).cpu().detach().numpy()

# Compute the integral using AutoInt for the control variate
session_input = {'x_coords': x_coords}
integral_vals = session.compute_graph(session_input).detach().cpu().numpy()

# Visualize results
plt.plot(x_coords.cpu().numpy(), true_vals, label="Target Function")
plt.plot(x_coords.cpu().numpy(), control_vals, label="Control Variate")
plt.legend()
plt.show()
