In [None]:
!pip install transformers

In [None]:
import torch
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import torchvision
import torch.optim as optim
from tqdm import tqdm

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize CIFAR-10 images to 224x224 to match DeiT input size
    transforms.ToTensor(),
    # Normalize using ImageNet mean and std
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Download and load CIFAR-10 training dataset
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)

# Download and load CIFAR-10 test dataset
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=2)


In [None]:
from transformers import DeiTForImageClassificationWithTeacher

baseline_model = DeiTForImageClassificationWithTeacher.from_pretrained('facebook/deit-base-distilled-patch16-224')
baseline_model = baseline_model.to(device)
pruned_model = DeiTForImageClassificationWithTeacher.from_pretrained('facebook/deit-base-distilled-patch16-224')
pruned_model = pruned_model.to(device)

In [None]:
pruned_model

In [None]:
pruned_model.distillation_classifier = torch.nn.Linear(in_features=pruned_model.distillation_classifier.in_features, out_features=10)
pruned_model.cls_classifier = torch.nn.Linear(in_features=pruned_model.cls_classifier.in_features, out_features=10)

In [None]:
pruned_model = pruned_model.to(device)

In [None]:
pruned_model

In [None]:
def initialize_mask_MLP(layer):
    mask = torch.ones_like(layer.weight.data)
    layer.register_buffer('pruning_mask', mask)
    print("Initialized mask for MLP layer")

def initialize_mask_attn_layer(layer):
    for name in ['query', 'key', 'value']:
        linear_layer = getattr(layer, name, None)
        if linear_layer is not None:
            weight_matrix = linear_layer.weight
            mask = torch.ones_like(weight_matrix.data)
            layer.register_buffer(f'{name}_pruning_mask', mask)
            print(f"Initialized mask for self-attention {name}")

def initialize_masks_for_model(model):
    for name, module in model.named_modules():
        if 'cls_classifier' in name or 'distillation_classifier' in name:
            continue  # Skip initializing masks for these specific layers

        if isinstance(module, torch.nn.Linear):
            initialize_mask_MLP(module)
        elif 'DeiTSelfAttention' in str(type(module)) or 'DeiTAttention' in str(type(module)):  # A more flexible check
            # This assumes the self-attention module directly contains the query, key, value attributes
            if hasattr(module, 'query') and hasattr(module, 'key') and hasattr(module, 'value'):
                initialize_mask_attn_layer(module)
                
                
initialize_masks_for_model(pruned_model)

In [None]:
def check_and_view_masks(model):
    for name, module in model.named_modules():
        # Check for MLP layer masks
        if hasattr(module, 'pruning_mask'):
            print(f"{name} has a pruning mask.")
            # Print the mask to verify its contents
            print(module.pruning_mask)

        # Check for self-attention layer masks
        for attn_part in ['query', 'key', 'value']:
            mask_name = f'{attn_part}_pruning_mask'
            if hasattr(module, mask_name):
                print(f"{name} has a {mask_name}.")
                # Print the mask to verify its contents
                print(getattr(module, mask_name))

# Example usage, assuming your model variable is named `model` and you have already called initialize_masks_for_model(model):
check_and_view_masks(pruned_model)


In [None]:
def copy_weights_to_mask(layer):
    # Handle MLP layer
    if hasattr(layer, 'pruning_mask'):
        layer.pruning_mask.data = torch.clone(layer.weight.data)
    
    # Handle self-attention layer
    for name in ['query', 'key', 'value']:
        mask_name = f'{name}_pruning_mask'
        if hasattr(layer, mask_name):
            attn_weight = getattr(layer, name).weight
            attn_mask = getattr(layer, mask_name)
            attn_mask.data = torch.clone(attn_weight.data)


In [None]:
def view_weights_and_masks(model):
    for name, module in model.named_modules():
        # Check for layers with weights (typically MLP layers)
        if hasattr(module, 'weight') and hasattr(module, 'pruning_mask'):
            print(f"Layer: {name}")
            print("Weights:")
            print(module.weight.data)
            print("Mask:")
            print(module.pruning_mask.data)
            print("-----------------------------------")

        # Additionally, check for self-attention layers if they have separate masks
        if hasattr(module, 'query') and hasattr(module, 'key') and hasattr(module, 'value'):
            for part in ['query', 'key', 'value']:
                if hasattr(module, f'{part}') and hasattr(module, f'{part}_pruning_mask'):
                    attn_component = getattr(module, part)
                    mask = getattr(module, f'{part}_pruning_mask')
                    print(f"Layer: {name} - {part.capitalize()}")
                    print("Weights:")
                    print(attn_component.weight.data)
                    print("Mask:")
                    print(mask.data)
                    print("-----------------------------------")

In [None]:
view_weights_and_masks(pruned_model)

In [None]:
def apply_copy_weights_to_masks(model):
    for name, module in model.named_modules():
        # Apply to MLP layers directly
        if hasattr(module, 'pruning_mask'):
            copy_weights_to_mask(module)
        
        # For self-attention layers in DeiT, the attention mechanism is wrapped within another module.
        # Thus, we directly check if the module contains self-attention components.
        for attn_part in ['query', 'key', 'value']:
            if hasattr(module, f'{attn_part}'):
                copy_weights_to_mask(module)


In [None]:
# apply_copy_weights_to_masks(pruned_model)

In [None]:
view_weights_and_masks(pruned_model)

In [None]:
# def check_and_compare_weights_with_masks(model):
#     for name, module in model.named_modules():
#         # Check and compare for MLP layer
#         if hasattr(module, 'pruning_mask'):
#             print(f"{name} (MLP layer) - Mask copied correctly: ", 
#                   torch.equal(module.pruning_mask.data, module.weight.data))

