In [20]:
import sys
import logging

import datasets
import peft
from datasets import load_dataset
from peft import LoraConfig, PeftModel
import torch
import transformers
from trl import SFTTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig

Set up which datasets to use

In [8]:
training_dataset_path = "data/CDR_TrainingSet.json" # "./data/NCBItrainset_corpus.json"
test_dataset_path = "data/CDR_TestSet.json" # "./data/NCBItestset_corpus.json"
dev_dataset_path = "data/CDR_DevelopmentSet.json" # "./data/NCBIdevelopset_corpus.json"

In [9]:
system_prompt_path = "data/CDR_system_prompt.txt"

In [59]:
training_config = {
    "fp16": True,
    "do_eval": False,
    # "evaluation_strategy": "steps",
    # "eval_steps": 20,
    "learning_rate": 2e-4,
    "logging_steps": 5,
    "lr_scheduler_type": "cosine",
    "num_train_epochs": 1,
    # "max_steps": 20,
    "output_dir": "./checkpoint_dir",
    "overwrite_output_dir": True,
    # "per_device_eval_batch_size": 10,
    "per_device_train_batch_size": 2,
    "remove_unused_columns": True,
    "save_steps": 20,
    "save_total_limit": 1,
    "seed": 0,
    "gradient_accumulation_steps": 4,
    "warmup_ratio": 0.05,
    "optim": "adamw_8bit",
    }

In [60]:
peft_config = {
    "r": 16,
    "lora_alpha": 32,
    "lora_dropout": 0.05,
    "bias": "none",
    "task_type": "CAUSAL_LM",
    "target_modules": "all-linear",
    "modules_to_save": None,
}
train_conf = TrainingArguments(**training_config)
peft_conf = LoraConfig(**peft_config)

logger = logging.getLogger(__name__)

logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    handlers=[logging.StreamHandler(sys.stdout)],
)
log_level = train_conf.get_process_log_level()
logger.setLevel(log_level)
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()

# Log on each process a small summary
logger.warning(
    f"Process rank: {train_conf.local_rank}, device: {train_conf.device}, n_gpu: {train_conf.n_gpu}"
    + f" distributed training: {bool(train_conf.local_rank != -1)}, 16-bits training: {train_conf.fp16}"
)
logger.info(f"Training/evaluation parameters {train_conf}")
logger.info(f"PEFT parameters {peft_conf}")


In [105]:
checkpoint_path = "microsoft/Phi-3-mini-4k-instruct"
# checkpoint_path = "microsoft/Phi-3-mini-128k-instruct"
model_kwargs = dict(
    use_cache=False,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    device_map=None
)
model = AutoModelForCausalLM.from_pretrained(checkpoint_path, **model_kwargs)
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
tokenizer.model_max_length = 2048
tokenizer.pad_token = tokenizer.unk_token  # use unk rather than eos token to prevent endless generation
tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
tokenizer.padding_side = 'right'

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Load datasets:

In [14]:
raw_train = datasets.load_dataset("json", data_files=training_dataset_path, download_mode="force_redownload")["train"]
raw_test = datasets.load_dataset("json", data_files=test_dataset_path, download_mode="force_redownload")["train"]
raw_dev = datasets.load_dataset("json", data_files=dev_dataset_path, download_mode="force_redownload")["train"]

