In [2]:
!pip install tqdm



In [1]:
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
from datasets import load_dataset,concatenate_datasets,DatasetDict
import torch
from collections import Counter
import torch
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, Subset
from tqdm import tqdm
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
import numpy as np
from sklearn.metrics import accuracy_score
import pandas as pd
from sklearn.utils import resample

## load dataset

In [43]:
dataset = load_dataset('multi_nli')
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['promptID', 'pairID', 'premise', 'premise_binary_parse', 'premise_parse', 'hypothesis', 'hypothesis_binary_parse', 'hypothesis_parse', 'genre', 'label'],
        num_rows: 392702
    })
    validation_matched: Dataset({
        features: ['promptID', 'pairID', 'premise', 'premise_binary_parse', 'premise_parse', 'hypothesis', 'hypothesis_binary_parse', 'hypothesis_parse', 'genre', 'label'],
        num_rows: 9815
    })
    validation_mismatched: Dataset({
        features: ['promptID', 'pairID', 'premise', 'premise_binary_parse', 'premise_parse', 'hypothesis', 'hypothesis_binary_parse', 'hypothesis_parse', 'genre', 'label'],
        num_rows: 9832
    })
})


In [44]:
label_counts = Counter(dataset['train']['label'])
print(label_counts)

Counter({2: 130903, 1: 130900, 0: 130899})


In [45]:
negation_words = ['nobody','no','never','nothing']


def contains_negation(sentence):
    return int(any(neg_word in sentence.lower().split() for neg_word in negation_words))

def add_negation_label(example):
    example['negation'] = contains_negation(example['hypothesis'])
    return example

dataset = dataset.map(add_negation_label)

def add_combined_label(example):
    num_spurious_features = 2  
    example['group_label'] = example['label'] * num_spurious_features + example['negation']
    return example


dataset = dataset.map(add_combined_label)

In [46]:
print(dataset['train'])

Dataset({
    features: ['promptID', 'pairID', 'premise', 'premise_binary_parse', 'premise_parse', 'hypothesis', 'hypothesis_binary_parse', 'hypothesis_parse', 'genre', 'label', 'negation', 'group_label'],
    num_rows: 392702
})


In [47]:
label_counts = Counter(dataset['train']['group_label'])
print(label_counts)

Counter({0: 128262, 2: 127276, 4: 110301, 5: 20602, 3: 3624, 1: 2637})


In [48]:
combine_dataset=concatenate_datasets([
    dataset['train'],
    dataset['validation_matched'],
    dataset['validation_mismatched']
])

In [49]:
combine_dataset = combine_dataset.shuffle(seed=42)

In [50]:

total_size = len(combine_dataset)
train_size = int(0.5 * total_size)
validation_size = int(0.2 * total_size)

train_dataset = combine_dataset.select(range(train_size))
validation_dataset = combine_dataset.select(range(train_size, train_size + validation_size))
test_dataset = combine_dataset.select(range(train_size + validation_size, total_size))


split_datasets = DatasetDict({
    'train': train_dataset,
    'validation': validation_dataset,
    'test': test_dataset
})

In [51]:
print(train_dataset)

Dataset({
    features: ['promptID', 'pairID', 'premise', 'premise_binary_parse', 'premise_parse', 'hypothesis', 'hypothesis_binary_parse', 'hypothesis_parse', 'genre', 'label', 'negation', 'group_label'],
    num_rows: 206174
})


In [52]:
label_counts = Counter(train_dataset['group_label'])
print(label_counts)

Counter({0: 67522, 2: 66633, 4: 57967, 5: 10746, 3: 1907, 1: 1399})


## tokenize

In [53]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
def tokenize_function(example):
    return tokenizer(example['premise'], example['hypothesis'], truncation=True, padding='max_length', max_length=128,return_tensors="pt")


train_dataset = train_dataset.map(tokenize_function, batched=True)
train_dataset =train_dataset.rename_column("label", "labels")
train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

validation_dataset = validation_dataset.map(tokenize_function, batched=True)
validation_dataset = validation_dataset.rename_column("label", "labels")
validation_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

