In [None]:
!pip install flash-attn torch accelerate transformers datasets scikit-learn

In [2]:
import json
import os
import torch
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
import torch.nn as nn
import torch.optim as optim
import time


In [None]:
# Load model and tokenizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModelForCausalLM.from_pretrained(
    "Model_Name", 
    device_map="cuda", 
    torch_dtype="auto", 
    trust_remote_code=True, 
).to(device)

tokenizer = AutoTokenizer.from_pretrained("Model_Name")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
# Load dataset
dataset = load_dataset('Dataset_Name')

# Dataset features
train_data = dataset['train']
valid_data = dataset['valid']
test_data = dataset['test']

In [5]:
# Storage for layer representations and classifiers
tot_layer = len(model.model.layers)+1  # Total layers in the model
layer_representations = [[] for _ in range(tot_layer)]
classifiers = []
layer_metrics = []  # Store metrics for each layer

In [6]:
# Define a shared-head neural network classifier
class SharedNN(nn.Module):
    def __init__(self, input_size, num_classes):
        super(SharedNN, self).__init__()
        self.shared_fc1 = nn.Linear(input_size, 128)
        self.shared_relu = nn.ReLU()
        self.shared_fc2 = nn.Linear(128, 64)
        self.shared_relu2 = nn.ReLU()
        self.output_layer = nn.Linear(64, num_classes)

    def forward(self, x):
        x = self.shared_fc1(x)
        x = self.shared_relu(x)
        x = self.shared_fc2(x)
        x = self.shared_relu2(x)
        return self.output_layer(x)

In [None]:
# Preprocess data
def preprocess(example):
    query = example['query']
    input_text = f"{query}"
    label = example['gold']
    return input_text, label

In [None]:
# Collect representations for each layer
def collect_representations(data):
    for example in tqdm(data):
        input_text, label = preprocess(example)
        with torch.no_grad():
            inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=400).to(device)
            outputs = model(**inputs, output_hidden_states=True)

        hidden_states = outputs.hidden_states
        for i, hs in enumerate(hidden_states):
            hs_last = hs[0, -1, :].float().cpu().numpy()
            layer_representations[i].append((hs_last, label))

collect_representations(train_data)

In [None]:
# Train classifiers for each layer
shared_nn = SharedNN(input_size=model.config.hidden_size, num_classes=len(set(train_data['gold']))).to(device)
criterion = nn.CrossEntropyLoss()

In [None]:
# Train classifiers for each layer
trained_classifiers = {}

for i, layer_data in enumerate(layer_representations):
    X = np.array([x[0] for x in layer_data])
    y = np.array([x[1] for x in layer_data])
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

    # Convert data to tensors
    X_train_tensor = torch.tensor(X_train, dtype=torch.float32).to(device)
    y_train_tensor = torch.tensor(y_train, dtype=torch.long).to(device)
    X_test_tensor = torch.tensor(X_test, dtype=torch.float32).to(device)
    y_test_tensor = torch.tensor(y_test, dtype=torch.long).to(device)

    # Initialize a new classifier for this layer
    layer_classifier = SharedNN(input_size=model.config.hidden_size, num_classes=len(set(train_data['gold']))).to(device)
    optimizer = optim.Adam(layer_classifier.parameters(), lr=2e-4)

    # Train the classifier
    layer_classifier.train()
    for epoch in range(200):
        optimizer.zero_grad()
        outputs = layer_classifier(X_train_tensor)
        loss = criterion(outputs, y_train_tensor)
        loss.backward()
        optimizer.step()

    # Evaluate the classifier
    layer_classifier.eval()
    with torch.no_grad():
        y_pred = torch.argmax(layer_classifier(X_test_tensor), dim=1).cpu().numpy()

    accuracy = accuracy_score(y_test, y_pred)
    f1 = f1_score(y_test, y_pred, average="weighted")

    print(f"Layer {i}: Accuracy={accuracy}, F1={f1}")
    layer_metrics.append((i, accuracy, f1))

    # Store the trained classifier
    trained_classifiers[i] = layer_classifier

