In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
%cd drive/MyDrive/Project/squad/qafcnn

/content/drive/MyDrive/Project/squad/qafcnn


In [1]:
import torch
import numpy as np
import pandas as pd
import pickle
import re, os, string, typing, gc, json
import spacy
from collections import Counter

from torch import nn
import torch
import numpy as np
import pandas as pd
import pickle, time
import re, os, string, typing, gc, json
import torch.nn.functional as F
import spacy
from sklearn.model_selection import train_test_split
from collections import Counter
nlp = spacy.blank('en')

# import sys
# sys.path.insert(1, '%cd /drive/MyDrive/Project/squad')
from drive.MyDrive.Project.squad.preprocess import *

ModuleNotFoundError: ignored

#Data Loader

In [None]:
with open('../QANET/qanetw2id.pickle','rb') as handle:
    word2idx = pickle.load(handle)
with open('../QANET/qanetc2id.pickle','rb') as handle:
    char2idx = pickle.load(handle)

In [None]:
train_df = pd.read_pickle('../QANET/qanettrain.pkl')
valid_df = pd.read_pickle('../QANET/qanetvalid.pkl')

In [None]:
idx2word = {v:k for k,v in word2idx.items()}

In [None]:
class SQuAD:
    def __init__(self, data, batch_size):
        self.batch_size = batch_size
        data = [data[i:i+self.batch_size] for i in range(0, len(data), self.batch_size)]
        self.data = data
        
        
    def __len__(self):
        return len(self.data)
    
    def make_char_vector(self, max_sent_len, sentence, max_word_len=16):
        
        char_vec = torch.zeros(max_sent_len, max_word_len).type(torch.LongTensor)
        
        for i, word in enumerate(nlp(sentence, disable=['parser','tagger','ner'])):
            for j, ch in enumerate(word.text):
                if j == max_word_len:
                    break
                char_vec[i][j] = char2idx.get(ch, 0)
        
        return char_vec     
    
    def get_span(self, text):

        text = nlp(text, disable=['parser','tagger','ner'])
        span = [(w.idx, w.idx+len(w.text)) for w in text]

        return span

    
    def __iter__(self):
        
        for batch in self.data:
            
            spans = []
            ctx_text = []
            answer_text = []
            
             
            for ctx in batch.context:
                ctx_text.append(ctx)
                spans.append(self.get_span(ctx))
            
            for ans in batch.answer:
                answer_text.append(ans)
                
            max_context_len = max([len(ctx) for ctx in batch.context_ids])
            padded_context = torch.LongTensor(len(batch), max_context_len).fill_(1)
            
            for i, ctx in enumerate(batch.context_ids):
                padded_context[i, :len(ctx)] = torch.LongTensor(ctx)
                
            max_word_ctx = 16
          
            char_ctx = torch.zeros(len(batch), max_context_len, max_word_ctx).type(torch.LongTensor)
            for i, context in enumerate(batch.context):
                char_ctx[i] = self.make_char_vector(max_context_len, context)
            
            max_question_len = max([len(ques) for ques in batch.question_ids])
            padded_question = torch.LongTensor(len(batch), max_question_len).fill_(1)
            
            for i, ques in enumerate(batch.question_ids):
                padded_question[i, :len(ques)] = torch.LongTensor(ques)
                
            max_word_ques = 16
            
            char_ques = torch.zeros(len(batch), max_question_len, max_word_ques).type(torch.LongTensor)
            for i, question in enumerate(batch.question):
                char_ques[i] = self.make_char_vector(max_question_len, question)
            
              
            label = torch.LongTensor(list(batch.label_idx))
            ids = list(batch.id)
            
            yield (padded_context, padded_question, char_ctx, char_ques, label, ctx_text, answer_text, ids)
            
         

In [None]:
train_dataset = SQuAD(train_df, 16)
valid_dataset = SQuAD(valid_df, 16)     

#FCNN

