## Student Model(with dirichlet output) training and evaluation on Amazon reviews polarity dataset

In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [2]:
import pandas as pd
import numpy as np
import torch
import torch.optim as optim
import torch.nn.functional as F
from llm_classifier_modified import LLMClassifier
from llm_model_modified import LLM
from torch.utils.data import Dataset, DataLoader
from scipy.stats import dirichlet
import evaluation 

In [3]:
# Load Amazon reviews polarity train and test data
df_train = pd.read_csv('train_amazon.csv', header=None)
df_test = pd.read_csv('test_amazon.csv', header=None)

n_train = 10000
n_in_context = 5  
n_total_in_context = len(df_train) * n_in_context  
n_test = 5000
n_val = 100

df_train_actual = df_train.iloc[:n_train] 
df_in_context_base = df_train.iloc[n_train:n_train + n_total_in_context]
df_val = df_train.iloc[n_train + n_total_in_context:n_train + n_total_in_context + n_val]
df_test_actual = df_test.iloc[:n_test]  

gt_labels_train = df_train_actual.iloc[:, 0].values.astype(int) 
samples_train = df_train_actual.iloc[:, 2].values 
gt_labels_val = df_val.iloc[:, 0].values.astype(int) 
samples_val = df_val.iloc[:, 2].values 

gt_labels_test = df_test_actual.iloc[:, 0].values.astype(int)
samples_test = df_test_actual.iloc[:, 2].values  


In [4]:
# Define a prompt formatting class for sentiment classification and initializes an LLM-based classifier
class PromptFormatting(object):
    def __init__(self):
        # Best instruction from BayesPE teacher i.e. instruction with highest weight
        self.INSTRUCTION = 'classify the sentiment of the Amazon review below into one of the following classes:'
        self.CLASSES = ['negative', 'positive']
        self.CLASSES_FOR_MATCHING = [self.CLASSES, ['neg', 'pos'], ['1', '2']]
        self.CLASSES_TEXT = '''1. {}\n2. {}'''.format(self.CLASSES[0], self.CLASSES[1])

    def format_instruction(self, instruction):
        return '''{}\n{}\n'''.format(instruction, self.CLASSES_TEXT)

    def format_content(self, content):
        return '''review: {}\nthe review is '''.format(content)

llm = LLM(model_name="mistralai/Mistral-7B-Instruct-v0.3", use_reduced_precision=True,use_lora=True)
prompt_formatting = PromptFormatting()
classifier = LLMClassifier(model=llm, prompt_formatting=prompt_formatting)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

trainable params: 14,221,312 || all params: 7,262,244,864 || trainable%: 0.1958


In [5]:
# Load teacher predictions and weights
probs = torch.load("amazon_probs.pt", weights_only=False)
weights = torch.load("amazon_prompt_weights.pt", weights_only=False)
if isinstance(probs, np.ndarray):
    probs = torch.tensor(probs, dtype=torch.float32, device=llm.device)
if isinstance(weights, np.ndarray):
    weights = torch.tensor(weights, dtype=torch.float32, device=llm.device)

In [6]:
# Compute Dirichlet-based distillation loss
import torch

def dirichlet_loss(alpha, probs, weights):

    alpha_0 = torch.sum(alpha, dim=1, keepdim=True)                      
    log_gamma_alpha_0 = torch.lgamma(alpha_0)                           
    log_gamma_alpha_c = torch.lgamma(alpha).sum(dim=1, keepdim=True)   

    alpha_expanded = alpha.unsqueeze(-1)                                

    weighted_log_probs = (alpha_expanded - 1) * torch.log(probs + 1e-8) 

    class_sum = weighted_log_probs.sum(dim=1)                           

    if weights.ndim == 1:
        weights = weights.unsqueeze(1)                                   
    weights_broadcasted = weights.T.expand(probs.shape[0], -1)          
    
    weighted_terms = class_sum * weights_broadcasted                  

    prompt_sum = weighted_terms.sum(dim=1, keepdim=True)               

    loss = -(log_gamma_alpha_0 - log_gamma_alpha_c + prompt_sum).mean()

    return loss


