In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [None]:
import pandas as pd
import numpy as np
import torch
import torch.optim as optim
import torch.nn.functional as F
from llm_classifier_modified import LLMClassifier
from llm_model_modified import LLM

In [None]:
df_train = pd.read_csv('train_modified.csv', header=None)
df_test = pd.read_csv('test_modified.csv', header=None)

n_train = 50000  
n_in_context = 5  
n_total_in_context = len(df_train) * n_in_context  
n_test = 5000
n_val = 100

df_train_actual = df_train.iloc[:n_train] 
df_in_context_base = df_train.iloc[n_train:n_train + n_total_in_context]
df_val = df_train.iloc[n_train + n_total_in_context:n_train + n_total_in_context + n_val]
df_test_actual = df_test.iloc[:n_test]  

gt_labels_train = df_train_actual.iloc[:, 0].values.astype(int) 
samples_train = df_train_actual.iloc[:, 2].values 
gt_labels_val = df_val.iloc[:, 0].values.astype(int) 
samples_val = df_val.iloc[:, 2].values 

gt_labels_test = df_test_actual.iloc[:, 0].values.astype(int)
samples_test = df_test_actual.iloc[:, 2].values  


In [None]:
class PromptFormatting(object):
    def __init__(self):
        self.INSTRUCTION = 'classify the sentiment of the Amazon review below into one of the following classes:'
        self.CLASSES = ['negative', 'positive']
        self.CLASSES_FOR_MATCHING = [self.CLASSES, ['neg', 'pos'], ['1', '2']]
        self.CLASSES_TEXT = '''1. {}\n2. {}'''.format(self.CLASSES[0], self.CLASSES[1])

    def format_instruction(self, instruction):
        return '''{}\n{}\n'''.format(instruction, self.CLASSES_TEXT)

    def format_content(self, content):
        return '''review: {}\nthe review is '''.format(content)

llm = LLM(model_name="mistralai/Mistral-7B-Instruct-v0.3", use_reduced_precision=True,use_lora=True)
prompt_formatting = PromptFormatting()
classifier = LLMClassifier(model=llm, prompt_formatting=prompt_formatting)

In [None]:
import torch

def dirichlet_loss(alpha, probs, weights):
    
    alpha_0 = torch.sum(alpha, dim=1, keepdim=True)

    log_gamma_alpha_0 = torch.lgamma(alpha_0)  
    log_gamma_alpha_c = torch.lgamma(alpha).sum(dim=1, keepdim=True)

    alpha_expanded = alpha.unsqueeze(-1)  
    weighted_log_probs = (alpha_expanded - 1) * torch.log(probs + 1e-8)

    class_sum = weighted_log_probs.sum(dim=1) 
    weighted_terms = class_sum * weights
    prompt_sum = weighted_terms.sum(dim=-1) 

    loss = -(log_gamma_alpha_0 - log_gamma_alpha_c + prompt_sum).mean()

    return loss


In [None]:
# Load teacher predictions and weights
probs = torch.load("teacher_probs_promptwise.pt", weights_only=False)
weights = torch.load("prompt_weights.pt", weights_only=False)
if isinstance(probs, np.ndarray):
    probs = torch.tensor(probs, dtype=torch.float32, device=llm.device)
if isinstance(weights, np.ndarray):
    weights = torch.tensor(weights, dtype=torch.float32, device=llm.device)

In [None]:
import torch.optim as optim

def train_student(samples_train, probs, weights, num_epochs=10, learning_rate=1e-4):
    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, llm.model.parameters()), lr=learning_rate)
    llm.model.train() 

    for epoch in range(num_epochs):
        optimizer.zero_grad()

        alpha = classifier.soft_labels_batch(input_texts=samples_train)

        loss = dirichlet_loss(alpha, probs, weights)
        
        assert loss.requires_grad, "Loss does not require gradients!"

        print(f"\n[Epoch {epoch+1}] Checking Gradients BEFORE Backpropagation:")
        for name, param in llm.model.named_parameters():
            if param.requires_grad:
                grad_status = "None" if param.grad is None else "Has Gradient"
                print(f"  {name}: {grad_status}")

     
        loss.backward()

        print(" Checking Gradients AFTER Backpropagation:")
        for name, param in llm.model.named_parameters():
             if not param.requires_grad:
                     print(f"{name} is detached (requires_grad=False)")
             elif param.grad is None:
                     print(f"{name} is trainable but has no gradient (not updated)")
             else:
                     print(f"{name} is trainable with {param.numel()} params")


        optimizer.step()

        torch.cuda.empty_cache()

        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")

train_student(samples_train, probs, weights)
