In [1]:
from dataclasses import dataclass, field
from typing import Optional

import torch
from accelerate import Accelerator
from datasets import load_dataset, concatenate_datasets, DatasetDict
from peft import LoraConfig
from tqdm import tqdm
from transformers import (
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    AutoTokenizer,
)
from langchain.prompts import PromptTemplate
from trl import SFTTrainer
import os, wandb

In [None]:
dataset_language = "fr"

In [None]:
def get_fr_relation(example):
    entities_ls = example["entities"]
    relations = []
    for relation in example["relations"]:
        relation_dict = {}
        object_index = relation["object"]
        relation_dict["Objet"] = entities_ls[object_index]["surfaceform"]
        relation_dict["Prédicat"] = relation["predicate"]
        subject_index = relation["subject"]
        relation_dict["Subjet"] = entities_ls[subject_index]["surfaceform"]
        relations.append(relation_dict)

    return str(relations)


def get_en_relation(example):
    entities_ls = example["entities"]
    relations = []
    for relation in example["relations"]:
        relation_dict = {}
        object_index = relation["object"]
        relation_dict["Object"] = entities_ls[object_index]["surfaceform"]
        relation_dict["Predicate"] = relation["predicate"]
        subject_index = relation["subject"]
        relation_dict["Subject"] = entities_ls[subject_index]["surfaceform"]
        relations.append(relation_dict)

    return str(relations)


def get_entities(example):
    entities_ls = list(set([entity["surfaceform"] for entity in example["entities"]]))
    return str(entities_ls)

In [None]:
def get_fr_prompt_template():
    fr_prompt_template = """
Vous êtes un expert en data science et en traitement du langage naturel(NLP).
Votre tâche consiste à extraire les triplets du TEXTE fourni ci-dessous.
Les entité s'agit du sujet et de l'objet d'une phrase, la liste d'entités doit être sous forme:
['entité1', 'entité2', 'entité3', ...]
Un triplet de connaissances est constitué de 2 entités (sujet et objet) liées par un prédicat : 
{{"Objet": "","Prédicat": "", "Sujet": "" }}
Les triples multiples doivent être sous forme de liste.\n
### TEXTE:
{text}{eos_token}\n
### ENTITES:
{entities}{eos_token}\n
### RELATIONS:
{relations}{eos_token}\n
"""
    return PromptTemplate(
        template=fr_prompt_template,
        input_variables=["text", "eos_token", "entities", "relations"],
    )


def get_en_prompt_template():
    en_prompt_template = """You are an expert in data science and natural language processing (NLP).
Your task is to extract triples from the TEXT provided below.
Entities are the subject and object of a sentence, the list of entities must be in the form:
['entity1', 'entity2', 'entity3', ...]
A knowledge triplet is made up of 2 entities (subject and object) linked by a predicate: 
{{"Object": "", "Predicate": "", "Subject": "" }}
Multiple triples must be in list form.\n
### TEXT:
{text}{eos_token}\n
### ENTITIES:
{entities}{eos_token}\n
### RELATIONS:
{relations}{eos_token}\n
"""
    return PromptTemplate(
        template=en_prompt_template,
        input_variables=["text", "eos_token", "entities", "relations"],
    )


def generate_base_prompt(example):
    if dataset_language == "fr":
        template = get_fr_prompt_template()
        full_prompt = template.format(
            text=example["text"],
            eos_token=tokenizer.eos_token,
            entities=get_entities(example),
            relations=get_fr_relation(example),
        )
        return {"text": full_prompt}
    elif dataset_language == "en":
        template = get_en_prompt_template()
        full_prompt = template.format(
            text=example["text"],
            eos_token=tokenizer.eos_token,
            entities=get_entities(example),
            relations=get_en_relation(example),
        )
        return {"text": full_prompt}
    else:
        return None