In [7]:
# Evaluate performance of model on Amazon reviews polarity test data
def evaluate():
    def dirichlet_to_prob(alpha):
        return alpha / alpha.sum(dim=1, keepdim=True) 
    
    
    class TestDirichletDataset(Dataset):
        def __init__(self, samples, n_samples):
            self.samples = samples
            self.n_samples = n_samples
    
        def __len__(self):
            return self.n_samples
    
        def __getitem__(self, idx):
            return self.samples[idx]
    
    llm.model.eval()
    test_dataset = TestDirichletDataset(samples_test, n_test)
    test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False) 
    
    def get_test_alpha(test_dataloader, classifier):
        all_alpha = []
    
        with torch.no_grad():
            for batch_samples in test_dataloader:
                alpha_batch = classifier.soft_labels_batch(input_texts=batch_samples)
                all_alpha.append(alpha_batch)
    
        return torch.cat(all_alpha, dim=0) 
    
    alpha_test = get_test_alpha(test_dataloader, classifier)
    stu_probs = dirichlet_to_prob(alpha_test)
    stu_probs=stu_probs.cpu().numpy()
    f1_score = evaluation.compute_metric(gt_labels_test, stu_probs, metric='f1')
    ece = evaluation.compute_metric(gt_labels_test, stu_probs, metric='ece')
    print('Student f1-score: {}, Student ECE: {}'.format(f1_score, ece))

In [8]:
from torch.utils.data import Dataset, DataLoader

class DirichletDataset(Dataset):
    def __init__(self, samples, num_samples):
        self.samples = samples
        self.num_samples = num_samples

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return self.samples[idx], idx 

In [9]:
# Train student model with teacher predictions and evaluate after each epoch on test data
def train_student(samples_train, probs, weights, num_epochs=10, learning_rate=1e-5, batch_size=32):
    dataset = DirichletDataset(samples_train, len(samples_train))
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, llm.model.parameters()), lr=learning_rate)
    llm.model.train()  

    for epoch in range(num_epochs):
        total_loss = 0
        llm.model.train()
        for batch_idx, (batch_samples, batch_indices) in enumerate(dataloader, start=1):
            batch_indices = batch_indices.to(llm.device)

            batch_probs = probs[batch_indices] 

            optimizer.zero_grad()

            alpha = classifier.soft_labels_batch(input_texts=batch_samples)
            alpha = torch.clamp(alpha, min=1e-3)
            loss = dirichlet_loss(alpha, batch_probs, weights)
          
            loss.backward()

            optimizer.step()
            total_loss += loss.item()

            if batch_idx % 1000 == 0:
                    print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}/{len(dataloader)}")

        torch.cuda.empty_cache()
        
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss}")
        evaluate()
train_student(samples_train, probs, weights, batch_size=16)

Epoch 1/10, Loss: -1564.08486405015
Student f1-score: 0.951456341839787, Student ECE: 0.010912657715380192
Epoch 2/10, Loss: -2205.8668997883797
Student f1-score: 0.9585356323742725, Student ECE: 0.01002875529229641
Epoch 3/10, Loss: -2415.1293152570724
Student f1-score: 0.9579417981779905, Student ECE: 0.005028365179896355
Epoch 4/10, Loss: -2402.583825945854
Student f1-score: 0.9569887121176082, Student ECE: 0.009572663344442844
Epoch 5/10, Loss: -2564.2673350572586
Student f1-score: 0.9591458699117983, Student ECE: 0.018178431317210197
Epoch 6/10, Loss: -2664.221682071686
Student f1-score: 0.9575293015765765, Student ECE: 0.021279077976942062
Epoch 7/10, Loss: -2782.1060552597046
Student f1-score: 0.9585468720130468, Student ECE: 0.019260089844465256
Epoch 8/10, Loss: -2873.131416320801
Student f1-score: 0.959346726324044, Student ECE: 0.02085147425532341
Epoch 9/10, Loss: -2956.709189891815
Student f1-score: 0.9601576571517145, Student ECE: 0.022752856835722923
Epoch 10/10, Loss: -

## Evaluation on out-of-distribution data

In [9]:
# Load Yahoo Answers dataset
df_train = pd.read_csv('train_yahoo.csv', header=None)
df_test = pd.read_csv('test_yahoo.csv', header=None)

n_train = 10000  
n_in_context = 5  
n_val = 100
n_test = 5000

# Split data
df_train_actual = df_train.iloc[:n_train]
df_test_actual = df_test.iloc[:n_test]

# Format function for prompts
def format_prompt(q1, q2, a):
    return "Question: " + q1.astype(str) + " " + q2.astype(str) + "\nAnswer: " + a.astype(str)

# Extract training data
gt_labels_train = df_train_actual.iloc[:, 0].values.astype(int)
samples_train = format_prompt(df_train_actual.iloc[:, 1], df_train_actual.iloc[:, 2], df_train_actual.iloc[:, 3]).values

