## Student Model(with softmax output) training and evaluation on Amazon reviews polarity dataset

In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [2]:
import pandas as pd
import numpy as np
import torch
import torch.optim as optim
import torch.nn.functional as F
from llm_classifier_modified import LLMClassifier
from llm_model_modified1 import LLM

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Load Amazon reviews polarity train and test data
df_train = pd.read_csv('train_amazon.csv', header=None)
df_test = pd.read_csv('test_amazon.csv', header=None)

n_train = 10000
n_in_context = 5  
n_total_in_context = len(df_train) * n_in_context  
n_test = 5000
n_val = 100

df_train_actual = df_train.iloc[:n_train] 
df_in_context_base = df_train.iloc[n_train:n_train + n_total_in_context]
df_val = df_train.iloc[n_train + n_total_in_context:n_train + n_total_in_context + n_val]
df_test_actual = df_test.iloc[:n_test]  

gt_labels_train = df_train_actual.iloc[:, 0].values.astype(int) 
samples_train = df_train_actual.iloc[:, 2].values 
gt_labels_val = df_val.iloc[:, 0].values.astype(int) 
samples_val = df_val.iloc[:, 2].values 

gt_labels_test = df_test_actual.iloc[:, 0].values.astype(int)
samples_test = df_test_actual.iloc[:, 2].values  


In [4]:
# Define a prompt formatting class for sentiment classification and initializes an LLM-based classifier
class PromptFormatting(object):
    def __init__(self):
        # Best instruction from BayesPE teacher i.e. instruction with highest weight
        self.INSTRUCTION = 'classify the sentiment of the Amazon review below into one of the following classes:'
        self.CLASSES = ['negative', 'positive']
        self.CLASSES_FOR_MATCHING = [self.CLASSES, ['neg', 'pos'], ['1', '2']]
        self.CLASSES_TEXT = '''1. {}\n2. {}'''.format(self.CLASSES[0], self.CLASSES[1])

    def format_instruction(self, instruction):
        return '''{}\n{}\n'''.format(instruction, self.CLASSES_TEXT)

    def format_content(self, content):
        return '''review: {}\nthe review is '''.format(content)

llm = LLM(model_name="mistralai/Mistral-7B-Instruct-v0.3", use_reduced_precision=True,use_lora=True)
prompt_formatting = PromptFormatting()
classifier = LLMClassifier(model=llm, prompt_formatting=prompt_formatting)

Loading checkpoint shards: 100%|██████████████████| 3/3 [00:02<00:00,  1.09it/s]


trainable params: 7,110,656 || all params: 7,255,134,208 || trainable%: 0.0980


In [5]:
# Load teacher predictions and weights
probs = torch.load("amazon_probs.pt", weights_only=False)
weights = torch.load("amazon_prompt_weights.pt", weights_only=False)
if isinstance(probs, np.ndarray):
    probs = torch.tensor(probs, dtype=torch.float32, device=llm.device)
if isinstance(weights, np.ndarray):
    weights = torch.tensor(weights, dtype=torch.float32, device=llm.device)

In [6]:
# Compute KL divergence loss
def dirichlet_loss(student_probs, probs):
    kl_loss = F.kl_div(student_probs.log(), probs, reduction='batchmean')
    return kl_loss

In [7]:
from torch.utils.data import Dataset, DataLoader

class DirichletDataset(Dataset):
    def __init__(self, samples, num_samples):
        self.samples = samples
        self.num_samples = num_samples

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return self.samples[idx], idx 