Layer 0: Accuracy=0.5806451612903226, F1=0.42659644502962474
Layer 1: Accuracy=0.5806451612903226, F1=0.42659644502962474
Layer 2: Accuracy=0.5806451612903226, F1=0.42659644502962474
Layer 3: Accuracy=0.5806451612903226, F1=0.42659644502962474
Layer 4: Accuracy=0.6, F1=0.4733455539617122
Layer 5: Accuracy=0.6564516129032258, F1=0.5849051318735032
Layer 6: Accuracy=0.6564516129032258, F1=0.5849051318735032
Layer 7: Accuracy=0.6532258064516129, F1=0.5956133822636247
Layer 8: Accuracy=0.6629032258064517, F1=0.6091863301156157
Layer 9: Accuracy=0.6645161290322581, F1=0.6027132951462165
Layer 10: Accuracy=0.6725806451612903, F1=0.618333481035094
Layer 11: Accuracy=0.6709677419354839, F1=0.6250487988235577
Layer 12: Accuracy=0.6612903225806451, F1=0.6011536329115249
Layer 13: Accuracy=0.6709677419354839, F1=0.6142827528194994
Layer 14: Accuracy=0.7645161290322581, F1=0.7526584669044527
Layer 15: Accuracy=0.8064516129032258, F1=0.8014296124817752
Layer 16: Accuracy=0.8338709677419355, F1=0.83

In [11]:
# Calculate the standard deviation for accuracy and F1
accuracy_std = np.std([accuracy for _, accuracy, _ in layer_metrics])
f1_std = np.std([f1 for _, _, f1 in layer_metrics])
margin_accuracy = 0.5 * accuracy_std
margin_f1 = 0.5 * f1_std

best_layers = []

# Find layers that are not strictly dominated
for i, accuracy, f1 in layer_metrics:
    is_dominated = False
    for j, acc, f in layer_metrics:
        if (acc >= accuracy + margin_accuracy) and (f >= f1 + margin_f1) and j != i:
            is_dominated = True
            break
    if not is_dominated and i not in best_layers:  # Check if already in list
        best_layers.append(i)

print(f"Selected Best Layers (Dynamic Margin Based on Std Dev): {best_layers}")

Selected Best Layers (Dynamic Margin Based on Std Dev): [15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32]


In [12]:
# Freeze all layers except best layers
def freeze_non_best_layers(model, best_layers):
    for idx, layer in enumerate(model.model.layers):
        requires_grad = idx in best_layers
        for param in layer.parameters():
            param.requires_grad = requires_grad

freeze_non_best_layers(model, best_layers)

In [13]:
def count_trainable_parameters(model, ensemble_clf):
    model_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    clf_params = sum(p.numel() for p in shared_nn.parameters() if p.requires_grad)
    
    total_params = model_params + clf_params
    
    return model_params, clf_params, total_params

# Assuming you've frozen some layers in the model and ensemble_clf is your classifier
model_params, clf_params, total_params = count_trainable_parameters(model, shared_nn)

print(f"Total trainable parameters in model: {model_params}")
print(f"Total trainable parameters in classifier: {clf_params}")
print(f"Total trainable parameters (model + classifier): {total_params}")


Total trainable parameters in model: 2122294272
Total trainable parameters in classifier: 401795
Total trainable parameters (model + classifier): 2122696067


