In [1]:
import pandas as pd

df = pd.read_json('/data_vault/hexai/Biolaysum/biolaysumm2024_data/eLife_train.jsonl', lines=True)
df.head(3)

Unnamed: 0,lay_summary,article,headings,keywords,id
0,"In the USA , more deaths happen in the winter ...","In temperate climates , winter deaths exceed s...","[Abstract, Introduction, Results, Discussion, ...",[epidemiology and global health],elife-35500-v1
1,Most people have likely experienced the discom...,Whether complement dysregulation directly cont...,"[Abstract, Introduction, Results, Discussion, ...","[microbiology and infectious disease, immunolo...",elife-48378-v2
2,The immune system protects an individual from ...,Variation in the presentation of hereditary im...,"[Abstract, Introduction, Results, Discussion, ...","[microbiology and infectious disease, immunolo...",elife-04494-v1


In [2]:
val_df = pd.read_json('/data_vault/hexai/Biolaysum/biolaysumm2024_data/eLife_val.jsonl', lines=True)
val_df.head(3)

Unnamed: 0,lay_summary,article,headings,keywords,id
0,The DNA in genes encodes the basic information...,Cell-fate reprograming is at the heart of deve...,"[Abstract, Introduction, Results, Discussion, ...",[developmental biology],elife-15477-v3
1,Klebsiella pneumoniae is a type of bacteria th...,"Klebsiella pneumoniae is a respiratory , blood...","[Abstract, Introduction, Results, Discussion, ...","[microbiology and infectious disease, immunolo...",elife-56656-v2
2,Malaria is one of the world's most deadly infe...,Plasmodium vivax relapse infections occur foll...,"[Abstract, Introduction, Results, Discussion, ...",[epidemiology and global health],elife-04692-v2


In [3]:
from datasets import Dataset

data = Dataset.from_pandas(df[['lay_summary', 'article']])
val_data = Dataset.from_pandas(val_df[['lay_summary', 'article']])

In [4]:
# !huggingface-cli download TheBloke/Orca-2-13B-GGUF orca-2-13b.Q5_K_S.gguf --local-dir . --local-dir-use-symlinks False

In [5]:
model_type = 'gemma2b' # orca13b

if model_type == 'gemma2b':
    model_id = "google/gemma-2b-it"
    
elif model_type == 'orca7b':
    model_id = 'microsoft/Orca-2-7b'

In [6]:
import os
import torch
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

In [7]:
torch.cuda.set_device(1)

In [8]:
torch.cuda.current_device()

1

In [9]:
from accelerate import PartialState

# model_id = "TheBloke/Orca-2-13B-GGUF"
bnb_config = BitsAndBytesConfig(
    load_in_8bit=True,
    bnb_8bit_quant_type="nf8",
    bnb_8bit_compute_dtype=torch.float16
)

os.environ['HF_TOKEN'] = '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'
os.environ['HF_HOME'] = '/data_vault/hexai/huggingface/hub/'

tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.environ['HF_TOKEN'], cache_dir=os.environ['HF_HOME'], use_fast=False)
model = AutoModelForCausalLM.from_pretrained(
    model_id, quantization_config=bnb_config, device_map={"":PartialState().process_index}, token=os.environ['HF_TOKEN'], cache_dir=os.environ['HF_HOME']
)

lora_config = LoraConfig(
    r=4,
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
    task_type="CAUSAL_LM",
)

Gemma's activation function should be approximate GeLU and not exact GeLU.
Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu`, edit the `model.config` to set `hidden_activation=gelu`   instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [10]:
!nvidia-smi

Wed Apr  3 22:36:07 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.223.02   Driver Version: 470.223.02   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Quadro RTX 8000     On   | 00000000:17:00.0 Off |                  Off |
| 50%   70C    P2   153W / 182W |  12107MiB / 48601MiB |     46%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Quadro RTX 8000     On   | 00000000:73:00.0  On |                  Off |
| 33%   56C    P2    68W / 182W |    586MiB / 48592MiB |      0%      Default |
|       

In [11]:
def formatting_func(example):
    output_texts = []
    for i in range(len(example['article'])):
        messages = [
            {"role": "user",
             "content": f"""
                Summarize this document. Text: {example['article'][i]}. 
                Summary:
                """},
             {"role": "assistant",
             "content": "{}".format(example['lay_summary'][i])}
         ]
        output_texts.append(tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False))
    return output_texts
    
# Print the first training example
print(formatting_func(data[:1])[0])

<bos><start_of_turn>user
Summarize this document. Text: In temperate climates , winter deaths exceed summer ones . However , there is limited information on the timing and the relative magnitudes of maximum and minimum mortality , by local climate , age group , sex and medical cause of death . We used geo-coded mortality data and wavelets to analyse the seasonality of mortality by age group and sex from 1980 to 2016 in the USA and its subnational climatic regions . Death rates in men and women ≥ 45 years peaked in December to February and were lowest in June to August , driven by cardiorespiratory diseases and injuries . In these ages , percent difference in death rates between peak and minimum months did not vary across climate regions , nor changed from 1980 to 2016 . Under five years , seasonality of all-cause mortality largely disappeared after the 1990s . In adolescents and young adults , especially in males , death rates peaked in June/July and were lowest in December/January , d

In [12]:
import transformers
from trl import SFTTrainer

trainer = SFTTrainer(
    model=model,
    train_dataset=data,
    eval_dataset=val_data,
    max_seq_length=700,
    args=transformers.TrainingArguments(
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        warmup_steps=2,
        #max_steps=50,
        eval_steps=5, 
        learning_rate=1e-4,
        fp16=True,
        logging_steps=1,
        output_dir="outputs",
        optim="paged_adamw_8bit"
    ),
    peft_config=lora_config,
    formatting_func=formatting_func
)
trainer.train()


Map:   0%|          | 0/4346 [00:00<?, ? examples/s]

Map:   0%|          | 0/241 [00:00<?, ? examples/s]

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


Step,Training Loss
1,3.2081
2,3.2106
3,3.2593
4,3.1389
5,2.8914
6,3.1609
7,2.9628
8,2.8478
9,2.7377
10,2.7771


config.json:   0%|          | 0.00/627 [00:00<?, ?B/s]



TrainOutput(global_step=1086, training_loss=2.0933922266872327, metrics={'train_runtime': 13061.3795, 'train_samples_per_second': 0.998, 'train_steps_per_second': 0.083, 'total_flos': 1.08728923312128e+17, 'train_loss': 2.0933922266872327, 'epoch': 3.0})

In [13]:
transformers.TrainingArguments

transformers.training_args.TrainingArguments

### Save model

In [14]:
# trainer.model.save_pretrained('/data/vep52/nlp/model/lora_adapter')

# Merge the adapters into the base model so you can use the model like a normal transformers model
model = trainer.model.merge_and_unload()
model.save_pretrained(f'nlp/model/{model_type}')

