In [1]:
import torch
import torchtext
from torch import nn
from torch.utils.data import DataLoader
from torch.autograd import Variable
from perceiver_pytorch import Perceiver
import numpy as np
from matplotlib import pyplot as plt
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score
import itertools
from torch.utils.tensorboard import SummaryWriter
import os
from tqdm.notebook import tqdm
from datetime import datetime
from babi_joint import BabiDataset, pad_collate
from torch.utils.data.dataset import Dataset
from glob import glob

In [20]:
import torch
from perceiver_pytorch import PerceiverIO

model = PerceiverIO(
    dim = 16,                    # dimension of sequence to be encoded
    queries_dim = 179,            # dimension of decoder queries
#     logits_dim = 50,            # dimension of final logits
    depth = 3,                   # depth of net
    num_latents = 13,           # number of latents, or induced set points, or centroids. different papers giving it different names
    latent_dim = 16,            # latent dimension
    cross_heads = 1,             # number of heads for cross attention. paper said 1
    latent_heads = 8,            # number of heads for latent self attention, 8
    cross_dim_head = 32,         # number of dimensions per cross attention head
    latent_dim_head = 32,        # number of dimensions per latent self attention head
    weight_tie_layers = True,   # whether to weight tie layers (optional, as indicated in the diagram)
    self_per_cross_attn = 2,     # number of self attention blocks per cross attention
)

# seq = torch.randn(1, 512, 32)
# queries = torch.randn(1, 128, 32)

# logits = model(seq, queries = queries) # (1, 128, 100) - (batch, decoder seq, logits dim)

# add bAbI query as latent query
# only one query/question and answer per forward pass

seq = torch.randn(1, 70*13, 16) # (batch, input_sequence_length, input_sequence_dim/dim)
output_query = torch.randn(1, 1, 179) # (batch, output_sequence_length, queries_dim) - should be learned
output_query = Variable(output_query) # .cuda()

logits = model(seq, queries=output_query) # (batch, output_sequence_length, output_sequence_dim/logits dim/queries_dim if not)
print(logits.shape)

model_name = 'perceiverIO_bAbi'

torch.Size([1, 1, 179])


In [None]:
import torch
from perceiver_pytorch import PerceiverIO

model = PerceiverIO(
    dim = 16,                    # dimension of sequence to be encoded
    queries_dim = 179,            # dimension of decoder queries
#     logits_dim = 50,            # dimension of final logits
    depth = 3,                   # depth of net
    num_latents = 13,           # number of latents, or induced set points, or centroids. different papers giving it different names
    latent_dim = 16,            # latent dimension
    cross_heads = 1,             # number of heads for cross attention. paper said 1
    latent_heads = 8,            # number of heads for latent self attention, 8
    cross_dim_head = 32,         # number of dimensions per cross attention head
    latent_dim_head = 32,        # number of dimensions per latent self attention head
    weight_tie_layers = True,   # whether to weight tie layers (optional, as indicated in the diagram)
    self_per_cross_attn = 2,     # number of self attention blocks per cross attention
)

# seq = torch.randn(1, 512, 32)
# queries = torch.randn(1, 128, 32)

# logits = model(seq, queries = queries) # (1, 128, 100) - (batch, decoder seq, logits dim)

# add bAbI query as latent query
# only one query/question and answer per forward pass

seq = torch.randn(1, 70*13, 16) # (batch, input_sequence_length, input_sequence_dim/dim)
output_query = torch.randn(1, 1, 179) # (batch, output_sequence_length, queries_dim) - should be learned
output_query = Variable(output_query) # .cuda()

logits = model(seq, queries=output_query) # (batch, output_sequence_length, output_sequence_dim/logits dim/queries_dim if not)
print(logits.shape)

model_name = 'perceiverIO_bAbi'

In [19]:
batch_size=128
max_epochs = 2

babi_dataset = BabiDataset(ds_path='/home/gabriel/Documents/datasets/bAbi/en/qa{}_*', 
                           vocab_path='/home/gabriel/Documents/datasets/bAbi/en/babi{}_vocab.pkl')
vocab_size = len(babi_dataset.QA.VOCAB)
print('len(babi_dataset) train', len(babi_dataset))
print('vocab_size', vocab_size)

use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
print(device)
torch.backends.cudnn.benchmark = True
model.to(device)

criterion = nn.CrossEntropyLoss(reduction='sum')
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
val_loss_min = float('-inf')

writer = SummaryWriter()
if not os.path.isdir(f'checkpoints/{model_name}'):
    os.mkdir(f'checkpoints/{model_name}')
now = datetime.now().strftime("%d_%m_%Y__%H_%M_%S")

for epoch in range(max_epochs):
    train_loss = 0.0
    val_loss = 0.0
    loss_sum = 0
    loss_count = 0
    model.train()

    babi_dataset.set_mode('train')
    train_loader = DataLoader(
        babi_dataset, batch_size=batch_size, shuffle=True, collate_fn=pad_collate
    )
    
    for batch_idx, data in enumerate(tqdm(train_loader, desc=f"Epoch {epoch + 1}/{max_epochs}")):
        optimizer.zero_grad()

        contexts, questions, answers = data
        contexts = Variable(contexts.long().to(device))
        questions = Variable(questions.long().to(device))
        answers = Variable(answers.to(device))

        context = nn.Embedding(vocab_size, 70*13)
