<div style="padding: 0.5em; background-color: #1876d1; color: #fff; font-weight: bold; font-size: 1.4em;">
    [Approach 3]  Location Mention Recognition - Fine-tunning LLM
</div>

In this Jupyter notebook, we will use LLM to extract from X (Twitter formely) tweets Location Mention from Emergency Situation.

Step :
* Retreive dataset 
* Prepare prompt for fine-tunning
* Tested model: <span style="color: red;">01-Mistral-7B-Instruct v02</span>

---
<b>#Microsoft Learn Challenge, #Zindi, #Hamad Bin Khalifa University </b>

### **Importing Library**

In [1]:
# !pip install trl
# !pip install peft
# !pip install transformers accelerate bitsandbytes>0.37.0

In [2]:
# general utils
import werpy
import numpy as np
import pandas as pd
import seaborn as sns
import os, sys, requests
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder
pd.set_option('display.max_colwidth', 300)
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

# utils setup
current_directory = os.getcwd()
root_directory = os.path.abspath(os.path.join(current_directory, os.pardir))
sys.path.append(root_directory)

# logging & warning
import wandb, warnings
os.environ["WANDB_NOTEBOOK_NAME"] = "fine-tunne--t5-google.ipynb"
os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.7"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ['MALLOC_STACK_LOGGING'] = '0'
warnings.filterwarnings("ignore")

# custom utils
from utils.io import LMR_XML_Scrapper
from utils.preprocessing import Preprocess

# hugging face utils
import torch
from peft import LoraConfig
from datasets import Dataset
from dataclasses import dataclass
from trl import SFTConfig, SFTTrainer
from transformers.utils import PaddingStrategy
from transformers.integrations import WandbCallback
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers import AutoModelForSeq2SeqLM, AutoModelForCausalLM, AutoTokenizer
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
from transformers import T5Tokenizer, T5ForConditionalGeneration
from transformers import TFAutoModel, BitsAndBytesConfig

torch.cuda.empty_cache()

### **Preparing Data**

We have to dwonload the dataset and format in such way it an be used by the model

In [3]:
#LMR_XML_Scrapper(output_dir="../data/self_scrapped/raw-llm").run()

- Let concatenate out dataset

In [4]:
train_dfs = []
dev_dfs   = []
test_dfs  = []
path_dfs  = "../data/self_scrapped/raw-llm"
for filename in os.listdir(path_dfs):
    if filename.endswith(".csv"):
        file_path = os.path.join(path_dfs, filename)
        if filename.startswith("train"):
            df = pd.read_csv(file_path)
            train_dfs.append(df)
        elif filename.startswith("dev"):
            df = pd.read_csv(file_path)
            dev_dfs.append(df)
        elif filename.startswith("test_unlabeled"):
            df = pd.read_csv(file_path)
            test_dfs.append(df)

df_train = pd.concat(train_dfs, ignore_index=True) if train_dfs else pd.DataFrame()
df_test  = pd.concat(test_dfs, ignore_index=True) if test_dfs else pd.DataFrame()
df_dev   = pd.concat(dev_dfs, ignore_index=True) if dev_dfs else pd.DataFrame()

df_train.to_csv("../data/transformed/lmr_train.csv")
#df_test.to_csv("../data/transformed/lmr_test.csv")
df_dev.to_csv("../data/transformed/lmr_dev.csv")

print("TRAIN SHAPE: ", df_train.shape)
print("TEST  SHAPE: ", df_test.shape)
print("DEV   SHAPE: ", df_dev.shape)

TRAIN SHAPE:  (14392, 3)
TEST  SHAPE:  (4066, 3)
DEV   SHAPE:  (2056, 3)


- We observe in sentencess that we have hashtag, no-ascii character, stopword , ... we have to clean data 

In [5]:
df_train.head(5)