In [None]:
class FCNN(nn.Module):
    def __init__(self, emb_dim, device):
    
        super(FCNN, self).__init__()
        self.device = device

        self.word_embedding = self.get_glove_embedding()

        self.dropout = nn.Dropout()
        self.similarity_weight = nn.Linear(3*(emb_dim), 1, bias=False)
        self.l1 = nn.Linear(4*emb_dim, emb_dim)
        self.l2 = nn.Linear(emb_dim, emb_dim)
        self.l3 = nn.Linear(emb_dim, emb_dim)
        self.output_start = nn.Linear(5*(emb_dim), 1, bias=False)
        self.output_end = nn.Linear(5*(emb_dim), 1, bias=False)

    def get_glove_embedding(self):
        
        weights_matrix = np.load('../qanet/qanetglove.npy')
        num_embeddings, embedding_dim = weights_matrix.shape
        embedding = nn.Embedding.from_pretrained(torch.FloatTensor(weights_matrix).to(self.device),freeze=True)

        return embedding

    def forward(self, ctx, ques, char_ctx, char_ques):
        
        ctx_len = ctx.shape[1]
        ques_len = ques.shape[1]
        
        ctx_word_embed = self.word_embedding(ctx)

        ques_word_embed = self.word_embedding(ques)

        # Similarity Matrix
        c = ctx_word_embed.unsqueeze(2).repeat(1, 1, ques_len, 1)
        q = ques_word_embed.unsqueeze(1).repeat(1, ctx_len, 1, 1)

        elementwise = torch.mul(c, q)
        alpha = torch.cat([c, q, elementwise], dim=3)
        similarity_matrix = self.similarity_weight(alpha).view(-1, ctx_len, ques_len)

        # Context2Query Attention
        a = F.softmax(similarity_matrix, dim=-1)
        c2q = torch.bmm(a, ques_word_embed)

        # Query2Context Attention
        b = F.softmax(torch.max(similarity_matrix, 2)[0], dim=-1)
        b = b.unsqueeze(1)
        q2c = torch.bmm(b, ctx_word_embed)
        q2c = q2c.repeat(1, ctx_len, 1)

        # Query Aware Representation
        G = torch.cat([ctx_word_embed, c2q, torch.mul(ctx_word_embed, c2q), torch.mul(ctx_word_embed, q2c)], dim=2)
        
        M = F.relu(self.l2(F.relu(self.l1(G))))
        M2 = F.relu(self.l3(M))

        p1 = F.softmax(self.output_start(torch.cat([G, M], dim=2)).squeeze(), dim=-1)

        p2 = F.softmax(self.output_start(torch.cat([G, M2], dim=2)).squeeze(), dim=-1)

        return p1, p2





In [None]:
EMB_DIM = 300
device = torch.device('cuda:0')

model = FCNN(EMB_DIM, device).to(device)

In [None]:
import torch.optim as optim
from torch.autograd import Variable
optimizer = optim.Adadelta(model.parameters())

from tqdm import tqdm

In [None]:
def train(model, train_dataset):
    print("training")
    train_loss = 0.
    batch_count = 0
    model.train()

    for batch in tqdm(train_dataset):
        optimizer.zero_grad()

        if batch_count%500 == 0:
            print(f"Starting batch: {batch_count}")

        batch_count += 1

        context, question, char_ctx, char_ques, label, ctx_text, ans, ids = batch
        context, question, char_ctx, char_ques, label = context.to(device), question.to(device), char_ctx.to(device), char_ques.to(device), label.to(device)

        preds = model(context, question, char_ctx, char_ques)

        start_pred, end_pred = preds
        s_idx, e_idx = label[:,0], label[:,1]

        loss = F.cross_entropy(start_pred, s_idx) + F.cross_entropy(end_pred, e_idx)
        loss.backward()

        optimizer.step()
        train_loss += loss.item()

    return train_loss/len(train_dataset)

