In [1]:
import os
import wandb
import torch
from langchain.prompts import PromptTemplate
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, StoppingCriteria, StoppingCriteriaList 
from typing import Dict, List, Generator
from datasets import load_dataset, load_from_disk
# from peft import PeftModel
from transformers import GenerationConfig
import json 
import ast
from IPython.display import Markdown, display

In [3]:
BASE_MODEL = "mistralai/Mistral-7B-v0.1"
QUANTIZATION_CONFIG = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

## load the model

1. create Model class

In [None]:
class StopGenerationCriteria(StoppingCriteria):
    def __init__(
            self, 
            stop_words: List[str], 
            tokenizer: AutoTokenizer, 
            device: torch.device
            ) -> None:
        stop_words = [' ' + stop_word for stop_word in stop_words]
        stop_token_ids = [tokenizer(t, add_special_tokens=False)['input_ids'][1:] for t in stop_words]
        self.stop_token_ids = [
            torch.LongTensor(x).to(device) for x in stop_token_ids
        ]

    def __call__(
            self, 
            input_ids: torch.LongTensor, 
            scores: torch.FloatTensor, 
            **kwargs
            ) -> bool:
        for stop_ids in self.stop_token_ids:
            if torch.eq(input_ids[0][-len(stop_ids):], stop_ids).all():
                return True
        return False

class Model:
    def __init__(self, checkpoint_dir: str) -> None:
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        print("THE DEVICE INFERENCE IS RUNNING ON IS: ", self.device)
        self.tokenizer = None
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.stopping_criteria = None
        self.checkpoint_dir = checkpoint_dir
    
    def get_checkpoint_dir(self):
        run = wandb.init()
        checkpoint = run.use_artifact(self.wandb_checkpoint_name, type='model')
        checkpoint_dir = checkpoint.download()
        return checkpoint_dir

    def load(self):
        self.tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.base_model = AutoModelForCausalLM.from_pretrained(
            BASE_MODEL,  # Mistral, same as before
            quantization_config=QUANTIZATION_CONFIG,  # Same quantization config as before
            # torch_dtype=torch.bfloat16,
            device_map="auto",
            trust_remote_code=True
        )

        # load and merge the checkpoint with base model
        self.ft_model = PeftModel.from_pretrained(self.base_model, self.checkpoint_dir)
        self.ft_model.eval()

        self.gen_cfg = GenerationConfig.from_model_config(self.ft_model.config)
        self.gen_cfg.max_new_tokens = 1024
        self.gen_cfg.temperature = 0.5
        self.gen_cfg.num_return_sequences = 1
        self.gen_cfg.use_cache = True
        self.gen_cfg.min_length = 1

    def predict(self, request: Dict) -> Dict | Generator:
        with torch.no_grad():
            prompt = request.pop("prompt")
            stop_words = []
            if "stop" in request:
                stop_words = request.pop("stop")
            inputs = self.tokenizer(prompt, return_tensors="pt")
            input_ids = inputs["input_ids"].cuda()
            generation_output = self.ft_model.generate(
                input_ids=input_ids,
                pad_token_id=self.tokenizer.eos_token_id,
                stopping_criteria = StoppingCriteriaList(
                    [StopGenerationCriteria(
                        stop_words,
                        self.tokenizer,
                        self.device
                        )]),
                generation_config=self.gen_cfg,
                return_dict_in_generate=True,
                output_scores=True,
                max_new_tokens=256
            )
            outputs = []
            for seq in generation_output.sequences:
                output = self.tokenizer.decode(seq)
                outputs.append(output)

            return "\n".join(outputs)

## Process dataset

1. get the testing prompt and ground truth into a list of dictionaries for each example

In [None]:
def get_relation(example: Dict) -> str:
    """
    Extracts and structures relations from a single example within a English dataset.

    Args:
        example (dict): A dictionary containing entities and relations.

    Returns:
        str: A string representation of the extracted relations.

    Example:
        Given an 'example' dictionary containing 'entities' and 'relations', this function
        extracts and structures relations, returning them as a string.

    """
    entities_ls = example["entities"]
    relations = []
    for relation in example["relations"]:
        relation_dict = {}
        object_index = relation["object"]
        relation_dict["Object"] = entities_ls[object_index]["surfaceform"]
        relation_dict["Predicate"] = relation["predicate"]
        subject_index = relation["subject"]
        relation_dict["Subject"] = entities_ls[subject_index]["surfaceform"]
        relations.append(relation_dict)

    return relations

