# OOD detection using sst2 as in distribution and imdb as out of distribution

In [None]:
!pip install transformers
!pip install datasets
!pip install torch
!pip install pytorch_transformers

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from transformers import RobertaForSequenceClassification, RobertaTokenizer, AdamW
from datasets import load_dataset, concatenate_datasets, Dataset
import re
import string
from sklearn.metrics import roc_auc_score, average_precision_score
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from torch import cuda
from sklearn.model_selection import train_test_split
from pytorch_transformers import AdamW, WarmupLinearSchedule

In [None]:
# Configuration de l'appareil pour l'utilisation du GPU
device = 'cuda' if cuda.is_available() else 'cpu'

## Chargement des datasets

In [None]:
sst2 = load_dataset('glue','sst2')
inds_set = concatenate_datasets([sst2['train'],sst2['validation'],sst2['test']])
ood_set = load_dataset('imdb', split='test')

In [None]:
inds_df= pd.DataFrame(inds_set)

# création des sets de training, validation et test
train, rest = train_test_split(inds_df, train_size=0.7, random_state=42, stratify=inds_df['label'])
validation, test  = train_test_split(rest, train_size=1/3, random_state=42, stratify=rest['label'])

train_set = Dataset.from_pandas(train).remove_columns(['__index_level_0__'])
validation_set = Dataset.from_pandas(validation).remove_columns(['__index_level_0__'])
test_set = Dataset.from_pandas(test).remove_columns(['__index_level_0__'])

### prétraitement du texte

In [None]:
def preprocess_text(text):
    text = re.sub(r'http\S+', '', text) # Supprimer les URLs
    text = re.sub(r'<.*?>', '', text) # Supprimer les balises HTML
    text = re.sub(r'[^\w\s]', '', text) # Supprimer la ponctuation
    text = re.sub(r'#\w+', '', text) # Supprimer les hashtags
    text = re.sub(r'@\w+', '', text) # Supprimer les mentions
    text = re.sub(r'[0-9]+', '', text) # Supprimer les chiffres
    text = text.translate(str.maketrans('', '', string.punctuation))
    return text.strip()

# Prétraiter les données de sst2
train_set= train_set.map(lambda example: {'label': example['label'], 'sentence': preprocess_text(example['sentence'])})
validation_set= validation_set.map(lambda example: {'label': example['label'], 'sentence': preprocess_text(example['sentence'])})
test_set= test_set.map(lambda example: {'label': example['label'], 'sentence': preprocess_text(example['sentence'])})

# Prétraiter les données de imdb
ood_set = ood_set.map(lambda example: {'label': example['label'], 'text': preprocess_text(example['text'])})



### tokenizer et tenseurs

In [None]:
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')


# Tokenize sst2
train_set = train_set.map(lambda x: tokenizer(x['sentence'], padding='max_length', truncation=True, return_tensors='pt'), batched=True)
validation_set = validation_set.map(lambda x: tokenizer(x['sentence'], padding='max_length', truncation=True, return_tensors='pt'), batched=True)
test_set = test_set.map(lambda x: tokenizer(x['sentence'], padding='max_length', truncation=True, return_tensors='pt'), batched=True)


# Tokenize imdb
ood_set = ood_set.map(lambda x: tokenizer(x['text'], padding='max_length', truncation=True, return_tensors='pt'), batched=True)

In [None]:
# Conversion en tenseurs PyTorch 
train_inputs = torch.tensor(train_set['input_ids']).to(torch.int64)
val_inputs = torch.tensor(validation_set['input_ids']).to(torch.int64)
test_inputs = torch.tensor(test_set['input_ids']).to(torch.int64)
ood_inputs = torch.tensor(ood_set['input_ids']).to(torch.int64)

train_labels = torch.tensor(train_set['label']).to(torch.int64)
val_labels = torch.tensor(validation_set['label']).to(torch.int64)
test_labels = torch.tensor(test_set['label']).to(torch.int64)
ood_labels = torch.tensor(ood_set['label']).to(torch.int64)


