###Setup Tasks###

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/gdrive')

###Import Statements###

In [43]:
import os
import random

import numpy as np
import pandas as pd
import sentencepiece as spm
import pyarrow.parquet as pq
import pyarrow as pa

from gensim.models import Word2Vec


import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

In [2]:
# Paths (only for running on Mark's Mac)
os.chdir('/content/gdrive/My Drive/Colab Notebooks/Document Ranking/')

data_dir = os.getcwd()
spm_path = os.getcwd()
w2v_path = os.getcwd()
train_set_path = os.getcwd() + '/train-00000-of-00001.parquet'

###Constants and Global Variables###

In [3]:
# Constants for data preparation
data_max_len = 50 # 512
chunk_size = 10 # make much larger - set to a multiple of the number of rows in a row group of the parquet file

###Class Definitions###

In [None]:
class ChunkedTripletDataset(torch.utils.data.Dataset):
    def __init__(self, data_dir, spm_path, w2v_path, max_len, chunk_size):
        self.data_dir = data_dir
        self.sp_model_path = spm_path
        self.word2vec_model_path = w2v_path
        self.max_len = max_len
        self.chunk_size = chunk_size

        self.data_files = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith('.json')]
        self.total_examples = sum(self.count_examples(f) for f in self.data_files)

    def count_examples(self, file_path):
        with open(file_path, 'r') as f:
            return sum(1 for _ in f)

    def __len__(self):
        return self.total_examples

    def __getitem__(self, idx):
        for file_path in self.data_files:
            with open(file_path, 'r') as f:
                for batch_examples in self.chunk_data(f):
                    if idx < len(batch_examples):
                        example = batch_examples[idx]
                        dataset = TripletDataset([example], self.sp_model_path, self.word2vec_model_path, self.max_len)
                        return dataset[0]
                    else:
                        idx -= len(batch_examples)

    def chunk_data(self, file_obj):
        while True:
            chunk = list(islice(file_obj, self.chunk_size))
            if not chunk:
                break
            chunk = [json.loads(line) for line in chunk]
            yield chunk

Function definitions.

In [50]:
def generate_triplets(query_list, pos_doc_list, neg_doc_list):

    # Create neg_doc_list - do this before calling this function


    # Assuming you have the following lists:
    queries = ['q1', 'q2', ..., 'q100']
    positive_documents = [['pd1_1', 'pd1_2', ..., 'pd1_10'], ['pd2_1', 'pd2_2', ..., 'pd2_10'], ..., ['pd100_1', 'pd100_2', ..., 'pd100_10']]
    negative_documents = [['nd1_1', 'nd1_2', ..., 'nd1_10'], ['nd2_1', 'nd2_2', ..., 'nd2_10'], ..., ['nd100_1', 'nd100_2', ..., 'nd100_10']]

    # Step 1: Flatten positive and negative documents into a single list
    pos_doc_tensor = [doc for docs in positive_documents for doc in docs]
    neg_doc_tensor = [doc for docs in negative_documents for doc in docs]

    # Step 2: Repeat each query for the number of corresponding positive documents
    query_tensor = [query for query, docs in zip(queries, positive_documents) for _ in docs]

    # Convert lists to tensors
    pos_doc_tensor = torch.tensor(pos_doc_tensor)
    neg_doc_tensor = torch.tensor(neg_doc_tensor)
    query_tensor = torch.tensor(query_tensor)

    # Verification
    print("pos_doc_tensor:", pos_doc_tensor)
    print("neg_doc_tensor:", neg_doc_tensor)
    print("query_tensor:", query_tensor)


    # tokenize and embed the tensors
    # call function to do this

    # sort the tensors

    # Calculate the lengths of the documents in pos_doc_tensor and neg_doc_tensor
    pos_doc_lengths = [len(doc) for doc in pos_doc_tensor]
    neg_doc_lengths = [len(doc) for doc in neg_doc_tensor]

    # Determine which tensor has the longest documents
    if max(pos_doc_lengths) > max(neg_doc_lengths):
        longest_tensor = pos_doc_tensor
    else:
        longest_tensor = neg_doc_tensor

    # Sort the longest tensor along with the other tensors
    sorted_indices = torch.argsort(longest_tensor.apply(len))  # Sort indices based on document length
    pos_doc_tensor = pos_doc_tensor[sorted_indices]
    neg_doc_tensor = neg_doc_tensor[sorted_indices]
    query_tensor = query_tensor[sorted_indices]

    # Verification
    print("pos_doc_tensor:", pos_doc_tensor)
    print("neg_doc_tensor:", neg_doc_tensor)
    print("query_tensor:", query_tensor)



    triplets_list = []
    num_rows = query_and_relevant_doc_df.shape[0]

    for index, row in query_and_relevant_doc_df.iterrows():
        query = row["query"]  # Assuming 'query' is the column name for the query text
        passages = row["passages"]  # Accessing the dictionary in the 'passages' column
        passage_texts = passages["passage_text"]  # Extracting the list of passage texts


        for relevant_document in passage_texts:
            # Randomly select another index
            random_row_index = random.randint(0, num_rows - 1)
            # Spliting each passage into a sentence
            sentences = sent_tokenize(relevant_document)
            # appending each sentence into corpus.txt
            with open('corpus.txt', 'a') as file:
                for sentence in sentences:
                    file.write(sentence + '\n')  # Write each sentence on a new line

            while index == random_row_index:
                random_row_index = random.randint(0, num_rows - 1)


            # Retrieve a passage from the randomly selected row
            random_passages = query_and_relevant_doc_df.loc[random_row_index, "passages"]
            random_passage_texts = random_passages["passage_text"]

            # Optionally, select a random passage text from the selected row
            irrelevant_document = random.choice(random_passage_texts)
            triplets_list.append((query, relevant_document, irrelevant_document))

    return query_tensor,  pos_doc_tensor, neg_doc_tensor