gt_labels_test = df_test_actual.iloc[:, 0].values.astype(int)
samples_test = format_prompt(df_test_actual.iloc[:, 1], df_test_actual.iloc[:, 2], df_test_actual.iloc[:, 3]).values

In [10]:
class PromptFormatting(object):
    def __init__(self):
        self.INSTRUCTION = 'classify the question and answer below into one of the following topics:'
        self.CLASSES = [
    'Society & Culture',
    'Science & Mathematics',
    'Health',
    'Education & Reference',
    'Computers & Internet',
    'Sports',
    'Business & Finance',
    'Entertainment & Music',
    'Family & Relationships',
    'Politics & Government'
]
        self.CLASSES_FOR_MATCHING = [self.CLASSES]
        self.CLASSES_TEXT = '''1. {}\n2. {}\n3. {}\n4. {}\n5. {}\n6. {}\n7. {}\n8. {}\n9. {}\n10. {}'''.format(self.CLASSES[0],self.CLASSES[1], self.CLASSES[2], self.CLASSES[3], self.CLASSES[4], self.CLASSES[5], self.CLASSES[6], self.CLASSES[7], self.CLASSES[8], self.CLASSES[9])
    def format_instruction(self, instruction):
        return '''{}\n{}\n'''.format(instruction, self.CLASSES_TEXT)

    def format_content(self, content):
        return '''{}\nthe topic is '''.format(content)

prompt_formatting = PromptFormatting()
classifier = LLMClassifier(model=llm, prompt_formatting=prompt_formatting)

In [11]:
# Evaluate performance of model on Yahoo answers test data
def dirichlet_to_prob(alpha):
    return alpha / alpha.sum(dim=1, keepdim=True) 


class DirichletDataset(Dataset):
    def __init__(self, samples, n_samples):
        self.samples = samples
        self.n_samples = n_samples

    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        return self.samples[idx]

llm.model.eval()
test_dataset = DirichletDataset(samples_test, n_test)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False) 

def get_test_alpha(test_dataloader, classifier):
    all_alpha = []

    with torch.no_grad():
        for batch_samples in test_dataloader:
            alpha_batch = classifier.soft_labels_batch(input_texts=batch_samples)
            all_alpha.append(alpha_batch)

    return torch.cat(all_alpha, dim=0) 

yahoo_alpha_test = get_test_alpha(test_dataloader, classifier)
stu_probs = dirichlet_to_prob(yahoo_alpha_test)

In [12]:
import evaluation  
stu_probs=stu_probs.cpu().numpy()
f1_score = evaluation.compute_metric(gt_labels_test, stu_probs, metric='f1')
ece = evaluation.compute_metric(gt_labels_test, stu_probs, metric='ece')
print('Student f1-score: {}, Student ECE: {}'.format(f1_score, ece))

Student f1-score: 0.5545621356087371, Student ECE: 0.2955533266067505


In [47]:
# Compute the predictive entropy by sampling from Dirichlet distribution
def dirichlet_entropy(alpha: torch.Tensor, n_samples: int = 1000) -> torch.Tensor:

    batch_size, num_classes = alpha.shape
    
    alpha_expanded = alpha.unsqueeze(1).expand(batch_size, n_samples, num_classes)
    
    samples = torch.distributions.Dirichlet(alpha_expanded).sample()  
    
    entropy_samples = -torch.sum(samples * torch.log(samples + 1e-10), dim=2) 
    
    entropy_estimate = entropy_samples.mean(dim=1) 
    
    return entropy_estimate


In [48]:
# Predictive entropy on yahoo answers test data
ent_yahoo = dirichlet_entropy(yahoo_alpha_test)
print(ent_yahoo)

tensor([2.1386, 2.0383, 2.0178,  ..., 2.0230, 2.0151, 1.8256], device='cuda:0')


In [49]:
# Mean predictive entropy
print(ent_yahoo.mean())

tensor(2.0286, device='cuda:0')


In [50]:
# Load sst2 dataset
df_train = pd.read_csv('train_sst2.csv')
df_test = pd.read_csv('test_sst2.csv')
n_train = 10000  
n_in_context = 5  
n_total_in_context = 9 * n_in_context  
n_val=100
df_train_actual = df_train.iloc[:n_train] 
df_test_actual = df_test.iloc[:]  
gt_labels_train = df_train_actual.iloc[:, 2].values.astype(int) 
samples_train = df_train_actual.iloc[:, 1].values 
gt_labels_test = df_test_actual.iloc[:, 2].values.astype(int)
samples_test = df_test_actual.iloc[:, 1].values 

