In [1]:
import json

import torch
import transformers
from environs import env
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
)

from local_funcs import prompt_funcs
from yiutils.project_utils import find_project_root

proj_root = find_project_root("justfile")
data_dir = proj_root / "data"

print(transformers.__version__)
print(torch.__version__)
print(torch.cuda.is_available())
print(torch.version.cuda)

env.read_env(proj_root / ".env")
access_token = env("HUGGINGFACE_TOKEN")

path_to_mr_pubmed_data = (
    data_dir / "intermediate" / "mr-pubmed-data" / "mr-pubmed-data.json"
)
assert path_to_mr_pubmed_data.exists(), (
    f"Data file {path_to_mr_pubmed_data} does not exist."
)

with open(path_to_mr_pubmed_data, "r") as f:
    mr_pubmed_data = json.load(f)

article_data = mr_pubmed_data[0]

  from .autonotebook import tqdm as notebook_tqdm


4.51.3
2.6.0
True
12.6


In [10]:
MODEL_ID = "Henrychur/MMed-Llama-3-8B"

device = "cuda"
dtype = torch.bfloat16

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=access_token, legacy=False)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=dtype,
    device_map=device,
    token=access_token,
)

Loading checkpoint shards: 100%|██████████| 7/7 [00:01<00:00,  5.56it/s]


In [11]:
message_metadata = prompt_funcs.make_message_metadata(article_data["ab"])
message_results = prompt_funcs.make_message_results(article_data["ab"])

In [12]:
messages = message_metadata
# messages = message_results
input_ids = tokenizer.apply_chat_template(
    conversation=messages, add_generation_prompt=True, return_tensors="pt"
).to(model.device)
print(input_ids.shape)
input_ids

torch.Size([1, 1056])


tensor([[128000, 128006,   9125,  ...,  78191, 128007,    271]],
       device='cuda:0')

In [13]:
terminators = [
    tokenizer.eos_token_id,
    # tokenizer.convert_tokens_to_ids("<|eot_id|>"),
]
outputs = model.generate(
    input_ids,
    max_new_tokens=2048,
    eos_token_id=terminators,
    do_sample=True,
    temperature=0.1,
    # top_p=0.15,
)
print(outputs.shape)
outputs

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


torch.Size([1, 3104])


tensor([[128000, 128006,   9125,  ...,    330,  26992,  45876]],
       device='cuda:0')

In [14]:
tokenizer.decode(outputs[0], skip_special_tokens=True)

'system\n\nYou are a data scientist responsible for extracting accurate information from research papers. You answer each question with a single JSON string.user\n\nThis is an abstract from a Mendelian randomization study.\n                    "Alcohol consumption significantly impacts disease burden and has been linked to various diseases in observational studies. However, comprehensive meta-analyses using Mendelian randomization (MR) to examine drinking patterns are limited. We aimed to evaluate the health risks of alcohol use by integrating findings from MR studies. A thorough search was conducted for MR studies focused on alcohol exposure. We utilized two sets of instrumental variables-alcohol consumption and problematic alcohol use-and summary statistics from the FinnGen consortium R9 release to perform de novo MR analyses. Our meta-analysis encompassed 64 published and 151 de novo MR analyses across 76 distinct primary outcomes. Results show that a genetic predisposition to alcoh

In [15]:
response = outputs[0][input_ids.shape[-1] :]
result = tokenizer.decode(response, skip_special_tokens=True)
print(result)

This is an example output in JSON format: 
    { "metadata": {
    "exposures": [
    {
        "id": "1",
        "trait": "Alcohol consumption",
        "category": "behavioural"
    },
    {
        "id": "2",
        "trait": "Problematic alcohol use",
        "category": "behavioural"
    }
    ],
    "outcomes": [
    {
        "id": "1",
        "trait": "Parkinson's disease",
        "category": "neoplasm"
    },
    {
        "id": "2",
        "trait": "Prostate hyperplasia",
        "category": "neoplasm"
    },
    {
        "id": "3",
        "trait": "Rheumatoid arthritis",
        "category": "neoplasm"
    },
    {
        "id": "4",
        "trait": "Chronic pancreatitis",
        "category": "disease of the digestive system"
    },
    {
        "id": "5",
        "trait": "Colorectal cancer",
        "category": "neoplasm"
    },
    {
        "id": "6",
        "trait": "Head and neck cancers",
        "category": "neoplasm"
    },
    {
        "id": "7",
        "