In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [2]:
import pandas as pd
import numpy as np
import torch
import torch.optim as optim
import torch.nn.functional as F
from llm_classifier_modified import LLMClassifier
from llm_model_modified import LLM
from torch.utils.data import Dataset, DataLoader
from scipy.stats import dirichlet
import evaluation 

In [3]:
# Load Yahoo answers train and test data
df_train = pd.read_csv('train_yahoo.csv', header=None)
df_test = pd.read_csv('test_yahoo.csv', header=None)
n_train = 10000  
n_in_context = 5  
n_val = 100
n_test = 5000
df_train_actual = df_train.iloc[:n_train]
df_test_actual = df_test.iloc[:n_test]

def format_prompt(q1, q2, a):
    return "Question: " + q1.astype(str) + " " + q2.astype(str) + "\nAnswer: " + a.astype(str)

gt_labels_train = df_train_actual.iloc[:, 0].values.astype(int)
samples_train = format_prompt(df_train_actual.iloc[:, 1], df_train_actual.iloc[:, 2], df_train_actual.iloc[:, 3]).values
gt_labels_test = df_test_actual.iloc[:, 0].values.astype(int)
samples_test = format_prompt(df_test_actual.iloc[:, 1], df_test_actual.iloc[:, 2], df_test_actual.iloc[:, 3]).values

In [4]:
# Define a prompt formatting class for sentiment classification and initializes an LLM-based classifier
class PromptFormatting(object):
    def __init__(self):
        # Best instruction from BayesPE teacher i.e. instruction with highest weight
        self.INSTRUCTION = 'Identify the topic that the following question and answer belong to:'
        self.CLASSES = [
    'Society & Culture',
    'Science & Mathematics',
    'Health',
    'Education & Reference',
    'Computers & Internet',
    'Sports',
    'Business & Finance',
    'Entertainment & Music',
    'Family & Relationships',
    'Politics & Government'
]
        self.CLASSES_FOR_MATCHING = [self.CLASSES]
        self.CLASSES_TEXT = '''1. {}\n2. {}\n3. {}\n4. {}\n5. {}\n6. {}\n7. {}\n8. {}\n9. {}\n10. {}'''.format(self.CLASSES[0],self.CLASSES[1], self.CLASSES[2], self.CLASSES[3], self.CLASSES[4], self.CLASSES[5], self.CLASSES[6], self.CLASSES[7], self.CLASSES[8], self.CLASSES[9])
    def format_instruction(self, instruction):
        return '''{}\n{}\n'''.format(instruction, self.CLASSES_TEXT)

    def format_content(self, content):
        return '''{}\nthe topic is '''.format(content)

prompt_formatting = PromptFormatting()
llm = LLM(model_name="mistralai/Mistral-7B-Instruct-v0.3", use_reduced_precision=True,use_lora=True)
classifier = LLMClassifier(model=llm, prompt_formatting=prompt_formatting)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

trainable params: 14,221,312 || all params: 7,262,244,864 || trainable%: 0.1958


In [5]:
# Compute Dirichlet-based distillation loss
def dirichlet_loss(alpha, probs, weights):
    alpha_0 = torch.sum(alpha, dim=1, keepdim=True)                      
    log_gamma_alpha_0 = torch.lgamma(alpha_0)                           
    log_gamma_alpha_c = torch.lgamma(alpha).sum(dim=1, keepdim=True)   
    alpha_expanded = alpha.unsqueeze(-1)                                
    weighted_log_probs = (alpha_expanded - 1) * torch.log(probs + 1e-8) 
    class_sum = weighted_log_probs.sum(dim=1)                           
    if weights.ndim == 1:
        weights = weights.unsqueeze(1)                                   
    weights_broadcasted = weights.T.expand(probs.shape[0], -1)          
    weighted_terms = class_sum * weights_broadcasted                  
    prompt_sum = weighted_terms.sum(dim=1, keepdim=True)               
    loss = -(log_gamma_alpha_0 - log_gamma_alpha_c + prompt_sum).mean()
    return loss