test_dataset = test_dataset.map(tokenize_function, batched=True)
test_dataset =test_dataset.rename_column("label", "labels")
test_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

Map:   0%|          | 0/123706 [00:00<?, ? examples/s]

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pai

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pai

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pai

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pai

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pai

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pai

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pai

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pai

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pai

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pai

In [3]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
def tokenize_function(example):
    # Encode the premise and the hypothesis, with truncation and padding
    return tokenizer(example['premise'], example['hypothesis'], truncation=True, padding='max_length', max_length=128,return_tensors="pt")

In [11]:
dataset2=load_dataset('multi_nli')

In [12]:
metadata=pd.read_csv('metadata_random.csv')
metadata['group_label']=metadata['gold_label']*2+metadata['sentence2_has_negation']
print(metadata.head())

   Unnamed: 0  gold_label  sentence2_has_negation  split  group_label
0           0           2                       0      2            4
1           1           1                       0      0            2
2           2           1                       0      2            2
3           3           1                       0      2            2
4           4           2                       0      0            4


In [13]:
full_dataset=concatenate_datasets([
    dataset2['train'],
    dataset2['validation_matched'],
    dataset2['validation_mismatched']
])

metadata_dict_list = metadata.to_dict(orient='records')

def add_metadata(example, index):
    for key, value in metadata_dict_list[index].items():
        example[key] = value
    return example

full_dataset = full_dataset.map(add_metadata, with_indices=True)

In [14]:
train_dataset2 = full_dataset.filter(lambda example: example['split'] == 0)
validation_dataset2 = full_dataset.filter(lambda example: example['split'] == 1)
test_dataset2 = full_dataset.filter(lambda example: example['split'] == 2)

In [15]:
label_counts = Counter(train_dataset2['group_label'])
print(label_counts)

Counter({2: 67376, 4: 66630, 0: 57498, 1: 11158, 5: 1992, 3: 1521})


In [16]:
train_dataset2 = train_dataset2.map(tokenize_function, batched=True)
train_dataset2 =train_dataset2.rename_column("gold_label", "labels")
train_dataset2.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

validation_dataset2 = validation_dataset2.map(tokenize_function, batched=True)
validation_dataset2 = validation_dataset2.rename_column("gold_label", "labels")
validation_dataset2.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

test_dataset2 = test_dataset2.map(tokenize_function, batched=True)
test_dataset2 =test_dataset2.rename_column("gold_label", "labels")
test_dataset2.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

Map:   0%|          | 0/82462 [00:00<?, ? examples/s]

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pai

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pai

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pai

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pai

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pai

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pai

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pai

Map:   0%|          | 0/123712 [00:00<?, ? examples/s]

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pai

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pai

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pai

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pai

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pai

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pai

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pai

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pai

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pai

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pai

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pai

## training and saving model

In [None]:
#train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

In [17]:
import torch
from transformers import BertForSequenceClassification, AdamW
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, Subset
from tqdm import tqdm

# Load model and optimizer
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=3,output_hidden_states=True)
optimizer = AdamW(model.parameters(), lr=1e-5, weight_decay=1e-4)

train_dataloader = DataLoader(train_dataset2, sampler=RandomSampler(train_dataset2), batch_size=16)
total_batches_per_epoch = len(train_dataset2) // 16 + (len(train_dataset2) % 16 > 0)
total_steps = total_batches_per_epoch * 5
# Scheduler for linear learning rate decay(need to make it varable)
#total_steps = 1000 * 5  # Number of batches * number of epochs
scheduler = LambdaLR(optimizer, lr_lambda=lambda step: 1 - step / total_steps)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Setup DataLoader
#train_subset = Subset(dataset['train'], indices=range(16000))  #subset for training(only done for testing pourpose)
#train_dataloader = DataLoader(train_subset, sampler=RandomSampler(train_subset), batch_size=16)
#total_batches_per_epoch = len(dataset['train']) // 16 + (len(dataset['train']) % 16 > 0)
#total_steps = total_batches_per_epoch * 5
# Training
model.train()  
for epoch in range(5):  
    total_train_loss = 0
    progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}", leave=False)

    for batch in progress_bar:
        batch = {k: v.to(device) for k, v in batch.items()}  # Move batch to device

        # Forward pass
        outputs = model(**batch)
        loss = outputs.loss
        total_train_loss += loss.item()

        # Backward pass
        loss.backward()

        # Update parameters and take a step using the computed gradient
        optimizer.step()
        scheduler.step()  # Update learning rate schedule

        # Clear the gradients after updating weights
        optimizer.zero_grad()

        #progress bar
        progress_bar.set_postfix({'loss': loss.item()})

    # Calculate average loss over the training data
    avg_train_loss = total_train_loss / len(train_dataloader)
    print(f"Epoch {epoch+1} | Average Training Loss: {avg_train_loss:.2f}")


