## Libraries

In [27]:
import datasets
import torch
import nltk
nltk.download('wordnet')
from nltk.corpus import wordnet as wn

[nltk_data] Downloading package wordnet to /home/leonardo/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


## GPU Settings


In [28]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

GPU check

In [29]:
print(torch.cuda.is_available())
print(torch.cuda.current_device())
print(torch.cuda.get_device_name(0))

True
0
NVIDIA GeForce RTX 3060 Ti


releasing memory allocated on the GPU

In [30]:
torch.cuda.empty_cache()

## Datasets creation

In [31]:
dataset = datasets.load_dataset(
                        path = 'parquet',
                        data_files = {
                                    'train': 'data/train-00000-of-00001.parquet',
                                    'validation' : 'data/validation-00000-of-00001.parquet'
                                },
                    )

## WSD data augmentation

In [32]:
wsd_dataset = dataset.remove_columns(['srl'])

### Synonyms

####  substituting all synonyms for each word in each sample premise

In [33]:
def get_synonyms_by_offset(wn_synset_offset, original_word, index):
    """This function returns a list of synonyms for a given wordnet synset offset."""
    try:
        if wn_synset_offset == 'O':
            return []
        
        if wn_synset_offset[-1].isalpha() and wn_synset_offset[:-1].isdigit():
            offset_str, pos_indicator = wn_synset_offset[:-1], wn_synset_offset[-1]
        else: 
            raise ValueError(f"wn_synset_offset is not in the correct format at id {index}.")
        
        offset_int = int(offset_str)
            
        synset = wn.synset_from_pos_and_offset(pos_indicator, offset_int)
        if synset is None:
            return []
        
        synonyms = set(lemma.name().replace('_', ' ') for lemma in synset.lemmas() if lemma.name().lower() != original_word.lower())
        return list(synonyms)
    except Exception as e:
        return []

In [34]:
def wsd_mapping(sample):
    """This function create a new sample for each synonym of each word. If no synonym is found, the sample is discarded."""
    index = sample['id']
    for word in sample['wsd']['premise']:
        wn_synset_offset = word['wnSynsetOffset']
        original_word = word['text']

        if not wn_synset_offset or wn_synset_offset == 'O':
            continue
        else:
            synonyms = get_synonyms_by_offset(wn_synset_offset, original_word, index)
            if synonyms:
                sample['premise'] = sample['premise'].replace(original_word, synonyms[0])
            else:
                continue
    return sample

In [35]:
synonym_dataset = wsd_dataset.map(wsd_mapping)

Augmentation

In [36]:
wsd_aug_train_dataset = datasets.concatenate_datasets([wsd_dataset['train'], synonym_dataset['train']])
wsd_aug_train_dataset = wsd_aug_train_dataset.remove_columns(['wsd'])

wsd_aug_val_dataset = datasets.concatenate_datasets([wsd_dataset['validation'], synonym_dataset['validation']])
wsd_aug_val_dataset = wsd_aug_val_dataset.remove_columns(['wsd'])

#### We now save the augmented dataset

In [37]:
wsd_aug_train_dataset.to_parquet('wsd_aug_train_dataset.parquet')
wsd_aug_val_dataset.to_parquet('wsd_aug_val_dataset.parquet')

Creating parquet from Arrow format: 100%|██████████| 103/103 [00:00<00:00, 915.58ba/s]
Creating parquet from Arrow format: 100%|██████████| 5/5 [00:00<00:00, 953.99ba/s]


1906335

## SRL data augmentation

Data preparation

In [38]:
srl_dataset = dataset.remove_columns(['wsd'])

### Switching agent and patient

Generation of samples with switched agent and patient

