## Setup

In [11]:
import pandas as pd
import transformers
import textwrap
from transformers import LlamaTokenizer, LlamaForCausalLM
import os
import sys
from typing import List

from peft import (
    LoraConfig,
    get_peft_model,
    get_peft_model_state_dict,
    prepare_model_for_int8_training,
)


import torch
from datasets import load_dataset
import pandas as pd
import json


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE

  from .autonotebook import tqdm as notebook_tqdm



Welcome to bitsandbytes. For bug reports, please run

python -m bitsandbytes

 and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues
bin /home/yangj13/miniconda3/envs/falcon-2/lib/python3.11/site-packages/bitsandbytes/libbitsandbytes_cpu.so
CUDA SETUP: Highest compute capability among GPUs detected: 8.0
CUDA SETUP: Detected CUDA version 117
CUDA SETUP: Loading binary /home/yangj13/miniconda3/envs/falcon-2/lib/python3.11/site-packages/bitsandbytes/libbitsandbytes_cpu.so...


'cuda'

## Data

In [12]:
path = '/mnt/isilon/wang_lab/jingye/projects/data/biolarkgsc_locs.csv'
biolark = pd.read_csv(path, sep='\t')

In [13]:
from sklearn.model_selection import train_test_split
train_df, test_df = train_test_split(biolark, test_size=0.1, random_state=42)
len(train_df), len(test_df)

(205, 23)

In [14]:
def preprocess(train_df):
    rows = []
    for i, row in train_df.iterrows():
#         bad examples in train_df
#         if i in {160,7,23,76}:
#             continue
        list_hpo = row.labels.split(';')
        list_hpo_pair = [[_.split('|')[0], _.split('|')[1].split(':')]  for _ in list_hpo]
        # print(list_hpo_pair)
        # for hpo, [s,e] in list_hpo_pair:
        #     print(row.text[int(s)-50:int(e)+50])
        # break
        ans = ""
        hpos = set()
        for hpo, [s,e] in list_hpo_pair:
            if hpo in hpos:
                continue
            ans += ' ' + row.text[int(s):int(e)] + ' | ' + hpo + '\n'
            hpos.add(hpo)
        ans += ' END'
        text = {'input': row.text, 'output': ans}
        rows.append(text)
#         rows.append({"prompt":f"{row.text}\n\n###\n\n", "completion":f" {ans}"})
    return rows

In [15]:
biolark_ft = preprocess(train_df)
# biolark_ft

In [16]:
biolark_test = preprocess(test_df)

In [17]:
biolark_ft[2], biolark_test[0]

