In [None]:
import numpy as np
import scipy
import pandas as pd

In [None]:
#from langchain_community.document_loaders import WikipediaLoader
#docs = WikipediaLoader(query='Barack Obama', load_max_docs=2, doc_content_chars_max=100000).load()
#print(len(docs))
# print(docs[0].metadata)
# print(docs[0].page_content)
from datasets import load_dataset
ds = load_dataset('cjlovering/natural-questions-short')
class DocWrapper(object):
  def __init__(self, s):
    self.page_content = s['contexts']
    self.metadata = {}
docs = [DocWrapper(ex) for ex in ds['train']]
print(docs[:10])
print(docs[0].page_content)

In [None]:
import os
create_chunks = True
if os.path.isfile('data/chunks.json'):
    create_chunks = False

In [None]:
from langchain_text_splitters import RecursiveCharacterTextSplitter, TokenTextSplitter, CharacterTextSplitter, NLTKTextSplitter,\
                                     SentenceTransformersTokenTextSplitter, SpacyTextSplitter
from langchain_experimental.text_splitter import SemanticChunker
from langchain_huggingface import HuggingFaceEmbeddings

if create_chunks:
    def get_splits(splitter, data):
        return [split.page_content for split in splitter.split_documents(docs)]

    rec_char_splitter = RecursiveCharacterTextSplitter(
        chunk_size=1000, chunk_overlap=200, add_start_index=True
    )
    rec_char_splitter_2 = RecursiveCharacterTextSplitter(
        chunk_size=2000, chunk_overlap=200, add_start_index=True
    )
    token_splitter = TokenTextSplitter(
        encoding_name='gpt2', chunk_size=100, chunk_overlap=0
    )
    char_splitter = CharacterTextSplitter(
        separator='\n', is_separator_regex=False
    )
    nltk_splitter = NLTKTextSplitter(
        separator='\n\n', language='english'
    )
    sentence_transformer_splitter = SentenceTransformersTokenTextSplitter(
        chunk_overlap=50, model_name='sentence-transformers/all-mpnet-base-v2', tokens_per_chunk=None
    )
    spacy_splitter = SpacyTextSplitter(
        separator='\n\n', pipeline='en_core_web_sm', max_length=1000000
    )
    semantic_chunker = SemanticChunker(
        embeddings=HuggingFaceEmbeddings(), buffer_size=1, add_start_index=True
    )
    #... add more, like the SemanticChunker, etc.                    # DONE
    #... use different parameters for these (especially chunk sizes) # TODO: How wide of a parameter space should we use? Default is 4000/200

In [None]:
if create_chunks:
    all_splits, all_ids = [], []
    splitters = [rec_char_splitter, rec_char_splitter_2, token_splitter, char_splitter, nltk_splitter, \
                sentence_transformer_splitter]
    splitters = [rec_char_splitter, token_splitter]
    for splitter in splitters:
        if splitter.__class__.__name__ != 'SemanticChunker':
            print(f'{splitter.__class__.__name__}_{splitter._chunk_size}_{splitter._chunk_overlap}')
        else:
            print(f'{splitter.__class__.__name__}')
        splits = get_splits(splitter, docs)
        all_splits.extend(splits)
        all_ids.extend([f'id_{splitter.__class__.__name__}_{ii}' for ii in range(len(splits))])
    
    df = pd.DataFrame.from_dict({'chunk': all_splits, 'id': all_ids})
    df.to_json('data/chunks.json')
else:
    df = pd.read_json('data/chunks.json')
    all_splits = list(df['chunk'].values)
    all_ids = list(df['id'].values)
    
print(all_splits[0])
print(len(all_splits))
print(all_splits[:5])
# print(all_ids)

In [None]:
# from sklearn.feature_extraction.text import TfidfVectorizer
# from scipy import sparse


# class BM25(object):
#     def __init__(self, b=0.75, k1=1.6):
#         self.vectorizer = TfidfVectorizer(norm=None, smooth_idf=False)
#         self.b = b
#         self.k1 = k1

#     def fit(self, X):
#         ''' Fit IDF to documents X '''
#         self.vectorizer.fit(X)
#         y = super(TfidfVectorizer, self.vectorizer).transform(X)
#         self.avdl = y.sum(1).mean()