In [39]:
def switcherSRL(dataset):

    """This function creates new samples that are used to generate a new dataset in order to augment the starting one. 
    The new samples are created by switching the position of the agent and the patient in the premise of the original samples.
    Spans corresponding to the agent and the patient are extracted from the SRL annotations and used to switch the words in the premise.
    Which is done by creating a new premise from scratch.
    Also samples missing both roles (agent and patient) are discarded and counted in order to quantify the loss of samples."""
    
    new_samples = []
    missing_sample_count = 0

    for sample in dataset:
        agent_span = []
        patient_span = []
        words = []

        # spans extraction
        for annotation in sample['srl']['premise'].get('annotations', []):
            if 'verbatlas' in annotation:
                for role in annotation['verbatlas']['roles']:
                    if role['role'] == 'Agent':
                        agent_span = role['span']
                    if role['role'] == 'Patient':
                        patient_span = role['span']

        # words extraction
        for token in sample['srl']['premise'].get('tokens', []):
            words.append(token['rawText'])

        # New premise creation
        if agent_span and patient_span:
            agent_words = words[agent_span[0]:agent_span[1]+1]
            patient_words = words[patient_span[0]:patient_span[1]+1]
            
            first_span, second_span = sorted([agent_span, patient_span], key=lambda x: x[0])
            first_words, second_words = (agent_words, patient_words) if first_span == agent_span else (patient_words, agent_words)
            
            before_first_span = words[:first_span[0]]
            between_spans = words[first_span[1]+1:second_span[0]]
            after_second_span = words[second_span[1]+1:]
            
            if first_span == agent_span:
                new_words = before_first_span + second_words + between_spans + first_words + after_second_span
            else:
                new_words = before_first_span + first_words + between_spans + second_words + after_second_span
            
            new_premise = ' '.join(new_words)
            
            new_sample = {
                'id': sample['id'],
                'premise': new_premise,
                'hypothesis': sample['hypothesis'],
                'label': sample['label']
            }
            new_samples.append(new_sample)
        
        # missing samples count
        else:
            missing_sample_count += 1

    print(f"Missing roles in {missing_sample_count} samples.")
    return new_samples


In [40]:
switched_srl_train = switcherSRL(srl_dataset['train'])
switched_srl_val = switcherSRL(srl_dataset['validation'])

Missing roles in 29700 samples.
Missing roles in 1455 samples.


### Augmented dataset creation

In [41]:
"""In order to create a new dataset, the samples obtained from the switcher function are stored in dictionaries."""

dataset_dict_srl_train = {
    'id': [sample['id'] for sample in switched_srl_train],
    'premise': [sample['premise'] for sample in switched_srl_train],
    'hypothesis': [sample['hypothesis'] for sample in switched_srl_train],
    'label': [sample['label'] for sample in switched_srl_train]
}

dataset_dict_srl_val = {
    'id': [sample['id'] for sample in switched_srl_val],
    'premise': [sample['premise'] for sample in switched_srl_val],
    'hypothesis': [sample['hypothesis'] for sample in switched_srl_val],
    'label': [sample['label'] for sample in switched_srl_val]
}

In [42]:
switched_srl_train_dataset = datasets.Dataset.from_dict(dataset_dict_srl_train)
switched_srl_val_dataset = datasets.Dataset.from_dict(dataset_dict_srl_val)

In [43]:
srl_dataset = srl_dataset.remove_columns(['srl'])

In [44]:
srl_aug_train_dataset = datasets.concatenate_datasets([srl_dataset['train'], switched_srl_train_dataset])
srl_aug_val_dataset = datasets.concatenate_datasets([srl_dataset['validation'], switched_srl_val_dataset])

Dataset saving

In [45]:
srl_aug_train_dataset.to_parquet('srl_aug_train_dataset.parquet')
srl_aug_val_dataset.to_parquet('srl_aug_val_dataset.parquet')

Creating parquet from Arrow format: 100%|██████████| 73/73 [00:00<00:00, 923.99ba/s]
Creating parquet from Arrow format: 100%|██████████| 4/4 [00:00<00:00, 1042.84ba/s]


1367667

## SRL + WSD data augmentation

In [46]:
srl_wsd_aug_train_dataset = datasets.concatenate_datasets([switched_srl_train_dataset, wsd_aug_train_dataset])
srl_wsd_aug_val_dataset = datasets.concatenate_datasets([switched_srl_val_dataset, wsd_aug_val_dataset])

Dataset saving

In [47]:
srl_wsd_aug_train_dataset.to_parquet('srl_wsd_aug_train_dataset.parquet')
srl_wsd_aug_val_dataset.to_parquet('srl_wsd_aug_val_dataset.parquet')

Creating parquet from Arrow format: 100%|██████████| 124/124 [00:00<00:00, 830.26ba/s]
Creating parquet from Arrow format: 100%|██████████| 6/6 [00:00<00:00, 962.00ba/s]


2356278