## Masked Distilation


### Baseline:
* Bool masking of weights trained model
* Gradient descent of loss by continious mask
* Clipping masks to bool value

### Ours:
* Bool masking of weights trained model
* Frank Wolfe of loss by continious mask

## Baseline implementation

In [1]:

import sys
import os

repo_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if repo_root not in sys.path:
    sys.path.insert(0, repo_root)
print(f"Added repo root to sys.path: {repo_root}")

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from src.model import MLP
from src.trainer import Trainer
import os
import json
import copy


%load_ext autoreload
%autoreload 2

Added repo root to sys.path: /Users/igoreshka/Desktop/CFW-in-ML


In [2]:
# MPS, CUDA, or CPU
DEVICE = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")

In [3]:

# Step 1: Инициализация и обучение модели
model = MLP().to(device=DEVICE)
trainer = Trainer(dataset_name='MNIST', batch_size=64, model=model, checkpoint_path='checkpoints/ckpt_0', device=DEVICE)
trainer.train(n_epochs=10)


2025-05-24 17:35:03,080 - INFO - Epoch 1: Train Loss = 0.2286, Test Loss = 0.1217, Accuracy = 96.27%
2025-05-24 17:35:07,548 - INFO - Epoch 2: Train Loss = 0.0867, Test Loss = 0.0930, Accuracy = 97.16%
2025-05-24 17:35:11,922 - INFO - Epoch 3: Train Loss = 0.0583, Test Loss = 0.0731, Accuracy = 97.74%
2025-05-24 17:35:16,296 - INFO - Epoch 4: Train Loss = 0.0411, Test Loss = 0.0829, Accuracy = 97.58%
2025-05-24 17:35:20,672 - INFO - Epoch 5: Train Loss = 0.0310, Test Loss = 0.0800, Accuracy = 97.77%
2025-05-24 17:35:25,061 - INFO - Epoch 6: Train Loss = 0.0283, Test Loss = 0.0779, Accuracy = 97.99%
2025-05-24 17:35:29,434 - INFO - Epoch 7: Train Loss = 0.0229, Test Loss = 0.0869, Accuracy = 97.86%
2025-05-24 17:35:33,756 - INFO - Epoch 8: Train Loss = 0.0205, Test Loss = 0.0955, Accuracy = 97.81%
2025-05-24 17:35:38,099 - INFO - Epoch 9: Train Loss = 0.0171, Test Loss = 0.0760, Accuracy = 98.23%
2025-05-24 17:35:42,446 - INFO - Epoch 10: Train Loss = 0.0163, Test Loss = 0.0994, Accurac

In [4]:
import torch.nn as nn

class MaskedLinear(nn.Module):
    def __init__(self, linear_layer):
        super().__init__()
        self.linear = linear_layer
        self.mask = nn.Parameter(torch.ones_like(linear_layer.weight))

    def forward(self, x):
        masked_weight = self.linear.weight * self.mask
        return nn.functional.linear(x, masked_weight, self.linear.bias)


def apply_masked_linear(model):
    for name, module in model.named_children():
        if isinstance(module, nn.Linear):
            setattr(model, name, MaskedLinear(module))
        else:
            apply_masked_linear(module)  # рекурсивно проходим по слоям


In [5]:
# Step 2: Маскирование весов
ckpt_path = 'checkpoints/ckpt_0/model.pt'
ckpt = torch.load(ckpt_path, map_location=DEVICE)
model.load_state_dict(ckpt)  # <- вот здесь без ['model_state_dict']

<All keys matched successfully>

In [16]:
# 1. Загружаем модель и freeze'им веса
masked_model = copy.deepcopy(model)
for param in masked_model.parameters():
    param.requires_grad = False

# 2. Оборачиваем линейные слои
apply_masked_linear(masked_model)

# 3. Расмаскируем параметры масок
mask_params = [module.mask for module in masked_model.modules() if isinstance(module, MaskedLinear)]
for p in mask_params:
    p.requires_grad = True