In [8]:
# Train student model with teacher predictions
def train_student(samples_train, probs, weights, num_epochs=10, learning_rate=1e-5, batch_size=32):
    dataset = DirichletDataset(samples_train, len(samples_train))
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, llm.model.parameters()), lr=learning_rate)
    llm.model.train()  

    for epoch in range(num_epochs):
        total_loss = 0

        for batch_idx, (batch_samples, batch_indices) in enumerate(dataloader, start=1):
            batch_indices = batch_indices.to(llm.device)

            batch_probs = probs[batch_indices] 
            weights = weights.view(-1)

                  
            batch_probs = (batch_probs * weights) 
            batch_probs = batch_probs.sum(dim=2) 
            optimizer.zero_grad()

            student_probs = classifier.soft_labels_batch(input_texts=batch_samples)
            loss = dirichlet_loss(student_probs, batch_probs)
          
            loss.backward()

            optimizer.step()
            total_loss += loss.item()

            if batch_idx % 1000 == 0:
                    print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}/{len(dataloader)}")

        torch.cuda.empty_cache()

        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss}")

train_student(samples_train, probs, weights, batch_size=16)

Epoch 1/10, Loss: 14.995201678422745
Epoch 2/10, Loss: 5.594062375283102
Epoch 3/10, Loss: 3.152453308532131
Epoch 4/10, Loss: 2.143950943296659
Epoch 5/10, Loss: 2.2521223495932645
Epoch 6/10, Loss: 2.2971361029849504
Epoch 7/10, Loss: 1.934872161684325
Epoch 8/10, Loss: 1.2925368519863696
Epoch 9/10, Loss: 1.029416300902085
Epoch 10/10, Loss: 0.947909543679998


In [9]:
# Evaluate performance of model on amazon reviews polarity test data
class DirichletDataset(Dataset):
    def __init__(self, samples, n_samples):
        self.samples = samples
        self.n_samples = n_samples

    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        return self.samples[idx]

llm.model.eval()
test_dataset = DirichletDataset(samples_test, n_test)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False) 

def get_test_alpha(test_dataloader, classifier):
    all_alpha = []

    with torch.no_grad():
        for batch_samples in test_dataloader:
            alpha_batch = classifier.soft_labels_batch(input_texts=batch_samples)
            all_alpha.append(alpha_batch)

    return torch.cat(all_alpha, dim=0) 



stu_probs = get_test_alpha(test_dataloader, classifier)

In [10]:
import evaluation  
stu_probs=stu_probs.cpu().numpy()
f1_score = evaluation.compute_metric(gt_labels_test, stu_probs, metric='f1')
ece = evaluation.compute_metric(gt_labels_test, stu_probs, metric='ece')
print('Student f1-score: {}, Student ECE: {}'.format(f1_score, ece))

Student f1-score: 0.9589682890525113, Student ECE: 0.01881118305027485


## Evaluation on out-of-distribution data

In [9]:
# Load Yahoo Answers dataset
df_train = pd.read_csv('train_yahoo.csv', header=None)
df_test = pd.read_csv('test_yahoo.csv', header=None)

n_train = 10000  
n_in_context = 5  
n_val = 100
n_test = 5000

# Split data
df_train_actual = df_train.iloc[:n_train]
df_test_actual = df_test.iloc[:n_test]

# Format function for prompts
def format_prompt(q1, q2, a):
    return "Question: " + q1.astype(str) + " " + q2.astype(str) + "\nAnswer: " + a.astype(str)

# Extract training data
gt_labels_train = df_train_actual.iloc[:, 0].values.astype(int)
samples_train = format_prompt(df_train_actual.iloc[:, 1], df_train_actual.iloc[:, 2], df_train_actual.iloc[:, 3]).values

gt_labels_test = df_test_actual.iloc[:, 0].values.astype(int)
samples_test = format_prompt(df_test_actual.iloc[:, 1], df_test_actual.iloc[:, 2], df_test_actual.iloc[:, 3]).values

