In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [2]:
import sys
import os
import pandas as pd
import numpy as np
import torch
import torch.nn.functional as F
from bpe import BayesPE 
from llm_model import LLM
import evaluation
import constants

In [3]:
# Define task instructions
instructions = [
    'classify the question and answer below into one of the following topics:',
    'Assign a topic label to the following question and answer from the list provided:',
    'Determine which topic best fits the question and answer shown below:',
    'Categorize the following Q&A under one of these topics:',
    'Select the most appropriate topic for the question and answer pair below:',
    'Choose the correct topic category for the given question and answer:',
    'Identify the topic that the following question and answer belong to:',
    'Match the question and answer below to the relevant topic:',
    'Label the question and answer below with the most fitting topic from the list:'
]

In [4]:
# Load Yahoo Answers train and test data
df_train = pd.read_csv('train_yahoo.csv', header=None)
df_test = pd.read_csv('test_yahoo.csv', header=None)

n_train = 50000  
n_in_context = 5  
n_val = 100
n_test = 5000
n_total_in_context = len(instructions) * n_in_context
df_train_actual = df_train.iloc[:n_train]
df_in_context_base = df_train.iloc[n_train:n_train + n_total_in_context]
df_val = df_train.iloc[n_train + n_total_in_context:n_train + n_total_in_context + n_val]
df_test_actual = df_test.iloc[:n_test]
def format_prompt(q1, q2, a):
    return "Question: " + q1.astype(str) + " " + q2.astype(str) + "\nAnswer: " + a.astype(str)

gt_labels_train = df_train_actual.iloc[:, 0].values.astype(int)
samples_train = format_prompt(df_train_actual.iloc[:, 1], df_train_actual.iloc[:, 2], df_train_actual.iloc[:, 3]).values
gt_labels_val = df_val.iloc[:, 0].values.astype(int)
samples_val = format_prompt(df_val.iloc[:, 1], df_val.iloc[:, 2], df_val.iloc[:, 3]).values
gt_labels_test = df_test_actual.iloc[:, 0].values.astype(int)
samples_test = format_prompt(df_test_actual.iloc[:, 1], df_test_actual.iloc[:, 2], df_test_actual.iloc[:, 3]).values

for i in range(len(instructions)):
    start_idx = i * n_in_context
    end_idx = (i + 1) * n_in_context
    df_in_context = df_in_context_base.iloc[start_idx:end_idx]

    samples_in_context_i = format_prompt(df_in_context.iloc[:, 1], df_in_context.iloc[:, 2], df_in_context.iloc[:, 3]).values
    gt_labels_in_context_i = df_in_context.iloc[:, 0].values.astype(int)

    if i == 0:
        samples_in_context = np.expand_dims(samples_in_context_i, axis=1)
        gt_labels_in_context = np.expand_dims(gt_labels_in_context_i, axis=1)
    else:
        samples_in_context = np.concatenate((samples_in_context, np.expand_dims(samples_in_context_i, axis=1)), axis=1)
        gt_labels_in_context = np.concatenate((gt_labels_in_context, np.expand_dims(gt_labels_in_context_i, axis=1)), axis=1)


In [5]:
# Define a prompt formatting class for topic classification and initializes an LLM-based classifier
class PromptFormatting(object):
    def __init__(self):
        self.INSTRUCTION = 'classify the question and answer below into one of the following topics:'
        self.CLASSES = [
    'Society & Culture',
    'Science & Mathematics',
    'Health',
    'Education & Reference',
    'Computers & Internet',
    'Sports',
    'Business & Finance',
    'Entertainment & Music',
    'Family & Relationships',
    'Politics & Government'
]
        self.CLASSES_FOR_MATCHING = [self.CLASSES]
        self.CLASSES_TEXT = '''1. {}\n2. {}\n3. {}\n4. {}\n5. {}\n6. {}\n7. {}\n8. {}\n9. {}\n10. {}'''.format(self.CLASSES[0],self.CLASSES[1], self.CLASSES[2], self.CLASSES[3], self.CLASSES[4], self.CLASSES[5], self.CLASSES[6], self.CLASSES[7], self.CLASSES[8], self.CLASSES[9])
    def format_instruction(self, instruction):
        return '''{}\n{}\n'''.format(instruction, self.CLASSES_TEXT)

    def format_content(self, content):
        return '''{}\nthe topic is '''.format(content)

