In [3]:
import torch.nn as nn
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader

In [40]:
class SparseLinearDataset(Dataset):
    def __init__(self, 
                 total_sequence_length: int = 200, 
                 sparsity: int = 6,
                 num_samples: int = 1000, 
                 noise_std: float = 0.1,
                 input_dist: str = "gaussian",
                 input_std:int = 5,
                 input_range:tuple[float, float] = (-10, 10),
                 true_weight_dist: str = "gaussian",
                 weight_std:int = 5,
                 weight_range:tuple[float, float] = (-10, 10),
                 coefficient:list[int] = None):
        super().__init__()

        self.active_set = torch.randperm(total_sequence_length)[:sparsity]
        self.true_weight = torch.zeros(total_sequence_length).int()
        self.num_samples = num_samples
        if coefficient is not None:
            true_weights = torch.Tensor(coefficient)
        else:
            if true_weight_dist == "gaussian":
                true_weights = torch.randn(sparsity) * weight_std
            else:
                assert true_weight_dist == "uniform", f"unknown distribution {true_weight_dist}."
                low, high = weight_range
                true_weights = torch.empty(sparsity).uniform_(low, high)
        true_weights = true_weights.int()
        self.true_weight[self.active_set] = true_weights
        self.data = []
        if input_dist == "gaussian":
            for _ in range(num_samples):
                x = torch.randint(0, 10, (total_sequence_length,)).int()
                y = x @ self.true_weight + noise_std * torch.randn(1)
                x = x.float()
                y = y.float()
                self.data.append((x, y))
        else:
            assert input_dist == "uniform", f"unknown input distribution {input_dist}."
            low, high = input_range
            for _ in range(num_samples):
                x = torch.empty(total_sequence_length).uniform_(low, high).float()
                x = x.int()               
                y = x @ self.true_weight + noise_std * torch.randn(1)
                x = x.float()
                y = y.float()
                self.data.append((x, y.item()))


    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return self.data[idx]


In [41]:
class linear_network(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.network = nn.Linear(in_features=input_dim, out_features=1)
    def forward(self, x):
        return self.network(x).squeeze(1)

In [42]:
class linear_network_hidden(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.network = nn.ModuleList([
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        ])
    def forward(self, x):
        for layer in self.network:
            x = layer(x)
        return x.squeeze(1)

In [65]:
configurations = [
    (50, 6),
    (400, 6),
    (1000, 6)
]
runs = 1
epochs = 500

In [66]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [67]:
import matplotlib.pyplot as plt
import numpy as np

def plot_weight_comparison(learned_weights, true_weights, true_active_set, title="", save_path=None):
    n = len(learned_weights)
    x = np.arange(n)

    # Build color scheme
    bar_colors = ['red' if i in true_active_set else 'blue' for i in range(n)]
    bar_alpha = [1.0 if i in true_active_set else 0.3 for i in range(n)]

    plt.figure(figsize=(12, 3))

    # Plot learned weights
    for i in range(n):
        plt.bar(i, learned_weights[i], color=bar_colors[i], alpha=bar_alpha[i], width=0.8)

    # Overlay true weights with dashed black outlines
    plt.plot(x, true_weights, color='black', linestyle='--', linewidth=1.5, label='True Weights')

    # Highlight true active set weights
    plt.scatter(true_active_set, true_weights[true_active_set], color='black', zorder=5)

    plt.title(title)
    plt.xlabel("Weight Index")
    plt.ylabel("Value")
    plt.legend()
    plt.tight_layout()

    if save_path:
        plt.savefig(save_path)
        plt.close()
    else:
        plt.show()


### Case 1: noiseless label, inputs and true weights being both i.i.d. gaussian

In [68]:

for (n,k) in configurations:
    losses_across_runs = []
    stored_weights = 0
    for run_id in range(runs):
        losses = []
        dataset = SparseLinearDataset(total_sequence_length=n, sparsity=k, num_samples=1600, noise_std=0)
        dataloader = DataLoader(dataset, shuffle=True)
        true_active_set = dataset.active_set
        true_weight = dataset.true_weight
        model = linear_network(n)
        model.to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr = 5e-3)

        model.train()
        loss_fn = nn.MSELoss()
        for idx in range(epochs):
            total_loss = 0
            for (x, y) in dataloader:
                x = x.to(device)
                y = y.to(device)
                preds = model(x)
                loss = loss_fn(preds, y)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
            avg_loss = total_loss / len(dataloader)

            if idx % 20 == 0:
                print(f"Epoch {idx:02d} - Loss: {avg_loss:.4f}")
            losses.append(avg_loss)
        losses_across_runs.append(losses)
        # After training completes per run:
        learned_weights = model.network.weight.data.detach().cpu().squeeze()

        plot_weight_comparison(
            learned_weights=learned_weights.numpy(),
            true_weights=true_weight.numpy(),
            true_active_set=true_active_set,
            title=f"n={n}, k={k} - Learned vs True Weights",
            save_path=f"../log/plots/linear/n_{n}_k_{k}.png"
        )

    np.save(f"../log/n_{n}_k_{k}_50runs_linear.npy", np.array(losses_across_runs))
        

  return F.mse_loss(input, target, reduction=self.reduction)


Epoch 00 - Loss: 1061.6757
Epoch 20 - Loss: 0.1280
Epoch 40 - Loss: 0.1479
Epoch 60 - Loss: 0.2025
Epoch 80 - Loss: 0.1790
Epoch 100 - Loss: 0.1659
Epoch 120 - Loss: 0.2353
Epoch 140 - Loss: 0.2181
Epoch 160 - Loss: 0.1844
Epoch 180 - Loss: 0.1693
Epoch 200 - Loss: 0.1535
Epoch 220 - Loss: 0.1903
Epoch 240 - Loss: 0.2487
Epoch 260 - Loss: 0.1622
Epoch 280 - Loss: 0.1657
Epoch 300 - Loss: 0.2317
Epoch 320 - Loss: 0.0781
Epoch 340 - Loss: 0.1648
Epoch 360 - Loss: 0.1320
Epoch 380 - Loss: 0.1766
Epoch 400 - Loss: 0.1930
Epoch 420 - Loss: 0.2154
Epoch 440 - Loss: 0.2096
Epoch 460 - Loss: 0.2449
Epoch 480 - Loss: 0.1312
Epoch 00 - Loss: 875.8969
Epoch 20 - Loss: 10.4376
Epoch 40 - Loss: 12.2000
Epoch 60 - Loss: 8.9046
Epoch 80 - Loss: 13.6201
Epoch 100 - Loss: 12.3132
Epoch 120 - Loss: 11.6207
Epoch 140 - Loss: 13.9156
Epoch 160 - Loss: 8.6433
Epoch 180 - Loss: 11.7261
Epoch 200 - Loss: 10.8515
Epoch 220 - Loss: 10.3456
Epoch 240 - Loss: 11.0494
Epoch 260 - Loss: 11.5550
Epoch 280 - Loss: 1

### Case 2: noiseless label, inputs being gaussian, true weights being uniform

### Case 3: noiseless label, inputs and true weights being both uniform

### Case 4: noiseless label, inputs being uniform and true weights being gaussian

### Case 5: noiseless label, inputs being uniform and true weights being 1, -1, ...

### Case 6: noiseless label, inputs being gaussian and true weights being 1, -1, ...