# *Notebook* à utiliser pour faire le travail pratique # 3 sur l'analyse d'incidents.





## Imports

In [None]:
# Importation des bibliothèques nécessaires
from transformers import T5ForConditionalGeneration, T5Tokenizer
import torch
import json
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from torch.optim import AdamW
from tqdm import tqdm

## Chargements Modèles & Tokenizers & Données

In [None]:
# Initialisation du modèle et du tokenizer pour t5-base
model_name = "t5-Large"
model = T5ForConditionalGeneration.from_pretrained(model_name)
tokenizer = T5Tokenizer.from_pretrained(model_name)

In [None]:
file_path = 'data/dev_examples.json'
with open(file_path, 'r') as file:
    data = json.load(file)

# Chargement des données

In [None]:
# Vérification de la disponibilité du GPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Formattage Données

In [None]:
def format_data_for_t5(data):
    formatted_data = []

    for item in data:
        text = item['text']
        arguments = item['arguments']
        
        for key, values in arguments.items():
            for value in values:
                # Création de la question
                question = f"What is the {key} in the incident?"
                # Formatage de la paire question-réponse pour T5
                input_text = f"question: {question} context: {text}"
                target_text = value

                formatted_data.append((input_text, target_text))
    
    return formatted_data

# Formatage des données
formatted_data = format_data_for_t5(data)

# Création DataSet

In [None]:
class IncidentDataset(Dataset):
    def __init__(self, tokenizer, formatted_data, max_token_length=512):
        self.tokenizer = tokenizer
        self.data = formatted_data
        self.max_token_length = max_token_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        input_text, target_text = self.data[idx]

        input_encoding = self.tokenizer.encode_plus(
            input_text,
            max_length=self.max_token_length,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )
        target_encoding = self.tokenizer.encode_plus(
            target_text,
            max_length=self.max_token_length,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )

        return {
            'input_ids': input_encoding['input_ids'].squeeze(0),
            'attention_mask': input_encoding['attention_mask'].squeeze(0),
            'labels': target_encoding['input_ids'].squeeze(0)
        }

# Division des données en ensembles d'entraînement et de validation
train_data, val_data = train_test_split(formatted_data, test_size=0.2, random_state=42)

In [None]:
train_dataset = IncidentDataset(tokenizer, train_data, max_token_length=512)
val_dataset = IncidentDataset(tokenizer, val_data, max_token_length=512)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8)

In [None]:
# Configuration de l'optimiseur
optimizer = AdamW(model.parameters(), lr=5e-5)

# Fonction de perte (la perte de cross-entropy est généralement utilisée pour T5)
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

# Entrainement Models

In [None]:
# Nombre d'epochs
num_epochs = 15

for epoch in range(num_epochs):
    model.train()
    total_loss = 0

    for batch in tqdm(train_loader):
        optimizer.zero_grad()

        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        total_loss += loss.item()

        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch + 1} completed. Average loss: {total_loss / len(train_loader)}")