<a href="https://colab.research.google.com/github/YoAkeHotaru/Erdos-Deep-Learning-2024-RAG-Project/blob/main/Reranker_finetune_naive.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Finetune a Re-ranker

In [1]:
import torch
import torch.nn.functional as F
import numpy as np
import pandas as pd

In [2]:
from google.colab import drive
drive.mount('/content/drive')
data = pd.read_json('/content/drive/MyDrive/Erdos_DL_2024/data/reddit.json')

#Build a test data for the first ~1000 submission chunks
def get_test_data(data, length):
    j = 0
    for i in range(len(data)):
        if (i > length):
            if (data.iloc[i].reddit_title != None):
                j = i - 1
                break
        elif (i > length):
            j = i
            break
    return data.iloc[:j]

class RedditNode:
    def __init__(self, data):
        self.type = data.aware_post_type
        self.subreddit = data.reddit_subreddit
        self.parent = data.reddit_parent_id
        self.id = data.reddit_name
        self.title = data.reddit_title
        self.text = data.reddit_text
        self.url = 'www.reddit.com' + data.reddit_permalink
        self.url2 = data.reddit_url
        self.next = []

    def is_submission(self):
        return self.title != None

    def link_next(self, node):
        self.next.append(node)

class RedditTree:
    def __init__(self, df):
        self.df = df
        self.nodes = None
        self.roots = None
        self.build()

    def build(self):
        #Create a dictionary of data nodes and roots(submission)
        self.nodes = {}
        self.roots = {}

        for i in range(len(self.df)):
            #Build current RedditNode and add to dictionary
            data = self.df.iloc[i]
            node = RedditNode(data)
            self.nodes[node.id] = node

            #If the node is a comment, find and link with the parent
            if (data.reddit_parent_id != None):
                parent = self.nodes.get(data.reddit_parent_id)
                parent.link_next(node)
            #If the node is a submission, add to roots dict
            else:
                self.roots[node.id] = node

    #get the submission text and url link of the roots
    def get_submission(self):
        if self.nodes == None:
            return 'None object'
        res = []
        url = []
        for r in self.roots:
            res.append(self.roots.get(r).title + ' [SEP] ' + self.roots.get(r).text)
            url.append(self.roots.get(r).url)
        return res, url

    #get the comment text associated with the node
    def get_comments(self, node):
        if self.nodes == None:
            return 'None object'
        res = []
        for child in node.next:
            res.append(child.text)
        return res

Mounted at /content/drive


In [3]:
from sklearn.model_selection import train_test_split
#Submissions in the test data
test = get_test_data(data, 100000)
data_tree = RedditTree(test)

data_starbucks = []
data_starbaristas = []
data_walmart = []
min_len = 40

for k,v in data_tree.roots.items():
    if ((len(v.title + v.text) < min_len) or (v.url2[:24] != 'https://www.reddit.com/r')): continue
    if v.subreddit == 'starbucks':
        data_starbucks.append(v)
    elif v.subreddit == 'starbucksbaristas':
        data_starbaristas.append(v)
    elif v.subreddit == 'WalmartEmployees':
        data_walmart.append(v)

data_pos = []
data_neg = []

#Postive pairs
min_len = 40
for post in (data_starbucks + data_walmart):
    query = post.title + '[SEP]' + post.text
    for comment in post.next:
        if len(comment.text) < min_len:
            continue
        else:
            example = [query, comment.text, 1]
            data_pos.append(example)

#Negative pairs
import random
neg_size = 1000
for i in range(neg_size):
    s = random.randint(0, len(data_starbucks)-1)
    w = random.randint(0, len(data_walmart)-1)
    p1 = data_starbucks[s]
    q1 = p1.title + '[SEP]' + p1.text
    p2 = data_walmart[w]
    q2 = p2.title + '[SEP]' + p2.text
    for comment in p1.next:
        if len(comment.text) < min_len:
            continue
        else:
            example = [q2, comment.text, 0]
            data_neg.append(example)
    for comment in p2.next:
        if len(comment.text) < min_len:
            continue
        else:
            example = [q1, comment.text, 0]
            data_neg.append(example)


train_pos, test_pos = train_test_split(data_pos, test_size = 0.2, shuffle = True, random_state = 420)
train_neg, test_neg = train_test_split(data_neg, test_size = 0.2, shuffle = True, random_state = 420)
train_data = train_pos + train_neg
test_data = test_pos + test_neg

test_labels = []
for l in test_data:
    test_labels.append(l[2])

In [97]:
train_data[:2]