In [6]:
# Load teacher predictions and weights
probs = torch.load("yahoo_probs.pt", weights_only=False)
weights = torch.load("yahoo_prompt_weights.pt", weights_only=False)
if isinstance(probs, np.ndarray):
    probs = torch.tensor(probs, dtype=torch.float32, device=llm.device)
if isinstance(weights, np.ndarray):
    weights = torch.tensor(weights, dtype=torch.float32, device=llm.device)

In [7]:
# Evaluate performance of model on yahoo answers test data
def evaluate():
    def dirichlet_to_prob(alpha):
        return alpha / alpha.sum(dim=1, keepdim=True) 
    
    
    class TestDirichletDataset(Dataset):
        def __init__(self, samples, n_samples):
            self.samples = samples
            self.n_samples = n_samples
    
        def __len__(self):
            return self.n_samples
    
        def __getitem__(self, idx):
            return self.samples[idx]
    
    llm.model.eval()
    test_dataset = TestDirichletDataset(samples_test, n_test)
    test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False) 
    
    def get_test_alpha(test_dataloader, classifier):
        all_alpha = []
    
        with torch.no_grad():
            for batch_samples in test_dataloader:
                alpha_batch = classifier.soft_labels_batch(input_texts=batch_samples)
                all_alpha.append(alpha_batch)
    
        return torch.cat(all_alpha, dim=0) 
    
    alpha_test = get_test_alpha(test_dataloader, classifier)
    stu_probs = dirichlet_to_prob(alpha_test)
    stu_probs=stu_probs.cpu().numpy()
    f1_score = evaluation.compute_metric(gt_labels_test, stu_probs, metric='f1')
    ece = evaluation.compute_metric(gt_labels_test, stu_probs, metric='ece')
    print('Student f1-score: {}, Student ECE: {}'.format(f1_score, ece))

In [8]:
from torch.utils.data import Dataset, DataLoader

class DirichletDataset(Dataset):
    def __init__(self, samples, num_samples):
        self.samples = samples
        self.num_samples = num_samples

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return self.samples[idx], idx 

In [9]:
# Train student model with teacher predictions and evaluate after each epoch on test data
def train_student(samples_train, probs, weights, num_epochs=10, learning_rate=1e-5, batch_size=32):
    dataset = DirichletDataset(samples_train, len(samples_train))
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, llm.model.parameters()), lr=learning_rate)
    llm.model.train()  

    for epoch in range(num_epochs):
        total_loss = 0

        for batch_idx, (batch_samples, batch_indices) in enumerate(dataloader, start=1):
            batch_indices = batch_indices.to(llm.device)

            batch_probs = probs[batch_indices] 

            optimizer.zero_grad()

            alpha = classifier.soft_labels_batch(input_texts=batch_samples)
            alpha = torch.clamp(alpha, min=1e-3)
            loss = dirichlet_loss(alpha, batch_probs, weights)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()


        torch.cuda.empty_cache()

        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss}")
        evaluate()
train_student(samples_train, probs, weights, batch_size=1)

Epoch 1/10, Loss: -219809.61242961884
Student f1-score: 0.633404153879325, Student ECE: 0.027892576530575752
Epoch 2/10, Loss: -255079.65225172043
Student f1-score: 0.6325764009906196, Student ECE: 0.06699836254119873
Epoch 3/10, Loss: -269250.308511734
Student f1-score: 0.6298374066746713, Student ECE: 0.05541273579001427
Epoch 4/10, Loss: -279097.34998989105
Student f1-score: 0.6214519725698301, Student ECE: 0.09352415800094604
Epoch 5/10, Loss: -286695.85922527313
Student f1-score: 0.6295344252669677, Student ECE: 0.07703009247779846
Epoch 6/10, Loss: -292516.57237911224
Student f1-score: 0.6270476503891149, Student ECE: 0.10067932307720184
Epoch 7/10, Loss: -298502.26456832886
Student f1-score: 0.6248490290990344, Student ECE: 0.10534060001373291
Epoch 8/10, Loss: -303035.6002044678
Student f1-score: 0.6312031593375466, Student ECE: 0.1013832837343216
Epoch 9/10, Loss: -307017.99964523315
Student f1-score: 0.6241456164990662, Student ECE: 0.10634072870016098
Epoch 10/10, Loss: -310