#         make questions into input_query and also pass to model
#         output_query = torch.randn(1, 1, 179) # (batch, output_sequence_length, queries_dim) - should be learned
#         output_query = Variable(output_query) # .cuda()
        
        logits = model(seq, queries=output_query)
        loss = criterion(logits, answers)
        loss.backward()
        optimizer.step()
        train_loss += loss.data
        writer.add_scalars("losses_step", {"train_loss": loss.data}, epoch * len(train_loader) + batch_idx)
        
    model.eval()
    
    babi_dataset.set_mode('valid')
    val_loader = DataLoader(
        babi_dataset, batch_size=batch_size, shuffle=False, collate_fn=pad_collate
    )
    
    with torch.no_grad():
        for batch_idx, (X_batch, y_batch) in enumerate(val_loader):
        
            contexts, questions, answers = data
            contexts = Variable(contexts.long().to(device))
            questions = Variable(questions.long().to(device))
            answers = Variable(answers.to(device))
        
#             make questions into input_query and also pass to model
#             output_query = torch.randn(1, 1, 179) # (batch, output_sequence_length, queries_dim) - should be learned
#             output_query = Variable(output_query) # .cuda()
        
            logits = model(seq, queries=output_query)
            loss = criterion(logits, answers)
            val_loss += loss.data

    train_loss = train_loss / len(train_loader)
    val_loss = val_loss / len(val_loader)
    
    writer.add_scalars("losses_step", {"val_loss": val_loss}, (epoch + 1) * len(train_loader) - 1)
    
    writer.add_scalars("losses_epoch", {"train_loss": train_loss}, epoch)
    writer.add_scalars("losses_epoch", {"val_loss": val_loss}, epoch)
    
    print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(
        epoch+1, 
        train_loss,
        val_loss
    ))
    
    checkpoint = {
        'epoch': epoch + 1,
        'valid_loss_min': val_loss,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
    }
    
    torch.save(checkpoint, f'checkpoints/{model_name}/checkpoint_{now}.pt')
    if val_loss <= val_loss_min:
        print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(val_loss_min,val_loss))
        torch.save(checkpoint, f'checkpoints/{model_name}/best_checkpoint_{now}.pt')
        val_loss_min = val_loss

torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size

torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size

torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size

torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size

torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size

torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size

torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size

torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size

torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size

torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size

torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size

torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([32, 910]) torch.Size([32, 13]) torch.Size([32])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([1

torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size([128])
torch.Size([128, 910]) torch.Size([128, 13]) torch.Size

In [None]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)
    

In [None]:
# use_cuda = torch.cuda.is_available()
# device = torch.device("cuda:0" if use_cuda else "cpu")aset
# model.to(device)
# # checkpoint = torch.load(f'checkpoints/{model_name}/best_checkpoint.pt')
# # model.load_state_dict(checkpoint['state_dict'])
# model.eval()
# m = nn.Softmax(dim=1)

# babi_dataset.set_mode('test')
# test_loader = DataLoader(
#     babi_dataset, batch_size=batch_size, shuffle=False, collate_fn=pad_collate
# )

# y_pred_extended = []
# y_true_extended = []
# with torch.no_grad():
#     for batch_idx, (X_batch, y_batch) in enumerate(tqdm(test_loader, desc=f"Inference")):
#         contexts, questions, answers = data
#         contexts = Variable(contexts.long().to(device))
#         questions = Variable(questions.long().to(device))
#         answers = Variable(answers.to(device))

# #         make questions into input_query and also pass to model
# #         output_query = torch.randn(1, 1, 179) # (batch, output_sequence_length, queries_dim) - should be learned
# #         output_query = Variable(output_query) # .cuda()            
            
#         logits = model(seq, queries=output_query)
#         y_pred = m(logits)
#         y_pred = np.argmax(y_pred, axis=1)
#         y_pred_extended.extend(y_pred)
#         y_true_extended.extend(answers.cpu())
        
# def plot_confusion_matrix(cm, classes,
#                           normalize=False,
#                           title='Confusion matrix',
#                           cmap=plt.cm.Blues):
#     """
#     This function prints and plots the confusion matrix.
#     Normalization can be applied by setting `normalize=True`.
#     """
#     if normalize:
#         cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
#         #print("Normalized confusion matrix")
#     #else:
#         #print('Confusion matrix, without normalization')

#     #print(cm)
    
#     plt.figure(figsize=(25,25))
#     plt.imshow(cm, interpolation='nearest', cmap=cmap)
#     plt.title(title)
#     plt.colorbar()
#     tick_marks = np.arange(len(classes))
#     plt.xticks(tick_marks, classes, rotation=45)
#     plt.yticks(tick_marks, classes)

#     fmt = '.2f' if normalize else 'd'
#     thresh = cm.max() / 2.
#     for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
#         plt.text(j, i, format(cm[i, j], fmt),
#                  horizontalalignment="center",
#                  color="white" if cm[i, j] > thresh else "black")

#     plt.ylabel('True label')
#     plt.xlabel('Predicted label')
#     plt.tight_layout()

# accuracy = accuracy_score(y_pred_extended, y_true_extended)
# cr = classification_report(y_true_extended, y_pred_extended)
# print(accuracy)
# print(cr)
# cnf_matrix = confusion_matrix(y_true_extended, y_pred_extended)
# # Plot non-normalized confusion matrix
# plt.figure()
# plot_confusion_matrix(cnf_matrix, classes=[str(i) for i in list(range(10))], title = ('Confusion Matrix'))
# plt.show()