In [None]:
def valid(model, valid_dataset):
    print("Validating....")
    valid_loss = 0.
    batch_count = 0
    f1, em = 0., 0.
    model.eval()
    predictions = {}

    for batch in valid_dataset:
        if batch_count % 500 == 0:
            print(f"Starting batch {batch_count}")
        batch_count += 1

        context, question, char_ctx, char_ques, label, ctx, answers, ids = batch
        context, question, char_ctx, char_ques, label = context.to(device), question.to(device), char_ctx.to(device), char_ques.to(device), label.to(device)

        with torch.no_grad():
            s_idx, e_idx = label[:, 0], label[:, 0]
            preds = model(context, question, char_ctx, char_ques)

            p1, p2 = preds
            loss = F.cross_entropy(p1, s_idx) + F.cross_entropy(p2, e_idx)
            valid_loss += loss.item()

            batch_size, c_len = p1.size()
            ls = nn.LogSoftmax(dim=1)
            mask = (torch.ones(c_len, c_len) * float('-inf')).to(device).tril(-1).unsqueeze(0).expand(batch_size, -1, -1)
            score = (ls(p1).unsqueeze(2) + ls(p2).unsqueeze(1)) + mask
            score, s_idx = score.max(dim=1)
            score, e_idx = score.max(dim=1)
            s_idx = torch.gather(s_idx, 1, e_idx.view(-1, 1)).squeeze()
            
           
            for i in range(batch_size):
                id = ids[i]
                pred = context[i][s_idx[i]:e_idx[i]+1]
                pred = ' '.join([idx2word[idx.item()] for idx in pred])
                predictions[id] = pred   

    
    em, f1 = evaluate(predictions)
    print("done")
    return valid_loss/len(valid_dataset), em, f1

In [None]:
def normalize_answer(s):
    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
    scores_for_ground_truths = []
    for ground_truth in ground_truths:
        score = metric_fn(prediction, ground_truth)
        scores_for_ground_truths.append(score)
        
    return max(scores_for_ground_truths)


def f1_score(prediction, ground_truth):
    prediction_tokens = normalize_answer(prediction).split()
    ground_truth_tokens = normalize_answer(ground_truth).split()
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1


def exact_match_score(prediction, ground_truth):
    return (normalize_answer(prediction) == normalize_answer(ground_truth))


def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [None]:
def evaluate(predictions):

    with open("../data/squad_dev.json", "r", encoding="utf-8") as f:
        dataset = json.load(f)

    dataset = dataset["data"]
    f1 = exact_match = total = 0
    for article in dataset:
        for paragraphs in article["paragraphs"]:
            for qa in paragraphs["qas"]:
                total+=1
                if qa["id"] not in predictions:
                    continue
                
                ground_truths = list(map(lambda x: x['text'], qa['answers']))
                prediction = predictions[qa['id']]
                
                exact_match += metric_max_over_ground_truths(
                    exact_match_score, prediction, ground_truths)
                
                f1 += metric_max_over_ground_truths(
                    f1_score, prediction, ground_truths)
                
    
    exact_match = 100.0 * exact_match / total
    f1 = 100.0 * f1 / total
    
    return exact_match, f1

In [None]:
train_losses = []
valid_losses = []
ems = []
f1s = []
epochs = 5
for epoch in range(epochs):
    print(f"Epoch {epoch+1}")
    start_time = time.time()
    
    train_loss = train(model, train_dataset)
    valid_loss, em, f1 = valid(model, valid_dataset)
    
    torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': valid_loss,
            'em':em,
            'f1':f1,
            }, 'bidaf_run4_{}.pth'.format(epoch))
    
    
    end_time = time.time()
    
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    train_losses.append(train_loss)
    valid_losses.append(valid_loss)
    ems.append(em)
    f1s.append(f1)

    print(f"Epoch train loss : {train_loss}| Time: {epoch_mins}m {epoch_secs}s")
    print(f"Epoch valid loss: {valid_loss}")
    print(f"Epoch EM: {em}")
    print(f"Epoch F1: {f1}")
    print("---------------Done---------------")