optimizer = torch.optim.Adam(mask_params, lr=1e-1)
loss_fn = nn.CrossEntropyLoss()

# 4. Обучение масок
n_mask_epochs = 1
for epoch in range(n_mask_epochs):
    for x, y in trainer.train_loader:
        x, y = x.to(trainer.device), y.to(trainer.device)

        optimizer.zero_grad()
        logits = masked_model(x)
        loss = loss_fn(logits, y)

        loss.backward()  # ← градиенты посчитаны тут

        # # теперь можно посмотреть на градиенты масок
        # for module in masked_model.modules():
        #     if isinstance(module, MaskedLinear):
        #         print(module.mask.grad)  # ← теперь не None

        optimizer.step()
        # break


In [7]:
with torch.no_grad():
    for module in masked_model.modules():
        if isinstance(module, MaskedLinear):
            print(module.mask)
            

Parameter containing:
tensor([[1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        ...,
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.]], device='mps:0', requires_grad=True)
Parameter containing:
tensor([[ 1.0000,  1.9594,  1.4129,  ...,  0.4427,  2.3374,  2.1956],
        [ 1.0000,  2.2058,  0.4408,  ...,  1.5718,  1.1755,  1.8411],
        [ 1.0000, -1.1590,  1.3135,  ...,  0.9250,  0.4462,  1.1584],
        ...,
        [ 1.0000,  2.6824,  1.3716,  ...,  0.1664,  3.6657,  0.7383],
        [ 1.0000,  1.1696,  2.7084,  ...,  1.0120,  2.8961,  1.4096],
        [ 1.0000,  2.1034,  1.0000,  ..., -1.2630,  2.2218,  1.0000]],
       device='mps:0', requires_grad=True)