Unnamed: 0,tweet_id,plain_text,xml_text
0,ID_1022420413882744832,Nearly half of #houses checked in #fire-stricken areas deemed #uninhabitable #GO #PrayForGreece #PrayForAthens #AthensFires ἞C἟7,Nearly half of #houses checked in #fire-stricken areas deemed #uninhabitable #GO #PrayForGreece #PrayForAthens #AthensFires ἞C἟7
1,ID_1021778661895294976,RT @anadoluagency: #Greece: Death toll from wildfires hits 74,RT @anadoluagency: #<COUNTRY>Greece</COUNTRY>: Death toll from wildfires hits 74
2,ID_1022015997740503042,When the essence of cooperation meets the sad reality of lifeThe IPA partner country offers financial aid to Greece to handle disaster @InterregIPACBC #Greecefires,When the essence of cooperation meets the sad reality of lifeThe IPA partner country offers financial aid to <COUNTRY>Greece</COUNTRY> to handle disaster @InterregIPACBC #Greecefires
3,ID_1022557424585240576,We are live from the Lureio Idrima the orphanage and nursing home operared by the nuns of the Holy Trinity Monastery that was destroyed by the fire in Neos Voutzas. Here too the scene is apocalyptic.,We are live from the Lureio Idrima the orphanage and nursing home operared by the nuns of the <HUMAN-MADE-POINT-OF-INTEREST>Holy Trinity Monastery</HUMAN-MADE-POINT-OF-INTEREST> that was destroyed by the fire in <NEIGHBORHOOD>Neos Voutzas</NEIGHBORHOOD>. Here too the scene is apocalyptic.
4,ID_1021749412639457280,RT @AP: Greek prime minister declares 3-day national mourning period for dozens killed by wildfires near Athens.,RT @AP: Greek prime minister declares 3-day national mourning period for dozens killed by wildfires near <CITY>Athens</CITY>.


In [6]:
df_train = Preprocess.remove_non_ascii(df_train, column_name='plain_text')
df_train = Preprocess.remove_non_ascii(df_train, column_name='xml_text')
df_dev   = Preprocess.remove_non_ascii(df_dev, column_name='plain_text')
df_dev   = Preprocess.remove_non_ascii(df_dev, column_name='xml_text')

In [7]:
df_train.head(5)

Unnamed: 0,tweet_id,plain_text,xml_text
0,ID_1022420413882744832,Nearly half of #houses checked in #fire-stricken areas deemed #uninhabitable #GO #PrayForGreece #PrayForAthens #AthensFires C7,Nearly half of #houses checked in #fire-stricken areas deemed #uninhabitable #GO #PrayForGreece #PrayForAthens #AthensFires C7
1,ID_1021778661895294976,RT @anadoluagency: #Greece: Death toll from wildfires hits 74,RT @anadoluagency: #<COUNTRY>Greece</COUNTRY>: Death toll from wildfires hits 74
2,ID_1022015997740503042,When the essence of cooperation meets the sad reality of lifeThe IPA partner country offers financial aid to Greece to handle disaster @InterregIPACBC #Greecefires,When the essence of cooperation meets the sad reality of lifeThe IPA partner country offers financial aid to <COUNTRY>Greece</COUNTRY> to handle disaster @InterregIPACBC #Greecefires
3,ID_1022557424585240576,We are live from the Lureio Idrima the orphanage and nursing home operared by the nuns of the Holy Trinity Monastery that was destroyed by the fire in Neos Voutzas. Here too the scene is apocalyptic.,We are live from the Lureio Idrima the orphanage and nursing home operared by the nuns of the <HUMAN-MADE-POINT-OF-INTEREST>Holy Trinity Monastery</HUMAN-MADE-POINT-OF-INTEREST> that was destroyed by the fire in <NEIGHBORHOOD>Neos Voutzas</NEIGHBORHOOD>. Here too the scene is apocalyptic.
4,ID_1021749412639457280,RT @AP: Greek prime minister declares 3-day national mourning period for dozens killed by wildfires near Athens.,RT @AP: Greek prime minister declares 3-day national mourning period for dozens killed by wildfires near <CITY>Athens</CITY>.


- Get a look of extracted tags and prepare list of entities

In [8]:
tags_list = pd.read_csv("../utils/tag_description.csv")
tags_list

Unnamed: 0,original_type,xml_tag
0,State,STATE
1,County,COUNTY
2,City/town,CITY
3,Road/street,ROAD
4,Island,ISLAND
5,Human-made Point-of-Interest,HUMAN-MADE-POINT-OF-INTEREST
6,Continent,CONTINENT
7,Neighborhood,NEIGHBORHOOD
8,Natural Point-of-Interest,NATURAL-POINT-OF-INTEREST
9,District,DISTRICT