({'input': 'Hereditary isolated brachydactyly type C (OMIM 113100) mostly follows an autosomal dominant pattern of inheritance with a marked variability in expression. This phenotype has been mapped to two different loci on chromosomes 12q24 and 20q11.2. The latter locus contains the cartilage-derived morphogenetic protein (CDMP)1 gene, in which a null mutation has been found in patients with malformations restricted to the upper limbs. A more complex brachydactyly type C phenotype has been mapped to chromosome 12q24. Differences in complexity of these phenotypes have been attributed to locus heterogeneity. Clinical subclassification based on the degree of complexity of the phenotype has therefore been suggested. We present patients with a complex brachydactyly type C phenotype in whom there is considerable intra- and interfamilial variability in expression. We show that clinical subclassification based on the complexity of the brachydactyly type C phenotype related to the genetic defe

In [18]:
with open("biolark_train.json", "w") as f:
    json.dump(biolark_ft, f)
with open("biolark_test.json", "w") as f:
    json.dump(biolark_test, f)

## Dataset

In [19]:
BASE_MODEL = "/mnt/isilon/wang_lab/jingye/projects/gpt/llama_hf"

model = LlamaForCausalLM.from_pretrained(
    BASE_MODEL,
    load_in_8bit=True,
    torch_dtype=torch.float16,
    device_map='auto',
)

tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)

tokenizer.pad_token_id = (
    0  # unk. we want this to be different from the eos token
)
tokenizer.padding_side = "left"

Loading checkpoint shards: 100%|██████████| 2/2 [00:05<00:00,  2.91s/it]
You are using the legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This means that tokens that come after special tokens will not be properly handled. We recommend you to read the related pull request available at https://github.com/huggingface/transformers/pull/24565


In [20]:
train_data = load_dataset("json", data_files="biolark_train.json", split = 'train')

Downloading and preparing dataset json/default to /home/yangj13/.cache/huggingface/datasets/json/default-246a7faddfad9c69/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96...


Downloading data files: 100%|██████████| 1/1 [00:00<00:00, 11522.81it/s]
Extracting data files: 100%|██████████| 1/1 [00:00<00:00, 488.11it/s]
                                                        

Dataset json downloaded and prepared to /home/yangj13/.cache/huggingface/datasets/json/default-246a7faddfad9c69/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96. Subsequent calls will reuse this data.




In [21]:
test_data = load_dataset("json", data_files="biolark_test.json", split = 'train')

Downloading and preparing dataset json/default to /home/yangj13/.cache/huggingface/datasets/json/default-6ba8aa92f9fe6390/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96...


Downloading data files: 100%|██████████| 1/1 [00:00<00:00, 13617.87it/s]
Extracting data files: 100%|██████████| 1/1 [00:00<00:00, 739.87it/s]
                                                        

Dataset json downloaded and prepared to /home/yangj13/.cache/huggingface/datasets/json/default-6ba8aa92f9fe6390/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96. Subsequent calls will reuse this data.




In [22]:
test_data[0]

{'output': ' basal cell carcinoma | HP_0002671\n exotropia | HP_0000577\n cystic mass in the left posterior alveolar ridge | HP_0006477\n odontogenic keratocysts | HP_0010603\n palmar and plantar pitting | HP_0010612\n END',
 'input': 'Nevoid basal cell carcinoma syndrome (NBCCS) is rare in black persons. We describe an 11-year-old black boy with NBCCS who presented with exotropia and a painful, expanding, cystic mass in the left posterior alveolar ridge. Further examination revealed odontogenic keratocysts with palmar and plantar pitting. Less than 5% of reported patients with NBCCS are black. To our knowledge, this is the first report of a black patient with NBCCS presenting with exotropia and an impacted molar displaced into the orbit by an odontogenic keratocyst.'}

In [23]:
CUTOFF_LEN = 512

In [24]:
def generate_prompt(data_point):
    return f"""Input:
{data_point["input"]}
### Response:
{data_point["output"]}"""

#     return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
# ### Instruction:
# Detect the phenotype entities of the clinical abstract.
# ### Input:
# {data_point["input"]}
# ### Response:
# {data_point["output"]}"""

In [25]:
def tokenize(prompt, add_eos_token=True):
    # there's probably a way to do this with the tokenizer settings
    # but again, gotta move fast
    result = tokenizer(
        prompt,
        truncation=True,
        max_length=CUTOFF_LEN,
        padding=False,
        return_tensors=None,
    )
    if (
        result["input_ids"][-1] != tokenizer.eos_token_id
        and len(result["input_ids"]) < CUTOFF_LEN
        and add_eos_token
    ):
        result["input_ids"].append(tokenizer.eos_token_id)
        result["attention_mask"].append(1)

    result["labels"] = result["input_ids"].copy()

    return result

def generate_and_tokenize_prompt(data_point):
    full_prompt = generate_prompt(data_point)
    tokenized_full_prompt = tokenize(full_prompt)
    return tokenized_full_prompt

In [26]:
# train_val = data["train"].train_test_split(
#     test_size=200, shuffle=True, seed=42
# )
train_data = (
    train_data.map(generate_and_tokenize_prompt)
)
val_data = (
    test_data.map(generate_and_tokenize_prompt)
)

                                                              

## Alpaca LoRa

In [29]:
LORA_R = 8
LORA_ALPHA = 16
LORA_DROPOUT= 0.05
LORA_TARGET_MODULES = [
    "q_proj",
    "v_proj",
]

BATCH_SIZE = 128
MICRO_BATCH_SIZE = 128
GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // MICRO_BATCH_SIZE
LEARNING_RATE = 3e-4
TRAIN_STEPS = 40
OUTPUT_DIR = "experiments_debug"

In [30]:
model = prepare_model_for_int8_training(model)
config = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    target_modules=LORA_TARGET_MODULES,
    lora_dropout=LORA_DROPOUT,
    bias="none",
    task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)
