In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [2]:
import pandas as pd
import numpy as np
import torch
import torch.optim as optim
import torch.nn.functional as F
from llm_classifier_modified import LLMClassifier
from llm_model_modified import LLM
from torch.utils.data import Dataset, DataLoader
from scipy.stats import dirichlet
import evaluation 

In [3]:
# Load sst2 train and test data
df_train = pd.read_csv('train_sst2.csv')
df_test = pd.read_csv('test_sst2.csv')
n_train = 10000  
n_in_context = 5  
n_total_in_context = 9 * n_in_context  
n_val=100
df_train_actual = df_train.iloc[:n_train] 
df_test_actual = df_test.iloc[:]  
gt_labels_train = df_train_actual.iloc[:, 2].values.astype(int) 
samples_train = df_train_actual.iloc[:, 1].values 
gt_labels_test = df_test_actual.iloc[:, 2].values.astype(int)
samples_test = df_test_actual.iloc[:, 1].values 

In [4]:
# Define a prompt formatting class for sentiment classification and initializes an LLM-based classifier
class PromptFormatting(object):
    def __init__(self):
        # Best instruction from BayesPE teacher i.e. instruction with highest weight
        self.INSTRUCTION = 'Select the sentiment category that best matches the opinion expressed in the review snippet.'
        self.CLASSES = ['negative', 'positive']
        self.CLASSES_FOR_MATCHING = [self.CLASSES, ['neg', 'pos'], ['1', '2']]
        self.CLASSES_TEXT = '''1. {}\n2. {}'''.format(self.CLASSES[0], self.CLASSES[1])

    def format_instruction(self, instruction):
        return '''{}\n{}\n'''.format(instruction, self.CLASSES_TEXT)

    def format_content(self, content):
        return '''review: {}\nthe review is '''.format(content)

llm = LLM(model_name="mistralai/Mistral-7B-Instruct-v0.3", use_reduced_precision=True,use_lora=True)
prompt_formatting = PromptFormatting()
classifier = LLMClassifier(model=llm, prompt_formatting=prompt_formatting)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

trainable params: 14,221,312 || all params: 7,262,244,864 || trainable%: 0.1958


In [5]:
# Load teacher predictions and weights
probs = torch.load("sst2_probs.pt", weights_only=False)
weights = torch.load("sst2_prompt_weights.pt", weights_only=False)
if isinstance(probs, np.ndarray):
    probs = torch.tensor(probs, dtype=torch.float32, device=llm.device)
if isinstance(weights, np.ndarray):
    weights = torch.tensor(weights, dtype=torch.float32, device=llm.device)

In [6]:
# Compute Dirichlet-based distillation loss
def dirichlet_loss(alpha, probs, weights):
    alpha_0 = torch.sum(alpha, dim=1, keepdim=True)                      
    log_gamma_alpha_0 = torch.lgamma(alpha_0)                           
    log_gamma_alpha_c = torch.lgamma(alpha).sum(dim=1, keepdim=True)   
    alpha_expanded = alpha.unsqueeze(-1)                                
    weighted_log_probs = (alpha_expanded - 1) * torch.log(probs + 1e-8) 
    class_sum = weighted_log_probs.sum(dim=1)                           
    if weights.ndim == 1:
        weights = weights.unsqueeze(1)                                   
    weights_broadcasted = weights.T.expand(probs.shape[0], -1)          
    weighted_terms = class_sum * weights_broadcasted                 
    prompt_sum = weighted_terms.sum(dim=1, keepdim=True)               
    loss = -(log_gamma_alpha_0 - log_gamma_alpha_c + prompt_sum).mean()
    return loss

In [7]:
# Evaluate performance of model on Amazon reviews polarity test data
def evaluate():
    def dirichlet_to_prob(alpha):
        return alpha / alpha.sum(dim=1, keepdim=True) 
    
    
    class TestDirichletDataset(Dataset):
        def __init__(self, samples, n_samples):
            self.samples = samples
            self.n_samples = n_samples
    
        def __len__(self):
            return self.n_samples
    
        def __getitem__(self, idx):
            return self.samples[idx]
    
    llm.model.eval()
    test_dataset = TestDirichletDataset(samples_test, 872)
    test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False) 
    
    def get_test_alpha(test_dataloader, classifier):
        all_alpha = []
    
        with torch.no_grad():
            for batch_samples in test_dataloader:
                alpha_batch = classifier.soft_labels_batch(input_texts=batch_samples)
                all_alpha.append(alpha_batch)
    
        return torch.cat(all_alpha, dim=0) 
    
    alpha_test = get_test_alpha(test_dataloader, classifier)
    stu_probs = dirichlet_to_prob(alpha_test)
    stu_probs=stu_probs.cpu().numpy()
    f1_score = evaluation.compute_metric(gt_labels_test, stu_probs, metric='f1')
    ece = evaluation.compute_metric(gt_labels_test, stu_probs, metric='ece')
    print('Student f1-score: {}, Student ECE: {}'.format(f1_score, ece))

In [8]:
from torch.utils.data import Dataset, DataLoader

class DirichletDataset(Dataset):
    def __init__(self, samples, num_samples):
        self.samples = samples
        self.num_samples = num_samples

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return self.samples[idx], idx 

In [9]:
# Train student model with teacher predictions and evaluate after each epoch on test data
def train_student(samples_train, probs, weights, num_epochs=10, learning_rate=1e-5, batch_size=32):
    dataset = DirichletDataset(samples_train, len(samples_train))
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, llm.model.parameters()), lr=learning_rate)
    llm.model.train()  

    for epoch in range(num_epochs):
        total_loss = 0

        for batch_idx, (batch_samples, batch_indices) in enumerate(dataloader, start=1):
            batch_indices = batch_indices.to(llm.device)

            batch_probs = probs[batch_indices] 

            optimizer.zero_grad()

            alpha = classifier.soft_labels_batch(input_texts=batch_samples)
            alpha = torch.clamp(alpha, min=1e-3)
            loss = dirichlet_loss(alpha, batch_probs, weights)
          
            loss.backward()

            optimizer.step()
            total_loss += loss.item()

            if batch_idx % 1000 == 0:
                    print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}/{len(dataloader)}")

        torch.cuda.empty_cache()

        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss}")
        evaluate()
train_student(samples_train, probs, weights, batch_size=16)

Epoch 1/10, Loss: -1346.9719011932611
Student f1-score: 0.9541088861405678, Student ECE: 0.014382947236299515
Epoch 2/10, Loss: -1943.2796809077263
Student f1-score: 0.9529810948545125, Student ECE: 0.022758912295103073
Epoch 3/10, Loss: -2075.80743509531
Student f1-score: 0.9541245791245792, Student ECE: 0.016509659588336945
Epoch 4/10, Loss: -2215.766918540001
Student f1-score: 0.9529810948545125, Student ECE: 0.013660853728652
Epoch 5/10, Loss: -2309.298662543297
Student f1-score: 0.9529801054501886, Student ECE: 0.015524783171713352
Epoch 6/10, Loss: -2334.9343638420105
Student f1-score: 0.9552746999835607, Student ECE: 0.01670568808913231
Epoch 7/10, Loss: -2379.418802857399
Student f1-score: 0.9541281990583655, Student ECE: 0.015593266114592552
Epoch 8/10, Loss: -2430.474953174591
Student f1-score: 0.9541274751173117, Student ECE: 0.012486808933317661
Epoch 9/10, Loss: -2352.5335055589676
Student f1-score: 0.9518186410709775, Student ECE: 0.012769356369972229
Epoch 10/10, Loss: -