In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

## Imports

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as T
from torch.utils.data import DataLoader,random_split
from tqdm import tqdm
import torchvision.datasets as datasets
import random

In [3]:
def set_seed(seed):
    torch.manual_seed(seed)  # Set seed for CPU
    torch.cuda.manual_seed_all(seed)  # Set seed for all GPUs
    random.seed(seed)  # Set seed for Python's random module
    np.random.seed(seed)  # Set seed for NumPy

set_seed(42)

In [4]:
# setting up hyperparams
batch_size = 128
lr = 1e-4
num_epochs = 15
early_stopping_patience = 5

In [5]:
trans = T.Compose([
    T.ToTensor(),
    T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [6]:
train_data = datasets.CIFAR10(root = "/data", train = True, download = True, transform = trans)
test_data = datasets.CIFAR10(root = "/data", train = False, download = True, transform = trans)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:04<00:00, 35132553.68it/s]


Extracting /data/cifar-10-python.tar.gz to /data
Files already downloaded and verified


In [7]:
train_size = int(0.8 * len(train_data))  
val_size = len(train_data) - train_size 

# Split the dataset
train_set, val_set = random_split(train_data, [train_size, val_size])

# Create DataLoaders for train, val and test

train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_dl = DataLoader(val_set, batch_size=batch_size, shuffle=True)
test_dl = DataLoader(test_data, batch_size = batch_size, shuffle = True)

In [8]:
model = models.wide_resnet50_2(pretrained = True)
model.fc = nn.Linear(model.fc.in_features, 10)

Downloading: "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth" to /root/.cache/torch/hub/checkpoints/wide_resnet50_2-95faca4d.pth
100%|██████████| 132M/132M [00:02<00:00, 59.9MB/s] 


In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [10]:
model = model.to(device)

In [11]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr = lr)

In [12]:
best_loss = np.inf
patience_counter = 0

In [14]:
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    train_correct = 0
    train_total = 0
    for inputs, labels in tqdm(train_dl, desc = f"Epoch {epoch+1}/{num_epochs}"):
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        _,pred_train = torch.max(outputs,1)
        train_total += labels.size(0)
        train_correct += (pred_train == labels).sum().item()
        running_loss += loss.item()
        
        loss.backward()
        optimizer.step()
        
    avg_loss = running_loss / len(train_dl)
    train_acc = (train_correct/train_total)*100
    print(f"Train Loss : {avg_loss:.4f}   Train Accuracy : {train_acc:.4f}")   
    
    
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    for inputs, labels in tqdm(val_dl, desc = f"Epoch {epoch+1}/{num_epochs}"):
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        _, pred_test = torch.max(outputs, 1)
        val_total += labels.size(0)
        val_correct += (pred_test == labels).sum().item()
        val_loss += loss.item()
        
    avg_val_loss = val_loss / len(val_dl)
    test_acc = (val_correct/val_total)*100
    print(f"Val Loss : {avg_val_loss:.4f}   VAl Accuracy : {test_acc:.4f}")
    
    if avg_val_loss < best_loss:
        best_loss = avg_val_loss
        patience_counter = 0
        torch.save(model.state_dict(), 'plain_wide_resnet.pth')  # Save the best model
    else:
        patience_counter += 1
        if patience_counter >= early_stopping_patience:
            print("Early stopping triggered!")
            break

Epoch 1/15: 100%|██████████| 313/313 [00:43<00:00,  7.19it/s]


Train Loss : 0.5782   Train Accuracy : 80.7975


Epoch 1/15: 100%|██████████| 79/79 [00:05<00:00, 15.43it/s]


Val Loss : 0.5277   VAl Accuracy : 81.8300


Epoch 2/15: 100%|██████████| 313/313 [00:45<00:00,  6.89it/s]


Train Loss : 0.2885   Train Accuracy : 90.3600


Epoch 2/15: 100%|██████████| 79/79 [00:05<00:00, 15.40it/s]


Val Loss : 0.5276   VAl Accuracy : 82.8900


Epoch 3/15: 100%|██████████| 313/313 [00:46<00:00,  6.68it/s]


Train Loss : 0.1753   Train Accuracy : 94.0900


Epoch 3/15: 100%|██████████| 79/79 [00:05<00:00, 15.18it/s]


Val Loss : 0.5803   VAl Accuracy : 82.9500


Epoch 4/15: 100%|██████████| 313/313 [00:46<00:00,  6.74it/s]


Train Loss : 0.1299   Train Accuracy : 95.7050


Epoch 4/15: 100%|██████████| 79/79 [00:05<00:00, 15.32it/s]


Val Loss : 0.5991   VAl Accuracy : 83.4700


Epoch 5/15: 100%|██████████| 313/313 [00:46<00:00,  6.71it/s]


Train Loss : 0.1009   Train Accuracy : 96.6425


Epoch 5/15: 100%|██████████| 79/79 [00:05<00:00, 15.41it/s]