# Save the final model
model.save_pretrained('./bert_retrain_on_real_data')

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
                                                                           

Epoch 1 | Average Training Loss: 0.59


                                                                           

Epoch 2 | Average Training Loss: 0.41


                                                                           

Epoch 3 | Average Training Loss: 0.29


                                                                           

Epoch 4 | Average Training Loss: 0.21


                                                                            

Epoch 5 | Average Training Loss: 0.16


In [8]:
if torch.cuda.is_available():
    print("CUDA is available. GPU can be used.")
else:
    print("CUDA is not available. Using CPU.")

CUDA is available. GPU can be used.


In [22]:
import torch
print(torch.__version__)
print(torch.cuda.is_available())


2.3.0+cpu
False


In [33]:
!nvcc --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2021 NVIDIA Corporation
Built on Sun_Aug_15_21:18:57_Pacific_Daylight_Time_2021
Cuda compilation tools, release 11.4, V11.4.120
Build cuda_11.4.r11.4/compiler.30300941_0


In [65]:
#model_path = './bert_retrain_half'
#model = BertForSequenceClassification.from_pretrained(model_path, output_hidden_states=True)
#model.eval()  
#model.to('cuda') 

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12,

In [2]:
import torch
print(torch.__version__)
print(torch.version.cuda)
print(torch.cuda.is_available())

2.3.0+cu118
11.8
True


In [18]:

validation_dataset2.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
valid_loader = DataLoader(validation_dataset2, batch_size=16, shuffle=False)  

test_dataset2.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
test_loader = DataLoader(test_dataset2, batch_size=16, shuffle=False)

def evaluate_model(model, data_loader, device):
    model.eval()  

    correct_predictions = 0
    total_predictions = 0

    with torch.no_grad(): 
        for batch in data_loader:
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            predictions = torch.argmax(outputs.logits, dim=-1)
            correct_predictions += (predictions == batch['labels']).sum().item()
            total_predictions += predictions.size(0)

    accuracy = correct_predictions / total_predictions
    return accuracy

In [19]:
import torch
from sklearn.metrics import accuracy_score

def evaluate_model_by_group(model, data_loader, device):
    model.eval()
    group_accuracies = {0: [], 1: [], 2: [], 3: [], 4: [], 5: []}  # Store predictions and labels for each group

    with torch.no_grad():
        for batch in data_loader:
            inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels' and k != 'group_label'}
            labels = batch['labels'].to(device)
            groups = batch['group_label'].to(device)
            outputs = model(**inputs)
            predictions = torch.argmax(outputs.logits, dim=-1)

            # Group predictions and labels
            for group in group_accuracies.keys():
                group_mask = (groups == group)
                if group_mask.any():  # Only calculate if there are any items in this group
                    group_preds = predictions[group_mask]
                    group_labels = labels[group_mask]
                    group_accuracies[group].append((group_preds, group_labels))

    # Calculate accuracy for each group
    for group, data in group_accuracies.items():
        if data:
            all_group_preds = torch.cat([item[0] for item in data])
            all_group_labels = torch.cat([item[1] for item in data])
            group_accuracy = accuracy_score(all_group_labels.cpu().numpy(), all_group_preds.cpu().numpy())
            group_accuracies[group] = group_accuracy
        else:
            group_accuracies[group] = float('nan')  # No data for this group

    return group_accuracies

In [20]:
model.to('cuda')  

accuracy = evaluate_model(model, valid_loader, 'cuda')
print(f"Validation Accuracy: {accuracy:.4f}")