#     def transform(self, q, X):
#         ''' Calculate BM25 between query q and documents X '''
#         b, k1, avdl = self.b, self.k1, self.avdl

#         # apply CountVectorizer
#         X = super(TfidfVectorizer, self.vectorizer).transform(X)
#         len_X = X.sum(1).A1
#         q, = super(TfidfVectorizer, self.vectorizer).transform([q])
#         assert sparse.isspmatrix_csr(q)

#         # convert to csc for better column slicing
#         X = X.tocsc()[:, q.indices]
#         denom = X + (k1 * (1 - b + b * len_X / avdl))[:, None]
#         # idf(t) = log [ n / df(t) ] + 1 in sklearn, so it need to be coneverted
#         # to idf(t) = log [ n / df(t) ] with minus 1
#         idf = self.vectorizer._tfidf.idf_[None, q.indices] - 1.
#         numer = X.multiply(np.broadcast_to(idf, X.shape)) * (k1 + 1)                                                          
#         return (numer / denom).sum(1).A1


In [None]:
from fast_bm25 import BM25
my_splits = [x.split() for x in all_splits]
bm25 = BM25(my_splits)

In [None]:
results = bm25.get_top_n(['largest', 'city', 'in', 'Japan'], my_splits, n=10)
print(results)

In [None]:
# bm25 = BM25()
# bm25.fit(all_splits)

In [None]:
def retrieve_from_query(query, k):
    tokenized_query = query.split()
    retrieve = bm25.get_top_n(tokenized_query, my_splits, n=k)
    retrieved_docs = [' '.join(x) for x in retrieve]
    return retrieved_docs

In [None]:
# import chromadb
# chroma_client = chromadb.Client()
# try:
#     collection = chroma_client.create_collection(name='my_collection')
# except:
#     chroma_client.delete_collection(name='my_collection')
#     collection = chroma_client.create_collection(name='my_collection')
# collection.add(
#     documents=all_splits[:100],
#     ids=all_ids[:100]
# )

In [None]:
# results = collection.query(
#     query_texts=['where is Obama from'],
#     n_results=2 # how many results to return
# )
# print(results)

In [None]:

from torch.utils.data import DataLoader
from datasets import Dataset
from transformers import AutoTokenizer

batch_size = 8
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-125m')

# Now pass the retrieval results to the LLM (basically, RAG with frozen components)

# Baseline: deduplicate retrieval results somehow

# The neural net we train:
# - For each question:
#    - Encode each chunk and output a score (how to encode?)
#    - Turn all scores over all chunks into a distribution P(d|q)
#    - For each chunk get the NLL of the correct answer as another distribution Q(d|q) (from a RAG 'open domain QA' dataset; use HF transformers to get the NLL)
#    - Minimize the KL div of those two distributions (KL_div(P||Q)) (https://arxiv.org/abs/2301.12652 - REPLUG)
# - Add loss term for the length of the sequence (number of tokens)
# - Train

# For Chroma
# def add_retrieval_results(ex):
#     ex['retrieved'] = collection.query(
#         query_texts=[ex['questions'][0]['input_text']],
#         n_results=10 # the higher the better
#     )['documents']
#     # one alternative here for speed could be to not use Chroma but some kind of sparse search like Elastic
#     return ex

# For (fast) BM25
def add_retrieval_results(ex):
    ex['retrieved'] = retrieve_from_query(ex['questions'][0]['input_text'], k=10)
    # one alternative here for speed could be to not use Chroma but some kind of sparse search like Elastic
    return ex

# Data pre-processing (tokenize and de-nesting)
def preprocess(ex):
    ex['questions'] = ex['questions'][0]['input_text']
    ex['answers'] = ex['answers'][0]['span_text']

    # tokenized_query = tokenizer(ex['questions'], padding='max_length', truncation=True, max_length=64, return_tensors='pt')
    # ex['input'] = tokenized_query

    # tokenized_docs = [tokenizer(doc, padding='max_length', truncation=True, max_length=100, return_tensors='pt') for doc in ex['retrieved']]
    # ex['tokenized_docs'] = tokenized_docs

    # tokenized_answer = tokenizer(ex['answers'], padding='max_length', truncation=True, max_length=16, return_tensors='pt')
    # ex['tokenized_answer'] = tokenized_answer
    return ex

