In [None]:
!pip install --upgrade git+https://github.com/terrierteam/pyterrier_t5.git
!pip install ir_datasets

In [None]:
#loading ir_datasets module
from google.colab import drive
import os

drive.mount("/content/drive", force_remount=True)

os.environ['IR_DATASETS_HOME'] = "/content/drive/MyDrive/Colab Notebooks/Dissertation/ir_datasets"

print(os.getenv('IR_DATASETS_HOME'))

Mounted at /content/drive
/content/drive/MyDrive/Colab Notebooks/Dissertation/ir_datasets


In [None]:
#Importing libraries
import ir_datasets
import torch
torch.cuda.empty_cache()
from torch.nn import CrossEntropyLoss, Softmax
import pandas as pd
import pyterrier as pt
if not pt.started():
  pt.init()
from pyterrier.measures import *
from pyterrier_t5 import MonoT5ReRanker
from transformers import T5ForConditionalGeneration, T5Tokenizer
from transformers import AdamW
from random import Random
import itertools
import pickle

terrier-assemblies 5.6 jar-with-dependencies not found, downloading to /root/.pyterrier...
Done
terrier-python-helper 0.0.6 jar not found, downloading to /root/.pyterrier...
Done


PyTerrier 0.8.1 has loaded Terrier 5.6 (built by craigmacdonald on 2021-09-17 13:27)



##Load Dataset

In [None]:
with open('/content/drive/MyDrive/Colab Notebooks/Dissertation/Training and Eval/train_data.pkl', 'rb') as handle:
    data = pickle.load(handle)
data

Unnamed: 0,qid,query,docno,rank,relevance,weight
0,1,A potlatch is considered an example of,4063745-0,1.0,false,0.000000
1,1,A potlatch is considered an example of,4063746-0,2.0,false,0.000000
2,1,A potlatch is considered an example of,4063747-0,3.0,false,0.000000
3,1,A potlatch is considered an example of,4063748-0,4.0,false,0.000000
4,1,A potlatch is considered an example of,4063749-0,5.0,false,0.000000
...,...,...,...,...,...,...
6999995,864260,what is yield to worst,4221105-0,4.0,false,0.000000
6999996,864260,what is yield to worst,4221106-0,5.0,false,0.000000
6999997,864260,what is yield to worst,4221107-0,6.0,false,0.000000
6999998,864260,what is yield to worst,4221108-0,7.0,false,0.000000


##Removing Class Imbalance

In [None]:
#Removing class imbalance

# Total number of non-relevant doucments - 6541972
nr = data[data['relevance'] == 'false']

# Total number of relevant doucments - 458028
r = data[data['relevance'] == 'true']

#Undersampling
n_rel = pd.DataFrame(nr.sample(frac = 0.0700137))
rel = pd.DataFrame(r.sample(frac=1))

print(n_rel.shape)
rel.shape

(458028, 6)


(458028, 6)

#T5

In [None]:
#Setting batch size
BATCH_SIZE = 16

torch.manual_seed(0)

_logger = ir_datasets.log.easy()

OUTPUTS = ['true', 'false']

#Function to send query-document pair as well as its relevance and weight to the trianing model
def iter_train_samples():
  while True: 
    for (nr_idx, nr_row), (r_idx, r_row) in zip(n_rel.iterrows(), rel.iterrows()):
        yield 'Query: ' + str(r_row['query']) + ' Document: ' + str(docs.get(r_row['docno']).text), OUTPUTS[0], r_row['weight']
        #Setting weight of non-relevant documents to 0
        yield 'Query: ' + str(nr_row['query']) + ' Document: ' + str(docs.get(nr_row['docno']).text), OUTPUTS[1], 0


train_iter = _logger.pbar(iter_train_samples(), desc='total train samples')


model = T5ForConditionalGeneration.from_pretrained("t5-base").cuda()
tokenizer = T5Tokenizer.from_pretrained("t5-base")
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4)


reranker = MonoT5ReRanker(verbose=False, batch_size=BATCH_SIZE)
reranker.REL = tokenizer.encode(OUTPUTS[0])[0]
reranker.NREL = tokenizer.encode(OUTPUTS[1])[0]

In [None]:
#Function to build validation dataset
def build_validation_data():
  result = []
  #Loading TREC 2019 dataset
  dataset = ir_datasets.load('msmarco-passage/trec-dl-2019/judged')
  docs = dataset.docs_store()
  queries = {q.query_id: q.text for q in dataset.queries_iter()}
  for qrel in _logger.pbar(ir_datasets.load('msmarco-passage/trec-dl-2019/judged').scoreddocs, desc='dev data'):
    if qrel.query_id in queries:
      result.append([qrel.query_id, queries[qrel.query_id], qrel.doc_id, docs.get(qrel.doc_id).text])
  return pd.DataFrame(result, columns=['qid', 'query', 'docno', 'text'])

valid_data = build_validation_data()
valid_qrels = pt.get_dataset('irds:msmarco-passage/trec-dl-2019/judged').get_qrels()

In [None]:
epoch = 0
max_ndcg = 0.

while epoch!= 10:
  with _logger.pbar_raw(desc=f'train {epoch}', total=16384 // BATCH_SIZE) as pbar:
    model.train()
    total_loss = 0
    count = 0
    for _ in range(16384 // BATCH_SIZE):
      inp, out, weight = [], [], []
      for i in range(BATCH_SIZE):
        i, o, p = next(train_iter)
        inp.append(i)
        out.append(o)
        weight.append(p)
      inp_ids = tokenizer(inp, return_tensors='pt', padding=True).input_ids.cuda()
      out_ids = tokenizer(out, return_tensors='pt', padding=True).input_ids.cuda()

      # ------------------------------CUSTOM LOSS FUNCTION---------------------------------------------
      weight_tensor=torch.mean(torch.FloatTensor(weight)).cuda()
      #Multiplying by two to take the mean of 8 relevant documents rathen than all 16 documents
      weight_tensor = 2*weight_tensor
      #Using inbuilt loss function
      loss_mean = model(input_ids=inp_ids, labels=out_ids).loss
      #Multiplying mean of weights to the mean of loss 
      loss = (loss_mean * weight_tensor)
      # -----------------------------------------------------------------------------------------------

      loss.backward()
      optimizer.step()
      optimizer.zero_grad()
      total_loss = loss.item()
      count += 1
      pbar.update(1)
      pbar.set_postfix({'loss': total_loss/count})

  with _logger.duration(f'valid {epoch}'):
    reranker.model = model
    reranker.verbose = True
    res = reranker(valid_data)
    reranker.verbose = False
    metrics = {'epoch': epoch, 'loss': total_loss / count}
    metrics.update(pt.Utils.evaluate(res, valid_qrels, metrics=[nDCG, RR(rel=2)]))
    _logger.info(metrics)
    with open('log.jsonl', 'at') as f:
      f.write(json.dumps(metrics) + '\n')
    model.save_pretrained(f'/content/drive/MyDrive/Colab Notebooks/Dissertation/Models/relwgt-nrwgt0-{epoch}')
    if metrics['nDCG'] > max_ndcg:
      _logger.info('New Best nDCG')
      max_ndcg = metrics['nDCG']
  epoch += 1