# Named Entity Recognition on BC5CDR (Disease Corpus) with BioBERT





Notebook to train/fine-tune a BioBERT model to perform named entity recognition (NER). 

The [dataset](https://github.com/aczgh/NER/blob/main/data.txt) used is a pre-processed version of the BC5CDR (BioCreative V CDR task corpus: a resource for  relation extraction) dataset from [Li et al. (2016)](https://github.com/aczgh/NER/blob/main/data.txt).


Our model trained on top of BioBERT has an F1-score of **97.7%** 

The notebook is structured as follows:
* Setting up the GPU Environment
* Getting Data
* Training and Testing the Model
* Using the Model (Running Inference)

#### Task Description

> Named entity recognition (NER) is the task of tagging entities in text with their corresponding type. Approaches typically use BIO notation, which differentiates the beginning (B) and the inside (I) of entities. O is used for non-entity tokens.

# Setting up the GPU Environment

#### Ensure we have a GPU runtime

If you're running this notebook in Google Colab, select `Runtime` > `Change Runtime Type` from the menubar. Ensure that `GPU` is selected as the `Hardware accelerator`. This will allow us to use the GPU to train the model subsequently.

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
!python3 -m pip install setuptools==59.5.0

!pip3 install --upgrade numpy==1.20.3


Collecting setuptools==59.5.0
  Downloading setuptools-59.5.0-py3-none-any.whl (952 kB)
[?25l[K     |▍                               | 10 kB 27.4 MB/s eta 0:00:01[K     |▊                               | 20 kB 27.0 MB/s eta 0:00:01[K     |█                               | 30 kB 23.3 MB/s eta 0:00:01[K     |█▍                              | 40 kB 11.7 MB/s eta 0:00:01[K     |█▊                              | 51 kB 11.2 MB/s eta 0:00:01[K     |██                              | 61 kB 13.0 MB/s eta 0:00:01[K     |██▍                             | 71 kB 12.9 MB/s eta 0:00:01[K     |██▊                             | 81 kB 12.5 MB/s eta 0:00:01[K     |███                             | 92 kB 13.7 MB/s eta 0:00:01[K     |███▍                            | 102 kB 12.2 MB/s eta 0:00:01[K     |███▉                            | 112 kB 12.2 MB/s eta 0:00:01[K     |████▏                           | 122 kB 12.2 MB/s eta 0:00:01[K     |████▌                           | 133 kB 

#### Install Dependencies and Restart Runtime

In [4]:
!pip install -q transformers
!pip install -q simpletransformers

[K     |████████████████████████████████| 4.0 MB 11.9 MB/s 
[K     |████████████████████████████████| 895 kB 48.9 MB/s 
[K     |████████████████████████████████| 596 kB 39.7 MB/s 
[K     |████████████████████████████████| 77 kB 5.2 MB/s 
[K     |████████████████████████████████| 6.6 MB 26.1 MB/s 
[K     |████████████████████████████████| 249 kB 9.2 MB/s 
[K     |████████████████████████████████| 1.8 MB 11.1 MB/s 
[K     |████████████████████████████████| 325 kB 8.6 MB/s 
[K     |████████████████████████████████| 43 kB 837 kB/s 
[K     |████████████████████████████████| 10.1 MB 25.2 MB/s 
[K     |████████████████████████████████| 1.2 MB 47.1 MB/s 
[K     |████████████████████████████████| 144 kB 49.4 MB/s 
[K     |████████████████████████████████| 181 kB 49.8 MB/s 
[K     |████████████████████████████████| 63 kB 1.5 MB/s 
[K     |████████████████████████████████| 136 kB 52.7 MB/s 
[K     |████████████████████████████████| 212 kB 50.3 MB/s 
[K     |██████████████████████

In [5]:
## import library
import numpy as np
## import lib|rary 
import urllib.request
from pathlib import Path
import pandas as pd

In [6]:
import pandas as pd
def read_conll(filename):
    df = pd.read_csv(filename,
                    sep = '\t', header = None, keep_default_na = False,
                    names = ['words', 'pos', 'chunk', 'labels'],
                    quoting = 3, skip_blank_lines = False)
    df = df[~df['words'].astype(str).str.startswith('-DOCSTART- ')] # Remove the -DOCSTART- header
    df['sentence_id'] = (df.words == '').cumsum()
    return df[df.words != '']

In [7]:
def download_file(url, output_file):
  Path(output_file).parent.mkdir(parents=True, exist_ok=True)
  urllib.request.urlretrieve (url, output_file)

download_file('https://raw.githubusercontent.com/shreyashub/BioFLAIR/master/data/ner/bc5cdr/train.txt', '/content/data/train.txt')
download_file('https://raw.githubusercontent.com/shreyashub/BioFLAIR/master/data/ner/bc5cdr/test.txt', '/content/data/test.txt')
download_file('https://raw.githubusercontent.com/shreyashub/BioFLAIR/master/data/ner/bc5cdr/dev.txt', '/content/data/dev.txt')

In [8]:
train_df = read_conll('/content/data/train.txt')
test_df = read_conll('/content/data/test.txt')
dev_df = read_conll('/content/data/dev.txt')

In [9]:
train_df['labels'].value_counts()

O           82026
I-Entity    10931
B-Entity     2413
Name: labels, dtype: int64

In [10]:
## B-Entity for all train
print(len(train_df[train_df['labels'] == "B-Entity"]))
print(len(train_df[train_df['labels'] == "I-Entity"]))
print(len(train_df[train_df['labels'] == "O"]))

2413
10931
82026


In [11]:
train_df['labels'].value_counts()

O           82026
I-Entity    10931
B-Entity     2413
Name: labels, dtype: int64

In [12]:
## B-Entity for all test
print(len(test_df[test_df['labels'] == "B-Entity"]))
print(len(test_df[test_df['labels'] == "I-Entity"]))
print(len(test_df[test_df['labels'] == "O"]))

2246
11114
85331


In [13]:
## B-Entity for all dev
print(len(dev_df[dev_df['labels'] == "B-Entity"]))
print(len(dev_df[dev_df['labels'] == "I-Entity"]))
print(len(dev_df[dev_df['labels'] == "O"]))

2317
10934
81186


In [14]:
## B-Entity  
2413+2246+2317

6976

In [15]:
## I-Entity  
10931+11114+10934

32979

In [16]:

dev_df.head(3)

Unnamed: 0,words,pos,chunk,labels,sentence_id
2,Tricuspid,ADJ,O,B-Entity,1
3,valve,NOUN,O,I-Entity,1
4,regurgitation,NOUN,O,I-Entity,1


In [17]:
## read Disease
di = pd.read_csv('/content/drive/MyDrive/NER-project/dataset_analysis/compineAllDisease.csv')
di['labels'].value_counts()

O    440304
B     17163
I     14700
Name: labels, dtype: int64

In [18]:
di.columns

Index(['words', 'labels'], dtype='object')

In [19]:
di.head(25)

Unnamed: 0,words,labels
0,BRCA1,O
1,is,O
2,secreted,O
3,and,O
4,exhibits,O
5,properties,O
6,of,O
7,a,O
8,granin,O
9,.,O


In [20]:

di_3 = di[di['labels'] == "I"]
di_4 = list(set(di_3['words']))

print(len(di_4))

1614


In [21]:
print("length disease 2 is : ", len(di_4))

length disease 2 is :  1614


In [22]:
di_4[:16]

['headaches',
 'tachycardias',
 'fire',
 'cataract',
 'infarcts',
 'deformity',
 'glomerular',
 'meningitis',
 'diarrhea',
 'accidents',
 'rate',
 'thromboembolism',
 'psychosis',
 'thrombo',
 'neoplasias',
 'adrenal']

In [23]:
newalldiseses = []
for d in di_4:
  newalldiseses.append(d)

In [24]:
print(len(newalldiseses))

1614


In [25]:
## read disease and drug dataset from drive
disease_drug = pd.read_csv('/content/drive/MyDrive/NER-project/dataset_analysis/diseaseDrug.csv')
disease_drug.head(10)

Unnamed: 0,words,pos,chunk,labels
0,Valsartan,PROPN,O,drug
1,Guanfacine,PROPN,O,drug
2,Lybrel,PROPN,O,drug
3,Ortho Evra,PROPN,O,drug
4,Buprenorphine / naloxone,PROPN,O,drug
5,Cialis,PROPN,O,drug
6,Levonorgestrel,PROPN,O,drug
7,Aripiprazole,PROPN,O,drug
8,Keppra,PROPN,O,drug
9,Ethinyl estradiol / levonorgestrel,PROPN,O,drug


In [26]:
disease_drug['labels'].value_counts()

drug       215063
disease    215063
Name: labels, dtype: int64

In [27]:
## list All drug 
drug = disease_drug[disease_drug['labels'] == "drug"]
drug.head(3)

Unnamed: 0,words,pos,chunk,labels
0,Valsartan,PROPN,O,drug
1,Guanfacine,PROPN,O,drug
2,Lybrel,PROPN,O,drug


In [28]:
drug.shape

(215063, 4)

In [29]:
## list All Disease 
disease = disease_drug[disease_drug['labels'] == "disease"]
disease.head(3)

Unnamed: 0,words,pos,chunk,labels
215063,Left Ventricular Dysfunction,PROPN,O,disease
215064,ADHD,PROPN,O,disease
215065,Birth Control,PROPN,O,disease


In [30]:
disease['labels'].value_counts()

disease    215063
Name: labels, dtype: int64

In [31]:
drug = list(drug['words'])
disease = list(disease['words'])

In [32]:
disease[:5]

['Left Ventricular Dysfunction',
 'ADHD',
 'Birth Control',
 'Birth Control',
 'Opiate Dependence']

In [33]:
print(len(disease))
print(len(drug))

215063
215063


In [34]:
disease = list(set(disease))
drug = list(set(drug))

In [35]:
print("len unique disease is : " , len(disease))
print("len unique drug is : " , len(drug))

len unique disease is :  917
len unique drug is :  3671


In [36]:
##  djcjklsdxldcdlz;
with open('/content/drive/MyDrive/NER-project/dataset_analysis/disease.txt') as f:
  alist = [line.rstrip() for line in f]

In [37]:
new_diseae = list(alist)

In [38]:
print(len(new_diseae))

796


In [39]:
disease.extend(newalldiseses)
disease.extend(new_diseae)

In [40]:
print("len unique disease is : " , len(disease))
print("len unique drug is : " , len(drug))

len unique disease is :  3327
len unique drug is :  3671


In [41]:
listDisease = []
listDrug = []

for j in range(9):
  for i in range(len(drug)):
    listDrug.append(drug[i])

for j in range(2):
  for i in range(len(disease)):
    listDisease.append(disease[i])
    

In [42]:
new = listDisease[200:570]
listDisease.extend(new)

In [43]:
print("listDisease is : " , len(listDisease))
print("listDrug is : " , len(listDrug))

listDisease is :  7024
listDrug is :  33039


In [44]:
print(len(list(set(listDrug))))


3671


## O	        248543
## I-Entity	  32979
## B-Entity 	6976

## drug >>>>>>>>  I-Entity

## disease >>>>>  B-Entity


In [45]:
print("listDisease is : " , len(listDisease))
print("listDrug is : " , len(listDrug))

listDisease is :  7024
listDrug is :  33039


In [46]:
listDisease[8]

'Urticaria'

In [47]:
import random

random.shuffle(listDisease)
random.shuffle(listDrug)


O           82026

I-Entity    10931

B-Entity     2413

In [48]:
train_df['labels'].value_counts()

O           82026
I-Entity    10931
B-Entity     2413
Name: labels, dtype: int64

In [49]:
## read Train data
data = read_conll('/content/data/train.txt')
data['labels']
data = data.reset_index()
j=0    
for i in range(len(data)):
  if data['labels'][i] == "I-Entity":
    data['words'][i] = listDrug[j]
    j+=1

j=0    
for i in range(len(data)):
  if data['labels'][i] == "B-Entity":
    data['words'][i] = listDisease[j]
    j+=1    

train_df = data


A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  


In [50]:
test_df['labels'].value_counts()

O           85331
I-Entity    11114
B-Entity     2246
Name: labels, dtype: int64

In [51]:
## read test data
data = read_conll('/content/data/test.txt')
data['labels']
data = data.reset_index()
j=10931
for i in range(len(data)):
  if data['labels'][i] == "I-Entity":
    data['words'][i] = listDrug[j]
    j+=1

j=2413
for i in range(len(data)):
  if data['labels'][i] == "B-Entity":
    data['words'][i] = listDisease[j]
    j+=1    

test_df = data

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  


In [52]:
## read test data
data = read_conll('/content/data/dev.txt')
data = data.reset_index()
j=22045
for i in range(len(data)):
  if data['labels'][i] == "I-Entity":
    data['words'][i] = listDrug[j]
    j+=1

j=4659
for i in range(len(data)):
  if data['labels'][i] == "B-Entity":
    data['words'][i] = listDisease[j]
    j+=1    

dev_df = data

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  import sys
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  del sys.path[0]


In [53]:
train_df['labels'] = train_df['labels'].map({"I-Entity": "drug" , "B-Entity":"disease" , "O" : "O"})
test_df['labels'] = test_df['labels'].map({"I-Entity": "drug" , "B-Entity":"disease" , "O" : "O"})
dev_df['labels'] = dev_df['labels'].map({"I-Entity": "drug" , "B-Entity":"disease" , "O" : "O"})

In [54]:
## drop index 
train_df.drop('index' , axis = 1 , inplace = True)
test_df.drop('index' , axis = 1 , inplace = True)
dev_df.drop('index' , axis = 1 , inplace = True)

In [55]:
dev_df.head(4)

Unnamed: 0,words,pos,chunk,labels,sentence_id
0,ICA,ADJ,O,disease,1
1,Lidocaine / prilocaine,NOUN,O,drug,1
2,Hydrochlorothiazide / irbesartan,NOUN,O,drug,1
3,and,CCONJ,O,O,1


In [56]:
## check if train and test and dev contain null value 
print(train_df.isnull().sum())
print(test_df.isnull().sum())
print(dev_df.isnull().sum())

words          0
pos            0
chunk          0
labels         0
sentence_id    0
dtype: int64
words          0
pos            0
chunk          0
labels         0
sentence_id    0
dtype: int64
words          2
pos            0
chunk          0
labels         0
sentence_id    0
dtype: int64


In [57]:
## drop null value 
train_df.dropna(axis = 0 , inplace = True)
test_df.dropna(axis = 0 , inplace = True)
dev_df.dropna(axis = 0 , inplace = True)

In [58]:
## check if train and test and dev contain null value 
print(train_df.isnull().sum())
print(test_df.isnull().sum())
print(dev_df.isnull().sum())

words          0
pos            0
chunk          0
labels         0
sentence_id    0
dtype: int64
words          0
pos            0
chunk          0
labels         0
sentence_id    0
dtype: int64
words          0
pos            0
chunk          0
labels         0
sentence_id    0
dtype: int64


In [59]:
print("len of train_df is : " , train_df['sentence_id'].nunique())
print("len of test_df is : " , test_df['sentence_id'].nunique())
print("len of dev_df is : " , dev_df['sentence_id'].nunique())

len of train_df is :  3942
len of test_df is :  4139
len of dev_df is :  3949


In [60]:
print("value counts for label in train is : " , train_df['labels'].value_counts())
print("\n\n\n")
print("value counts for label in test is : " , test_df['labels'].value_counts())
print("\n\n\n")
print("value counts for label in dev_df is : " , dev_df['labels'].value_counts())

value counts for label in train is :  O          82026
drug       10931
disease     2413
Name: labels, dtype: int64




value counts for label in test is :  O          85331
drug       11114
disease     2246
Name: labels, dtype: int64




value counts for label in dev_df is :  O          81186
drug       10934
disease     2315
Name: labels, dtype: int64


In [61]:
train_df['sentence_id'].nunique()

3942

In [62]:
dev_df.head(3)

Unnamed: 0,words,pos,chunk,labels,sentence_id
0,ICA,ADJ,O,disease,1
1,Lidocaine / prilocaine,NOUN,O,drug,1
2,Hydrochlorothiazide / irbesartan,NOUN,O,drug,1


In [63]:
dev_df['labels'].value_counts()

O          81186
drug       10934
disease     2315
Name: labels, dtype: int64

In [64]:
data = [[train_df['sentence_id'].nunique(), test_df['sentence_id'].nunique(), dev_df['sentence_id'].nunique()]]

# Prints out the dataset sizes of train and test sets per label.
pd.DataFrame(data, columns=["Train", "Test", "Dev"])

Unnamed: 0,Train,Test,Dev
0,3942,4139,3949


In [65]:
test_df.head(4)

Unnamed: 0,words,pos,chunk,labels,sentence_id
0,Lo Loestrin Fe,PROPN,O,drug,1
1,-,PUNCT,O,O,1
2,associated,VERB,O,O,1
3,Nabilone,NOUN,O,drug,1


In [66]:
custom_labels = list(train_df['labels'].unique())
print(custom_labels)

['drug', 'O', 'disease']


# Training Model

#### Set up the Training Arguments

We set up the training arguments. Here we train to 10 epochs to get accuracy close to the SOTA. The train, test and dev sets are relatively small so we don't have to wait too long. We set a sliding window as NER sequences can be quite long and because we have limited GPU memory we can't increase the `max_seq_length` too long.

In [73]:
train_args = {
    'reprocess_input_data': True,
    'overwritae_output_dir': True,
    'sliding_window': True,
    'max_seq_length': 128,
    'num_train_epochs': 5,
    'train_batch_size': 32,
    'fp16': True,
    'output_dir': '/content/drive/MyDrive/NER-project/modelTraining',
    'best_model_dir': '/content/drive/MyDrive/NER-project/modelTraining/best_model/',
    'evaluate_during_training': True,
}

The following line of code saves (to the variable custom_labels) a set of all the NER tags/labels in the dataset.

In [74]:
custom_labels = list(train_df['labels'].unique())
print(custom_labels)

['drug', 'O', 'disease']


In [75]:
custom_labels = ['drug', 'O', 'disease'] 

#### Train the Model

Once we have setup the `train_args` dictionary, the next step would be to train the model. We use the pre-trained BioBERT model (by [DMIS Lab, Korea University](https://huggingface.co/dmis-lab)) from the awesome [Hugging Face Transformers](https://github.com/huggingface/transformers) library as the base and use the [Simple Transformers library](https://simpletransformers.ai/docs/classification-models/) on top of it to make it so we can train the NER (sequence tagging) model with just a few lines of code.

In [76]:
from simpletransformers.ner import NERModel
from transformers import AutoTokenizer
import pandas as pd
import logging

logging.basicConfig(level=logging.DEBUG)
transformers_logger = logging.getLogger('transformers')
transformers_logger.setLevel(logging.WARNING)



In [77]:

# We use the bio BERT pre-trained model.
model = NERModel('bert', 'dmis-lab/biobert-v1.1', labels=custom_labels, args=train_args)



DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /dmis-lab/biobert-v1.1/resolve/main/config.json HTTP/1.1" 200 0
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /dmis-lab/biobert-v1.1/resolve/main/pytorch_model.bin HTTP/1.1" 302 0
Some weights of BertForTokenClassification were not initialized from the model checkpoint at dmis-lab/biobert-v1.1 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /dmis-lab/biobert-v1.1/resolve/main/vocab.txt HTTP/1.1" 200 0
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface

In [78]:
# Train the model
# https://simpletransformers.ai/docs/tips-and-tricks/#using-early-stopping
model.train_model(train_df, eval_data=dev_df)

INFO:simpletransformers.ner.ner_model: Converting to features started.


  0%|          | 0/2 [00:00<?, ?it/s]



Epoch:   0%|          | 0/5 [00:00<?, ?it/s]

Running Epoch 0 of 5:   0%|          | 0/124 [00:00<?, ?it/s]

INFO:simpletransformers.ner.ner_model: Converting to features started.


  0%|          | 0/3 [00:00<?, ?it/s]

Running Evaluation:   0%|          | 0/494 [00:00<?, ?it/s]



Running Epoch 1 of 5:   0%|          | 0/124 [00:00<?, ?it/s]

INFO:simpletransformers.ner.ner_model: Converting to features started.


  0%|          | 0/3 [00:00<?, ?it/s]

Running Evaluation:   0%|          | 0/494 [00:00<?, ?it/s]

Running Epoch 2 of 5:   0%|          | 0/124 [00:00<?, ?it/s]

INFO:simpletransformers.ner.ner_model: Converting to features started.


  0%|          | 0/3 [00:00<?, ?it/s]

Running Evaluation:   0%|          | 0/494 [00:00<?, ?it/s]

Running Epoch 3 of 5:   0%|          | 0/124 [00:00<?, ?it/s]

INFO:simpletransformers.ner.ner_model: Converting to features started.


  0%|          | 0/3 [00:00<?, ?it/s]

Running Evaluation:   0%|          | 0/494 [00:00<?, ?it/s]

Running Epoch 4 of 5:   0%|          | 0/124 [00:00<?, ?it/s]

INFO:simpletransformers.ner.ner_model: Converting to features started.


  0%|          | 0/3 [00:00<?, ?it/s]

Running Evaluation:   0%|          | 0/494 [00:00<?, ?it/s]

INFO:simpletransformers.ner.ner_model: Training of bert model complete. Saved to /content/drive/MyDrive/NER-project/modelTraining.


(620,
 defaultdict(list,
             {'eval_loss': [0.015534223348073386,
               0.0165804699045774,
               0.01681740663309161,
               0.016067232578360808,
               0.0168092842671208],
              'f1_score': [0.9675296932410492,
               0.9694817820610364,
               0.9736573759347381,
               0.9777114851907095,
               0.9782144382742417],
              'global_step': [124, 248, 372, 496, 620],
              'precision': [0.9710977701543739,
               0.9627266621893888,
               0.9718405428329092,
               0.9788377847939244,
               0.9816529492455418],
              'recall': [0.9639877405074068,
               0.9763323684658607,
               0.9754810148135535,
               0.9765877745615529,
               0.9747999318917078],
              'train_loss': [0.010975207202136517,
               0.01730622909963131,
               0.00024192218552343547,
               0.0017937022494152188

In [None]:
model.results

In [None]:

# Evaluate the model in terms of accuracy score
result, model_outputs, preds_list = model.eval_model(test_df)

# save Model after Train It

In [None]:
model = NERModel('bert', '/content/drive/MyDrive/NER-project/modelTraining/best_model', labels=custom_labels, args=train_args)


In [None]:
test_df.head(3)

In [None]:
sample = test_df[test_df.sentence_id == 6].words.str.cat(sep=' ')
print(sample)

In [None]:
sample1 = test_df[test_df.sentence_id == 1].words.str.cat(sep=' ')
sample2 = test_df[test_df.sentence_id == 2].words.str.cat(sep=' ')
sample3 = test_df[test_df.sentence_id == 3].words.str.cat(sep=' ')
sample4 = test_df[test_df.sentence_id == 4].words.str.cat(sep=' ')
sample5 = test_df[test_df.sentence_id == 5].words.str.cat(sep=' ')
sample6 = test_df[test_df.sentence_id == 6].words.str.cat(sep=' ')

print(sample1)

# Using the Model (Running Inference)

Running the model to do some predictions/inference is as simple as calling `model.predict(samples)`. First we get a sentence from the test set and print it out. Then we run the prediction on the sentence.

In [None]:
samples = [sample1]
predictions, _ = model.predict(samples)
for idx, sample in enumerate(samples):
  print('{}: '.format(idx))
  for word in predictions[idx]:
    print('{}'.format(word))