###Setup Tasks###

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/gdrive')

###Import Statements###

In [32]:
import os
import random

import numpy as np
import pandas as pd
import sentencepiece as spm
import pyarrow.parquet as pq
import pyarrow as pa

from gensim.models import Word2Vec


import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

In [33]:
# Paths (only for running on Mark's Mac)
os.chdir('/content/gdrive/My Drive/Colab Notebooks/Document Ranking')

data_dir = os.getcwd()
spm_path = os.getcwd() + '/spm.model'
w2v_path = os.getcwd() + '/w2v.model'
train_set_path = os.getcwd() + '/train-00000-of-00001.parquet'

###Constants and Global Variables###

In [72]:
# Constants for data preparation
data_max_len = 1000 # 512
chunk_size = 1000 # make much larger - set to a multiple of the number of rows in a row group of the parquet file
batch_size = 100

###Class Definitions###

In [73]:
class ChunkedTripletDataset(torch.utils.data.Dataset):
    def __init__(self, parquet_file_path, chunk_size, data_max_len):
        super().__init__()
        self.parquet_file = pq.ParquetFile(parquet_file_path)
        self.chunk_size = chunk_size
        self.total_rows = self.parquet_file.metadata.num_rows
        self.total_row_groups = self.parquet_file.num_row_groups

        # Define the schema for the columns
        self.schema = pa.schema([
            ("query", pa.string()),
            ("passages", pa.struct([
            ("is_selected", pa.list_(pa.int32())),
            ("passage_text", pa.list_(pa.string())),
            ("url", pa.list_(pa.string()))
            ]))
        ])

        # Print the total number of rows and row groups
        print(f'Total number of rows: {self.total_rows}')
        print(f'Total number of row groups: {self.total_row_groups}')

    def __len__(self):
        return self.total_rows

    def __iter__(self):
        start_row = 0
        end_row = self.total_rows

        for start in range(start_row, end_row, self.chunk_size):
            end = min(start + self.chunk_size, end_row)
            table_chunk = self.parquet_file.read_rows(start, end, columns=['query', 'passages'])
            table = pa.Table.from_batches([table_chunk])

            query_list = table['query'].to_pylist()
            passages = table['passages'].to_pylist()

            neg_doc_list = []
            for query, passage_row in zip(query_list, passages):
                pos_doc_list = passage_row['passage_text']
                print("POS DOC LIST CHECK:", pos_doc_list)

                # Randomly select negative documents from other rows
                random_index = random.choice([i for i in range(len(query_list)) if query_list[i] != query])
                neg_pos_doc_list = passages[random_index]['passage_text']
                neg_doc_list.append(neg_pos_doc_list)

            # Instantiate the TokenizerAndEmbedder class
            spm_path = "your_spm_model_path.model"
            w2v_path = "your_word2vec_model_path.model"
            tokenizer_and_embedder = TokenizerAndEmbedder(spm_path, w2v_path)

            # Generate dataset
            yield generate_triplets(query_list, pos_doc_list, neg_doc_list, tokenizer_and_embedder, data_max_len)

    def __getitem__(self, index):
        # Calculate the start and end indices for the current chunk
        #start_index = index * self.chunk_size
        #end_index = min((index + 1) * self.chunk_size, self.total_rows)

        # Calculate the row group index and row index within the row group
        row_group_index = index // self.chunk_size
        row_index_within_group = index % self.chunk_size

        # Check if the row group index is valid
        if row_group_index >= self.total_row_groups:
            raise IndexError(f"Row group index {row_group_index} out of bounds")

        # Read the data chunk from the Parquet file
        table_chunk = self.parquet_file.read_row_group(row_group_index, columns=['query', 'passages'])
        table = pa.Table.from_batches([table_chunk])

        # Extract queries and passages from the data chunk
        query_list = table['query'].to_pylist()
        passages = table['passages'].to_pylist()

        # Initialize lists to store positive and negative documents
        pos_doc_list = []
        neg_doc_list = []

        for query, passage_row in zip(query_list, passages):
            pos_doc_list.append(passage_row['passage_text'])

            # Randomly select a negative document from another row
            random_index = random.choice([i for i in range(len(query_list)) if i != index])
            neg_doc_list.append(passages[random_index]['passage_text'])

        # Instantiate the TokenizerAndEmbedder class
        spm_path = "your_spm_model_path.model"
        w2v_path = "your_word2vec_model_path.model"
        tokenizer_and_embedder = TokenizerAndEmbedder(spm_path, w2v_path)

        # Generate triplets for the current chunk
        return generate_triplets(query_list, pos_doc_list, neg_doc_list, tokenizer_and_embedder, self.data_max_len)


