In [1]:
from lightning import Fabric
import numpy as np
from sklearn.linear_model import SGDRegressor
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm.auto import tqdm, trange

from causalpruner.sgd_pruner import ParamDataset

torch.set_float32_matmul_precision('medium')

In [2]:
fb = Fabric()
fb.launch()

In [3]:
ds = ParamDataset("./checkpoints/mlpnet_mnist_causalpruner_1_10_5_0_0.98/weights/0",
                  "./checkpoints/mlpnet_mnist_causalpruner_1_10_5_0_0.98/loss/0",
                  True)
dl = DataLoader(ds, batch_size=len(ds))
dl = fb.setup_dataloaders(dl)

2.8429925514467103e-12; 1.6184477369396433e-11
-8.992925722850487e-05; 3.0292972951428965e-05


In [4]:
X, Y = next(iter(dl))

In [7]:
topk = torch.topk(torch.mean(X, dim=0), k=1, largest=True)
print(topk)

torch.return_types.topk(
values=tensor([30.3825], device='cuda:0'),
indices=tensor([31554], device='cuda:0'))


In [8]:
X[:,topk.indices[0]]

tensor([ 3.2132, 18.6051, 25.2001,  ...,  8.6268, 10.5456, 22.4080],
       device='cuda:0')

In [9]:
Y

tensor([1.0002, 1.2474, 0.8577,  ..., 0.9133, 1.0585, 1.8976], device='cuda:0')

In [10]:
X_abs = torch.abs(X)
Y_abs = torch.abs(Y)

print(f"{torch.max(X_abs)}; {torch.min(X_abs)}")
print(f"{torch.max(Y_abs)}; {torch.min(Y_abs)}")

526.0156860351562; 2.666211912583094e-06
2.7532496452331543; 0.2631295323371887


In [11]:
dl = DataLoader(ds, batch_size=16, shuffle=True)
dl = fb.setup_dataloaders(dl)

In [12]:
num_params = ds.weights_zstats.num_params
l1_regularization_coeff = 0
lr = 1e-3

In [13]:
ds.weights_zstats

ZStats(num_params=32360, mean=tensor([4.0807e-13, 4.0808e-13, 4.0808e-13,  ..., 1.5538e-11, 7.5351e-13,
        4.9722e-12]), std=tensor([6.3572e-13, 6.3571e-13, 6.3571e-13,  ..., 3.3823e-11, 4.1001e-12,
        9.4806e-12]), global_mean=tensor(2.8430e-12), global_std=tensor(1.6184e-11))

In [14]:
ds.loss_zstats

ZStats(num_params=1, mean=tensor(-8.9929e-05), std=tensor(3.0293e-05), global_mean=tensor(-8.9929e-05), global_std=tensor(3.0293e-05))

In [15]:
num_epochs = int(np.ceil(np.log(num_params / len(ds))))
print(num_epochs)
num_epochs=20

4


In [16]:
model = nn.Linear(num_params, 1, bias=False)
nn.init.zeros_(model.weight)
abs_weight = torch.abs(model.weight)
print(f"{torch.max(abs_weight)}; {torch.min(abs_weight)}")
optimizer = optim.SGD(model.parameters(), lr=lr)
model, optimizer = fb.setup(model, optimizer)
model.train()
dl = DataLoader(ds, batch_size=16, shuffle=True)
dl = fb.setup_dataloaders(dl)
dl_iter = iter(dl)

0.0; 0.0


In [None]:
model.train()
for epoch in trange(100):
    total_loss = 0
    num_batches = 0
    for idx, (X, Y) in enumerate(tqdm(dl)):
        optimizer.zero_grad(set_to_none=True)
        outputs = model(X)
        Y = Y.view(outputs.size())
        loss = F.mse_loss(outputs, Y, reduction="mean")
        total_loss += loss.item()
        num_batches += 1
        tqdm.write(f"epoch: {epoch + 1}; batch: {idx + 1}; Loss: {loss.item()}")
        fb.backward(loss)
        optimizer.step()
    avg_loss = total_loss / num_batches
    tqdm.write(f"epoch: {epoch + 1}; avg_loss: {avg_loss}")

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/74 [00:00<?, ?it/s]