Val Loss : 0.6079   VAl Accuracy : 83.6800


Epoch 6/15: 100%|██████████| 313/313 [00:46<00:00,  6.73it/s]


Train Loss : 0.0775   Train Accuracy : 97.4075


Epoch 6/15: 100%|██████████| 79/79 [00:05<00:00, 15.21it/s]


Val Loss : 0.6382   VAl Accuracy : 84.0300


Epoch 7/15: 100%|██████████| 313/313 [00:46<00:00,  6.72it/s]


Train Loss : 0.0754   Train Accuracy : 97.5900


Epoch 7/15: 100%|██████████| 79/79 [00:05<00:00, 15.60it/s]

Val Loss : 0.6437   VAl Accuracy : 84.1300
Early stopping triggered!





In [15]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in test_dl:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Accuracy on test set: {100 * correct / total:.2f}%")

Accuracy on test set: 83.16%


In [16]:
def check_model_size(model):
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Total number of parameters: {total_params}")
    model_size_mb = total_params * 4 / (1024 ** 2)  # Assuming float32 (4 bytes)
    print(f"Approximate model size: {model_size_mb:.2f} MB")

In [17]:
print("model size before pruning : ")
check_model_size(model)

model size before pruning : 
Total number of parameters: 66854730
Approximate model size: 255.03 MB


In [18]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [19]:
print("model size after pruning : ")
count_parameters(model)

model size after pruning : 


66854730

### model_size = 255MB, model_accuracy_on_test_set = 83.16%

## implementing research paper for pruning

In [72]:
class PolicyNetwork(nn.Module):
    def __init__(self, num_filters, hidden_size=128):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(num_filters, hidden_size)
        self.activation = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, num_filters)

    def forward(self, x):
        x = self.fc1(x)
        x = self.activation(x)
        x = self.fc2(x)

        return torch.sigmoid(x)  # Ensure output is between 0 and 1
    

In [73]:
class PruningAgent:
    def __init__(self, num_filters):
        self.num_filters = num_filters
        self.policy_net = PolicyNetwork(num_filters).to(device)
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=1e-3)
        
    def calculate_reward(self, p_hat, p_star, b, A_l):
        reward_accuracy = b - (p_star - p_hat) / b  # Calculate reward based on performance drop
        # reward accuracy will be negative if p_hat is greter than p_star

        # Calculate efficiency term
        C_A_l = A_l.sum()  # THis is the number of filters kept ie number of ones
        if C_A_l.item() == 0:
            reward_efficiency = 0  # If no filters are kept, efficiency reward is zero
        else:
            reward_efficiency = np.log(self.num_filters / C_A_l.item())  # Efficiency reward
            # if C_A_l is less, then number of filters kept is less, effienciency is increased

        # Total reward
        reward = reward_accuracy * reward_efficiency
        return reward
    
    def get_binary_actions(self, weights):
        # a randomly intialized tensor consisting of 0s and 1s
        # 1=> keep the filter, 0=> remove the filter
        weights_flat = weights.view(weights.size(0), -1)
        probs = self.policy_net(weights_flat)
        actions = (torch.rand(weights.size(0), device=device) < probs).float()  # Sample actions
        return actions

In [74]:
def prune_layer(model, layer_id, A_l):
    with torch.no_grad():
        layer_name = layer[:-1]
        layer = model._modules[layer_name]
        A_l = A_l.to(device)

        # If the layer is a Sequential block
        if isinstance(layer, torch.nn.Sequential):
            # Access Block
            block = layer[int(layer_id[-1])]  
            conv_layer = block.conv1
            W_l = conv_layer.weight.data.to(device)
        elif isinstance(layer, torch.nn.Conv2d):
            # If it's directly a Conv2d layer
            W_l = layer.weight.data.to(device)
        elif isinstance(layer, Bottleneck):
            # If it's a Bottleneck block, access the conv1 layer inside it
            W_l = layer.conv1.weight.data.to(device)
        else:
            raise ValueError(f"Layer type {type(layer)} not handled for pruning.")

        # Apply binary actions to prune weights
        pruned_W_l = W_l * A_l.view(-1, 1, 1, 1)  # Apply binary actions
        pruned_W_l = pruned_W_l.to(device)

        # Update the weights in the layer
        if isinstance(layer, torch.nn.Sequential):
            conv_layer.weight.data = pruned_W_l
        elif isinstance(layer, Bottleneck):
            layer.conv1.weight.data = pruned_W_l
        else:
            layer.weight.data = pruned_W_l

    return model


In [75]:
def fine_tune_model(model, train_loader, criterion, optimizer, num_epochs=5):
    """Fine-tune the pruned model to recover performance."""
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        
        print(f"Fine-tuning Epoch {epoch+1}/{num_epochs}, Loss: {running_loss:.4f}")


