In [82]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import math
import matplotlib.pyplot as plt
import numpy as np
from collections import OrderedDict
import copy


In [None]:
# Load the list of names from a file
with open('/home/mohammad/Safety-Driven-Self-Compressing-Neural-Networks/NLP/data/names.txt', 'r') as f:
    words = f.read().splitlines()

# Build the vocabulary of characters
chars = sorted(list(set(''.join(words))))
stoi = {s: i+1 for i, s in enumerate(chars)}  # Start indices from 1
stoi['.'] = 0  # End-of-sequence token
itos = {i: s for s, i in stoi.items()}
vocab_size = len(stoi)
print(f"Vocabulary size: {vocab_size}")


In [None]:
def build_dataset(words, block_size=8):
    X, Y = [], []
    for w in words:
        context = [0] * block_size  # Initialize context with zeros (start tokens)
        for ch in w + '.':  # Append the end-of-sequence token
            ix = stoi[ch]
            X.append(context)
            Y.append(ix)
            context = context[1:] + [ix]  # Slide the context window
    X = torch.tensor(X, dtype=torch.long)
    Y = torch.tensor(Y, dtype=torch.long)
    return X, Y

# Shuffle and split the dataset
random.seed(42)
random.shuffle(words)
n1 = int(0.8 * len(words))
n2 = int(0.9 * len(words))

block_size = 8  # Context size

Xtr, Ytr = build_dataset(words[:n1], block_size)
Xdev, Ydev = build_dataset(words[n1:n2], block_size)
Xte, Yte = build_dataset(words[n2:], block_size)

print(f"Training set size: {Xtr.shape}, {Ytr.shape}")
print(f"Validation set size: {Xdev.shape}, {Ydev.shape}")
print(f"Test set size: {Xte.shape}, {Yte.shape}")


In [None]:
# Load preservation set from a text file
def load_preservation_set(file_path):
    with open(file_path, 'r') as f:
        preservation_words = [line.strip().lower() for line in f.readlines()]
    return preservation_words

preservation_file_path = '/home/mohammad/Safety-Driven-Self-Compressing-Neural-Networks/NLP/data/hardest_examples.txt'  # Update with your preservation set file path
preservation_words = load_preservation_set(preservation_file_path)

# Build dataset for preservation set
Xpres, Ypres = build_dataset(preservation_words, block_size)
preservation_set = (Xpres, Ypres)

print(f"Loaded Preservation Set: {preservation_words[:5]}")


In [89]:
class TransformerModel(nn.Module):
    def __init__(self, vocab_size, embed_size, num_heads, num_layers, block_size, dropout=0.1):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, embed_size)
        self.position_embedding = nn.Embedding(block_size, embed_size)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_size,
            nhead=num_heads,
            dim_feedforward=embed_size * 4,
            dropout=dropout,
            activation='gelu',
            batch_first=True
        )
        self.layers = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.ln_f = nn.LayerNorm(embed_size)
        self.head = nn.Linear(embed_size, vocab_size)
        self.block_size = block_size
        self.embed_size = embed_size

    def forward(self, idx):
        B, T = idx.size()
        assert T <= self.block_size, "Sequence length exceeds block size."
        token_emb = self.token_embedding(idx)  # (B, T, embed_size)
        position_ids = torch.arange(T, device=idx.device).unsqueeze(0).expand(B, T)
        pos_emb = self.position_embedding(position_ids)  # (B, T, embed_size)
        x = token_emb + pos_emb  # (B, T, embed_size)

        x = self.layers(x)
        x = self.ln_f(x)
        logits = self.head(x[:, -1, :])  # Predict the next token
        return logits


In [None]:
def calculate_model_size(model, precision_bits=32):
    total_params = sum(p.numel() for p in model.parameters())
    # Number of bytes per parameter
    bytes_per_param = precision_bits / 8
    total_size_in_bytes = total_params * bytes_per_param
    total_size_in_megabytes = total_size_in_bytes / (1024 ** 2)  # Convert to MB
    return total_params, total_size_in_megabytes

embed_size = 128
num_heads = 8
num_layers = 4
dropout = 0.1

