# *Notebook* à utiliser pour faire le travail pratique # 3 sur l'analyse d'incidents.





## Imports

In [1]:
# Importation des bibliothèques 
from transformers import GPT2LMHeadModel, GPT2Tokenizer, AdamW
import torch
import json
import re
from torch.utils.data import Dataset, DataLoader
import string
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from sklearn.metrics import accuracy_score
from torch.nn.utils.rnn import pad_sequence
from collections import Counter

In [2]:
# Vérifiez si un GPU est disponible et utilisez-le, sinon utilisez le CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Chargements Modèles et Tokenizers

In [3]:
with open('data/dev_examples.json', 'r') as file:
  data = json.load(file)

In [None]:
# Charger le modèle GPT-2
# Initialisation du modèle et du tokenizer pour GPT-2
model_name = "gpt2"  # ou "gpt2-medium" selon les ressources disponibles
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model.to(device)


## création du dataset

In [5]:
# Formatage des données pour GPT-2
def format_data_for_gpt2(data):
    formatted_data = []

    for item in data:
        text = item['text']
        arguments = item['arguments']

        # Combinez le texte avec les arguments
        full_text = f"{text} {' '.join(str(arg) for arg in arguments.values())}"

        formatted_data.append(full_text)

    return formatted_data
    # Formatage des données
formatted_data = format_data_for_gpt2(data)

# Création du Dataset
class IncidentDataset(Dataset):
    def __init__(self, tokenizer, formatted_data, max_token_length=512):
        self.tokenizer = tokenizer
        self.data = formatted_data
        self.max_token_length = max_token_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text = self.data[idx]

        input_encoding = self.tokenizer.encode(
            text,
            max_length=self.max_token_length,
            return_tensors='pt',
            truncation=True  
        )
        

        

        return {
            'input_ids': input_encoding.squeeze(0),
            
        }
        
        

class Collator:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def __call__(self, batch):
        input_ids = [item['input_ids'] for item in batch]

        # pad_sequence pour rembourrer les séquences à la même longueur
        input_ids = pad_sequence(input_ids, batch_first=True)

        return {'input_ids': input_ids}

# ...

# Créez un objet Collator
collator = Collator(tokenizer)

#création du DataLoader
train_data, val_data = train_test_split(formatted_data, test_size=0.2, random_state=42)
train_dataset = IncidentDataset(tokenizer, train_data, max_token_length=512)
val_dataset = IncidentDataset(tokenizer, val_data, max_token_length=512)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=collator)
val_loader = DataLoader(val_dataset, batch_size=8)

# Configuration de l'optimiseur
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

In [None]:
# Configuration de l'optimiseur pour GPT-2
optimizer_gpt2 = AdamW(model.parameters(), lr=5e-5)

# Fonction de perte pour 
loss_fn_gpt2 = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

## les fonctions d'évaluation

In [7]:

def normalize_answer(s):
    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))

def f1_score(prediction, ground_truth):
    prediction_tokens = normalize_answer(prediction).split()
    ground_truth_tokens = normalize_answer(ground_truth).split()
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())
    if len(ground_truth_tokens) == 0 or len(prediction_tokens) == 0:
        return int(ground_truth_tokens == prediction_tokens)
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1

def exact_match_score(prediction, ground_truth):
    return (normalize_answer(prediction) == normalize_answer(ground_truth))

def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
    scores_for_ground_truths = []
    for ground_truth in ground_truths:
        score = metric_fn(prediction, ground_truth)
        scores_for_ground_truths.append(score)
    return max(scores_for_ground_truths)

# Fonction d'évaluation complète
def evaluate_model(model, val_loader, device):
    model.eval()
    total_f1, total_exact_match, total_count = 0, 0, 0

    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits
            preds = torch.argmax(logits, dim=-1)

            for pred, label in zip(preds, labels):
                pred_text = tokenizer.decode(pred, skip_special_tokens=True)
                label_text = tokenizer.decode(label, skip_special_tokens=True)
                total_f1 += metric_max_over_ground_truths(f1_score, pred_text, [label_text])
                total_exact_match += metric_max_over_ground_truths(exact_match_score, pred_text, [label_text])
                total_count += 1

    return total_f1 / total_count, total_exact_match / total_count






## Entrainement

In [None]:
# Nombre d'epochs pour GPT-2
num_epochs_gpt2 = 5

# Initialisation f1 à 0
best_val_f1= 0

for epoch in range(num_epochs_gpt2):
    model.train()
    total_loss_gpt2 = 0

    for batch in tqdm(train_loader):
        optimizer.zero_grad()

        input_ids_gpt2 = batch['input_ids'].to(device)
        

        outputs_gpt2 = model(input_ids=input_ids_gpt2) 
        loss_gpt2 = outputs_gpt2.loss
        total_loss_gpt2 += loss_gpt2.item()

        loss_gpt2.backward()
        optimizer.step()

    # Évaluation du modèle GPT-2
    val_f1, val_exact_match = evaluate_model(model, val_loader, device)
    print(f"Validation F1 Score for Epoch {epoch + 1}: {val_f1}")
    print(f"Validation Exact Match for Epoch {epoch + 1}: {val_exact_match}")
    print(f"Training Loss for Epoch {epoch + 1}: {total_loss_gpt2 / len(train_loader)}")

    