In [1]:
from lightning import Fabric
import numpy as np
from sklearn.linear_model import SGDRegressor
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm.auto import tqdm, trange

from causalpruner.sgd_pruner import ParamDataset

torch.set_float32_matmul_precision('medium')

In [2]:
fb = Fabric()
fb.launch()

In [3]:
ds = ParamDataset("./checkpoints/mlpnet_mnist_causalpruner_1_5_1_0.001_0.9/weights/0",
                  "./checkpoints/mlpnet_mnist_causalpruner_1_5_1_0.001_0.9/loss/0",
                  1e-3,
                  False)

In [73]:
dl = DataLoader(ds, batch_size=len(ds))
dl = fb.setup_dataloaders(dl)

In [74]:
X, Y = next(iter(dl))

In [75]:
X

tensor([[3.1209e-10, 2.8103e-10, 3.1312e-10,  ..., 1.7826e-04, 3.0727e-05,
         1.2710e-05],
        [9.4209e-09, 9.3814e-09, 9.4265e-09,  ..., 1.9620e-06, 7.6819e-06,
         4.0881e-06],
        [6.6822e-08, 6.7033e-08, 6.6852e-08,  ..., 9.3237e-05, 3.1727e-05,
         6.2297e-08],
        ...,
        [3.7989e-08, 3.8251e-08, 3.7978e-08,  ..., 8.8818e-10, 6.4171e-06,
         3.3979e-05],
        [9.3884e-08, 9.3314e-08, 9.3884e-08,  ..., 2.0978e-04, 1.8417e-05,
         1.0495e-10],
        [3.6202e-06, 3.6167e-06, 3.6202e-06,  ..., 5.3749e-05, 4.3380e-05,
         3.0852e-06]], device='cuda:0')

In [76]:
Y