Parameter containing:
tensor([[ 1.0073, -0.1485, -1.7482,  ...,  2.1074,  3.0330,  4.3077],
        [ 1.0672,  1.8483,  1.9075,  ...,  2.7925,  3.9435,  2.1373],
        [ 2.3592,  0.1757, 

In [27]:
prune_ratio = 0.10  # 10%

with torch.no_grad():
    for module in masked_model.modules():
        if isinstance(module, MaskedLinear):
            mean_per_neuron = module.mask.abs().mean(dim=1)  # shape: [output_dim]
            n_prune = int(prune_ratio * mean_per_neuron.numel())
            if n_prune == 0:
                continue  # если слишком мало нейронов — ничего не делаем
            prune_indices = torch.topk(mean_per_neuron, k=n_prune, largest=False).indices
            keep_mask = torch.ones_like(mean_per_neuron)
            keep_mask[prune_indices] = 0.0
            keep_mask = keep_mask.view(-1, 1)  # для broadcast по входам
            module.mask.data *= keep_mask


In [30]:
trainer.evaluate_model(model, description="Original model")
trainer.evaluate_model(masked_model, description="Masked model")


Original model: Test Loss = 0.0994, Accuracy = 97.80%
Masked model: Test Loss = 0.2265, Accuracy = 94.52%


(0.22649068035781383, 94.52)

In [31]:
def extract_pruned_mlp(masked_model, input_dim=28*28, output_dim=10):
    # Сначала получаем список слоёв MaskedLinear из masked_model (в правильном порядке)
    masked_linears = [m for m in masked_model.modules() if isinstance(m, MaskedLinear)]
    
    # Запишем размеры слоёв по маске (сколько нейронов осталось)
    layer_output_masks = []
    for layer in masked_linears:
        # Маска shape: [out_dim, in_dim]
        mask = layer.mask.abs()
        # Средняя маска по входам нейрона
        mean_per_neuron = mask.mean(dim=1)
        # Нейроны, которые НЕ занулены полностью (mean > 0)
        keep_neurons = (mean_per_neuron > 0).cpu()
        layer_output_masks.append(keep_neurons)
    
    # Размерности скрытых слоёв (число оставшихся нейронов в каждом слое)
    hidden_dims = [mask.sum().item() for mask in layer_output_masks[:-1]]  # все кроме последнего слоя
    # Последний слой должен иметь output_dim нейронов (обычно 10)
    # Но, если последний слой тоже прунинговали — оставим сколько есть
    last_layer_out_dim = layer_output_masks[-1].sum().item()
    if last_layer_out_dim != output_dim:
        print(f"Warning: output layer pruned to {last_layer_out_dim} neurons instead of {output_dim}")
        output_dim = last_layer_out_dim
    
    # Создаём новую модель MLP с нужными hidden_dims
    new_mlp = MLP(input_dim=input_dim, output_dim=output_dim, hidden_dims=hidden_dims)
    
    # Копируем веса из masked_model в new_mlp с учётом масок
    prev_keep_mask = torch.ones(input_dim, dtype=torch.bool)  # входной слой — все входы
    
    for i, (old_layer, new_layer, keep_mask) in enumerate(zip(masked_linears, new_mlp.model, layer_output_masks)):
        # keep_mask — булев mask для выходных нейронов слоя (out_dim)
        keep_mask = keep_mask.to(old_layer.mask.device)
        
        # old weights: [out_dim, in_dim]
        w = old_layer.linear.weight.data
        b = old_layer.linear.bias.data if old_layer.linear.bias is not None else None
        
        # Обрезаем вес по входам (столбцам), оставляем только prev_keep_mask
        w = w[:, prev_keep_mask]
        
        # Обрезаем вес по выходам (строкам) — оставляем keep_mask
        w = w[keep_mask, :]
        
        if b is not None:
            b = b[keep_mask]
        
        # Копируем в new_layer (nn.Linear)
        new_layer.weight.data = w.clone()
        if b is not None:
            new_layer.bias.data = b.clone()
        else:
            new_layer.bias = None
        
        # Теперь входной маск для следующего слоя — это keep_mask
        prev_keep_mask = keep_mask
    
    return new_mlp


In [11]:
def get_active_neurons(masked_linear):
    # Предполагается, что masked_linear.mask shape == weight.shape
    # Для нейронов (выходов) — это строки
    row_mask = masked_linear.mask.abs().sum(dim=1) > 0
    return row_mask  # bool-тензор длины out_features
def prune_linear_layer(old_layer, active_out_indices, active_in_indices):
    new_layer = nn.Linear(
        in_features=len(active_in_indices),
        out_features=len(active_out_indices),
        bias=old_layer.bias is not None,
    )

    # Скопируем веса и байасы
    with torch.no_grad():
        new_weight = old_layer.weight[active_out_indices][:, active_in_indices]
        new_layer.weight.copy_(new_weight)

        if old_layer.bias is not None:
            new_layer.bias.copy_(old_layer.bias[active_out_indices])
    
    return new_layer



In [42]:
def get_active_neurons(module, threshold=0):
    # Возвращает булев вектор, какие выходные нейроны активны
    mask = module.mask.abs()
    mean_mask = mask.mean(dim=1)  # среднее по входам
    return mean_mask > threshold

def prune_linear_layer(module, out_indices, in_indices):
    # Создаёт новый nn.Linear с весами обрезанными по индексам out_indices (выходы) и in_indices (входы)
    new_layer = nn.Linear(len(in_indices), len(out_indices), bias=module.linear.bias is not None)
    new_layer.weight.data = module.linear.weight.data[out_indices][:, in_indices].clone()
    if module.linear.bias is not None:
        new_layer.bias.data = module.linear.bias.data[out_indices].clone()
    return new_layer

def convert_masked_to_pruned_model(masked_model, input_dim=28*28, output_dim=10):
    pruned_layers = []
    prev_active_indices = None
    modules = [m for m in masked_model.modules() if isinstance(m, MaskedLinear)]
    
    for i, module in enumerate(modules):
        if i == len(modules) - 1:
            # Для последнего слоя не пруним выходные нейроны - оставляем все
            out_indices = torch.arange(module.linear.weight.shape[0])
        else:
            active_neurons = get_active_neurons(module)
            out_indices = torch.where(active_neurons)[0]

        in_indices = (
            torch.arange(module.linear.weight.shape[1])
            if prev_active_indices is None
            else prev_active_indices
        )
        
        pruned_layer = prune_linear_layer(module, out_indices, in_indices)
        pruned_layers.append(pruned_layer)
        prev_active_indices = out_indices
    
    # hidden_dims для новой модели
    hidden_dims = [layer.out_features for layer in pruned_layers[:-1]]
    pruned_output_dim = pruned_layers[-1].out_features
    
    # Создаём новую MLP с нужными размерами слоёв
    pruned_model = MLP(input_dim=input_dim, output_dim=pruned_output_dim, hidden_dims=hidden_dims)

    # Копируем веса
    pruned_model_layers = [module for module in pruned_model.model if isinstance(module, nn.Linear)]
    for new_layer, old_layer in zip(pruned_model_layers, pruned_layers):
        new_layer.weight.data = old_layer.weight.data.clone()
        if old_layer.bias is not None:
            new_layer.bias.data = old_layer.bias.data.clone()
        else:
            new_layer.bias = None

    return pruned_model


In [43]:
pruned_model = convert_masked_to_pruned_model(masked_model)

In [51]:
trainer.evaluate_model(model, description="Original model")
trainer.evaluate_model(masked_model, description="Masked model")
trainer.evaluate_model(pruned_model, description="Pruned model")



Original model: Test Loss = 0.0994, Accuracy = 97.80%
Masked model: Test Loss = 0.2265, Accuracy = 94.52%
Pruned model: Test Loss = 0.0990, Accuracy = 97.70%


(0.09896772104129486, 97.7)

In [48]:
def count_all_params(model):
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Total parameters in model: {total_params}")

    print("\nParameters by Linear layer:")
    for i, module in enumerate(model.modules()):
        if isinstance(module, nn.Linear) or (hasattr(module, 'linear') and isinstance(module.linear, nn.Linear)):
            if isinstance(module, nn.Linear):
                w = module.weight.numel()
                b = module.bias.numel() if module.bias is not None else 0
                print(f"Layer {i} (Linear): weights = {w}, bias = {b}, total = {w+b}")
            else:  # Например MaskedLinear
                w = module.linear.weight.numel()
                b = module.linear.bias.numel() if module.linear.bias is not None else 0
                m = module.mask.numel()
                print(f"Layer {i} (MaskedLinear): weights = {w}, bias = {b}, mask = {m}, total = {w+b+m}")

# Пример использования:
print("Original model:")
count_all_params(model)

print("\nMasked model:")
count_all_params(masked_model)

print("\nPruned model:")
count_all_params(pruned_model)


Original model:
Total parameters in model: 669706

Parameters by Linear layer:
Layer 2 (Linear): weights = 401408, bias = 512, total = 401920
Layer 4 (Linear): weights = 262144, bias = 512, total = 262656
Layer 6 (Linear): weights = 5120, bias = 10, total = 5130

Masked model:
Total parameters in model: 1338378

Parameters by Linear layer:
Layer 2 (MaskedLinear): weights = 401408, bias = 512, mask = 401408, total = 803328
Layer 3 (Linear): weights = 401408, bias = 512, total = 401920
Layer 5 (MaskedLinear): weights = 262144, bias = 512, mask = 262144, total = 524800
Layer 6 (Linear): weights = 262144, bias = 512, total = 262656
Layer 8 (MaskedLinear): weights = 5120, bias = 10, mask = 5120, total = 10250
Layer 9 (Linear): weights = 5120, bias = 10, total = 5130

Pruned model:
Total parameters in model: 579487

Parameters by Linear layer:
Layer 2 (Linear): weights = 361424, bias = 461, total = 361885
Layer 4 (Linear): weights = 212521, bias = 461, total = 212982
Layer 6 (Linear): weight