[['Where’s my money ?!💁🏻\u200d♀️[SEP]Nov 1 I changed positions from a stocker ($17) to electronic Teamlead($20) and when I check  my even app it still says I’m getting paid  the same($17)  but my position changed. Is there suppose to be a delay or something. If so I wasn’t told..',
  'The only delay is usually new pays kicks in next pay period. Stop looking at Even, seriously. Like it says, it’s just an estimate, it’ll be the last person/thing to know you got a new payrate. Instead, check your workday compensation page or the My Money page on OneWalmart to see what you’re being paid.',
  1],
 ['A coworker that hates me walked in on me crying[SEP]They were making snarky comments about me like they usually do, and I was overstimulated. I usually work part-time because I can’t really handle being social with multiple strangers for several hours, but someone called off and they really needed someone to stay longer so I was the one who did it. I shouldn’t of. Anyway, on my lunchbreak I comp

In [98]:
test_data[-2:]

[['what does it mean when you\'re called in for a global ethics investigation?[SEP]My friend is freaking out bc she was called in to like testify or something about a "global ethics investigation" against an employee but we don\'t know what that means.',
  "Customers don't understand what the difference would be. if they place the order at the box, it still takes the same amount of time until they get to the window. \n\nI always place my order when I'm leaving my house, but there's 100% been times when my phone hasn't disconnected from my wifi yet and the order doesn't go through. If it's just my drink I'll order at the box but if I'm grabbing a drink for my coworker, I don't have hers memorized so it's way easier to just hit the button on the app. I'm sure plenty of people order in-app for similar reasons.\n\nIt may help if you let customers know if there's something different between ordering at the window vs from the app in the lot. I assume the timer on those orders is different? B

In [4]:
!pip install accelerate

[31mERROR: Invalid requirement: 'transformers,'[0m[31m
[0m

In [None]:
!pip install sentence_transformers
from sentence_transformers import SentenceTransformer

# Finetune CrossEncoder

In [6]:
from sentence_transformers import SentenceTransformer, losses, CrossEncoder
from sentence_transformers.readers import InputExample
from torch.utils.data import DataLoader
import sentence_transformers


model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', max_length=512, device = 'cuda')

config.json:   0%|          | 0.00/794 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/316 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

In [7]:
model.predict([('what does starbuck baristas do?', 'starbucks baristas prepare coffee and serve customers'), ('what does starbuck baristas do?', 'starbucks baristas drink coffee')])

array([-1.1548539, -3.127532 ], dtype=float32)

In [8]:
from sentence_transformers import evaluation

In [9]:
train_examples = []
test_examples = []

for example in train_data:
  train_examples.append(InputExample(texts = [example[0], example[1]], label = example[2]))

for example in test_data:
  test_examples.append(InputExample(texts = [example[0], example[1]], label = example[2]))

In [10]:
len(train_examples)

11286

## Evaluation result before fine-tuning

In [11]:
total = len(test_examples)
correct = 0
for test in test_examples:
  text = test.texts
  label = test.label
  pred = model.predict((text[0], text[1]))
  if ((pred >= 0 and label == 1) or (pred < 0 and label == 0)):
    correct += 1
print(f"The accuracy of the model before finetuning on the test data set is {correct}/{total} = {correct/total}")

The accuracy of the model before finetuning on the test data set is 1782/2823 = 0.6312433581296493


In [12]:
small_len = 100
small_test_idx = [random.randint(0, total - 1) for i in range(small_len)]
small_test = [test_examples[idx] for idx in small_test_idx]

In [13]:
def small_eval(model, small_test):
  small_len = len(small_test)
  correct = 0
  for test in small_test:
    text = test.texts
    label = test.label
    pred = model.predict((text[0], text[1]))
    if ((pred >= 0 and label == 1) or (pred < 0 and label == 0)):
      correct += 1
  print(f"The accuracy of the model before finetuning on the test data set is {correct}/{small_len} = {correct/small_len}")

In [14]:
#This takes around 10 seconds to run on a CPU
small_eval(model, small_test)

The accuracy of the model before finetuning on the test data set is 58/100 = 0.58


In [15]:
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)
train_loss = losses.ContrastiveLoss(model=model)

In [61]:
#from sentence_transformers.cross_encoder import evaluation
#evaluator = evaluation.CEBinaryAccuracyEvaluator(test_examples, labels = [1,0])

In [16]:
num_epochs = 2
warmup_steps = 0

for n in range(num_epochs):
  model.fit(
      train_dataloader=train_dataloader,
      #evaluator=evaluator,
      #evaluation_steps=1,
      epochs=1,
      #warmup_steps=warmup_steps,
      #output_path=model_save_path,
  )
  small_eval(model, small_test)


Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

Iteration:   0%|          | 0/706 [00:00<?, ?it/s]

The accuracy of the model before finetuning on the test data set is 78/100 = 0.78


Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

Iteration:   0%|          | 0/706 [00:00<?, ?it/s]

The accuracy of the model before finetuning on the test data set is 81/100 = 0.81


In [17]:
model.predict([('what does a worker at a coffee shop do?', 'starbucks baristas prepare cafe and serve customers'), ('what does a baristas do?', 'walmart employees push carts')])

array([ 1.3057055, -9.353149 ], dtype=float32)

In [31]:
query = "Walmart: What does a maintenance associate do? I got hired as a maintenance associate a few weeks ago and I still have no clue what I’m supposed to do. "
text1 = "You are a cart pusher and janitor essentially. You push carts most of the day. Clean the bathrooms, restock the bathrooms, clean spills around the store, and that’s about it."
text2 = "In simple, you worry about the trash/bathrooms. Stocking them up or keeping them clean or any calls to clean in aisles. Never heard of maintenance to grab carts but maybe that’s your store?[SEP]Yeah that’s why I’m confused because that’s how all of the maintenance people are treated as cart pushers and maintenance. They schedule 1 person on the busiest days and then expect everything to get done but I flat out refuse to stop doing carts or whatever just because they want me to clean 1 tiny spot (happens all the time)[SEP]"
model.predict([(query, text1), (query, text2)])

array([-0.3682082,  1.3098446], dtype=float32)

In [22]:
#Save the model
model.save("/content/drive/MyDrive/Erdos_DL_2024/Reranker")

In [23]:
model2 = CrossEncoder('/content/drive/MyDrive/Erdos_DL_2024/Reranker')

In [27]:
model2

<sentence_transformers.cross_encoder.CrossEncoder.CrossEncoder at 0x7b19288756f0>

#Rerank function

In [None]:
def rerank(query, documents):
  #documents is a list of string
  input = [(query, document) for document in documents]
  scores = model.predict(input)
  ranking = sorted(range(len(scores)), key=lambda k: scores[k], reverse=True) # return the index of sorted list
  res = [documents[i] for i in ranking]
  return res