In [None]:
def get_entities(example: Dict) -> List:
    # Extract the list of entities' surface forms without duplicates
    entities = example["entities"]
    unique_entities_surface_forms = list(
        set(entity["surfaceform"] for entity in entities)
    )
    return unique_entities_surface_forms

In [None]:
def get_ground_truth(test_example):
    # try:
    #     prompt, _ = tuple(test_example["text"].split("### ENTITES:"))
    # except:
    #     prompt, _ = tuple(test_example["text"].split("### ENTITIES:"))
    # prompt+="### RELATIONS:\n"
    # test_example["prompt"] = prompt
    test_example["entities_GT"]=get_entities(test_example)
    test_example["relations_GT"]=get_relation(test_example)
    return test_example

In [None]:
def get_entities_prompt_template():
    prompt_template = """### Instruction:
You are an expert in data science and natural language processing (NLP).
Your task is to extract entities from the text provided below.
Entities are the subject and object of a sentence, the list of entities must be in the form:
['entity1', 'entity2', 'entity3', ...]
Text: {text}\n
### Response:"""
    input_variables = ["text"]
    return PromptTemplate(
        template=prompt_template,
        input_variables=input_variables,
    )

In [None]:
def get_relations_prompt_template():
    prompt_template = """### Instruction:
You are an expert in data science and natural language processing (NLP).
Your task is to extract triplets from the text provided below.
A knowledge triplet is made up of 2 entities (subject and object) linked by a predicate: 
{{"Object": "", "Predicate": "", "Subject": "" }}
Multiple triplets must be in list form.
Text: {text}\n
### Response:"""
    input_variables = ["text"]
    return PromptTemplate(
        template=prompt_template,
        input_variables=input_variables,
    )

In [None]:
def get_prompt(test_example):
    text = test_example["text"].split("\nText: ")[1].split("\n\n### Response:")[0]

    entities_prompt_template = get_entities_prompt_template()
    relations_prompt_template = get_relations_prompt_template()
    # Create the full prompt by filling in the template
    entities_prompt = entities_prompt_template.format(
        text=text,
    )
    relations_prompt = relations_prompt_template.format(
        text=text,
    )

    test_example["entities_prompt"] = entities_prompt
    test_example["relations_prompt"] = relations_prompt
    return test_example

In [None]:
# used in solar-serenity-74
def entities_relation_GT(test_example):
    prompt, _ = tuple(test_example["text"].split("### Response:"))
    prompt += "### Response:\n"
    test_example["prompt"] = prompt
    test_example["entities_GT"] = get_entities(test_example)
    test_example["relation_GT"] = get_relation(test_example)
    return test_example

2. get prediction result based previsouly extracted prompt

In [None]:
def get_prediction(test_example):
    output = ft_model.predict(request={
        "prompt": test_example["relations_prompt"],
        "temperature": 0.1, 
        "max_new_tokens": 1024,
        # "stop": ["\n\n"]
        "stop": ["\n\n### Instruction:"]
    })
    # output_dict = string_to_dict(output)
    test_example["prediction"] = output
    return test_example

3. load relation prediction into list of dictionaries 

In [None]:
def string_to_dict(test_example):
    string = test_example["prediction"]
    string = string.split("### RELATIONS:\n")[1]
    string = string.replace("</s>", "").replace("\n", "")
    
        #     try:
        #         k, v = tuple(k_v.split("': "))
        #     except:
        #         print(k_v)
        #     relation[k.strip("'\" ")] = v.strip("'\" ")
        # extracted_relations.append(relation)

    test_example["prediction_dict"] = json.loads(string)
    return test_example   

In [None]:
def GT_string_to_dict(test_example):

    test_example["ground_truth"] = json.loads(test_example["ground_truth"])
    return test_example 

4. count true positive, ground truth and prediction to calculate f1 score

In [None]:
def parse_prediction(test_example):
    string = test_example['prediction'].split("### Response:\nRelations: ")[-1].replace("</s>", "").replace("\n", "")
    prediction_dict = json.loads(string)
    test_example["prediction_dict"] = prediction_dict
    return test_example

