In [3]:
import numpy as np
import torch
from torch import nn
import nltk
nltk.download('stopwords')
import pandas as pd
from nltk.corpus import stopwords
import torch.optim as optim
from tqdm import tqdm
import os

[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [4]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [0]:
contractions = { 
"n't": "not",
"'ve": "have",
"'d": "would",
"'ll": "will",
"'s": "is",
"'m": "am",
"ma'am": "madam",
"'re": "they are"
}

In [0]:
def clean_text(text, remove_stopwords = False):
    '''Remove unwanted characters, stopwords, and format the text to create fewer nulls word embeddings'''
    
    # Convert words to lower case
    text = text.lower()
    
    # Replace contractions with their longer forms 
    if True:
        text = text.split()
        new_text = []
        for word in text:
            if word in contractions:
                new_text.append(contractions[word])
            else:
                new_text.append(word)
        text = " ".join(new_text)
    
    # Format words and remove unwanted characters
    text = re.sub(r'https?:\/\/.*[\r\n]*', '', text, flags=re.MULTILINE)
    text = re.sub(r'\<a href', ' ', text)
    text = re.sub(r'&amp;', '', text) 
    text = re.sub(r'[_"\-;%()|+&=*%.,!?:#$@\[\]/]', ' ', text)
    text = re.sub(r'<br />', ' ', text)
    text = re.sub(r'\'', ' ', text)
    
    # Optionally, remove stop words
    if remove_stopwords:
        text = text.split()
        stops = set(stopwords.words("english"))
        text = [w for w in text if not w in stops]
        text = " ".join(text)

    return text

In [0]:
embeddings_index = {}
with open('/content/drive/My Drive/numberbatch-en.txt', encoding='utf-8') as f:
    for line in f:
        values = line.split(' ')
        word = values[0]
        embedding = np.asarray(values[1:], dtype='float32')
        embeddings_index[word] = embedding

In [0]:
training_data=[]
with open('/content/drive/My Drive/train.txt','r') as f:
  line=f.readline()
  while(line):
    line=f.readline()
    if(line.strip()==''):
       training_data.append({'sent':'','label':''})
    else: 
      line=line.split()  
      training_data[-1]['sent']+=' '+str(line[0])
      training_data[-1]['label']+=' '+str(line[-1])

validation_data=[]
with open('/content/drive/My Drive/valid.txt','r') as f:
  line=f.readline()
  while(line):
    line=f.readline()
    if(line.strip()==''):
       validation_data.append({'sent':'','label':''})
    else: 
      line=line.split()  
      validation_data[-1]['sent']+=' '+str(line[0])
      validation_data[-1]['label']+=' '+str(line[-1])   

In [0]:
df=pd.DataFrame(training_data)
df_valid=pd.DataFrame(validation_data)

In [0]:
df['label'].replace('',np.nan,inplace=True)
df_valid['label'].replace('',np.nan,inplace=True)

In [11]:
print(df.isna().sum())
print(df_valid.isna().sum())

sent     0
label    2
dtype: int64
sent     0
label    2
dtype: int64


In [0]:
df.dropna(inplace=True)

In [0]:
def count_words(count_dict, text):
    '''Count the number of occurrences of each word in a set of text'''
    for sentence in text:
        for word in sentence.split():
            if word not in count_dict:
                count_dict[word] = 1
            else:
                count_dict[word] += 1

In [0]:
word_counts = {}

count_words(word_counts, df['sent'].values)

In [15]:
missing_words = 0
threshold = 20

for word, count in word_counts.items():
    if count > threshold:
        if word not in embeddings_index:
            missing_words += 1
            
missing_ratio = round(missing_words/len(word_counts),4)*100
            
print("Number of words missing from CN:", missing_words)
print("Percent of words that are missing from vocabulary: {}%".format(missing_ratio))

Number of words missing from CN: 434
Percent of words that are missing from vocabulary: 1.8399999999999999%


In [16]:
vocab_to_int = {} 

value = 0
for word, count in word_counts.items():
    if count >= threshold or word in embeddings_index:
        vocab_to_int[word] = value
        value += 1

# Special tokens that will be added to our vocab
codes = ["<UNK>","</s>","<s>"]   

# Add codes to vocab
for code in codes:
    vocab_to_int[code] = len(vocab_to_int)

int_to_vocab = {}
for word, value in vocab_to_int.items():
    int_to_vocab[value] = word

usage_ratio = round(len(vocab_to_int) / len(word_counts),4)*100

print("Total number of unique words:", len(word_counts))
print("Number of words we will use:", len(vocab_to_int))
print("Percent of words we will use: {}%".format(usage_ratio))    

Total number of unique words: 23624
Number of words we will use: 9473
Percent of words we will use: 40.1%


In [0]:
def convert_to_ints(sentence):
    sentence_int = []
    # print(sentence)
    # sentence_int.append(vocab_to_int['<s>'])
    for word in sentence.split():
        if word in vocab_to_int:
            sentence_int.append(vocab_to_int[word])
        else:
            sentence_int.append(vocab_to_int["<UNK>"])
        # sentence_int.append(vocab_to_int["</s>"])        
    return sentence_int

In [0]:
df['encoding']=df['sent'].apply(lambda x:convert_to_ints(x))
df_valid['encoding']=df_valid['sent'].apply(lambda x:convert_to_ints(x))

In [0]:
unique_labels=np.unique(''.join(list(df['label'])).split())
tag_to_index={}
index_to_tag={}
for i,lbl in enumerate(list(unique_labels)):
  tag_to_index[lbl]=i
  index_to_tag[i]=lbl

In [0]:
def argmax(vec):
    # return the argmax as a python int
    _, idx = torch.max(vec, 1)
    return idx.item()
def log_sum_exp(vec):
    max_score = vec[0, argmax(vec)]
    max_score_broadcast = max_score.view(1, -1).expand(1, vec.size()[1])
    return max_score + \
        torch.log(torch.sum(torch.exp(vec - max_score_broadcast)))    

In [0]:
class BiLSTM_CRF(nn.Module):
    def __init__(self,vocab_size,hidden_size,embedding_size,tag_to_index):
        super(BiLSTM_CRF, self).__init__()
        self.target_size=len(tag_to_index)
        self.embed = nn.Embedding(vocab_size, embedding_size)
        self.lstm = nn.LSTM(embedding_size, hidden_size // 2,num_layers=2, bidirectional=True)
        self.hidden2tag = nn.Linear(hidden_size, self.target_size)
        self.transitions = nn.Parameter(torch.randn(self.target_size, self.target_size))
        self.start_transition=nn.Parameter(torch.randn(1, self.target_size))
        self.end_transition=nn.Parameter(torch.randn(1, self.target_size))
        # self.hidden=self.init_hidden(hidden_size)
    def _get_lstm_features(self,sentence):
        embed=self.embed(sentence)
        output,hidden=self.lstm(embed.view(len(sentence),1,50))
        output=self.hidden2tag(output)
        return output
    def _forward(self,features):   
        temp= self.start_transition.data+features[0,0,:]
        labels=[]
        lbl=torch.argmax(temp,1)
        # score=torch.max(veterbi_var)
        labels.append(lbl)
        # labels.append(argmax(veterbi_var))
        # score=torch.max(veterbi_var)        
        for indx in range(1,features.shape[0]):
            veterbi_var=[]
            for tag in range(features.shape[2]):                             
                veterbi_var.append(log_sum_exp(temp+self.transitions[:,tag].view(1,9)+features[indx,0,tag].expand(1,9)).view(1,1))
            # print(veterbi_var)                   
            temp=torch.cat(veterbi_var,dim=1)
            # print(temp.shape)
            lbl=torch.argmax(temp,1)
                # score=torch.max(veterbi_var)
            labels.append(lbl)
            # score=torch.max(veterbi_var)   
        temp+=self.end_transition.view(1,-1)
        # print(torch.min(temp))
        score=log_sum_exp(temp)
        return score,labels
    def _score_sentence(self,feats,tags):
        score=feats[0,0,tags[0]]+self.start_transition[0,tags[0]]
        for indx in range(1,feats.shape[0]):
            # print(feats.shape,tags[indx],indx,tags.shape)
            score+=feats[indx,0,tags[indx]]+self.transitions[tags[indx-1],tags[indx]]
        score+=self.end_transition[0,tags[-1]] 
        return score
    def _neg_log_likelihood(self,sentence,targets):
        feats = self._get_lstm_features(sentence)
        forward_score,_=self._forward(feats)
        gold_score = self._score_sentence(feats, targets) 
        # print('forward score {}'.format(forward_score),'gold score {}'.format(gold_score)) 
        loss=forward_score - gold_score     
        return loss
    def _viterbi(self,features):
        # print(self.start_transition.data,features)
        temp= self.start_transition.data.view(1,-1)+features[0,0,:].view(1,-1)
        labels=[]
        # print(temp)
        temp,lbl=torch.max(temp,1)
        # print(temp,lbl)
        labels.append(lbl)     
        for indx in range(1,features.shape[0]):
            veterbi_var=[]
            for tag in range(features.shape[2]):                            
                veterbi_var.append(temp.view(1,1)+self.transitions[lbl[-1],tag].view(1)+features[indx,0,tag].view(1))
            # print(veterbi_var)                   
            temp=torch.cat(veterbi_var,dim=1)
            # lbl=torch.argmax(temp,1)
            temp,lbl=torch.max(temp,1)
            labels.append(lbl)  
        temp=temp.view(1,1)+self.end_transition[0,lbl].view(1,1)
        return temp,labels  
    def _predict(self,sentence):
        feats = self._get_lstm_features(sentence)
        score,labels=self._viterbi(feats)
        return labels

In [26]:
import pickle
from torch.nn.utils import clip_grad_norm_
vocab_size=len(vocab_to_int)
hidden_size=4
embedding_size=50
df_valid.dropna(inplace=True)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model=BiLSTM_CRF(len(vocab_to_int),hidden_size,embedding_size,tag_to_index).to(device)
# if(os.path.exists('/content/drive/My Drive/model.pt')):
#   model.load_state_dict(torch.load('/content/drive/My Drive/model.pt'))
optimizer = optim.SGD(model.parameters(), lr=1e-1, weight_decay=1e-4)
losses=[]
# if(os.path.exists('/content/drive/My Drive/losses.pkl')):
#   with open('/content/drive/My Drive/losses.pkl','rb') as f:
#     losses=pickle.load(f)
for epoch in tqdm(range(500)): 
    i=0
    for index,row in df.iterrows():
        model.zero_grad()
        sent=torch.tensor(row['encoding'],dtype=torch.long).to(device)
        targets = torch.tensor([tag_to_index[t] for t in row['label'].split()], dtype=torch.int32).to(device)
        loss = model._neg_log_likelihood(sent, targets)
        loss.backward()
        # print(loss.data.item())
        clip_grad_norm_(model.parameters(), 1)
        optimizer.step() 
        i+=1
        if(i==10):
          break
    print(loss.data.item())  
        # losses.append(loss)
    # val_loss=0    
    # for index,row in df_valid.iterrows():
    #     model.zero_grad()
    #     sent=torch.tensor(row['encoding'],dtype=torch.long).to(device)
    #     targets = torch.tensor([tag_to_index[t] for t in row['label'].split()], dtype=torch.int32).to(device)
    #     val_loss+= model._neg_log_likelihood(sent, targets).item()    
    # val_loss=val_loss/len(df_valid)
    # print('training loss:{}'.format(loss),'validation loss:{}'.format(val_loss))
    # if(len(losses)>0 and val_loss<min(losses)):
    #   torch.save(model.state_dict(), '/content/drive/My Drive/model.pt')            
    # losses.append(val_loss)


  0%|          | 0/500 [00:00<?, ?it/s][A
  0%|          | 1/500 [00:01<09:52,  1.19s/it][A

71.82562255859375



  0%|          | 2/500 [00:02<09:52,  1.19s/it][A

33.57405090332031



  1%|          | 3/500 [00:03<09:52,  1.19s/it][A

19.319229125976562



  1%|          | 4/500 [00:04<09:51,  1.19s/it][A

17.236106872558594



  1%|          | 5/500 [00:05<09:50,  1.19s/it][A

16.41014862060547



  1%|          | 6/500 [00:07<09:47,  1.19s/it][A

15.756370544433594



  1%|▏         | 7/500 [00:08<09:44,  1.19s/it][A

15.200782775878906



  2%|▏         | 8/500 [00:09<09:41,  1.18s/it][A

14.724571228027344



  2%|▏         | 9/500 [00:10<09:41,  1.18s/it][A

14.314743041992188



  2%|▏         | 10/500 [00:11<09:43,  1.19s/it][A

13.957733154296875



  2%|▏         | 11/500 [00:13<09:40,  1.19s/it][A

13.639045715332031



  2%|▏         | 12/500 [00:14<09:39,  1.19s/it][A

13.3487548828125



  3%|▎         | 13/500 [00:15<09:42,  1.20s/it][A

13.077713012695312



  3%|▎         | 14/500 [00:16<09:42,  1.20s/it][A

12.816200256347656



  3%|▎         | 15/500 [00:17<09:39,  1.20s/it][A

12.554840087890625



  3%|▎         | 16/500 [00:19<09:37,  1.19s/it][A

12.284225463867188



  3%|▎         | 17/500 [00:20<09:35,  1.19s/it][A

11.995559692382812



  4%|▎         | 18/500 [00:21<09:40,  1.20s/it][A

11.681243896484375



  4%|▍         | 19/500 [00:22<09:43,  1.21s/it][A

11.337486267089844



  4%|▍         | 20/500 [00:23<09:40,  1.21s/it][A

10.965080261230469



  4%|▍         | 21/500 [00:25<09:39,  1.21s/it][A

10.570449829101562



  4%|▍         | 22/500 [00:26<09:36,  1.21s/it][A

10.160171508789062



  5%|▍         | 23/500 [00:27<09:38,  1.21s/it][A

9.739425659179688



  5%|▍         | 24/500 [00:28<09:36,  1.21s/it][A

9.311660766601562



  5%|▌         | 25/500 [00:29<09:33,  1.21s/it][A

8.8765869140625



  5%|▌         | 26/500 [00:31<09:30,  1.20s/it][A

8.434310913085938



  5%|▌         | 27/500 [00:32<09:28,  1.20s/it][A

7.992218017578125



  6%|▌         | 28/500 [00:33<09:27,  1.20s/it][A

7.560028076171875



  6%|▌         | 29/500 [00:34<09:24,  1.20s/it][A

7.150604248046875



  6%|▌         | 30/500 [00:35<09:24,  1.20s/it][A

6.7721405029296875



  6%|▌         | 31/500 [00:37<09:20,  1.20s/it][A

6.4216156005859375



  6%|▋         | 32/500 [00:38<09:19,  1.20s/it][A

6.0907440185546875



  7%|▋         | 33/500 [00:39<09:17,  1.19s/it][A

5.775054931640625



  7%|▋         | 34/500 [00:40<09:21,  1.20s/it][A

5.4776458740234375



  7%|▋         | 35/500 [00:41<09:21,  1.21s/it][A

5.200469970703125



  7%|▋         | 36/500 [00:43<09:22,  1.21s/it][A

4.9509429931640625



  7%|▋         | 37/500 [00:44<09:18,  1.21s/it][A

4.7282257080078125



  8%|▊         | 38/500 [00:45<09:20,  1.21s/it][A

4.5217742919921875



  8%|▊         | 39/500 [00:46<09:18,  1.21s/it][A

4.330047607421875



  8%|▊         | 40/500 [00:48<09:17,  1.21s/it][A

4.151641845703125



  8%|▊         | 41/500 [00:49<09:13,  1.21s/it][A

3.9855499267578125



  8%|▊         | 42/500 [00:50<09:10,  1.20s/it][A

3.83392333984375



  9%|▊         | 43/500 [00:51<09:06,  1.20s/it][A

3.6945037841796875



  9%|▉         | 44/500 [00:52<09:11,  1.21s/it][A

3.5709075927734375



  9%|▉         | 45/500 [00:54<09:06,  1.20s/it][A

3.4357147216796875



  9%|▉         | 46/500 [00:55<09:06,  1.20s/it][A

3.3234100341796875



  9%|▉         | 47/500 [00:56<09:01,  1.19s/it][A

3.129241943359375



 10%|▉         | 48/500 [00:57<08:58,  1.19s/it][A

3.207763671875



 10%|▉         | 49/500 [00:58<08:58,  1.19s/it][A

2.97979736328125



 10%|█         | 50/500 [00:59<08:54,  1.19s/it][A

3.0217132568359375



 10%|█         | 51/500 [01:01<08:50,  1.18s/it][A

2.831878662109375



 10%|█         | 52/500 [01:02<08:49,  1.18s/it][A

2.8529052734375



 11%|█         | 53/500 [01:03<08:48,  1.18s/it][A

2.702972412109375



 11%|█         | 54/500 [01:04<08:45,  1.18s/it][A

2.6878662109375



 11%|█         | 55/500 [01:05<08:46,  1.18s/it][A

2.56951904296875



 11%|█         | 56/500 [01:07<08:43,  1.18s/it][A

2.538726806640625



 11%|█▏        | 57/500 [01:08<08:42,  1.18s/it][A

2.46160888671875



 12%|█▏        | 58/500 [01:09<08:40,  1.18s/it][A

2.552947998046875



 12%|█▏        | 59/500 [01:10<08:40,  1.18s/it][A

2.423980712890625



 12%|█▏        | 60/500 [01:11<08:43,  1.19s/it][A

2.359130859375



 12%|█▏        | 61/500 [01:12<08:42,  1.19s/it][A

2.331695556640625



 12%|█▏        | 62/500 [01:14<08:40,  1.19s/it][A

2.381195068359375



 13%|█▎        | 63/500 [01:15<08:42,  1.20s/it][A

2.447967529296875



 13%|█▎        | 64/500 [01:16<08:41,  1.20s/it][A

2.336273193359375



 13%|█▎        | 65/500 [01:17<08:40,  1.20s/it][A

2.254364013671875



 13%|█▎        | 66/500 [01:18<08:37,  1.19s/it][A

2.246917724609375



 13%|█▎        | 67/500 [01:20<08:36,  1.19s/it][A

2.399017333984375



 14%|█▎        | 68/500 [01:21<08:37,  1.20s/it][A

2.37139892578125



 14%|█▍        | 69/500 [01:22<08:41,  1.21s/it][A

2.31085205078125



 14%|█▍        | 70/500 [01:23<08:38,  1.21s/it][A

2.290435791015625



 14%|█▍        | 71/500 [01:25<08:39,  1.21s/it][A

2.2945556640625



 14%|█▍        | 72/500 [01:26<08:36,  1.21s/it][A

2.342987060546875



 15%|█▍        | 73/500 [01:27<08:33,  1.20s/it][A

2.304901123046875



 15%|█▍        | 74/500 [01:28<08:30,  1.20s/it][A

2.3603515625



 15%|█▌        | 75/500 [01:29<08:27,  1.19s/it][A

2.347381591796875



 15%|█▌        | 76/500 [01:30<08:25,  1.19s/it][A

2.527740478515625



 15%|█▌        | 77/500 [01:32<08:23,  1.19s/it][A

2.434600830078125



 16%|█▌        | 78/500 [01:33<08:25,  1.20s/it][A

2.70263671875



 16%|█▌        | 79/500 [01:34<08:23,  1.20s/it][A

2.496734619140625



 16%|█▌        | 80/500 [01:35<08:23,  1.20s/it][A

2.787841796875



 16%|█▌        | 81/500 [01:36<08:18,  1.19s/it][A

2.570556640625



 16%|█▋        | 82/500 [01:38<08:16,  1.19s/it][A

2.8609619140625



 17%|█▋        | 83/500 [01:39<08:14,  1.19s/it][A

2.62774658203125



 17%|█▋        | 84/500 [01:40<08:14,  1.19s/it][A

2.8740234375



 17%|█▋        | 85/500 [01:41<08:15,  1.19s/it][A

2.684326171875



 17%|█▋        | 86/500 [01:42<08:15,  1.20s/it][A

2.88958740234375



 17%|█▋        | 87/500 [01:44<08:14,  1.20s/it][A

2.73760986328125



 18%|█▊        | 88/500 [01:45<08:14,  1.20s/it][A

2.8876953125



 18%|█▊        | 89/500 [01:46<08:14,  1.20s/it][A

2.7874755859375



 18%|█▊        | 90/500 [01:47<08:12,  1.20s/it][A

2.88006591796875



 18%|█▊        | 91/500 [01:48<08:09,  1.20s/it][A

2.828125



 18%|█▊        | 92/500 [01:50<08:09,  1.20s/it][A

2.8653564453125



 19%|█▊        | 93/500 [01:51<08:06,  1.20s/it][A

2.856781005859375



 19%|█▉        | 94/500 [01:52<08:11,  1.21s/it][A

2.845794677734375



 19%|█▉        | 95/500 [01:53<08:13,  1.22s/it][A

2.8695068359375



 19%|█▉        | 96/500 [01:54<08:12,  1.22s/it][A

2.82049560546875



 19%|█▉        | 97/500 [01:56<08:08,  1.21s/it][A

2.8597412109375



 20%|█▉        | 98/500 [01:57<08:04,  1.21s/it][A

2.788330078125



 20%|█▉        | 99/500 [01:58<08:02,  1.20s/it][A

2.8238525390625



 20%|██        | 100/500 [01:59<08:00,  1.20s/it][A

2.745880126953125



 20%|██        | 101/500 [02:00<08:00,  1.20s/it][A

2.76458740234375



 20%|██        | 102/500 [02:02<07:59,  1.21s/it][A

2.68682861328125



 21%|██        | 103/500 [02:03<08:04,  1.22s/it][A

2.68255615234375



 21%|██        | 104/500 [02:04<08:04,  1.22s/it][A

2.608062744140625



 21%|██        | 105/500 [02:05<08:04,  1.23s/it][A

2.578369140625



 21%|██        | 106/500 [02:07<08:01,  1.22s/it][A

2.514617919921875



 21%|██▏       | 107/500 [02:08<08:01,  1.23s/it][A

2.463043212890625



 22%|██▏       | 108/500 [02:09<08:00,  1.23s/it][A

2.410186767578125



 22%|██▏       | 109/500 [02:10<07:59,  1.23s/it][A

2.326568603515625



 22%|██▏       | 110/500 [02:12<07:58,  1.23s/it][A

2.29248046875



 22%|██▏       | 111/500 [02:13<07:56,  1.23s/it][A

2.1688232421875



 22%|██▏       | 112/500 [02:14<07:52,  1.22s/it][A

2.15264892578125



 23%|██▎       | 113/500 [02:15<07:49,  1.21s/it][A

2.007904052734375



 23%|██▎       | 114/500 [02:16<07:46,  1.21s/it][A

1.99639892578125



 23%|██▎       | 115/500 [02:18<07:43,  1.20s/it][A

1.853302001953125



 23%|██▎       | 116/500 [02:19<07:38,  1.19s/it][A

1.82891845703125



 23%|██▎       | 117/500 [02:20<07:36,  1.19s/it][A

1.7010498046875



 24%|██▎       | 118/500 [02:21<07:36,  1.19s/it][A

1.714385986328125



 24%|██▍       | 119/500 [02:22<07:38,  1.20s/it][A

1.60504150390625



 24%|██▍       | 120/500 [02:24<07:35,  1.20s/it][A

1.597930908203125



 24%|██▍       | 121/500 [02:25<07:36,  1.21s/it][A

1.514739990234375



 24%|██▍       | 122/500 [02:26<07:33,  1.20s/it][A

1.529266357421875



 25%|██▍       | 123/500 [02:27<07:29,  1.19s/it][A

1.43072509765625



 25%|██▍       | 124/500 [02:28<07:27,  1.19s/it][A

1.53826904296875



 25%|██▌       | 125/500 [02:29<07:24,  1.19s/it][A

1.487396240234375



 25%|██▌       | 126/500 [02:31<07:23,  1.19s/it][A

1.502532958984375



 25%|██▌       | 127/500 [02:32<07:22,  1.19s/it][A

1.398468017578125



 26%|██▌       | 128/500 [02:33<07:23,  1.19s/it][A

1.55322265625



 26%|██▌       | 129/500 [02:34<07:21,  1.19s/it][A

1.4500732421875



 26%|██▌       | 130/500 [02:35<07:23,  1.20s/it][A

1.4215087890625



 26%|██▌       | 131/500 [02:37<07:20,  1.19s/it][A

2.02471923828125



 26%|██▋       | 132/500 [02:38<07:16,  1.19s/it][A

1.4912109375



 27%|██▋       | 133/500 [02:39<07:14,  1.18s/it][A

1.389862060546875



 27%|██▋       | 134/500 [02:40<07:12,  1.18s/it][A

1.502349853515625



 27%|██▋       | 135/500 [02:41<07:12,  1.18s/it][A

1.34173583984375



 27%|██▋       | 136/500 [02:43<07:10,  1.18s/it][A

1.362396240234375



 27%|██▋       | 137/500 [02:44<07:08,  1.18s/it][A

1.29327392578125



 28%|██▊       | 138/500 [02:45<07:09,  1.19s/it][A

1.3436279296875



 28%|██▊       | 139/500 [02:46<07:07,  1.18s/it][A

1.429840087890625



 28%|██▊       | 140/500 [02:47<07:06,  1.18s/it][A

1.52203369140625



 28%|██▊       | 141/500 [02:48<07:04,  1.18s/it][A

1.4913330078125



 28%|██▊       | 142/500 [02:50<07:02,  1.18s/it][A

1.413116455078125



 29%|██▊       | 143/500 [02:51<07:02,  1.18s/it][A

1.352508544921875



 29%|██▉       | 144/500 [02:52<07:05,  1.20s/it][A

1.438018798828125



 29%|██▉       | 145/500 [02:53<07:04,  1.20s/it][A

1.667266845703125



 29%|██▉       | 146/500 [02:54<07:03,  1.20s/it][A

1.449920654296875



 29%|██▉       | 147/500 [02:56<07:00,  1.19s/it][A

1.602752685546875



 30%|██▉       | 148/500 [02:57<06:56,  1.18s/it][A

1.5833740234375



 30%|██▉       | 149/500 [02:58<06:55,  1.18s/it][A

1.633270263671875



 30%|███       | 150/500 [02:59<06:56,  1.19s/it][A

1.562103271484375



 30%|███       | 151/500 [03:00<06:54,  1.19s/it][A

1.3975830078125



 30%|███       | 152/500 [03:02<06:51,  1.18s/it][A

1.40179443359375



 31%|███       | 153/500 [03:03<06:50,  1.18s/it][A

1.384033203125



 31%|███       | 154/500 [03:04<06:49,  1.18s/it][A

1.359405517578125



 31%|███       | 155/500 [03:05<06:49,  1.19s/it][A

1.345428466796875



 31%|███       | 156/500 [03:06<06:46,  1.18s/it][A

1.339324951171875



 31%|███▏      | 157/500 [03:07<06:44,  1.18s/it][A

1.29229736328125



 32%|███▏      | 158/500 [03:09<06:42,  1.18s/it][A

1.2587890625



 32%|███▏      | 159/500 [03:10<06:43,  1.18s/it][A

1.25457763671875



 32%|███▏      | 160/500 [03:11<06:42,  1.18s/it][A

1.2008056640625



 32%|███▏      | 161/500 [03:12<06:42,  1.19s/it][A

1.17742919921875



 32%|███▏      | 162/500 [03:13<06:40,  1.18s/it][A

1.16253662109375



 33%|███▎      | 163/500 [03:15<06:38,  1.18s/it][A

1.12103271484375



 33%|███▎      | 164/500 [03:16<06:38,  1.19s/it][A

1.106109619140625



 33%|███▎      | 165/500 [03:17<06:38,  1.19s/it][A

1.097320556640625



 33%|███▎      | 166/500 [03:18<06:36,  1.19s/it][A

1.093963623046875



 33%|███▎      | 167/500 [03:19<06:34,  1.19s/it][A

1.040374755859375



 34%|███▎      | 168/500 [03:20<06:37,  1.20s/it][A

1.072052001953125



 34%|███▍      | 169/500 [03:22<06:38,  1.20s/it][A

0.9925537109375



 34%|███▍      | 170/500 [03:23<06:42,  1.22s/it][A

1.0013427734375



 34%|███▍      | 171/500 [03:24<06:41,  1.22s/it][A

0.982818603515625



 34%|███▍      | 172/500 [03:25<06:41,  1.23s/it][A

0.946075439453125



 35%|███▍      | 173/500 [03:27<06:38,  1.22s/it][A

0.870147705078125



 35%|███▍      | 174/500 [03:28<06:35,  1.21s/it][A

0.80950927734375



 35%|███▌      | 175/500 [03:29<06:32,  1.21s/it][A

0.743133544921875



 35%|███▌      | 176/500 [03:30<06:31,  1.21s/it][A

0.713409423828125



 35%|███▌      | 177/500 [03:31<06:30,  1.21s/it][A

0.646484375



 36%|███▌      | 178/500 [03:33<06:30,  1.21s/it][A

0.65338134765625



 36%|███▌      | 179/500 [03:34<06:26,  1.20s/it][A

0.6646728515625



 36%|███▌      | 180/500 [03:35<06:22,  1.20s/it][A

0.69219970703125



 36%|███▌      | 181/500 [03:36<06:20,  1.19s/it][A

0.66790771484375



 36%|███▋      | 182/500 [03:37<06:16,  1.19s/it][A

0.65673828125



 37%|███▋      | 183/500 [03:39<06:14,  1.18s/it][A

0.64093017578125



 37%|███▋      | 184/500 [03:40<06:13,  1.18s/it][A

0.62591552734375



 37%|███▋      | 185/500 [03:41<06:13,  1.18s/it][A

0.6004638671875



 37%|███▋      | 186/500 [03:42<06:14,  1.19s/it][A

0.591064453125



 37%|███▋      | 187/500 [03:43<06:14,  1.20s/it][A

0.5601806640625



 38%|███▊      | 188/500 [03:45<06:11,  1.19s/it][A

0.55548095703125



 38%|███▊      | 189/500 [03:46<06:11,  1.19s/it][A

0.52294921875



 38%|███▊      | 190/500 [03:47<06:10,  1.20s/it][A

0.5224609375



 38%|███▊      | 191/500 [03:48<06:08,  1.19s/it][A

0.49017333984375



 38%|███▊      | 192/500 [03:49<06:06,  1.19s/it][A

0.490966796875



 39%|███▊      | 193/500 [03:50<06:05,  1.19s/it][A

0.45745849609375



 39%|███▉      | 194/500 [03:52<06:06,  1.20s/it][A

0.4656982421875



 39%|███▉      | 195/500 [03:53<06:07,  1.21s/it][A

0.43994140625



 39%|███▉      | 196/500 [03:54<06:04,  1.20s/it][A

0.44140625



 39%|███▉      | 197/500 [03:55<06:02,  1.20s/it][A

0.41229248046875



 40%|███▉      | 198/500 [03:56<05:58,  1.19s/it][A

0.42816162109375



 40%|███▉      | 199/500 [03:58<05:56,  1.18s/it][A

0.40838623046875



 40%|████      | 200/500 [03:59<05:55,  1.19s/it][A

0.414306640625



 40%|████      | 201/500 [04:00<05:53,  1.18s/it][A

0.39208984375



 40%|████      | 202/500 [04:01<05:52,  1.18s/it][A

0.4114990234375



 41%|████      | 203/500 [04:02<05:55,  1.20s/it][A

0.3883056640625



 41%|████      | 204/500 [04:04<05:53,  1.19s/it][A

0.40069580078125



 41%|████      | 205/500 [04:05<05:51,  1.19s/it][A

0.3756103515625



 41%|████      | 206/500 [04:06<05:50,  1.19s/it][A

0.38580322265625



 41%|████▏     | 207/500 [04:07<05:48,  1.19s/it][A

0.36419677734375



 42%|████▏     | 208/500 [04:08<05:45,  1.18s/it][A

0.385986328125



 42%|████▏     | 209/500 [04:10<05:44,  1.18s/it][A

0.36053466796875



 42%|████▏     | 210/500 [04:11<05:42,  1.18s/it][A

0.366455078125



 42%|████▏     | 211/500 [04:12<05:43,  1.19s/it][A

0.3460693359375



 42%|████▏     | 212/500 [04:13<05:42,  1.19s/it][A

0.37982177734375



 43%|████▎     | 213/500 [04:14<05:40,  1.19s/it][A

0.34814453125



 43%|████▎     | 214/500 [04:15<05:40,  1.19s/it][A

0.35430908203125



 43%|████▎     | 215/500 [04:17<05:38,  1.19s/it][A

0.33331298828125



 43%|████▎     | 216/500 [04:18<05:37,  1.19s/it][A

0.3704833984375



 43%|████▎     | 217/500 [04:19<05:35,  1.19s/it][A

0.33447265625



 44%|████▎     | 218/500 [04:20<05:36,  1.19s/it][A

0.34857177734375



 44%|████▍     | 219/500 [04:21<05:34,  1.19s/it][A

0.3231201171875



 44%|████▍     | 220/500 [04:23<05:36,  1.20s/it][A

0.35693359375



 44%|████▍     | 221/500 [04:24<05:34,  1.20s/it][A

0.3211669921875



 44%|████▍     | 222/500 [04:25<05:34,  1.20s/it][A

0.34619140625



 45%|████▍     | 223/500 [04:26<05:31,  1.20s/it][A

0.3143310546875



 45%|████▍     | 224/500 [04:27<05:28,  1.19s/it][A

0.3446044921875



 45%|████▌     | 225/500 [04:29<05:26,  1.19s/it][A

0.30987548828125



 45%|████▌     | 226/500 [04:30<05:24,  1.19s/it][A

0.34661865234375



 45%|████▌     | 227/500 [04:31<05:23,  1.19s/it][A

0.3060302734375



 46%|████▌     | 228/500 [04:32<05:23,  1.19s/it][A

0.3375244140625



 46%|████▌     | 229/500 [04:33<05:22,  1.19s/it][A

0.30084228515625



 46%|████▌     | 230/500 [04:35<05:20,  1.19s/it][A

0.3443603515625



 46%|████▌     | 231/500 [04:36<05:19,  1.19s/it][A

0.297607421875



 46%|████▋     | 232/500 [04:37<05:17,  1.18s/it][A

0.3365478515625



 47%|████▋     | 233/500 [04:38<05:16,  1.18s/it][A

0.29296875



 47%|████▋     | 234/500 [04:39<05:14,  1.18s/it][A

0.341552734375



 47%|████▋     | 235/500 [04:40<05:13,  1.18s/it][A

0.2891845703125



 47%|████▋     | 236/500 [04:42<05:15,  1.20s/it][A

0.34161376953125



 47%|████▋     | 237/500 [04:43<05:16,  1.20s/it][A

0.2852783203125



 48%|████▊     | 238/500 [04:44<05:14,  1.20s/it][A

0.345458984375



 48%|████▊     | 239/500 [04:45<05:15,  1.21s/it][A

0.2825927734375



 48%|████▊     | 240/500 [04:47<05:13,  1.21s/it][A

0.34942626953125



 48%|████▊     | 241/500 [04:48<05:11,  1.20s/it][A

0.28106689453125



 48%|████▊     | 242/500 [04:49<05:08,  1.19s/it][A

0.3438720703125



 49%|████▊     | 243/500 [04:50<05:06,  1.19s/it][A

0.27813720703125



 49%|████▉     | 244/500 [04:51<05:04,  1.19s/it][A

0.34381103515625



 49%|████▉     | 245/500 [04:52<05:06,  1.20s/it][A

0.27569580078125



 49%|████▉     | 246/500 [04:54<05:03,  1.19s/it][A

0.33880615234375



 49%|████▉     | 247/500 [04:55<05:02,  1.19s/it][A

0.272216796875



 50%|████▉     | 248/500 [04:56<04:59,  1.19s/it][A

0.33489990234375



 50%|████▉     | 249/500 [04:57<04:57,  1.19s/it][A

0.2689208984375



 50%|█████     | 250/500 [04:58<04:55,  1.18s/it][A

0.33538818359375



 50%|█████     | 251/500 [05:00<04:53,  1.18s/it][A

0.26531982421875



 50%|█████     | 252/500 [05:01<04:51,  1.18s/it][A

0.32574462890625



 51%|█████     | 253/500 [05:02<04:51,  1.18s/it][A

0.261474609375



 51%|█████     | 254/500 [05:03<04:49,  1.18s/it][A

0.3326416015625



 51%|█████     | 255/500 [05:04<04:48,  1.18s/it][A

0.25787353515625



 51%|█████     | 256/500 [05:05<04:48,  1.18s/it][A

0.3121337890625



 51%|█████▏    | 257/500 [05:07<04:46,  1.18s/it][A

0.25390625


KeyboardInterrupt: ignored

In [0]:
# for index,row in df_valid.iterrows():
#   try:
#     df_valid.iloc[index]['label'].split()
#   except:
#     print(index,df_valid.iloc[index]['label'])  

# print(3465,df_valid.iloc[index])      

3465 sent           
label       NaN
encoding     []
Name: 3466, dtype: object


# Testing

In [0]:
def process_input(file):
    valid_data=[]
    with open(file,'r') as f:
      line=f.readline()
      while(line):
        line=f.readline()
        if(line.strip()==''):
          valid_data.append({'sent':'','label':''})
        else: 
          line=line.split()  
          valid_data[-1]['sent']+=' '+str(line[0])
          valid_data[-1]['label']+=' '+str(line[-1])
    df=pd.DataFrame(valid_data)
    df['label'].replace('',np.nan,inplace=True)
    df.dropna(inplace=True)
    df['encoding']=df['sent'].apply(lambda x:convert_to_ints(x))
    target=[]
    return df    

In [27]:
vocab_size=len(vocab_to_int)
hidden_size=4
embedding_size=50
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# model=BiLSTM_CRF(len(vocab_to_int),hidden_size,embedding_size,tag_to_index)
# model.load_state_dict(torch.load('/content/drive/My Drive/model.pt'))
# model.to(device)
valid_df=process_input('/content/drive/My Drive/train.txt')
i=0
for index,row in valid_df.iterrows():
    # print(row['sent'])
    prediction=model._predict(torch.tensor(row['encoding']).to(device))
    print(row['sent'])
    print(row['label'])
    # print(row['encoding'])
    print([index_to_tag[int(t)] for t in prediction])
    # print(row['label'])
    i+=1
    if(i%10==0):
      break

 EU rejects German call to boycott British lamb .
 B-ORG O B-MISC O O O B-MISC O O
['O', 'O', 'B-MISC', 'O', 'O', 'O', 'B-MISC', 'O', 'O']
 Peter Blackburn
 B-PER I-PER
['O', 'B-PER']
 BRUSSELS 1996-08-22
 B-LOC O
['O', 'O']
 The European Commission said on Thursday it disagreed with German advice to consumers to shun British lamb until scientists determine whether mad cow disease can be transmitted to sheep .
 O B-ORG I-ORG O O O O O O B-MISC O O O O O B-MISC O O O O O O O O O O O O O O
['O', 'B-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-MISC', 'O', 'O', 'O', 'O', 'O', 'B-MISC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
 Germany 's representative to the European Union 's veterinary committee Werner Zwingmann said on Wednesday consumers should buy sheepmeat from countries other than Britain until the scientific advice was clearer .
 B-LOC O O O O B-ORG I-ORG O O O B-PER I-PER O O O O O O O O O O O B-LOC O O O O O O O
['O', 'O', 'O', 'O', 'O', 'B-ORG', 'O', 

In [28]:
index_to_tag

{0: 'B-LOC',
 1: 'B-MISC',
 2: 'B-ORG',
 3: 'B-PER',
 4: 'I-LOC',
 5: 'I-MISC',
 6: 'I-ORG',
 7: 'I-PER',
 8: 'O'}