## Masked Distilation


### Baseline 1:
* Bool masking of weights trained model
* Gradient descent of loss by continious mask
* Clipping masks to bool value

### Baseline 2:
* Random pruning

### Ours:
* Bool masking of weights trained model
* Frank Wolfe of loss by continious mask

## Default model training

In [24]:

import sys
import os

repo_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if repo_root not in sys.path:
    sys.path.insert(0, repo_root)
print(f"Added repo root to sys.path: {repo_root}")

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from src.model import MLP
from src.trainer import Trainer
import os
import json
import copy


from src.utils import count_all_params


%load_ext autoreload
%autoreload 2

Added repo root to sys.path: /Users/igoreshka/Desktop/CFW-in-ML
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [74]:
# MPS, CUDA, or CPU
DEVICE = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")  
trainer = Trainer(dataset_name='MNIST', batch_size=64, model=model, checkpoint_path='checkpoints/ckpt_0', device=DEVICE)

In [None]:
# Step 1: Инициализация и обучение модели
model = MLP().to(device=DEVICE)
trainer.train(n_epochs=10)


2025-05-24 17:35:03,080 - INFO - Epoch 1: Train Loss = 0.2286, Test Loss = 0.1217, Accuracy = 96.27%
2025-05-24 17:35:07,548 - INFO - Epoch 2: Train Loss = 0.0867, Test Loss = 0.0930, Accuracy = 97.16%
2025-05-24 17:35:11,922 - INFO - Epoch 3: Train Loss = 0.0583, Test Loss = 0.0731, Accuracy = 97.74%
2025-05-24 17:35:16,296 - INFO - Epoch 4: Train Loss = 0.0411, Test Loss = 0.0829, Accuracy = 97.58%
2025-05-24 17:35:20,672 - INFO - Epoch 5: Train Loss = 0.0310, Test Loss = 0.0800, Accuracy = 97.77%
2025-05-24 17:35:25,061 - INFO - Epoch 6: Train Loss = 0.0283, Test Loss = 0.0779, Accuracy = 97.99%
2025-05-24 17:35:29,434 - INFO - Epoch 7: Train Loss = 0.0229, Test Loss = 0.0869, Accuracy = 97.86%
2025-05-24 17:35:33,756 - INFO - Epoch 8: Train Loss = 0.0205, Test Loss = 0.0955, Accuracy = 97.81%
2025-05-24 17:35:38,099 - INFO - Epoch 9: Train Loss = 0.0171, Test Loss = 0.0760, Accuracy = 98.23%
2025-05-24 17:35:42,446 - INFO - Epoch 10: Train Loss = 0.0163, Test Loss = 0.0994, Accurac

## Baseline 1 implementation

In [75]:
from src.neurodistil.masking_prune import prune_model

ckpt_path = 'checkpoints/ckpt_0/model.pt'
ckpt = torch.load(ckpt_path, map_location=DEVICE)
model = MLP().to(device=DEVICE)
model.load_state_dict(ckpt)

pruned_model = prune_model(model, trainer.get_train_loader(), device=DEVICE, prune_ratio=0.7, n_epochs=1)


In [76]:
trainer.evaluate_model(model, description="Original model")
trainer.evaluate_model(pruned_model, description="Pruned model")

Original model: Test Loss = 0.0994, Accuracy = 97.80%
Pruned model: Test Loss = 0.6904, Accuracy = 90.65%


(0.690416181564331, 90.65)

In [77]:
count_all_params(model) ,count_all_params(pruned_model)

(669706, 146310)

## Baseline 2: Implemetation

In [90]:
from src.neurodistil.prune import get_mlp_with_pruned_layers

In [121]:
pruned_model_2 = get_mlp_with_pruned_layers(model, sparsity_level=0.83)
pruned_model_2

Original parameters: 669706
Target parameters: 113850
Pruned hidden dims: [145, 180]


MLP(
  (model): Sequential(
    (0): Linear(in_features=784, out_features=145, bias=True)
    (1): ReLU()
    (2): Linear(in_features=145, out_features=180, bias=True)
    (3): ReLU()
    (4): Linear(in_features=180, out_features=10, bias=True)
  )
)

In [122]:
trainer.evaluate_model(model, description="Original model")
trainer.evaluate_model(pruned_model_2, description="Pruned model")

Original model: Test Loss = 0.0994, Accuracy = 97.80%
Pruned model: Test Loss = 0.2121, Accuracy = 94.27%


(0.2120740130662918, 94.27)

In [123]:
count_all_params(model) ,count_all_params(pruned_model_2)

(669706, 141915)