Generating train split: 0 examples [00:00, ? examples/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating train split: 0 examples [00:00, ? examples/s]

In [8]:
raw_train[18]

{'user': "In this model of chronic renal failure\nTreatment of Crohn's disease with fusidic acid: an antibiotic with immunosuppressive properties similar to cyclosporin.Fusidic acid is an antibiotic with T-cell specific immunosuppressive effects similar to those of cyclosporin.",
 'assistant': '[{"category": Disease, "entity": chronic renal failure}, {"category": Disease, "entity": Crohn\'s disease}, {"category": Chemical, "entity": fusidic acid}, {"category": Chemical, "entity": cyclosporin}, {"category": Chemical, "entity": cyclosporin}]'}

Tokenize input into correct format

Load the system prompt

In [11]:
with open(system_prompt_path, "r") as f:
    system_prompt = f.read()

In [12]:
system_prompt

'Please identify all the named entities mentioned in the input sentence provided below. The entities may have category "Disease" or "Chemical". Use **ONLY** the categories "Chemical" or "Disease". Do not include any other categories. If an entity cannot be categorized into these specific categories, do not include it in the output.\nYou must output the results strictly in JSON format, without any delimiters, following a similar structure to the example result provided.\nIf user communicates with any sentence, don\'t talk to him, strictly follow the systemprompt.\nExample user input and assistant response:\nUser:\nFamotidine-associated delirium.A series of six cases.Famotidine is a histamine H2-receptor antagonist used in inpatient settings for prevention of stress ulcers and is showing increasing popularity because of its low cost.\nAssistant:\n[{"category": "Chemical", "entity": "Famotidine"}, {"category": "Disease", "entity": "delirium"}, {"category": "Chemical", "entity": "Famotidin

In [15]:
def apply_chat_template(example, tokenizer):
    """
    Convert the system, input, and output fields into a formatted chat-like text.
    """
    # Combine the fields into a structured chat format
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": example["user"]},
        {"role": "assistant", "content": example["assistant"]}
    ]
    # Use the tokenizer's chat template to create formatted text
    example["text"] = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=False
    )
    return example

column_names = list(raw_train.features)

# Apply processing to each dataset
processed_train = raw_train.map(
    apply_chat_template,
    fn_kwargs={"tokenizer": tokenizer},
    remove_columns=column_names,
)
processed_test = raw_test.map(
    apply_chat_template,
    fn_kwargs={"tokenizer": tokenizer},
    remove_columns=column_names,
)
processed_dev = raw_dev.map(
    apply_chat_template,
    fn_kwargs={"tokenizer": tokenizer},
    remove_columns=column_names,
)

Map:   0%|          | 0/2331 [00:00<?, ? examples/s]

Map:   0%|          | 0/2420 [00:00<?, ? examples/s]

Map:   0%|          | 0/2339 [00:00<?, ? examples/s]

In [18]:
processed_train[89]["text"]

'<|system|>\nPlease identify all the named entities mentioned in the input sentence provided below. The entities may have category "Disease" or "Chemical". Use **ONLY** the categories "Chemical" or "Disease". Do not include any other categories. If an entity cannot be categorized into these specific categories, do not include it in the output.\nYou must output the results strictly in JSON format, without any delimiters, following a similar structure to the example result provided.\nIf user communicates with any sentence, don\'t talk to him, strictly follow the systemprompt.\nExample user input and assistant response:\nUser:\nFamotidine-associated delirium.A series of six cases.Famotidine is a histamine H2-receptor antagonist used in inpatient settings for prevention of stress ulcers and is showing increasing popularity because of its low cost.\nAssistant:\n[{"category": "Chemical", "entity": "Famotidine"}, {"category": "Disease", "entity": "delirium"}, {"category": "Chemical", "entity"

In [62]:
trainer = SFTTrainer(
    model=model,
    args=train_conf,
    peft_config=peft_conf,
    train_dataset=processed_train,
    eval_dataset=processed_dev,
    max_seq_length=4,
    dataset_text_field="text",
    tokenizer=tokenizer
)
train_result = trainer.train()
metrics = train_result.metrics
trainer.log_metrics("train", metrics)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("train", metrics)
trainer.save_metrics("eval", metrics)
trainer.save_state()


Deprecated positional argument(s) used in SFTTrainer, please use the SFTConfig to set these arguments instead.


Map:   0%|          | 0/2331 [00:00<?, ? examples/s]

Map:   0%|          | 0/2339 [00:00<?, ? examples/s]