epoch: 1; batch: 1; Loss: 0.9917110800743103
epoch: 1; batch: 2; Loss: 263.8990783691406
epoch: 1; batch: 3; Loss: 112447.390625
epoch: 1; batch: 4; Loss: 34414484.0
epoch: 1; batch: 5; Loss: 7719519232.0
epoch: 1; batch: 6; Loss: 1601117880320.0
epoch: 1; batch: 7; Loss: 438176455852032.0
epoch: 1; batch: 8; Loss: 6.725357003813683e+16
epoch: 1; batch: 9; Loss: 1.3028591565075907e+19
epoch: 1; batch: 10; Loss: 4.4868980281523925e+21
epoch: 1; batch: 11; Loss: 1.0956368702681989e+24
epoch: 1; batch: 12; Loss: 1.9223794721070493e+26
epoch: 1; batch: 13; Loss: 4.0755991973988395e+28
epoch: 1; batch: 14; Loss: 1.7680411965727073e+31
epoch: 1; batch: 15; Loss: 8.937979696092115e+33
epoch: 1; batch: 16; Loss: 3.001955077665476e+36
epoch: 1; batch: 17; Loss: inf
epoch: 1; batch: 18; Loss: inf
epoch: 1; batch: 19; Loss: inf
epoch: 1; batch: 20; Loss: inf
epoch: 1; batch: 21; Loss: inf
epoch: 1; batch: 22; Loss: inf
epoch: 1; batch: 23; Loss: inf
epoch: 1; batch: 24; Loss: inf
epoch: 1; batch:

  0%|          | 0/74 [00:00<?, ?it/s]

epoch: 2; batch: 1; Loss: nan
epoch: 2; batch: 2; Loss: nan
epoch: 2; batch: 3; Loss: nan
epoch: 2; batch: 4; Loss: nan
epoch: 2; batch: 5; Loss: nan
epoch: 2; batch: 6; Loss: nan
epoch: 2; batch: 7; Loss: nan
epoch: 2; batch: 8; Loss: nan
epoch: 2; batch: 9; Loss: nan
epoch: 2; batch: 10; Loss: nan
epoch: 2; batch: 11; Loss: nan
epoch: 2; batch: 12; Loss: nan
epoch: 2; batch: 13; Loss: nan
epoch: 2; batch: 14; Loss: nan
epoch: 2; batch: 15; Loss: nan
epoch: 2; batch: 16; Loss: nan
epoch: 2; batch: 17; Loss: nan
epoch: 2; batch: 18; Loss: nan
epoch: 2; batch: 19; Loss: nan
epoch: 2; batch: 20; Loss: nan
epoch: 2; batch: 21; Loss: nan
epoch: 2; batch: 22; Loss: nan
epoch: 2; batch: 23; Loss: nan
epoch: 2; batch: 24; Loss: nan
epoch: 2; batch: 25; Loss: nan
epoch: 2; batch: 26; Loss: nan
epoch: 2; batch: 27; Loss: nan
epoch: 2; batch: 28; Loss: nan
epoch: 2; batch: 29; Loss: nan
epoch: 2; batch: 30; Loss: nan
epoch: 2; batch: 31; Loss: nan
epoch: 2; batch: 32; Loss: nan
epoch: 2; batch: 

  0%|          | 0/74 [00:00<?, ?it/s]

epoch: 3; batch: 1; Loss: nan
epoch: 3; batch: 2; Loss: nan
epoch: 3; batch: 3; Loss: nan
epoch: 3; batch: 4; Loss: nan
epoch: 3; batch: 5; Loss: nan
epoch: 3; batch: 6; Loss: nan
epoch: 3; batch: 7; Loss: nan
epoch: 3; batch: 8; Loss: nan
epoch: 3; batch: 9; Loss: nan
epoch: 3; batch: 10; Loss: nan
epoch: 3; batch: 11; Loss: nan
epoch: 3; batch: 12; Loss: nan
epoch: 3; batch: 13; Loss: nan
epoch: 3; batch: 14; Loss: nan
epoch: 3; batch: 15; Loss: nan
epoch: 3; batch: 16; Loss: nan
epoch: 3; batch: 17; Loss: nan
epoch: 3; batch: 18; Loss: nan
epoch: 3; batch: 19; Loss: nan
epoch: 3; batch: 20; Loss: nan
epoch: 3; batch: 21; Loss: nan
epoch: 3; batch: 22; Loss: nan
epoch: 3; batch: 23; Loss: nan
epoch: 3; batch: 24; Loss: nan
epoch: 3; batch: 25; Loss: nan
epoch: 3; batch: 26; Loss: nan
epoch: 3; batch: 27; Loss: nan
epoch: 3; batch: 28; Loss: nan
epoch: 3; batch: 29; Loss: nan
epoch: 3; batch: 30; Loss: nan
epoch: 3; batch: 31; Loss: nan
epoch: 3; batch: 32; Loss: nan
epoch: 3; batch: 

  0%|          | 0/74 [00:00<?, ?it/s]

