In [1]:
import os
import wandb
import torch
from langchain.prompts import PromptTemplate
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, StoppingCriteria, StoppingCriteriaList 
from typing import Dict, List, Generator
from datasets import load_dataset, load_from_disk, concatenate_datasets
from peft import PeftModel
from transformers import GenerationConfig
import json 
import ast
from IPython.display import Markdown, display

In [2]:
BASE_MODEL = "mistralai/Mistral-7B-v0.1"
QUANTIZATION_CONFIG = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

In [3]:
class StopGenerationCriteria(StoppingCriteria):
    def __init__(
            self, 
            stop_words: List[str], 
            tokenizer: AutoTokenizer, 
            device: torch.device
            ) -> None:
        stop_words = [' ' + stop_word for stop_word in stop_words]
        stop_token_ids = [tokenizer(t, add_special_tokens=False)['input_ids'][1:] for t in stop_words]
        self.stop_token_ids = [
            torch.LongTensor(x).to(device) for x in stop_token_ids
        ]

    def __call__(
            self, 
            input_ids: torch.LongTensor, 
            scores: torch.FloatTensor, 
            **kwargs
            ) -> bool:
        for stop_ids in self.stop_token_ids:
            if torch.eq(input_ids[0][-len(stop_ids):], stop_ids).all():
                return True
        return False

class Model:
    def __init__(self, checkpoint_dir: str) -> None:
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        print("THE DEVICE INFERENCE IS RUNNING ON IS: ", self.device)
        self.tokenizer = None
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.stopping_criteria = None
        self.checkpoint_dir = checkpoint_dir
    
    def get_checkpoint_dir(self):
        run = wandb.init()
        checkpoint = run.use_artifact(self.wandb_checkpoint_name, type='model')
        checkpoint_dir = checkpoint.download()
        return checkpoint_dir

    def load(self):
        self.tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.base_model = AutoModelForCausalLM.from_pretrained(
            BASE_MODEL,  # Mistral, same as before
            quantization_config=QUANTIZATION_CONFIG,  # Same quantization config as before
            # torch_dtype=torch.bfloat16,
            device_map="auto",
            trust_remote_code=True
        )

        # load and merge the checkpoint with base model
        self.ft_model = PeftModel.from_pretrained(self.base_model, self.checkpoint_dir)
        self.ft_model.eval()

        self.gen_cfg = GenerationConfig.from_model_config(self.ft_model.config)
        self.gen_cfg.max_new_tokens = 1024
        self.gen_cfg.temperature = 0.5
        self.gen_cfg.num_return_sequences = 1
        self.gen_cfg.use_cache = True
        self.gen_cfg.min_length = 1

    def predict(self, request: Dict) -> Dict | Generator:
        with torch.no_grad():
            prompt = request.pop("prompt")
            stop_words = []
            if "stop" in request:
                stop_words = request.pop("stop")
            inputs = self.tokenizer(prompt, return_tensors="pt")
            input_ids = inputs["input_ids"].cuda()
            generation_output = self.ft_model.generate(
                input_ids=input_ids,
                pad_token_id=self.tokenizer.eos_token_id,
                stopping_criteria = StoppingCriteriaList(
                    [StopGenerationCriteria(
                        stop_words,
                        self.tokenizer,
                        self.device
                        )]),
                generation_config=self.gen_cfg,
                return_dict_in_generate=True,
                output_scores=True,
                max_new_tokens=256
            )
            outputs = []
            for seq in generation_output.sequences:
                output = self.tokenizer.decode(seq)
                outputs.append(output)

            return "\n".join(outputs)

In [13]:
def get_relation(example):
    RELATION_NAMES=['country', 'place of birth', 'spouse', 'country of citizenship', 'instance of',
            'capital', 'child', 'shares border with', 'author', 'director', 'occupation',
              'founded by', 'league', 'owned by', 'genre', 'named after', 'follows',
                'headquarters location', 'cast member', 'manufacturer',
                  'located in or next to body of water', 'location', 'part of', 
                  'mouth of the watercourse', 'member of', 'sport', 'characters',
                    'participant', 'notable work', 'replaces', 'sibling', 'inception']
    relations = []
    for relation in example['relations']:
        relation_dict = {}
        relation_dict["object"] = relation['object']['surfaceform']
        relation_dict["subject"] = relation['subject']['surfaceform']
        relation_dict["predicate"] = RELATION_NAMES[relation['predicate']]
        relations.append(relation_dict)

    return relations