In [9]:
"""
tags_description = [
    "CONTINENT: A large continuous landmass on the Earth's surface",
    "COUNTRY: A nation or territory recognized as an independent state",
    "COUNTY: A geographical region within a country, often a subdivision of a state or province.",
    "DISTRICT: An administrative division within a city, county, or country.",
    "STATE: A larger administrative division within a country",
    "CITY: A large and significant urban area",
    "ROAD: A street, avenue, or highway that connects different locations.",
    "ISLAND: A landmass completely surrounded by water.",
    "NEIGHBORHOOD: A localized community within a city or town.",
    "HUMAN-MADE-POINT-OF-INTEREST: A landmark created by humans, such as monuments or buildings.",
    "NATURAL-POINT-OF-INTEREST: A location of natural significance, like mountains or rivers.",
]"""
tags_description = [
    "CONTINENT: A large landmass on Earth.",
    "COUNTRY: An independent nation.",
    "COUNTY: A region within a country.",
    "DISTRICT: An administrative area within a city",
    "STATE: A large administrative division within a country.",
    "CITY: A significant urban area.",
    "ROAD: A street, avenue connecting locations.",
    "ISLAND: A landmass surrounded by water.",
    "NEIGHBORHOOD: A community within a city.",
    "HUMAN-MADE-POINT-OF-INTEREST: A landmark created by humans, like monuments.",
    "NATURAL-POINT-OF-INTEREST: A naturally significant location, like mountains."
]

### **Init model utils**

Init model and tokenizer

In [10]:
compute_dtype = getattr(torch, 'float16')
quant_config =  BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=compute_dtype,
    bnb_4bit_use_double_quant=True,
    llm_int8_enable_fp32_cpu_offload=True
)

In [14]:
special_tokens = ['<STATE>', '</STATE>', '<ROAD>', '</ROAD>', '<COUNTY>', '</COUNTY>', '<CONTINENT>', '</CONTINENT>', '<NATURAL-POINT-OF-INTEREST>', '</NATURAL-POINT-OF-INTEREST>', '<NEIGHBORHOOD>', '</NEIGHBORHOOD>', '<COUNTRY>', '</COUNTRY>', '<CITY>', '</CITY>', '<DISTRICT>', '</DISTRICT>', '<ISLAND>', '</ISLAND>', '<HUMAN-MADE-POINT-OF-INTEREST>', '</HUMAN-MADE-POINT-OF-INTEREST>', '[INST]', '[/INST]', '[PAD]']

#tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small", additional_special_tokens=special_tokens)
#model     = TFAutoModel.from_pretrained("google-t5/t5-small", return_dict=True)

#tokenizer = T5Tokenizer.from_pretrained('t5-small', additional_special_tokens=special_tokens)
#model = T5ForConditionalGeneration.from_pretrained('t5-small', return_dict=True)

tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2", additional_special_tokens=special_tokens)
model     = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2", quantization_config=quant_config, device_map='auto')
#model     = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-small")

RuntimeError: No GPU found. A GPU is needed for quantization.

In [11]:
tokenizer.pad_token = "[PAD]"
tokenizer.padding_side = "right"
model.resize_token_embeddings(len(tokenizer))

Embedding(32125, 512)

### **Including Chain-Of-Thought in the prompt design**



In [12]:
def apply_chat_template(messages):
    formatted_messages = []
    for message in messages:
        if message['role'] == 'user':
            formatted_messages.append(f"[INST] {message['content']} [/INST]")
        elif message['role'] == 'assistant':
            formatted_messages.append(f"{message['content']}")
    return '<s>' + ''.join(formatted_messages) + '</s>'

In [13]:
messages = [
    {"role": "user", "content": "usr_msg1"},
    {"role": "assistant", "content": "asst_msg1"},
    {"role": "user", "content": "usr_msg2"},
    {"role": "assistant", "content": "asst_msg2"},
]
#tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
apply_chat_template(messages)

'<s>[INST] usr_msg1 [/INST]asst_msg1[INST] usr_msg2 [/INST]asst_msg2</s>'

