In [None]:
from flask import Flask, request, jsonify, render_template
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os
from transformers import OPTForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments, DataCollatorWithPadding
from datasets import load_dataset
from sklearn.cluster import KMeans
from kneed import KneeLocator
import time
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_opt_classifier(model_name):
    """Load the specified OPT model and tokenizer."""
    model = OPTForSequenceClassification.from_pretrained(model_name, num_labels=2).to(device)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    return model, tokenizer

def get_model_size(model, path="temp_model.pth"):
    """Calculate model size in MB."""
    torch.save(model.state_dict(), path)
    size_mb = os.path.getsize(path) / (1024 * 1024)
    os.remove(path)
    return size_mb
def enforce_head_constraint(num_heads, embed_dim):
    """ Adjusts number of heads to ensure divisibility with embedding dimension. """
    while embed_dim % num_heads != 0:
        num_heads -= 1
    return num_heads

def load_opt_classifier():
    model = OPTForSequenceClassification.from_pretrained("facebook/opt-350m", num_labels=2)
    tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
    return model, tokenizer
def get_attention_scores(model, input_ids):
    """ Extracts attention scores from the model while ensuring correct dimensions. """
    attention_scores = {}
    input_ids = input_ids.to(device)  

    with torch.no_grad():
        outputs = model(input_ids, output_attentions=True)

        for layer_idx, attn in enumerate(outputs.attentions):
            attn = attn.cpu().numpy()  
            attn = np.mean(attn, axis=(0, 2, 3)) if attn.ndim == 4 else np.mean(attn, axis=0)
            attention_scores[layer_idx] = attn  
    return attention_scores




