In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [2]:
import pandas as pd
import numpy as np
import torch
import torch.optim as optim
import torch.nn.functional as F
from llm_classifier_modified import LLMClassifier
from llm_model_modified1 import LLM

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Load sst2 train and test data
df_train = pd.read_csv('train_sst2.csv')
df_test = pd.read_csv('test_sst2.csv')
n_train = 10000  
n_in_context = 5  
n_total_in_context = 9 * n_in_context  
n_val=100
df_train_actual = df_train.iloc[:n_train] 
df_test_actual = df_test.iloc[:]  
gt_labels_train = df_train_actual.iloc[:, 2].values.astype(int) 
samples_train = df_train_actual.iloc[:, 1].values 
gt_labels_test = df_test_actual.iloc[:, 2].values.astype(int)
samples_test = df_test_actual.iloc[:, 1].values 

In [4]:
# Define a prompt formatting class for sentiment classification and initializes an LLM-based classifier
class PromptFormatting(object):
    def __init__(self):
        # Best instruction from BayesPE teacher i.e. instruction with highest weight
        self.INSTRUCTION = 'Select the sentiment category that best matches the opinion expressed in the review snippet.'
        self.CLASSES = ['negative', 'positive']
        self.CLASSES_FOR_MATCHING = [self.CLASSES, ['neg', 'pos'], ['1', '2']]
        self.CLASSES_TEXT = '''1. {}\n2. {}'''.format(self.CLASSES[0], self.CLASSES[1])

    def format_instruction(self, instruction):
        return '''{}\n{}\n'''.format(instruction, self.CLASSES_TEXT)

    def format_content(self, content):
        return '''review: {}\nthe review is '''.format(content)

llm = LLM(model_name="mistralai/Mistral-7B-Instruct-v0.3", use_reduced_precision=True,use_lora=True)
prompt_formatting = PromptFormatting()
classifier = LLMClassifier(model=llm, prompt_formatting=prompt_formatting)

Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:02<00:00,  1.08it/s]


trainable params: 7,110,656 || all params: 7,255,134,208 || trainable%: 0.0980


In [5]:
# Load teacher predictions and weights
probs = torch.load("sst2_probs.pt", weights_only=False)
weights = torch.load("sst2_prompt_weights.pt", weights_only=False)
if isinstance(probs, np.ndarray):
    probs = torch.tensor(probs, dtype=torch.float32, device=llm.device)
if isinstance(weights, np.ndarray):
    weights = torch.tensor(weights, dtype=torch.float32, device=llm.device)

In [6]:
# Compute KL divergence loss
def dirichlet_loss(student_probs, probs):
    kl_loss = F.kl_div(student_probs.log(), probs, reduction='batchmean')
    return kl_loss

In [7]:
from torch.utils.data import Dataset, DataLoader

class DirichletDataset(Dataset):
    def __init__(self, samples, num_samples):
        self.samples = samples
        self.num_samples = num_samples

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return self.samples[idx], idx 

In [8]:
# Train student model with teacher predictions
def train_student(samples_train, probs, weights, num_epochs=10, learning_rate=1e-5, batch_size=32):
    dataset = DirichletDataset(samples_train, len(samples_train))
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, llm.model.parameters()), lr=learning_rate)
    llm.model.train()  

    for epoch in range(num_epochs):
        total_loss = 0

        for batch_idx, (batch_samples, batch_indices) in enumerate(dataloader, start=1):
            batch_indices = batch_indices.to(llm.device)

            batch_probs = probs[batch_indices] 
            weights = weights.view(-1)

            # probs shape: [16, 2, 9]
            # weights shape: [9] (will broadcast over batch and class dims)
            
            batch_probs = (batch_probs * weights)  # broadcasting over last dim
            batch_probs = batch_probs.sum(dim=2) 
            optimizer.zero_grad()

            student_probs = classifier.soft_labels_batch(input_texts=batch_samples)
            loss = dirichlet_loss(student_probs, batch_probs)
          
            loss.backward()

            optimizer.step()
            total_loss += loss.item()

            if batch_idx % 1000 == 0:
                    print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}/{len(dataloader)}")

        torch.cuda.empty_cache()

        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss}")

train_student(samples_train, probs, weights, batch_size=16)

Epoch 1/10, Loss: 22.32190666499082
Epoch 2/10, Loss: 8.624014998029452
Epoch 3/10, Loss: 6.080231167346938
Epoch 4/10, Loss: 4.490629538078792
Epoch 5/10, Loss: 3.567622763948748
Epoch 6/10, Loss: 2.9407568451279076
Epoch 7/10, Loss: 2.9864752750872867
Epoch 8/10, Loss: 2.6392889528360683
Epoch 9/10, Loss: 2.2708205345188617
Epoch 10/10, Loss: 2.2003654105355963


In [9]:
# Evaluate performance of model on sst2 test data
class DirichletDataset(Dataset):
    def __init__(self, samples, n_samples):
        self.samples = samples
        self.n_samples = n_samples

    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        return self.samples[idx]

llm.model.eval()
test_dataset = DirichletDataset(samples_test, 872)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False) 

def get_test_alpha(test_dataloader, classifier):
    all_alpha = []

    with torch.no_grad():
        for batch_samples in test_dataloader:
            alpha_batch = classifier.soft_labels_batch(input_texts=batch_samples)
            all_alpha.append(alpha_batch)

    return torch.cat(all_alpha, dim=0) 

stu_probs = get_test_alpha(test_dataloader, classifier)

In [10]:
import evaluation  
stu_probs=stu_probs.cpu().numpy()
f1_score = evaluation.compute_metric(gt_labels_test, stu_probs, metric='f1')
ece = evaluation.compute_metric(gt_labels_test, stu_probs, metric='ece')
print('Student f1-score: {}, Student ECE: {}'.format(f1_score, ece))

Student f1-score: 0.9541262684981088, Student ECE: 0.024856507778167725
