In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
# !pip install transformers

In [3]:
from transformers import (
    RobertaConfig,
    RobertaModel,
    AutoTokenizer,
    pipeline,
    AutoModel,
    RobertaTokenizerFast,
    RobertaForQuestionAnswering,
    AutoModelForQuestionAnswering,
    AutoConfig
)

In [4]:
from pathlib import Path

# link primqa 

In [5]:
import sys
sys.path.append("/Users/nishparadox/dev/uah/nasa-impact/llm-experiments/ibm-llm/primeqa")

In [6]:
from primeqa.mrc.models.task_model import ModelForDownstreamTasks

In [7]:
from primeqa.mrc.models.heads.extractive import ExtractiveQAHead, EXTRACTIVE_HEAD

# Load model

In [8]:
def load_model(config, pth, device="cpu", freeze_base=False, freeze_llm=False):
    model = ModelForDownstreamTasks.from_config(
        config,# 
        pretrained_model_name_or_path=pth,
        task_heads=EXTRACTIVE_HEAD
    )
    model.set_task_head("qa_head")
    model.to(device)
    
    # freeze the base model
    print(f"freeze_base={freeze_base}, freeze_llm={freeze_llm}")
    if freeze_base:
        for param in model.roberta.parameters():
            param.requires_grad = False
    
    if freeze_llm:
        # freeze the llm part
        for param in model.lm_head.parameters():
            param.requires_grad = False
            
    return model

In [9]:
# checkpoint_path = "/Users/nishparadox/dev/uah/nasa-impact/llm-experiments/nasa_wiki_v6/sq2v6/train-watbertv6-squad-2ep/"
checkpoint_path = "tmp/checkpoint-679/"
checkpoint_path = "tmp/checkpoint-1359/"
checkpoint_path = "tmp/checkpoint-5438/"

In [10]:
config = AutoConfig.from_pretrained(Path(checkpoint_path).joinpath("config.json"))
model = load_model(
    config,
    Path(checkpoint_path).joinpath("pytorch_model.bin"),
    device="mps"
)

{"time":"2023-05-04 13:55:51,877", "name": "ExtractiveQAHead", "level": "INFO", "message": "Loading dropout value 0.1 from config attribute 'hidden_dropout_prob'"}
{"time":"2023-05-04 13:55:52,427", "name": "RobertaModelForDownstreamTasks", "level": "INFO", "message": "Setting task head for first time to 'None'"}
freeze_base=False, freeze_llm=False


In [11]:
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)

In [12]:
# model = pipeline("question-answering").model
# tokenizer = pipeline("question-answering").tokenizer

In [13]:
model

