<a href="https://colab.research.google.com/github/annanasnas/askqe/blob/main/baseline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# BioMQM + gemma-2-2b-it

In [1]:
import os
from google.colab import userdata

GH_TOKEN = userdata.get('GH_TOKEN')
HF_TOKEN = userdata.get('HF_TOKEN')

In [2]:
from huggingface_hub import login
login(token=HF_TOKEN)

In [3]:
!git clone https://{GH_TOKEN}@github.com/annanasnas/askqe.git

Cloning into 'askqe'...
remote: Enumerating objects: 1199, done.[K
remote: Counting objects: 100% (110/110), done.[K
remote: Compressing objects: 100% (105/105), done.[K
remote: Total 1199 (delta 58), reused 7 (delta 3), pack-reused 1089 (from 2)[K
Receiving objects: 100% (1199/1199), 52.72 MiB | 10.33 MiB/s, done.
Resolving deltas: 100% (928/928), done.
Updating files: 100% (1042/1042), done.


In [14]:
!pip install -q -U bitsandbytes accelerate

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.4/59.4 MB[0m [31m15.1 MB/s[0m eta [36m0:00:00[0m
[?25h

## 1. Question Generation (QG)

### 1.1 Fact Generation

In [4]:
import json
from tqdm.notebook import tqdm

In [None]:
import torch
from transformers import pipeline, BitsAndBytesConfig
if torch.cuda.is_available():
    torch.cuda.empty_cache()

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16
)

pipe = pipeline(
    "text-generation",
    model="google/gemma-2-2b-it",
    model_kwargs={"quantization_config": quantization_config},
    device_map="auto",
)
pipe.tokenizer.padding_side = "left"

model.safetensors.index.json:   0%|          | 0.00/24.2k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/241M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/187 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/47.0k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

Device set to use cuda:0


In [None]:
atomic_fact_prompt = """Task: You will be given an English sentence. Your goal is to identify a list of atomic facts from the sentence. Atomic fact is a short sentence conveying one piece of information. Output the list of atomic facts in Python list format without giving any additional explanation. Do not output as code format (```python```).

*** Example Starts ***
Sentence: The number of accessory proteins and their function is unique depending on the specific coronavirus.
Atomic facts: ['The number of accessory proteins is unique depending on the specific coronavirus.', 'The function of accessory proteins is unique depending on the specific coronavirus.']
*** Example Ends ***

Sentence: {{sentence}}
Atomic facts: """

In [12]:
from collections import deque

BATCH_SIZE = 100
input_file = f"/content/askqe/biomqm/dev_with_backtranslation.jsonl"
output_file = f"/content/askqe/baseline/askqe_atomic_facts.jsonl"

data_buffer = deque()

processed_count = 0 # checkpoint
if os.path.exists(output_file):
    with open(output_file, "r", encoding="utf-8") as f:
        processed_count = sum(1 for _ in f)

with open(input_file, "r", encoding="utf-8") as f:
    total_lines = sum(1 for _ in f)


def data_generator():
    with open(input_file, "r", encoding="utf-8") as f:
        for line in f:
            data = json.loads(line)
            yield data

def prompt_generator(source_gen):
    generated_so_far = 0
    for data in source_gen:
        if "src" in data:
            ################# DELETE LATER #################
            if data.get("lang_tgt") != "es": #temporary
                continue
            ################################################
            if generated_so_far < processed_count:
                generated_so_far += 1
                continue
            data_buffer.append(data)
            prompt = atomic_fact_prompt.replace("{{sentence}}", data["src"])
            full_prompt = f"<start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
            yield full_prompt
        else:
            continue

data_gen = data_generator()
prompts_gen = prompt_generator(data_gen)


with open(output_file, "a", encoding="utf-8") as f_out:
    pipeline_iterator = pipe(prompts_gen, batch_size=BATCH_SIZE, max_new_tokens=1024, return_full_text=False)
    for out in tqdm(pipeline_iterator, total=total_lines, initial=processed_count):
          current_data = data_buffer.popleft()

          response = out[0]["generated_text"].strip()
          if response.endswith("<end_of_turn>"):
              response = response[:-len("<end_of_turn>")].strip()

          current_data["atomic_facts"] = response

          f_out.write(json.dumps(current_data, ensure_ascii=False) + "\n")


In [4]:
import pandas as pd

file_path = "/content/askqe/baseline/askqe_atomic_facts.jsonl"

df = pd.read_json(file_path, lines=True)

df.head()

Unnamed: 0,src,tgt,ref,system,lang_src,lang_tgt,annotator,errors_src,errors_tgt,doc_id,bt_tgt,atomic_facts
0,"However, in the last years several step forwar...","Sin embargo, en los últimos años se han dado p...","[Sin embargo, en los últimos años se han dado ...",talp_upc_run2,en,es,AB/enes,[],"[{'term': 'años', 'startIndex': 28, 'endIndex'...",doc98,"However, in recent years steps have been taken...",['Several steps forwards in the field of preci...
1,In this review we focused on some of these ele...,En esta revisión nos centraremos en algunos de...,[En esta revisión nos centramos en algunos de ...,talp_upc_run2,en,es,AB/enes,[],"[{'term': 'revisión', 'startIndex': 8, 'endInd...",doc98,In this review we will focus on some of these ...,['This review focuses on some of these element...
2,"Although several progresses have been made, at...","Aunque se han producido varios avances, en el ...","[Aunque se han logrado varios progresos, en es...",talp_upc_run2,en,es,AB/enes,[],"[{'term': 'momento', 'startIndex': 46, 'endInd...",doc98,"Although several advances have been made, at p...","['Several progresses have been made.', 'At the..."
3,In this review we focused on some of these ele...,En esta revisión nos centraremos en algunos de...,[En esta revisión nos centramos en algunos de ...,talp_upc_run2,en,es,AB/enes,[],"[{'term': 'aumentada', 'startIndex': 38, 'endI...",doc98,In this review we will focus on some of these ...,['This review focused on some of these element...
4,Transurethral resection of the bladder represe...,La resección transuretral de la vejiga represe...,[La resección transuretral de la vejiga repres...,talp_upc_run2,en,es,AB/enes,[],"[{'term': 'manejo', 'startIndex': 102, 'endInd...",doc98,Transurethral resection of the bladder represe...,['Transurethral resection of the bladder is a ...


### 1.2 Entailment classification

In [15]:
import json, os
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from tqdm import tqdm
import torch
import ast


input_file = f"/content/askqe/baseline/askqe_atomic_facts.jsonl"
output_file = f"/content/askqe/baseline/askqe_atomic_facts_filtered.jsonl"

with open(input_file, "r") as f:
    total_lines = sum(1 for _ in f)

device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained("potsawee/deberta-v3-large-mnli")
model = AutoModelForSequenceClassification.from_pretrained("potsawee/deberta-v3-large-mnli").to(device)

with open(input_file, "r") as f_in, open(output_file, "w") as f_out:
    for line in tqdm(f_in, total=total_lines):
        data = json.loads(line)
        if "src" in data:
            src = data.get("src")
            facts = src.split("', '")
            clean_facts = []

            for i, item in enumerate(facts):

                fact = item.strip()

                if i == 0:
                    fact = fact.lstrip('[')
                if i == len(facts) - 1:
                    fact = fact.rstrip(']')

                if fact[0] in ["'", '"']:
                    fact = fact[1:]
                if fact[-1] in ["'", '"']:
                    fact = fact[1:]

                inputs = tokenizer.batch_encode_plus(
                    batch_text_or_text_pairs=[(src, fact)],
                    add_special_tokens=True, return_tensors="pt",
                ).to(device)
                with torch.no_grad():
                    logits = model(**inputs).logits # neutral is already removed
                    probs = torch.softmax(logits, dim=-1)[0]
                    if probs[1] < 0.5:
                        clean_facts.append(fact)
                    else:
                        print(f"Contradict: {src}, {fact}")
            data["atomic_facts"] = clean_facts
        f_out.write(json.dumps(data, ensure_ascii=False) + "\n")


100%|██████████| 801/801 [01:25<00:00,  9.41it/s]