Epoch 1
training


  0%|          | 1/5397 [00:00<15:19,  5.87it/s]

Starting batch: 0


  9%|▉         | 501/5397 [00:39<04:51, 16.82it/s]

Starting batch: 500


 19%|█▊        | 1003/5397 [01:13<04:32, 16.14it/s]

Starting batch: 1000


 28%|██▊       | 1501/5397 [01:46<03:58, 16.31it/s]

Starting batch: 1500


 37%|███▋      | 2003/5397 [02:28<04:14, 13.31it/s]

Starting batch: 2000


 46%|████▋     | 2503/5397 [03:10<03:53, 12.38it/s]

Starting batch: 2500


 56%|█████▌    | 3002/5397 [03:52<03:47, 10.51it/s]

Starting batch: 3000


 65%|██████▍   | 3502/5397 [04:34<03:17,  9.61it/s]

Starting batch: 3500


 74%|███████▍  | 4002/5397 [05:15<02:00, 11.61it/s]

Starting batch: 4000


 83%|████████▎ | 4502/5397 [05:58<01:16, 11.74it/s]

Starting batch: 4500


 93%|█████████▎| 5003/5397 [06:39<00:28, 13.97it/s]

Starting batch: 5000


100%|██████████| 5397/5397 [07:12<00:00, 12.47it/s]


Validating....
Starting batch 0
Starting batch 500
Starting batch 1000
Starting batch 1500
Starting batch 2000
done
Epoch train loss : 10.304877542093195| Time: 9m 52s
Epoch valid loss: 9.932044783220604
Epoch EM: 7.729422894985809
Epoch F1: 12.983051456152799
Epoch 2
training


  0%|          | 2/5397 [00:00<08:15, 10.89it/s]

Starting batch: 0


  9%|▉         | 501/5397 [00:39<07:08, 11.42it/s]

Starting batch: 500


 19%|█▊        | 1003/5397 [01:13<04:31, 16.21it/s]

Starting batch: 1000


 28%|██▊       | 1502/5397 [01:46<03:41, 17.60it/s]

Starting batch: 1500


 37%|███▋      | 2002/5397 [02:27<03:59, 14.20it/s]

Starting batch: 2000


 46%|████▋     | 2503/5397 [03:11<03:51, 12.52it/s]

Starting batch: 2500


 56%|█████▌    | 3001/5397 [03:52<03:38, 10.98it/s]

Starting batch: 3000


 65%|██████▍   | 3502/5397 [04:34<02:20, 13.51it/s]

Starting batch: 3500


 74%|███████▍  | 4002/5397 [05:14<01:57, 11.83it/s]

Starting batch: 4000


 83%|████████▎ | 4501/5397 [05:58<01:14, 12.11it/s]

Starting batch: 4500


 93%|█████████▎| 5002/5397 [06:41<00:27, 14.19it/s]

Starting batch: 5000


100%|██████████| 5397/5397 [07:13<00:00, 12.45it/s]


Validating....
Starting batch 0
Starting batch 500
Starting batch 1000
Starting batch 1500
Starting batch 2000
done
Epoch train loss : 10.224057499117071| Time: 9m 55s
Epoch valid loss: 9.910852672460493
Epoch EM: 8.732261116367077
Epoch F1: 14.89093346190782
Epoch 3
training


  0%|          | 1/5397 [00:00<10:13,  8.80it/s]

Starting batch: 0


  9%|▉         | 501/5397 [00:40<04:56, 16.49it/s]

Starting batch: 500


 19%|█▊        | 1004/5397 [01:14<04:32, 16.09it/s]

Starting batch: 1000


 28%|██▊       | 1502/5397 [01:49<03:46, 17.19it/s]

Starting batch: 1500


 37%|███▋      | 2003/5397 [02:30<04:09, 13.61it/s]

Starting batch: 2000


 46%|████▋     | 2503/5397 [03:15<04:02, 11.93it/s]

Starting batch: 2500


 56%|█████▌    | 3001/5397 [03:58<03:29, 11.46it/s]

