In [297]:
import sys


In [298]:
import json
import os
import re
import copy
import glob
import time
import random
import logging
from pprint import pprint
from tqdm import tqdm
from IPython.core.debugger import set_trace

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModel, BertTokenizerFast

from common.utils import Preprocessor, DefaultLogger
from tplinker import (
    HandshakingTaggingScheme,
    DataMaker4Bert,
    TPLinkerBert,
    TPLinkerBiLSTM,
    MetricsCalculator
)


In [299]:
 #Configuration manuelle pour éviter les problèmes de chemin
train_config = {
    "exp_name": "nyt",
    "data_home": ".",  # Utiliser le répertoire courant
    "train_data": "nyt_dataset.json",  # Pointer directement vers le fichier JSON
    "valid_data": "nyt_dataset.json",  # Utiliser le même fichier pour validation
    "rel2id": "rel2id.json",
    "device_num": 0,
    "encoder": "BERT",
    "bert_path": "bert-base-cased",  # Utiliser directement le nom du modèle Hugging Face
    "logger_type": "default",  # Ajout de la clé logger_type au lieu de logger
    "log_path": "./logs",
    "run_name": "nyt_run",
    "run_id": "1",
    "note": "Adaptation du notebook avec le dataset JSON",
    "path_to_save_model": "./models",
    "hyper_parameters": {
        "batch_size": 6,
        "max_seq_len": 100,
        "sliding_len": 20,
        "epochs": 5,
        "seed": 42,
        "shaking_type": "cln",
        "inner_enc_type": "lstm",
        "dist_emb_size": 20,
        "ent_add_dist": True,
        "rel_add_dist": True,
        "lr": 1e-5,
        "weight_decay": 0
    }
}

# Utiliser cette configuration au lieu d'importer config.py
config = train_config
hyper_parameters = config["hyper_parameters"]

In [300]:

# Créer les répertoires nécessaires
os.makedirs("./models", exist_ok=True)
os.makedirs("./logs", exist_ok=True)



In [301]:
os.environ["TOKENIZERS_PARALLELISM"] = "true"
os.environ["CUDA_VISIBLE_DEVICES"] = str(config["device_num"])
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Utilisation du dispositif: {device}")

Utilisation du dispositif: cuda:0


In [302]:
# for reproductivity
torch.manual_seed(hyper_parameters["seed"]) # pytorch random seed
torch.backends.cudnn.deterministic = True

In [303]:
# Définir les chemins d'accès aux fichiers
# Utiliser des chemins directs pour éviter les problèmes
train_data_path = config["train_data"]
valid_data_path = config["valid_data"]
rel2id_path = config["rel2id"]

print(f"Chemin du fichier d'entraînement: {train_data_path}")
print(f"Chemin du fichier de validation: {valid_data_path}")
print(f"Chemin du fichier rel2id: {rel2id_path}")

Chemin du fichier d'entraînement: nyt_dataset.json
Chemin du fichier de validation: nyt_dataset.json
Chemin du fichier rel2id: rel2id.json


In [304]:
# Configuration du logger simplifiée
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(os.path.join(config["log_path"], "training.log")),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger("TPLinker")
model_state_dict_dir = config["path_to_save_model"]
if not os.path.exists(model_state_dict_dir):
    os.makedirs(model_state_dict_dir)

In [305]:
import json
import random
from pprint import pprint
from collections import defaultdict
from transformers import AutoTokenizer

# === CONFIGURATION ===
input_path = "nyt_dataset.json"
output_prefix = "nyt_dataset"
tokenizer_name = "bert-base-cased"  # ou ton tokenizer BiLSTM
train_ratio = 0.8
random_seed = 42

# === CHARGEMENT ===
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
with open(input_path, "r", encoding="utf-8") as f:
    full_data = json.load(f)

def char_to_token(offsets, char_index):
    """Mappe un index de caractère à un index de token BERT (WordPiece)."""
    for idx, (start, end) in enumerate(offsets):
        if start <= char_index < end:
            return idx
    return None

def find_key(possibles, keys):
    for k in possibles:
        if k in keys:
            return k
    return None

fixed_data = []
skipped = 0