In [None]:
print('1')
tiny_train_dataset = Dataset.from_dict(ds['train'][:])
print('2')
tiny_train_dataset = tiny_train_dataset.map(add_retrieval_results)
print('3')
tiny_train_dataset = tiny_train_dataset.map(preprocess)
print('4')
tiny_train_dataset = tiny_train_dataset.remove_columns(['contexts', 'has_correct_context', 'name', 'id'])
tiny_train_loader = DataLoader(tiny_train_dataset, batch_size=batch_size, shuffle=True)

tiny_valid_dataset = Dataset.from_dict(ds['validation'][:])
tiny_valid_dataset = tiny_valid_dataset.map(add_retrieval_results)
tiny_valid_dataset = tiny_valid_dataset.map(preprocess)
tiny_valid_loader = DataLoader(tiny_valid_dataset, batch_size=batch_size, shuffle=True)

# train_dataset = ds['train'].map(add_retrieval_results)
# train_dataset = train_dataset.map(tokenize)
# train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# valid_dataset = ds['valid'].map(add_retrieval_results)
# valid_dataset = valid_dataset.map(tokenize)
# valid_loader = DataLoader(valid_dataset, batch_size=batch_size)

In [None]:
print(tiny_train_dataset[0])

In [None]:
# print(tiny_train_dataset)
# print(tiny_train_dataset['retrieved'])
# print(tiny_train_dataset['input'])
# print(tiny_train_dataset['tokenized_docs'])
# print(tiny_train_dataset[0])
# print(tiny_train_dataset['questions'])

In [None]:
batch = next(iter(tiny_train_loader))
print(batch['questions'])
# print(batch['retrieved'][0])
# print(batch['input']['input_ids'][0])
print(batch['answers'])

In [None]:
import torch
print((batch['retrieved']))
print([x for retr in batch['retrieved'] for x in retr])
print(tokenizer([x for retr in batch['retrieved'] for x in retr], padding=True, return_tensors='pt'))
tokenized = tokenizer([x for retr in batch['retrieved'] for x in retr], padding=True, return_tensors='pt')
print(tokenized['input_ids'].shape)
tokenized['input_ids'] = torch.reshape(tokenized['input_ids'], (len(batch['retrieved'][0]), len(batch['retrieved']), -1))
tokenized['attention_mask'] = torch.reshape(tokenized['attention_mask'], (len(batch['retrieved'][0]), len(batch['retrieved']), -1))
print(tokenized['input_ids'].shape)
# print(tokenizer(batch['retrieved']))
# print([tokenizer(x, padding=True, return_tensors='pt') for x in batch['retrieved']])
# torch.stack([tokenizer(x, padding=True, return_tensors='pt')['input_ids'] for x in batch['retrieved']])

In [None]:
import wandb
wandb.init(
    # set the wandb project where this run will be logged
    project='chunking_experiments',

    # track hyperparameters and run metadata
    config={
    'learning_rate': 0.1,
    'architecture': 'Transformer',
    'dataset': 'NQ',
    'epochs': 10,
    }
)

In [None]:
from transformers import AutoModelForCausalLM, AutoModel
model_id = 'facebook/opt-125m'
llm_model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)

