In [None]:
import pandas as pd
import numpy as np
import torch
import torch.optim as optim
import torch.nn.functional as F
from llm_classifier_modified import LLMClassifier
from llm_model_modified import LLM

In [None]:
# **Load modified datasets**
df_train = pd.read_csv('train_modified.csv', header=None)
df_test = pd.read_csv('test_modified.csv', header=None)

# **Dataset parameters**
n_train = 50000  
n_in_context = 5  
n_total_in_context = len(df_train) * n_in_context  
n_test = 5000
n_val = 100

# **Split Data**
df_train_actual = df_train.iloc[:n_train] 
df_in_context_base = df_train.iloc[n_train:n_train + n_total_in_context]
df_val = df_train.iloc[n_train + n_total_in_context:n_train + n_total_in_context + n_val]
df_test_actual = df_test.iloc[:n_test]  

# **Extract Training Data**
gt_labels_train = df_train_actual.iloc[:, 0].values.astype(int) 
samples_train = df_train_actual.iloc[:, 2].values 
gt_labels_val = df_val.iloc[:, 0].values.astype(int) 
samples_val = df_val.iloc[:, 2].values 

# **Extract Test Data**
gt_labels_test = df_test_actual.iloc[:, 0].values.astype(int)
samples_test = df_test_actual.iloc[:, 2].values  


In [None]:
# **Define Prompt Formatting Class**
class PromptFormatting(object):
    def __init__(self):
        self.INSTRUCTION = 'classify the sentiment of the Amazon review below into one of the following classes:'
        self.CLASSES = ['negative', 'positive']
        self.CLASSES_FOR_MATCHING = [self.CLASSES, ['neg', 'pos'], ['1', '2']]
        self.CLASSES_TEXT = '''1. {}\n2. {}'''.format(self.CLASSES[0], self.CLASSES[1])

    def format_instruction(self, instruction):
        return '''{}\n{}\n'''.format(instruction, self.CLASSES_TEXT)

    def format_content(self, content):
        return '''review: {}\nthe review is '''.format(content)

# **Load Model and Classifier**
llm = LLM(model_name="mistralai/Mistral-7B-Instruct-v0.3", use_lora=True)
prompt_formatting = PromptFormatting()
classifier = LLMClassifier(model=llm, prompt_formatting=prompt_formatting)

In [None]:
def dirichlet_loss(alpha, probs, weights):

    alpha_0 = alpha.sum(dim=0, keepdim=True)  
    log_gamma_alpha_0 = torch.lgamma(alpha_0)
    log_gamma_alpha_c = torch.lgamma(alpha).sum(dim=0, keepdim=True)

    weighted_log_probs = (alpha - 1) * torch.log(probs + 1e-8)

    class_sum = torch.sum(weighted_log_probs, dim=0) 
    weighted_terms = class_sum * weights
    final_sum = torch.sum(weighted_terms, dim=-1)  

    loss = -(log_gamma_alpha_0 - log_gamma_alpha_c + final_sum).mean() 
    return loss

In [None]:
# **Training Function**
def train_student(input_texts,probs, weights, num_epochs=10, learning_rate=1e-4):

    optimizer = optim.AdamW(llm.model.parameters(), lr=learning_rate)
    llm.model.train()

    for epoch in range(num_epochs):
        optimizer.zero_grad()
        
        # Compute Dirichlet parameters and probabilities
        alpha,_ = classifier.soft_labels_batch(input_texts=input_texts)

        # Compute loss
        loss = dirichlet_loss(alpha, probs, weights)

        # Backpropagation
        loss.backward()
        optimizer.step()

        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")

probs = torch.load("teacher_probs_promptwise.pt") 
weights = torch.load("prompt_weights.pt")
train_student(samples_train, probs, weights)