RobertaModelForDownstreamTasks(
  (roberta): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(65536, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768, padding_idx=0)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0-11): 12 x RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (

# QA pipeline

In [14]:
from langchain.chat_models import ChatOpenAI
from langchain.document_loaders import TextLoader
from langchain.schema import HumanMessage, SystemMessage

{"time":"2023-05-04 13:55:59,177", "name": "numexpr.utils", "level": "INFO", "message": "Note: NumExpr detected 10 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8."}
{"time":"2023-05-04 13:55:59,177", "name": "numexpr.utils", "level": "INFO", "message": "NumExpr defaulting to 8 threads."}


# Putting together

In [15]:
import warnings
warnings.filterwarnings("ignore", message="Length of IterableDataset")

In [16]:
from typing import List

In [17]:
import json

In [18]:
from qa_gen import (
    LangchainSimpleQuestionGenerator,
    CachedLangchainSimpleQuestionGenerator,
    QuestionAnswerGenerator
)

In [19]:
def cmr_document_iterator(path):
    with open(path) as f:
        for data in json.load(f):
            text = data.get("text", "").strip()
            if not text:
                continue
            yield text

In [20]:
def load_document(filepath):
    loader = TextLoader(filepath)
    doc = loader.load()[0]
    return doc.page_content

In [None]:
# i = 0
# for text in cmr_document_iterator("data/cmr.json"):
#     if i > 3:
#         break
#     print(text[:128])
#     i += 1

In [None]:
cmr_iterator = cmr_document_iterator("data/cmr.json")

In [None]:
# document = load_document("data/test.md")
document = next(cmr_iterator)
document

In [None]:
CachedLangchainSimpleQuestionGenerator._PROMPT_SYSTEM_QUESTION

In [None]:
# question_generator = LangchainSimpleQuestionGenerator(ChatOpenAI(temperature=0.0), n_questions=10)
question_generator = CachedLangchainSimpleQuestionGenerator(ChatOpenAI(temperature=0.0), n_questions=10)

In [None]:
question_generator.generate_questions_from_text(document)

In [None]:
qa_generator = QuestionAnswerGenerator(
    question_generator=question_generator,
    model=model,
    tokenizer=tokenizer
)

In [None]:
qa_data = qa_generator.generate_questions_from_text(document, cutoff_threshold=0.1)

In [None]:
len(qa_data)

In [None]:
# for q, a in zip(question_generator.generate_questions_from_text(document), qa_data):
#     print(f"Q = {q}")
#     print(a)
#     print("-"*20)

In [None]:
for _ in qa_data:
    print(_["question"], "|", _["answer"], "|", f"({_['score']})")

### sq2 conversion

In [21]:
from itertools import groupby
import uuid

In [22]:
import json

In [23]:
def convert_to_sq2(data):
    res = dict(version="v2.0", data=[])
    for context, vals in groupby(data, key=lambda x: x["context"]):
        idx = str(hash(context))
        tmpdata = dict(title=idx, paragraphs=[dict(qas=[], context=context)])
        for _qad in vals:
            tmpdata["paragraphs"][0]["qas"].append(
                dict(
                    is_impossible="false",
                    question=_qad["question"],
                    answers=[dict(text=_qad["answer"], answer_start=_qad["start"])],
                    id=uuid.uuid4().hex,
                )
            )
        res["data"].append(tmpdata)
    return res

In [24]:
# with open("data/dump.json", "w") as f:
#     json.dump(convert_to_sq2(qa_data), f)

# In Bulk

In [25]:
from tqdm import tqdm

In [26]:
import pickle

In [27]:
def bulk_qa_generator(doc_iterator, question_generator, qa_generator, n_docs=10, cutoff_threshold=0.1):
    qas = []
    counter = 0
    for document in tqdm(doc_iterator, total=n_docs):
        questions = question_generator.generate_questions_from_text(document)
        qa_data = qa_generator.generate_questions_from_text(
            document,
            cutoff_threshold=cutoff_threshold
        )
        qas.extend(qa_data)
        counter += 1
        if counter > n_docs:
            break
    return convert_to_sq2(qas)

In [30]:
question_generator = CachedLangchainSimpleQuestionGenerator(ChatOpenAI(temperature=0.0), n_questions=10)

In [31]:
qa_generator = QuestionAnswerGenerator(
    question_generator=question_generator,
    model=model,
    tokenizer=tokenizer
)

In [32]:
cmr_iterator = cmr_document_iterator("data/cmr.json")

In [33]:
cmr_qa_sq2 = bulk_qa_generator(
    cmr_iterator,
    question_generator,
    qa_generator,
    n_docs=10
)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [02:20<00:00, 14.01s/it]


In [34]:
len(cmr_qa_sq2["data"])

9

In [36]:
# cmr_qa_sq2

In [37]:
with open("data/cached_question_generator.pkl", "wb") as f:
    pickle.dump(question_generator, f)

In [39]:
with open("data/cmr_qa_sqv2.json", "w") as f:
    json.dump(cmr_qa_sq2, f)

# Dataset test

In [None]:
# !pip install datasets==2.3.2

In [None]:
from primeqa.mrc.processors.preprocessors.squad import SQUADPreprocessor

In [None]:
from datasets import load_dataset

In [None]:
sq2 = load_dataset("squad_v2")

In [None]:
sq2 = load_dataset("json", "plain_text", data_files="data/dump.json")["train"]

In [None]:
sq2[0]

In [None]:
preprocessor = SQUADPreprocessor(
    stride=128,
    tokenizer=tokenizer,
    load_from_cache_file=False,
    negative_sampling_prob_when_has_answer=1.0,
    negative_sampling_prob_when_no_answer=1.0,   
)

In [None]:
examples_train, data_train = preprocessor.process_train(sq2)

In [None]:
examples_train[0]