Starting batch: 3000


 65%|██████▍   | 3502/5397 [04:41<02:23, 13.24it/s]

Starting batch: 3500


 74%|███████▍  | 4003/5397 [05:24<01:59, 11.63it/s]

Starting batch: 4000


 83%|████████▎ | 4502/5397 [06:06<01:17, 11.49it/s]

Starting batch: 4500


 93%|█████████▎| 5003/5397 [06:50<00:29, 13.26it/s]

Starting batch: 5000


100%|██████████| 5397/5397 [07:23<00:00, 12.16it/s]


Validating....
Starting batch 0
Starting batch 500
Starting batch 1000
Starting batch 1500
Starting batch 2000
done
Epoch train loss : 10.188865631227916| Time: 10m 9s
Epoch valid loss: 9.892804049773956
Epoch EM: 9.24314096499527
Epoch F1: 16.327613237658493
Epoch 4
training


  0%|          | 0/5397 [00:00<?, ?it/s]

Starting batch: 0


  9%|▉         | 501/5397 [00:40<05:06, 15.95it/s]

Starting batch: 500


 19%|█▊        | 1004/5397 [01:16<04:35, 15.96it/s]

Starting batch: 1000


 28%|██▊       | 1501/5397 [01:49<03:59, 16.24it/s]

Starting batch: 1500


 37%|███▋      | 2002/5397 [02:32<04:08, 13.66it/s]

Starting batch: 2000


 46%|████▋     | 2502/5397 [03:16<04:12, 11.47it/s]

Starting batch: 2500


 56%|█████▌    | 3001/5397 [03:58<03:28, 11.49it/s]

Starting batch: 3000


 65%|██████▍   | 3502/5397 [04:41<02:22, 13.33it/s]

Starting batch: 3500


 74%|███████▍  | 4003/5397 [05:23<01:57, 11.84it/s]

Starting batch: 4000


 83%|████████▎ | 4502/5397 [06:06<01:17, 11.56it/s]

Starting batch: 4500


 93%|█████████▎| 5001/5397 [06:49<00:39,  9.94it/s]

Starting batch: 5000


100%|██████████| 5397/5397 [07:22<00:00, 12.19it/s]


Validating....
Starting batch 0
Starting batch 500
Starting batch 1000
Starting batch 1500
Starting batch 2000
done
Epoch train loss : 10.159802252967555| Time: 10m 6s
Epoch valid loss: 9.882048839246723
Epoch EM: 9.943235572374645
Epoch F1: 17.631802074602028
Epoch 5
training


  0%|          | 1/5397 [00:00<13:09,  6.83it/s]

Starting batch: 0


  9%|▉         | 501/5397 [00:40<05:06, 15.98it/s]

Starting batch: 500


 19%|█▊        | 1003/5397 [01:15<04:38, 15.78it/s]

Starting batch: 1000


 28%|██▊       | 1501/5397 [01:50<04:02, 16.09it/s]

Starting batch: 1500


 37%|███▋      | 2003/5397 [02:31<04:10, 13.53it/s]

Starting batch: 2000


 46%|████▋     | 2503/5397 [03:15<03:53, 12.39it/s]

Starting batch: 2500


 56%|█████▌    | 3001/5397 [03:57<03:25, 11.67it/s]

Starting batch: 3000


 65%|██████▍   | 3502/5397 [04:38<02:18, 13.66it/s]

Starting batch: 3500


 74%|███████▍  | 4002/5397 [05:20<01:56, 11.95it/s]

Starting batch: 4000


 83%|████████▎ | 4502/5397 [06:02<01:16, 11.66it/s]

Starting batch: 4500


 93%|█████████▎| 5002/5397 [06:44<00:28, 13.97it/s]

Starting batch: 5000


100%|██████████| 5397/5397 [07:16<00:00, 12.37it/s]