In [10]:
class PromptFormatting(object):
    def __init__(self):
        self.INSTRUCTION = 'Identify the topic that the following question and answer belong to:'
        self.CLASSES = [
            ' ',
    'Society & Culture',
    'Science & Mathematics',
    'Health',
    'Education & Reference',
    'Computers & Internet',
    'Sports',
    'Business & Finance',
    'Entertainment & Music',
    'Family & Relationships',
    'Politics & Government'
]
        self.CLASSES_FOR_MATCHING = [self.CLASSES]
        self.CLASSES_TEXT = '''1. {}\n2. {}\n3. {}\n4. {}\n5. {}\n6. {}\n7. {}\n8. {}\n9. {}\n10. {}'''.format(self.CLASSES[1], self.CLASSES[2], self.CLASSES[3], self.CLASSES[4], self.CLASSES[5], self.CLASSES[6], self.CLASSES[7], self.CLASSES[8], self.CLASSES[9], self.CLASSES[10])

    def format_instruction(self, instruction):
        return '''{}\n{}\n'''.format(instruction, self.CLASSES_TEXT)

    def format_content(self, content):
        return '''{}\nthe topic is '''.format(content)

prompt_formatting = PromptFormatting()
classifier = LLMClassifier(model=llm, prompt_formatting=prompt_formatting)

In [11]:
# Evaluate performance of model on Yahoo answers test data
class DirichletDataset(Dataset):
    def __init__(self, samples, n_samples):
        self.samples = samples
        self.n_samples = n_samples

    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        return self.samples[idx]

llm.model.eval()
test_dataset = DirichletDataset(samples_test, n_test)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False) 

def get_test_alpha(test_dataloader, classifier):
    all_alpha = []

    with torch.no_grad():
        for batch_samples in test_dataloader:
            alpha_batch = classifier.soft_labels_batch(input_texts=batch_samples)
            all_alpha.append(alpha_batch)

    return torch.cat(all_alpha, dim=0) 



stu_probs = get_test_alpha(test_dataloader, classifier)

In [12]:
import evaluation  
stu_probs=stu_probs.cpu().numpy()
f1_score = evaluation.compute_metric(gt_labels_test, stu_probs, metric='f1')
ece = evaluation.compute_metric(gt_labels_test, stu_probs, metric='ece')
print('Student f1-score: {}, Student ECE: {}'.format(f1_score, ece))

Student f1-score: 0.5477921854622683, Student ECE: 0.23658406734466553


In [13]:
# Compute the predictive entropy on yahoo answers test data
def entropy_numpy(probs: np.ndarray) -> np.ndarray:
    """
    Compute entropy from probabilities for each row (example) in a NumPy array.
    `probs`: shape [num_examples, num_classes]
    Returns entropy: shape [num_examples]
    """
    return entropy(stu_probs, axis=1)  # computes entropy along the class dimension
    
ent_yahoo = entropy_numpy(stu_probs)
print(ent_yahoo)


[1.7640457  0.11008723 0.3987073  ... 1.3392062  0.31525946 0.01033301]


In [14]:
# Mean predictive entropy
print(ent_yahoo.mean())

0.6054618


In [15]:
# Load sst2 dataset
df_train = pd.read_csv('train_sst2.csv')
df_test = pd.read_csv('test_sst2.csv')
n_train = 10000  
n_in_context = 5  
n_total_in_context = 9 * n_in_context  
n_val=100
df_train_actual = df_train.iloc[:n_train] 
df_test_actual = df_test.iloc[:]  
gt_labels_train = df_train_actual.iloc[:, 2].values.astype(int) 
samples_train = df_train_actual.iloc[:, 1].values 
gt_labels_test = df_test_actual.iloc[:, 2].values.astype(int)
samples_test = df_test_actual.iloc[:, 1].values 

In [16]:
class PromptFormatting(object):
    def __init__(self):
        self.INSTRUCTION = 'Select the sentiment category that best matches the opinion expressed in the review snippet.'
        self.CLASSES = ['negative', 'positive']
        self.CLASSES_FOR_MATCHING = [self.CLASSES, ['neg', 'pos'], ['1', '2']]
        self.CLASSES_TEXT = '''1. {}\n2. {}'''.format(self.CLASSES[0], self.CLASSES[1])

    def format_instruction(self, instruction):
        return '''{}\n{}\n'''.format(instruction, self.CLASSES_TEXT)

    def format_content(self, content):
        return '''review: {}\nthe review is '''.format(content)

