In [None]:
!git clone https://github.com/Reennon/ua-gec-lora.git
!cd ua-gec-lora && pip install -r requirements.txt
!pwd && ls -a
# Install additional libs
!pip install -q -U bitsandbytes
!pip install -q -U git+https://github.com/huggingface/transformers.git
!pip install -q -U git+https://github.com/huggingface/peft.git
!pip install -q -U git+https://github.com/huggingface/accelerate.git
!pip install git+https://github.com/huggingface/trl.git@7630f877f91c556d9e5a3baa4b6e2894d90ff84c
!pip install ua_gec
!pip install datasets==2.16.0
!pip install nltk
!pip install wandb -q -U
# CD into the project directory
%cd ua-gec-lora
!git pull origin "feature/fine-tuning-research"
!git status

In [None]:
from transformers import AutoModelForCausalLM, pipeline, Conversation, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
from src.packages.constants.error_constants import ErrorConstants
from src.packages.prompts.instruction_tuning_gec_prompts import InstructionTuningGecPrompts
from ua_gec import Corpus
from langchain.prompts import PromptTemplate
from kaggle_secrets import UserSecretsClient
import torch
import nltk
import wandb

nltk.download('punkt')  # Download the necessary resources for sentence tokenization

from nltk.tokenize import sent_tokenize

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

Load HuggingFace and Weights & Biases secrets

In [None]:
user_secrets = UserSecretsClient()
secret_hf = user_secrets.get_secret("HUGGINGFACE_TOKEN")
secret_wandb = user_secrets.get_secret("wandb")

Login to HuggingFace

In [None]:
!huggingface-cli login --token $secret_hf

Login to Weights & Biases and connect to project

In [None]:
wandb_project_name = 'UA-GEC LoRA fine tuning mistral 7B'

wandb.login(key = secret_wandb)
run = wandb.init(
    project=wandb_project_name, 
    job_type="training", 
    anonymous="allow"
)

In [None]:
model_name = "mistralai/Mistral-7B-Instruct-v0.2" #"/kaggle/input/mistral/pytorch/7b-instruct-v0.1-hf/1" 

In [None]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype= torch.bfloat16,
    bnb_4bit_use_double_quant= False,
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16, 
    device_map={'':torch.cuda.current_device()},)

# model.config.use_cache = False # silence the warnings
model.config.pretraining_tp = 1
model.gradient_checkpointing_enable()

In [None]:
model = prepare_model_for_kbit_training(model)
peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False,
    r=4,
    lora_alpha=16,
    bias="none",
    lora_dropout=0.05,  # Conventional
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj","gate_proj"]
)

# base_model.enable_input_require_grads()
peft_model = get_peft_model(model, peft_config)
peft_model.print_trainable_parameters()

In [None]:
template = """[INST] Given a text ("ORIGINAL_TEXT") in Ukrainian with potential errors, correct them to fulfill the GEC (Grammar Error Correction) Task, especially tailored for Mistral 7B LLM.
Consider the provided set of error types ("ERROR_TYPES"):
{error_types}
When you identify an error ("ERROR") in the text, correct it according to the format:
("ERROR") => ("CORRECTION")
The correction should address the error without providing explicit reasoning for the change.
The resulting text ("FIXED_TEXT") should be error-free, maintaining the original information's semantics.
Focus solely on correcting Ukrainian language errors.
Ensure that the corrected text doesn't include original errors, additional text, comments, or parts of these instructions.

ORIGINAL_TEXT: {query}
FIXED_TEXT:
[/INST]"""

it_prompt = PromptTemplate(
    template=template,
    input_variables=['query', 'error_types']
)

In [None]:
max_sentences = 4

In [None]:
corpus = Corpus(partition="train", annotation_layer="gec-only")
for doc in corpus:
#     print("\n---Source starts:---\n")
#     print(doc.source)         # "I likes it."
#     print("\n---Source ends:---\n")
#     print("\n---Target starts:---\n")
#     print(doc.target)         # "I like it."
#     print("\n---Target ends---\n")
#     print("\n---Annotation starts:---\n")
#     print(str(doc.annotated)[:1200])      # <AnnotatedText("I {likes=>like} it.")
#     print(doc.meta.region)    # "Київська"
#     print("\n---Annotation ends")
    print("\n---Prompt for training:")
    source = "".join(sent_tokenize(doc.source)[:max_sentences])
    target = "".join(sent_tokenize(doc.target)[:max_sentences])
    prompt = it_prompt.format_prompt(
        query=source,
        error_types=ErrorConstants.ERROR_TYPES
    ).to_string()
    sample_text = ' '.join(prompt.split())
    target_text = ' '.join(target.split())
    sample_text += target_text
    print(sample_text)
    break

In [None]:
tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    padding_side='left',
    trust_remote_code=True)