model = TransformerModel(
    vocab_size=vocab_size,
    embed_size=embed_size,
    num_heads=num_heads,
    num_layers=num_layers,
    block_size=block_size,
    dropout=dropout
)

parameters = list(model.parameters())

# Calculate original model size with 32-bit precision
original_params, original_size = calculate_model_size(model, precision_bits=32)
print(f"Original Model Size: {original_params} parameters, {original_size:.2f} MB")


In [94]:
def quantize_attention_heads(model):
    """
    Apply dynamic quantization to the model's linear layers.

    Returns:
        A quantized version of the model.
    """
    # Create a copy of the model to quantize
    quantized_model = copy.deepcopy(model)
    # Apply dynamic quantization only to nn.Linear layers
    quantizable_layers = {nn.Linear}
    torch.quantization.quantize_dynamic(quantized_model, quantizable_layers, dtype=torch.qint8, inplace=True)
    return quantized_model


In [95]:
def evaluate_preservation_set(preservation_set, model):
    model.eval()
    Xpres, Ypres = preservation_set
    with torch.no_grad():
        logits = model(Xpres)
        loss = F.cross_entropy(logits, Ypres)
    return loss.item()

# Set acceptable loss increase threshold (e.g., 5% increase)
loss_increase_threshold = 0.05


In [None]:
optimizer = torch.optim.AdamW(parameters, lr=1e-3)
lossi = []
stepi = []
max_steps = 5000
batch_size = 64

# Evaluate initial preservation loss
initial_preservation_loss = evaluate_preservation_set(preservation_set, model)
print(f"Initial Preservation Loss: {initial_preservation_loss:.4f}")

best_model_state = copy.deepcopy(model.state_dict())
best_preservation_loss = initial_preservation_loss

for i in range(max_steps):
    model.train()
    ix = torch.randint(0, Xtr.shape[0], (batch_size,))
    X_batch, Y_batch = Xtr[ix], Ytr[ix]

    logits = model(X_batch)
    loss = F.cross_entropy(logits, Y_batch)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if i % 100 == 0:
        lossi.append(loss.item())
        stepi.append(i)

    # Quantization and Evaluation on Preservation Set
    if i % 500 == 0 and i > 0:
        # Quantize the model
        quantized_model = quantize_attention_heads(model)

        # Evaluate preservation loss after quantization
        preservation_loss_after = evaluate_preservation_set(preservation_set, quantized_model)

        # Check if loss increase is within acceptable threshold
        loss_increase = (preservation_loss_after - initial_preservation_loss) / initial_preservation_loss
        if loss_increase > loss_increase_threshold:
            # Do not keep quantized model
            print(f"Step {i}: Restored original model due to preservation loss increase ({loss_increase * 100:.2f}%).")
            # No action needed since we haven't modified the original model
        else:
            # Keep quantized model
            model = quantized_model
            optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
            initial_preservation_loss = preservation_loss_after
            print(f"Step {i}: Quantization successful. Preservation loss increase: {loss_increase * 100:.2f}%")
            # Update best model state
            best_model_state = copy.deepcopy(model.state_dict())
            best_preservation_loss = preservation_loss_after

    if i % 1000 == 0:
        print(f"Step {i}, Training Loss: {loss.item():.4f}")


In [None]:
def calculate_quantized_model_size(model):
    total_params = 0
    total_size_in_bytes = 0
    for name, param in model.named_parameters():
        numel = param.numel()
        if hasattr(param, 'dtype') and param.dtype == torch.qint8:
            # Quantized weights are int8 (1 byte)
            bytes_per_param = 1
        else:
            # Assume other parameters are float32 (4 bytes)
            bytes_per_param = 4
        total_params += numel
        total_size_in_bytes += numel * bytes_per_param
    total_size_in_megabytes = total_size_in_bytes / (1024 ** 2)
    return total_params, total_size_in_megabytes

# Calculate model size after quantization
quantized_params, quantized_size = calculate_quantized_model_size(model)
print(f"Quantized Model Size: {quantized_params} parameters, {quantized_size:.2f} MB")


In [None]:
# Load the best model state
model.load_state_dict(best_model_state)
test_loss = evaluate_model(Xte, Yte, model)
print(f"Test Loss: {test_loss:.4f}")
