<a href="https://colab.research.google.com/github/antsh3k/NN-learning/blob/master/6_Entity_Extraction_RNN_Diseases.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Switch to GPU
---

## Imports and globals

In [0]:
# IMPORTS (try to organize/group your imports)
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import re
import json
import spacy
from os import path

import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim

import sklearn.metrics
from sklearn.metrics import classification_report

In [0]:
# Any global variables
SEED = 15
DATA_PATH = '/content/' # Used for Colab
MAX_SEQ_LEN = 600
nlp = spacy.load('en_core_web_sm')
DEVICE = 'cuda'

# Set SEEDs
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

## Download the data

---
Notes:
- Change directory accordingly if working locally or not on Colab

In [0]:
# TODO: Download the data from: https://github.com/w-is-h/DeepLearningNLP/raw/master/Session_7/data/diseases_medmentions.txt
!wget https://github.com/w-is-h/DeepLearningNLP/raw/master/Session_7/data/diseases_medmentions.txt -P /content/
# TODO: Show a couple of lines from the downloaded file
!gunzip /content/diseases_medmentions.txt
print("*"*200)
print("Head of the data: \n\n")
!head /content/diseases_medmentions.txt

--2019-10-22 13:38:54--  https://github.com/w-is-h/DeepLearningNLP/raw/master/Session_7/data/diseases_medmentions.txt
Resolving github.com (github.com)... 192.30.253.112
Connecting to github.com (github.com)|192.30.253.112|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/w-is-h/DeepLearningNLP/master/Session_7/data/diseases_medmentions.txt [following]
--2019-10-22 13:38:54--  https://raw.githubusercontent.com/w-is-h/DeepLearningNLP/master/Session_7/data/diseases_medmentions.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.0.133, 151.101.64.133, 151.101.128.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.0.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 4942710 (4.7M) [text/plain]
Saving to: ‘/content/diseases_medmentions.txt’


2019-10-22 13:38:54 (40.9 MB/s) - ‘/content/diseases_medmentions.txt’ saved [4942710/4942710]

gzip: 

## We need to put the data into the right format (x, y)

The input data in `diseases_medmentions.txt` has the form:
```
<token>\t<label>
<token>\t<label>
.
.
.
```
The documents are separated by a blank line. We need to read the tokens into `x,y` so that one row in `x,y` is one document:
```
x = [[<token>, <token>, <token>, ...], 
     [<token>, <token>, <label>, ...], 
     ...]

y = [[<label>, <label>, <label>, ...], 
     [<label>, <label>, <label>, ...], 
     ...]
``` 

---
Notes:
- **We will also lowercase the text**
- Usually we would not load the data into memory, but in this case it is small so who cares

In [0]:
# Load the data
x = []
y = []

# TODO - open the file, read the data and put into 'x' and 'y'
#
#
#
x.append([])
y.append([])
for line in open(DATA_PATH + "diseases_medmentions.txt"):
  if line.strip():
    parts = line.split("\t")
    token = parts[0].lower()
    label = int(parts[1])
    
    x[-1].append(token)
    y[-1].append(label)
  else:
    x.append([])
    y.append([])

del x[-1]
del y[-1]
# Random sanity check
assert len(x[12]) == len(y[12])

# Print an example
print("Sample at position 0 for both x and y")
print(x[0])
print(y[0])