In [None]:
def get_true_positive(test_example):
    # gt_ls = test_example["ground_truth"]
    gt_ls = test_example["relations_GT"]
    pred_ls = test_example["prediction_dict"]
    true_positive = 0
    for pred in pred_ls:
        if pred in gt_ls:
            true_positive+=1 
    test_example["correct"] = true_positive
    # test_example["guess"] = len(test_example["ground_truth"])
    test_example["guess"] = len(test_example["relations_GT"])
    test_example["gold"] = len(test_example["prediction_dict"])
    return test_example

## Load checkpoints and dataset

In [16]:
## pious-sound-65 run
# os.environ["WANDB_BASE_URL"]="https://api.wandb.ai"
# run = wandb.init()
# artifact = run.use_artifact('xianli/digital_safety/checkpoint-17x5m17w:v10', type='model')
# artifact_dir = artifact.download()

In [17]:
# os.environ["WANDB_BASE_URL"]="https://api.wandb.ai"
# run = wandb.init()
# artifact = run.use_artifact('xianli/digital_safety/SREDFM-dataset:v3', type='dataset')
# artifact_dir = artifact.download()

In [18]:
## bright-shape-68 run
# os.environ["WANDB_BASE_URL"]="https://api.wandb.ai"
# run = wandb.init()
# artifact = run.use_artifact('xianli/digital_safety/checkpoint-o04konfu:v10', type='model')
# artifact_dir = artifact.download()

### solar-disco-79 run (r=16, alpha=64, lora_dropout=0.1)

In [19]:
# os.environ["WANDB_BASE_URL"]="https://api.wandb.ai"
# run = wandb.init()
# artifact = run.use_artifact('xianli/digital_safety/checkpoint-wvo5zep6:v18', type='model')
# artifact_dir = artifact.download(root="./checkpoints/solar-disco-79/checkpoint-4grkql3s:v18")

In [20]:
# os.environ["WANDB_BASE_URL"]="https://api.wandb.ai"
# run = wandb.init()
# artifact = run.use_artifact('xianli/digital_safety/SREDFM-dataset:v12', type='dataset')
# artifact_dir = artifact.download(root="./datasets/solar-disco-79/SREDFM-dataset:v12/")

checkpoint-4grkql3s:v10

In [22]:
checkpoint_dir = "./checkpoints/solar-disco-79/checkpoint-4grkql3s:v10"
ft_model = Model(checkpoint_dir=checkpoint_dir)
ft_model.load()

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


THE DEVICE INFERENCE IS RUNNING ON IS:  cuda


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [23]:
# dataset_path = "./datasets/solar-disco-79/SREDFM-dataset:v12/test/"
dataset_path = "./datasets/SREDFM-dataset-1024/test/"
test_dataset = load_from_disk(dataset_path)
prompt_dataset = test_dataset.map(get_prompt)
pred_dataset = prompt_dataset.map(get_prediction)
pred_dataset.save_to_disk('./datasets/eager-rain-77/pred_dataset')
GT_pred_dataset = pred_dataset.map(get_ground_truth)
GT_pred_dict_dataset = GT_pred_dataset.map(parse_prediction)
metrics_dataset = GT_pred_dict_dataset.map(get_true_positive)

correct = sum(metrics_dataset["correct"])
guess = sum(metrics_dataset["guess"])
gold = sum(metrics_dataset["gold"])

precision = float(correct)/float(guess)
recall = float(correct)/float(gold)
f1_score = 2*precision*recall/(precision+recall)
print(precision, recall, f1_score)

Map:   0%|          | 0/2529 [00:00<?, ? examples/s]

Map:   0%|          | 0/2529 [00:00<?, ? examples/s]



KeyboardInterrupt: 

checkpoint-4grkql3s:v18

In [None]:
checkpoint_dir = "./checkpoints/solar-disco-79/checkpoint-4grkql3s:v18"
ft_model = Model(checkpoint_dir=checkpoint_dir)
ft_model.load()