class TokenizerAndEmbedder:
    def __init__(self, spm_path, w2v_path):
        self.sp = spm.SentencePieceProcessor()
        self.sp.load(spm_path)
        self.w2v = Word2Vec.load(w2v_path)

    def tokenize_and_embed(self, input_tensor):
        tokenized_tensor = self.sp.EncodeAsPieces(input_tensor)
        embedded_tensor = [self.w2v.wv[word] for word in tokenized_tensor if word in self.w2v.wv]

        return embedded_tensor





###Function definitions###

In [74]:
def generate_triplets(queries, pos_docs, neg_docs, tokenizer_and_embedder, data_max_len):

    # Truncate queries and documents to data_max_len
    queries = [query[:data_max_len] for query in queries]
    pos_docs = [[doc[:data_max_len] for doc in docs] for docs in pos_docs]
    neg_docs = [[doc[:data_max_len] for doc in docs] for docs in neg_docs]

    # Build query_tensor
    query_list = [query for query, docs in zip(queries, pos_docs) for _ in docs]
    query_tensor = torch.tensor(query_list)

    # SBuild pos_doc_tensor
    pos_doc_list = [doc for docs in pos_docs for doc in docs]
    pos_doc_tensor = torch.tensor(pos_doc_list)

    # Build neg_doc_tensor
    neg_doc_list = [doc for docs in neg_docs for doc in docs]
    neg_doc_tensor = torch.tensor(neg_doc_list)

    # Verification
    print("pos_doc_tensor:", pos_doc_tensor)
    print("neg_doc_tensor:", neg_doc_tensor)
    print("query_tensor:", query_tensor)

    # tokenize and embed the tensors
    pos_doc_tensor_embedded = [tokenizer_and_embedder.tokenize_and_embed(doc) for doc in pos_doc_list]
    neg_doc_tensor_embedded = [tokenizer_and_embedder.tokenize_and_embed(doc) for doc in neg_doc_list]

    # Sort the tensors
    # Calculate the lengths of the documents in pos_doc_tensor and neg_doc_tensor
    pos_doc_lengths = [len(doc) for doc in pos_doc_tensor]
    neg_doc_lengths = [len(doc) for doc in neg_doc_tensor]

    # Determine which tensor has the longest documents
    if max(pos_doc_lengths) > max(neg_doc_lengths):
        longest_tensor = pos_doc_tensor
    else:
        longest_tensor = neg_doc_tensor

    # Sort the longest tensor along with the other tensors
    sorted_indices = torch.argsort(longest_tensor.apply(len))  # Sort indices based on document length
    pos_doc_tensor = pos_doc_tensor[sorted_indices]
    neg_doc_tensor = neg_doc_tensor[sorted_indices]
    query_tensor = query_tensor[sorted_indices]

    # Verification
    print("pos_doc_tensor:", pos_doc_tensor)
    print("neg_doc_tensor:", neg_doc_tensor)
    print("query_tensor:", query_tensor)

    return (query_tensor, pos_doc_tensor, neg_doc_tensor)


def tokenize_and_embed(input_tensor, sp, w2v):
    tokenized_tensor = sp(input_tensor)
    embedded_tensor = w2v(tokenized_tensor)

    return embedded_tensor



###Main program logic###

In [75]:
# Create dataset and dataloader
dataset = ChunkedTripletDataset(train_set_path, chunk_size=chunk_size, data_max_len=data_max_len)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

Total number of rows: 82326
Total number of row groups: 83


###Test the dataloader###

In [76]:
for batch in dataloader:
    query_tensor, pos_doc_tensor, neg_doc_tensor = batch
    print("Query tensor size:", query_tensor.size())
    print("Positive document tensor size:", pos_doc_tensor.size())
    print("Negative document tensor size:", neg_doc_tensor.size())

TypeError: Cannot convert pyarrow.lib.Table to pyarrow.lib.RecordBatch