In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import pandas as pd
import numpy as np
from transformers import AutoTokenizer, AutoModel

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
class CFG:
    seed = 42  # Random seed
    preset = "bert-base-uncased" # Name of pretrained models
    sequence_length = 512  # Input sequence length
    epochs = 3 # Training epochs
    batch_size = 16  # Batch size
    scheduler = 'cosine'  # Learning rate scheduler
    label2name = {0: 'winner_model_a', 1: 'winner_model_b', 2: 'winner_tie'}
    name2label = {v:k for k, v in label2name.items()}
    class_labels = list(label2name.keys())
    class_names = list(label2name.values())

In [4]:
BASE_PATH = './data'
# Load Train Data
df = pd.read_csv(f'{BASE_PATH}/train.csv') 

df["prompt"] = df.prompt.map(lambda x: eval(x)[0])
df["response_a"] = df.response_a.map(lambda x: eval(x.replace("null","''"))[0])
df["response_b"] = df.response_b.map(lambda x: eval(x.replace("null", "''"))[0])

df["class_name"] = df[["winner_model_a", "winner_model_b" , "winner_tie"]].idxmax(axis=1)
df["class_label"] = df.class_name.map(CFG.name2label)

df.head()

Unnamed: 0,id,model_a,model_b,prompt,response_a,response_b,winner_model_a,winner_model_b,winner_tie,class_name,class_label
0,30192,gpt-4-1106-preview,gpt-4-0613,Is it morally right to try to have a certain p...,The question of whether it is morally right to...,"As an AI, I don't have personal beliefs or opi...",1,0,0,winner_model_a,0
1,53567,koala-13b,gpt-4-0613,What is the difference between marriage licens...,A marriage license is a legal document that al...,A marriage license and a marriage certificate ...,0,1,0,winner_model_b,1
2,65089,gpt-3.5-turbo-0613,mistral-medium,explain function calling. how would you call a...,Function calling is the process of invoking or...,Function calling is the process of invoking a ...,0,0,1,winner_tie,2
3,96401,llama-2-13b-chat,mistral-7b-instruct,How can I create a test set for a very rare ca...,Creating a test set for a very rare category c...,When building a classifier for a very rare cat...,1,0,0,winner_model_a,0
4,198779,koala-13b,gpt-3.5-turbo-0314,What is the best way to travel from Tel-Aviv t...,The best way to travel from Tel Aviv to Jerusa...,The best way to travel from Tel-Aviv to Jerusa...,0,1,0,winner_model_b,1


In [5]:
# Load Test Data
df_test = pd.read_csv(f'{BASE_PATH}/test.csv')

# Take the first prompt and response
df_test["prompt"] = df_test.prompt.map(lambda x: eval(x)[0])
df_test["response_a"] = df_test.response_a.map(lambda x: eval(x.replace("null","''"))[0])
df_test["response_b"] = df_test.response_b.map(lambda x: eval(x.replace("null", "''"))[0])

# Show Sample
df_test.head()

Unnamed: 0,id,prompt,response_a,response_b
0,136060,"I have three oranges today, I ate an orange ye...",You have two oranges today.,You still have three oranges. Eating an orange...
1,211333,You are a mediator in a heated political debat...,Thank you for sharing the details of the situa...,Mr Reddy and Ms Blue both have valid points in...
2,1233961,How to initialize the classification head when...,When you want to initialize the classification...,To initialize the classification head when per...


In [6]:
df_train, df_valid = train_test_split(df, test_size=0.2, stratify=df["class_label"])

In [7]:
df_train.head()

Unnamed: 0,id,model_a,model_b,prompt,response_a,response_b,winner_model_a,winner_model_b,winner_tie,class_name,class_label
10451,775229282,gpt-3.5-turbo-0613,vicuna-33b,Create a chess game in chess notation between ...,Sure! Here's a chess game in algebraic notatio...,Here's a hypothetical chess game between Hikar...,0,1,0,winner_model_b,1
52831,3950226447,mixtral-8x7b-instruct-v0.1,mistral-medium,"251, 239, 239 to rgb hexa","I'll remember to always assist with care, resp...","The RGB color values you provided (251, 239, 2...",0,0,1,winner_tie,2
26168,1947607959,mixtral-8x7b-instruct-v0.1,pplx-70b-online,"I have a package, MetFamily (https:\/\/github....",To view the logs of the Shiny application runn...,To see the log and details of the running code...,1,0,0,winner_model_a,0
18815,1406691554,gpt-4-1106-preview,gemini-pro-dev-api,Do you like feet,"As an AI, I don't have personal likes or disli...",I am a virtual assistant and do not have perso...,1,0,0,winner_model_a,0
13804,1027167915,gpt-3.5-turbo-0613,gpt-4-1106-preview,Can you make a Matlab livescript simulating wa...,Sure! I can help you with that. Here's a MATLA...,Certainly! To simulate water waves generated b...,0,1,0,winner_model_b,1


In [8]:
df_valid.head()