Validation Accuracy: 0.8196


In [22]:
validation_dataset2.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels','group_label'])
valid_loader = DataLoader(validation_dataset2, batch_size=16, shuffle=False) 

In [24]:
test_dataset2.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels','group_label'])
test_loader = DataLoader(test_dataset2, batch_size=16, shuffle=False) 

In [25]:
group_accuracies = evaluate_model_by_group(model, test_loader, 'cuda')
for group, accuracy in group_accuracies.items():
    print(f"Group {group} Accuracy: {accuracy:.4f}")

Group 0 Accuracy: 0.8095
Group 1 Accuracy: 0.9500
Group 2 Accuracy: 0.8410
Group 3 Accuracy: 0.7799
Group 4 Accuracy: 0.7932
Group 5 Accuracy: 0.6577


In [74]:
test_dataset2.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels','group_label'])

## get embeddings

In [26]:
def extract_embeddings(dataloader, model):
    model.eval()
    embeddings = []
    label_list = []
    group_label_list = []
    with torch.no_grad():
        for batch in dataloader:
            inputs = {k: v.to(model.device) for k, v in batch.items() if k != 'labels' and k != 'group_label'}
            labels = batch['labels'].to(model.device)  # Main classification labels
            group_labels = batch['group_label'].to(model.device)  # Group labels

            # Forward pass to extract embeddings
            outputs = model(**inputs)
            batch_embeddings = outputs.hidden_states[-2].mean(dim=1)  # Mean across the sequence length
            embeddings.append(batch_embeddings)
            label_list.append(labels)
            group_label_list.append(group_labels)

    # Concatenate all lists into single tensors
    embeddings = torch.cat(embeddings, dim=0)
    labels = torch.cat(label_list, dim=0)
    group_labels = torch.cat(group_label_list, dim=0)

    return embeddings, labels, group_labels



In [27]:
validation_embeddings, validation_labels,validation_groups = extract_embeddings(valid_loader, model)
test_embeddings, test_labels,test_groups = extract_embeddings(test_loader, model)

In [28]:
print(validation_groups)

tensor([4, 2, 4,  ..., 0, 0, 0], device='cuda:0')


## Impelementing logistic regression without tunning and rebalancing

In [29]:
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score

# Split data into training and test sets
X_train, X_test, y_train, y_test, group_train, group_test = train_test_split(
    validation_embeddings.cpu().numpy(), validation_labels.cpu().numpy(), validation_groups.cpu().numpy(), test_size=0.2, random_state=42)

# Train logistic regression
clf = LogisticRegression(max_iter=1000)
clf.fit(X_train, y_train)

# Predict and evaluate
predictions = clf.predict(X_test)
accuracy = accuracy_score(y_test, predictions)
print(f"Logistic Regression Accuracy: {accuracy:.4f}")


unique_groups = set(group_test)
for group in unique_groups:
    group_indices = [i for i, g in enumerate(group_test) if g == group]
    group_accuracy = accuracy_score([y_test[i] for i in group_indices], [predictions[i] for i in group_indices])
    print(f"Group {group} Accuracy: {group_accuracy:.4f}")


Logistic Regression Accuracy: 0.7991
Group 0 Accuracy: 0.7756
Group 1 Accuracy: 0.9494
Group 2 Accuracy: 0.8410
Group 3 Accuracy: 0.7642
Group 4 Accuracy: 0.7569
Group 5 Accuracy: 0.6039


## tunning the hyperparameter

In [30]:
from sklearn.preprocessing import StandardScaler



scaler = StandardScaler()
embeddings_scaled = scaler.fit_transform(validation_embeddings.cpu().numpy()) 


In [31]:
from sklearn.model_selection import train_test_split

X_train_val, X_tune_val, y_train_val, y_tune_val, groups_train_val, groups_tune_val = train_test_split(
    embeddings_scaled, validation_labels.cpu().numpy(), validation_groups.cpu().numpy(), test_size=0.5, random_state=42)


In [32]:
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score

C_values = [1.0, 0.7, 0.3, 0.1, 0.07, 0.03, 0.01]
best_c = 0
best_worst_group_accuracy = float('-inf')