# model.print_trainable_parameters()



In [31]:
def print_trainable_parameters(model):
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

In [32]:
print_trainable_parameters(model)

trainable params: 4194304 || all params: 6742609920 || trainable%: 0.06220594176090199


In [33]:
# Verifying the datatypes.
dtypes = {}
for _, p in model.named_parameters():
    dtype = p.dtype
    if dtype not in dtypes:
        dtypes[dtype] = 0
    dtypes[dtype] += p.numel()
total = 0
for k, v in dtypes.items():
    total += v
for k, v in dtypes.items():
    print(k, v, v / total)

torch.float32 266604544 0.03954025921167333
torch.int8 6476005376 0.9604597407883266


## Training

In [38]:
training_arguments = transformers.TrainingArguments(
    per_device_train_batch_size=64, #MICRO_BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    warmup_steps=10,
    max_steps=50, #TRAIN_STEPS,
    learning_rate=LEARNING_RATE,
    fp16=True,
    logging_steps=10,
    optim="adamw_torch",
    evaluation_strategy="steps",
    save_strategy="steps",
    eval_steps=10,
    save_steps=10,
    output_dir=OUTPUT_DIR,
    save_total_limit=3,
    load_best_model_at_end=False, #True,
#     report_to="tensorboard"
)

In [39]:
data_collator = transformers.DataCollatorForSeq2Seq(
    tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
)

In [40]:
trainer = transformers.Trainer(
    model=model,
    train_dataset=train_data,
    eval_dataset=val_data,
    args=training_arguments,
    data_collator=data_collator
)
model.config.use_cache = False
# old_state_dict = model.state_dict
# model.state_dict = (
#     lambda self, *_, **__: get_peft_model_state_dict(
#         self, old_state_dict()
#     )
# ).__get__(model, type(model))

# model = torch.compile(model)



In [41]:
trainer.train()


Step,Training Loss,Validation Loss


In [47]:
model.save_pretrained(OUTPUT_DIR)

## Inference

In [2]:
BASE_MODEL = "/mnt/isilon/wang_lab/jingye/projects/gpt/llama_hf"
lora_weights = './experiments/'

In [3]:
load_8bit = False

In [4]:
model = LlamaForCausalLM.from_pretrained(
            BASE_MODEL,
            load_in_8bit=load_8bit,
#             torch_dtype=torch.float16,
            device_map = "auto"
        )

Loading checkpoint shards: 100%|██████████| 2/2 [00:32<00:00, 16.36s/it]


In [5]:
from peft import PeftModel
model = PeftModel.from_pretrained(
            model,
            lora_weights,
            torch_dtype=torch.float16,
        )

In [6]:
tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)

You are using the legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This means that tokens that come after special tokens will not be properly handled. We recommend you to read the related pull request available at https://github.com/huggingface/transformers/pull/24565


In [7]:
model.config.pad_token_id = tokenizer.pad_token_id = 0  # unk
model.config.bos_token_id = 1
model.config.eos_token_id = 2

In [23]:
# if not load_8bit:
#     model.half()  # seems to fix bugs for some users.



In [9]:
# model.eval()
# if torch.__version__ >= "2" and sys.platform != "win32":
#     model = torch.compile(model)

RuntimeError: Python 3.11+ not yet supported for torch.compile

In [8]:
model.eval()


PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): LlamaForCausalLM(
      (model): LlamaModel(
        (embed_tokens): Embedding(32000, 4096, padding_idx=0)
        (layers): ModuleList(
          (0-31): 32 x LlamaDecoderLayer(
            (self_attn): LlamaAttention(
              (q_proj): Linear(
                in_features=4096, out_features=4096, bias=False
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
              (v_proj): 

In [9]:
from transformers import GenerationConfig
generation_config = GenerationConfig(
            temperature=0.1,
            top_p=0.5,
#             top_k=40,
#             num_beams=4,
#             **kwargs,
#             bad_words_ids = [[2]]
        )

In [24]:
prompt = f"""Three successive generations in two families affected with the popliteal pterygium syndrome are reported. While expression of the syndrome was relatively mild in the first and second generation, the patients in the third generation showed the full-blown syndrome. Differential diagnosis between mildly affected patients with the popliteal pterygium syndrome and those with Van der Woude syndrome is difficult and may even be impossible. The present observations further support the hypothesis that both syndromes may in fact represent variants of the same condition.
### Response:
"""

In [25]:
inputs = tokenizer(prompt, return_tensors="pt")

In [26]:
input_ids = inputs["input_ids"].to('cuda')

In [27]:
# model.to(DEVICE)

In [28]:
with torch.no_grad():
    generation_output = model.generate(
        input_ids=input_ids,
        generation_config=generation_config,
        return_dict_in_generate=True,
        output_scores=True,
        max_new_tokens=300,
    )

In [29]:
generation_output

GreedySearchDecoderOnlyOutput(sequences=tensor([[    1, 12753,  2551,   573,  1176,   800,   297,  1023, 13175, 15201,
           411,   278,  1835, 29880,   568,   284,   282,   357,  4790,  1974,
         22898,  4871,   526,  8967, 29889,  5806,  4603,   310,   278, 22898,
          4871,   471, 13774,   286,   789,   297,   278,   937,   322,  1473,
         12623, 29892,   278, 22069,   297,   278,  4654, 12623, 10018,   278,
          2989, 29899, 29890,   677, 29876, 22898,  4871, 29889,   360,  8349,
          2556, 24876, 19263,  1546,   286,   789,   368, 15201, 22069,   411,
           278,  1835, 29880,   568,   284,   282,   357,  4790,  1974, 22898,
          4871,   322,  1906,   411,  6556,   589,   399,   283,   311, 22898,
          4871,   338,  5189,   322,  1122,  1584,   367,  9301, 29889,   450,
          2198, 13917,  4340,  2304,   278, 20051,   393,  1716, 22898,   456,
           267,  1122,   297,  2114,  2755, 29161,   310,   278,  1021,  4195,
         298

In [30]:
s = generation_output.sequences[0]
output = tokenizer.decode(s)
output

'<s> Three successive generations in two families affected with the popliteal pterygium syndrome are reported. While expression of the syndrome was relatively mild in the first and second generation, the patients in the third generation showed the full-blown syndrome. Differential diagnosis between mildly affected patients with the popliteal pterygium syndrome and those with Van der Woude syndrome is difficult and may even be impossible. The present observations further support the hypothesis that both syndromes may in fact represent variants of the same condition.\n### Response:\n popliteal pterygium | HP_0009762\n Pterygium | HP_0001175\n END</s>'

In [55]:
biolark_test[22]['input']

'We compared epilepsy phenotypes with genotypes of Angelman syndrome (AS), including chromosome 15q11-13 deletions (class I), uniparental disomy (class II), methylation imprinting abnormalities (class III), and mutation in the UBE3A gene (class IV). Twenty patients were prospectively selected based on clinical cytogenetic and molecular diagnosis of AS. All patients had 6 to 72 hours of closed-circuit television videotaping and digitized electroencephalogrpahic (EEG) telemetry. Patients from all genotypic classes had characteristic EEGs with diffuse bifrontally dominant high-amplitude 1- to 3-Hz notched or triphasic or polyphasic slow waves, or slow and sharp waves. Class I patients had severe intractable epilepsy, most frequently with atypical absences and myoclonias and less frequently with generalized extensor tonic seizures or flexor spasms. Epileptic spasms were recorded in AS patients as old as 41 years. Aged-matched class II, III, and IV patients had either no epilepsy or drug-re