for sample in full_data:
    keys = sample.keys()
    text_field = find_key(["text", "sentence", "content"], keys)
    if not text_field:
        continue
    text = sample[text_field]
    encoding = tokenizer(text, return_offsets_mapping=True, truncation=True, max_length=512)
    offsets = encoding["offset_mapping"]

    subj_field = find_key(["raw_subjects", "subjects", "subj", "subject"], keys)
    obj_field = find_key(["raw_objects", "objects", "obj", "object"], keys)
    pred_field = find_key(["predicates", "predicate", "relations", "relation"], keys)
    subj_start_field = find_key(["subj_char_span_starts", "subject_start", "subj_starts"], keys)
    subj_end_field = find_key(["subj_char_span_ends", "subject_end", "subj_ends"], keys)
    obj_start_field = find_key(["obj_char_span_starts", "object_start", "obj_starts"], keys)
    obj_end_field = find_key(["obj_char_span_ends", "object_end", "obj_ends"], keys)

    # === GENERATION DE RELATION_LIST ===
    relation_list = []
    if all([subj_field, obj_field, pred_field, subj_start_field, subj_end_field, obj_start_field, obj_end_field]):
        subj_list = sample[subj_field]
        obj_list = sample[obj_field]
        pred_list = sample[pred_field]
        subj_start_list = sample[subj_start_field]
        subj_end_list = sample[subj_end_field]
        obj_start_list = sample[obj_start_field]
        obj_end_list = sample[obj_end_field]
        num_rels = min(len(subj_list), len(obj_list), len(pred_list), len(subj_start_list), len(subj_end_list), len(obj_start_list), len(obj_end_list))
        for i in range(num_rels):
            subj, obj, pred = subj_list[i], obj_list[i], pred_list[i]
            subj_char_start, subj_char_end = subj_start_list[i], subj_end_list[i]
            obj_char_start, obj_char_end = obj_start_list[i], obj_end_list[i]
            subj_tok_start = char_to_token(offsets, subj_char_start)
            subj_tok_end = char_to_token(offsets, subj_char_end - 1)
            obj_tok_start = char_to_token(offsets, obj_char_start)
            obj_tok_end = char_to_token(offsets, obj_char_end - 1)
            if None in (subj_tok_start, subj_tok_end, obj_tok_start, obj_tok_end):
                skipped += 1
                continue
            relation_list.append({
                "predicate": pred,
                "subject": subj,
                "object": obj,
                "subj_char_span": [subj_char_start, subj_char_end],
                "obj_char_span": [obj_char_start, obj_char_end],
                "subj_tok_span": [subj_tok_start, subj_tok_end + 1],  # [start, end)
                "obj_tok_span": [obj_tok_start, obj_tok_end + 1]
            })
    else:
        print(f"WARNING: Could not auto-detect relation fields for sample {sample.get('id', '?')}. Empty relation_list will be created.")

    sample["text"] = text
    sample["relation_list"] = relation_list

    # === entity_list (optionnel mais utile pour analyse ou débogage) ===
    ent_field = find_key(["raw_entities", "entities", "entity"], keys)
    ent_start_field = find_key(["entity_char_span_starts", "entity_start", "ent_starts"], keys)
    ent_end_field = find_key(["entity_char_span_ends", "entity_end", "ent_ends"], keys)
    ent_type_field = find_key(["classes", "entity_types", "types", "labels"], keys)
    entity_list = []
    if ent_field and ent_start_field and ent_end_field:
        ent_list = sample[ent_field]
        ent_start_list = sample[ent_start_field]
        ent_end_list = sample[ent_end_field]
        ent_type_list = sample.get(ent_type_field, ["DEFAULT"] * len(ent_list))
        for i in range(min(len(ent_list), len(ent_start_list), len(ent_end_list), len(ent_type_list))):
            ent = ent_list[i]
            char_span = [ent_start_list[i], ent_end_list[i]]
            tok_start = char_to_token(offsets, char_span[0])
            tok_end = char_to_token(offsets, char_span[1] - 1)
            entity_list.append({
                "text": ent,
                "char_span": char_span,
                "tok_span": [tok_start, tok_end + 1] if tok_start is not None and tok_end is not None else [0, 1],
                "type": ent_type_list[i]
            })
    # fallback: crée les entités à partir des relations (pas obligatoire)
    elif relation_list:
        ents = defaultdict(dict)
        for rel in relation_list:
            for role, key in [("subject", "subj_char_span"), ("object", "obj_char_span")]:
                ent_text = rel[role]
                char_span = rel.get(f"{role[:4]}_char_span", None)
                tok_span = rel.get(f"{role[:4]}_tok_span", None)
                ent_type = "DEFAULT"
                ent_id = (ent_text, tuple(char_span) if char_span else ())
                if ent_id not in ents:
                    ents[ent_id] = {
                        "text": ent_text,
                        "char_span": char_span if char_span else [0, 1],
                        "tok_span": tok_span if tok_span else [0, 1],
                        "type": ent_type
                    }
        entity_list = list(ents.values())
    sample["entity_list"] = entity_list
    if "event_list" not in sample:
        sample["event_list"] = []
    fixed_data.append(sample)

print(f"Processed {len(fixed_data)} samples. Skipped {skipped} relations due to char/token mapping issues.")
# === Sauvegarde et écrasement du nyt_dataset original avec la nouvelle structure ===
with open(input_path, "w", encoding="utf-8") as f:
    json.dump(fixed_data, f, ensure_ascii=False, indent=2)

print(f"\nnyt_dataset.json a été ÉCRASÉ par la version au format TPLinker ({len(fixed_data)} samples).")

# === SHUFFLE & SPLIT ===
random.seed(random_seed)
random.shuffle(fixed_data)
split_index = int(len(fixed_data) * train_ratio)
train_data = fixed_data[:split_index]
valid_data = fixed_data[split_index:]

# === rel2id (pour TPLinker) ===
all_predicates = set()
for sample in train_data + valid_data:
    for rel in sample["relation_list"]:
        all_predicates.add(rel["predicate"])
rel2id = {p: i for i, p in enumerate(sorted(all_predicates))}

with open(f"{output_prefix}_rel2id.json", "w", encoding="utf-8") as f:
    json.dump(rel2id, f, ensure_ascii=False, indent=2)
with open(f"{output_prefix}_train.json", "w", encoding="utf-8") as f:
    json.dump(train_data, f, ensure_ascii=False, indent=2)
with open(f"{output_prefix}_valid.json", "w", encoding="utf-8") as f:
    json.dump(valid_data, f, ensure_ascii=False, indent=2)

print("\nExemple d'entrée du dataset:")
pprint(train_data[0])
# Sauvegarde du mapping relation -> id dans rel2id.json
with open(f"./rel2id.json", "w", encoding="utf-8") as f:
    json.dump(rel2id, f, ensure_ascii=False, indent=2)

# Affichage rel2id (console)
print("\nRelation to ID mapping (rel2id):")
pprint(rel2id)



Processed 100 samples. Skipped 0 relations due to char/token mapping issues.

nyt_dataset.json a été ÉCRASÉ par la version au format TPLinker (100 samples).

