In [None]:
## Standard libraries
import os
import json
import math
import numpy as np
import time
import pandas as pd
from google.colab import drive
import re
import nltk
from functools import partial

## Imports for plotting
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf') # For export
from matplotlib.colors import to_rgb
import matplotlib
matplotlib.rcParams['lines.linewidth'] = 2.0
import seaborn as sns
sns.reset_orig()
sns.set()

## Progress bar
from tqdm.notebook import tqdm

## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
# PyTorch Lightning
try:
    
    import pytorch_lightning as pl
except ModuleNotFoundError: # Google Colab does not have PyTorch Lightning installed by default. Hence, we do it here if necessary
    !pip install --quiet pytorch-lightning wandb
    import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import WandbLogger

try:
  from datasets import load_dataset
except ModuleNotFoundError:
  !pip install --quiet datasets
  from datasets import load_dataset

try:
  import spacy
  import es_core_news_sm
except ModuleNotFoundError:
  !pip install --quiet spacy
  import spacy
  !python -m spacy download es_core_news_sm
  import es_core_news_sm

import unicodedata
import re
import string
from collections import Counter
import torchmetrics
from torchmetrics import AUROC, Precision, Recall, F1Score, Accuracy

try:
  from transformers import BertTokenizerFast as BertTokenizer, BertModel, AdamW, get_linear_schedule_with_warmup, AutoModel, AutoTokenizer
  from transformers import AutoModelWithHeads
except ModuleNotFoundError:
  !pip install --quiet -U adapter-transformers
  from transformers import BertTokenizerFast as BertTokenizer, BertModel, AdamW, get_linear_schedule_with_warmup, AutoModel, AutoTokenizer
  from transformers import AutoModelWithHeads
  
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import datasets
# from datasets import Dataset, DatasetDict
from transformers import AutoConfig, AutoAdapterModel