Validating....
Starting batch 0
Starting batch 500
Starting batch 1000
Starting batch 1500
Starting batch 2000
done
Epoch train loss : 10.146143917775008| Time: 9m 58s
Epoch valid loss: 9.880302768805777
Epoch EM: 9.593188268684957
Epoch F1: 16.830685366590135


In [None]:
!touch main.py

In [29]:
!ls

bidaf_run4_0.pth  bidaf_run4_2.pth  bidaf_run4_4.pth
bidaf_run4_1.pth  bidaf_run4_3.pth  main.py


In [30]:
%cd ../

/content/drive/MyDrive/Project/squad


In [31]:
!ls

bidaf		  bidaf_run4_2.pth  data	 main.py	qafcnn
bidaf_run4_0.pth  bidaf_run4_3.pth  evaluate.py  preprocess.py	qanet
bidaf_run4_1.pth  bidaf_run4_4.pth  GloVe	 __pycache__


In [32]:
%cd bidaf

/content/drive/MyDrive/Project/squad/bidaf


In [33]:
!ls

bidafc2id.pickle   bidafglove.npy  bidafvalid.pkl    main2.py
bidafglove100.npy  bidaftrain.pkl  bidafw2id.pickle  main.py


In [39]:
!touch main_rnn.py

In [40]:
!ls

bidafc2id.pickle   bidaftrain.pkl    main2.py	  preprocess.py
bidafglove100.npy  bidafvalid.pkl    main.py	  __pycache__
bidafglove.npy	   bidafw2id.pickle  main_rnn.py


In [41]:
!python main_rnn.py

Epoch 1
training
  0% 0/10827 [00:00<?, ?it/s]Starting batch: 0
  1% 56/10827 [00:07<23:28,  7.65it/s]
Traceback (most recent call last):
  File "main_rnn.py", line 470, in <module>
  File "main_rnn.py", line 309, in train
    for batch in tqdm(train_dataset):
  File "/usr/local/lib/python3.7/dist-packages/tqdm/std.py", line 1195, in __iter__
    for obj in iterable:
  File "main_rnn.py", line 95, in __iter__
    char_ctx[i] = self.make_char_vector(max_context_len, max_word_ctx, context)
  File "main_rnn.py", line 56, in make_char_vector
    char_vec[i][j] = char2idx.get(ch, 0)
KeyboardInterrupt


In [5]:
import math

In [166]:
class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_size, dropout=0.1, num_heads=10):
        super(MultiHeadAttention, self).__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.dropout = dropout
        self.head_size = int(hidden_size / num_heads)
        assert self.head_size * num_heads == self.hidden_size

        self.query = nn.Linear(hidden_size, self.hidden_size)
        self.key = nn.Linear(hidden_size, self.hidden_size)
        self.value = nn.Linear(hidden_size, self.hidden_size)

        self.dense = nn.Linear(2*hidden_size, hidden_size)

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_heads, self.head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, hidden_states, memory, attention_mask):
        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(memory)
        mixed_value_layer = self.value(memory)

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)
        print(value_layer.size())

        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        print(attention_scores.size())
        attention_scores = attention_scores / torch.sqrt(torch.FloatTensor([self.head_size])).cuda()
        print(attention_scores.size())
        attention_scores = attention_scores - 1e30 * (1 - attention_mask[:,None,None, :])

        attention_probs = nn.Softmax(dim=-1)(attention_scores)
        attention_probs = nn.Dropout(self.dropout)(attention_probs)

        context_layer = torch.matmul(attention_probs, value_layer)
        print(context_layer.size())
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        print(context_layer.size())
        new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        print(context_layer.size())

        ou = torch.cat([hidden_states, context_layer], dim=-1)

        attention_output = self.dense(ou)
        return F.relu(attention_output)

In [167]:
device = torch.device('cuda:0')

In [168]:
mha = MultiHeadAttention(200).to(device)

In [169]:
msk = (torch.randn((12, 40)) > 0).float().to(device)

In [170]:
inp = torch.randn((12, 40, 200)).to(device)
inp.size() + 

torch.Size([12, 40, 200])