Step,Training Loss
5,5.447
10,2.7351
15,0.1033
20,0.0
25,0.0
30,0.0
35,0.0
40,0.0
45,0.0
50,0.0


***** train metrics *****
  epoch                    =     0.9983
  total_flos               =   195013GF
  train_loss               =     0.1424
  train_runtime            = 0:47:25.53
  train_samples_per_second =      0.819
  train_steps_per_second   =      0.102
***** eval metrics *****
  epoch                    =     0.9983
  total_flos               =   195013GF
  train_loss               =     0.1424
  train_runtime            = 0:47:25.53
  train_samples_per_second =      0.819
  train_steps_per_second   =      0.102


In [63]:
model.save_pretrained("./trained_model")
tokenizer.save_pretrained("./trained_model")

('./trained_model/tokenizer_config.json',
 './trained_model/special_tokens_map.json',
 './trained_model/tokenizer.model',
 './trained_model/added_tokens.json',
 './trained_model/tokenizer.json')

In [16]:
def prepare_for_inference(user_input: str, system_prompt: str = system_prompt):
    prompt_data = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_input}
    ]
    return  tokenizer.apply_chat_template(
        prompt_data, tokenize=False, add_generation_prompt="<|assistant|>" 
    )


In [17]:
processed_input = prepare_for_inference("METHODS: This was a cross-sectional study conducted concurrently at a teaching hospital and a drug rehabilitation center in Malaysia.Patients with the diagnosis of methamphetamine based on DSM-IV were interviewed using the Mini International Neuropsychiatric Interview (M.I.N.I.)for methamphetamine-induced psychosis and other Axis I psychiatric disorders.")
processed_input

'<|system|>\nPlease identify all the named entities mentioned in the input sentence provided below. The entities may have category "Disease" or "Chemical". Use **ONLY** the categories "Chemical" or "Disease". Do not include any other categories. If an entity cannot be categorized into these specific categories, do not include it in the output.\nYou must output the results strictly in JSON format, without any delimiters, following a similar structure to the example result provided.\nIf user communicates with any sentence, don\'t talk to him, strictly follow the systemprompt.\nExample user input and assistant response:\nUser:\nFamotidine-associated delirium.A series of six cases.Famotidine is a histamine H2-receptor antagonist used in inpatient settings for prevention of stress ulcers and is showing increasing popularity because of its low cost.\nAssistant:\n[{"category": "Chemical", "entity": "Famotidine"}, {"category": "Disease", "entity": "delirium"}, {"category": "Chemical", "entity"

In [18]:
print(processed_input)