prompt_formatting = PromptFormatting()



# Initialize BayesPE (Teacher Model)
bayespe_classifier = BayesPE(
    model_name="mistralai/Mistral-7B-Instruct-v0.3", 
    prompt_formatting=prompt_formatting,
    instructions=instructions, 
    few_shot_texts_sets=samples_in_context, 
    few_shot_labels_sets=gt_labels_in_context, 
    use_reduced_precision=True
)

# Print example prompt
bayespe_classifier.print_prompt_example()

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

EXAMPLE 1:
classify the question and answer below into one of the following topics:
1. Society & Culture
2. Science & Mathematics
3. Health
4. Education & Reference
5. Computers & Internet
6. Sports
7. Business & Finance
8. Entertainment & Music
9. Family & Relationships
10. Politics & Government

Question: when you talk about the volume of a gas are you refering to the volume of the molecules themselves? explain?
Answer: No, the volume refers to the total space in which those molecules are found moving around (should be the same as the volume of the container). In any case, atoms and molecules are pretty much all empty space themselves - most of the mass is concentrated in the nucleus, but the electron cloud takes up a lot more space.
the topic is Science & Mathematics

EXAMPLE 2:
classify the question and answer below into one of the following topics:
1. Society & Culture
2. Science & Mathematics
3. Health
4. Education & Reference
5. Computers & Internet
6. Sports
7. Business & Finan

In [6]:
# Optimize prompt weights
bayespe_classifier.optimise_weights(samples_val, gt_labels_val)

inference for promt 1 out of 9


100%|█████████████████████████████████████████| 100/100 [00:21<00:00,  4.72it/s]


inference for promt 2 out of 9


100%|█████████████████████████████████████████| 100/100 [00:21<00:00,  4.67it/s]


inference for promt 3 out of 9


100%|█████████████████████████████████████████| 100/100 [00:20<00:00,  4.92it/s]


inference for promt 4 out of 9


100%|█████████████████████████████████████████| 100/100 [00:23<00:00,  4.33it/s]


inference for promt 5 out of 9


100%|█████████████████████████████████████████| 100/100 [00:18<00:00,  5.40it/s]


inference for promt 6 out of 9


100%|█████████████████████████████████████████| 100/100 [00:19<00:00,  5.01it/s]


inference for promt 7 out of 9


100%|█████████████████████████████████████████| 100/100 [00:20<00:00,  4.82it/s]


inference for promt 8 out of 9


100%|█████████████████████████████████████████| 100/100 [00:23<00:00,  4.25it/s]


inference for promt 9 out of 9


100%|█████████████████████████████████████████| 100/100 [00:26<00:00,  3.75it/s]


iteration 0, loss: 157.72197375504285


array([0.20779699, 0.12849686, 0.06699287, 0.07335307, 0.0165612 ,
       0.0174329 , 0.3899038 , 0.0432402 , 0.05622215], dtype=float32)

In [7]:
# Get prompt weights and prompt wise probabilities on train data
_,probs,weights = bayespe_classifier.forward(samples_train, n_forward_passes=9)

inference for promt 1 out of 9


100%|███████████████████████████████████| 50000/50000 [2:34:49<00:00,  5.38it/s]


inference for promt 2 out of 9


100%|███████████████████████████████████| 50000/50000 [2:47:11<00:00,  4.98it/s]


inference for promt 3 out of 9


100%|███████████████████████████████████| 50000/50000 [3:16:54<00:00,  4.23it/s]


inference for promt 4 out of 9


100%|███████████████████████████████████| 50000/50000 [3:41:26<00:00,  3.76it/s]


inference for promt 5 out of 9


100%|███████████████████████████████████| 50000/50000 [2:51:42<00:00,  4.85it/s]


inference for promt 6 out of 9