prompt_formatting = PromptFormatting()
classifier = LLMClassifier(model=llm, prompt_formatting=prompt_formatting)

In [17]:
# Evaluate performance of model on sst2 test data
class DirichletDataset(Dataset):
    def __init__(self, samples, n_samples):
        self.samples = samples
        self.n_samples = n_samples

    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        return self.samples[idx]

llm.model.eval()
test_dataset = DirichletDataset(samples_test, 872)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False) 

def get_test_alpha(test_dataloader, classifier):
    all_alpha = []

    with torch.no_grad():
        for batch_samples in test_dataloader:
            alpha_batch = classifier.soft_labels_batch(input_texts=batch_samples)
            all_alpha.append(alpha_batch)

    return torch.cat(all_alpha, dim=0) 

stu_probs = get_test_alpha(test_dataloader, classifier)

In [18]:
import evaluation  
stu_probs=stu_probs.cpu().numpy()
f1_score = evaluation.compute_metric(gt_labels_test, stu_probs, metric='f1')
ece = evaluation.compute_metric(gt_labels_test, stu_probs, metric='ece')
print('Student f1-score: {}, Student ECE: {}'.format(f1_score, ece))

Student f1-score: 0.9483830232236934, Student ECE: 0.030479537323117256


In [19]:
# Predictive entropy on sst2 test data
ent_sst2 = entropy_numpy(stu_probs)
print(ent_sst2)

[8.50894838e-04 3.96795012e-03 8.55409773e-04 1.11426064e-03
 1.85965863e-03 4.10740450e-03 6.62713200e-02 5.76008439e-01
 8.41927191e-04 4.35562851e-03 9.44217958e-04 4.40091500e-03
 2.62905564e-02 4.57564890e-01 1.94608793e-03 8.31549056e-04
 5.71298599e-01 5.27614611e-04 9.60300211e-03 3.16841295e-03
 3.10213596e-01 1.01728719e-02 1.39706712e-02 6.49171066e-04
 1.15419098e-03 1.43031366e-02 4.69876640e-03 4.89693973e-03
 2.43290840e-03 1.92579790e-03 1.72192173e-03 3.31451744e-03
 8.04147101e-04 4.00499851e-02 3.00773047e-03 4.10983711e-02
 1.58589741e-03 2.85308272e-01 2.30503222e-03 9.13104101e-04
 7.43988960e-04 1.34720095e-03 1.60268337e-01 8.75237340e-04
 2.29302887e-03 4.61061954e-01 3.36075597e-03 3.49129667e-03
 7.58618349e-04 2.54554977e-03 2.51035555e-03 1.98733737e-03
 2.40848988e-01 1.27145126e-02 1.80523982e-03 1.00607274e-03
 9.15724318e-03 2.39584267e-01 4.55524167e-03 2.42446247e-03
 1.43257820e-03 1.83217116e-02 5.42911608e-03 1.22100220e-03
 2.01691031e-01 1.332820

In [20]:
print(ent_sst2.mean())

0.059649516


In [21]:
# Load youtube comments dataset
df_train = pd.read_csv('youtube.csv')
n_train = 1100  
n_in_context = 5 

n_total_in_context = 9 * n_in_context  
n_val=100
df_train_actual = df_train.iloc[:n_train] 
df_in_context_base = df_train.iloc[n_train:n_train + n_total_in_context]
df_val = df_train.iloc[n_train + n_total_in_context:n_train+n_total_in_context+n_val]
df_test_actual = df_train.iloc[n_train+n_total_in_context+n_val:]  

gt_labels_train = df_train_actual.iloc[:, 4].values.astype(int) 
samples_train = df_train_actual.iloc[:, 3].values 
gt_labels_val = df_val.iloc[:, 4].values.astype(int) 
samples_val = df_val.iloc[:, 3].values 
gt_labels_test = df_test_actual.iloc[:, 4].values.astype(int)
samples_test = df_test_actual.iloc[:, 3].values 

In [22]:
class PromptFormatting(object):
    def __init__(self):
        self.INSTRUCTION = 'Judge whether the Youtube comment should be flagged as spam.'
        self.CLASSES = ['not spam', 'spam']
        self.CLASSES_FOR_MATCHING = [self.CLASSES, ['ham', 'spam'], ['0', '1']]
        self.CLASSES_TEXT = '''1. {}\n2. {}'''.format(self.CLASSES[0], self.CLASSES[1])

    def format_instruction(self, instruction):
        return '''{}\n{}\n'''.format(instruction, self.CLASSES_TEXT)

    def format_content(self, content):
        return '''comment: {}\nthe comment is '''.format(content)

prompt_formatting = PromptFormatting()
classifier = LLMClassifier(model=llm, prompt_formatting=prompt_formatting)

In [23]:
# Evaluate performance of model on youtube comments test data
class DirichletDataset(Dataset):
    def __init__(self, samples, n_samples):
        self.samples = samples
        self.n_samples = n_samples

    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        return self.samples[idx]

llm.model.eval()
test_dataset = DirichletDataset(samples_test, 711)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False) 

