In [21]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [22]:
import torch
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from transformers import DeiTForImageClassificationWithTeacher
import torchvision
from tqdm import tqdm

In [23]:
def check_pruning_correctness_and_extent(model, compression_ratio):
    print("Checking pruning correctness and extent...")
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            # Calculate sparsity
            total_weights = module.weight.numel()
            zero_weights = total_weights - module.weight.data.nonzero().size(0)
            sparsity = zero_weights / total_weights
            expected_sparsity = compression_ratio
            print(f"{name}: Sparsity: {sparsity:.2f}. Expected: {expected_sparsity}.")


In [24]:
def freeze_pruned_weights(model):
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            mask = module.weight.data != 0  # A mask of which weights are non-zero
            module.weight.register_hook(lambda grad, mask=mask: grad * mask)
    print("Pruned weights are now frozen and will not be updated during training.")


In [25]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [26]:
# Data preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [27]:
# Load CIFAR-10 dataset
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [28]:
# Load DeiT model
model = DeiTForImageClassificationWithTeacher.from_pretrained('facebook/deit-base-distilled-patch16-224')
model = model.to(device)

# Adjust classifier for CIFAR-10 classes
model.distillation_classifier = torch.nn.Linear(in_features=model.distillation_classifier.in_features, out_features=10)
model.cls_classifier = torch.nn.Linear(in_features=model.cls_classifier.in_features, out_features=10)
model = model.to(device)

In [29]:
def apply_pruning_to_layer(layer, compression_ratio):
    total_weights = layer.weight.numel()
    num_weights_to_keep = int(total_weights * (1 - compression_ratio))
    weights_abs = torch.abs(layer.weight.data.view(-1))
    threshold = torch.kthvalue(weights_abs, num_weights_to_keep).values
    mask = torch.ge(weights_abs, threshold).float().view(layer.weight.shape)
    layer.weight.data.mul_(mask)
    # Removed bias pruning to avoid the shape mismatch error

In [30]:
def prune_model(model, compression_ratio):
    with torch.no_grad():
        for name, module in model.named_modules():
            if isinstance(module, torch.nn.Linear):
                apply_pruning_to_layer(module, compression_ratio)

In [31]:
# def prune_model(model, compression_ratio):
#     with torch.no_grad():
#         for name, module in model.named_modules():
#             if isinstance(module, torch.nn.Linear):
#                 apply_pruning_to_layer(module, compression_ratio)

In [32]:
def train(model, train_loader, optimizer, epochs):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f'Epoch {epoch+1}/{epochs}')
        for batch_idx, (data, target) in progress_bar:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = F.cross_entropy(output.logits, target)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            progress_bar.set_postfix(loss=total_loss/(batch_idx+1))

In [33]:
def test(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    progress_bar = tqdm(enumerate(test_loader), total=len(test_loader), desc='Testing')
    with torch.no_grad():
        for batch_idx, (data, target) in progress_bar:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.logits.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += target.size(0)
            progress_bar.set_postfix(acc=f'{100. * correct / total:.2f}%')

    print(f'Test set: Accuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.0f}%)')

In [None]:
# Parameters
compression_ratio = 0.5  # Example: 50% pruning
epochs = 16  # Retraining epochs

# Retrain the pruned model
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Prune the model without causing shape mismatch errors
prune_model(model, compression_ratio)

# Check pruning correctness and extent
check_pruning_correctness_and_extent(model, compression_ratio)

# Freeze pruned weights
freeze_pruned_weights(model)

# Retrain the pruned model
train(model, train_loader, optimizer, epochs)

# Evaluate the pruned and retrained model
test(model, test_loader)



Checking pruning correctness and extent...
deit.encoder.layer.0.attention.attention.query: Sparsity: 0.50. Expected: 0.5.
deit.encoder.layer.0.attention.attention.key: Sparsity: 0.50. Expected: 0.5.
deit.encoder.layer.0.attention.attention.value: Sparsity: 0.50. Expected: 0.5.
deit.encoder.layer.0.attention.output.dense: Sparsity: 0.50. Expected: 0.5.
deit.encoder.layer.0.intermediate.dense: Sparsity: 0.50. Expected: 0.5.
deit.encoder.layer.0.output.dense: Sparsity: 0.50. Expected: 0.5.
deit.encoder.layer.1.attention.attention.query: Sparsity: 0.50. Expected: 0.5.
deit.encoder.layer.1.attention.attention.key: Sparsity: 0.50. Expected: 0.5.
deit.encoder.layer.1.attention.attention.value: Sparsity: 0.50. Expected: 0.5.
deit.encoder.layer.1.attention.output.dense: Sparsity: 0.50. Expected: 0.5.
deit.encoder.layer.1.intermediate.dense: Sparsity: 0.50. Expected: 0.5.
deit.encoder.layer.1.output.dense: Sparsity: 0.50. Expected: 0.5.
deit.encoder.layer.2.attention.attention.query: Sparsity: 0

Epoch 1/16: 100%|██████████| 1563/1563 [15:20<00:00,  1.70it/s, loss=1.43]
Epoch 2/16: 100%|██████████| 1563/1563 [15:20<00:00,  1.70it/s, loss=0.958]
Epoch 3/16: 100%|██████████| 1563/1563 [15:19<00:00,  1.70it/s, loss=0.769]
Epoch 4/16: 100%|██████████| 1563/1563 [15:20<00:00,  1.70it/s, loss=0.646]
Epoch 5/16: 100%|██████████| 1563/1563 [15:20<00:00,  1.70it/s, loss=0.781]
Epoch 6/16: 100%|██████████| 1563/1563 [15:20<00:00,  1.70it/s, loss=0.531]
Epoch 7/16: 100%|██████████| 1563/1563 [15:21<00:00,  1.70it/s, loss=0.449]
Epoch 8/16: 100%|██████████| 1563/1563 [15:21<00:00,  1.70it/s, loss=0.389]
Epoch 9/16: 100%|██████████| 1563/1563 [15:20<00:00,  1.70it/s, loss=0.339]
Epoch 10/16: 100%|██████████| 1563/1563 [15:20<00:00,  1.70it/s, loss=0.284]
Epoch 11/16: 100%|██████████| 1563/1563 [15:19<00:00,  1.70it/s, loss=0.246]
Epoch 12/16: 100%|██████████| 1563/1563 [15:19<00:00,  1.70it/s, loss=0.211]
Epoch 13/16: 100%|██████████| 1563/1563 [15:19<00:00,  1.70it/s, loss=0.186]
Epoch 14/

In [None]:
check_pruning_correctness_and_extent(model, compression_ratio)