100%|███████████████████████████████████| 50000/50000 [3:14:01<00:00,  4.29it/s]


inference for promt 7 out of 9


100%|███████████████████████████████████| 50000/50000 [3:02:30<00:00,  4.57it/s]


inference for promt 8 out of 9


100%|███████████████████████████████████| 50000/50000 [2:58:27<00:00,  4.67it/s]


inference for promt 9 out of 9


100%|███████████████████████████████████| 50000/50000 [2:54:00<00:00,  4.79it/s]


In [8]:
torch.save(probs,'yahoo_probs.pt')
torch.save(weights,'yahoo_prompt_weights.pt')

In [9]:
# Evaluate BayesPE performance on yahoo answers test data
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

start.record()
teacher_probs,_,_ = bayespe_classifier.forward(samples_test, n_forward_passes=9)
end.record()

# Wait for the events to be recorded
torch.cuda.synchronize()

# Report in seconds
elapsed_time_ms = start.elapsed_time(end) 
elapsed_time_sec = elapsed_time_ms / 1000  

print(f"Inference time: {elapsed_time_sec:.4f} seconds")
print(teacher_probs[:10, :])
f1_score = evaluation.compute_metric(gt_labels_test, teacher_probs, metric='f1')
ece = evaluation.compute_metric(gt_labels_test, teacher_probs, metric='ece')
print('Teacher f1-score: {}, Teacher ECE: {}'.format(f1_score, ece))

inference for promt 1 out of 9


100%|███████████████████████████████████████| 5000/5000 [15:23<00:00,  5.42it/s]


inference for promt 2 out of 9


100%|███████████████████████████████████████| 5000/5000 [16:37<00:00,  5.01it/s]


inference for promt 3 out of 9


100%|███████████████████████████████████████| 5000/5000 [19:32<00:00,  4.26it/s]


inference for promt 4 out of 9


100%|███████████████████████████████████████| 5000/5000 [22:08<00:00,  3.76it/s]


inference for promt 5 out of 9


100%|███████████████████████████████████████| 5000/5000 [17:07<00:00,  4.87it/s]


inference for promt 6 out of 9


100%|███████████████████████████████████████| 5000/5000 [19:24<00:00,  4.30it/s]


inference for promt 7 out of 9


100%|███████████████████████████████████████| 5000/5000 [18:11<00:00,  4.58it/s]


inference for promt 8 out of 9


100%|███████████████████████████████████████| 5000/5000 [17:44<00:00,  4.70it/s]


inference for promt 9 out of 9


100%|███████████████████████████████████████| 5000/5000 [17:19<00:00,  4.81it/s]


Inference time: 9808.3280 seconds
[[1.93036354e-01 2.57357897e-05 3.65092082e-05 5.45402775e-03
  1.26509534e-05 2.28543786e-05 2.70930552e-05 2.19088936e-03
  7.99151768e-01 4.21669012e-05]
 [1.35133549e-01 8.50703641e-01 2.55414551e-03 1.08454302e-02
  3.69958667e-05 8.94036967e-05 1.82077871e-05 3.35173085e-04
  2.58672280e-04 2.48294007e-05]
 [4.56086223e-02 5.87470741e-02 2.42778728e-04 9.58485288e-02
  4.28586243e-04 3.46439326e-03 7.02500764e-05 7.87274751e-01
  8.14672994e-03 1.68333976e-04]
 [1.41373868e-05 2.61183168e-05 1.04747918e-05 9.99873712e-01
  1.19752805e-05 1.01062609e-05 1.32049412e-05 1.46576673e-05
  1.51186073e-05 1.05436736e-05]
 [4.40773411e-04 1.32909309e-02 9.80813199e-01 4.28599835e-03
  1.04046251e-05 1.14920683e-05 1.05185334e-05 1.28341023e-05
  1.10336284e-03 2.05349281e-05]
 [4.77729435e-01 1.17814362e-04 3.41906512e-02 8.39525247e-03
  4.76183763e-05 4.25647387e-05 1.95559361e-02 3.28599674e-04
  4.46479828e-01 1.31123485e-02]
 [2.05624900e-04 7.34238