In [76]:
def evaluate_accuracy(model, val_loader):
    """Evaluate the model accuracy on the validation set."""
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

In [77]:
from torchvision.models.resnet import Bottleneck  # Import Bottleneck

In [78]:
def prune_layer_by_layer(model, agents, criterion, optimizer, train_loader, val_loader, baseline_acc, b, num_epochs=10):
    model.train()
    
    # Iterate over each convolutional layer to prune
    for layer_id, agent in agents.items():
        print(f"Pruning {layer_id}...")
        # Access the weights of the convolutional layer within a Sequential block
        layer_name = layer_id[:-1]
        x = int(layer_id[-1])
        layer = model._modules[layer_name]
        
        if isinstance(layer, torch.nn.Sequential):
            # If the layer is a Sequential container, access the first Conv2d layer
            block = layer[x]  
            conv_layer = block.conv1
            W_l = conv_layer.weight.data.to(device) 
        elif isinstance(layer, Bottleneck):
            # If the layer is a Bottleneck block, access its convolutional layers
            conv_layer = layer.conv1
            W_l = conv_layer.weight.data.to(device) 
        else:
            # If it's not a Sequential container, directly access the weights
            W_l = layer.weight.data.to(device)  
        
        A_l = agent.get_binary_actions(W_l)
        
        # Prune the current layer based on actions
        prune_layer(model, layer_id, A_l)
        
        # Fine-tune the entire network after pruning the current layer
        print(f"Fine-tuning the model after pruning {layer_id}...")
        fine_tune_model(model, train_loader, criterion, optimizer, num_epochs=4)
        
        # Validate the pruned model
        new_acc = evaluate_accuracy(model, val_loader)
        print(f"Accuracy after pruning {layer_name}: {new_acc * 100:.2f}%")
        
        # Calculate the reward
        reward = agent.calculate_reward(new_acc, baseline_acc, b, A_l)
        print(f"Reward for pruning {layer_name}: {reward:.4f}")
        
        self.optimizer.zero_grad()
        log_prob = torch.log(agent.policy_net(W_l))
        loss = -reward * log_prob.sum()  # REINFORCE loss
        loss.backward()
        self.optimizer.step()

    print("Pruning completed.")

In [79]:
# Define a set of pruning agents for the convolutional layers
pruning_agents = {
   # 'conv10': PruningAgent(num_filters=model.conv1.weight.size(0)),
    'layer10': PruningAgent(num_filters=model.layer1[0].conv1.weight.size(0)),  # First block of layer1
    'layer11': PruningAgent(num_filters=model.layer1[1].conv1.weight.size(0)), # Second block of layer1
    'layer12': PruningAgent(num_filters=model.layer1[2].conv1.weight.size(0)),
    'layer20': PruningAgent(num_filters=model.layer2[0].conv1.weight.size(0)),  # First block of layer2
    'layer21': PruningAgent(num_filters=model.layer2[1].conv1.weight.size(0)),  
    'layer22': PruningAgent(num_filters=model.layer2[2].conv1.weight.size(0)),  
    'layer23': PruningAgent(num_filters=model.layer2[3].conv1.weight.size(0)),  
    'layer30': PruningAgent(num_filters=model.layer3[0].conv1.weight.size(0)),  # First block of layer3
    'layer31': PruningAgent(num_filters=model.layer3[1].conv1.weight.size(0)),
    'layer32': PruningAgent(num_filters=model.layer3[2].conv1.weight.size(0)),
    'layer33': PruningAgent(num_filters=model.layer3[3].conv1.weight.size(0)),
    'layer34': PruningAgent(num_filters=model.layer3[4].conv1.weight.size(0)),
    'layer35': PruningAgent(num_filters=model.layer3[5].conv1.weight.size(0)),
    'layer40': PruningAgent(num_filters=model.layer4[0].conv1.weight.size(0)),  # First block of layer4
    'layer41': PruningAgent(num_filters=model.layer4[1].conv1.weight.size(0)), 
    'layer42': PruningAgent(num_filters=model.layer4[2].conv1.weight.size(0))
}


In [53]:
optimizer2 = optim.Adam(model.parameters(), lr = lr)
baseline_acc = evaluate_accuracy(model, test_dl)
b = 1  # Performance drop bound
num_epochs = 10

In [80]:
prune_layer_by_layer(model, pruning_agents, criterion, optimizer2, train_dl, val_dl, baseline_acc, b, num_epochs)

Pruning layer10...


RuntimeError: mat1 and mat2 shapes cannot be multiplied (128x64 and 128x128)

In [157]:
print("model size after pruning : ")
check_model_size(model)

model size after pruning : 
Total number of parameters: 66854730
Approximate model size: 255.03 MB


In [158]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [159]:
print("model size after pruning : ")
count_parameters(model)

model size after pruning : 


66854730

In [160]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in test_dl:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Accuracy on test set: {100 * correct / total:.2f}%")

Accuracy on test set: 85.34%
