In [1]:
import torch.nn as nn
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader

In [12]:
class SparseLinearDataset(Dataset):
    def __init__(self, 
                 total_sequence_length: int = 200, 
                 sparsity: int = 6,
                 num_samples: int = 1000, 
                 noise_std: float = 0.1,
                 input_dist: str = "gaussian",
                 input_std:int = 5,
                 input_range:tuple[float, float] = (-10, 10),
                 true_weight_dist: str = "gaussian",
                 weight_std:int = 5,
                 weight_range:tuple[float, float] = (-10, 10),
                 coefficient:list[int] = None):
        super().__init__()

        self.active_set = torch.randperm(total_sequence_length)[:sparsity]
        self.true_weight = torch.zeros(total_sequence_length).int()
        self.num_samples = num_samples
        if coefficient is not None:
            self.true_weight[self.active_set] = coefficient
        else:
            if true_weight_dist == "gaussian":
                true_weights = torch.randn(sparsity) * weight_std
            else:
                assert true_weight_dist == "uniform", f"unknown distribution {true_weight_dist}."
                low, high = weight_range
                true_weights = torch.empty(sparsity).uniform_(low, high)
        true_weights = true_weights.int()
        self.true_weight[self.active_set] = true_weights
        self.data = []
        if input_dist == "gaussian":
            for _ in range(num_samples):
                x = torch.randn(total_sequence_length) * input_std
                x = x.int()
                y = x @ self.true_weight + noise_std * torch.randn(1)
                x = x.float()
                y = y.float()
                self.data.append((x, y))
        else:
            assert input_dist == "uniform", f"unknown input distribution {input_dist}."
            low, high = input_range
            for _ in range(num_samples):
                x = torch.empty(total_sequence_length).uniform_(low, high).float()
                x = x.int()               
                y = x @ self.true_weight + noise_std * torch.randn(1)
                x = x.float()
                y = y.float()
                self.data.append((x, y.item()))


    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return self.data[idx]


In [13]:
class linear_network(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.network = nn.Linear(in_features=input_dim, out_features=1)
    def forward(self, x):
        return self.network(x).squeeze(1)

In [14]:
class linear_network_hidden(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.network = nn.ModuleList([
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        ])
    def forward(self, x):
        for layer in self.network:
            x = layer(x)
        return x.squeeze(1)

In [15]:
configurations = [
    (20, 6),
    (200, 6),
    (2000, 6)
]
runs = 100
epochs = 4000
patience = 200

### Case 1: noiseless label, inputs and true weights being both i.i.d. gaussian

In [None]:
for (n,k) in configurations:
    for _ in range(runs):
        dataset = SparseLinearDataset(total_sequence_length=n, sparsity=k, num_samples=16000, noise_std=0)
        dataloader = DataLoader(dataset, batch_size=200, shuffle=True)
        true_active_set = dataset.active_set
        model = linear_network(n)
        optimizer = torch.optim.Adam(model.parameters(), lr = 2e-3)

        best_loss = float("inf") 
        epochs_without_improvement = 0 # for early stopping

        model.train()
        loss_fn = nn.MSELoss()
        for idx in range(epochs):
            total_loss = 0
            for (x, y) in dataloader:
                preds = model(x)
                loss = loss_fn(preds, y)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
            avg_loss = total_loss / len(dataloader)
            if avg_loss < best_loss:
                best_loss = avg_loss
                epochs_without_improvement = 0
            else:
                epochs_without_improvement += 1
            if epochs_without_improvement > patience:
                # final_weight = model.network.weight.data.detach()
                print(f"Early stopping at Epoch {idx}, with loss {avg_loss:.4f}.")
                break
            if idx % 20 == 0:
                print(f"Epoch {idx:02d} - Loss: {avg_loss:.4f}")
        

### Case 2: noiseless label, inputs being gaussian, true weights being uniform

### Case 3: noiseless label, inputs and true weights being both uniform

### Case 4: noiseless label, inputs being uniform and true weights being gaussian

### Case 5: noiseless label, inputs being uniform and true weights being 1, -1, ...

### Case 6: noiseless label, inputs being gaussian and true weights being 1, -1, ...