In [None]:
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune

# Pruning

In [None]:
# Pruning function
def prune_model(model=None, pruning_rate=0):

    for _, module in model.named_modules():
        if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):            
            prune.l1_unstructured(module, name='weight', amount=pruning_rate, importance_scores=None)

# Use function to prune the model
prune_model(model, 0.9)

In [None]:
# Function that removes pruning mask
def delete_mask(model):
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
            prune.remove(module, 'weight')

    return model

# Use function to delete model's mask
model = delete_mask(model)

# Knowledge Distillation

In [None]:
'''
This code calculates the distillation loss as part of the distillation process.
The full distillation pipeline will be provided in in a future update.
'''
# Forward pass with the teacher model - do not save gradients here as we do not change the teacher's weights
with torch.no_grad():
    teacher_logits = teacher_model(X_train) # X_train is the batched training images
                
# Forward pass with the student model
student_logits = student_model(X_train) # X_train are the batched training images

# Soften the student logits by applying softmax first and log() second
# T is the "Temperature" hyperparameter
soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1)
soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)

# Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network"
soft_targets_loss = torch.sum(soft_targets * (soft_targets.log() - soft_prob)) / soft_prob.size()[0] * (T**2)

# Calculate the true label loss
label_loss = loss_fn(student_logits, y_train) # y_train are the batched labels

# Weighted sum of the two losses
# 'soft_target_loss_weight' and 'ce_loss_weight' are weight hyperparameters for the soft and hard target labels resectively
loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss

# Accumulates the total loss for the epoch for monitoring
epoch_loss += loss.item()

# Quantization

In [None]:
# Quantize the original 'model' in all linear layer to int8 from float32
model_int8 = torch.ao.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)

# Low Rank Factorization

In [None]:
# Decompose a dense layer using SVD
def decompose_linear_layer(layer, rank):
    weights = layer.weight.data
    # SVD decomposition
    U, S, V = torch.svd_lowrank(weights)
    U_hat, S_hat, V_hat = U[:, :rank], torch.diag(S[:rank]), V[:,:rank]
    W1 = torch.mm(S_hat, V_hat.t())

    return W1, U_hat, layer.bias.data

# Decompose the dense layer
# 'original_model' is the model we need to perform the LRF
W1, W2, biases = decompose_linear_layer(original_model.fc[1], rank)

# Set weights for the decomposed layers
# 'model_lrf' represents an instance of the decomposed model class
model_lrf.fc[1].weight.data = W1
model_lrf.fc1[1].weight.data = W2
model_lrf.fc1[1].bias.data = biases