In [14]:
def evaluate_with_voting(valid_data, best_layers, trained_classifiers):
    """
    Evaluate using a voting ensemble of classifiers for the best-performing layers.
    
    Args:
        valid_data: Validation dataset.
        best_layers: List of indices of the best-performing layers.
        trained_classifiers: Dictionary of trained classifiers for each layer.
    """
    model.eval()  # Ensure the LLM is in evaluation mode
    shared_nn.eval()  # Ensure the classifier is in evaluation mode
    
    total_correct = 0
    total_samples = 0
    y_true = []
    y_pred = []
    total_inference_time = 0  # To accumulate inference time

    with torch.no_grad():
        for example in tqdm(valid_data):
            input_text, label = preprocess(example)
            label = torch.tensor([label]).to(device)

            # Tokenize input
            inputs = tokenizer(
                input_text,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=400
            ).to(device)

            # Record the start time for inference
            start_time = time.time()

            # Forward pass through the LLM (model)
            outputs = model(**inputs, output_hidden_states=True)

            # Calculate inference time for the LLM forward pass
            inference_time = time.time() - start_time
            total_inference_time += inference_time  # Accumulate inference time

            # Collect predictions from classifiers (shared_nn)
            predictions = []
            for layer_idx in best_layers:
                layer_output = outputs.hidden_states[layer_idx][0, -1, :].float()
                logits = shared_nn(layer_output)  # Using shared_nn as classifier
                pred = torch.argmax(logits, dim=0).cpu().item()
                predictions.append(pred)

            # Voting ensemble
            final_prediction = max(set(predictions), key=predictions.count)

            total_correct += (final_prediction == label.cpu().item())
            total_samples += 1

            y_true.append(label.cpu().item())
            y_pred.append(final_prediction)

            # Print inference time for this sample (optional)
            #print(f"Inference Time for Sample {total_samples}: {inference_time:.4f} seconds")

    # Metrics
    accuracy = total_correct / total_samples
    f1 = f1_score(y_true, y_pred, average="weighted")
    avg_inference_time = total_inference_time / total_samples  # Average inference time per sample
    print(f"Evaluation Accuracy: {accuracy:.4f}, F1 Score: {f1:.4f}")
    print(f"Average Inference Time per Sample: {avg_inference_time:.4f} seconds")

In [None]:

import time
from torch.optim import Adam
from torch import nn
from tqdm import tqdm
def fine_tune_with_ensemble(train_data, valid_data, best_layers, trained_classifiers):
    """
    Fine-tune the best layers of the LLM and their associated classifiers.
    
    Args:
        train_data: Training dataset.
        valid_data: Validation dataset.
        best_layers: List of indices of the best-performing layers.
        trained_classifiers: Dictionary of trained classifiers for each layer.
    """
    model.train()  # Set LLM to training mode
    shared_nn.train()  # Set classifier to training mode
    
    # Prepare optimizer for fine-tuning both the layers (LLM) and their classifiers (shared_nn)
    optimizer = Adam([
        {"params": model.parameters(), "lr": 5e-5, "weight_decay": 1e-4},  # LLM parameters
        {"params": [param for i in best_layers for param in trained_classifiers[i].parameters()], "lr": 1e-4}  # Classifier parameters
    ])

    criterion = nn.CrossEntropyLoss()

    total_train_loss = 0  # Variable to accumulate training loss

    for epoch in range(100):
        total_train_loss = 0
        epoch_start_time = time.time()
    
        for example in tqdm(train_data):
            input_text, label = preprocess(example)
            label = torch.tensor([label]).to(device)
    
            # Tokenize input
            inputs = tokenizer(
                input_text,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=400
            ).to(device)
    
            # Forward pass
            outputs = model(**inputs, output_hidden_states=True)
    
            # Loss for each layer
            layer_losses = []
            for layer_idx in best_layers:
                layer_output = outputs.hidden_states[layer_idx][0, -1, :].float()
                logits = shared_nn(layer_output)
                loss = criterion(logits.unsqueeze(0), label)
                layer_losses.append(loss)
    
            # Average or sum layer losses
            total_layer_loss = sum(layer_losses)  # Average loss
            optimizer.zero_grad()
            total_layer_loss.backward()
            optimizer.step()
    
            total_train_loss += total_layer_loss.item()
    

        # Calculate and print time taken for each epoch
        epoch_time = time.time() - epoch_start_time
        print(f"Epoch {epoch + 1}, Total Loss: {total_train_loss:.4f}, Time Taken: {epoch_time:.4f}s")

        # Evaluate after each epoch
        print(f"Evaluating at Epoch {epoch + 1}...")
        evaluate_with_voting(valid_data, best_layers, trained_classifiers)
        evaluate_with_voting(test_data, best_layers, trained_classifiers)

fine_tune_with_ensemble(train_data, valid_data, best_layers, trained_classifiers)