Unnamed: 0,id,model_a,model_b,prompt,response_a,response_b,winner_model_a,winner_model_b,winner_tie,class_name,class_label
20487,1526327941,mistral-medium,mixtral-8x7b-instruct-v0.1,How fast an infection can kill someone ?,The speed at which an infection can kill someo...,The speed at which an infection can kill someo...,1,0,0,winner_model_a,0
34781,2581593185,gpt-4-0613,claude-1,Please draw a grid ascii map of the real world...,Creating an ASCII map of the real world is qui...,Here is a very rough ascii grid map of the maj...,1,0,0,winner_model_a,0
11228,831606738,llama2-70b-steerlm-chat,gpt-4-1106-preview,Can you describe the notion of a relational fr...,"In Relational Frame Theory (RFT), a relational...",Relational Frame Theory (RFT) is a psychologic...,0,1,0,winner_model_b,1
50489,3770890701,claude-instant-1,vicuna-7b,I don't know much about AWS. Can you recommend...,Here are some AWS services that can be used to...,Sure! AWS offers several services that you can...,1,0,0,winner_model_a,0
39525,2942949049,chatglm2-6b,mistral-7b-instruct,"in dominion, is chariot race a good open?","In Dominion, the Chariot Race is considered to...",The Chariot Race in Dominion can be a good ope...,0,0,1,winner_tie,2


In [9]:
# Define a function to create options based on the prompt and choices
def make_pairs(row):
    row["encode_fail"] = False
    try:
        prompt = row.prompt.encode("utf-8").decode("utf-8")
    except:
        prompt = ""
        row["encode_fail"] = True

    try:
        response_a = row.response_a.encode("utf-8").decode("utf-8")
    except:
        response_a = ""
        row["encode_fail"] = True

    try:
        response_b = row.response_b.encode("utf-8").decode("utf-8")
    except:
        response_b = ""
        row["encode_fail"] = True
        
    row['options'] = [f"Prompt: {prompt}\n\nResponse: {response_a}",  # Response from Model A
                      f"Prompt: {prompt}\n\nResponse: {response_b}"  # Response from Model B
                     ]
    return row

In [10]:
df_train = df_train.apply(make_pairs, axis=1)  # Apply the make_pairs function to each row in df
display(df_train.head(2))  # Display the first 2 rows of df

df_valid = df_valid.apply(make_pairs, axis=1)  # Apply the make_pairs function to each row in df
display(df_valid.head(2))  # Display the first 2 rows of df

Unnamed: 0,id,model_a,model_b,prompt,response_a,response_b,winner_model_a,winner_model_b,winner_tie,class_name,class_label,encode_fail,options
10451,775229282,gpt-3.5-turbo-0613,vicuna-33b,Create a chess game in chess notation between ...,Sure! Here's a chess game in algebraic notatio...,Here's a hypothetical chess game between Hikar...,0,1,0,winner_model_b,1,False,[Prompt: Create a chess game in chess notation...
52831,3950226447,mixtral-8x7b-instruct-v0.1,mistral-medium,"251, 239, 239 to rgb hexa","I'll remember to always assist with care, resp...","The RGB color values you provided (251, 239, 2...",0,0,1,winner_tie,2,False,"[Prompt: 251, 239, 239 to rgb hexa\n\nResponse..."


Unnamed: 0,id,model_a,model_b,prompt,response_a,response_b,winner_model_a,winner_model_b,winner_tie,class_name,class_label,encode_fail,options
20487,1526327941,mistral-medium,mixtral-8x7b-instruct-v0.1,How fast an infection can kill someone ?,The speed at which an infection can kill someo...,The speed at which an infection can kill someo...,1,0,0,winner_model_a,0,False,[Prompt: How fast an infection can kill someon...
34781,2581593185,gpt-4-0613,claude-1,Please draw a grid ascii map of the real world...,Creating an ASCII map of the real world is qui...,Here is a very rough ascii grid map of the maj...,1,0,0,winner_model_a,0,False,[Prompt: Please draw a grid ascii map of the r...


In [11]:
class CustomDataset(Dataset):
    def __init__(self, texts, labels=None, num_classes=3, preprocess_fn=None):
        self.texts = texts
        self.labels = labels
        self.num_classes = num_classes
        self.preprocess_fn = preprocess_fn  # Optional preprocessing function

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = None
        if self.labels is not None:
            label = self.labels[idx]
            # Convert labels to one-hot encoding if needed
            label = np.eye(self.num_classes)[label].astype(np.float32)
        if self.preprocess_fn:
            text = self.preprocess_fn(text)
        return (text, label) if label is not None else text


# Define the function to build the DataLoader
def build_dataloader(texts, labels=None, batch_size=32, shuffle=True, preprocess_fn=None):
    dataset = CustomDataset(texts, labels, preprocess_fn=preprocess_fn)
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        drop_last=False,  # Similar to `drop_remainder=False`
        num_workers=4,    # Parallel data loading
        prefetch_factor=2 # Prefetch batches
    )
    return dataloader

In [None]:
# Train
train_texts = df_train.options.tolist()  # Extract training texts
train_labels = df_train.class_label.tolist()  # Extract training labels
train_dataloader = build_dataloader(
    texts=train_texts,
    labels=train_labels,
    batch_size=CFG.batch_size,
    shuffle=True
)

# Valid
valid_texts = df_valid.options.tolist()  # Extract validation texts
valid_labels = df_valid.class_label.tolist()  # Extract validation labels
valid_dataloader = build_dataloader(
    texts=valid_texts,
    labels=valid_labels,
    batch_size=CFG.batch_size,
    shuffle=False
)

In [None]:
print(next(iter(train_dataloader)))