def cluster_heads(attention_scores, num_clusters):
    """ Clusters attention heads while ensuring correct shape. """
    num_heads = len(attention_scores)

    if num_heads <= 10:
        return list(range(num_heads))

    attention_scores = np.array(attention_scores).reshape(-1, 1)  

    kmeans = KMeans(n_clusters=num_clusters, random_state=42, n_init="auto")
    kmeans.fit(attention_scores)

    labels = kmeans.labels_
    cluster_representatives = []

    for cluster_idx in range(num_clusters):
        indices = np.where(labels == cluster_idx)[0]
        if len(indices) > 0:
            keep_count = max(1, len(indices) * 5 // 10)  
            cluster_representatives.extend(indices[:keep_count])

    return sorted(cluster_representatives)

def prune_attention_heads(model, clustered_heads):
    """ Prunes attention heads while ensuring correct embedding dimensions. """
    for layer_idx, heads_to_keep in enumerate(clustered_heads):
        attn_layer = model.model.decoder.layers[layer_idx].self_attn

        # Ensure valid number of heads per layer
        original_num_heads = attn_layer.num_heads
        new_num_heads = enforce_head_constraint(len(heads_to_keep), attn_layer.embed_dim)

        #  Update number of heads
        attn_layer.num_heads = new_num_heads

        #  Ensure Q, K, V projections match new number of heads
        head_dim = attn_layer.embed_dim // original_num_heads
        new_embed_dim = new_num_heads * head_dim

        attn_layer.q_proj = nn.Linear(attn_layer.embed_dim, new_embed_dim, bias=False)
        attn_layer.k_proj = nn.Linear(attn_layer.embed_dim, new_embed_dim, bias=False)
        attn_layer.v_proj = nn.Linear(attn_layer.embed_dim, new_embed_dim, bias=False)

        #  Ensure output projection layer matches new size
        attn_layer.out_proj = nn.Linear(new_embed_dim, attn_layer.embed_dim, bias=False)

    return model

def divide_layers_by_sensitivity(sensitivities):
    """ Splits layers into 3 groups (High, Medium, Low) based on sensitivity scores. """
    sorted_layers = sorted(sensitivities, key=sensitivities.get, reverse=True)
    num_layers = len(sorted_layers)
    high = sorted_layers[: int(num_layers * 0.2)]
    medium = sorted_layers[int(num_layers * 0.2) : int(num_layers * 0.7)]
    low = sorted_layers[int(num_layers * 0.7) :]
    return high, medium, low

def apply_mixed_precision(model, medium, low):
    """ Applies mixed precision quantization to the model. """
    for layer_idx in medium + low:
        for param in model.model.decoder.layers[layer_idx].parameters():
            param.data = param.data.half()
    model.half()
    return model

def compute_sensitivity(attention_scores):
    # Debugging: Print attention scores
    print("Raw Attention Scores:", attention_scores)

    # Compute absolute mean for each layer
    cleaned_scores = {
        layer_idx: np.mean(np.abs(np.array(scores, dtype=np.float32)))
        for layer_idx, scores in attention_scores.items()
        if isinstance(scores, (list, np.ndarray)) and len(scores) > 0  # Ensure non-empty values
    }
    
    # Debugging: Print computed sensitivities
    print("Computed Sensitivities:", cleaned_scores)
    
    return cleaned_scores# -------------------- Evaluation -------------------- #
# def evaluate_model(model, tokenizer):
#     """ Evaluates model accuracy on PIQA dataset (multiple-choice task). """
#     dataset = load_dataset("piqa", split="validation[:100]")
#     model.eval()
#     correct, total = 0, 0

#     with torch.no_grad():
#         for example in dataset:
#             prompt = example["goal"]
#             choices = [example["sol1"], example["sol2"]]
#             inputs = tokenizer([prompt + " " + choice for choice in choices], return_tensors="pt", padding=True, truncation=True)
#             outputs = model(**inputs)
#             logits = outputs.logits.squeeze()
#             predicted_choice = torch.argmax(logits).item()
#             correct += (predicted_choice == example["label"])
#             total += 1

#     return (correct / total) * 100

# -------------------- Main Execution -------------------- #
def get_optimal_clusters(attention_scores):
    """ Determines optimal clusters for attention heads using the Elbow Method. """
    num_heads = len(attention_scores)
    if num_heads <= 10:
        return num_heads
    max_clusters = max(num_heads - int(0.2 * num_heads), num_heads * 4 // 5)
    errors = []
    cluster_range = range(1, max_clusters + 1)
    for num_clusters in cluster_range:
        kmeans = KMeans(n_clusters=num_clusters, random_state=42, n_init="auto")
        kmeans.fit(attention_scores.reshape(-1, 1))
        errors.append(kmeans.inertia_)
    elbow = KneeLocator(cluster_range, errors, curve="convex", direction="decreasing")
    return max(num_heads - int(0.2 * num_heads), elbow.elbow if elbow.elbow else max_clusters)

def get_model_size(model, path="temp_model.pth"):
    """ Saves model temporarily and checks disk size. """
    torch.save(model.state_dict(), path)
    size_mb = os.path.getsize(path) / (1024 * 1024)  # Convert bytes to MB
    os.remove(path)  #  Clean up after measurement
    return size_mb

def evaluate_model(model, tokenizer, dataset_name):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    """Evaluates model accuracy on the given dataset."""
    dataset_mapping = {
        "sst2": ("glue", "sst2", "sentence"),
        "piqa": ("piqa", "train", "goal"),
        "rte": ("glue", "rte", "sentence1"),
    }

    if dataset_name not in dataset_mapping:
        return {"error": f"Unsupported dataset: {dataset_name}"}

    dataset_source, dataset_subset, text_key = dataset_mapping[dataset_name]

    try:
        dataset = load_dataset(dataset_source, dataset_subset, split="train", trust_remote_code=True)
    except Exception as e:
        return {"error": f"Error loading dataset: {str(e)}"}

    model.eval()
    correct, total = 0, 0
    start_time = time.time()
    with torch.no_grad():
        for example in dataset:
            inputs = tokenizer(example[text_key], return_tensors="pt", padding=True, truncation=True).to(device)
            outputs = model(**inputs)
            predictions = torch.argmax(outputs.logits, dim=-1)
            correct += (predictions.item() == example["label"])
            total += 1

    end_time = time.time()
    accuracy = (correct / total) * 100 if total > 0 else 0.0
    latency = end_time - start_time
    return accuracy, latency

def apply_pruning(model, tokenizer,dataset_name):
    size_before = get_model_size(model)

    print("\n Evaluating Accuracy Before Any Modification...")
    # accuracy_before = evaluate_model(model, tokenizer,dataset_name)
    # print(f" Original Accuracy: {accuracy_before:.2f}%\n")

    #  Measure Original Model Size
    #  Compute Attention Scores
    print("\n Computing Attention Scores...")
    input_ids = torch.randint(0, 50256, (1, 32))  # Random input for attention extraction
    attention_scores = get_attention_scores(model, input_ids)

    #  Apply Clustering to All Layers
    print("\n Clustering Attention Heads...")
#  Ensure layers exist before accessing
    if not attention_scores:
        raise ValueError(" No attention scores extracted! Check if model supports output_attentions.")

    available_layers = list(attention_scores.keys())
    print(f" Available Layers for Clustering: {available_layers}")

    clustered_heads = [
        cluster_heads(attention_scores[layer], get_optimal_clusters(attention_scores[layer]))
        for layer in available_layers
    ]


    #  Apply Clustering & Pruning (CHAI-Base)
    print("\n Applying Clustering and Pruning (CHAI-Base)...")
    chai_base_model = prune_attention_heads(model, clustered_heads)
    print("got heads")
    return chai_base_model

In [None]:
from flask import Flask, request, jsonify, render_template
import torch
import torch.nn.functional as F
import numpy as np
import os
from transformers import OPTForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments, DataCollatorWithPadding
from datasets import load_dataset
from sklearn.cluster import KMeans
from kneed import KneeLocator
import time

#  Detect and use the available device (MPS for Apple GPUs, CUDA for NVIDIA, otherwise CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f" Using device: {device}")

def load_teacher_model(model_name, model_path):
    """ Load the teacher model and move it to the correct device """
    print(f" Loading teacher model {model_name} from {model_path}")

    if not os.path.exists(model_path):
        raise FileNotFoundError(f" Model file not found at {model_path}")

    teacher_model = OPTForSequenceClassification.from_pretrained(model_name)

    try:
        teacher_model.load_state_dict(torch.load(model_path, map_location=device))
        print(" Successfully loaded teacher model weights.")
    except Exception as e:
        print(f" Error loading teacher model weights: {e}")

    teacher_model.to(device)
    teacher_model.eval()
    return teacher_model

def chai_knowledge_distillation_enhancement(student_model, teacher_model, tokenizer, dataset_name):
    """ Apply Knowledge Distillation (CHAI-KD) Enhancement """
    
    print("\n Applying Knowledge Distillation (CHAI-KD)...")
    print(" [chai-kd] Before modification: model parameters")
    for name, param in student_model.named_parameters():
        print(f"  {name}: {param.shape}")

    #  Move student model to the correct device
    student_model.to(device)

    #  Training hyperparameters
    epochs = 1
    batch_size = 16
    temperature = 2.0
    alpha = 0.5

    #  Load dataset and tokenize it
    dataset = load_dataset("glue", "sst2", split="train[:5000]")
    tokenized_dataset = dataset.map(
        lambda e: tokenizer(e['sentence'], truncation=True, padding='max_length', max_length=128),
        batched=True
    )

    training_args = TrainingArguments(
        output_dir="./chai_kd_model",
        learning_rate=5e-5,
        per_device_train_batch_size=batch_size,
        num_train_epochs=epochs,
        save_strategy="epoch",
        report_to="none"
    )

    class KDTrainer(Trainer):
        def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
            """ Compute Knowledge Distillation Loss """

            labels = inputs.pop("labels")  # Extract labels
            inputs = {k: v.to(device) for k, v in inputs.items()}  #  Move inputs to the correct device
            model.to(device)

            student_outputs = model(**inputs)
            student_logits = student_outputs.logits

            with torch.no_grad():
                teacher_outputs = teacher_model(**inputs)  # Ensure teacher model is on the correct device
                teacher_logits = teacher_outputs.logits

            kd_loss = F.kl_div(
                F.log_softmax(student_logits / temperature, dim=-1),
                F.softmax(teacher_logits / temperature, dim=-1),
                reduction="batchmean"
            ) * (temperature ** 2)

            ce_loss = F.cross_entropy(student_logits, labels)
            loss = alpha * kd_loss + (1 - alpha) * ce_loss  # Combined loss

            return (loss, student_outputs) if return_outputs else loss

    data_collator = DataCollatorWithPadding(tokenizer=tokenizer, padding="longest")

    trainer = KDTrainer(
        model=student_model,
        args=training_args,
        train_dataset=tokenized_dataset,
        eval_dataset=tokenized_dataset,
        data_collator=data_collator  #  Ensure proper padding
    )

    trainer.train()

    print(" [chai-kd] After modification: model parameters")
    for name, param in student_model.named_parameters():
        print(f"  {name}: {param.shape}")

    return student_model


In [None]:
def evaluate_model(model, tokenizer, dataset_name):
    """Evaluates model accuracy on the given dataset."""

    #  Dataset mappings for correct loading
    dataset_mapping = {
        "sst2": ("glue", "sst2", "validation", "sentence"),
        "rte": ("glue", "rte", "validation", ("sentence1", "sentence2")),  # RTE has two sentences
        "piqa": ("piqa", None, "validation", ("goal", "sol1", "sol2")),  # PIQA has different format
    }

    if dataset_name not in dataset_mapping:
        return {"error": f"Unsupported dataset: {dataset_name}"}

    dataset_source, dataset_subset, split_name, text_key = dataset_mapping[dataset_name]

    #  Load dataset with cache handling
    try:
        dataset = load_dataset(dataset_source, dataset_subset, split=split_name, cache_dir="./cache") if dataset_subset \
            else load_dataset(dataset_source, split=split_name, cache_dir="./cache")
    except Exception as e:
        return {"error": f"Error loading dataset: {str(e)}"}

    dataloader = DataLoader(dataset, batch_size=16)

    #  Enable GPU if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()

    # Ensure correct precision (avoid float16 on CPU)
    if device.type == "cpu":
        model.to(torch.float32)

    start_time = time.time()
    correct, total = 0, 0

    with torch.no_grad():
        for batch in dataloader:
            #  Handle multiple input fields correctly
            if isinstance(text_key, tuple):  # RTE & PIQA
                inputs = tokenizer(*[batch[key] for key in text_key], return_tensors="pt", padding=True, truncation=True).to(device)
            else:
                inputs = tokenizer(batch[text_key], return_tensors="pt", padding=True, truncation=True).to(device)

            # ✅ FIX: Correct tensor creation and movement
            labels = torch.tensor(batch["label"], dtype=torch.long).to(device, non_blocking=True)

            outputs = model(**inputs)
            predictions = torch.argmax(F.softmax(outputs.logits, dim=-1), dim=-1)

            correct += (predictions == labels).sum().item()
            total += labels.size(0)

    end_time = time.time()
    accuracy = (correct / total) * 100 if total > 0 else 0.0
    latency = end_time - start_time

    return [accuracy, latency]

dataset_name = "sst2"
model_name = "facebook/opt-350m"
# model = OPTForSequenceClassification.from_pretrained(model_name, num_labels=2).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)

model, tokenizer = load_opt_classifier(model_name)
model = apply_pruning(model,tokenizer,dataset_name)
classifier_model = kd_load_model_for_kd(model_name)


#  Track applied methods
def evaluate_model(model, tokenizer, dataset_name):
    """Evaluates model accuracy on the given dataset."""

    #  Dataset mappings for correct loading
    dataset_mapping = {
        "sst2": ("glue", "sst2", "validation", "sentence"),
        "rte": ("glue", "rte", "validation", ("sentence1", "sentence2")),  # RTE has two sentences
        "piqa": ("piqa", None, "validation", ("goal", "sol1", "sol2")),  # PIQA has different format
    }

    if dataset_name not in dataset_mapping:
        return {"error": f"Unsupported dataset: {dataset_name}"}

    dataset_source, dataset_subset, split_name, text_key = dataset_mapping[dataset_name]

    #  Load dataset with cache handling
    try:
        dataset = load_dataset(dataset_source, dataset_subset, split=split_name, cache_dir="./cache") if dataset_subset \
            else load_dataset(dataset_source, split=split_name, cache_dir="./cache")
    except Exception as e:
        return {"error": f"Error loading dataset: {str(e)}"}

    dataloader = DataLoader(dataset, batch_size=16)

    #  Enable GPU if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()

    #  Ensure correct precision (avoid float16 on CPU)
    if device.type == "cpu":
        model.to(torch.float32)

    start_time = time.time()
    correct, total = 0, 0

    with torch.no_grad():
        for batch in dataloader:
            #  Handle multiple input fields correctly
            if isinstance(text_key, tuple):  # RTE & PIQA
                inputs = tokenizer(*[batch[key] for key in text_key], return_tensors="pt", padding=True, truncation=True).to(device)
            else:
                inputs = tokenizer(batch[text_key], return_tensors="pt", padding=True, truncation=True).to(device)

            #  FIX: Correct tensor creation and movement
            labels = torch.tensor(batch["label"], dtype=torch.long).to(device, non_blocking=True)

            outputs = model(**inputs)
            predictions = torch.argmax(F.softmax(outputs.logits, dim=-1), dim=-1)

            correct += (predictions == labels).sum().item()
            total += labels.size(0)

    end_time = time.time()
    accuracy = (correct / total) * 100 if total > 0 else 0.0
    latency = end_time - start_time

    return [accuracy, latency]


[a1, b1] = evaluate_model_piqa(model, tokenizer,dataset, num_classes=2)

# Apply Pruning
model = apply_pruning(model, tokenizer, dataset_name)

print("\n Initial Model Evaluation:")
print(f"Accuracy: {a1:.2f}%")
print(f"Latency: {b1:.4f} sec")
print("------------------------------------------------------------------")

#  Measure Size After Pruning
print("\n Chai Base Model Evaluation (After Pruning)")
size2 = get_model_size(model)
print(f"Reduction Percentage: {(size1 - size2) * 100 / size1:.2f}%")

[a, b] = evaluate_model_piqa(model, tokenizer,dataset, num_classes=2)
print(f"Accuracy: {a:.2f}%")
print(f"Latency: {b:.4f} sec")
print("------------------------------------------------------------------")

#  Apply CHAI-Quant Enhancement
model = chai_knowledge_distillation_enhancement(model, tokenizer)
size3 = get_model_size(model)
print("\n Applied Methods: KD")
[a, b] = evaluate_model_piqa(model, tokenizer,dataset, num_classes=2)

print("\n Final Model Evaluation (After CHAI-KD)")
print(f"Accuracy after applying methods: {a:.2f}%")
print(f"Latency: {b:.4f} sec")
print(f"Reduction Percentage: {(size2 - size3) * 100 / size2:.2f}%")

