In [1]:
import torch
from torch import nn
import einops
import numpy as np
from tqdm import tqdm

In [8]:
# Define models
class NonlinearModel(nn.Module):
    def __init__(self, n, m):
        super().__init__()
        self.w = nn.Parameter(torch.empty((n, m)))
        nn.init.kaiming_normal_(self.w)
        self.b = nn.Parameter(torch.zeros(n))
        self.activ = nn.ReLU()

    def forward(self, x):
        h = einops.einsum(x, self.w, "b n, n m -> b m")
        out = self.activ(einops.einsum(self.w, h, "n m, b m -> b n")  + self.b)
        return out

class LinearModel(nn.Module):
    def __init__(self, n, m):
        super().__init__()
        self.w = nn.Parameter(torch.empty((n, m)))
        nn.init.xavier_normal_(self.w)
        self.b = nn.Parameter(torch.zeros(n))

    def forward(self, x):
        h = einops.einsum(x, self.w, "b n, n m -> b m")
        out = einops.einsum(self.w, h, "n m, b m -> b n")  + self.b
        return out

In [9]:
import numpy as np
import torch

def generate_synthetic_data(batch_size, num_features, sparsity):
    data = np.random.rand(batch_size, num_features)
    mask = (np.random.rand(batch_size, num_features) >= sparsity).astype(float)
    data *= mask
    return torch.tensor(data, dtype=torch.float32)

In [10]:
num_features = 20
hidden_dim = 5
sparsity = 0
importance = torch.tensor([0.9**i for i in range(1, num_features+1)])

In [16]:
nonlinear_model = NonlinearModel(num_features, hidden_dim)
linear_model = LinearModel(num_features, hidden_dim)

In [17]:
# Training parameters
learning_rate = 1e-3
batch_size = 1024
steps = 10_000

def train_model(model, tol=1e-4, patience=10):
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    prev_loss = float('inf')
    counter = 0
    pbar = tqdm(range(steps), position=0, leave=True)
    for i in pbar:
        batch = generate_synthetic_data(batch_size, num_features, sparsity)
        optimizer.zero_grad()
        output = model(batch) 
        loss = torch.mean(importance * (batch - output)**2)
        loss.backward()
        optimizer.step()
        
        pbar.set_description(f"Loss: {loss.item():.4f}")

        # Check convergence
        if abs(prev_loss - loss.item()) < tol:
            counter += 1
            if counter >= patience:
                print(f"Loss: {loss.item():.4f}")
                return
        else:
            counter = 0
        if i % 1000 == 0:
            print(f"Loss: {loss.item():.4f}")
        prev_loss = loss.item()

print("Training NonlinearModel:")
train_model(nonlinear_model)

print("\nTraining LinearModel:")
train_model(linear_model)

Training NonlinearModel:
Loss: 0.8097
Loss: 0.2001
Loss: 0.1208
Loss: 0.1032
Loss: 0.0964
Loss: 0.0906
Loss: 0.0869
Loss: 0.0832
Loss: 0.0835
Loss: 0.0789
Loss: 0.0762
Loss: 0.0740
Loss: 0.0701
Loss: 0.0669
Loss: 0.0657
Loss: 0.0632
Loss: 0.0590
Loss: 0.0550
Loss: 0.0517
Loss: 0.0481
Loss: 0.0448
Loss: 0.0415
Loss: 0.0385
Loss: 0.0361
Loss: 0.0335
Loss: 0.0315
Loss: 0.0298
Loss: 0.0278
Loss: 0.0265
Loss: 0.0255
Loss: 0.0240
Loss: 0.0227
Loss: 0.0225
Loss: 0.0216
Loss: 0.0211
Loss: 0.0206
Loss: 0.0202
Loss: 0.0203
Loss: 0.0199
Loss: 0.0196
Loss: 0.0195
Loss: 0.0194
Loss: 0.0195
Loss: 0.0192
Loss: 0.0196
Loss: 0.0194
Loss: 0.0189
Loss: 0.0189
Loss: 0.0192
Loss: 0.0187
Loss: 0.0189
Loss: 0.0186
Loss: 0.0190
Loss: 0.0187
Loss: 0.0188
Loss: 0.0187
Loss: 0.0188
Loss: 0.0183
Loss: 0.0185
Loss: 0.0185
Loss: 0.0183
Loss: 0.0181
Loss: 0.0180
Loss: 0.0183
Loss: 0.0181
Loss: 0.0180
Loss: 0.0179
Loss: 0.0178
Loss: 0.0178
Loss: 0.0176
Loss: 0.0178
Loss: 0.0177
Loss: 0.0177
Loss: 0.0177
Loss: 0.0177


In [19]:
import plotly.graph_objects as go

# Compute the matrix W^T * W
matrix = linear_model.w.detach().mm(linear_model.w.detach().t()).numpy()

# Create heatmap using plotly
fig = go.Figure(data=go.Heatmap(z=matrix, colorscale='RdBu', zmid=0))
fig.show()

In [25]:
import plotly.graph_objects as go

# Reshape the bias to make it 2-dimensional
bias_2d = linear_model.b.detach().numpy().reshape(-1, 1)

fig = go.Figure(data=go.Heatmap(z=bias_2d, colorscale='RdBu', zmid=0))
fig.show()