Dataset is accessible through: https://drive.google.com/drive/folders/13USe0gzuzmgJxKuQQqGZck3stj5WkY00?usp=sharing

In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

import pandas as pd
import numpy as np
import torch
import torch.optim as optim
import torch.nn.functional as F
from llm_classifier_modified import LLMClassifier
from llm_model_modified import LLM

In [None]:
df_train = pd.read_csv('train_modified.csv', header=None)
df_test = pd.read_csv('test_modified.csv', header=None)

n_train = 5000
n_in_context = 5  
n_total_in_context = len(df_train) * n_in_context  
n_test = 5000
n_val = 100

df_train_actual = df_train.iloc[:n_train] 
df_in_context_base = df_train.iloc[n_train:n_train + n_total_in_context]
df_val = df_train.iloc[n_train + n_total_in_context:n_train + n_total_in_context + n_val]
df_test_actual = df_test.iloc[:n_test]  

gt_labels_train = df_train_actual.iloc[:, 0].values.astype(int) 
samples_train = df_train_actual.iloc[:, 2].values 
gt_labels_val = df_val.iloc[:, 0].values.astype(int) 
samples_val = df_val.iloc[:, 2].values 

gt_labels_test = df_test_actual.iloc[:, 0].values.astype(int)
samples_test = df_test_actual.iloc[:, 2].values  


In [4]:
class PromptFormatting(object):
    def __init__(self):
        self.INSTRUCTION = 'classify the sentiment of the Amazon review below into one of the following classes:'
        self.CLASSES = ['negative', 'positive']
        self.CLASSES_FOR_MATCHING = [self.CLASSES, ['neg', 'pos'], ['1', '2']]
        self.CLASSES_TEXT = '''1. {}\n2. {}'''.format(self.CLASSES[0], self.CLASSES[1])

    def format_instruction(self, instruction):
        return '''{}\n{}\n'''.format(instruction, self.CLASSES_TEXT)

    def format_content(self, content):
        return '''review: {}\nthe review is '''.format(content)

llm = LLM(model_name="mistralai/Mistral-7B-Instruct-v0.3", use_reduced_precision=True,use_lora=True)
prompt_formatting = PromptFormatting()
classifier = LLMClassifier(model=llm, prompt_formatting=prompt_formatting)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

trainable params: 14,221,312 || all params: 7,262,244,864 || trainable%: 0.1958


In [5]:
# Load teacher predictions and weights
probs = torch.load("teacher_probs_promptwise.pt", weights_only=False)
weights = torch.load("prompt_weights.pt", weights_only=False)
if isinstance(probs, np.ndarray):
    probs = torch.tensor(probs, dtype=torch.float32, device=llm.device)
if isinstance(weights, np.ndarray):
    weights = torch.tensor(weights, dtype=torch.float32, device=llm.device)

In [6]:
import torch

def dirichlet_loss(alpha, probs, weights):

    alpha_0 = torch.sum(alpha, dim=1, keepdim=True)                      
    log_gamma_alpha_0 = torch.lgamma(alpha_0)                           
    log_gamma_alpha_c = torch.lgamma(alpha).sum(dim=1, keepdim=True)   

    alpha_expanded = alpha.unsqueeze(-1)                                

    weighted_log_probs = (alpha_expanded - 1) * torch.log(probs + 1e-8) 

    class_sum = weighted_log_probs.sum(dim=1)                           

    if weights.ndim == 1:
        weights = weights.unsqueeze(1)                                   
    weights_broadcasted = weights.T.expand(probs.shape[0], -1)          
    
    weighted_terms = class_sum * weights_broadcasted                  

    prompt_sum = weighted_terms.sum(dim=1, keepdim=True)               

    loss = -(log_gamma_alpha_0 - log_gamma_alpha_c + prompt_sum).mean()

    return loss


In [7]:
from torch.utils.data import Dataset, DataLoader

class DirichletDataset(Dataset):
    def __init__(self, samples, num_samples):
        self.samples = samples
        self.num_samples = num_samples

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return self.samples[idx], idx 

In [8]:
def train_student(samples_train, probs, weights, num_epochs=3, learning_rate=1e-4, batch_size=32):
    dataset = DirichletDataset(samples_train, len(samples_train))
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, llm.model.parameters()), lr=learning_rate)
    llm.model.train()  

    for epoch in range(num_epochs):
        total_loss = 0

        for batch_idx, (batch_samples, batch_indices) in enumerate(dataloader, start=1):
            batch_indices = batch_indices.to(llm.device)

            batch_probs = probs[batch_indices] 

            optimizer.zero_grad()

            alpha = classifier.soft_labels_batch(input_texts=batch_samples)
            alpha = torch.clamp(alpha, min=1e-3)
            loss = dirichlet_loss(alpha, batch_probs, weights)
          
            loss.backward()

            optimizer.step()
            total_loss += loss.item()

            if batch_idx % 1000 == 0:
                    print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}/{len(dataloader)}")

        torch.cuda.empty_cache()

        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss}")

train_student(samples_train, probs, weights, batch_size=16)

Epoch 1/3, Loss: -776.5150373280048
Epoch 2/3, Loss: -968.3509067893028
Epoch 3/3, Loss: -1057.5121891498566


In [9]:
def dirichlet_to_prob(alpha):
    return alpha / alpha.sum(dim=1, keepdim=True) 

class DirichletDataset(Dataset):
    def __init__(self, samples, n_samples):
        self.samples = samples
        self.n_samples = n_samples

    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        return self.samples[idx]

llm.model.eval()
test_dataset = DirichletDataset(samples_test, n_test)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False) 

def get_test_alpha(test_dataloader, classifier):
    all_alpha = []

    with torch.no_grad():
        for batch_samples in test_dataloader:
            alpha_batch = classifier.soft_labels_batch(input_texts=batch_samples)
            all_alpha.append(alpha_batch)

    return torch.cat(all_alpha, dim=0) 

alpha_test = get_test_alpha(test_dataloader, classifier)
stu_probs = dirichlet_to_prob(alpha_test)  


In [10]:
import evaluation  
stu_probs=stu_probs.cpu().numpy()
f1_score = evaluation.compute_metric(gt_labels_test, stu_probs, metric='f1')
ece = evaluation.compute_metric(gt_labels_test, stu_probs, metric='ece')
print('Student f1-score: {}, Student ECE: {}'.format(f1_score, ece))

Student f1-score: 0.9540861317702354, Student ECE: 0.021748656406998634