def get_test_alpha(test_dataloader, classifier):
    all_alpha = []

    with torch.no_grad():
        for batch_samples in test_dataloader:
            alpha_batch = classifier.soft_labels_batch(input_texts=batch_samples)
            all_alpha.append(alpha_batch)

    return torch.cat(all_alpha, dim=0) 



stu_probs = get_test_alpha(test_dataloader, classifier)

In [24]:
import evaluation  
stu_probs=stu_probs.cpu().numpy()
f1_score = evaluation.compute_metric(gt_labels_test, stu_probs, metric='f1')
ece = evaluation.compute_metric(gt_labels_test, stu_probs, metric='ece')
print('Student f1-score: {}, Student ECE: {}'.format(f1_score, ece))

Student f1-score: 0.6526290734770916, Student ECE: 0.15583941340446472


In [25]:
# Predictive entropy on youtube comments test data
ent = entropy_numpy(stu_probs)
print(ent)

[0.68108755 0.4733035  0.65510744 0.4924947  0.4924947  0.14531085
 0.0855129  0.18770973 0.598686   0.598686   0.28673667 0.5370392
 0.3835734  0.6804848  0.5336786  0.6087053  0.69040054 0.598686
 0.23086795 0.42448777 0.6561176  0.6833564  0.31778398 0.10269673
 0.2968649  0.5633357  0.27128315 0.52521867 0.31778398 0.31322876
 0.10054242 0.10176843 0.69265914 0.23395337 0.29979986 0.5286122
 0.65510744 0.36533386 0.19139497 0.36369503 0.4262181  0.30127415
 0.4418527  0.09092619 0.38861087 0.2838846  0.4021613  0.30127415
 0.17799634 0.67138517 0.01731138 0.07965304 0.6822505  0.41586155
 0.37856194 0.5387141  0.5166848  0.4924947  0.4523219  0.29249662
 0.33167794 0.54703087 0.16490765 0.27683643 0.52181333 0.1983931
 0.41758323 0.05518102 0.46630877 0.50117457 0.43663013 0.67859197
 0.11686215 0.4942335  0.54703087 0.52010643 0.10908221 0.18562979
 0.12111352 0.17550981 0.3208422  0.21415447 0.65305144 0.3254612
 0.6765737  0.6432272  0.0948991  0.633747   0.6804848  0.0683818
 0

In [26]:
print(ent.mean())

0.38787