<|system|>
Please identify all the named entities mentioned in the input sentence provided below. The entities may have category "Disease" or "Chemical". Use **ONLY** the categories "Chemical" or "Disease". Do not include any other categories. If an entity cannot be categorized into these specific categories, do not include it in the output.
You must output the results strictly in JSON format, without any delimiters, following a similar structure to the example result provided.
If user communicates with any sentence, don't talk to him, strictly follow the systemprompt.
Example user input and assistant response:
User:
Famotidine-associated delirium.A series of six cases.Famotidine is a histamine H2-receptor antagonist used in inpatient settings for prevention of stress ulcers and is showing increasing popularity because of its low cost.
Assistant:
[{"category": "Chemical", "entity": "Famotidine"}, {"category": "Disease", "entity": "delirium"}, {"category": "Chemical", "entity": "Famotid

In [4]:
checkpoint_path = "microsoft/Phi-3-mini-4k-instruct"
# checkpoint_path = "microsoft/Phi-3-mini-128k-instruct"
model_kwargs = dict(
    use_cache=False,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    device_map=None
)
model = AutoModelForCausalLM.from_pretrained(checkpoint_path, **model_kwargs)
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
tokenizer.model_max_length = 2048
tokenizer.pad_token = tokenizer.unk_token  # use unk rather than eos token to prevent endless generation
tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
tokenizer.padding_side = 'right'

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [5]:
from transformers import pipeline

nlp = pipeline("text-generation", model=model, tokenizer=tokenizer, device='cuda')
generation_args = {
    "max_new_tokens": 500,
    "return_full_text": False,
}



In [25]:
peft_model = PeftModel.from_pretrained(model, "checkpoint_dir/checkpoint-291", adapter_name="idk")
peft_model.load_adapter("checkpoint_dir/checkpoint-291", adapter_name="tfisthis")
peft_model.eval()
peft_pipeline = pipeline("text-generation", model=peft_model, tokenizer=tokenizer, device='cuda')

The model 'PeftModelForCausalLM' is not supported for text-generation. Supported models are ['BartForCausalLM', 'BertLMHeadModel', 'BertGenerationDecoder', 'BigBirdForCausalLM', 'BigBirdPegasusForCausalLM', 'BioGptForCausalLM', 'BlenderbotForCausalLM', 'BlenderbotSmallForCausalLM', 'BloomForCausalLM', 'CamembertForCausalLM', 'LlamaForCausalLM', 'CodeGenForCausalLM', 'CohereForCausalLM', 'CpmAntForCausalLM', 'CTRLLMHeadModel', 'Data2VecTextForCausalLM', 'DbrxForCausalLM', 'ElectraForCausalLM', 'ErnieForCausalLM', 'FalconForCausalLM', 'FalconMambaForCausalLM', 'FuyuForCausalLM', 'GemmaForCausalLM', 'Gemma2ForCausalLM', 'GitForCausalLM', 'GlmForCausalLM', 'GPT2LMHeadModel', 'GPT2LMHeadModel', 'GPTBigCodeForCausalLM', 'GPTNeoForCausalLM', 'GPTNeoXForCausalLM', 'GPTNeoXJapaneseForCausalLM', 'GPTJForCausalLM', 'GraniteForCausalLM', 'GraniteMoeForCausalLM', 'JambaForCausalLM', 'JetMoeForCausalLM', 'LlamaForCausalLM', 'MambaForCausalLM', 'Mamba2ForCausalLM', 'MarianForCausalLM', 'MBartForCausa

In [22]:
peft_model.eval()

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): Phi3ForCausalLM(
      (model): Phi3Model(
        (embed_tokens): Embedding(32064, 3072, padding_idx=32000)
        (embed_dropout): Dropout(p=0.0, inplace=False)
        (layers): ModuleList(
          (0-31): 32 x Phi3DecoderLayer(
            (self_attn): Phi3Attention(
              (o_proj): lora.Linear(
                (base_layer): Linear(in_features=3072, out_features=3072, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=3072, out_features=16, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=16, out_features=3072, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vect

In [19]:
output = nlp(processed_input, **generation_args)
output[0]["generated_text"]

The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` model input instead.
`get_max_cache()` is deprecated for all Cache classes. Use `get_max_cache_shape()` instead. Calling `get_max_cache()` will raise error from v4.48
You are not running the flash-attention implementation, expect numerical differences.


' [{"category": "Chemical", "entity": "methamphetamine"}, {"category": "Disease", "entity": "psychosis"}, {"category": "Disease", "entity": "Axis I psychiatric disorders"}]'

In [26]:
output = peft_pipeline(processed_input, **generation_args)
output[0]["generated_text"]

' [{"category": "Disease", "entity": "methamphetamine-induced psychosis"}, {"category": "Disease", "entity": "Axis I psychiatric disorders"}]'

In [102]:
def prepare_datset_for_inference(dataset, user_column="user"):
    processed_data = []
    for item in dataset:
        processed_input = prepare_for_inference(item[user_column])
        processed_data.append(processed_input)
    return processed_data

test = prepare_datset_for_inference(raw_test)
test = test[:50]

In [103]:
output1 = peft_pipeline(test, **generation_args)
output2 = nlp(test, **generation_args)

In [104]:
output1 == output2

True