#         # Check and compare for self-attention layer components
#         for attn_part in ['query', 'key', 'value']:
#             mask_name = f'{attn_part}_pruning_mask'
#             if hasattr(module, mask_name):
#                 weight = getattr(module, attn_part).weight
#                 mask = getattr(module, mask_name)
#                 print(f"{name} (Self-attention {attn_part}) - Mask copied correctly: ",
#                       torch.equal(mask.data, weight.data))

# # After performing the copy operation with apply_copy_weights_to_masks(model):
# # apply_copy_weights_to_masks(model)
# # Now, check and compare the masks with the weights:
# check_and_compare_weights_with_masks(pruned_model)


In [None]:
def apply_z_score_normalization(mask):
    print("Z score called")
    mean = torch.mean(mask)
    std = torch.std(mask)
    normalized_mask = (mask - mean) / std
    print(normalized_mask)
    return normalized_mask

In [None]:
def binarize_mask(mask, pruning_ratio):
    print("Binarize called")
    # Flatten the mask to simplify thresholding
    flat_mask = mask.view(-1)
    # Calculate the number of weights to keep
    num_weights_to_keep = int((1 - pruning_ratio) * flat_mask.numel())
    # Use torch.topk to get the threshold value
    threshold_value, _ = torch.topk(flat_mask.abs(), num_weights_to_keep, largest=True)
    min_value_to_keep = threshold_value[-1]
    # Binarize the mask
    binarized_mask = torch.where(flat_mask.abs() >= min_value_to_keep, torch.tensor(1.0, device=mask.device), torch.tensor(0.0, device=mask.device))
    print(binarized_mask)
    return binarized_mask.view_as(mask)

In [None]:
class AdaptivePruning:
    def __init__(self, prev_val_acc, delta=0.01, performance_threshold=0.02):
        self.delta = delta
        self.performance_threshold = performance_threshold
        self.previous_accuracy = prev_val_acc

    def adjust_pruning_rate(self, current_accuracy, current_pruning_rate):
        accuracy_change = current_accuracy - self.previous_accuracy

        # Adjust the pruning rate based on the accuracy change
        if accuracy_change < -self.performance_threshold:
            adjusted_pruning_rate = max(current_pruning_rate - self.delta, 0)
        elif accuracy_change >= self.performance_threshold:
            adjusted_pruning_rate = min(current_pruning_rate + self.delta, 1)
        else:
            adjusted_pruning_rate = current_pruning_rate

        # Update the previous accuracy for the next iteration
        self.previous_accuracy = current_accuracy

        return adjusted_pruning_rate


In [None]:
import torch
import torch.nn.functional as F

def combined_loss(output, target, model, lambda_sparsity=1e-4):
    # Cross Entropy Loss
    ce_loss = F.cross_entropy(output, target)
    
    # Sparsity Loss (L1 norm of model weights)
    sparsity_loss = 0
    for param in model.parameters():
        sparsity_loss += torch.sum(torch.abs(param))
    
    # Combined Loss
    combined_loss = ce_loss + lambda_sparsity*sparsity_loss
    
    return combined_loss

In [None]:
def update_pruning_rate(epoch, max_epoch, initial_rate, final_rate):
    # Linear scheduling from initial_rate to final_rate
    current_rate = initial_rate + (final_rate - initial_rate) * (epoch / max_epoch)
    return current_rate

In [None]:
def apply_mask_MLP(layer):
    if hasattr(layer, 'pruning_mask'):
        # Ensure the mask is on the same device as the layer weights
        mask = layer.pruning_mask.to(layer.weight.device)
        # Apply the mask by element-wise multiplication
        masked_weights = layer.weight.data.mul(mask)
        # Explicitly assign the masked weights back to the layer's weight attribute
        layer.weight.data = masked_weights
        if layer.bias is not None and hasattr(layer, 'bias_mask'):
            bias_mask = layer.bias_mask.to(layer.bias.device)
            masked_bias = layer.bias.data.mul(bias_mask)
            layer.bias.data = masked_bias
        print("Applied pruning mask to MLP layer")


In [None]:
def apply_mask_attn_layer(layer):
    # Iterate through each component of the self-attention mechanism
    for name in ['query', 'key', 'value']:
        # Access the linear layer for query, key, value
        attn_component = getattr(layer, name, None)
        if attn_component is not None:
            # Access the weight tensor of the component
            weight_matrix = attn_component.weight
            # Access the corresponding mask
            mask_name = f'{name}_pruning_mask'
            mask = getattr(layer, mask_name, None)
            if mask is not None:
                # Ensure the mask is on the same device as the weights
                mask = mask.to(weight_matrix.device)
                # Apply the mask by element-wise multiplication
                weight_matrix.data.mul_(mask)
                print(f"Applied mask to {name} weights in self-attention layer")


In [None]:
def apply_masks_for_model(model):
    for name, module in model.named_modules():
        print("Going to modules")

        if isinstance(module, torch.nn.Linear):
            print("going to MLP")
            apply_mask_MLP(module)
        # Adjust the condition to check for self-attention layers more accurately
        elif hasattr(module, 'query') and hasattr(module, 'key') and hasattr(module, 'value'):
            print("going to self attn")
            apply_mask_attn_layer(module)