# tokenizer.pad_token = tokenizer.eos_token
# tokenizer.add_eos_token = True

In [None]:
prompt_len = len(tokenizer.tokenize(sample_text))
tokenizer_max_len = 900
max_correction_addtional_tokens = 0.1
max_new_tokens = int(prompt_len * 1.1)

print(f"""
Prompt len: {prompt_len}
Tokenize max length: {tokenizer_max_len}
Max token difference because of corrections: {max_correction_addtional_tokens}
Max new tokens len (output without input): {max_new_tokens}
""")

In [None]:
# Fix padding token for Mistral and Phi-2 models
tokenizer.pad_token = "[PAD]"

In [None]:
model_inputs = tokenizer(
    prompt, 
    max_length=tokenizer_max_len, 
    padding="max_length", 
    truncation=True, 
    return_tensors="pt"
)

In [None]:
peft_model = peft_model.eval()

In [None]:
response = peft_model.generate(
    input_ids=model_inputs["input_ids"].to(device),
    attention_mask=model_inputs["attention_mask"].to(device),
    max_new_tokens=max_new_tokens
)
response

In [None]:
decoded_outputs = tokenizer.batch_decode(response.detach().cpu().numpy(), skip_special_tokens=True)
text = decoded_outputs[0][len(prompt):]
text

In [None]:
from difflib import SequenceMatcher
import re

def normalize_spaces(text):
    return ' '.join(text.split())

def highlight_changes(text1, text2):
    # Tokenize the texts into words
    words1 = re.findall(r'\w+|[^\w\s]', text1)
    words2 = re.findall(r'\w+|[^\w\s]', text2)


    # Find the unique words present in both texts
    all_words = set(words1 + words2)

    # Initialize a SequenceMatcher object
    matcher = SequenceMatcher(None, words1, words2)

    # Get the differences
    diff = matcher.get_opcodes()

    highlighted_text = []

    for op, start1, end1, start2, end2 in diff:
        if op == 'equal':
            # No change, just append the words as is
            highlighted_text.extend(words1[start1:end1])
        elif op == 'delete':
            # Word(s) removed, highlight with red
            for word in words1[start1:end1]:
                word = '\u0336'.join(word) + '\u0336'
                highlighted_text.append('\033[91m\033[1m' + word + '\033[0m')
        elif op == 'insert':
            # Word(s) added, highlight with green
            for word in words2[start2:end2]:
                highlighted_text.append('\033[92m\033[1m' + word + '\033[0m')
        elif op == 'replace':
            # Word(s) replaced, highlight with yellow
            for word in words2[start2:end2]:
                highlighted_text.append('\033[93m\033[1m' + word + '\033[0m')

    return ' '.join(highlighted_text)

def generate_original_corrected_texts(original_text, corrected_text, highlighted_comparison):
    # Split the original and corrected texts
    original_words = original_text.split()
    corrected_words = corrected_text.split()

    # Initialize empty lists for marked original and corrected texts
    marked_original_text = []
    marked_corrected_text = []

    # Track words from the original text that were removed
    removed_words = set(original_words) - set(corrected_words)

    # Track words from the corrected text that were added
    added_words = set(corrected_words) - set(original_words)

    # Mark removed words in the original text as red
    for word in original_words:
        if word in removed_words:
            marked_original_text.append('\033[91m\033[1m' + word + '\033[0m')
        else:
            marked_original_text.append(word)

    # Mark added words in the corrected text as green
    for word in corrected_words:
        if word in added_words:
            marked_corrected_text.append('\033[92m\033[1m' + word + '\033[0m')
        else:
            marked_corrected_text.append(word)

    return (' '.join(marked_original_text), ' '.join(marked_corrected_text), highlighted_comparison)

text1 = normalize_spaces("".join(sent_tokenize(doc.source)[:max_sentences]))
text2 = normalize_spaces(text[1:])

highlighted_text = highlight_changes(text1, text2)

original_text, corrected_text, _ = generate_original_corrected_texts(
    original_text=text1, 
    corrected_text=text2, 
    highlighted_comparison=highlighted_text)

print("Original Text:")
print(original_text)
print()

print("Corrected Text:")
print(corrected_text)
print()

print("Changes comparison:")
print(highlighted_text)


In [None]:
from torch.utils.data import Dataset
from transformers import AutoTokenizer