In [5]:
def get_prompt_template() -> PromptTemplate:
    prompt_template = """### Instruction:
You are an expert in data science and natural language processing (NLP).
Your task is to extract triplets from the text provided below.
A knowledge triplet is made up of 2 entities (subject and object) linked by a predicate: 
{{"Object": "", "Predicate": "", "Subject": "" }}
Multiple triplets must be in list form.
Text: {text}\n
### Response:"""
    input_variables = ["text"]
    return PromptTemplate(
        template=prompt_template,
        input_variables=input_variables,
    )

In [6]:
def get_prompt(test_example):
    prompt_template = get_prompt_template()

    relations_prompt = prompt_template.format(
        text=test_example["text"],
    )

    test_example["prompt"] = relations_prompt
    return test_example

## Use checkpoint to make predictions (solar-disco-79, checkpoint-4grkql3s:v10)

In [7]:
checkpoint_dir = "./checkpoints/solar-disco-79/checkpoint-4grkql3s:v10"
ft_model = Model(checkpoint_dir=checkpoint_dir)
ft_model.load()

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


THE DEVICE INFERENCE IS RUNNING ON IS:  cuda


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [8]:
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [15]:
import wandb
run = wandb.init()
artifact = run.use_artifact('xianli/digital_safety/model-4grkql3s:v0', type='model')
artifact_dir = artifact.download()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mxianli[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Downloading large artifact model-4grkql3s:v0, 164.08MB. 7 files... 
[34m[1mwandb[0m:   7 of 7 files downloaded.  
Done. 0:0:1.5


Error in callback <bound method _WandbInit._pause_backend of <wandb.sdk.wandb_init._WandbInit object at 0x7f2dda2a5690>> (for post_run_cell), with arguments args (<ExecutionResult object at 7f2ddc293490, execution_count=15 error_before_exec=None error_in_exec=None info=<ExecutionInfo object at 7f2ddc291390, raw_cell="import wandb
run = wandb.init()
artifact = run.use.." store_history=True silent=False shell_futures=True cell_id=vscode-notebook-cell://ssh-remote%2B52.57.169.192/home/ubuntu/fine-tuning/generate_dpo_dataset.ipynb#X44sdnNjb2RlLXJlbW90ZQ%3D%3D> result=None>,),kwargs {}:


TypeError: _WandbInit._pause_backend() takes 1 positional argument but 2 were given

Load the REDFM dataset

In [5]:
train_fr_dataset = load_dataset("Babelscape/REDFM", language="fr", split="train")
train_en_dataset = load_dataset("Babelscape/REDFM", language="en", split="train")
assert train_fr_dataset.features.type == train_en_dataset.features.type
sliced_dataset_en = train_en_dataset.shuffle(seed=42).select(
    range(train_fr_dataset.num_rows)
)
concat_dataset = concatenate_datasets(
    [sliced_dataset_en, train_fr_dataset]
).shuffle(seed=80)

In [7]:
concat_dataset

Dataset({
    features: ['docid', 'title', 'uri', 'text', 'entities', 'relations'],
    num_rows: 3730
})

In [11]:
prompt_dataset = concat_dataset.map(get_prompt)
prompt_dataset

Dataset({
    features: ['docid', 'title', 'uri', 'text', 'entities', 'relations', 'prompt'],
    num_rows: 3730
})

In [12]:
def get_prediction(test_example):
    output = ft_model.predict(request={
        "prompt": test_example["prompt"],
        "temperature": 0.1, 
        "max_new_tokens": 1024,
        # "stop": ["\n\n"]
        "stop": ["\n\n### Instruction:"]
    })
    test_example["prediction"] = output
    return test_example

In [13]:
pred_dataset = prompt_dataset.map(get_prediction)
pred_dataset.save_to_disk("./datasets/redfm_pred_dataset")

Map:   0%|          | 0/3730 [00:00<?, ? examples/s]



Saving the dataset (0/1 shards):   0%|          | 0/3730 [00:00<?, ? examples/s]

In [4]:
pred_dataset = load_from_disk("./datasets/redfm_pred_dataset")
pred_dataset

Dataset({
    features: ['docid', 'title', 'uri', 'text', 'entities', 'relations', 'prompt', 'prediction'],
    num_rows: 3730
})

In [5]:
def string_to_dict(string):
    string = string.replace("</s>", "").replace("Relations: ", "").replace("\n", "")
    relations = string.split("}, ")
    relations_dict = []
    for re in relations:
        re = re.strip("[]").strip("\{\}")
        re_dict = {}
        for k_v in re.split(", "):
            k, v = tuple(k_v.split(": "))
            re_dict[k.replace("\"", "")] = v.replace("\"", "")
        relations_dict.append(re_dict)
    return relations_dict

In [6]:
def pred_to_dict(example):
    string = example["prediction"].split("\n\n### Response:")[-1]
    string = string.replace("</s>", "").replace("Relations: ", "")
    try:
        pred_dict = json.loads(string)
    except:
        pred_dict = string_to_dict(string)
    example["pred_dict"] = pred_dict
    return example

In [7]:
pred_dict_dataset = pred_dataset.map(pred_to_dict)
pred_dict_dataset

Map:   0%|          | 0/3730 [00:00<?, ? examples/s]

Dataset({
    features: ['docid', 'title', 'uri', 'text', 'entities', 'relations', 'prompt', 'prediction', 'pred_dict'],
    num_rows: 3730
})

In [10]:
def are_lists_of_dicts_identical(list1, list2):
    # Check if the lengths of the lists are the same
    if len(list1) != len(list2):
        return False

    # Iterate through each dictionary in the lists and compare them
    for dict1, dict2 in zip(list1, list2):
        if dict1 != dict2:
            return False

    # If all dictionaries are identical, return True
    return True

In [11]:
def return_prompt_and_responses(example):
    ground_truth = get_relation(example)
    prediction = example["pred_dict"]
    if are_lists_of_dicts_identical(ground_truth, prediction):
        example["chosen"] =  json.dumps(ground_truth, ensure_ascii=False)
        example["rejected"] = ""
    else:
        example["chosen"] =  json.dumps(ground_truth, ensure_ascii=False)
        example["rejected"] = json.dumps(prediction, ensure_ascii=False)
    return example

In [15]:
json.loads(return_prompt_and_responses(pred_dict_dataset[10])["rejected"])

[{'Object': 'Loki', 'Predicate': 'parent', 'Subject': 'Jörmungand'}]

In [36]:
json.loads(return_prompt_and_responses(pred_dict_dataset[0])["rejected"])

[{'Object': 'Rouen', 'Predicate': 'capital', 'Subject': 'Normandy'}]

In [16]:
dpo_dataset = pred_dict_dataset.map(return_prompt_and_responses)
dpo_dataset

Map:   0%|          | 0/3730 [00:00<?, ? examples/s]

Dataset({
    features: ['docid', 'title', 'uri', 'text', 'entities', 'relations', 'prompt', 'prediction', 'pred_dict', 'chosen', 'rejected'],
    num_rows: 3730
})

In [17]:
dpo_dataset[10]

{'docid': '86182-1',
 'title': 'Jörmungand',
 'uri': 'Q181227',
 'text': 'Selon l’"Edda de Snorri", il est le fils du dieu malin Loki et de la géante Angrboda, et le frère du loup Fenrir ainsi que de la déesse du monde des morts Hel. Peu après sa naissance, le dieu Odin jette Jörmungand dans la mer qui encercle Midgard, puisque les prophéties racontent qu\'il causera de grands dégâts chez les dieux. Mais ce dernier grandit tellement qu\'il finit par entourer le monde et se mordre la queue, d\'où son autre nom, Midgardsorm "(Miðgarðsormr)", .',
 'entities': [{'uri': 'Q205882',
   'surfaceform': 'Edda de Snorri',
   'type': 'MEDIA',
   'start': 9,
   'end': 23},
  {'uri': 'Q133147',
   'surfaceform': 'Loki',
   'type': 'MISC',
   'start': 55,
   'end': 59},
  {'uri': 'Q210053',
   'surfaceform': 'géante',
   'type': 'Concept',
   'start': 69,
   'end': 75},
  {'uri': 'Q371828',
   'surfaceform': 'Angrboda',
   'type': 'MISC',
   'start': 76,
   'end': 84},
  {'uri': 'Q182560',
   'surfac

In [18]:
dpo_dataset.filter(lambda example: example["rejected"]!="")

Filter:   0%|          | 0/3730 [00:00<?, ? examples/s]

Dataset({
    features: ['docid', 'title', 'uri', 'text', 'entities', 'relations', 'prompt', 'prediction', 'pred_dict', 'chosen', 'rejected'],
    num_rows: 3730
})

In [19]:
dpo_dataset = dpo_dataset.remove_columns(['docid', 'title', 'uri', 'text', 'entities', 'relations', 'prediction', 'pred_dict'])
dpo_dataset

Dataset({
    features: ['prompt', 'chosen', 'rejected'],
    num_rows: 3730
})

In [20]:
dpo_dataset[10]

{'prompt': '### Instruction:\nYou are an expert in data science and natural language processing (NLP).\nYour task is to extract triplets from the text provided below.\nA knowledge triplet is made up of 2 entities (subject and object) linked by a predicate: \n{"Object": "", "Predicate": "", "Subject": "" }\nMultiple triplets must be in list form.\nText: Selon l’"Edda de Snorri", il est le fils du dieu malin Loki et de la géante Angrboda, et le frère du loup Fenrir ainsi que de la déesse du monde des morts Hel. Peu après sa naissance, le dieu Odin jette Jörmungand dans la mer qui encercle Midgard, puisque les prophéties racontent qu\'il causera de grands dégâts chez les dieux. Mais ce dernier grandit tellement qu\'il finit par entourer le monde et se mordre la queue, d\'où son autre nom, Midgardsorm "(Miðgarðsormr)", .\n\n### Response:',
 'chosen': '[{"object": "géante", "subject": "Loki", "predicate": "member of"}, {"object": "Fenrir", "subject": "Loki", "predicate": "child"}, {"object"

In [21]:
dpo_dataset = dpo_dataset.filter(lambda example: len(example["prompt"]) < 1024, load_from_cache_file=False)
dpo_dataset

Filter:   0%|          | 0/3730 [00:00<?, ? examples/s]

Dataset({
    features: ['prompt', 'chosen', 'rejected'],
    num_rows: 3079
})

In [22]:
dpo_dataset.save_to_disk("./datasets/dpo_dataset/")

Saving the dataset (0/1 shards):   0%|          | 0/3079 [00:00<?, ? examples/s]

In [None]:
model = AutoPeftModelForCausalLM.from_pretrained(
    script_args.model_name_or_path,
    low_cpu_mem_usage=True,
    torch_dtype=torch.float16,
    load_in_4bit=True,
    is_trainable=True,
)
model_ref = AutoPeftModelForCausalLM.from_pretrained(
    script_args.model_name_or_path,
    low_cpu_mem_usage=True,
    torch_dtype=torch.float16,
    load_in_4bit=True,
)
...
dpo_trainer = DPOTrainer(
    model,
    model_ref,
    args=training_args,
    beta=script_args.beta,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    peft_config=peft_config,
)
dpo_trainer.train()
dpo_trainer.save_model()

In [23]:
dpo_dataset = load_from_disk(".//datasets/dpo_dataset")
dpo_dataset[10]

{'prompt': '### Instruction:\nYou are an expert in data science and natural language processing (NLP).\nYour task is to extract triplets from the text provided below.\nA knowledge triplet is made up of 2 entities (subject and object) linked by a predicate: \n{"Object": "", "Predicate": "", "Subject": "" }\nMultiple triplets must be in list form.\nText: Selon l’"Edda de Snorri", il est le fils du dieu malin Loki et de la géante Angrboda, et le frère du loup Fenrir ainsi que de la déesse du monde des morts Hel. Peu après sa naissance, le dieu Odin jette Jörmungand dans la mer qui encercle Midgard, puisque les prophéties racontent qu\'il causera de grands dégâts chez les dieux. Mais ce dernier grandit tellement qu\'il finit par entourer le monde et se mordre la queue, d\'où son autre nom, Midgardsorm "(Miðgarðsormr)", .\n\n### Response:',
 'chosen': '[{"object": "géante", "subject": "Loki", "predicate": "member of"}, {"object": "Fenrir", "subject": "Loki", "predicate": "child"}, {"object"