In [None]:
def check_binarization_accuracy(model, pruning_ratio):
    
    binarization_accuracy = {}

    for name, module in model.named_modules():
        if 'cls_classifier' in name or 'distillation_classifier' in name:
            continue  # Skip initializing masks for these specific layers

        masks_to_check = []

        # For MLP layers, add the single pruning mask to the list
        if hasattr(module, 'pruning_mask'):
            masks_to_check.append(module.pruning_mask)

        # For self-attention layers, add query, key, and value masks to the list
        for attn_part in ['query', 'key', 'value']:
            mask_name = f'{attn_part}_pruning_mask'
            if hasattr(module, mask_name):
                masks_to_check.append(getattr(module, mask_name))

        # Calculate and compare sparsity for each mask
        for mask in masks_to_check:
            total_elements = mask.numel()
            zero_elements = total_elements - torch.count_nonzero(mask)
            actual_sparsity = zero_elements.float() / total_elements
            expected_sparsity = pruning_ratio

            # Store the results
            binarization_accuracy[name] = (expected_sparsity, actual_sparsity)

            # Optionally, print out the results for each mask
            print(f"{name}: Expected Sparsity: {expected_sparsity:.3f}, Actual Sparsity: {actual_sparsity:.3f}")

    return binarization_accuracy


In [None]:
#apply_masks_for_model(pruned_model)

In [None]:
view_weights_and_masks(pruned_model)

In [None]:
# def view_weights_of_layers(model):
#     for name, module in model.named_modules():
#         # Check for MLP layer weights
#         if isinstance(module, torch.nn.Linear):
#             print(f"Viewing weights for MLP layer: {name}")
#             print(module.weight.data)

#         # Check for self-attention layer weights
#         # This assumes your model's self-attention mechanism follows a specific naming convention or structure
#         elif hasattr(module, 'query') and hasattr(module, 'key') and hasattr(module, 'value'):
#             print(f"Viewing weights for Self-Attention layer: {name}")
#             for attn_part in ['query', 'key', 'value']:
#                 weight_matrix = getattr(module, attn_part).weight.data
#                 print(f"Weights for {attn_part}:")
#                 print(weight_matrix)

In [None]:
# view_weights_of_layers(pruned_model)

In [None]:
# def check_weights_squared(model):
#     for name, module in model.named_modules():
#         # Check for MLP layers and their masks
#         if isinstance(module, torch.nn.Linear) and hasattr(module, 'pruning_mask'):
#             original_mask = module.pruning_mask.data
#             squared_mask = original_mask ** 2
#             current_weights = module.weight.data
            
#             # Compare the squared mask with the current weights
#             if torch.allclose(squared_mask, current_weights, rtol=1e-05, atol=1e-08):
#                 print(f"{name}: Weights are squared of the original mask values.")
#             else:
#                 print(f"{name}: Weights are NOT squared of the original mask values.")
                
#         # Check for self-attention layer weights and their masks
#         elif hasattr(module, 'query') and hasattr(module, 'key') and hasattr(module, 'value'):
#             for part in ['query', 'key', 'value']:
#                 linear_layer = getattr(module, part)
#                 if hasattr(module, f'{part}_pruning_mask'):
#                     original_mask = getattr(module, f'{part}_pruning_mask').data
#                     squared_mask = original_mask ** 2
#                     current_weights = linear_layer.weight.data
                    
#                     # Compare the squared mask with the current weights
#                     if torch.allclose(squared_mask, current_weights, rtol=1e-05, atol=1e-08):
#                         print(f"{name} {part}: Weights are squared of the original mask values.")
#                     else:
#                         print(f"{name} {part}: Weights are NOT squared of the original mask values.")


In [None]:
# check_weights_squared(pruned_model)

In [None]:
# import inspect
# from transformers import DeiTForImageClassificationWithTeacher

# # Attempt to print the forward method source code
# print(inspect.getsource(DeiTForImageClassificationWithTeacher.forward))

In [None]:
# import torchvision.utils as utils
# class_labels = train_dataset.classes
# class_labels

In [None]:
# for data, targets in train_loader:
#     print(f'{data.shape} {targets.shape}')
    
#     num_images_to_display = 4
#     grid_images = utils.make_grid(data[:num_images_to_display], nrow=num_images_to_display)

#     # Move the tensor to CPU and convert it to a NumPy array for visualization
#     grid_images = grid_images.cpu().numpy().transpose((1, 2, 0))
#     print([class_labels[x] for x in targets[:num_images_to_display]])
#     # Display the images
#     plt.imshow(grid_images)
#     plt.title('Batch of Images')
#     plt.axis('off')
#     plt.show()
#     break

In [None]:
# with torch.no_grad():
#     for data, targets in train_loader:
#         input_data = data
#         break
#     outputs = x_model(input_data)
#     logits = outputs["logits"]
#     logits.shape

In [None]:
# for epoch in epochs:
#     for data, targets in train_loader:
#         outputs = model(data)
        