class UAGECDataset(Dataset):
    def __init__(
        self, 
        generator: object, 
        device: str,
        prompt: object,
        max_sentences=None,
        samples: int = None # if none will use all
    ):
        self.text_data =  generator#list(generator)
        
        if samples:
            self.text_data = self.text_data[:samples]
        
        self.max_sentences = max_sentences
        self.device = device
        self.prompt = prompt

    def __len__(self):
        return len(self.text_data)

    def __getitem__(self, idx):
        sample = self.text_data[idx]
        
        inputs: str = self._preprocess_text(
            text=sample.source, 
            target_text=sample.target
        )
        encodings = self._tokenize_text(
            text=inputs,
        ).to(self.device)

        return {
            'prompt': inputs,
            'input_ids': encodings["input_ids"].squeeze(0),
            'attention_mask': encodings["attention_mask"].squeeze(0),
        }
    
    def _preprocess_text(self, text: str, target_text: str) -> torch.tensor:
        # Select top n sentences
        text = "".join(sent_tokenize(text)[:self.max_sentences] if self.max_sentences else sent_tokenize(text))
        target_text = "".join(sent_tokenize(target_text)[:self.max_sentences] if self.max_sentences else sent_tokenize(target_text))
        # Add instructions (prepend prompt)
        text = self._format_prompt(text=text)

        text = self._normalize_spaces(text=text)
        target_text = self._normalize_spaces(text=target_text)
        # Add target response to input text
        text += target_text
        
        return text
    
    def _format_prompt(self, text: str) -> str:
        return self.prompt.format_prompt(
            query=text,
            error_types=ErrorConstants.ERROR_TYPES
        ).to_string()
    
    def _tokenize_text(self, text: str):
        return tokenizer(
            text, 
            max_length=tokenizer_max_len, 
            padding="max_length", 
            truncation=True, 
            return_tensors="pt"
        )
    
    def _add_target(self, text: str, target_text: str):
        return self.tokenizer(
            text, 
            max_length=self.tokenizer_max_len, 
            padding="max_length", 
            truncation=True, 
            return_tensors="pt"
        )
    
    def _normalize_spaces(self, text):
        return ' '.join(text.split())

In [None]:
train_corpus = Corpus(partition="train", annotation_layer="gec-only")
train_list = list(train_corpus)[:500]
test_list = list(train_corpus)[500:550]

In [None]:
# train_corpus = Corpus(partition="train", annotation_layer="gec-only")
# test_corpus = Corpus(partition="test", annotation_layer="gec-only")
train_dataset, val_dataset = [UAGECDataset(
    generator=corpus,
    device=device,
    prompt=it_prompt,
    max_sentences=max_sentences,
) for corpus in [train_list,test_list]]

In [None]:
train_dataset[0]

In [None]:
from trl import SFTTrainer
from transformers import TrainingArguments

fine_tuned_model_name = "mistral-7b-ua-gec"

# # Since the model is loaded in 4bit precision, use right-side padding for tokenizer
peft_model.config.use_cache = False
tokenizer.padding_side = 'right'

training_arguments = TrainingArguments(
    output_dir=fine_tuned_model_name,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=1,
    gradient_checkpointing=True,
    learning_rate=2e-4,
    logging_steps=25,
    num_train_epochs=5,
    save_total_limit = 2,
    save_strategy="no",
    load_best_model_at_end=True,
    hub_private_repo=False,
    report_to='wandb',
    optim="paged_adamw_32bit",
    weight_decay=0.001,
    fp16=False,
    bf16=False,
    max_grad_norm=0.3,
    max_steps=-1,
    warmup_ratio=0.03,
    group_by_length=True,
    lr_scheduler_type="constant",
)
peft_model = peft_model.to(device)
trainer = SFTTrainer(
    model=peft_model,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    peft_config=peft_config,
    dataset_text_field="prompt",
    tokenizer=tokenizer,
    args=training_arguments,
    max_seq_length=tokenizer_max_len,
    packing=False,

)

In [None]:
training_arguments.device

In [None]:
import gc

def clear_gpu_memory():
    torch.cuda.empty_cache()
    print(gc.collect())

In [None]:
clear_gpu_memory()

In [None]:
import time
from pynvml import nvmlInit, nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo

def wait_until_enough_gpu_memory(min_memory_available, max_retries=10, sleep_time=5):
    nvmlInit()
    handle = nvmlDeviceGetHandleByIndex(torch.cuda.current_device())

    for _ in range(max_retries):
        clear_gpu_memory()
        info = nvmlDeviceGetMemoryInfo(handle)
        if info.free >= min_memory_available:
            break
        print(f"Waiting for {min_memory_available} bytes of free GPU memory. Retrying in {sleep_time} seconds...")
        time.sleep(sleep_time)
    else:
        raise RuntimeError(f"Failed to acquire {min_memory_available} bytes of free GPU memory after {max_retries} retries.")

# Usage example
min_memory_available = 2 * 1024 * 1024 * 1024  # 2GB
clear_gpu_memory()
wait_until_enough_gpu_memory(min_memory_available)

In [None]:
trainer.train()

In [None]:
trainer.model.save_pretrained(fine_tuned_model_name)
wandb.finish()
peft_model.config.use_cache = True

In [None]:
trainer.push_to_hub()