### Synthetic dataset generation

In [1]:
import torch
from torch.utils.data import TensorDataset, DataLoader

torch.manual_seed(1337)  

# Parameters
num_samples = 1000  #Corresponds to m in the paper
dim_x = 10  
dim_y = 1   

# Dataset creation
X = torch.randn(num_samples, dim_x)
true_weights = torch.randn(dim_x, dim_y)
Y = X @ true_weights + 0.1 * torch.randn(num_samples, dim_y) #Y = XW + noise

# Create a TensorDataset and DataLoader
dataset = TensorDataset(X, Y)
batch_size = 32
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


### Model instantiation, optimizer and loss function

In [2]:
import torch.nn as nn
import sys
import os

# Import 
current_dir = os.getcwd()
parent_dir = os.path.abspath(os.path.join(current_dir, '..'))
if parent_dir not in sys.path:
    sys.path.append(parent_dir)

from Utils.classes import LinearNN

# Define hidden layer dimensions
hidden_dims = [64, 32]

# Instantiate the model
# Directly pass hidden dimensions as positional arguments
model = LinearNN(dim_x=10, dim_y=1, hidden_dims=hidden_dims)


# Define the loss function
criterion = nn.MSELoss()

# Define the optimizer
learning_rate = 0.01
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

### Training loop

In [3]:
# Training parameters
num_epochs = 100

# Training loop
for epoch in range(num_epochs):
    model.train()  # Set the model to training mode
    epoch_loss = 0.0

    for batch_X, batch_Y in dataloader:
        # Zero the gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(batch_X)

        # Compute the loss
        loss = criterion(outputs, batch_Y)

        # Backward pass
        loss.backward()

        # Update the weights
        optimizer.step()

        # Accumulate loss
        epoch_loss += loss.item() * batch_X.size(0)

    # Compute average loss for the epoch
    avg_loss = epoch_loss / num_samples

    if (epoch + 1) % 10 == 0 or epoch == 0:
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")

Epoch [1/100], Loss: 0.8296
Epoch [10/100], Loss: 0.0140
Epoch [20/100], Loss: 0.0193
Epoch [30/100], Loss: 0.0176
Epoch [40/100], Loss: 0.0141
Epoch [50/100], Loss: 0.0152
Epoch [60/100], Loss: 0.0134
Epoch [70/100], Loss: 0.0128
Epoch [80/100], Loss: 0.0225
Epoch [90/100], Loss: 0.0143
Epoch [100/100], Loss: 0.0152