Sample at position 0 for both x and y
['dctn4', 'as', 'a', 'modifier', 'of', 'chronic', 'pseudomonas', 'aeruginosa', 'infection', 'in', 'cystic', 'fibrosis', 'pseudomonas', 'aeruginosa', '(', 'pa', ')', 'infection', 'in', 'cystic', 'fibrosis', '(', 'cf', ')', 'patients', 'is', 'associated', 'with', 'worse', 'long', '-', 'term', 'pulmonary', 'disease', 'and', 'shorter', 'survival', ',', 'and', 'chronic', 'pa', 'infection', '(', 'cpa', ')', 'is', 'associated', 'with', 'reduced', 'lung', 'function', ',', 'faster', 'rate', 'of', 'lung', 'decline', ',', 'increased', 'rates', 'of', 'exacerbations', 'and', 'shorter', 'survival', '.', 'by', 'using', 'exome', 'sequencing', 'and', 'extreme', 'phenotype', 'design', ',', 'it', 'was', 'recently', 'shown', 'that', 'isoforms', 'of', 'dynactin', '4', '(', 'dctn4', ')', 'may', 'influence', 'pa', 'infection', 'in', 'cf', ',', 'leading', 'to', 'worse', 'respiratory', 'disease', '.', 'the', 'purpose', 'of', 'this', 'study', 'was', 'to', 'investigate', 'th

## Download the word embeddings

The links contain the glove trained model (glove.840B.300d.zip), I've only converted it into keyed_vectors for gensim and saved on AWS to speed up things.


---
Notes: 
- Change directory accordingly if working locally or not on Colab
- Pretrained word embeddings taken from: https://nlp.stanford.edu/projects/glove/

In [0]:
!wget https://zkcl.s3-eu-west-1.amazonaws.com/keyed_vectors_840_300.dat -P /content/
!wget https://zkcl.s3-eu-west-1.amazonaws.com/keyed_vectors_840_300.dat.vectors.npy -P /content/

--2019-10-22 14:41:11--  https://zkcl.s3-eu-west-1.amazonaws.com/keyed_vectors_840_300.dat
Resolving zkcl.s3-eu-west-1.amazonaws.com (zkcl.s3-eu-west-1.amazonaws.com)... 52.218.97.48
Connecting to zkcl.s3-eu-west-1.amazonaws.com (zkcl.s3-eu-west-1.amazonaws.com)|52.218.97.48|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 120497892 (115M) [application/x-www-form-urlencoded]
Saving to: ‘/content/keyed_vectors_840_300.dat.1’


2019-10-22 14:41:16 (26.6 MB/s) - ‘/content/keyed_vectors_840_300.dat.1’ saved [120497892/120497892]

--2019-10-22 14:41:17--  https://zkcl.s3-eu-west-1.amazonaws.com/keyed_vectors_840_300.dat.vectors.npy
Resolving zkcl.s3-eu-west-1.amazonaws.com (zkcl.s3-eu-west-1.amazonaws.com)... 52.218.36.67
Connecting to zkcl.s3-eu-west-1.amazonaws.com (zkcl.s3-eu-west-1.amazonaws.com)|52.218.36.67|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2635219328 (2.5G) [application/x-www-form-urlencoded]
Saving to: ‘/content/keyed_

## Load the pretrained embeddings into a KeyedVectors Model

In [0]:
from gensim.models import KeyedVectors
keyed_vectors = KeyedVectors.load(DATA_PATH + "keyed_vectors_840_300.dat")

# Sanity check
keyed_vectors.most_similar("fibrosis")

  'See the migration notes for details: %s' % _MIGRATION_NOTES_URL
  if np.issubdtype(vec.dtype, np.int):


[('cystic', 0.8054730892181396),
 ('pulmonary', 0.7173466086387634),
 ('emphysema', 0.699813961982727),
 ('cirrhosis', 0.6976104974746704),
 ('lung', 0.6918838620185852),
 ('Fibrosis', 0.6711395978927612),
 ('sarcoidosis', 0.6456377506256104),
 ('bronchiectasis', 0.6421423554420471),
 ('renal', 0.6387313604354858),
 ('sclerosis', 0.6375002861022949)]

## Subset the pretrained embeddings to only the ones we need
Create a vocabulary based on our current datasete `x`. We use this to subset the full keyed_vectors from glove.

---
Notes:
- Usually we would not do this, unles we know that our current dataset contains all words that we are ever going to see in the future. Here we are only doing it to speed things up.

In [0]:
current_vocab = []
#? TODO
for sample in x:
  for token in sample:
    current_vocab.append(token)

# Convert to set to remove duplicates
current_vocab = set(current_vocab)
print(len(current_vocab))

29262


## From the gensim model take only what we need

`embeddings` - a list of vectors, where each row represents the embedding of one word

`id2word` - map from index in the embeddings list to words

`word2id` - map from word to index in the embeddings list


Once done we should be able to get the embedding for word "house" like this:
`embeddings[word2id['house']]`

---
Notes:
- We only want to have the embeddings for the words in `current_vocab`

In [0]:
embeddings = [] # A list of embeddings for each word in the word2vec vocab

# Embeddings is a list, meaning we know that embeddings[1] is a vector for the 
#word with ID=1, but we don't know what word is that. That is why we need 
#the id2word and word2id mappings.
id2word = {}
word2id = {}

# TODO:???
for word in current_vocab:
  if word in keyed_vectors.vocab.keys():
    id = len(embeddings)  # What is the position of this word in the embeddings list?
    id2word[id] = word  # Add mapping from ID to word
    word2id[word] = id  # From word 2 id
    embeddings.append(keyed_vectors[word])  # Add the embedding for 'word', embeddings are available in the 'model'
  
# Add <UNK> and <PAD>
word = "<UNK>"
id2word[len(embeddings)] = word
word2id[word] = len(embeddings)
embeddings.append(np.random.rand(len(embeddings[0])))
  
# For the word '<PAD>', embedding is set to all zeros
word = "<PAD>"
id2word[len(embeddings)] = word
word2id[word] = len(embeddings)
embeddings.append(np.zeros(len(embeddings[0])))


# Convert the embeddings list into a tensor
embeddings = torch.tensor(embeddings, dtype=torch.float32)

# Sanity
assert len(embeddings) == len(id2word) == len(word2id)
assert keyed_vectors['house'][0] == embeddings[word2id['house']][0]


## Convert x into a list of indices instead of words

---
Notes:
- If a word does not exist in our `word2id` map we use the index for `<UNK>`

In [0]:
x_ind = [] # TODO
for sample in x:
  n_sample = []
  for word in sample:
    if word in word2id:
      n_sample.append(word2id[word])
    else:
      n_sample.append(word2id['<UNK>'])
  x_ind.append(n_sample)
    
    
print(x[0])
print(x_ind[0])
assert len(x_ind[0]) == len(y[0])

['dctn4', 'as', 'a', 'modifier', 'of', 'chronic', 'pseudomonas', 'aeruginosa', 'infection', 'in', 'cystic', 'fibrosis', 'pseudomonas', 'aeruginosa', '(', 'pa', ')', 'infection', 'in', 'cystic', 'fibrosis', '(', 'cf', ')', 'patients', 'is', 'associated', 'with', 'worse', 'long', '-', 'term', 'pulmonary', 'disease', 'and', 'shorter', 'survival', ',', 'and', 'chronic', 'pa', 'infection', '(', 'cpa', ')', 'is', 'associated', 'with', 'reduced', 'lung', 'function', ',', 'faster', 'rate', 'of', 'lung', 'decline', ',', 'increased', 'rates', 'of', 'exacerbations', 'and', 'shorter', 'survival', '.', 'by', 'using', 'exome', 'sequencing', 'and', 'extreme', 'phenotype', 'design', ',', 'it', 'was', 'recently', 'shown', 'that', 'isoforms', 'of', 'dynactin', '4', '(', 'dctn4', ')', 'may', 'influence', 'pa', 'infection', 'in', 'cf', ',', 'leading', 'to', 'worse', 'respiratory', 'disease', '.', 'the', 'purpose', 'of', 'this', 'study', 'was', 'to', 'investigate', 'the', 'role', 'of', 'dctn4', 'missense',

In [0]:
# Sanity Check: convert the indicies for x_ind[0] back to words
" ".join([id2word[i] for i in x_ind[0]])

'<UNK> as a modifier of chronic pseudomonas aeruginosa infection in cystic fibrosis pseudomonas aeruginosa ( pa ) infection in cystic fibrosis ( cf ) patients is associated with worse long - term pulmonary disease and shorter survival , and chronic pa infection ( cpa ) is associated with reduced lung function , faster rate of lung decline , increased rates of exacerbations and shorter survival . by using exome sequencing and extreme phenotype design , it was recently shown that isoforms of dynactin 4 ( <UNK> ) may influence pa infection in cf , leading to worse respiratory disease . the purpose of this study was to investigate the role of <UNK> missense variants on pa infection incidence , age at first pa infection and chronic pa infection incidence in a cohort of adult cf patients from a single centre . polymerase chain reaction and direct sequencing were used to screen dna samples for <UNK> variants . a total of 121 adult cf patients from the cochin hospital cf centre have been inclu

## Print some statistics

In [0]:
doc_lengths = [len(doc) for doc in x_ind] # Document lengths
pos = np.sum(np.sum(y)) # Number of positive examples
neg = np.sum([len(one) for one in y]) - pos # Number of negative examples 
avg = np.average(doc_lengths) # Average doc length
md = np.median(doc_lengths) # Median doc length
mx = np.max(doc_lengths) # Maximum doc length
mi = np.min(doc_lengths) # Minimum doc length

print("Number of positive examples: {}".format(pos))
print("Number of negative examples: {}".format(neg))
print("Average length of the doc:   {:.2f}".format(avg))
print("Median length of the doc:    {}".format(md))
print("Max length of the doc:       {}".format(mx))
print("Min length of the doc:       {}".format(mi))

Number of positive examples: 19780
Number of negative examples: 535604
Average length of the doc:   288.51
Median length of the doc:    294.0
Max length of the doc:       833
Min length of the doc:       29


## Trim sentences, create masks and add padding

We are doing three things in this step:
1. Limiting the size of documents to `MAX_SEQ_LEN`
2. Adding padding to short documents to match `MAX_SEQ_LEN`
3. Creating a mask for each input document to have `1` where a real word from the document is and `0` where padding is.

An example (here I will show 'words' but normally we work with indices):
```
MAX_SEQ_LEN = 5

doc = ['i', 'was', 'running']
y = [0, 0, 1]

doc_pad = ['i', 'was', 'running', '<PAD>', '<PAD>']
mask = [1, 1, 1, 0, 0]
# We have to add padding to y, which is 0s
y = [0, 0, 1, 0, 0]
```

---
Notes:
- Here we can not have documents of length 0, usually we would remove/discard them beforhand. In our case I have done that in the data preparation phase.
- Whatever change to the shape of `x` is done, the same must be done to `y`

In [0]:
# Add padding and limit to MAX_SEQ_LEN
x_ind_pad = [(sample + [word2id['<PAD>']] * max(0, MAX_SEQ_LEN - len(sample)))[0:MAX_SEQ_LEN] for sample in x_ind]
y_pad = [(sample + [0] * max(0, MAX_SEQ_LEN - len(sample)))[0:MAX_SEQ_LEN] for sample in y] #? TODO

# Sanity
assert x_ind_pad[-1][0] == x_ind[-1][0]
assert np.average([len(s) for s in x_ind_pad]) == MAX_SEQ_LEN
assert np.average([len(s) for s in y_pad]) == MAX_SEQ_LEN

In [0]:
# Create a mask for each document in 'x'
masks = [[1]*min(MAX_SEQ_LEN, doc_len)+[0]*max(0, MAX_SEQ_LEN - doc_len) for doc_len in doc_lengths]

# Sanity
assert np.sum(masks[-1]) == doc_lengths[-1]
assert np.average([len(mask) for mask in masks]) == MAX_SEQ_LEN

## Split into train/test and convert to tensors

---
Notes:
- We are still not moving to 'device'

In [0]:
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test, masks_train, masks_test = train_test_split(x_ind_pad, y_pad, masks, test_size=0.2, random_state=SEED)

x_train = torch.tensor(x_train, dtype=torch.long)
y_train = torch.tensor(y_train, dtype=torch.long)
masks_train = torch.tensor(masks_train, dtype=torch.float32)

x_test = torch.tensor(x_test, dtype=torch.long)
y_test = torch.tensor(y_test, dtype=torch.long)
masks_test = torch.tensor(masks_test, dtype=torch.float32)

## Make a function that creates batches

To create a batch we need to know:

`ind` - index of this batch in the whole data

`batch_size` - size of the batch

`x, y` - the data and labels/targets

`masks` - for choosing the right masks for the new batch

`device` - for moving the batch onto the right device


An example (only for masks, same is for everything else):
```
masks = [[1, 1, 1, 1, 0],
         [1, 1, 1, 0, 0],
         [1, 1, 1, 1, 1]]

batch_size = 2
ind = 0

# Calculate given the forumals in the get_batch function
start = 0
end = 2

# Now we can see that for the first batch we can reduce the MAX_SEQ_LEN to 4 instead of 5 as the last column is '0' for all the selected rows.
mask_batch = [[1, 1, 1, 1],
              [1, 1, 1, 0]]
```
---
Notes:
- When working with pytorch batches can be of different shapes

In [0]:
def get_batch(ind, batch_size, x, y, masks, device):
  # Get the start/end index for this batch
  start = ind * batch_size
  end = (ind+1) * batch_size
    
  # Get the batch
  x_batch = x[start:end]
  y_batch = y[start:end] #? TODO
  mask_batch = masks[start:end] #? TODO
  
  # Get the longest sequence
  max_len = max(mask_batch.sum(1).int()) #? TODO:
  
  # Cut off the unnecessary part
  x_batch = x_batch[:, 0:max_len]
  y_batch = y_batch[:, 0:max_len] #? TODO
  mask_batch = mask_batch[:, 0:max_len] #? TODO
  
  # Return and move the batches to the right device
  return x_batch.to(device), y_batch.to(device), mask_batch.to(device) #? TODO: order is x, y, mask

## Create the network

In [0]:
class RNN(nn.Module):
  def __init__(self, embeddings, padding_idx):
    super(RNN, self).__init__()
    # Get the required sizes
    vocab_size = len(embeddings)
    embedding_size = len(embeddings[0])
    
    # Create embeddings
    self.embeddings = nn.Embedding(vocab_size, embedding_size, padding_idx=padding_idx) #?
    # Load existing weights
    self.embeddings.load_state_dict({'weight': embeddings}) #?
    # Disable training for the embeddings - IMPORTANT
    self.embeddings.weight.requires_grad = False #?
    
 
    hidden_size = 300
    bid = True # Is the network bidirectional

    # Create the RNN cell - devide 
    self.rnn = nn.LSTM(input_size=300, 
                       hidden_size=hidden_size // (2 if bid else 1), 
                       num_layers=2, 
                       dropout=0.5, 
                       bidirectional=bid)
    self.fc1 = nn.Linear(hidden_size, 2)

  def forward(self, x, mask):
    # Embed the input: from id -> vec
    x = self.embeddings(x) # x.shape = batch_size x sequence_length x emb_size

    # Tell RNN to ignore padding and set the batch_first to True
    x = nn.utils.rnn.pack_padded_sequence(x, mask.sum(1).int(), batch_first=True,
                                          enforce_sorted=False)  #? TODO

    # Run 'x' through the RNN
    x, hidden = self.rnn(x) #? TODO

    # Add the padding again
    x, hidden = torch.nn.utils.rnn.pad_packed_sequence(x, batch_first=True) #? TODO

    # Push x through the fc network
    x = self.fc1(x) #? TODO

    # Multiply by mask to ignore places where we have padding - IMPORTANT
    x *= mask.unsqueeze(2)
    
    return x

## Create the device, network, criterion and optimizer

In [0]:
device = torch.device(DEVICE) # Create a torch device
net = RNN(embeddings, padding_idx=word2id['<PAD>']) # Create an instance of the RNN, take care what input parameters does it require
criterion = nn.CrossEntropyLoss() # Set the criterion to Cross Entropy Loss
parameters = filter(lambda p: p.requires_grad, net.parameters()) # Get only the parameters that require training
optimizer = optim.Adam(parameters, lr=0.001) # Set the optimizer to Adam with lr = 0.001
net.to(device) # Move the network to device

RNN(
  (embeddings): Embedding(22766, 300, padding_idx=22765)
  (rnn): LSTM(300, 150, num_layers=2, dropout=0.5, bidirectional=True)
  (fc1): Linear(in_features=300, out_features=2, bias=True)
)

In [0]:
batch_size = 40
# Calculate the number of batches given training size len(x_train)
num_batches = int(np.ceil(len(x_train) / batch_size))
for epoch in range(50):
  # Switch network to train mode
  net.train()

  # Create the running loss array
  running_loss = []
  for i in range(num_batches):
    x_train_batch, y_train_batch, mask_train_batch = get_batch(ind=i, 
                                                               batch_size=batch_size, 
                                                               x=x_train, 
                                                               y=y_train, 
                                                               masks=masks_train, 
                                                               device=device)
    # zero gradients
    optimizer.zero_grad() #? TODO
    # Get outputs for our batch
    outputs = net(x_train_batch, mask_train_batch) #? TODO
    # Get loss
    loss = criterion(outputs.view(-1,2), y_train_batch.view(-1))#? TODO
    # Do the backward step
    loss.backward() #? TODO
    
    # Clip grads
    parameters = filter(lambda p: p.requires_grad, net.parameters())
    torch.nn.utils.clip_grad_norm_(parameters, 0.25)
    
    # Do the optimizer step
    optimizer.step() #? TODO

    # Add the loss to the running_loss
    running_loss.append(loss.item()) #? TODO

  if epoch % 5 == 0:
    net.eval()
    x_test_batch, y_test_batch, masks_test_batch = get_batch(ind=0, 
                                                              batch_size=len(x_test), 
                                                              x=x_test, 
                                                              y=y_test, 
                                                              masks=masks_test, 
                                                              device=device)
    outputs_test = net(x_test_batch, masks_test_batch)
    outputs_test = outputs_test.view(-1, 2)

    train_loss = np.average(running_loss) #? TODO
    test_loss = criterion(outputs_test, y_test_batch.view(-1)).item() #? TODO

    print("Train Loss: {:5}\nTest Loss:  {:5}\n\n".format(train_loss, test_loss))
    print(sklearn.metrics.classification_report(y_test_batch.view(-1).cpu().numpy(), 
                                                torch.max(outputs_test, 1)[1].cpu().detach().numpy()))    
print('Finished Training')

Train Loss: 0.38484168434754396
Test Loss:  0.3705970048904419


              precision    recall  f1-score   support

           0       0.99      1.00      0.99    202549
           1       0.73      0.24      0.36      3811

    accuracy                           0.98    206360
   macro avg       0.86      0.62      0.68    206360
weighted avg       0.98      0.98      0.98    206360

Train Loss: 0.31017564886655563
Test Loss:  0.35322946310043335


              precision    recall  f1-score   support

           0       0.99      1.00      0.99    202549
           1       0.74      0.68      0.71      3811

    accuracy                           0.99    206360
   macro avg       0.87      0.84      0.85    206360
weighted avg       0.99      0.99      0.99    206360

Train Loss: 0.3008638471364975
Test Loss:  0.35312479734420776


              precision    recall  f1-score   support

           0       1.00      0.99      0.99    202549
           1       0.72      0.76      0.

KeyboardInterrupt: ignored

## Print one example from the test set - Sanity check

In [0]:
net.eval()
s_ind = 1 # Sample index
predictions = torch.max(torch.softmax(net(x_test_batch, masks_test_batch), dim=2), 2)[1].cpu().detach().numpy()
for ind, i in enumerate(x_test[s_ind].detach().cpu().numpy()[0:20]): # 20 tokens
  print(id2word[i], predictions[s_ind][ind], y_test_batch[s_ind].detach().cpu().numpy()[ind])

evaluation 0 0
of 0 0
choroidal 0 0
thickness 0 0
in 0 0
patients 0 0
with 0 0
pseudoexfoliation 1 1
syndrome 1 1
and 0 0
pseudoexfoliation 1 1
glaucoma 1 1
purpose 0 0
. 0 0
to 0 0
compare 0 0
the 0 0
macular 0 0
and 0 0
peripapillary 0 0


## Write our own document and do diseases detection

In [0]:
# Write a document that contains a disease 
document = "The patient suffered from a non-epileptic seizure on wednesday he suffers from epilepsy and Creutzfeldt–Jakob disease (CJD)" #? TODO
# Convert to tokens and then indicies - take care of OOV
tkns = [token.lower() for token in document.split(" ")] #? TODO
ind_tkns = [word2id.get(tkn, word2id['<UNK>']) for tkn in tkns] #? TODO

# Convert to torch tensor - note that it has to be a batch
input = torch.tensor([ind_tkns], dtype=torch.long).to(device)
# Create the mask
masks = torch.tensor([[1] * len(ind_tkns)], dtype=torch.long).to(device)

# Evaluate
net.eval()
predictions = torch.max(torch.softmax(net(input, masks), dim=2), 2)[1].cpu().detach().numpy()[0]

In [0]:
# Print the tokens and predictions
for ind, tkn in enumerate(tkns):
  print("{:10} - {}".format(tkn, ("disease" if predictions[ind] else "")))

the        - 
patient    - 
suffered   - 
from       - 
a          - 
non-epileptic - 
seizure    - 
on         - 
wednesday  - 
he         - 
suffers    - 
from       - 
epilepsy   - disease
and        - 
creutzfeldt–jakob - disease
disease    - disease
(cjd)      - 
