In this notebook, I will add the missing roles in the temporary NoUniteD dataset

# Prerequisites

In [1]:
import os, sys

import numpy as np
import tqdm
import torch

  from .autonotebook import tqdm as notebook_tqdm


Important paths for the notebook:

In [2]:
datasets_root_path = './temp_files/'
srl_dataset_path = os.path.join(datasets_root_path, 'maven_nounited_srl')
checkpoints_dir_path = './checkpoints/'
model_dir_path = os.path.join(checkpoints_dir_path, 'models_nounited_maven')

srl_dataset_dict_paths = {}
for lang in os.listdir(srl_dataset_path):
    dataset_lang_path = os.path.join(srl_dataset_path, lang)
    if os.path.isdir(dataset_lang_path):
        srl_dataset_dict_paths[lang] = {}
        for d_type in os.listdir(dataset_lang_path):
            d_name = d_type.split('.')[0]
            srl_dataset_dict_paths[lang][d_name] = os.path.join(dataset_lang_path, d_type)

In [3]:
%load_ext autoreload
%autoreload 2

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

Setting the seed for reproducibility:

In [4]:
SEED = 28

# random.seed(SEED) # not used
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [31]:
global_params = np.load(os.path.join(model_dir_path, 'global_params.npy'), allow_pickle=True).tolist()

## Adding final roles to the dataset

In [6]:
from code_files.datasets.dataset_nounited import DatasetNoUniteD

In [18]:
split_type = '_n'

In [34]:
dataset_train_en_aic = DatasetNoUniteD(  srl_dataset_dict_paths['EN']['train'], split_predicates=True, split_type_to_use = split_type )
dataset_dev_en_aic = DatasetNoUniteD(  srl_dataset_dict_paths['EN']['dev'], split_predicates=True, split_type_to_use = split_type )

In [57]:
dataset_train_en_aic.data[0]

{'words': ['The',
  '2006',
  'Pangandaran',
  'earthquake',
  'and',
  'tsunami',
  'occurred',
  'on',
  'July',
  '17',
  'at',
  'along',
  'a',
  'subduction',
  'zone',
  'off',
  'the',
  'coast',
  'of',
  'west',
  'and',
  'central',
  'Java',
  ',',
  'a',
  'large',
  'and',
  'densely',
  'populated',
  'island',
  'in',
  'the',
  'Indonesian',
  'archipelago',
  '.'],
 'predicates': ['_',
  '_',
  '_',
  'MOVE-ONESELF',
  '_',
  '_',
  '_',
  '_',
  '_',
  '_',
  '_',
  '_',
  '_',
  '_',
  '_',
  '_',
  '_',
  '_',
  '_',
  '_',
  '_',
  '_',
  '_',
  '_',
  '_',
  '_',
  '_',
  '_',
  '_',
  '_',
  '_',
  '_',
  '_',
  '_',
  '_'],
 'predicates_v': ['_',
  '_',
  '_',
  '_',
  '_',
  '_',
  'HAPPEN_OCCUR',
  '_',
  '_',
  '_',
  '_',
  '_',
  '_',
  '_',
  '_',
  '_',
  '_',
  '_',
  '_',
  '_',
  '_',
  '_',
  '_',
  '_',
  '_',
  '_',
  '_',
  '_',
  '_',
  '_',
  '_',
  '_',
  '_',
  '_',
  '_'],
 'predicates_n': ['_',
  '_',
  '_',
  'MOVE-ONESELF',
  '_',
  '_',
 

In [33]:
from code_files.models.model_aic import ModelAIC
model_aic = ModelAIC(hparams = global_params)

Some weights of the model checkpoint at xlm-roberta-base were not used when initializing XLMRobertaModel: ['lm_head.layer_norm.weight', 'lm_head.decoder.weight', 'lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.bias']
- This IS expected if you are initializing XLMRobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLMRobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [37]:
model_aic.load_weights(os.path.join(model_dir_path, f'aic_transformer_nounited_v.pth'))

In [63]:
ssmpl = 3
samppll = dataset_train_en_aic.data[ssmpl]

print(" ".join([w if samppll['predicate_word'][0] != w else f"[{w}]" for w in samppll['words']]))
print( model_aic.predict([samppll['words']], [samppll['predicate_word']])[0] )

The July 2006 [earthquake] was also centered in the Indian Ocean , from the coast of Java , and had a duration of more than three minutes .
['_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_']