In [None]:
dataset_path = "./datasets/solar-disco-79/SREDFM-dataset:v12/test/"
test_dataset = load_from_disk(dataset_path)
prompt_dataset = test_dataset.map(get_prompt)
pred_dataset = prompt_dataset.map(get_prediction)
pred_dataset.save_to_disk('./datasets/eager-rain-77/pred_dataset')
GT_pred_dataset = pred_dataset.map(get_ground_truth)
GT_pred_dict_dataset = GT_pred_dataset.map(parse_prediction)
metrics_dataset = GT_pred_dict_dataset.map(get_true_positive)

correct = sum(metrics_dataset["correct"])
guess = sum(metrics_dataset["guess"])
gold = sum(metrics_dataset["gold"])

precision = float(correct)/float(guess)
recall = float(correct)/float(gold)
f1_score = 2*precision*recall/(precision+recall)
print(precision, recall, f1_score)

### eager-rain-77 run (changed the prompt into dialogue with packing=True)

In [15]:
# os.environ["WANDB_BASE_URL"]="https://api.wandb.ai"
# run = wandb.init()
# artifact = run.use_artifact('xianli/digital_safety/checkpoint-wvo5zep6:v10', type='model')
# artifact_dir = artifact.download(root="./checkpoints/")

In [6]:
# os.environ["WANDB_BASE_URL"]="https://api.wandb.ai"
# run = wandb.init()
# artifact = run.use_artifact('xianli/digital_safety/SREDFM-dataset:v11', type='dataset')
# artifact_dir = artifact.download(root="./datasets/eager-rain-77/SREDFM-dataset:v11/")

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011113743588888761, max=1.0…

wandb: Network error (ConnectionError), entering retry loop.


Problem at: /tmp/ipykernel_9610/1890665174.py 2 <module>


KeyboardInterrupt: 

In [18]:
artifact_dir = "./checkpoints/eager-rain-77/checkpoint-wvo5zep6:v10"
ft_model = Model(checkpoint_dir=artifact_dir)
ft_model.load()

THE DEVICE INFERENCE IS RUNNING ON IS:  cuda


Downloading (…)okenizer_config.json:   0%|          | 0.00/966 [00:00<?, ?B/s]