In [14]:
def formating_prompt(rule_set: List[str], input_str: str, label_str: str, tokenizer: PreTrainedTokenizerBase) -> torch.Tensor:
    rule_str = "\n".join(rule_set)

    # Message 1
    usr_msg1 = "Identify and tag all location mentions in the microblogging post using the provided location entity types." \
               f"\n\nEntity List:\n{rule_str}\n\n" \
               "Are the instructions clear?"
    asst_msg1 = "Yes, I will tag all location entities as specified, keeping the rest of the text unchanged."

    # Message 2
    usr_msg2 = "Florida Bahamas flooding shuts down train lines."
    asst_msg2 = "<STATE>Florida</STATE> <ISLAND>Bahamas</ISLAND> flooding shuts down train lines."

    # Message 3
    usr_msg3 = "Explain why your answer is correct."
    asst_msg3 = "I tagged 'Florida' as <STATE> and 'Bahamas' as <ISLAND>, following the location entities list."

    # Message 4
    usr_msg4 = "Now, tag another user post according to the same instructions. No explanation needed."
    asst_msg4 = "Sure, please provide the user post."

    # Encode in conversation
    messages = [
        {"role": "user", "content": usr_msg1},
        {"role": "assistant", "content": asst_msg1},
        {"role": "user", "content": usr_msg2},
        {"role": "assistant", "content": asst_msg2},
        {"role": "user", "content": usr_msg3},
        {"role": "assistant", "content": asst_msg3},
        {"role": "user", "content": usr_msg4},
        {"role": "assistant", "content": asst_msg4},
        {"role": "user", "content": input_str},
        {"role": "assistant", "content": label_str},
    ]

    return messages

In [15]:
def create_prompt_df(df: pd.DataFrame, rule_set: List[str], tokenizer: PreTrainedTokenizerBase) -> pd.DataFrame:
    prompt_list = []
    for _, row in df.iterrows():
        input_str = row['plain_text']
        label_str = row['xml_text']
        prompt_dict = formating_prompt(rule_set, input_str, label_str, tokenizer)
        prompt_list.append(prompt_dict)
    
    dataset = Dataset.from_dict({"prompt": prompt_list})
    #dataset = dataset.map(lambda x: {"formatted_prompt": tokenizer.apply_chat_template(x["prompt"], tokenize=False, add_generation_prompt=False)})
    dataset = dataset.map(lambda x: {"formatted_prompt": apply_chat_template(x["prompt"])})
    return dataset

### **Preparing model for fine-tunning**

In [16]:
dataset = {} 
dataset['train'] = create_prompt_df(df_train, tags_description, tokenizer)
dataset['eval']  = create_prompt_df(df_dev, tags_description, tokenizer)

Map:   0%|          | 0/14392 [00:00<?, ? examples/s]

Map:   0%|          | 0/2056 [00:00<?, ? examples/s]

In [17]:
max_formatted_prompt_length = max(len(prompt) for prompt in dataset['train']['formatted_prompt'])
max_formatted_prompt_length

3135

- Data Collator

In [18]:
@dataclass
class CustomDataCollatorWithPadding:
    """
    Data collator that will dynamically pad the inputs received.

    Args:
        tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
            The tokenizer used for encoding the data.
        padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
            Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
            among:

            - `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single
              sequence is provided).
            - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
              acceptable input length for the model if that argument is not provided.
            - `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths).
        max_length (`int`, *optional*):
            Maximum length of the returned list and optionally padding length (see above).
        pad_to_multiple_of (`int`, *optional*):
            If set will pad the sequence to a multiple of the provided value.

            This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
            7.5 (Volta).
        return_tensors (`str`, *optional*, defaults to `"pt"`):
            The type of Tensor to return. Allowable values are "np", "pt" and "tf".
    """

    tokenizer: PreTrainedTokenizerBase
    padding: Union[bool, str, PaddingStrategy] = True
    max_length: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    return_tensors: str = "pt"

    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
        batch = self.tokenizer.pad(
            features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors=self.return_tensors,
        )
        labels = batch["input_ids"].clone()
        
        # Set loss mask for all pad tokens
        labels[labels == self.tokenizer.pad_token_id] = -100
        
        # Compute loss mask for appropriate tokens only
        for i in range(batch['input_ids'].shape[0]):
            
            # Decode the training input
            text_content = self.tokenizer.decode(batch['input_ids'][i][1:])  # slicing from [1:] is important because tokenizer adds bos token
            
            # Extract substrings for prompt text in the training input
            # The training input ends at the last user msg ending in [/INST]
            prompt_gen_boundary = text_content.rfind("[/INST]") + len("[/INST]")
            prompt_text = text_content[:prompt_gen_boundary]
            
            # print(f"""PROMPT TEXT:\n{prompt_text}""")
            
            # retokenize the prompt text only
            prompt_text_tokenized = self.tokenizer(
                prompt_text,
                return_overflowing_tokens=False,
                return_length=False,
            )
            # compute index where prompt text ends in the training input
            prompt_tok_idx = len(prompt_text_tokenized['input_ids'])
            
            # Set loss mask for all tokens in prompt text
            labels[i][range(prompt_tok_idx)] = -100
            
            # print("================DEBUGGING INFORMATION===============")
            # for idx, tok in enumerate(labels[i]):
            #     token_id = batch['input_ids'][i][idx]
            #     decoded_token_id = self.tokenizer.decode(batch['input_ids'][i][idx])
            #     print(f"""TOKID: {token_id} | LABEL: {tok} || DECODED: {decoded_token_id}""")
                    
        batch["labels"] = labels
        return batch