In [None]:
import torch
import torch.nn as nn
class Encoder(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.vocab_size = vocab_size
        self.emb_dim = 256
        self.dim_ff = 1024
        self.n_head = 4
        self.n_layers = 4
        self.embedding = nn.Embedding(num_embeddings=self.vocab_size, embedding_dim=self.emb_dim)
        self.encoding_layer = nn.TransformerEncoderLayer(d_model=self.emb_dim, nhead=self.n_head, dim_feedforward=self.dim_ff, batch_first=True)
        self.encoding_block = nn.TransformerEncoder(self.encoding_layer, num_layers=self.n_layers)

    def forward(self, input, mask):
        x = self.embedding(input)
        x = self.encoding_block(x, src_key_padding_mask=mask)
        # x = self.encoding_block(x)
        # print('before mean', x.shape, x)
        x = torch.mean(x, dim=1)
        # print('after mean', x.shape, x)
        return x

In [None]:
import lightning as L
from tqdm import tqdm

from lightning.pytorch.demos import Transformer

# TODO by Andrew: fill in the rest from here: https://lightning.ai/docs/pytorch/stable/common/lightning_module.html
class ReplugTransformer(L.LightningModule):
    def __init__(self, vocab_size):
        super().__init__()
        # self.encoder = Transformer(vocab_size=vocab_size)
        self.encoder = Encoder(vocab_size)
        self.llm = llm_model  # LLM to get NLLs for reference distribution in KL div #TODO: separate 
        for param in self.llm.parameters():
            param.requires_grad = False
    
    def forward(self, questions, docs):
        questions, questions_mask = questions['input_ids'], torch.logical_not(torch.tensor(questions['attention_mask'], dtype=torch.bool))
        docs, docs_mask = docs['input_ids'], torch.logical_not(torch.tensor(docs['attention_mask'], dtype=torch.bool))
        # Encode query and documents
        print('questions', questions)
        print('questions_mask', questions_mask)
        # print(len(docs), query.shape, docs[0].shape)
        print('docs', docs.shape, docs)
        n_examples = questions.shape[0]
        q_emb = self.encoder(questions, questions_mask)
        print('q_emb', q_emb.shape, q_emb)
        d_embs = []
        # Squeeze and unsqueeze to pass batch of retrievals into encoder at once?
        # Turn B x K x L to (BK) x L
        B, K, L = docs.shape
        print(B, K, L)
        input_docs = torch.reshape(docs, (B * K, L))
        input_docs_mask = torch.reshape(docs_mask, (B * K, -1))
        print(input_docs.shape)
        d_embs = self.encoder(input_docs, input_docs_mask)
        d_embs = torch.reshape(d_embs, (B, K, -1))
        # for doc in docs:
        #     # print(doc)
        #     d_emb = self.encoder(doc)
        #     d_embs.append(d_emb)
        print(d_embs.shape)

        d_scores = torch.einsum('bij,bjk->bik', d_embs, torch.unsqueeze(q_emb, -1))
        d_scores = d_scores.squeeze()
        print('d_scores', d_scores.shape, d_scores)

        # # Calculate cosine similarity between query and all documents
        # d_scores = []
        # # print(len(d_embs))
        # cos_sim = torch.nn.CosineSimilarity(dim=1)  #Dot product
        # # print(len(docs), len(d_embs))
        # # print(q_emb.shape)
        # # print(docs[0].shape)
        # # print(d_embs[0].shape)
        # for d_emb in tqdm(d_embs):
        #     # print(q_emb.shape, d_emb.shape)
        #     similarity = cos_sim(q_emb, d_emb)
        #     # print(similarity.shape)
        #     d_scores.append(similarity)
        # # print(d_scores)
        # d_scores = torch.stack(d_scores, dim=0)
        # # print(d_scores.shape)
        # d_scores = torch.transpose(d_scores, 0, 1)
        # # print(d_scores.shape)
        # # print(d_scores)

        return d_scores

    # def llm_pass(self, questions, retrieved, answers):
    #     # Run query through LLM with each chunk and get NLL
    #     # with torch.no_grad():
    #         llm_scores = []
    #         for docs in tqdm(retrieved):
    #             batch_scores = []
    #             for i in range(len(questions)):
    #                 q = questions[i]
    #                 doc  = docs[i]
    #                 answer = answers[i]
    #                 # print(q)
    #                 # print(doc + ' ' + q)
    #                 # output = self.llm(doc + q)
    #                 # loss = output.loss
    #                 # llm_scores.append(loss)
    #                 input = tokenizer(doc + ' ' + q, return_tensors='pt')
    #                 # print(input['input_ids'])

    #                 tokenized_answer = tokenizer(answer, return_tensors='pt')
    #                 # print(tokenized_answer)
    #                 probs = []
    #                 tokenized_input = input['input_ids'].to('cuda')
    #                 for token in tokenized_answer['input_ids'][0]:
    #                     # print(input['input_ids'])
    #                     expected_token = token.to('cuda')

    #                     # print(next(self.llm.parameters()).is_cuda)
    #                     # print(tokenized_input.is_cuda)
    #                     outputs = self.llm(tokenized_input)
    #                     tokenized_input = torch.cat((tokenized_input, torch.tensor([[expected_token]]).to('cuda')), 1)
                        
    #                     last_output = outputs['logits'][0][-1]
    #                     last_output = torch.nn.functional.softmax(last_output, dim=0)
    #                     # print(last_output)
    #                     last_output = list(last_output.detach().cpu().numpy())

    #                     last_idx = last_output.index(max(last_output))
    #                     probs.append(last_output[expected_token])
                        
    #                     # print('expected prob: ', last_output[expected_token], ' token: ', tokenizer.decode(expected_token))
    #                     # print(last_idx)
    #                     # print(tokenizer.decode(last_idx))
    #                 # print(probs)

    #                 #perplexity
    #                 denom = 1
    #                 for prob in probs:
    #                     denom *= (1 / prob)
    #                 perplexity = denom ** (1 / len(probs))
    #                 # print(perplexity)

    #                 #ll
    #                 total = 0
    #                 for prob in probs:
    #                     total += np.log10(prob)
    #                 # print(total)
    #                 # total *= -1
    #                 score = total
    #                 batch_scores.append(score)
    #             # print(batch_scores)
    #             llm_scores.append(batch_scores)
    #         llm_scores = torch.tensor(llm_scores)
    #         llm_scores = torch.transpose(llm_scores, 0, 1).to('cuda')
    #         # print(llm_scores)
    #         return llm_scores
    
    def new_llm_pass(self, questions, docs, answers):
        # Again, reshape to put into LLM as B x ? shape
        questions, questions_mask = questions['input_ids'], torch.logical_not(torch.tensor(questions['attention_mask'], dtype=torch.bool))
        docs, docs_mask = docs['input_ids'], torch.logical_not(torch.tensor(docs['attention_mask'], dtype=torch.bool))
        answers, answers_mask = answers['input_ids'], torch.logical_not(torch.tensor(answers['attention_mask'], dtype=torch.bool)) # B x A
        B, K, L = docs.shape
        docs = torch.reshape(docs, (B*K, L))                   # BK x L
        _, answer_length = answers.shape
        expanded_questions = torch.unsqueeze(questions, 1)                 # B x 1 x S
        expanded_questions = expanded_questions.expand(-1, K, -1)          # B x K x S
        expanded_questions = torch.reshape(expanded_questions, (B*K, -1))  # BK x S
        print('expanded_questions', expanded_questions.shape, expanded_questions)
        combined_input = torch.cat([docs, expanded_questions], dim=1) # BK x (S + L)
        print('combined_input', combined_input.shape)
        expanded_answers = torch.unsqueeze(answers, 1)                # B x 1 x A
        expanded_answers = expanded_answers.expand(-1, K, -1)         # B x K x A
        expanded_answers = torch.reshape(expanded_answers, (B*K, -1)) # BK x A
        # expanded_answers_mask = torch.unsqueeze(answers_mask, 1)                # B x 1 x A
        # expanded_answers_mask = expanded_answers_mask.expand(-1, K, -1)         # B x K x A
        # expanded_answers_mask = torch.reshape(expanded_answers_mask, (B*K, -1)) # BK x A
        all_scores = []                                          # BK x A
        for i in range(answer_length):
            expected_tokens = expanded_answers[:, i]             # BK x 1
            print('expected_tokens', expected_tokens.shape, expected_tokens)
            outputs = self.llm(combined_input)                   # BK x (S+L+i) x V
            print('outputs[logits]', outputs['logits'].shape)
            last_outputs = outputs['logits'][:, -1, :]           # BK x V
            print('last_outputs', last_outputs.shape, last_outputs)
            scores = last_outputs[torch.arange(B*K), expected_tokens] # BK x 1
            print('scores', scores.shape, scores)
            all_scores.append(scores)

            combined_input = torch.cat([combined_input, torch.unsqueeze(expected_tokens, 1)], dim=1) #BK x (S + L + i)
        all_scores = torch.stack([x for x in all_scores], dim=1)
        print('all_scores', all_scores.shape, all_scores)
        all_scores = all_scores.reshape(B, K, -1)
        print('all_scores', all_scores.shape, all_scores)
        all_scores = torch.mean(all_scores, dim=-1)
        print('all_scores', all_scores.shape, all_scores)
        return all_scores


    def training_step(self, batch, batch_idx):
        # TODO: Make this actually work with the dataset constructed above
        # Pre-Process data (moved from above in order to batch tokenize for efficiency)
        questions = batch['questions']
        docs = batch['retrieved']
        answers = batch['answers']
        tokenized_questions = tokenizer(questions, padding=True, return_tensors='pt').to('cuda')
        tokenized_docs = tokenizer([x for retr in docs for x in retr], padding=True, return_tensors='pt').to('cuda')
        tokenized_docs['input_ids'] = torch.reshape(tokenized_docs['input_ids'], (len(docs[0]), len(docs), -1))
        tokenized_docs['attention_mask'] = torch.reshape(tokenized_docs['attention_mask'], (len(batch['retrieved'][0]), len(batch['retrieved']), -1))
        tokenized_answers = tokenizer(answers, padding=True, return_tensors='pt').to('cuda')

        # Normalize the retrieval scores from the forward pass
        # TODO: Apply the masks in the scoring process - currently no masks but still converges on single batch
        # print(batch['tokenized_docs'][0]['input_ids'])
        # print(torch.stack([torch.stack(batch['tokenized_docs'][0]['input_ids'][i]) for i in range(len(batch['tokenized_docs'][0]['input_ids']))]))
        # print(torch.transpose(torch.stack(batch['tokenized_docs'][0]['input_ids'][0], dim=0), 0, 1))
        # print(torch.stack(batch['tokenized_docs']['input_ids'][0], dim=0))
        # print(len(batch['tokenized_docs']))
        # print(batch['tokenized_docs'][0]['input_ids'])
        # tokenized_query = torch.transpose(torch.stack(batch['input']['input_ids'][0], dim=0), 0, 1)
        # query_mask = torch.transpose(torch.stack(batch['input']['attention_mask'][0], dim=0), 0, 1)
        # tokenized_docs = torch.stack([
        #                     torch.stack([ex for ex in batch['tokenized_docs'][i]['input_ids']][0])
        #                  for i in range(len(batch['tokenized_docs']))])
        # tokenized_docs = torch.transpose(tokenized_docs, 1, 2)
        # tokenized_docs = torch.transpose(tokenized_docs, 0, 1)
        # docs_mask = torch.stack([
        #                     torch.stack([ex for ex in batch['tokenized_docs'][i]['attention_mask']][0])
        #                  for i in range(len(batch['tokenized_docs']))])
        # docs_mask = torch.transpose(docs_mask, 1, 2)
        # docs_mask = torch.transpose(docs_mask, 0, 1)
        # tokenized_answer = torch.transpose(torch.stack(batch['tokenized_answer']['input_ids'][0], dim=0), 0, 1)
        # answer_mask = torch.transpose(torch.stack(batch['tokenized_answer']['attention_mask'][0], dim=0), 0, 1)
        # print('tokenized_docs', tokenized_docs.shape)
        # print('tokenized_answer', tokenized_answer.shape, tokenized_answer)
        # query_mask = torch.tensor(query_mask, dtype=torch.bool)
        # docs_mask = torch.tensor(docs_mask, dtype=torch.bool)
        # answer_mask = torch.tensor(answer_mask, dtype=torch.bool)
        # print('query_mask', query_mask)
        # questions = batch['questions']
        # retrieved = batch['retrieved']
        # answers = batch['answers']
        # print(len(retrieved))
        # print(tokenized_input.shape, tokenized_docs.shape)

        reranker_output = self(tokenized_questions, tokenized_docs)  # output is retrieval scores?
        rerank_dist = torch.nn.functional.log_softmax(reranker_output, dim=1)
        
        # Run an LLM to get the NLLs - NLLs or is it just the logits??
        llm_output = self.new_llm_pass(tokenized_questions, tokenized_docs, tokenized_answers)
        llm_dist = torch.nn.functional.softmax(llm_output, dim=1)

        print('rerank', rerank_dist.shape, rerank_dist)
        # print(rerank_dist)
        print('llm', llm_dist.shape, llm_dist)
        # print(llm_dist)

        # Compute loss = kldiv(scores, nlls)
        lossfn = torch.nn.KLDivLoss(reduction='batchmean')
        loss = lossfn(rerank_dist, llm_dist) # see docs for notation https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html
        print(loss)
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=batch_size)
        wandb.log({'train_loss': loss})
        return loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.encoder.parameters(), lr=1e-4)
    
    # def score_doc(self, query, doc, answer):
    #     # with torch.no_grad():
    #         input = tokenizer(doc + ' ' + query, return_tensors='pt')
    #         # print(input['input_ids'])

    #         tokenized_answer = tokenizer(answer, return_tensors='pt')
    #         # print(tokenized_answer)
    #         probs = []
    #         tokenized_input = input['input_ids'].to('cuda')
    #         for token in tokenized_answer['input_ids'][0]:
    #             # print(input['input_ids'])
    #             expected_token = token.to('cuda')

    #             # print(next(self.llm.parameters()).is_cuda)
    #             # print(tokenized_input.is_cuda)
    #             outputs = self.llm(tokenized_input)
    #             tokenized_input = torch.cat((tokenized_input, torch.tensor([[expected_token]]).to('cuda')), 1)
                
    #             last_output = outputs['logits'][0][-1]
    #             last_output = torch.nn.functional.softmax(last_output, dim=0)
    #             # print(last_output)
    #             last_output = list(last_output.detach().cpu().numpy())

    #             last_idx = last_output.index(max(last_output))
    #             probs.append(last_output[expected_token])
                
    #             # print('expected prob: ', last_output[expected_token], ' token: ', tokenizer.decode(expected_token))
    #             # print(last_idx)
    #             # print(tokenizer.decode(last_idx))
    #         # print(probs)

    #         #perplexity
    #         denom = 1
    #         for prob in probs:
    #             denom *= (1 / prob)
    #         perplexity = denom ** (1 / len(probs))
    #         # print(perplexity)

    #         #ll
    #         total = 0
    #         for prob in probs:
    #             total += np.log10(prob)
    #         # print(total)
    #         # total *= -1
    #         return total