In [51]:
class PromptFormatting(object):
    def __init__(self):
        self.INSTRUCTION = 'Select the sentiment category that best matches the opinion expressed in the review snippet.'
        self.CLASSES = ['negative', 'positive']
        self.CLASSES_FOR_MATCHING = [self.CLASSES, ['neg', 'pos'], ['1', '2']]
        self.CLASSES_TEXT = '''1. {}\n2. {}'''.format(self.CLASSES[0], self.CLASSES[1])

    def format_instruction(self, instruction):
        return '''{}\n{}\n'''.format(instruction, self.CLASSES_TEXT)

    def format_content(self, content):
        return '''review: {}\nthe review is '''.format(content)

prompt_formatting = PromptFormatting()
classifier = LLMClassifier(model=llm, prompt_formatting=prompt_formatting)

In [52]:
# Evaluate performance of model on sst2 test data
def dirichlet_to_prob(alpha):
    return alpha / alpha.sum(dim=1, keepdim=True) 


class DirichletDataset(Dataset):
    def __init__(self, samples, n_samples):
        self.samples = samples
        self.n_samples = n_samples

    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        return self.samples[idx]

llm.model.eval()
test_dataset = DirichletDataset(samples_test, 872)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False) 

def get_test_alpha(test_dataloader, classifier):
    all_alpha = []

    with torch.no_grad():
        for batch_samples in test_dataloader:
            alpha_batch = classifier.soft_labels_batch(input_texts=batch_samples)
            all_alpha.append(alpha_batch)

    return torch.cat(all_alpha, dim=0) 

sst2_alpha_test = get_test_alpha(test_dataloader, classifier)
stu_probs = dirichlet_to_prob(sst2_alpha_test)

In [53]:
import evaluation  
stu_probs=stu_probs.cpu().numpy()
f1_score = evaluation.compute_metric(gt_labels_test, stu_probs, metric='f1')
ece = evaluation.compute_metric(gt_labels_test, stu_probs, metric='ece')
print('Student f1-score: {}, Student ECE: {}'.format(f1_score, ece))

Student f1-score: 0.9506771110804181, Student ECE: 0.016975531354546547


In [54]:
# Predictive entropy on sst2 test data
ent_sst2 = dirichlet_entropy(sst2_alpha_test)
print(ent_sst2)