for c in C_values:
    clf = LogisticRegression(penalty='l1', C=c, solver='liblinear')  # L1 penalty
    clf.fit(X_train_val, y_train_val)

    # Evaluate on the tuning validation set
    predictions = clf.predict(X_tune_val)
    unique_groups = set(groups_tune_val)
    worst_group_accuracy = float('inf')

    # Calculate worst-group performance
    for group in unique_groups:
        group_indices = [i for i, g in enumerate(groups_tune_val) if g == group]
        group_accuracy = accuracy_score(y_tune_val[group_indices], predictions[group_indices])
        worst_group_accuracy = min(worst_group_accuracy, group_accuracy)
    
    # Update the best C if current one is better
    if worst_group_accuracy > best_worst_group_accuracy:
        best_worst_group_accuracy = worst_group_accuracy
        best_c = c



## Rebalancing dataset

In [33]:

from sklearn.utils import resample

def rebalance_data(X, y, groups):
    unique_groups = np.unique(groups)
    group_sizes = {group: (groups == group).sum() for group in unique_groups}
    min_size = min(group_sizes.values())

    X_balanced = []
    y_balanced = []

    for group in unique_groups:
        mask = (groups == group)
        X_group = X[mask]
        y_group = y[mask]
        X_resampled, y_resampled = resample(X_group, y_group, replace=False, n_samples=min_size, random_state=None)
        X_balanced.append(X_resampled)
        y_balanced.append(y_resampled)

    return np.vstack(X_balanced), np.hstack(y_balanced)


## DFR

In [34]:

def DFR(X, y, groups, C, num_trials=10):
    coefs = []
    intercepts = []

    for _ in range(num_trials):
        X_balanced, y_balanced = rebalance_data(X, y, groups)
        clf = LogisticRegression(penalty='l1', C=C, solver='liblinear')
        clf.fit(X_balanced, y_balanced)
        coefs.append(clf.coef_)
        intercepts.append(clf.intercept_)

    mean_coefs = np.mean(coefs, axis=0)
    mean_intercepts = np.mean(intercepts, axis=0)
    return mean_coefs, mean_intercepts


In [89]:
X_train
X_test, y_train, y_test, group_train, group_test =validation_embeddings.cpu().numpy(), validation_labels.cpu().numpy(), validation_groups.cpu().numpy(), test_embeddings.cpu().numpy(), test_labels.cpu().numpy(), test_groups.cpu().numpy()
print(X_test)

[2 0 0 ... 0 0 1]


In [35]:
#X_train, X_test, y_train, y_test, group_train, group_test =validation_embeddings.cpu().numpy(), validation_labels.cpu().numpy(), validation_groups.cpu().numpy(), test_embeddings.cpu().numpy(), test_labels.cpu().numpy(), test_groups.cpu().numpy()



mean_coefs, mean_intercepts = DFR(validation_embeddings.cpu().numpy(),validation_labels.cpu().numpy(), validation_groups.cpu().numpy(), best_c)


final_model = LogisticRegression()
final_model.coef_ = mean_coefs
final_model.intercept_ = mean_intercepts
final_model.classes_ = np.unique(validation_labels.cpu().numpy())


predictions = final_model.predict(test_embeddings.cpu().numpy())


In [39]:
y_test=test_labels.cpu().numpy()

accuracy = accuracy_score(y_test, predictions)
print(f"Accuracy: {accuracy:.4f}")


unique_groups = np.unique(test_groups.cpu().numpy())
for group in unique_groups:
    idx = (test_groups.cpu().numpy() == group)
    group_accuracy = accuracy_score(y_test[idx], predictions[idx])
    print(f"Group {group} Accuracy: {group_accuracy:.4f}")

Accuracy: 0.7825
Group 0 Accuracy: 0.7414
Group 1 Accuracy: 0.8747
Group 2 Accuracy: 0.8231
Group 3 Accuracy: 0.8341
Group 4 Accuracy: 0.7625
Group 5 Accuracy: 0.7099


In [38]:
test_groups.cpu().numpy()

array([4, 2, 2, ..., 2, 4, 2], dtype=int64)