In [8]:
import torch
from torch import nn
import einops
from tqdm import tqdm

In [9]:
# Define models
class NonlinearModel(nn.Module):
    def __init__(self, n, m):
        super().__init__()
        self.w = nn.Parameter(torch.empty((n, m)))
        nn.init.kaiming_normal_(self.w)
        self.b = nn.Parameter(torch.zeros(n))
        self.activ = nn.ReLU()

    def forward(self, x):
        h = einops.einsum(x, self.w, "b n, n m -> b m")
        out = self.activ(einops.einsum(self.w, h, "n m, b m -> b n")  + self.b)
        return out

class LinearModel(nn.Module):
    def __init__(self, n, m):
        super().__init__()
        self.w = nn.Parameter(torch.empty((n, m)))
        nn.init.xavier_normal_(self.w)
        self.b = nn.Parameter(torch.zeros(n))

    def forward(self, x):
        h = einops.einsum(x, self.w, "b n, n m -> b m")
        out = einops.einsum(self.w, h, "n m, b m -> b n")  + self.b
        return out

In [10]:
import numpy as np
import torch

def generate_synthetic_data(batch_size, num_features, sparsity):
    data = np.random.rand(batch_size, num_features)
    mask = (np.random.rand(batch_size, num_features) >= sparsity).astype(float)
    data *= mask
    return torch.tensor(data, dtype=torch.float32)

In [11]:
num_features = 20
hidden_dim = 5
sparsity = 0
importance = torch.tensor([0.9**i for i in range(1, num_features+1)])

In [12]:
nonlinear_model = NonlinearModel(num_features, hidden_dim)
linear_model = LinearModel(num_features, hidden_dim)

In [13]:
# Training parameters
learning_rate = 1e-3
batch_size = 1024
steps = 10_000

def train_model(model, tol=1e-4, patience=10):
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    prev_loss = float('inf')
    counter = 0
    pbar = tqdm(range(steps), position=0, leave=True)
    for _ in pbar:
        batch = generate_synthetic_data(batch_size, num_features, sparsity)
        optimizer.zero_grad()
        output = model(batch) 
        loss = torch.mean(importance * (batch - output)**2)
        loss.backward()
        optimizer.step()
        
        pbar.set_description(f"Loss: {loss.item():.4f}")

        # Check convergence
        if abs(prev_loss - loss.item()) < tol:
            counter += 1
            if counter >= patience:
                print(f"Loss: {loss.item():.4f}")
                return
        else:
            counter = 0
        prev_loss = loss.item()

print("Training NonlinearModel:")
train_model(nonlinear_model)

print("\nTraining LinearModel:")
train_model(linear_model)

Training NonlinearModel:


Loss: 0.0176:  98%|█████████▊| 9828/10000 [00:27<00:00, 356.79it/s]


Loss: 0.0176

Training LinearModel:


Loss: 0.0178: 100%|██████████| 10000/10000 [00:26<00:00, 376.01it/s]


In [14]:
import plotly.graph_objects as go

# Compute the matrix W^T * W
matrix = linear_model.w.detach().mm(linear_model.w.detach().t()).numpy()

# Create heatmap using plotly
fig = go.Figure(data=go.Heatmap(z=matrix, colorscale='RdBu', zmid=0))
fig.show()

In [17]:
import plotly.graph_objects as go

# Reshape the bias to make it 2-dimensional
bias_2d = linear_model.b.detach().numpy().reshape(-1, 1)

fig = go.Figure(data=go.Heatmap(z=bias_2d, colorscale='RdBu', zmid=0))
fig.show()