tensor([1.0924, 0.7842, 1.0919, 1.0024, 1.2729, 0.8606, 1.6258, 1.0034, 0.7199,
        0.8807, 0.7204, 0.9818, 1.0346, 1.3488, 1.5137, 0.7802, 0.6877, 0.8823,
        0.7591, 1.0939, 0.6203, 1.6413, 0.7380, 0.9551, 0.8566, 0.7717, 1.0848,
        1.2427, 0.7913, 0.8179, 1.6524, 1.0175, 1.0979, 1.0255, 0.7802, 1.0074,
        1.0260, 1.1618, 0.9994, 0.8802, 0.6329, 1.0185, 1.0109, 0.6294, 1.7057,
        1.1351, 1.0436, 0.6540, 0.8631, 0.8028, 0.7269, 0.6922, 0.7189, 1.0054,
        1.0245, 0.8647, 0.9652, 0.8833, 0.9084, 0.5907, 1.0652, 0.9370, 0.8541,
        0.9064, 0.7249, 0.7154, 0.6394, 0.9094, 0.6193, 0.9592, 1.4342, 1.4126,
        0.8933, 0.8732, 0.8968, 1.0632, 1.1432, 0.8063, 1.5021, 1.4347, 0.5922,
        0.8717, 0.5706, 0.7023, 0.9561, 1.4664, 1.3930, 0.6465, 1.1477, 0.7495,
        0.5510, 1.3427, 0.7415, 0.8184, 0.8596, 1.1542, 1.4734, 0.7626, 1.0135,
        0.6937, 1.1934, 0.7289, 0.8340, 1.1276, 0.8380, 0.8325, 0.5615, 1.0949,
        0.9903, 0.8144, 1.1366, 0.7410, 

In [7]:
topk = torch.topk(torch.mean(X, dim=0), k=1, largest=True)
print(topk)

torch.return_types.topk(
values=tensor([30.3825], device='cuda:0'),
indices=tensor([31554], device='cuda:0'))


In [8]:
X[:,topk.indices[0]]

tensor([ 3.2132, 18.6051, 25.2001,  ...,  8.6268, 10.5456, 22.4080],
       device='cuda:0')

In [9]:
Y

tensor([1.0002, 1.2474, 0.8577,  ..., 0.9133, 1.0585, 1.8976], device='cuda:0')

In [10]:
X_abs = torch.abs(X)
Y_abs = torch.abs(Y)

print(f"{torch.max(X_abs)}; {torch.min(X_abs)}")
print(f"{torch.max(Y_abs)}; {torch.min(Y_abs)}")

526.0156860351562; 2.666211912583094e-06
2.7532496452331543; 0.2631295323371887


In [11]:
dl = DataLoader(ds, batch_size=16, shuffle=True)
dl = fb.setup_dataloaders(dl)

In [12]:
num_params = ds.weights_zstats.num_params
l1_regularization_coeff = 0
lr = 1e-3

In [13]:
ds.weights_zstats

ZStats(num_params=32360, mean=tensor([4.0807e-13, 4.0808e-13, 4.0808e-13,  ..., 1.5538e-11, 7.5351e-13,
        4.9722e-12]), std=tensor([6.3572e-13, 6.3571e-13, 6.3571e-13,  ..., 3.3823e-11, 4.1001e-12,
        9.4806e-12]), global_mean=tensor(2.8430e-12), global_std=tensor(1.6184e-11))

In [14]:
ds.loss_zstats

ZStats(num_params=1, mean=tensor(-8.9929e-05), std=tensor(3.0293e-05), global_mean=tensor(-8.9929e-05), global_std=tensor(3.0293e-05))

In [15]:
num_epochs = int(np.ceil(np.log(num_params / len(ds))))
print(num_epochs)
num_epochs=20

4


In [16]:
model = nn.Linear(num_params, 1, bias=False)
nn.init.zeros_(model.weight)
abs_weight = torch.abs(model.weight)
print(f"{torch.max(abs_weight)}; {torch.min(abs_weight)}")
optimizer = optim.SGD(model.parameters(), lr=lr)
model, optimizer = fb.setup(model, optimizer)
model.train()
dl = DataLoader(ds, batch_size=16, shuffle=True)
dl = fb.setup_dataloaders(dl)
dl_iter = iter(dl)

0.0; 0.0


In [None]:
model.train()
for epoch in trange(100):
    total_loss = 0
    num_batches = 0
    for idx, (X, Y) in enumerate(tqdm(dl)):
        optimizer.zero_grad(set_to_none=True)
        outputs = model(X)
        Y = Y.view(outputs.size())
        loss = F.mse_loss(outputs, Y, reduction="mean")
        total_loss += loss.item()
        num_batches += 1
        tqdm.write(f"epoch: {epoch + 1}; batch: {idx + 1}; Loss: {loss.item()}")
        fb.backward(loss)
        optimizer.step()
    avg_loss = total_loss / num_batches
    tqdm.write(f"epoch: {epoch + 1}; avg_loss: {avg_loss}")

In [22]:
torch.max(torch.abs(model.weight))

tensor(5.0420e-09, device='cuda:0', grad_fn=<MaxBackward1>)

In [23]:
torch.min(torch.abs(model.weight))

tensor(0., device='cuda:0', grad_fn=<MinBackward1>)

In [4]:
from causalpruner.causal_weights_trainer import CausalWeightsTrainerConfig, get_causal_weights_trainer

In [5]:
config = CausalWeightsTrainerConfig(
    fabric=fb,
    init_lr=0.1,
    l1_regularization_coeff=1e-3,
    initialization='zeros',
    prune_amount=0.9,
    max_iter=30,
    loss_tol=1e-7,
    num_iter_no_change=2,
    backend='torch'
)
num_params = ds.weights_zstats.num_params
trainer = get_causal_weights_trainer(
    config, 
    num_params, 
    torch.ones(num_params),
    1,
    1,
    True
)

In [6]:
trainer.prune_amount_this_iteration

0.9

In [7]:
# dl = DataLoader(ds, batch_size=len(ds))
dl = DataLoader(ds, batch_size=16)
trainer.fit(dl)

Prune amount this iteration: 0.9
Setting learning rate to 0.1


Prune weight fitting:   0%|                                                        | 0/30 [00:00<?, ?it/s]

  0%|                                                                              | 0/15 [00:00<?, ?it/s]

Pruning iter: 1; loss: 1.117681622505188; best_loss: inf


  0%|                                                                              | 0/15 [00:00<?, ?it/s]

Pruning iter: 2; loss: 1.1176562309265137; best_loss: 1.117681622505188


  0%|                                                                              | 0/15 [00:00<?, ?it/s]

Pruning iter: 3; loss: 1.1176329851150513; best_loss: 1.1176562309265137


  0%|                                                                              | 0/15 [00:00<?, ?it/s]

Pruning iter: 4; loss: 1.1176104545593262; best_loss: 1.1176329851150513


  0%|                                                                              | 0/15 [00:00<?, ?it/s]

Pruning iter: 5; loss: 1.1175885200500488; best_loss: 1.1176104545593262


  0%|                                                                              | 0/15 [00:00<?, ?it/s]

Pruning iter: 6; loss: 1.1175671815872192; best_loss: 1.1175885200500488


  0%|                                                                              | 0/15 [00:00<?, ?it/s]

Pruning iter: 7; loss: 1.1175463199615479; best_loss: 1.1175671815872192


  0%|                                                                              | 0/15 [00:00<?, ?it/s]

Pruning iter: 8; loss: 1.117525577545166; best_loss: 1.1175463199615479


  0%|                                                                              | 0/15 [00:00<?, ?it/s]

Pruning iter: 9; loss: 1.1175053119659424; best_loss: 1.117525577545166


  0%|                                                                              | 0/15 [00:00<?, ?it/s]

Pruning iter: 10; loss: 1.1174850463867188; best_loss: 1.1175053119659424


  0%|                                                                              | 0/15 [00:00<?, ?it/s]

Pruning iter: 11; loss: 1.1174652576446533; best_loss: 1.1174850463867188


  0%|                                                                              | 0/15 [00:00<?, ?it/s]

Pruning iter: 12; loss: 1.1174453496932983; best_loss: 1.1174652576446533


  0%|                                                                              | 0/15 [00:00<?, ?it/s]

Pruning iter: 13; loss: 1.1174259185791016; best_loss: 1.1174453496932983


  0%|                                                                              | 0/15 [00:00<?, ?it/s]

Pruning iter: 14; loss: 1.1174066066741943; best_loss: 1.1174259185791016


  0%|                                                                              | 0/15 [00:00<?, ?it/s]

Pruning iter: 15; loss: 1.117387294769287; best_loss: 1.1174066066741943


  0%|                                                                              | 0/15 [00:00<?, ?it/s]

Pruning iter: 16; loss: 1.117368221282959; best_loss: 1.117387294769287


  0%|                                                                              | 0/15 [00:00<?, ?it/s]

Pruning iter: 17; loss: 1.1173491477966309; best_loss: 1.117368221282959


  0%|                                                                              | 0/15 [00:00<?, ?it/s]

Pruning iter: 18; loss: 1.1173304319381714; best_loss: 1.1173491477966309


  0%|                                                                              | 0/15 [00:00<?, ?it/s]

Pruning iter: 19; loss: 1.1173115968704224; best_loss: 1.1173304319381714


  0%|                                                                              | 0/15 [00:00<?, ?it/s]

Pruning iter: 20; loss: 1.1172927618026733; best_loss: 1.1173115968704224


  0%|                                                                              | 0/15 [00:00<?, ?it/s]

Pruning iter: 21; loss: 1.1172741651535034; best_loss: 1.1172927618026733


  0%|                                                                              | 0/15 [00:00<?, ?it/s]

Pruning iter: 22; loss: 1.1172558069229126; best_loss: 1.1172741651535034


  0%|                                                                              | 0/15 [00:00<?, ?it/s]

Pruning iter: 23; loss: 1.1172374486923218; best_loss: 1.1172558069229126


  0%|                                                                              | 0/15 [00:00<?, ?it/s]

Pruning iter: 24; loss: 1.1172192096710205; best_loss: 1.1172374486923218


  0%|                                                                              | 0/15 [00:00<?, ?it/s]

Pruning iter: 25; loss: 1.1172008514404297; best_loss: 1.1172192096710205


  0%|                                                                              | 0/15 [00:00<?, ?it/s]

Pruning iter: 26; loss: 1.1171826124191284; best_loss: 1.1172008514404297


  0%|                                                                              | 0/15 [00:00<?, ?it/s]

Pruning iter: 27; loss: 1.1171646118164062; best_loss: 1.1171826124191284


  0%|                                                                              | 0/15 [00:00<?, ?it/s]

Pruning iter: 28; loss: 1.117146611213684; best_loss: 1.1171646118164062


  0%|                                                                              | 0/15 [00:00<?, ?it/s]

Pruning iter: 29; loss: 1.117128610610962; best_loss: 1.117146611213684


  0%|                                                                              | 0/15 [00:00<?, ?it/s]

Pruning iter: 30; loss: 1.1171108484268188; best_loss: 1.117128610610962
Before: 31556
After: 3236


30

In [8]:
len(ds)

235

In [9]:
torch.sum(trainer.get_non_zero_weights())

tensor(3236., device='cuda:0')

In [10]:
num_params

32360

In [11]:
from tests.models import get_mlpnet

model = get_mlpnet("mnist")
model = fb.setup(model)

In [12]:
from causalpruner.base import Pruner

In [14]:
modules_dict = dict()
for name, module in model.named_modules():
    if Pruner.is_module_supported(module):
        modules_dict[name] = module

params = []
params_to_dims = dict()
for name, module in modules_dict.items():
    if hasattr(module, "weight"):
        params.append(name)
        params_to_dims[name] = torch.numel(module.weight)

mask = trainer.get_non_zero_weights()

masks = dict()
start_index, end_index = 0, 0
for param in params:
    end_index += params_to_dims[param]
    weight = modules_dict[param].weight
    masks[param] = mask[start_index:end_index].reshape_as(weight).to(weight.device, non_blocking=True)
    start_index = end_index

In [15]:
torch.sum(masks['_forward_module.fc1'])

tensor(2603., device='cuda:0')

In [16]:
torch.sum(masks['_forward_module.fc2'])

tensor(512., device='cuda:0')

In [17]:
torch.sum(masks['_forward_module.fc3'])

tensor(121., device='cuda:0')