train_masks = torch.tensor(train_set['attention_mask']).to(torch.int64)
val_masks = torch.tensor(validation_set['attention_mask']).to(torch.int64)
test_masks = torch.tensor(test_set['attention_mask']).to(torch.int64)
ood_masks = torch.tensor(ood_set['attention_mask']).to(torch.int64)



### création des dataloaders

In [None]:
# Create PyTorch DataLoader objects
batch_size = 16

train_data = TensorDataset(train_inputs, train_masks, train_labels)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

validation_data = TensorDataset(val_inputs, val_masks, val_labels)
validation_loader = DataLoader(validation_data, batch_size=batch_size, shuffle=False)

test_data = TensorDataset(test_inputs, test_masks, test_labels)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

ood_data = TensorDataset(ood_inputs, ood_masks, ood_labels)
ood_loader = DataLoader(ood_data, batch_size=batch_size, shuffle=False)

## Entraînement de Roberta pour la detection des ood

In [None]:
# Charger le modèle pré-entraîné Roberta et ajouter une couche de classification en sortie
model = RobertaForSequenceClassification.from_pretrained('roberta-base', num_labels=2)
model.cuda()

In [None]:

num_epochs = 4
WEIGHT_DECAY = 0.01
learning_rate = 2e-5 
WARMUP_STEPS =int(0.2*len(train_loader))

no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)],
     'weight_decay': WEIGHT_DECAY},
    {'params': [p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)],
     'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate, eps=1e-8)
scheduler = WarmupLinearSchedule(optimizer, warmup_steps=WARMUP_STEPS,
                                 t_total=len(train_loader)*num_epochs)

In [None]:
# Entraîner le modèle
for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for batch in train_loader:
        batch = tuple(t.to(device) for t in batch)
        train_inputs, train_masks, train_labels = batch

        optimizer.zero_grad()
        outputs = model(train_inputs, attention_mask = train_masks, 
                        labels = train_labels)
    
        loss = outputs[0]
        train_loss += loss.item()
        loss.backward()

        scheduler.step()
        optimizer.step()
   
    # Évaluer le modèle sur les données de validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch in validation_loader:
            batch = tuple(t.to(device) for t in batch)
            val_inputs, val_masks, val_labels = batch
            
            outputs = model(val_inputs,attention_mask=val_masks, 
                            labels=val_labels)
            loss = outputs[0]
            val_loss += loss.item()
    print(f'Epoch {epoch+1}, Train Loss: {train_loss/len(train_inputs)}, Val Loss: {val_loss/len(val_inputs)}')
    


#### Evaluation du modèl sur les données de test

In [None]:
# Évaluer le modèle sur les données de test
model.eval()
predictions = []
with torch.no_grad():
    for batch in test_loader:
        batch = tuple(t.to(device) for t in batch)
        test_inputs, test_masks, test_labels = batch
        outputs = model(test_inputs, attention_mask = test_masks)
        logits = outputs[0]
        predictions.extend(torch.argmax(logits, dim=1).tolist())

# Afficher l'accuracy sur les données de test
correct_predictions = 0
for i, prediction in enumerate(predictions):
    if prediction == test_labels[i]:
        correct_predictions += 1
accuracy = correct_predictions / len(predictions)
print(f'Accuracy on test set: {accuracy}')


In [None]:
# Détecteurs de OOD
#def max_softmax(scores):
   # return np.max(scores)

#def energy_score(scores):
    #return -np.sum(np.log(scores + 1e-6))

#def mahalanobis_score(scores):
    #cov = np.cov(scores, rowvar=False)
    #inv_cov = np.linalg.inv(cov + np.eye(cov.shape[0]) * 1e-6)
    #return np.dot(np.dot(scores - np.mean(scores, axis=0), inv_cov), (scores - np.mean(scores, axis=0)).T)