In [171]:
oup = mha(inp, inp, msk)

torch.Size([12, 10, 40, 20])
torch.Size([12, 10, 40, 40])
torch.Size([12, 10, 40, 40])
torch.Size([12, 10, 40, 20])
torch.Size([12, 40, 10, 20])
torch.Size([12, 40, 200])


In [172]:
oup.size()

torch.Size([12, 40, 200])

In [174]:
jk = torch.randn([12, 10, 40, 20])
jk.size()[:-2] + (800,)

torch.Size([12, 10, 800])

In [132]:
jk - msk[:, None]

tensor([[[ 0.0126, -2.1837, -0.5596,  ..., -0.0681, -0.9773,  0.3364],
         [ 0.8043, -3.4792,  0.7756,  ..., -1.5371, -0.3824, -1.8916],
         [-0.7206, -1.7320,  0.4998,  ...,  0.2066,  0.0157, -0.0585],
         ...,
         [-0.2690, -0.3023, -2.0801,  ..., -1.4509,  0.6430, -1.8389],
         [-0.7647,  1.5724,  0.7006,  ..., -1.2573,  0.4690, -3.2597],
         [-1.5683, -2.3968,  0.2595,  ..., -1.3187,  0.8512, -0.8413]],

        [[-1.0174, -2.5141, -0.7108,  ..., -1.4397,  0.5852, -1.8142],
         [-0.9676,  0.6741, -0.7844,  ...,  0.1234, -1.2307,  0.2743],
         [ 0.4299,  0.9110,  0.1502,  ..., -1.3556,  2.0762, -1.6457],
         ...,
         [-1.7616, -1.2087, -1.0000,  ..., -1.6305, -1.1972, -1.2158],
         [-2.7566,  0.1825, -0.1549,  ..., -2.1681, -0.5929,  0.5348],
         [-0.5984, -1.2707, -0.4227,  ..., -2.2108,  0.0104,  1.1768]],

        [[-0.5635, -0.7812, -1.7948,  ..., -3.0906,  1.6500, -0.1906],
         [ 1.8300, -0.3608, -1.7437,  ..., -2

In [130]:
msk = (torch.randn(12, 40) > 0).float()
msk[:, None], msk

(tensor([[[0., 1., 0., 0., 1., 1., 1., 0., 1., 0., 1., 1., 1., 1., 1., 1., 1.,
           1., 0., 0., 1., 1., 0., 1., 1., 1., 1., 1., 1., 0., 1., 0., 0., 1.,
           1., 1., 1., 1., 0., 1.]],
 
         [[1., 1., 0., 0., 1., 1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 0., 0.,
           1., 1., 0., 1., 0., 1., 1., 1., 1., 1., 0., 0., 0., 0., 1., 0., 0.,
           0., 0., 0., 1., 1., 0.]],
 
         [[0., 1., 1., 1., 1., 0., 0., 1., 0., 1., 0., 0., 0., 1., 0., 1., 1.,
           0., 1., 0., 0., 0., 0., 1., 1., 1., 1., 1., 0., 1., 0., 0., 1., 0.,
           1., 1., 0., 1., 0., 1.]],
 
         [[1., 1., 1., 1., 0., 1., 1., 0., 1., 0., 1., 1., 0., 1., 1., 0., 1.,
           1., 1., 0., 0., 1., 0., 0., 0., 1., 1., 1., 0., 1., 0., 0., 1., 1.,
           1., 1., 0., 0., 0., 1.]],
 
         [[0., 1., 1., 1., 1., 0., 0., 0., 1., 0., 0., 0., 1., 1., 1., 0., 0.,
           1., 1., 0., 0., 1., 1., 0., 1., 0., 0., 0., 1., 1., 1., 0., 1., 0.,
           0., 1., 1., 1., 1., 0.]],
 
         [[0., 

In [99]:
torch.eq(torch.tensor([[1, 2], [3, 4]]), 1)

tensor([[ True, False],
        [False, False]])