tensor([0.0164, 0.0517, 0.0154, 0.0157, 0.0231, 0.0180, 0.2753, 0.3825, 0.0158,
        0.0355, 0.0156, 0.0451, 0.1752, 0.5100, 0.0200, 0.0157, 0.4975, 0.0159,
        0.0945, 0.0563, 0.5022, 0.1333, 0.1261, 0.0152, 0.0161, 0.1583, 0.0990,
        0.0397, 0.0338, 0.0192, 0.0167, 0.0422, 0.0156, 0.2825, 0.0335, 0.1814,
        0.0156, 0.3833, 0.0521, 0.0157, 0.0160, 0.0167, 0.0852, 0.0161, 0.0297,
        0.4942, 0.0300, 0.0178, 0.0167, 0.0335, 0.0232, 0.0162, 0.5056, 0.0647,
        0.0200, 0.0168, 0.0657, 0.3615, 0.0377, 0.0257, 0.0170, 0.1691, 0.0453,
        0.0166, 0.4944, 0.1473, 0.5023, 0.0154, 0.0184, 0.0193, 0.2788, 0.0158,
        0.0161, 0.4915, 0.0223, 0.0398, 0.0518, 0.0854, 0.2397, 0.0187, 0.0178,
        0.0464, 0.0653, 0.0161, 0.0170, 0.1716, 0.0356, 0.0625, 0.4656, 0.0178,
        0.2267, 0.1086, 0.5044, 0.4945, 0.0150, 0.1523, 0.0332, 0.0162, 0.0196,
        0.2057, 0.0176, 0.0575, 0.4990, 0.0165, 0.1202, 0.1625, 0.0155, 0.0383,
        0.0339, 0.0173, 0.0588, 0.0570, 

In [55]:
print(ent_sst2.mean())

tensor(0.1146, device='cuda:0')


In [56]:
# Load youtube comments dataset
df_train = pd.read_csv('youtube.csv')
n_train = 1100  
n_in_context = 5 
n_total_in_context = 9 * n_in_context  
n_val=100
df_train_actual = df_train.iloc[:n_train] 
df_in_context_base = df_train.iloc[n_train:n_train + n_total_in_context]
df_val = df_train.iloc[n_train + n_total_in_context:n_train+n_total_in_context+n_val]
df_test_actual = df_train.iloc[n_train+n_total_in_context+n_val:]  
gt_labels_train = df_train_actual.iloc[:, 4].values.astype(int) 
samples_train = df_train_actual.iloc[:, 3].values 
gt_labels_val = df_val.iloc[:, 4].values.astype(int) 
samples_val = df_val.iloc[:, 3].values 
gt_labels_test = df_test_actual.iloc[:, 4].values.astype(int)
samples_test = df_test_actual.iloc[:, 3].values 

In [57]:
class PromptFormatting(object):
    def __init__(self):
        self.INSTRUCTION = 'Judge whether the Youtube comment should be flagged as spam.'
        self.CLASSES = ['not spam', 'spam']
        self.CLASSES_FOR_MATCHING = [self.CLASSES, ['ham', 'spam'], ['0', '1']]
        self.CLASSES_TEXT = '''1. {}\n2. {}'''.format(self.CLASSES[0], self.CLASSES[1])

    def format_instruction(self, instruction):
        return '''{}\n{}\n'''.format(instruction, self.CLASSES_TEXT)

    def format_content(self, content):
        return '''comment: {}\nthe comment is '''.format(content)

prompt_formatting = PromptFormatting()
classifier = LLMClassifier(model=llm, prompt_formatting=prompt_formatting)

In [58]:
# Evaluate performance of model on youtube comments test data
def dirichlet_to_prob(alpha):
    return alpha / alpha.sum(dim=1, keepdim=True) 


class DirichletDataset(Dataset):
    def __init__(self, samples, n_samples):
        self.samples = samples
        self.n_samples = n_samples

    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        return self.samples[idx]

llm.model.eval()
test_dataset = DirichletDataset(samples_test, 711)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False) 

def get_test_alpha(test_dataloader, classifier):
    all_alpha = []

    with torch.no_grad():
        for batch_samples in test_dataloader:
            alpha_batch = classifier.soft_labels_batch(input_texts=batch_samples)
            all_alpha.append(alpha_batch)

    return torch.cat(all_alpha, dim=0) 

youtube_alpha_test = get_test_alpha(test_dataloader, classifier)
stu_probs = dirichlet_to_prob(youtube_alpha_test)

In [59]:
import evaluation  
stu_probs=stu_probs.cpu().numpy()
f1_score = evaluation.compute_metric(gt_labels_test, stu_probs, metric='f1')
ece = evaluation.compute_metric(gt_labels_test, stu_probs, metric='ece')
print('Student f1-score: {}, Student ECE: {}'.format(f1_score, ece))

Student f1-score: 0.5224537771227076, Student ECE: 0.14621475338935852


In [60]:
# Predictive entropy on youtube comments test data
ent = dirichlet_entropy(youtube_alpha_test)
print(ent)

tensor([0.5304, 0.6640, 0.5778, 0.6527, 0.6513, 0.4010, 0.4197, 0.4753, 0.5635,
        0.5617, 0.4277, 0.4687, 0.6423, 0.6407, 0.4868, 0.5873, 0.6159, 0.5650,
        0.3918, 0.4491, 0.6012, 0.6071, 0.6593, 0.5052, 0.6575, 0.4285, 0.6543,
        0.5418, 0.6603, 0.6568, 0.4886, 0.4058, 0.5742, 0.4640, 0.6565, 0.5254,
        0.5735, 0.5169, 0.4534, 0.5153, 0.5023, 0.4488, 0.4186, 0.4121, 0.6528,
        0.4436, 0.6571, 0.4449, 0.3824, 0.5734, 0.3325, 0.3367, 0.5573, 0.4536,
        0.5345, 0.5636, 0.6524, 0.6364, 0.5444, 0.4319, 0.6348, 0.6579, 0.4182,
        0.4703, 0.5836, 0.4301, 0.6283, 0.4295, 0.6406, 0.4463, 0.4902, 0.5741,
        0.4492, 0.6432, 0.6519, 0.5506, 0.4316, 0.5154, 0.3295, 0.4721, 0.4323,
        0.4964, 0.6419, 0.4410, 0.6502, 0.6170, 0.4659, 0.6415, 0.6274, 0.3363,
        0.5445, 0.5906, 0.5664, 0.4536, 0.4660, 0.3626, 0.5363, 0.6487, 0.3917,
        0.5114, 0.5735, 0.4504, 0.4441, 0.3016, 0.4395, 0.6676, 0.3281, 0.5830,
        0.5075, 0.4550, 0.6255, 0.6217, 

In [61]:
print(ent.mean())

tensor(0.5143, device='cuda:0')
