<a href="https://colab.research.google.com/github/OmidGhadami95/Fake_Detection_BERT_Pruning/blob/main/Pruning_Part.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
from transformers import AutoModel

# Define model architecture (must match original training code)
class BERT_Arch(nn.Module):
    def __init__(self, bert):
        super(BERT_Arch, self).__init__()
        self.bert = bert
        self.dropout = nn.Dropout(0.1)
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(768, 512)
        self.fc2 = nn.Linear(512, 2)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, sent_id, mask):
        cls_hs = self.bert(sent_id, attention_mask=mask)['pooler_output']
        x = self.fc1(cls_hs)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.softmax(x)
        return x

# Load pre-trained model
def load_model(model_path):
    bert = AutoModel.from_pretrained('bert-base-uncased')
    model = BERT_Arch(bert)
    model.load_state_dict(torch.load(model_path))
    return model

# Smooth pruning implementation
def smooth_pruning(model, pruning_rate=0.2, pruning_steps=10):
    for step in range(pruning_steps):
        print(f"Pruning step {step + 1}/{pruning_steps}")
        for name, module in model.named_modules():
            if isinstance(module, nn.Linear):
                prune.l1_unstructured(module,
                                    name='weight',
                                    amount=pruning_rate/pruning_steps)

    # Make pruning permanent
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            prune.remove(module, 'weight')
    return model

if __name__ == "__main__":
    # Configuration
    MODEL_PATH = 'c3_new_model_weights.pt'
    SAVE_PATH = 'pruned_model_weights.pt'
    PRUNING_RATE = 0.2  # Total 20% pruning
    PRUNING_STEPS = 10   # Number of pruning iterations

    # Load and prune model
    model = load_model(MODEL_PATH)
    pruned_model = smooth_pruning(model, PRUNING_RATE, PRUNING_STEPS)

    # Save pruned model
    torch.save(pruned_model.state_dict(), SAVE_PATH)
    print(f"Pruned model saved to {SAVE_PATH}")