In [None]:
def train_one_epoch(epoch, model, train_loader, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    progress_bar = tqdm(enumerate(train_loader), total=len(train_loader))
    for batch_idx, (inputs, targets) in progress_bar:
        inputs, targets = inputs.to(device), targets.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = combined_loss(outputs.logits, targets, model)  # Use the defined combined_loss function
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.logits.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        progress_bar.set_description(f'Epoch {epoch} Loss: {running_loss/(batch_idx+1):.3f} Acc: {100.*correct/total:.3f}%')
    
    return running_loss / len(train_loader), 100.*correct / total

In [None]:
def validate(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            _, predicted = outputs.logits.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    
    acc = 100.*correct / total
    print(f'Validation Accuracy: {acc:.3f}%')
    return acc

In [None]:
# def unified_pruning(epoch, model, train_loader, optimizer, device, initial_pruning_rate, performance_threshold, max_epoch):
    
    
#     apply_copy_weights_to_masks(pruned_model)
    
#     for name, module in model.named_modules():
    
#         if hasattr(module, 'pruning_mask'):
#             print("z score for MLP unified pruning")
#             module.pruning_mask.data = apply_z_score_normalization(module.pruning_mask.data)
#         elif hasattr(module, 'query') and hasattr(module, 'key') and hasattr(module, 'value'):
#             print("going to self attn")
#             for name in ['query', 'key', 'value']:
#                 # Access the linear layer for query, key, value
#                 attn_component = getattr(module, name, None)
#                 if attn_component is not None:
#                     mask_name = f'{name}_pruning_mask'
#                     mask = getattr(module, mask_name, None)
#                     if mask is not None:
#                         zscore = apply_z_score_normalization(mask)
#                         setattr(module, mask_name, zscore)
            
# #         for attn_part in ['query', 'key', 'value']:
# #             mask_name = f'{attn_part}_pruning_mask'
# #             if hasattr(module, mask_name):
# #                 mask = getattr(module, mask_name)
# #                 print("z score for attn unified pruning")
# #                 setattr(module, mask_name, apply_z_score_normalization(mask))
                
#     current_accuracy = validate(model, test_loader, device) 
#     adaptive_pruning = AdaptivePruning(prev_val_acc=current_accuracy, initial_pruning_rate=initial_pruning_rate, delta=0.01, performance_threshold=performance_threshold)
#     current_pruning_rate = adaptive_pruning.adjust_pruning_rate(current_accuracy)
    
#     for name, module in model.named_modules():
        
#         if hasattr(module, 'pruning_mask'):
#             print("brinarize for MLP unified pruning")
#             module.pruning_mask.data = binarize_mask(module.pruning_mask.data, current_pruning_rate)
#         elif hasattr(module, 'query') and hasattr(module, 'key') and hasattr(module, 'value'):
#             print("going to self attn")
#             for name in ['query', 'key', 'value']:
#                 # Access the linear layer for query, key, value
#                 attn_component = getattr(module, name, None)
#                 if attn_component is not None:
#                     mask_name = f'{name}_pruning_mask'
#                     mask = getattr(module, mask_name, None)
#                     if mask is not None:
#                         binarized = binarize_mask(mask, current_pruning_rate)
#                         setattr(module, mask_name, binarized)
            
# #         for attn_part in ['query', 'key', 'value']:
# #             mask_name = f'{attn_part}_pruning_mask'
# #             if hasattr(module, mask_name):
# #                 mask = getattr(module, mask_name)
# #                 print("brinarize for attn unified pruning")
# #                 setattr(module, mask_name, binarize_mask(mask, current_pruning_rate))
    
#     binarization_results = check_binarization_accuracy(model, current_pruning_rate)
#     apply_masks_for_model(model)
#     check_pruning_effectiveness(model, current_pruning_rate)
#     freeze_pruned_weights(model)
    
#     train_loss, train_acc = train_one_epoch(epoch, model, train_loader, optimizer, device)
#     return train_loss, train_acc
# # epoch_loss, epoch_acc = unified_pruning(pruned_model, current_epoch, total_epochs, train_loader, optimizer, device, initial_rate, final_rate)


In [None]:
def unified_progressive_pruning(model, epoch, max_epoch, train_loader, optimizer, device, initial_pruning_rate, final_pruning_rate):
    
    
    train_loss, train_acc = train_one_epoch(epoch, model, train_loader, optimizer, device)

    
    print("copying weights for prog pruning")
    apply_copy_weights_to_masks(model)
    
    
    for name, module in model.named_modules():
        
        if hasattr(module, 'pruning_mask'):
            print("z score for MLP progressive pruning")
            module.pruning_mask.data = apply_z_score_normalization(module.pruning_mask.data)
        elif hasattr(module, 'query') and hasattr(module, 'key') and hasattr(module, 'value'):
            print("going to self attn")
            for name in ['query', 'key', 'value']:
                # Access the linear layer for query, key, value
                attn_component = getattr(module, name, None)
                if attn_component is not None:
                    mask_name = f'{name}_pruning_mask'
                    mask = getattr(module, mask_name, None)
                    if mask is not None:
                        zscore = apply_z_score_normalization(mask)
                        setattr(module, mask_name, zscore)
            
                            
    print("updating weights for prog pruning")
    pruning_rate = update_pruning_rate(epoch, max_epoch, initial_pruning_rate, final_pruning_rate)
    current_accuracy = validate(model, test_loader, device)
    pruning_rate = adaptive_pruning.adjust_pruning_rate(current_accuracy,pruning_rate)



    print("Binarizing masks for prog pruning")
    for name, module in model.named_modules():
        
        if hasattr(module, 'pruning_mask'):
            print("brinarize for MLP unified pruning")
            module.pruning_mask.data = binarize_mask(module.pruning_mask.data, pruning_rate)
        elif hasattr(module, 'query') and hasattr(module, 'key') and hasattr(module, 'value'):
            print("going to self attn")
            for name in ['query', 'key', 'value']:
                # Access the linear layer for query, key, value
                attn_component = getattr(module, name, None)
                if attn_component is not None:
                    mask_name = f'{name}_pruning_mask'
                    mask = getattr(module, mask_name, None)
                    if mask is not None:
                        binarized = binarize_mask(mask, pruning_rate)
                        setattr(module, mask_name, binarized)
           
                
    print("applying for progressive pruning")
    
    binarization_results = check_binarization_accuracy(model, pruning_rate)
    apply_masks_for_model(model)
    check_pruning_effectiveness(model, pruning_rate)
    freeze_pruned_weights(model)
    
    return current_accuracy


In [None]:
def freeze_pruned_weights(model):
    for name, module in model.named_modules():
        if hasattr(module, 'pruning_mask'):
            # Freeze weights for MLP layers
            def mlp_hook(grad, mask=module.pruning_mask):
                return grad * mask

            module.weight.register_hook(mlp_hook)
            print(f"Freezing pruned weights in {name} (MLP layer)")

        elif hasattr(module, 'query') and hasattr(module, 'key') and hasattr(module, 'value'):
            # Handle self-attention layers' query, key, value
            for attn_part in ['query', 'key', 'value']:
                linear_layer = getattr(module, attn_part, None)
                if linear_layer is not None:
                    mask_name = f'{attn_part}_pruning_mask'
                    mask = getattr(module, mask_name, None)
                    if mask is not None:
                        def attn_hook(grad, mask=mask):
                            return grad * mask

                        # Applying hook to the weight parameter of the linear layers within the self-attention mechanism
                        linear_layer.weight.register_hook(attn_hook)
                        print(f"Freezing pruned weights in {name}.{attn_part} (Self-attention component)")


In [None]:
def check_pruning_effectiveness(model, desired_pruning_ratio):
    for name, module in model.named_modules():
        # Check for MLP layers
        if isinstance(module, torch.nn.Linear):
            mask = getattr(module, 'pruning_mask', None)
            if mask is not None:
                actual_sparsity = torch.mean((mask == 0).float()).item()
                print(f"MLP Layer {name}: Desired sparsity = {desired_pruning_ratio}, Actual sparsity = {actual_sparsity}")
                if actual_sparsity < desired_pruning_ratio:
                    print(f"Warning: {name} layer is under-pruned.")
                elif actual_sparsity > desired_pruning_ratio:
                    print(f"Warning: {name} layer is over-pruned.")
        # Check for self-attention layers
        elif hasattr(module, 'query') and hasattr(module, 'key') and hasattr(module, 'value'):
            for attn_name in ['query', 'key', 'value']:
                attn_component = getattr(module, attn_name, None)
                if attn_component is not None:
                    mask_name = f'{attn_name}_pruning_mask'
                    mask = getattr(module, mask_name, None)
                    if mask is not None:
                        actual_sparsity = torch.mean((mask == 0).float()).item()
                        print(f"Self-attention {attn_name} in {name}: Desired sparsity = {desired_pruning_ratio}, Actual sparsity = {actual_sparsity}")
                        if actual_sparsity < desired_pruning_ratio:
                            print(f"Warning: {name} {attn_name} is under-pruned.")
                        elif actual_sparsity > desired_pruning_ratio:
                            print(f"Warning: {name} {attn_name} is over-pruned.")


In [None]:
# Hyperparameters and settings
initial_pruning_rate = 0.1 
final_pruning_rate = 0.8   
performance_threshold = 0.02  
lambda_sparsity = 1e-4  # Sparsity weighting factor for combined loss
optimizer = optim.Adam(pruned_model.parameters(), lr=0.001)
max_epoch = 16

In [None]:
#Initiating Training
for epoch in range(1, 17):  # 1 to 16 epochs
    if epoch in [1]:
        prev_val_acc= validate(pruned_model, test_loader, device)
        
    if epoch in [15, 16]:
        # Train the model normally
        print(f"Epoch {epoch}: Normal Training")
        train_loss, train_acc = train_one_epoch(epoch, pruned_model, train_loader, optimizer, device)
        print(f"Training Loss: {train_loss}, Training Accuracy: {train_acc}")
        val_acc = validate(pruned_model, test_loader, device)
        print(f"Validation Accuracy: {val_acc}%")
    else:
        # Apply Progressive Pruning
        adaptive_pruning = AdaptivePruning(prev_val_acc)
        print(f"Epoch {epoch}: Applying Progressive Pruning")
        prev_val_acc= unified_progressive_pruning(pruned_model, epoch, max_epoch, train_loader, optimizer, device, initial_pruning_rate, final_pruning_rate)
        print(f"Validation Accuracy: {prev_val_acc}%")


In [None]:
view_weights_and_masks(pruned_model)

In [None]:
# !pip install transformers
# import torch
# from torchvision import datasets, transforms
# import matplotlib.pyplot as plt
# import numpy as np
# import torchvision
# import torch.optim as optim
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device
# transform = transforms.Compose([
#     transforms.Resize((224, 224)),  # Resize CIFAR-10 images to 224x224 to match DeiT input size
#     transforms.ToTensor(),
#     # Normalize using ImageNet mean and std
#     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# ])

# # Download and load CIFAR-10 training dataset
# train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
# train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)

# # Download and load CIFAR-10 test dataset
# test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
# test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2)

# no_of_batches = len(train_loader)
# no_of_batches, 50000/64

# from transformers import DeiTForImageClassificationWithTeacher

# baseline_model = DeiTForImageClassificationWithTeacher.from_pretrained('facebook/deit-base-distilled-patch16-224')
# baseline_model = baseline_model.to(device)
# pruned_model = DeiTForImageClassificationWithTeacher.from_pretrained('facebook/deit-base-distilled-patch16-224')
# pruned_model = pruned_model.to(device)

# pruned_model

# pruned_model.distillation_classifier = torch.nn.Linear(in_features=pruned_model.distillation_classifier.in_features, out_features=10)
# pruned_model.cls_classifier = torch.nn.Linear(in_features=pruned_model.cls_classifier.in_features, out_features=10)

# pruned_model

# num_output_features = pruned_model.distillation_classifier.out_features
# print("Number of output features:", num_output_features)

# pruned_model = pruned_model.to(device)

# pruned_model

# def initialize_mask_MLP(layer):
#     mask = torch.ones_like(layer.weight.data)
#     layer.register_buffer('pruning_mask', mask)
    
# def initialize_mask_attn_layer(layer):
#     for name in ['query', 'key', 'value']:
#         weight_matrix = getattr(layer, f'{name}_weight')
#         mask = torch.ones_like(weight_matrix.data)
#         layer.register_buffer(f'{name}_pruning_mask', mask)
        
# x_model = pruned_model
# x_model

# def initialize_masks_for_model(model):
#     for name, module in model.named_modules():
#         # Check if the module is an MLP layer
#         if isinstance(module, torch.nn.Linear):
#             initialize_mask_MLP(module)
#         # Check if the module is a self-attention layer
#         # This is a simplified check; you may need a more specific condition based on your model's architecture
#         elif hasattr(module, 'query_weight') and hasattr(module, 'key_weight') and hasattr(module, 'value_weight'):
#             initialize_mask_attn_layer(module)
            
            
# initialize_masks_for_model(pruned_model)

# def check_and_view_masks(model):
#     for name, module in model.named_modules():
#         # Check for MLP layer masks
#         if hasattr(module, 'pruning_mask'):
#             print(f"{name} has a pruning mask.")
#             print(module.pruning_mask)
#         # Check for self-attention layer masks
#         for attn_part in ['query', 'key', 'value']:
#             mask_name = f'{attn_part}_pruning_mask'
#             if hasattr(module, mask_name):
#                 print(f"{name} has a {mask_name}.")
#                 print(getattr(module, mask_name))

# # # Call this function with your model
# check_and_view_masks(pruned_model)

# def copy_weights_to_mask(layer):
#     # Assuming the mask is already initialized and has the same shape as the layer's weights
#     layer.pruning_mask.data = torch.clone(layer.weight.data)
    
# def apply_z_score_normalization(mask):
#     mean = torch.mean(mask)
#     std = torch.std(mask)
#     normalized_mask = (mask - mean) / std
#     return normalized_mask


# def binarize_mask(mask, pruning_ratio):
#     # Flatten the mask to simplify thresholding
#     flat_mask = mask.view(-1)
#     # Calculate the number of weights to keep
#     num_weights_to_keep = int((1 - pruning_ratio) * flat_mask.numel())
#     # Use torch.topk to get the threshold value
#     threshold_value, _ = torch.topk(flat_mask.abs(), num_weights_to_keep, largest=True)
#     min_value_to_keep = threshold_value[-1]
#     # Binarize the mask
#     binarized_mask = torch.where(flat_mask.abs() >= min_value_to_keep, torch.tensor(1.0, device=mask.device), torch.tensor(0.0, device=mask.device))
#     return binarized_mask.view_as(mask)


# class AdaptivePruning:
#     def __init__(self, prev_val_acc, initial_pruning_rate=0.1, delta=0.01, performance_threshold=0.02):
#         self.pruning_rate = initial_pruning_rate
#         self.delta = delta  # Increment/decrement step for the pruning rate
#         self.performance_threshold = performance_threshold  # Minimum acceptable change in validation accuracy
#         self.previous_accuracy = prev_val_acc  # Placeholder for the last recorded accuracy

#     def adjust_pruning_rate(self, current_accuracy):
#         """
#         Adjusts the pruning rate based on the change in accuracy.
        
#         :param current_accuracy: The current accuracy of the model on the validation set.
#         """
#         accuracy_change = current_accuracy - self.previous_accuracy

#         # If performance drops significantly, decrease pruning rate
#         if accuracy_change < -self.performance_threshold:
#             self.pruning_rate = max(self.pruning_rate - self.delta, 0)  # Ensure pruning rate doesn't go negative
#         # If performance improves or remains stable, consider increasing the pruning rate
#         elif accuracy_change >= self.performance_threshold:
#             self.pruning_rate = min(self.pruning_rate + self.delta, 1)  # Ensure pruning rate doesn't exceed 1

#         self.previous_accuracy = current_accuracy  # Update the previous accuracy
#         return self.pruning_rate

    
# import torch
# import torch.nn.functional as F

# def combined_loss(output, target, model, lambda_sparsity=1e-4):
#     """
#     Calculate combined loss = CrossEntropyLoss + lambda * SparsityLoss
    
#     :param output: Tensor, model output logits
#     :param target: Tensor, ground truth labels
#     :param model: PyTorch model, to calculate sparsity loss over its parameters
#     :param lambda_sparsity: float, weighting factor for sparsity loss
#     :return: combined loss value
#     """
#     # Cross Entropy Loss
#     ce_loss = F.cross_entropy(output, target)
    
#     # Sparsity Loss (L1 norm of model weights)
#     sparsity_loss = 0
#     for param in model.parameters():
#         sparsity_loss += torch.sum(torch.abs(param))
    
#     # Combined Loss
#     combined_loss = ce_loss + lambda_sparsity * sparsity_loss
    
#     return combined_loss

# def update_pruning_rate(epoch, max_epoch, initial_rate, final_rate):
#     """
#     Updates the pruning rate over time, increasing from an initial rate to a final rate linearly over the epochs.

#     Args:
#         epoch (int): Current epoch number.
#         max_epoch (int): Total number of epochs for training.
#         initial_rate (float): Initial pruning rate at the beginning of training.
#         final_rate (float): Final pruning rate by the end of training.
    
#     Returns:
#         float: The updated pruning rate for the current epoch.
#     """
#     # Linear scheduling from initial_rate to final_rate
#     current_rate = initial_rate + (final_rate - initial_rate) * (epoch / max_epoch)
#     return current_rate

# def apply_pruning(model):
#     for name, module in model.named_modules():
#         # Check for MLP layer masks
#         if hasattr(module, 'pruning_mask'):
#             print(f"{name} has a pruning mask.")
#             module.weight.data = module.weight.data * module.pruning_mask
#         # Check for self-attention layer masks
#         for attn_part in ['query', 'key', 'value']:
#             mask_name = f'{attn_part}_pruning_mask'
#             if hasattr(module, mask_name):
#                 print(f"{name} has a {mask_name}.")
#                 module.weight.data = module.weight.data * module.pruning_mask

# import inspect
# from transformers import DeiTForImageClassificationWithTeacher

# # Attempt to print the forward method source code
# print(inspect.getsource(DeiTForImageClassificationWithTeacher.forward))


# import torchvision.utils as utils
# class_labels = train_dataset.classes
# class_labels

# for data, targets in train_loader:
#     print(f'{data.shape} {targets.shape}')
    
#     num_images_to_display = 4
#     grid_images = utils.make_grid(data[:num_images_to_display], nrow=num_images_to_display)

#     # Move the tensor to CPU and convert it to a NumPy array for visualization
#     grid_images = grid_images.cpu().numpy().transpose((1, 2, 0))
#     print([class_labels[x] for x in targets[:num_images_to_display]])
#     # Display the images
#     plt.imshow(grid_images)
#     plt.title('Batch of Images')
#     plt.axis('off')
#     plt.show()
#     break
    
    
    
# import torch
# from tqdm import tqdm

# # Assuming you have defined train_loader, test_loader, model, optimizer, and device

# def train_one_epoch(epoch, model, train_loader, optimizer, device):
#     model.train()
#     running_loss = 0.0
#     correct = 0
#     total = 0
    
#     progress_bar = tqdm(enumerate(train_loader), total=len(train_loader))
#     for batch_idx, (inputs, targets) in progress_bar:
#         inputs, targets = inputs.to(device), targets.to(device)
        
#         optimizer.zero_grad()
#         outputs = model(inputs)
#         loss = combined_loss(outputs.logits, targets, model)  # Use the defined combined_loss function
#         loss.backward()
#         optimizer.step()

#         running_loss += loss.item()
#         _, predicted = outputs.logits.max(1)
#         total += targets.size(0)
#         correct += predicted.eq(targets).sum().item()

#         progress_bar.set_description(f'Epoch {epoch} Loss: {running_loss/(batch_idx+1):.3f} Acc: {100.*correct/total:.3f}%')
    
#     return running_loss / len(train_loader), 100.*correct / total

# def validate(model, test_loader, device):
#     model.eval()
#     correct = 0
#     total = 0
#     with torch.no_grad():
#         for inputs, targets in test_loader:
#             inputs, targets = inputs.to(device), targets.to(device)
#             outputs = model(inputs)
#             _, predicted = outputs.logits.max(1)
#             total += targets.size(0)
#             correct += predicted.eq(targets).sum().item()
    
#     acc = 100.*correct / total
#     print(f'Validation Accuracy: {acc:.3f}%')
#     return acc

# def unified_pruning(epoch, model, train_loader, optimizer, device, initial_pruning_rate, performance_threshold, max_epoch):
    
#     for name, module in model.named_modules():
#         # Check if the module is an MLP layer or a self-attention layer
#         if isinstance(module, torch.nn.Linear) or \
#         (hasattr(module, 'query_weight') and hasattr(module, 'key_weight') and hasattr(module, 'value_weight')):
#             copy_weights_to_mask(module)
   
    
#     for name, module in model.named_modules():
    
#         if hasattr(module, 'pruning_mask'):
#             module.pruning_mask.data = apply_z_score_normalization(module.pruning_mask.data)
#         for attn_part in ['query', 'key', 'value']:
#             mask_name = f'{attn_part}_pruning_mask'
#             if hasattr(module, mask_name):
#                 mask = getattr(module, mask_name)
#                 setattr(module, mask_name, apply_z_score_normalization(mask))
                
#     current_accuracy = validate(model, test_loader, device) 
#     adaptive_pruning = AdaptivePruning(prev_val_acc=current_accuracy, initial_pruning_rate=initial_pruning_rate, delta=0.01, performance_threshold=performance_threshold)
#     current_pruning_rate = adaptive_pruning.adjust_pruning_rate(current_accuracy)
    
#     for name, module in model.named_modules():
        
#         if hasattr(module, 'pruning_mask'):
#             module.pruning_mask.data = binarize_mask(module.pruning_mask.data, current_pruning_rate)
#         for attn_part in ['query', 'key', 'value']:
#             mask_name = f'{attn_part}_pruning_mask'
#             if hasattr(module, mask_name):
#                 mask = getattr(module, mask_name)
#                 setattr(module, mask_name, binarize_mask(mask, current_pruning_rate))
#     apply_pruning(model)
#     train_loss, train_acc = train_one_epoch(epoch, model, train_loader, optimizer, device)
#     return train_loss, train_acc
# # epoch_loss, epoch_acc = unified_pruning(pruned_model, current_epoch, total_epochs, train_loader, optimizer, device, initial_rate, final_rate)


# def progressive_pruning(model, epoch, max_epoch, train_loader, optimizer, device, initial_pruning_rate, final_pruning_rate):

#     for name, module in model.named_modules():
#         if isinstance(module, torch.nn.Linear) or \
#         (hasattr(module, 'query_weight') and hasattr(module, 'key_weight') and hasattr(module, 'value_weight')):
#             copy_weights_to_mask(module)
    
#     for name, module in model.named_modules():
#         if hasattr(module, 'pruning_mask'):
#             module.pruning_mask.data = apply_z_score_normalization(module.pruning_mask)
#         for attn_part in ['query', 'key', 'value']:
#             mask_name = f'{attn_part}_pruning_mask'
#             if hasattr(module, mask_name):
#                 getattr(module, mask_name).data = apply_z_score_normalization(getattr(module, mask_name))
                
#     pruning_rate = update_pruning_rate(epoch, max_epoch, initial_pruning_rate, final_pruning_rate)
   
#     for name, module in model.named_modules():
        
#         if hasattr(module, 'pruning_mask'):
#             module.pruning_mask.data = binarize_mask(module.pruning_mask.data, pruning_rate)
#         for attn_part in ['query', 'key', 'value']:
#             mask_name = f'{attn_part}_pruning_mask'
#             if hasattr(module, mask_name):
#                 mask = getattr(module, mask_name)
#                 setattr(module, mask_name, binarize_mask(mask, pruning_rate))
                
#     apply_pruning(model)
#     train_loss, train_acc = train_one_epoch(epoch, model, train_loader, optimizer, device)
    
#     return train_loss, train_acc


# # Hyperparameters and settings
# initial_pruning_rate = 0.1  # Example initial pruning rate
# final_pruning_rate = 0.5   # Example final pruning rate for unified pruning
# performance_threshold = 0.02  # Example performance threshold for progressive pruning
# lambda_sparsity = 1e-4  # Sparsity weighting factor for combined loss
# optimizer = optim.Adam(pruned_model.parameters(), lr=0.001)
# max_epoch = 16

# #Initiating Training
# for epoch in range(1, 17):  # 1 to 16 epochs
#     if epoch in [4, 8, 12]:
#         # Apply Unified Pruning
#         print(f"Epoch {epoch}: Applying Unified Pruning")
#         unified_pruning(epoch, pruned_model, train_loader, optimizer, device, initial_pruning_rate, final_pruning_rate, 16)
#     elif epoch in [15, 16]:
#         # Train the model normally
#         print(f"Epoch {epoch}: Normal Training")
#         train_loss, train_acc = train_one_epoch(epoch, pruned_model, train_loader, optimizer, device)
#         print(f"Training Loss: {train_loss}, Training Accuracy: {train_acc}")
#     else:
#         # Apply Progressive Pruning
#         print(f"Epoch {epoch}: Applying Progressive Pruning")
#         progressive_pruning(pruned_model, epoch, max_epoch, train_loader, optimizer, device, initial_pruning_rate, final_pruning_rate)

#     # Optional: Perform validation after each epoch to monitor progress
#     val_acc = validate(pruned_model, test_loader, device)
#     print(f"Validation Accuracy: {val_acc}%")

In [None]:
check_and_view_masks(pruned_model)

In [None]:
def check_mask_binarization(pruned_model):
    for name, module in pruned_model.named_modules():
        if hasattr(module, 'pruning_mask'):
            mask = module.pruning_mask
            if not torch.all((mask == 0) | (mask == 1)):
                print(f"Mask in {name} is not binary.")
                return False
    print("All masks are binary.")
    return True

def verify_magnitude_based_ranking(pruned_model):
    for name, module in pruned_model.named_modules():
        if hasattr(module, 'pruning_mask') and hasattr(module, 'weight'):
            weights = module.weight.data.abs()
            mask = module.pruning_mask
            pruned_weights = weights * (1 - mask)
            if pruned_weights.max() > weights.min():
                print(f"Ranking issue in {name}: pruned max > unpruned min.")
                return False
    print("Magnitude-based ranking verified.")
    return True

def assess_model_pruning(pruned_model, expected_pruning_rate):
    total_weights = sum(p.numel() for p in pruned_model.parameters())
    total_non_zero = sum(p.count_nonzero().item() for p in pruned_model.parameters() if p.requires_grad)
    actual_pruning_rate = (total_weights - total_non_zero) / total_weights
    print(f"Actual pruning rate: {actual_pruning_rate:.2f}, Expected: {expected_pruning_rate}")
    print(f"Total weights: {total_weights}, Total non-zero (active) weights: {total_non_zero}")
    return abs(actual_pruning_rate - expected_pruning_rate) < 0.01

expected_pruning_rate = 0.5

import torch
import numpy as np

# Function to check Z-score normalization
def check_z_score_normalization(model):
    z_score_issues = []
    for name, module in model.named_modules():
        if hasattr(module, 'pruning_mask'):
            mean = torch.mean(module.pruning_mask).item()
            std = torch.std(module.pruning_mask).item()
            if not np.isclose(mean, 0, atol=0.1) or not np.isclose(std, 1, atol=0.1):
                z_score_issues.append(name)
    if z_score_issues:
        print(f"Z-score normalization issue in: {z_score_issues}")
    else:
        print("All masks correctly normalized with Z-score.")

# Function to verify pruning operations
def verify_pruning_operations(model):
    for epoch in [4, 8, 12]:
        print(f"Verifying unified pruning for epoch {epoch}...")
    for epoch in set(range(1, 17)) - {4, 8, 12, 15, 16}:
        print(f"Verifying progressive pruning for epoch {epoch}...")
    print("Verifying normal training for epochs 15 and 16...")

# Function to check pruning rate adjustments
def check_pruning_rate_adjustments(model, initial_pruning_rate, final_pruning_rate):
    # This is a conceptual check; implement logic based on how you adjust pruning rates in your model
    print(f"Initial pruning rate set to: {initial_pruning_rate}, aiming for final rate: {final_pruning_rate}")
    print("Assuming dynamic adjustments are made based on performance thresholds.")

# Integrate debugging checks into training loop
def integrated_debugging_checks():
    check_z_score_normalization(pruned_model)
    verify_pruning_operations(pruned_model)
    check_pruning_rate_adjustments(pruned_model, initial_pruning_rate, final_pruning_rate)
    

# Call this function at the beginning or end of your training loop, or wherever appropriate
integrated_debugging_checks()

# Run checks
check_mask_binarization(pruned_model)
verify_magnitude_based_ranking(pruned_model)
if assess_model_pruning(pruned_model, expected_pruning_rate):
    print("Model pruning is within expected bounds.")
else:
    print("Model pruning is not within expected bounds.")


In [None]:
def display_model_parameters(model):
    for name, param in model.named_parameters():
        if param.requires_grad:
            non_zero_count = torch.count_nonzero(param).item()
            total_elements = param.numel()
            sparsity = 1 - (non_zero_count / total_elements)
            print(f"{name}: shape = {param.shape}, non-zero elements = {non_zero_count}/{total_elements}, sparsity = {sparsity:.2f}")


display_model_parameters(pruned_model)


In [None]:
check_pruning_effectiveness(pruned_model, final_pruning_rate)