model = ReplugTransformer(vocab_size=tokenizer.vocab_size)
trainer = L.Trainer()
trainer.fit(model, tiny_train_loader, tiny_valid_loader)

In [None]:
#for batch in ds['train']:
#  batch_scores = net(batch) # Bx1 float
#  retrieval_scores = F.softmax(batch_scores)
#  nll_scores = llm(batch).mean() # Bx1 float
#  loss = kldiv(batch_scores, nll_scores)
#  # TODO: loss += weight_coeff * sum([len(x) for x in batch])
#  loss.backward()

# Process:
# - Retrieve with high n_results
# - Rerank using the above neural net
# - Take the top-k from that reranker
# - Give to LLM
# - Evaluate

# Once the above actually works:
# - Add additional deduplication steps?
# - Add a diversity term to the loss?
# - Add metadata to the chunk (e.g. prefix 'TokenTextSplitter' to each such chunk for the net() call, NOT the llm() call)


# Document score side:
# Input: BxS
# After retrieval: (BxS, BxKxL) (assume tokenized)
# After encoding: (BxSxE -> BxE), (BxKxE) (where E is embedding/hidden dim, separate encoders)
# After scoring: BxK (do bmm/einsum to get this)
# After normalize: BxK (logsoftmax, make sure you don't do double log)

# NLL score side:
# Input: BxS
# After encoding: Bx(S+L) -> Bx(Correct Answer)xV -> select logit of correct token, take mean -> Bx1

# 'Context: The largest city in Japan is Tokyo. Question: What is the largest city in Japan? Answer: ' -> what is the logit for Tokyo?
# Autoregressively generate for max_length_of_correct_answers_in_batch, giving you some matrix BxM,
# zero out all the components that you don't need (ie if the correct answer was 1 token, all the other tokens you can zero out),
# then take the mean of the non-zero elements to get Bx1 NLL tensor

# ^ TODO: double check that this is how Replug does it

# After loss: Bx1 (kldiv)

# Experimental protocol:
# 1. Validate that you get the above shapes
# 2. Print the BxK scores over time, make sure they change
# 3. 'Spike' the retrievals with 7 bad examples and 1 correct one, make sure the mass shifts to the correct one
# 4. Make sure train loss goes down (first on the spiked one, then on single batch, then all batches)
# 5. Run with fast BM25
