Лабораторная работа - тренировка ЛЛМ на датасете [DocVQA](https://rrc.cvc.uab.es/?ch=17&com=introduction). <br>
Датасет содержит анотации с вопросами к документам и OCR по словам.

Будем тренировать [phi-2](https://huggingface.co/microsoft/phi-2) - 2.7b params

In [8]:
import os
import json

import numpy as np

from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling, BitsAndBytesConfig, DataCollatorForLanguageModeling
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model  
from accelerate import FullyShardedDataParallelPlugin, Accelerator

import torch
from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfig

from datasets import load_dataset

from datetime import datetime

import matplotlib.pyplot as plt
%matplotlib inline

### 2. Тренировка модели

In [9]:
dataset = load_dataset('json', 
                        data_files={'train': '/kaggle/input/docvqa/annotations_train.json',
                                    'val': '/kaggle/input/docvqa/annotations_val.json'}, field="data")

  0%|          | 0/2 [00:00<?, ?it/s]

In [10]:
dataset

DatasetDict({
    train: Dataset({
        features: ['questionId', 'question', 'question_types', 'image', 'docId', 'ucsf_document_id', 'ucsf_document_page_no', 'answers', 'data_split', 'context'],
        num_rows: 39463
    })
    val: Dataset({
        features: ['questionId', 'question', 'question_types', 'image', 'docId', 'ucsf_document_id', 'ucsf_document_page_no', 'answers', 'data_split', 'context'],
        num_rows: 5349
    })
})

In [11]:
fsdp_plugin = FullyShardedDataParallelPlugin(
    state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=False),
    optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=False),
)

accelerator = Accelerator(fsdp_plugin=fsdp_plugin)

In [12]:
# Load model
model_path = "microsoft/phi-2"

model = AutoModelForCausalLM.from_pretrained(
    model_path,    
    device_map="auto",
    quantization_config=BitsAndBytesConfig(
        load_in_8bit=True,
    ),
    torch_dtype=torch.bfloat16,        
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [13]:
tokenizer = AutoTokenizer.from_pretrained(
    model_path,
    padding_side="left",
    add_eos_token=True,
    add_bos_token=True,
    trust_remote_code=True,
    use_fast=False, # needed for now, should be fixed soon
)
tokenizer.pad_token = tokenizer.eos_token

tokenizer_config.json:   0%|          | 0.00/7.34k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/1.08k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/99.0 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [16]:
def prepare_prompt(sample):
    full_prompt =f"""Context:{sample['context']} Instruction: Answer question using context.
### Question:
{sample["question"]}
### Answer:
{sample["answers"][0]}
"""
    sample['prompt'] = full_prompt
    return sample

In [17]:
remove_columns=["context", "question", "question_types", "image", "docId", "ucsf_document_id", "ucsf_document_page_no", "data_split"]
dataset = dataset.map(prepare_prompt, 
                    remove_columns=remove_columns)

  0%|          | 0/39463 [00:00<?, ?ex/s]

  0%|          | 0/5349 [00:00<?, ?ex/s]

In [18]:
dataset

DatasetDict({
    train: Dataset({
        features: ['questionId', 'answers', 'prompt'],
        num_rows: 39463
    })
    val: Dataset({
        features: ['questionId', 'answers', 'prompt'],
        num_rows: 5349
    })
})

In [19]:
max_length = 2048
def tokenize_function(example):
    return tokenizer(example["prompt"],
                     truncation=True,
                     max_length=max_length,
                     padding="max_length")

In [20]:
tokenized_dataset = dataset.map(tokenize_function, remove_columns=dataset["train"].column_names)

  0%|          | 0/39463 [00:00<?, ?ex/s]

  0%|          | 0/5349 [00:00<?, ?ex/s]

In [21]:
tokenized_dataset

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 39463
    })
    val: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 5349
    })
})

In [22]:
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True) 

# Adapter settings
lora_config = LoraConfig(
    r=32, 
    lora_alpha=32, 
    target_modules = [ "q_proj", "k_proj", "v_proj", "dense" ],
    modules_to_save = ["lm_head", "embed_tokens"],
    lora_dropout=0.1, 
    bias="none", 
    task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)

model.config.use_cache = False

In [23]:
if torch.cuda.device_count() > 1: # If more than 1 GPU
    model.is_parallelizable = True
    model.model_parallel = True

In [24]:
project = "docvqa_8bit_150steps"
base_model_name = "phi2"
run_name = base_model_name + "-" + project
output_dir = "./" + run_name

train_args=TrainingArguments(
        output_dir=output_dir,
        warmup_steps=5,
        per_device_train_batch_size=1,
        gradient_accumulation_steps=8,
        max_steps=150,
        learning_rate=2.5e-5,
        logging_steps=5,
        optim="paged_adamw_8bit",
        logging_dir="./logs",        
        save_strategy="steps",       
        save_steps=50,                              
        report_to="tensorboard",           
        run_name=f"{run_name}-{datetime.now().strftime('%Y-%m-%d-%H-%M')}"          
    )

In [25]:
trainer = Trainer(
    model=model,
    train_dataset=tokenized_dataset['train'],
    args=train_args,
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)

model.config.use_cache = False  # silence the warnings. Please re-enable for inference!

In [None]:
trainer.train()



Step,Training Loss
5,3.7237
10,3.4771
15,3.6669
20,3.4313
25,3.5404
30,3.3689


In [None]:
validation_ds = dataset['val']

In [None]:
validation_ds

In [None]:
def prepare_val_prompt(sample):
    full_prompt =f"""Context:{sample['context']} Instruction: Answer question using context.
### Question:
{sample["question"]}
### Answer:

    sample['prompt'] = full_prompt
    return sample

In [None]:
validation_ds = validation_ds.map(prepare_val_prompt)

In [None]:
answers = []
model.eval()
with open('phi2-150steps-4bit.json', 'w') as f:
    for idx, item in enumerate(validation_ds):
        prompt = tokenizer(item['prompt'], return_tensors="pt").to('cuda')
        generated_output = model.generate(**prompt, use_cache=True, max_new_tokens=10)
        cur_dict = {}
        cur_dict['output'] = tokenizer.batch_decode(generated_output)[0]
        cur_dict['answers'] = item['answers']
        cur_dict['questionId'] = item['questionId']
        cur_dict['question_types'] = item['question_types']
        answers.append(cur_dict)
    json.dump(answers, fp=f)