epoch: 4; batch: 1; Loss: nan
epoch: 4; batch: 2; Loss: nan
epoch: 4; batch: 3; Loss: nan
epoch: 4; batch: 4; Loss: nan
epoch: 4; batch: 5; Loss: nan
epoch: 4; batch: 6; Loss: nan
epoch: 4; batch: 7; Loss: nan
epoch: 4; batch: 8; Loss: nan
epoch: 4; batch: 9; Loss: nan
epoch: 4; batch: 10; Loss: nan
epoch: 4; batch: 11; Loss: nan
epoch: 4; batch: 12; Loss: nan
epoch: 4; batch: 13; Loss: nan
epoch: 4; batch: 14; Loss: nan
epoch: 4; batch: 15; Loss: nan
epoch: 4; batch: 16; Loss: nan
epoch: 4; batch: 17; Loss: nan
epoch: 4; batch: 18; Loss: nan
epoch: 4; batch: 19; Loss: nan
epoch: 4; batch: 20; Loss: nan
epoch: 4; batch: 21; Loss: nan
epoch: 4; batch: 22; Loss: nan
epoch: 4; batch: 23; Loss: nan
epoch: 4; batch: 24; Loss: nan
epoch: 4; batch: 25; Loss: nan
epoch: 4; batch: 26; Loss: nan
epoch: 4; batch: 27; Loss: nan
epoch: 4; batch: 28; Loss: nan
epoch: 4; batch: 29; Loss: nan
epoch: 4; batch: 30; Loss: nan
epoch: 4; batch: 31; Loss: nan
epoch: 4; batch: 32; Loss: nan
epoch: 4; batch: 

  0%|          | 0/74 [00:00<?, ?it/s]

epoch: 5; batch: 1; Loss: nan
epoch: 5; batch: 2; Loss: nan
epoch: 5; batch: 3; Loss: nan
epoch: 5; batch: 4; Loss: nan
epoch: 5; batch: 5; Loss: nan
epoch: 5; batch: 6; Loss: nan
epoch: 5; batch: 7; Loss: nan
epoch: 5; batch: 8; Loss: nan
epoch: 5; batch: 9; Loss: nan
epoch: 5; batch: 10; Loss: nan
epoch: 5; batch: 11; Loss: nan
epoch: 5; batch: 12; Loss: nan
epoch: 5; batch: 13; Loss: nan
epoch: 5; batch: 14; Loss: nan
epoch: 5; batch: 15; Loss: nan
epoch: 5; batch: 16; Loss: nan
epoch: 5; batch: 17; Loss: nan
epoch: 5; batch: 18; Loss: nan
epoch: 5; batch: 19; Loss: nan
epoch: 5; batch: 20; Loss: nan
epoch: 5; batch: 21; Loss: nan
epoch: 5; batch: 22; Loss: nan
epoch: 5; batch: 23; Loss: nan
epoch: 5; batch: 24; Loss: nan
epoch: 5; batch: 25; Loss: nan
epoch: 5; batch: 26; Loss: nan
epoch: 5; batch: 27; Loss: nan
epoch: 5; batch: 28; Loss: nan
epoch: 5; batch: 29; Loss: nan
epoch: 5; batch: 30; Loss: nan
epoch: 5; batch: 31; Loss: nan
epoch: 5; batch: 32; Loss: nan
epoch: 5; batch: 

  0%|          | 0/74 [00:00<?, ?it/s]

epoch: 6; batch: 1; Loss: nan
epoch: 6; batch: 2; Loss: nan
epoch: 6; batch: 3; Loss: nan
epoch: 6; batch: 4; Loss: nan
epoch: 6; batch: 5; Loss: nan
epoch: 6; batch: 6; Loss: nan
epoch: 6; batch: 7; Loss: nan
epoch: 6; batch: 8; Loss: nan
epoch: 6; batch: 9; Loss: nan
epoch: 6; batch: 10; Loss: nan
epoch: 6; batch: 11; Loss: nan
epoch: 6; batch: 12; Loss: nan
epoch: 6; batch: 13; Loss: nan
epoch: 6; batch: 14; Loss: nan
epoch: 6; batch: 15; Loss: nan
epoch: 6; batch: 16; Loss: nan
epoch: 6; batch: 17; Loss: nan
epoch: 6; batch: 18; Loss: nan
epoch: 6; batch: 19; Loss: nan
epoch: 6; batch: 20; Loss: nan
epoch: 6; batch: 21; Loss: nan
epoch: 6; batch: 22; Loss: nan
epoch: 6; batch: 23; Loss: nan
epoch: 6; batch: 24; Loss: nan
epoch: 6; batch: 25; Loss: nan
epoch: 6; batch: 26; Loss: nan
epoch: 6; batch: 27; Loss: nan
epoch: 6; batch: 28; Loss: nan
epoch: 6; batch: 29; Loss: nan
epoch: 6; batch: 30; Loss: nan
epoch: 6; batch: 31; Loss: nan
epoch: 6; batch: 32; Loss: nan
epoch: 6; batch: 

In [22]:
torch.max(torch.abs(model.weight))

tensor(5.0420e-09, device='cuda:0', grad_fn=<MaxBackward1>)

In [23]:
torch.min(torch.abs(model.weight))

tensor(0., device='cuda:0', grad_fn=<MinBackward1>)