- Model Building

In [19]:
max_seq_length = pd.concat([df_train, df_dev])['plain_text'].apply(len).max() + 400
max_seq_length

1345

In [20]:
wandb.init(project="T5-GOOGLE-MPS-TRAIN", name="Location Mention Recognition")
wandb.config.update({
    "max_seq_length": max_seq_length,
    "output_dir": "T5-Output",
})

[34m[1mwandb[0m: Currently logged in as: [33mgenereux-akotenou[0m ([33mgenereux-akotenou-local[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [21]:
peft_config = LoraConfig(
    r=64,
    lora_alpha=16,
    lora_dropout=0.1,
    bias="none",
    task_type="CAUSAL_LM",
)

In [24]:
# Metrics
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

# Args
tokenizer.padding_side = 'right'
max_seq_length = pd.concat([df_train, df_dev])['plain_text'].apply(len).max()
training_arguments = SFTConfig(
    fp16=False,
    bf16=False,
    packing=False,
    dataset_text_field="formatted_prompt",
    output_dir="T5-Output",
    max_seq_length=max_seq_length,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=4,
    #gradient_checkpointing=True,
    report_to="wandb",
    logging_dir="./logs",
    eval_strategy="steps",
    eval_steps=50,
    save_steps=100,
)

model.gradient_checkpointing_enable()

# Trainer
trainer = SFTTrainer(
    model=model.to(torch.device('mps')),
    train_dataset=dataset['train'],
    eval_dataset=dataset['eval'],
    peft_config=peft_config,
    tokenizer=tokenizer,
    args=training_arguments,
    data_collator=CustomDataCollatorWithPadding(
        tokenizer=tokenizer, 
        padding="longest", 
        max_length=max_seq_length, 
        return_tensors="pt"
    ),
    compute_metrics=compute_metrics,
    callbacks=[WandbCallback()]
)

Map:   0%|          | 0/14392 [00:00<?, ? examples/s]

Map:   0%|          | 0/2056 [00:00<?, ? examples/s]

You are adding a <class 'transformers.integrations.integration_utils.WandbCallback'> to the callbacks of this Trainer, but there is already one. The currentlist of callbacks is
:DefaultFlowCallback
WandbCallback


#### **Model Training**

In [25]:
trainer.train()



  0%|          | 0/10794 [00:00<?, ?it/s]

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


  0%|          | 0/2056 [00:00<?, ?it/s]

python(92142) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(92147) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(92154) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(92160) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(92171) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(92181) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(92189) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(92195) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(92202) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(92208) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(92221) Malloc

RuntimeError: Invalid buffer size: 8.03 GB

In [None]:
***

#### **Make Inference**

In [None]:
df_context = pd.read_csv('../data/provided/Test.csv')

In [None]:
pretrained_model     = trainer.model
pretrained_tokenizer = trainer.tokenizer

In [None]:
text = (f"Extract from this input all location mention by surrounding these location with xml tag. Sentence: '{df_context.iloc[0].text}'")
text = "Translate in french: 'SCHOOL'"
inputs = pretrained_tokenizer.encode(text,
                                     return_tensors='pt',
                                     max_length=512,
                                     truncation=True).to(torch.device('cuda'))

generated_ids = pretrained_model.generate(inputs, 
                                          max_length=80, 
                                          min_length=40, 
                                          length_penalty=5.0, 
                                          num_beams=2)

generated = pretrained_tokenizer.decode(generated_ids[0], skip_special_tokens=True)
print("[INPUT_]: ", text, '\n')
print("[RESULT]: ", generated)

In [None]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, T5Config

def load_pretrained_(model_path='T5-LMR'):
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map="auto")
    return model, tokenizer

def load_pretrained(model_path='T5-LMR'):
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    config = T5Config.from_pretrained(model_path)
    config.vocab_size = tokenizer.vocab_size  # Adjust vocab size based on tokenizer
    model = AutoModelForSeq2SeqLM.from_pretrained(model_path, config=config, torch_dtype=torch.float16, device_map="auto")
    return model, tokenizer

def load_pretrained(model_path='T5-LMR'):
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map="auto", ignore_mismatched_sizes=True)
    return model, tokenizer
    
"""
def make_inference_with_pretrained(model, tokenizer, prompt, max_seq_length=max_seq_length):
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=max_seq_length)
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    print(device)
    model.to(device)
    
    # Perform inference
    with torch.no_grad():
        outputs = model(**inputs.to(device))
    logits = outputs.logits
    predicted_ids = torch.argmax(logits, dim=-1)
    predicted_text = tokenizer.decode(predicted_ids[0], skip_special_tokens=True)
    
    print(predicted_text)
"""

def make_inference_with_pretrained_(model, tokenizer, prompt, max_seq_length=max_seq_length):
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=max_seq_length)
    decoder_input_ids = torch.tensor([[model.config.decoder_start_token_id]])
    
    # Move model and inputs to the appropriate device
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    model.to(device)
    inputs = {key: value.to(device) for key, value in inputs.items()}
    decoder_input_ids = decoder_input_ids.to(device)
    
    # Perform inference
    with torch.no_grad():
        outputs = model(input_ids=inputs['input_ids'], decoder_input_ids=decoder_input_ids)
    
    # Get logits and predict the most likely token IDs
    logits = outputs.logits
    predicted_ids = torch.argmax(logits, dim=-1)
    
    # Decode the predicted token IDs to text
    predicted_text = tokenizer.decode(predicted_ids[0], skip_special_tokens=True)
    
    #print(predicted_text)
    return predicted_text

def make_inference_with_pretrained(model, tokenizer, prompt, max_seq_length=max_seq_length, max_length=max_seq_length):
    # Tokenize input prompt
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=max_seq_length)
    
    # Move model and inputs to the appropriate device
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    model.to(device)
    inputs = {key: value.to(device) for key, value in inputs.items()}
    
    # Perform inference using the model's generate method for autoregressive decoding
    with torch.no_grad():
        generated_ids = model.generate(
            input_ids=inputs['input_ids'],
            max_length=max_length,
            num_beams=5,  # Beam search for better results; you can adjust or remove this
            early_stopping=True
        )

    print(generated_ids)
    # Decode the generated token IDs to text
    predicted_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    
    return predicted_text

In [None]:
#mlr_model, mlr_tokenizer = load_pretrained(model_path="T5-Output/checkpoint-6")

In [None]:
df_context = pd.read_csv('../data/provided/Test.csv')
df_context.head()

In [None]:
text = f"Extract from this input all location mention by surrounding these location with xml tag. Sentence: '{df_context.iloc[0].text}'"
text = f"Translate in french: 'You are strong'"
text

In [None]:
make_inference_with_pretrained(trainer.model, trainer.tokenizer, text)

In [None]:
model = AutoModelForCausalLM.from_pretrained('Mistral-7B-LMR', torch_dtype=torch.float16, device_map="auto")
#model.to("cuda")

prompt = "please help and donate to local charities"
tokens = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=max_seq_length)
generated_ids = model.generate(tokens, max_new_tokens=1000, do_sample=True)

# decode with mistral tokenizer
result = tokenizer.decode(generated_ids[0].tolist())
print(result)