[K     |████████████████████████████████| 708 kB 5.1 MB/s 
[K     |████████████████████████████████| 1.8 MB 62.8 MB/s 
[K     |████████████████████████████████| 5.9 MB 84.4 MB/s 
[K     |████████████████████████████████| 419 kB 80.4 MB/s 
[K     |████████████████████████████████| 162 kB 83.3 MB/s 
[K     |████████████████████████████████| 181 kB 75.4 MB/s 
[K     |████████████████████████████████| 63 kB 2.1 MB/s 
[K     |████████████████████████████████| 158 kB 63.5 MB/s 
[K     |████████████████████████████████| 157 kB 81.3 MB/s 
[K     |████████████████████████████████| 157 kB 83.8 MB/s 
[K     |████████████████████████████████| 157 kB 7.3 MB/s 
[K     |████████████████████████████████| 157 kB 87.1 MB/s 
[K     |████████████████████████████████| 157 kB 86.7 MB/s 
[K     |████████████████████████████████| 157 kB 88.5 MB/s 
[K     |████████████████████████████████| 157 kB 84.5 MB/s 
[K     |████████████████████████████████| 156 kB 88.8 MB/s 
[?25h  Building wheel for p

In [None]:
drive.mount('/drive')
%cd '/drive/MyDrive/qa_datasets_spanish'

Mounted at /drive
/drive/MyDrive/qa_datasets_spanish


In [None]:
from torch.nn import CrossEntropyLoss, MultiheadAttention

class CrossAttentionLayer(nn.Module):
    def __init__(self, d_model_size, num_heads):
        super(CrossAttentionLayer, self).__init__()
        self.attn_qa = MultiheadAttention(d_model_size, num_heads)

        self.attn_c = MultiheadAttention(d_model_size, num_heads)

        self.attn_r = MultiheadAttention(d_model_size, num_heads)

    def forward(self, qa_seq_output, r_seq_output, c_seq_output, qa_mask, r_mask, c_mask):
        qa_seq_representation = qa_seq_output.permute([1, 0, 2])
        r_seq_representation = r_seq_output.permute([1, 0, 2])
        c_seq_representation = c_seq_output.permute([1, 0, 2])
        
        
        enc_output_qa_c, _ = self.attn_qa(
            value=qa_seq_representation, key=qa_seq_representation, query=c_seq_representation, key_padding_mask=qa_mask
        )
        enc_output_qa_r, _ = self.attn_qa(
            value=qa_seq_representation, key=qa_seq_representation, query=r_seq_representation, key_padding_mask=qa_mask
        )


        enc_output_c_qa, _ = self.attn_c(
            value=c_seq_representation, key=c_seq_representation, query=qa_seq_representation, key_padding_mask=c_mask
        )
        enc_output_c_r, _ = self.attn_c(
            value=c_seq_representation, key=c_seq_representation, query=r_seq_representation, key_padding_mask=c_mask
        )
        
        enc_output_r_qa, _ = self.attn_r(
            value=r_seq_representation, key=r_seq_representation, query=qa_seq_representation, key_padding_mask=r_mask
        )
        enc_output_r_c, _ = self.attn_r(
            value=r_seq_representation, key=r_seq_representation, query=c_seq_representation, key_padding_mask=r_mask
        )

        return enc_output_qa_c.permute([1, 0, 2]), enc_output_qa_r.permute([1, 0, 2]), enc_output_c_qa.permute([1, 0, 2]), enc_output_c_r.permute([1, 0, 2]), enc_output_r_qa.permute([1, 0, 2]), enc_output_r_c.permute([1, 0, 2])


In [None]:
def separate_sequences(sequence_output, flat_input_ids, tokenizer):
    qa_seq_output = sequence_output.new(sequence_output.size()).zero_()

    qa_mask = torch.ones((sequence_output.shape[0], sequence_output.shape[1]),
                         device=sequence_output.device,
                         dtype=torch.bool)

    r_seq_output = sequence_output.new(sequence_output.size()).zero_()
    r_mask = torch.ones((sequence_output.shape[0], sequence_output.shape[1]),
                        device=sequence_output.device,
                        dtype=torch.bool)

    c_seq_output = sequence_output.new(sequence_output.size()).zero_()
    c_mask = torch.ones((sequence_output.shape[0], sequence_output.shape[1]),
                        device=sequence_output.device,
                        dtype=torch.bool)
    
    for i in range(flat_input_ids.size(0)):
        sep_lst = []
        for idx, e in enumerate(flat_input_ids[i]):
            if e == tokenizer.sep_token_id:
                sep_lst.append(idx)
        assert len(sep_lst) == 3

        qa_seq_output[i, :sep_lst[0] - 1] = sequence_output[i, 1:sep_lst[0]]
        qa_mask[i, :sep_lst[0] - 1] = 0

        r_seq_output[i, :sep_lst[1] - sep_lst[0] - 1] = sequence_output[i, sep_lst[0] + 1: sep_lst[1]]
        r_mask[i, :sep_lst[1] - sep_lst[0] - 1] = 0

        c_seq_output[i, :sep_lst[2] - sep_lst[1] - 1] = sequence_output[i, sep_lst[1] + 1: sep_lst[2]]
        c_mask[i, :sep_lst[2] - sep_lst[1] - 1] = 0

    return qa_seq_output, r_seq_output, c_seq_output, qa_mask, r_mask, c_mask

In [None]:
name = "bert-base-uncased"
bert = AutoModel.from_pretrained(name)
tokenizer = AutoTokenizer.from_pretrained(name, use_fast=True)
cross = CrossAttentionLayer(bert.config.hidden_size, bert.config.num_attention_heads)

Downloading config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/420M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Downloading tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading vocab.txt:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading tokenizer.json:   0%|          | 0.00/455k [00:00<?, ?B/s]

In [None]:
bert.config.hidden_size

768

In [None]:
bert.config.num_attention_heads

12

# Unit testing

In [None]:
context = "This is a test text."
question_option = "quesiton option"
reason = "reason "
text = question_option + tokenizer.sep_token + reason + tokenizer.sep_token + context 
encoding = tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=20,
            return_token_type_ids=False,
            padding="max_length",
            truncation=True,
            return_attention_mask=True,
            return_tensors="pt",
        )

input_ids=encoding["input_ids"]
attention_mask=encoding["attention_mask"]
# [CLS] quesiton option [SEP] reason [SEP] this is a test text. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD]

In [None]:

bert_output = bert(input_ids=input_ids, attention_mask=attention_mask)

hidden = bert_output.last_hidden_state
qa_seq_output, r_seq_output, c_seq_output, qa_mask, r_mask, c_mask = separate_sequences(hidden, input_ids, tokenizer)
enc_output_qa_c, enc_output_qa_r, enc_output_c_qa, enc_output_c_r, enc_output_r_qa, enc_output_r_c = cross(qa_seq_output, r_seq_output, c_seq_output, qa_mask, r_mask, c_mask)

fused_output = torch.cat([enc_output_qa_c, enc_output_qa_r, enc_output_c_qa, enc_output_c_r, enc_output_r_qa, enc_output_r_c], dim=1)
pooled_output = torch.mean(fused_output, dim=1)

print(pooled_output.shape)


torch.Size([1, 768])


In [None]:
# ->input_ids
# torch.Size([1, 20])
# batch_size, seq_len

# ->bert.last_hidden
# torch.Size([1, 20, 768])
# batch_size, seq_len, hidden_dim


# ############
# separate
# ###########
# ->qa_seq_output
# tensor de ceros
# torch.Size([1, 20, 768])

# ->qa_mask
# tensor booleano
# torch.Size([1, 20])

# ->p_seq_output
# tensor de ceros
# torch.Size([1, 20, 768])

# ->p_mask
# tensor booleano
# torch.Size([1, 20])