Exemple d'entrée du dataset:
{'classes': ['DEFAULT', 'DEFAULT'],
 'entity_char_span_ends': [367, 248],
 'entity_char_span_starts': [360, 238],
 'entity_list': [{'char_span': [360, 367],
                  'text': 'Angeles',
                  'tok_span': [74, 75],
                  'type': 'DEFAULT'},
                 {'char_span': [238, 248],
                  'text': 'California',
                  'tok_span': [50, 51],
                  'type': 'DEFAULT'}],
 'event_list': [],
 'id': 'train_42',
 'obj_char_span_ends': [248],
 'obj_char_span_starts': [238],
 'predicates': ['/location/location/contains'],
 'raw_entities': ['Angeles', 'California'],
 'raw_objects': ['California'],
 'raw_subjects': ['Angeles'],
 'relation_list': [{'obj_char_span': [238, 248],
                    'obj_tok_span': [50, 51],
                    'object': 

In [306]:
# @specific
if config["encoder"] == "BERT":
    # Correction de l'appel du tokenizer pour éviter l'erreur add_special_tokens
    tokenizer = BertTokenizerFast.from_pretrained(config["bert_path"])
    tokenize = tokenizer.tokenize
    get_tok2char_span_map = lambda text: tokenizer.encode_plus(text, return_offsets_mapping=True)["offset_mapping"]
elif config["encoder"] in {"BiLSTM", }:
    tokenize = lambda text: text.split(" ")
    def get_tok2char_span_map(text):
        tokens = text.split(" ")
        tok2char_span = []
        char_num = 0
        for tok in tokens:
            tok2char_span.append((char_num, char_num + len(tok)))
            char_num += len(tok) + 1 # +1: whitespace
        return tok2char_span

In [307]:
preprocessor = Preprocessor(tokenize_func = tokenize, 
                            get_tok2char_span_map_func = get_tok2char_span_map)

In [308]:
# train and valid max token num
max_tok_num = 0
all_data = train_data + valid_data 
    
for sample in all_data:
    tokens = tokenize(sample["text"])
    max_tok_num = max(max_tok_num, len(tokens))
print(f"Nombre maximum de tokens: {max_tok_num}")

Nombre maximum de tokens: 118


In [309]:
if max_tok_num > hyper_parameters["max_seq_len"]:
    train_data = preprocessor.split_into_short_samples(train_data, 
                                                          hyper_parameters["max_seq_len"], 
                                                          sliding_len = hyper_parameters["sliding_len"], 
                                                          encoder = config["encoder"]
                                                         )
    valid_data = preprocessor.split_into_short_samples(valid_data, 
                                                          hyper_parameters["max_seq_len"], 
                                                          sliding_len = hyper_parameters["sliding_len"], 
                                                          encoder = config["encoder"]
                                                         )

Splitting into subtexts: 100%|██████████| 80/80 [00:00<00:00, 3118.35it/s]
Splitting into subtexts: 100%|██████████| 20/20 [00:00<00:00, 3231.23it/s]


In [310]:
print("train: {}".format(len(train_data)), "valid: {}".format(len(valid_data)))

train: 81 valid: 20


In [311]:
max_seq_len = min(max_tok_num, hyper_parameters["max_seq_len"])
rel2id = json.load(open("rel2id.json", "r", encoding = "utf-8"))
handshaking_tagger = HandshakingTaggingScheme(rel2id = rel2id, max_seq_len = max_seq_len)

In [312]:
if config["encoder"] == "BERT":
    # Correction de l'appel du tokenizer pour éviter l'erreur add_special_tokens
    tokenizer = BertTokenizerFast.from_pretrained(config["bert_path"])
    data_maker = DataMaker4Bert(tokenizer, handshaking_tagger)

In [313]:
class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data
        
    def __getitem__(self, index):
        return self.data[index]
    
    def __len__(self):
        return len(self.data)

In [314]:
indexed_train_data = data_maker.get_indexed_data(train_data, max_seq_len)
indexed_valid_data = data_maker.get_indexed_data(valid_data, max_seq_len)

Generate indexed train or valid data: 81it [00:00, 4807.53it/s]
Generate indexed train or valid data: 20it [00:00, 4339.01it/s]


In [315]:
import torch
from torch.utils.data import Dataset, DataLoader


# Exemple simple de Dataset
class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data
    def __getitem__(self, index):
        return self.data[index]  # Doit être un dict !
    def __len__(self):
        return len(self.data)



train_dataloader = DataLoader(
    MyDataset(indexed_train_data),
    batch_size=hyper_parameters["batch_size"],
    shuffle=True,
    num_workers=0,
    drop_last=False,
    collate_fn=data_maker.generate_batch,
)

valid_dataloader = DataLoader(
    MyDataset(indexed_valid_data),
    batch_size=hyper_parameters["batch_size"],
    shuffle=True,
    num_workers=0,
    drop_last=False,
    collate_fn=data_maker.generate_batch,
)

# --- Visualisation d'un batch de données ---
train_data_iter = iter(train_dataloader)
batch_data = next(train_data_iter)
text_list, batch_input_ids, \
batch_attention_mask, batch_token_type_ids, \
tok2char_span_list, batch_ent_shaking_tag, \
batch_head_rel_shaking_tag, batch_tail_rel_shaking_tag = batch_data

print("\nExemple de texte:")
print(text_list[0])
print()
print("Tailles des tenseurs:")
print(f"batch_input_ids: {batch_input_ids.size()}")
print(f"batch_attention_mask: {batch_attention_mask.size()}")
print(f"batch_token_type_ids: {batch_token_type_ids.size()}")
print(f"offset_map_list: {len(tok2char_span_list)}")
print(f"batch_ent_shaking_tag: {len(batch_ent_shaking_tag) if batch_ent_shaking_tag is not None else None}")
print(f"batch_head_rel_shaking_tag: {len(batch_head_rel_shaking_tag) if batch_head_rel_shaking_tag is not None else None}")
print(f"batch_tail_rel_shaking_tag: {len(batch_tail_rel_shaking_tag) if batch_tail_rel_shaking_tag is not None else None}")



Exemple de texte:
{'id': 'train_41', 'text': '', 'tok_offset': 0, 'char_offset': 0, 'event_list': [], 'entity_list': [{'text': 'Washington', 'char_span': [89, 99], 'tok_span': [18, 19], 'type': 'DEFAULT'}, {'text': 'Seattle', 'char_span': [51, 58], 'tok_span': [13, 14], 'type': 'DEFAULT'}], 'relation_list': [{'predicate': '/location/location/contains', 'subject': 'Washington', 'object': 'Seattle', 'subj_char_span': [89, 99], 'obj_char_span': [51, 58], 'subj_tok_span': [18, 19], 'obj_tok_span': [13, 14]}]}

Tailles des tenseurs:
batch_input_ids: torch.Size([6, 100])
batch_attention_mask: torch.Size([6, 100])
batch_token_type_ids: torch.Size([6, 100])
offset_map_list: 6
batch_ent_shaking_tag: 6
batch_head_rel_shaking_tag: 6
batch_tail_rel_shaking_tag: 6


In [316]:
if config["encoder"] == "BERT":
    encoder = AutoModel.from_pretrained(config["bert_path"])
    hidden_size = encoder.config.hidden_size
    rel_extractor = TPLinkerBert(encoder, 
                                 len(rel2id), 
                                 hyper_parameters["shaking_type"],
                                 hyper_parameters["inner_enc_type"],
                                 hyper_parameters["dist_emb_size"],
                                 hyper_parameters["ent_add_dist"],
                                 hyper_parameters["rel_add_dist"],
                                )

In [317]:
rel_extractor = rel_extractor.to(device)

# Définition de l'optimiseur
optimizer = optim.Adam(rel_extractor.parameters(), lr = float(hyper_parameters["lr"]), weight_decay = float(hyper_parameters["weight_decay"]))

In [318]:
# Fonction d'entraînement
def train(dataloader, ep):
    rel_extractor.train()
    
    total_loss = 0.
    total_steps = len(dataloader)
    for batch_ind, batch_data in enumerate(tqdm(dataloader, desc = "Training")):
        text_id_list, text_list, batch_input_ids, \
        batch_attention_mask, batch_token_type_ids, \
        offset_map_list, batch_ent_shaking_tag, \
        batch_head_rel_shaking_tag, batch_tail_rel_shaking_tag = batch_data
        
        batch_input_ids, batch_attention_mask, batch_token_type_ids = (
            batch_input_ids.to(device),
            batch_attention_mask.to(device),
            batch_token_type_ids.to(device),
        )
        batch_ent_shaking_tag, batch_head_rel_shaking_tag, batch_tail_rel_shaking_tag = (
            batch_ent_shaking_tag.to(device),
            batch_head_rel_shaking_tag.to(device),
            batch_tail_rel_shaking_tag.to(device),
        )
        
        # Zero gradients
        optimizer.zero_grad()
        
        # Forward
        loss = rel_extractor(
            batch_input_ids, 
            batch_attention_mask, 
            batch_token_type_ids, 
            batch_ent_shaking_tag,
            batch_head_rel_shaking_tag,
            batch_tail_rel_shaking_tag,
        )
        
        # Backward
        loss.backward()
        
        # Update parameters
        optimizer.step()
        
        total_loss += loss.item()
        
        # Log
        if batch_ind % 10 == 0:
            print(f"Epoch {ep}, Batch {batch_ind}/{total_steps}, Loss: {loss.item():.4f}")
    
    avg_loss = total_loss / total_steps
    print(f"Epoch {ep}, Average Loss: {avg_loss:.4f}")
    return avg_loss

In [319]:
# Sauvegarde du modèle
save_path = os.path.join(model_state_dict_dir, "model.pt")
torch.save(rel_extractor.state_dict(), save_path)
print(f"Modèle sauvegardé à {save_path}")

Modèle sauvegardé à ./models/model.pt


In [320]:
# try:
#     from yaml import CLoader as Loader, CDumper as Dumper
# except ImportError:
#     from yaml import Loader, Dumper
# config = yaml.load(open("train_config.yaml", "r"), Loader = yaml.FullLoader)

In [None]:
# Configuration manuelle pour éviter les problèmes de chemin
train_config = {
    "exp_name": "nyt",
    "data_home": ".",  # Utiliser le répertoire courant
    "train_data": "nyt_dataset.json",  # Pointer directement vers le fichier JSON
    "valid_data": "nyt_dataset.json",  # Utiliser le même fichier pour validation
    "rel2id": "./rel2id.json",
    "device_num": 0,
    "encoder": "BERT",
    "model_state_dict_path": "./models/model.pt",
    "bert_path": "bert-base-cased",
    "log_path": "./logs/train.log",
    "log_path": "./tplinker/logs",
    "run_name": "nyt_run",
    "run_id": "1",
    "fr_scratch":False,
    "note": "Adaptation du notebook avec le dataset JSON",
    "path_to_save_model": "./models",
    "logger": "wandb",
    "hyper_parameters": {
        "batch_size": 6,
        "max_seq_len": 100,
        "sliding_len": 20,
        "epochs": 5,
        "seed": 42,
        "shaking_type": "cln",
        "inner_enc_type": "lstm",
        "dist_emb_size": 20,
        "ent_add_dist": True,
        "rel_add_dist": True,
        "lr": 1e-5,
        "weight_decay": 0,
        "scheduler":"CAWR"
    }
}

# Utiliser cette configuration au lieu d'importer config.py
config = train_config
hyper_parameters = config["hyper_parameters"]

In [322]:

# Créer les répertoires nécessaires
os.makedirs("./models", exist_ok=True)
os.makedirs("./logs", exist_ok=True)



In [323]:
os.environ["TOKENIZERS_PARALLELISM"] = "true"
os.environ["CUDA_VISIBLE_DEVICES"] = str(config["device_num"])
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Utilisation du dispositif: {device}")

Utilisation du dispositif: cuda:0


In [324]:
# for reproductivity
torch.manual_seed(hyper_parameters["seed"]) # pytorch random seed
torch.backends.cudnn.deterministic = True

In [325]:
# Utiliser des chemins directs pour éviter les problèmes
train_data_path = config["train_data"]
valid_data_path = config["valid_data"]
rel2id_path = config["rel2id"]

print(f"Chemin du fichier d'entraînement: {train_data_path}")
print(f"Chemin du fichier de validation: {valid_data_path}")
print(f"Chemin du fichier rel2id: {rel2id_path}")

Chemin du fichier d'entraînement: nyt_dataset.json
Chemin du fichier de validation: nyt_dataset.json
Chemin du fichier rel2id: ./rel2id.json


In [326]:
data_home = config["data_home"]
experiment_name = config["exp_name"]    
train_data_path = os.path.join(data_home, experiment_name, config["train_data"])
valid_data_path = os.path.join(data_home, experiment_name, config["valid_data"])
rel2id_path = os.path.join(data_home, experiment_name, config["rel2id"])

In [327]:
import wandb
if config["logger"] == "wandb":
    # init wandb
    wandb.init(project = experiment_name, 
               name = config["run_name"],
               config = hyper_parameters # Initialize config
              )

    wandb.config.note = config["note"]          

    model_state_dict_dir = wandb.run.dir
    logger = wandb
else:
    logger = DefaultLogger(config["log_path"], experiment_name, config["run_name"], config["run_id"], hyper_parameters)
    model_state_dict_dir = config["path_to_save_model"]
    if not os.path.exists(model_state_dict_dir):
        os.makedirs(model_state_dict_dir)

# Load Data

In [328]:
import json
import random
import os

# Charger le dataset complet
data_path = "./nyt_dataset.json"
with open(data_path, "r", encoding="utf-8") as f:
    full_data = json.load(f)

# Mélanger les données
random.seed(42)
random.shuffle(full_data)

# Split 80% train / 20% valid
split_index = int(len(full_data) * 0.8)
train_data = full_data[:split_index]
valid_data = full_data[split_index:]

# Sauvegarder dans le même dossier
with open("./nyt_dataset_train.json", "w", encoding="utf-8") as f:
    json.dump(train_data, f, ensure_ascii=False, indent=2)

with open("./nyt_dataset_valid.json", "w", encoding="utf-8") as f:
    json.dump(valid_data, f, ensure_ascii=False, indent=2)


# Split

In [329]:
# @specific
if config["encoder"] == "BERT":
    tokenizer = BertTokenizerFast.from_pretrained(config["bert_path"], do_lower_case=False)
    tokenize = tokenizer.tokenize
    get_tok2char_span_map = lambda text: tokenizer.encode_plus(
        text, 
        return_offsets_mapping=True, 
        add_special_tokens=False
    )["offset_mapping"]

elif config["encoder"] in {"BiLSTM", }:
    tokenize = lambda text: text.split(" ")
    def get_tok2char_span_map(text):
        tokens = text.split(" ")
        tok2char_span = []
        char_num = 0
        for tok in tokens:
            tok2char_span.append((char_num, char_num + len(tok)))
            char_num += len(tok) + 1 # +1: whitespace
        return tok2char_span

In [330]:
preprocessor = Preprocessor(tokenize_func = tokenize, 
                            get_tok2char_span_map_func = get_tok2char_span_map)

In [331]:
# train and valid max token num
max_tok_num = 0
all_data = train_data + valid_data 
    
for sample in all_data:
    tokens = tokenize(sample["text"])
    max_tok_num = max(max_tok_num, len(tokens))
max_tok_num

118

In [332]:
if max_tok_num > hyper_parameters["max_seq_len"]:
    train_data = preprocessor.split_into_short_samples(train_data, 
                                                          hyper_parameters["max_seq_len"], 
                                                          sliding_len = hyper_parameters["sliding_len"], 
                                                          encoder = config["encoder"]
                                                         )
    valid_data = preprocessor.split_into_short_samples(valid_data, 
                                                          hyper_parameters["max_seq_len"], 
                                                          sliding_len = hyper_parameters["sliding_len"], 
                                                          encoder = config["encoder"]
                                                         )

Splitting into subtexts: 100%|██████████| 80/80 [00:00<00:00, 2703.97it/s]
Splitting into subtexts: 100%|██████████| 20/20 [00:00<00:00, 3055.96it/s]


In [333]:
print("train: {}".format(len(train_data)), "valid: {}".format(len(valid_data)))

train: 81 valid: 20


# Tagger (Decoder)

In [334]:
max_seq_len = min(max_tok_num, hyper_parameters["max_seq_len"])
rel2id = json.load(open("./rel2id.json", "r", encoding = "utf-8"))
handshaking_tagger = HandshakingTaggingScheme(rel2id = rel2id, max_seq_len = max_seq_len)

# Dataset

In [335]:
if config["encoder"] == "BERT":
    tokenizer = BertTokenizerFast.from_pretrained(config["bert_path"], do_lower_case=False)
    data_maker = DataMaker4Bert(tokenizer, handshaking_tagger)

    
elif config["encoder"] in {"BiLSTM", }:
    token2idx_path = os.path.join(data_home, experiment_name, config["token2idx"])
    token2idx = json.load(open(token2idx_path, "r", encoding = "utf-8"))
    idx2token = {idx:tok for tok, idx in token2idx.items()}
    def text2indices(text, max_seq_len):
        input_ids = []
        tokens = text.split(" ")
        for tok in tokens:
            if tok not in token2idx:
                input_ids.append(token2idx['<UNK>'])
            else:
                input_ids.append(token2idx[tok])
        if len(input_ids) < max_seq_len:
            input_ids.extend([token2idx['<PAD>']] * (max_seq_len - len(input_ids)))
        input_ids = torch.tensor(input_ids[:max_seq_len])
        return input_ids
    data_maker = DataMaker4BiLSTM(text2indices, get_tok2char_span_map, handshaking_tagger)

In [336]:
class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data
        
    def __getitem__(self, index):
        return self.data[index]
    
    def __len__(self):
        return len(self.data)

In [337]:

indexed_train_data = data_maker.get_indexed_data(train_data, max_seq_len)
indexed_valid_data = data_maker.get_indexed_data(valid_data, max_seq_len)

Generate indexed train or valid data: 81it [00:00, 1872.11it/s]
Generate indexed train or valid data: 20it [00:00, 1574.59it/s]


In [338]:
train_dataloader = DataLoader(MyDataset(indexed_train_data), 
                                  batch_size = hyper_parameters["batch_size"], 
                                  shuffle = True, 
                                  num_workers = 6,
                                  drop_last = False,
                                  collate_fn = data_maker.generate_batch,
                                 )
valid_dataloader = DataLoader(MyDataset(indexed_valid_data), 
                          batch_size = hyper_parameters["batch_size"], 
                          shuffle = True, 
                          num_workers = 6,
                          drop_last = False,
                          collate_fn = data_maker.generate_batch,
                         )

In [339]:
# # have a look at dataloader
# train_data_iter = iter(train_dataloader)
# batch_data = next(train_data_iter)
# text_id_list, text_list, batch_input_ids, \
# batch_attention_mask, batch_token_type_ids, \
# offset_map_list, batch_ent_shaking_tag, \
# batch_head_rel_shaking_tag, batch_tail_rel_shaking_tag = batch_data

# print(text_list[0])
# print()
# print(tokenizer.decode(batch_input_ids[0].tolist()))
# print(batch_input_ids.size())
# print(batch_attention_mask.size())
# print(batch_token_type_ids.size())
# print(len(offset_map_list))
# print(batch_ent_shaking_tag.size())
# print(batch_head_rel_shaking_tag.size())
# print(batch_tail_rel_shaking_tag.size())

# Model

In [340]:
if config["encoder"] == "BERT":
    encoder = AutoModel.from_pretrained(config["bert_path"])
    hidden_size = encoder.config.hidden_size
    rel_extractor = TPLinkerBert(encoder, 
                                 len(rel2id), 
                                 hyper_parameters["shaking_type"],
                                 hyper_parameters["inner_enc_type"],
                                 hyper_parameters["dist_emb_size"],
                                 hyper_parameters["ent_add_dist"],
                                 hyper_parameters["rel_add_dist"],
                                )
    
elif config["encoder"] in {"BiLSTM", }:
    glove = Glove()
    glove = glove.load(config["pretrained_word_embedding_path"])
    
    # prepare embedding matrix
    word_embedding_init_matrix = np.random.normal(-1, 1, size=(len(token2idx), hyper_parameters["word_embedding_dim"]))
    count_in = 0

    # 在预训练词向量中的用该预训练向量
    # 不在预训练集里的用随机向量
    for ind, tok in tqdm(idx2token.items(), desc="Embedding matrix initializing..."):
        if tok in glove.dictionary:
            count_in += 1
            word_embedding_init_matrix[ind] = glove.word_vectors[glove.dictionary[tok]]

    print("{:.4f} tokens are in the pretrain word embedding matrix".format(count_in / len(idx2token))) # 命中预训练词向量的比例
    word_embedding_init_matrix = torch.FloatTensor(word_embedding_init_matrix)
    
    rel_extractor = TPLinkerBiLSTM(word_embedding_init_matrix, 
                                   hyper_parameters["emb_dropout"], 
                                   hyper_parameters["enc_hidden_size"], 
                                   hyper_parameters["dec_hidden_size"],
                                   hyper_parameters["rnn_dropout"],
                                   len(rel2id), 
                                   hyper_parameters["shaking_type"],
                                   hyper_parameters["inner_enc_type"],
                                   hyper_parameters["dist_emb_size"],
                                   hyper_parameters["ent_add_dist"],
                                   hyper_parameters["rel_add_dist"],
                                  )

rel_extractor = rel_extractor.to(device)

In [341]:
# all_paras = sum(x.numel() for x in rel_extractor.parameters())
# enc_paras = sum(x.numel() for x in encoder.parameters())

In [342]:
# print(all_paras, enc_paras)
# print(all_paras - enc_paras)

# Metrics

In [343]:
def bias_loss(weights = None):
    if weights is not None:
        weights = torch.FloatTensor(weights).to(device)
    cross_en = nn.CrossEntropyLoss(weight = weights)  
    return lambda pred, target: cross_en(pred.view(-1, pred.size()[-1]), target.view(-1))
loss_func = bias_loss()

In [344]:
metrics = MetricsCalculator(handshaking_tagger)

# Train

In [345]:
# train step
def train_step(batch_train_data, optimizer, loss_weights):
    if config["encoder"] == "BERT":
        sample_list, batch_input_ids, \
        batch_attention_mask, batch_token_type_ids, \
        tok2char_span_list, batch_ent_shaking_tag, \
        batch_head_rel_shaking_tag, batch_tail_rel_shaking_tag = batch_train_data
        
        batch_input_ids, \
        batch_attention_mask, \
        batch_token_type_ids, \
        batch_ent_shaking_tag, \
        batch_head_rel_shaking_tag, \
        batch_tail_rel_shaking_tag = (batch_input_ids.to(device), 
                                      batch_attention_mask.to(device), 
                                      batch_token_type_ids.to(device), 
                                      batch_ent_shaking_tag.to(device), 
                                      batch_head_rel_shaking_tag.to(device), 
                                      batch_tail_rel_shaking_tag.to(device)
                                     )
        
    elif config["encoder"] in {"BiLSTM", }:
        sample_list, batch_input_ids, \
        tok2char_span_list, batch_ent_shaking_tag, \
        batch_head_rel_shaking_tag, batch_tail_rel_shaking_tag = batch_train_data
        
        batch_input_ids, \
        batch_ent_shaking_tag, \
        batch_head_rel_shaking_tag, \
        batch_tail_rel_shaking_tag = (batch_input_ids.to(device), 
                                      batch_ent_shaking_tag.to(device), 
                                      batch_head_rel_shaking_tag.to(device), 
                                      batch_tail_rel_shaking_tag.to(device)
                                     )
    

    # zero the parameter gradients
    optimizer.zero_grad()
    
    if config["encoder"] == "BERT":
        ent_shaking_outputs, \
        head_rel_shaking_outputs, \
        tail_rel_shaking_outputs = rel_extractor(batch_input_ids, 
                                                  batch_attention_mask, 
                                                  batch_token_type_ids, 
                                                 )
    elif config["encoder"] in {"BiLSTM", }:
        ent_shaking_outputs, \
        head_rel_shaking_outputs, \
        tail_rel_shaking_outputs = rel_extractor(batch_input_ids)
    
    w_ent, w_rel = loss_weights["ent"], loss_weights["rel"]
    loss = w_ent * loss_func(ent_shaking_outputs, batch_ent_shaking_tag) + \
            w_rel * loss_func(head_rel_shaking_outputs, batch_head_rel_shaking_tag) + \
            w_rel * loss_func(tail_rel_shaking_outputs, batch_tail_rel_shaking_tag)
    
    loss.backward()
    optimizer.step()
    
    ent_sample_acc = metrics.get_sample_accuracy(ent_shaking_outputs, 
                                          batch_ent_shaking_tag)
    head_rel_sample_acc = metrics.get_sample_accuracy(head_rel_shaking_outputs, 
                                             batch_head_rel_shaking_tag)
    tail_rel_sample_acc = metrics.get_sample_accuracy(tail_rel_shaking_outputs, 
                                             batch_tail_rel_shaking_tag)
    
    return loss.item(), ent_sample_acc.item(), head_rel_sample_acc.item(), tail_rel_sample_acc.item()

# valid step
def valid_step(batch_valid_data):
    if config["encoder"] == "BERT":
        sample_list, batch_input_ids, \
        batch_attention_mask, batch_token_type_ids, \
        tok2char_span_list, batch_ent_shaking_tag, \
        batch_head_rel_shaking_tag, batch_tail_rel_shaking_tag = batch_valid_data
        
        batch_input_ids, \
        batch_attention_mask, \
        batch_token_type_ids, \
        batch_ent_shaking_tag, \
        batch_head_rel_shaking_tag, \
        batch_tail_rel_shaking_tag = (batch_input_ids.to(device), 
                                      batch_attention_mask.to(device), 
                                      batch_token_type_ids.to(device), 
                                      batch_ent_shaking_tag.to(device), 
                                      batch_head_rel_shaking_tag.to(device), 
                                      batch_tail_rel_shaking_tag.to(device)
                                     )
        
    elif config["encoder"] in {"BiLSTM", }:
        sample_list, batch_input_ids, \
        tok2char_span_list, batch_ent_shaking_tag, \
        batch_head_rel_shaking_tag, batch_tail_rel_shaking_tag = batch_valid_data
        
        batch_input_ids, \
        batch_ent_shaking_tag, \
        batch_head_rel_shaking_tag, \
        batch_tail_rel_shaking_tag = (batch_input_ids.to(device), 
                                      batch_ent_shaking_tag.to(device), 
                                      batch_head_rel_shaking_tag.to(device), 
                                      batch_tail_rel_shaking_tag.to(device)
                                     )
    
    with torch.no_grad():
        if config["encoder"] == "BERT":
            ent_shaking_outputs, \
            head_rel_shaking_outputs, \
            tail_rel_shaking_outputs = rel_extractor(batch_input_ids, 
                                                      batch_attention_mask, 
                                                      batch_token_type_ids, 
                                                     )
        elif config["encoder"] in {"BiLSTM", }:
            ent_shaking_outputs, \
            head_rel_shaking_outputs, \
            tail_rel_shaking_outputs = rel_extractor(batch_input_ids)

    
    ent_sample_acc = metrics.get_sample_accuracy(ent_shaking_outputs, 
                                          batch_ent_shaking_tag)
    head_rel_sample_acc = metrics.get_sample_accuracy(head_rel_shaking_outputs, 
                                             batch_head_rel_shaking_tag)
    tail_rel_sample_acc = metrics.get_sample_accuracy(tail_rel_shaking_outputs, 
                                             batch_tail_rel_shaking_tag)
    
    rel_cpg = metrics.get_rel_cpg(sample_list, tok2char_span_list, 
                                    ent_shaking_outputs,
                                    head_rel_shaking_outputs,
                                    tail_rel_shaking_outputs,
                                    hyper_parameters["match_pattern"]
                                    )
    
    return ent_sample_acc.item(), head_rel_sample_acc.item(), tail_rel_sample_acc.item(), rel_cpg

In [346]:
max_f1 = 0.
def train_n_valid(train_dataloader, dev_dataloader, optimizer, scheduler, num_epoch):  
    def train(dataloader, ep):
        # train
        rel_extractor.train()
        
        t_ep = time.time()
        start_lr = optimizer.param_groups[0]['lr']
        total_loss, total_ent_sample_acc, total_head_rel_sample_acc, total_tail_rel_sample_acc = 0., 0., 0., 0.
        for batch_ind, batch_train_data in enumerate(dataloader):
            t_batch = time.time()
            z = (2 * len(rel2id) + 1)
            steps_per_ep = len(dataloader)
            total_steps = hyper_parameters["loss_weight_recover_steps"] + 1 # + 1 avoid division by zero error
            current_step = steps_per_ep * ep + batch_ind
            w_ent = max(1 / z + 1 - current_step / total_steps, 1 / z)
            w_rel = min((len(rel2id) / z) * current_step / total_steps, (len(rel2id) / z))
            loss_weights = {"ent": w_ent, "rel": w_rel}
            
            loss, ent_sample_acc, head_rel_sample_acc, tail_rel_sample_acc = train_step(batch_train_data, optimizer, loss_weights)
            scheduler.step()
            
            total_loss += loss
            total_ent_sample_acc += ent_sample_acc
            total_head_rel_sample_acc += head_rel_sample_acc
            total_tail_rel_sample_acc += tail_rel_sample_acc
            
            avg_loss = total_loss / (batch_ind + 1)
            avg_ent_sample_acc = total_ent_sample_acc / (batch_ind + 1)
            avg_head_rel_sample_acc = total_head_rel_sample_acc / (batch_ind + 1)
            avg_tail_rel_sample_acc = total_tail_rel_sample_acc / (batch_ind + 1)
            
            batch_print_format = "\rproject: {}, run_name: {}, Epoch: {}/{}, batch: {}/{}, train_loss: {}, " + \
                                "t_ent_sample_acc: {}, t_head_rel_sample_acc: {}, t_tail_rel_sample_acc: {}," + \
                                 "lr: {}, batch_time: {}, total_time: {} -------------"
                    
            print(batch_print_format.format(experiment_name, config["run_name"], 
                                            ep + 1, num_epoch, 
                                            batch_ind + 1, len(dataloader), 
                                            avg_loss, 
                                            avg_ent_sample_acc,
                                            avg_head_rel_sample_acc,
                                            avg_tail_rel_sample_acc,
                                            optimizer.param_groups[0]['lr'],
                                            time.time() - t_batch,
                                            time.time() - t_ep,
                                           ), end="")
            
            if config["logger"] == "wandb" and batch_ind % hyper_parameters["log_interval"] == 0:
                logger.log({
                    "train_loss": avg_loss,
                    "train_ent_seq_acc": avg_ent_sample_acc,
                    "train_head_rel_acc": avg_head_rel_sample_acc,
                    "train_tail_rel_acc": avg_tail_rel_sample_acc,
                    "learning_rate": optimizer.param_groups[0]['lr'],
                    "time": time.time() - t_ep,
                })
                
        if config["logger"] != "wandb": # only log once for training if logger is not wandb
                logger.log({
                    "train_loss": avg_loss,
                    "train_ent_seq_acc": avg_ent_sample_acc,
                    "train_head_rel_acc": avg_head_rel_sample_acc,
                    "train_tail_rel_acc": avg_tail_rel_sample_acc,
                    "learning_rate": optimizer.param_groups[0]['lr'],
                    "time": time.time() - t_ep,
                }) 
            
        
    def valid(dataloader, ep):
        # valid
        rel_extractor.eval()
        
        t_ep = time.time()
        total_ent_sample_acc, total_head_rel_sample_acc, total_tail_rel_sample_acc = 0., 0., 0.
        total_rel_correct_num, total_rel_pred_num, total_rel_gold_num = 0, 0, 0
        for batch_ind, batch_valid_data in enumerate(tqdm(dataloader, desc = "Validating")):
            ent_sample_acc, head_rel_sample_acc, tail_rel_sample_acc, rel_cpg = valid_step(batch_valid_data)

            total_ent_sample_acc += ent_sample_acc
            total_head_rel_sample_acc += head_rel_sample_acc
            total_tail_rel_sample_acc += tail_rel_sample_acc
            
            total_rel_correct_num += rel_cpg[0]
            total_rel_pred_num += rel_cpg[1]
            total_rel_gold_num += rel_cpg[2]

        avg_ent_sample_acc = total_ent_sample_acc / len(dataloader)
        avg_head_rel_sample_acc = total_head_rel_sample_acc / len(dataloader)
        avg_tail_rel_sample_acc = total_tail_rel_sample_acc / len(dataloader)
        
        rel_prf = metrics.get_prf_scores(total_rel_correct_num, total_rel_pred_num, total_rel_gold_num)
        
        log_dict = {
                        "val_ent_seq_acc": avg_ent_sample_acc,
                        "val_head_rel_acc": avg_head_rel_sample_acc,
                        "val_tail_rel_acc": avg_tail_rel_sample_acc,
                        "val_prec": rel_prf[0],
                        "val_recall": rel_prf[1],
                        "val_f1": rel_prf[2],
                        "time": time.time() - t_ep,
                    }
        logger.log(log_dict)
        pprint(log_dict)
        
        return rel_prf[2]
        
    for ep in range(num_epoch):
        train(train_dataloader, ep)   
        valid_f1 = valid(valid_dataloader, ep)
        
        global max_f1
        if valid_f1 >= max_f1: 
            max_f1 = valid_f1
            if valid_f1 > config["f1_2_save"]: # save the best model
                modle_state_num = len(glob.glob(model_state_dict_dir + "/model_state_dict_*.pt"))
                torch.save(rel_extractor.state_dict(), os.path.join(model_state_dict_dir, "model_state_dict_{}.pt".format(modle_state_num)))
#                 scheduler_state_num = len(glob.glob(schedule_state_dict_dir + "/scheduler_state_dict_*.pt"))
#                 torch.save(scheduler.state_dict(), os.path.join(schedule_state_dict_dir, "scheduler_state_dict_{}.pt".format(scheduler_state_num))) 
        print("Current avf_f1: {}, Best f1: {}".format(valid_f1, max_f1))

In [347]:
# Hyperparamètres
hyper_parameters = {
    "lr": 1e-4,
    "scheduler": "CAWR",    # ou "Step"
    "T_mult": 2,
    "rewarm_epoch_num": 2,
    "decay_rate": 0.5,
    "decay_steps": 5
}

# Initialisation de l'optimizer
init_learning_rate = float(hyper_parameters["lr"])
optimizer = torch.optim.Adam(rel_extractor.parameters(), lr=init_learning_rate)

# Scheduler
if hyper_parameters.get("scheduler", None) == "CAWR":
    T_mult = hyper_parameters.get("T_mult", 1)  # valeur par défaut 1 si absent
    rewarm_epoch_num = hyper_parameters.get("rewarm_epoch_num", 2)
    first_cycle_steps = len(train_dataloader) * rewarm_epoch_num
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer,
        T_0=first_cycle_steps,
        T_mult=T_mult
    )
    print(f"Scheduler: CosineAnnealingWarmRestarts, T_0={first_cycle_steps}, T_mult={T_mult}")

elif hyper_parameters.get("scheduler", None) == "Step":
    decay_rate = hyper_parameters.get("decay_rate", 0.5)
    decay_steps = hyper_parameters.get("decay_steps", 5)
    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer,
        step_size=decay_steps,
        gamma=decay_rate
    )
    print(f"Scheduler: StepLR, step_size={decay_steps}, gamma={decay_rate}")
else:
    scheduler = None
    print("Scheduler not specified or not recognized.")


Scheduler: CosineAnnealingWarmRestarts, T_0=28, T_mult=2


In [348]:
if not config["fr_scratch"]:
    model_state_path = config["model_state_dict_path"]
    rel_extractor.load_state_dict(torch.load(model_state_path))
    print("------------model state {} loaded ----------------".format(model_state_path.split("/")[-1]))

train_n_valid(train_dataloader, valid_dataloader, optimizer, scheduler, hyper_parameters["epochs"])

  rel_extractor.load_state_dict(torch.load(model_state_path))


------------model state model.pt loaded ----------------


KeyError: 'epochs'