Main program logic.

In [None]:
# Load pre-trained models
sp = spm.SentencePieceProcessor()
sp.load('spm.model')
w2v = Word2Vec.load('w2v.model')

# Create dataset and dataloader
dataset = ChunkedTripletDataset(data_dir, spm_path, w2v_path, data_max_len, chunk_size)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

In [52]:

# This code is used to modify ChunkedTripletDataset() class




# Load the .parquet file
parquet_file = pq.ParquetFile(train_set_path)

# Print the total number of rows and row groups
total_rows = parquet_file.metadata.num_rows
total_row_groups = parquet_file.num_row_groups
print(f'Total rows in train parquet file: {total_rows}')
print(f'Total row groups in train parquet file: {total_row_groups}')

# Define the schema for the columns
schema = pa.schema([
    ("query", pa.string()),
    ("passages", pa.struct([
        ("is_selected", pa.list_(pa.int32())),
        ("passage_text", pa.list_(pa.string())),
        ("url", pa.list_(pa.string()))
    ]))
])

# Iterate over the file in chunks
for i in range(0, parquet_file.num_row_groups, chunk_size):

    # Read a chunk of rows with selected columns
    table_chunk = parquet_file.read_row_group(i, use_threads=True, columns=['query', 'passages'])
    table = pa.Table.from_arrays(table_chunk.columns, schema=schema)

    # Get the query data
    queries = table['query']
    query_list = [str(query) for query in queries]
    print("\nQUERIES:\n", query_list)

    # Get the positive document data
    passages = table['passages']
    passages_struct = passages[0]
    pos_doc_list = passages_struct['passage_text']
    print("\nPOSITIVE DOCS:\n", pos_doc_list)

    # Create neg_doc_list
    # Select random docs to build the list
    neg_doc_list = []

    # Generate dataset
    generate_triplets(query_list, pos_doc_list, neg_doc_list)

    break




Total rows in train parquet file: 82326
Total row groups in train parquet file: 83

QUERIES:
 ['what is rba', 'was ronald reagan a democrat', 'how long do you need for sydney and surrounding areas', 'price to install tile in shower', 'why conversion observed in body', 'where are the lungs located in the back', 'cost to get a patent', 'what does a metabolic acidosis need to reverse the condition', 'best tragedies of ancient greece', 'what is a conifer', 'in animals somatic cells are produced by and gametic cells are produced by', 'remembering the name of the author who wrote the cat in the hat', 'how long cooking chicken legs in the big easy', 'average cost of heating per square foot', 'is mount pinatubo made of granite or basalt', 'concrete pads cost', 'what kind of organism is a black damsel', 'who coined the phrase it is what it is', 'what is oilskin fabric', 'how long is german measles contagious', 'what is a camerata', 'how long does it take to bake a pound cake', 'what is the maxi

TypeError: 'ellipsis' object is not iterable

In [33]:
training_query_dataset = pd.read_parquet(train_set_path)
passages = training_query_dataset.passages

pd.set_option('display.max_colwidth', None)

print(passages.head(1))
print(passages.info())

0    {'is_selected': [0, 0, 0, 0, 0, 1, 0, 0, 0, 0], 'passage_text': ['Since 2007, the RBA's outstanding reputation has been affected by the 'Securency' or NPA scandal. These RBA subsidiaries were involved in bribing overseas officials so that Australia might win lucrative note-printing contracts. The assets of the bank include the gold and foreign exchange reserves of Australia, which is estimated to have a net worth of A$101 billion. Nearly 94% of the RBA's employees work at its headquarters in Sydney, New South Wales and at the Business Resumption Site.', 'The Reserve Bank of Australia (RBA) came into being on 14 January 1960 as Australia 's central bank and banknote issuing authority, when the Reserve Bank Act 1959 removed the central banking functions from the Commonwealth Bank. The assets of the bank include the gold and foreign exchange reserves of Australia, which is estimated to have a net worth of A$101 billion. Nearly 94% of the RBA's employees work at its headquarters in Sy