Downloading tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.80M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/72.0 [00:00<?, ?B/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Downloading (…)lve/main/config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

Downloading (…)model.bin.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading (…)l-00001-of-00002.bin:   0%|          | 0.00/9.94G [00:00<?, ?B/s]

KeyboardInterrupt: 

In [12]:
dataset_path = "../../artifacts/datasets/SREDFM-dataset-1024/test"
test_dataset = load_from_disk(dataset_path)
display(Markdown(test_dataset[2]["text"]))

### Instruction:
You are an expert in data science and natural language processing (NLP).
Your task is to extract entities from the text provided below.
Entities are the subject and object of a sentence, the list of entities must be in the form:
['entity1', 'entity2', 'entity3', ...]
Text: La Grosse Combine (titre original : "") est un film franco-italien réalisé par Bruno Corbucci et sorti en 1971.

### Response:
Entities: ["Bruno Corbucci", "film franco", "1971", "italien", "La Grosse Combine"]</s>

### Instruction:
Now based on the entities that you extracted before, you should extract all triplets including every extracted entity.
A knowledge triplet is made up of 2 entities (subject and object) linked by a predicate: 
{"Object": "", "Predicate": "", "Subject": "" }
Entities can be related to many other entities.
Multiple triplets must be in list form.

### Response:
Relations: [{"Object": "Bruno Corbucci", "Predicate": "director", "Subject": "La Grosse Combine"}]</s>



In [13]:
test = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nCreate a classification task by clustering the given list of items.\n\n### Input:\nApples, oranges, bananas, strawberries, pineapples\n\n### Response:\nClass 1: Apples, Oranges\nClass 2: Bananas, Strawberries\nClass 3: Pineapples"
display(Markdown(test))

Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
Create a classification task by clustering the given list of items.

### Input:
Apples, oranges, bananas, strawberries, pineapples

### Response:
Class 1: Apples, Oranges
Class 2: Bananas, Strawberries
Class 3: Pineapples

In [8]:
dataset_path = "../../artifacts/datasets/eager-rain-77/SREDFM-dataset:v11/test"
test_dataset = load_from_disk(dataset_path)
display(Markdown(test_dataset[2]["text"]))

### Instruction:
You are an expert in data science and natural language processing (NLP).
Your task is to extract entities from the text provided below.
Entities are the subject and object of a sentence, the list of entities must be in the form:
['entity1', 'entity2', 'entity3', ...]
Text: La Grosse Combine (titre original : "") est un film franco-italien réalisé par Bruno Corbucci et sorti en 1971.

### Response:
Entities: ["Bruno Corbucci", "1971", "italien", "film franco", "La Grosse Combine"]</s>

### Instruction:
You are an expert in data science and natural language processing (NLP).
Your task is to extract triplets from the text provided below.
A knowledge triplet is made up of 2 entities (subject and object) linked by a predicate: 
{"Object": "", "Predicate": "", "Subject": "" }
Multiple triplets must be in list form.

### Response:
Relations: [{"Object": "Bruno Corbucci", "Predicate": "director", "Subject": "La Grosse Combine"}]</s>



In [27]:
prompt_dataset = test_dataset.map(get_prompt)

In [28]:
print(prompt_dataset[0]["entities_prompt"])

### Instruction:
You are an expert in data science and natural language processing (NLP).
Your task is to extract entities from the text provided below.
Entities are the subject and object of a sentence, the list of entities must be in the form:
['entity1', 'entity2', 'entity3', ...]
Text: Khattiya Sawasdiphol (thaï : พลตรี ขัตติยะ สวัสดิผล), alias Seh Daeng (thaï : เสธ. แดง ; français : « Le commandant rouge »), né le et tué le , est un général de division de l'armée thaïlandaise, autrefois affecté au commandement des opérations de sécurité interne. 

### Response:


In [31]:
pred_dataset = prompt_dataset.map(get_prediction)
pred_dataset.save_to_disk('./datasets/eager-rain-77/pred_dataset')

Map:   0%|          | 0/2529 [00:00<?, ? examples/s]



NameError: name 'GT_pred_dataset' is not defined

In [18]:
pred_dataset = load_from_disk('./datasets/eager-rain-77/pred_dataset')
pred_dataset

Dataset({
    features: ['docid', 'title', 'uri', 'text', 'entities', 'relations', 'entities_prompt', 'relations_prompt', 'prediction'],
    num_rows: 2529
})

In [21]:
GT_pred_dataset = pred_dataset.map(get_ground_truth)
GT_pred_dataset

Dataset({
    features: ['docid', 'title', 'uri', 'text', 'entities', 'relations', 'entities_prompt', 'relations_prompt', 'prediction', 'entities_GT', 'relations_GT'],
    num_rows: 2529
})

In [23]:
print(GT_pred_dataset[0]["prediction"])

<s> ### Instruction:
You are an expert in data science and natural language processing (NLP).
Your task is to extract triplets from the text provided below.
A knowledge triplet is made up of 2 entities (subject and object) linked by a predicate: 
{"Object": "", "Predicate": "", "Subject": "" }
Multiple triplets must be in list form.
Text: Khattiya Sawasdiphol (thaï : พลตรี ขัตติยะ สวัสดิผล), alias Seh Daeng (thaï : เสธ. แดง ; français : « Le commandant rouge »), né le et tué le , est un général de division de l'armée thaïlandaise, autrefois affecté au commandement des opérations de sécurité interne. 

### Response:
Relations: [{"Object": "thaïlandaise", "Predicate": "country of citizenship", "Subject": "Khattiya Sawasdiphol"}]</s>


In [28]:
GT_pred_dict_dataset = GT_pred_dataset.map(parse_prediction)
GT_pred_dict_dataset.map()

Map:   0%|          | 0/2529 [00:00<?, ? examples/s]

Dataset({
    features: ['docid', 'title', 'uri', 'text', 'entities', 'relations', 'entities_prompt', 'relations_prompt', 'prediction', 'entities_GT', 'relations_GT', 'prediction_dict'],
    num_rows: 2529
})

In [33]:
metrics_dataset = GT_pred_dict_dataset.map(get_true_positive)

correct = sum(metrics_dataset["correct"])
guess = sum(metrics_dataset["guess"])
gold = sum(metrics_dataset["gold"])

precision = float(correct)/float(guess)
recall = float(correct)/float(gold)
f1_score = 2*precision*recall/(precision+recall)
print(precision, recall, f1_score)

Map:   0%|          | 0/2529 [00:00<?, ? examples/s]

0.37620849096258935 0.581924577373212 0.45698238447791684


In [None]:
# GT_pred_dict_dataset = GT_pred_dataset.map(string_to_dict)
# metrics_dataset = GT_pred_dict_dataset.map(get_true_positive)

# correct = sum(metrics_dataset["correct"])
# guess = sum(metrics_dataset["guess"])
# gold = sum(metrics_dataset["gold"])

# precision = float(correct)/float(guess)
# recall = float(correct)/float(gold)
# f1_score = 2*precision*recall/(precision+recall)
# print(precision, recall, f1_score)

In [None]:
GT_test_dataset = test_dataset.map(entities_relation_GT)
GT_pred_dataset = GT_test_dataset.map(get_prediction)
GT_pred_dataset.save_to_disk("./datasets/bright-shape-68/GT_pred_dataset")

In [69]:
# def get_relation_redfm(example):

#   RELATION_NAMES=['country', 'place of birth', 'spouse', 'country of citizenship', 'instance of',
#     'capital', 'child', 'shares border with', 'author', 'director', 'occupation',
#       'founded by', 'league', 'owned by', 'genre', 'named after', 'follows',
#       'headquarters location', 'cast member', 'manufacturer',
#         'located in or next to body of water', 'location', 'part of', 
#         'mouth of the watercourse', 'member of', 'sport', 'characters',
#           'participant', 'notable work', 'replaces', 'sibling', 'inception']
#   relations = []
#   for relation in example['relations']:
#     relation_dict = {}
#     relation_dict["object"] = relation['object']['surfaceform']
#     relation_dict["subject"] = relation['subject']['surfaceform']
#     relation_dict["predicate"] = RELATION_NAMES[relation['predicate']]
#     relations.append(relation_dict)
#   return relations

In [83]:
prompt = """### Instruction:
You are an expert in data science and natural language processing (NLP).
Your task is to extract ALL triplets from the text provided below.
A knowledge triplet is made up of 2 entities (subject and object) linked by a predicate: 
{"Object": "", "Predicate": "", "Subject": "" }
Multiple triplets must be in list form.
Text: Le chat mange la souris.

### Response:"""

In [88]:
get_relation_redfm(prompt_dataset[12])

[{'object': 'Somme', 'subject': 'Mers-les-Bains', 'predicate': 'location'},
 {'object': 'années 1960', 'subject': 'années 1970', 'predicate': 'follows'}]

### solar-serenity-74 run (used datacollator with instruction and response template)

1. download the checkpoint and dataset if haven't

In [15]:
## solar-serenity-74 run
# os.environ["WANDB_BASE_URL"]="https://api.wandb.ai"
# run = wandb.init()
# artifact = run.use_artifact('xianli/digital_safety/checkpoint-ql6d9loh:v10', type='model')
# artifact_dir = artifact.download(root="./checkpoints/")

In [16]:
# os.environ["WANDB_BASE_URL"]="https://api.wandb.ai"
# run = wandb.init()
# artifact = run.use_artifact('xianli/digital_safety/SREDFM-dataset:v9', type='dataset')
# artifact_dir = artifact.download(root="./datasets/")

2. load the checkpoint and merge with the base model

In [17]:
artifact_dir = "./checkpoints/solar-serenity-74/checkpoint-ql6d9loh:v10"
ft_model = Model(checkpoint_dir=artifact_dir)
ft_model.load()

THE DEVICE INFERENCE IS RUNNING ON IS:  cuda


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [18]:
dataset_path = "./datasets/solar-serenity-74/SREDFM-dataset:v9/test"
test_dataset = load_from_disk(dataset_path)
GT_test_dataset = test_dataset.map(entities_relation_GT)
GT_pred_dataset = GT_test_dataset.map(get_prediction)
GT_pred_dataset.save_to_disk("./datasets/bright-shape-68/GT_pred_dataset")
# GT_pred_dict_dataset = GT_pred_dataset.map(string_to_dict)
# metrics_dataset = GT_pred_dict_dataset.map(get_true_positive)

# correct = sum(metrics_dataset["correct"])
# guess = sum(metrics_dataset["guess"])
# gold = sum(metrics_dataset["gold"])

# precision = float(correct)/float(guess)
# recall = float(correct)/float(gold)
# f1_score = 2*precision*recall/(precision+recall)
# print(precision, recall, f1_score)

Map:   0%|          | 0/2529 [00:00<?, ? examples/s]



Saving the dataset (0/1 shards):   0%|          | 0/2529 [00:00<?, ? examples/s]

In [21]:
print(GT_pred_dataset[0]['text'])

### Instruction:
You are an expert in data science and natural language processing (NLP).
Your task is to extract triples from the text provided below.
Entities are the subject and object of a sentence, the list of entities must be in the form:
['entity1', 'entity2', 'entity3', ...]
A knowledge triplet is made up of 2 entities (subject and object) linked by a predicate: 
{"Object": "", "Predicate": "", "Subject": "" }
Multiple triplets must be in list form.

Text: Khattiya Sawasdiphol (thaï : พลตรี ขัตติยะ สวัสดิผล), alias Seh Daeng (thaï : เสธ. แดง ; français : « Le commandant rouge »), né le et tué le , est un général de division de l'armée thaïlandaise, autrefois affecté au commandement des opérations de sécurité interne. 

### Response:
Entities: ["général de division", "armée thaïlandaise", "Khattiya Sawasdiphol"]

Relations: [{"Object": "armée thaïlandaise", "Predicate": "military branch", "Subject": "Khattiya Sawasdiphol"}]</s>




In [20]:
print(GT_pred_dataset[0]['prediction'])

<s> ### Instruction:
You are an expert in data science and natural language processing (NLP).
Your task is to extract triples from the text provided below.
Entities are the subject and object of a sentence, the list of entities must be in the form:
['entity1', 'entity2', 'entity3', ...]
A knowledge triplet is made up of 2 entities (subject and object) linked by a predicate: 
{"Object": "", "Predicate": "", "Subject": "" }
Multiple triplets must be in list form.

Text: Khattiya Sawasdiphol (thaï : พลตรี ขัตติยะ สวัสดิผล), alias Seh Daeng (thaï : เสธ. แดง ; français : « Le commandant rouge »), né le et tué le , est un général de division de l'armée thaïlandaise, autrefois affecté au commandement des opérations de sécurité interne. 

### Response:
Entities: ["thaïlandaise", "Khattiya Sawasdiphol", "thaï"]

Relations: [{"Object": "thaïlandaise", "Predicate": "country of citizenship", "Subject": "Khattiya Sawasdiphol"}, {"Object": "thaïlandaise", "Predicate": "country", "Subject": "thaï"}]


### bright-shape-68 : training with data collator

In [17]:
dataset_path = "./datasets/bright-shape-68/SREDFM-dataset:v5/test/"
test_dataset = load_from_disk(dataset_path)
GT_test_dataset = test_dataset.map(get_ground_truth)
GT_pred_dataset = GT_test_dataset.map(get_prediction)
GT_pred_dataset.save_to_disk("./datasets/bright-shape-68/GT_pred_dataset")
GT_pred_dict_dataset = GT_pred_dataset.map(string_to_dict)
metrics_dataset = GT_pred_dict_dataset.map(get_true_positive)

correct = sum(metrics_dataset["correct"])
guess = sum(metrics_dataset["guess"])
gold = sum(metrics_dataset["gold"])

precision = float(correct)/float(guess)
recall = float(correct)/float(gold)
f1_score = 2*precision*recall/(precision+recall)
print(precision, recall, f1_score)

Map:   0%|          | 0/2525 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2525 [00:00<?, ? examples/s]

Map:   0%|          | 0/2525 [00:00<?, ? examples/s]

Map:   0%|          | 0/2525 [00:00<?, ? examples/s]

0.24878358366828857 0.45440494590417313 0.32153110047846895


Metrics of pions-sound-65

In [23]:
correct = sum(metrics_dataset["correct"])
guess = sum(metrics_dataset["guess"])
gold = sum(metrics_dataset["gold"])

precision = float(correct)/float(guess)
recall = float(correct)/float(gold)
f1_score = 2*precision*recall/(precision+recall)
print(precision, recall, f1_score)

0.3334038502221282 0.5485555168813088 0.4147368421052632
