In [1]:
!pip install conllu

Collecting conllu
  Downloading conllu-4.5.3-py2.py3-none-any.whl (16 kB)
Installing collected packages: conllu
Successfully installed conllu-4.5.3


In [2]:
from conllu import parse,TokenList,Token
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset
from torchtext.vocab import build_vocab_from_iterator,Vocab
from torch.utils.data import DataLoader
from collections import Counter
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import f1_score,recall_score,precision_score,accuracy_score,confusion_matrix

In [3]:
UNKNOWN_TOKEN = "<unk>"
PAD_TOKEN = "<pad>"
EMBEDDING_SIZE=256
HIDDEN_DIM=256
NUM_STACKS=2
BATCH_SIZE=256
lrate=0.001
EPOCHS=20

In [4]:
def filter_sentences_by_tag(sentences, pos_tags, tag_to_exclude='SYM'): # removing SYM tag sentences
    filtered_sentences = []
    filtered_pos_tags = []
    for sentence, tags in zip(sentences, pos_tags):
        if tag_to_exclude not in tags:
            filtered_sentences.append(sentence)
            filtered_pos_tags.append(tags)
    return filtered_sentences, filtered_pos_tags
def extract_tokens_and_tags(sentences):
    token_sequences = []
    tag_sequences = []
    for sentence in sentences:
        tokens = []
        tags = []
        for token in sentence:
            tokens.append(token["form"])
            tags.append(token["upos"])
        token_sequences.append(tokens)
        tag_sequences.append(tags)
    return token_sequences, tag_sequences
def replace_low_frequency_words(sentences, threshold=3):
    word_counts = Counter(word for sentence in sentences for word in sentence)
    replaced_sentences = [
        [UNKNOWN_TOKEN if word_counts[word] < threshold else word for word in sentence]
        for sentence in sentences
    ]
    return replaced_sentences

In [5]:
class EntityDataset_LSTM(Dataset):
  def __init__(self, sent, labs, vocabulary:Vocab|None=None):
    """Initialize the dataset. Setup Code goes here"""
    self.sentences = sent
    self.labels = labs
    if vocabulary is None:
      self.vocabulary = build_vocab_from_iterator(self.sentences, specials=[UNKNOWN_TOKEN, PAD_TOKEN])
      self.vocabulary.set_default_index(self.vocabulary[UNKNOWN_TOKEN])
    else:
      self.vocabulary = vocabulary

  def __len__(self) -> int:
    """Returns number of datapoints."""
    return len(self.sentences)

  def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor]:
    """Get the datapoint at `index`."""
    return torch.tensor(self.vocabulary.lookup_indices(self.sentences[index])), torch.tensor(self.labels[index])

  def collate(self, batch: list[tuple[torch.Tensor, torch.Tensor]]) -> tuple[torch.Tensor, torch.Tensor]:
    """Given a list of datapoints, batch them together"""
    sentences = [i[0] for i in batch]
    labels = [i[1] for i in batch]
    padded_sentences = pad_sequence(sentences, batch_first=True, padding_value=self.vocabulary[PAD_TOKEN])
    padded_labels = pad_sequence(labels, batch_first=True, padding_value=torch.tensor(0))
    return padded_sentences, padded_labels

In [6]:
def get_datasets_LSTM():
  with open('/content/en_atis-ud-train.conllu') as f:
    train_sentences = parse(f.read())
  tok_seq,tag_seq=extract_tokens_and_tags(train_sentences)
  tok_seq=replace_low_frequency_words(tok_seq)
  unique_tags = set(tag for tags in tag_seq for tag in tags)
  tag_to_id = {tag: idx+1 for idx, tag in enumerate(sorted(unique_tags))}
  print(tag_to_id)
  new_tag_seq=[[tag_to_id[tag] for tag in tags] for tags in tag_seq]
  train_dataset=EntityDataset_LSTM(tok_seq,new_tag_seq)
  with open('/content/en_atis-ud-test.conllu') as f:
    test_sentences = parse(f.read())
  with open('/content/en_atis-ud-dev.conllu') as f:
    val_sentences = parse(f.read())
  test_toks,test_tags=extract_tokens_and_tags(test_sentences)
  val_toks,val_tags=extract_tokens_and_tags(val_sentences)
  val_toks,val_tags=filter_sentences_by_tag(val_toks,val_tags)
  test_tags=[[tag_to_id[tag] for tag in tags] for tags in test_tags]
  val_tags=[[tag_to_id[tag] for tag in tags] for tags in val_tags]
  val_dataset=EntityDataset_LSTM(val_toks,val_tags,vocabulary=train_dataset.vocabulary)
  test_dataset=EntityDataset_LSTM(test_toks,test_tags,vocabulary=train_dataset.vocabulary)
  return train_dataset,val_dataset,test_dataset,tag_to_id

In [7]:
train_dataset,val_dataset,test_dataset,tag_to_id=get_datasets_LSTM()

{'ADJ': 1, 'ADP': 2, 'ADV': 3, 'AUX': 4, 'CCONJ': 5, 'DET': 6, 'INTJ': 7, 'NOUN': 8, 'NUM': 9, 'PART': 10, 'PRON': 11, 'PROPN': 12, 'VERB': 13}


In [8]:
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,collate_fn=train_dataset.collate)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE,collate_fn=val_dataset.collate)
test_dataloader=DataLoader(test_dataset, batch_size=BATCH_SIZE,collate_fn=test_dataset.collate)

In [9]:
class LSTMModel(torch.nn.Module):
  def __init__(self, embedding_dim: int, hidden_dim: int, vocabulary_size: int, tagset_size: int, stacks: int):
    super().__init__()
    self.embedding_module = torch.nn.Embedding(vocabulary_size, embedding_dim)
    self.lstm = torch.nn.LSTM(embedding_dim, hidden_dim, stacks)
    self.hidden_to_tag = torch.nn.Linear(hidden_dim, tagset_size)
  def forward(self, sentence: torch.Tensor):
    embeddings = self.embedding_module(sentence)
    lstm_out, _ = self.lstm(embeddings)
    return self.hidden_to_tag(lstm_out)

In [10]:
device = "mps" if torch.cuda.is_available() else "cpu"
device = torch.device("mps")
device

'cpu'

In [11]:
entity_predictor=LSTMModel(EMBEDDING_SIZE,HIDDEN_DIM,len(train_dataset.vocabulary),len(tag_to_id)+1,NUM_STACKS)
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(entity_predictor.parameters(),lr=lrate)
entity_predictor = entity_predictor.to(device)
dev_set_acc=[]
for epoch_num in range(EPOCHS):
  entity_predictor.train()
  for batch_num, (words, tags) in enumerate(train_dataloader):
    (words, tags) = (words.to(device), tags.to(device))
    one_hot_tags=(torch.nn.functional.one_hot(tags, num_classes=len(tag_to_id)+1)).float()
    pred = entity_predictor(words)
    loss = loss_fn(pred, one_hot_tags)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
  # entity_predictor.eval()
  # with torch.no_grad():
  #   test_loss = 0
  #   for batch_num, (words, tags) in enumerate(val_dataloader):
  #     (words, tags) = (words.to(device), tags.to(device))
  #     one_hot_tags=(torch.nn.functional.one_hot(tags, num_classes=len(tag_to_id)+1)).float()
  #     pred = entity_predictor(words)
  #     test_loss += (loss_fn(pred, one_hot_tags)).item()
  # print(f"Validation error: {test_loss/len(val_dataloader)}")
  entity_predictor.eval()
  predictions=[]
  true_vals=[]
  with torch.no_grad():
    for batch_num, (words, tags) in enumerate(val_dataloader):
      (words, tags) = (words.to(device), tags.to(device))
      pred = entity_predictor(words)
      pred_max_index = torch.argmax(pred, dim=2)
      true_vals.extend(tags.flatten().cpu())
      predictions.extend(pred_max_index.flatten().cpu())
  predictions=torch.stack(predictions).numpy()
  true_vals=torch.stack(true_vals).numpy()
  dev_set_acc.append(accuracy_score(true_vals,predictions))
plt.bar(range(1,EPOCHS+1), dev_set_acc)
plt.xlabel('epoch')
plt.ylabel('Accuracy on dev set')
plt.title('epoch number vs dev set accuracy for LSTM POS Tagger')
plt.show()

In [12]:
entity_predictor.eval()
predictions=[]
true_vals=[]
with torch.no_grad():
  for batch_num, (words, tags) in enumerate(val_dataloader):
    (words, tags) = (words.to(device), tags.to(device))
    pred = entity_predictor(words)
    pred_max_index = torch.argmax(pred, dim=2)
    true_vals.extend(tags.flatten().cpu())
    predictions.extend(pred_max_index.flatten().cpu())
predictions=torch.stack(predictions).numpy()
true_vals=torch.stack(true_vals).numpy()
f1_micro=f1_score(true_vals,predictions,average='micro')
f1_macro=f1_score(true_vals,predictions,average='macro')
rec_micro=recall_score(true_vals,predictions,average='micro')
rec_macro=recall_score(true_vals,predictions,average='macro')
pre_micro=precision_score(true_vals,predictions,average='micro')
pre_macro=precision_score(true_vals,predictions,average='macro')
print(f'Accuracy Score on dev set: {accuracy_score(true_vals,predictions)}')
print(f'F1-Score(micro) on dev set: {f1_micro}')
print(f'F1-Score(macro) on dev set: {f1_macro}')
print(f'Recall(micro) Score on dev set: {rec_micro}')
print(f'Recall(macro) Score on dev set: {rec_macro}')
print(f'Precision(micro) Score on dev set: {pre_micro}')
print(f'Precision(macro) Score on dev set: {pre_macro}')
print(f'Confusion matrix for dev set:\n {confusion_matrix(true_vals,predictions)}')

In [13]:
entity_predictor.eval()
predictions=[]
true_vals=[]
with torch.no_grad():
  for batch_num, (words, tags) in enumerate(test_dataloader):
    (words, tags) = (words.to(device), tags.to(device))
    pred = entity_predictor(words)
    pred_max_index = torch.argmax(pred, dim=2)
    true_vals.extend(tags.flatten().cpu())
    predictions.extend(pred_max_index.flatten().cpu())
predictions=torch.stack(predictions).numpy()
true_vals=torch.stack(true_vals).numpy()
f1_micro=f1_score(true_vals,predictions,average='micro')
f1_macro=f1_score(true_vals,predictions,average='macro')
rec_micro=recall_score(true_vals,predictions,average='micro')
rec_macro=recall_score(true_vals,predictions,average='macro')
pre_micro=precision_score(true_vals,predictions,average='micro')
pre_macro=precision_score(true_vals,predictions,average='macro')
print(f'Accuracy Score on test set: {accuracy_score(true_vals,predictions)}')
print(f'F1-Score(micro) on test set: {f1_micro}')
print(f'F1-Score(macro) on test set: {f1_macro}')
print(f'Recall(micro) Score on test set: {rec_micro}')
print(f'Recall(macro) Score on test set: {rec_macro}')
print(f'Precision(micro) Score on test set: {pre_micro}')
print(f'Precision(macro) Score on test set: {pre_macro}')
print(f'Confusion matrix for test set:\n {confusion_matrix(true_vals,predictions)}')

In [14]:
# for BATCH_SIZE in [8,16,32,64,128]:
#   for EMBEDDING_SIZE in [32,64,128,256]:
#     for HIDDEN_DIM in [128,256,512]:
#       train_dataset,val_dataset,test_dataset,tag_to_id=get_datasets_LSTM()
#       train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,collate_fn=train_dataset.collate)
#       val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE,collate_fn=val_dataset.collate)
#       test_dataloader=DataLoader(test_dataset, batch_size=BATCH_SIZE,collate_fn=test_dataset.collate)
#       entity_predictor=LSTMModel(EMBEDDING_SIZE,HIDDEN_DIM,len(train_dataset.vocabulary),len(tag_to_id)+1,NUM_STACKS)
#       loss_fn = torch.nn.CrossEntropyLoss()
#       optimizer = torch.optim.Adam(entity_predictor.parameters(),lr=lrate)
#       entity_predictor = entity_predictor.to(device)

#       for epoch_num in range(EPOCHS):
#         entity_predictor.train()
#         for batch_num, (words, tags) in enumerate(train_dataloader):
#           (words, tags) = (words.to(device), tags.to(device))
#           one_hot_tags=(torch.nn.functional.one_hot(tags, num_classes=len(tag_to_id)+1)).float()
#           pred = entity_predictor(words)
#           loss = loss_fn(pred, one_hot_tags)
#           optimizer.zero_grad()
#           loss.backward()
#           optimizer.step()
#       entity_predictor.eval()
#       predictions=[]
#       true_vals=[]
#       with torch.no_grad():
#         for batch_num, (words, tags) in enumerate(test_dataloader):
#           (words, tags) = (words.to(device), tags.to(device))
#           pred = entity_predictor(words)
#           pred_max_index = torch.argmax(pred, dim=2)
#           true_vals.extend(tags.flatten().cpu())
#           predictions.extend(pred_max_index.flatten().cpu())
#       predictions=torch.stack(predictions).numpy()
#       true_vals=torch.stack(true_vals).numpy()
#       non_zero_indices = np.where(true_vals != 0)
#       true_vals = true_vals[non_zero_indices]
#       predictions = predictions[non_zero_indices]
#       # f1_micro=f1_score(true_vals,predictions,average='micro')
#       # f1_macro=f1_score(true_vals,predictions,average='macro')
#       # rec_micro=recall_score(true_vals,predictions,average='micro')
#       # rec_macro=recall_score(true_vals,predictions,average='macro')
#       # pre_micro=precision_score(true_vals,predictions,average='micro')
#       # pre_macro=precision_score(true_vals,predictions,average='macro')
#       print(f'Accuracy Score on test set: {BATCH_SIZE, EMBEDDING_SIZE, HIDDEN_DIM, accuracy_score(true_vals,predictions)}')
#       # print(f'F1-Score(micro) on test set: {f1_micro}')
#       # print(f'F1-Score(macro) on test set: {f1_macro}')
#       # print(f'Recall(micro) Score on test set: {rec_micro}')
#       # print(f'Recall(macro) Score on test set: {rec_macro}')
#       # print(f'Precision(micro) Score on test set: {pre_micro}')
#       # print(f'Precision(macro) Score on test set: {pre_macro}')
#       # print(f'Confusion matrix for test set:\n {confusion_matrix(true_vals,predictions)}')

In [15]:
entity_predictor=torch.load('LSTM_POS_Tagger.pt')

In [35]:
import nltk
from nltk.tokenize import word_tokenize
import string
from nltk.tokenize import RegexpTokenizer

In [36]:
def tokenise(text):
    tokenizer=RegexpTokenizer(r'\w+')
    tokens=tokenizer.tokenize(text)
    return tokens
def reverse_dict(original_dict):
    reverse_dict = {v: k for k, v in original_dict.items()}
    return reverse_dict
rev_tag_to_id=reverse_dict(tag_to_id)
print(rev_tag_to_id)

{1: 'ADJ', 2: 'ADP', 3: 'ADV', 4: 'AUX', 5: 'CCONJ', 6: 'DET', 7: 'INTJ', 8: 'NOUN', 9: 'NUM', 10: 'PART', 11: 'PRON', 12: 'PROPN', 13: 'VERB'}


In [37]:
user_inp=input()
tok_sent=tokenise(user_inp)
orig_sent=[]
for x in tok_sent:
  orig_sent.append(x)
print(tok_sent)

hi, they go .
['hi', 'they', 'go']


In [33]:
ntok_sent=tok_sent
with torch.no_grad():
  temp=[train_dataset.vocabulary.lookup_indices(ntok_sent)]
  temp=torch.tensor(temp)
  res=entity_predictor(temp)
  pred_max_index = torch.argmax(res, dim=2)
  max_values = torch.max(res, dim=2).values
  mask = (pred_max_index == 0)
  next_max_indices = torch.argsort(res, dim=2)[:, :, -2]
  pred_max_index = torch.where(mask, next_max_indices, pred_max_index)
  print(pred_max_index)

tensor([[11,  4,  6, 12,  8,  2, 12]])


In [34]:
for a,t in zip(orig_sent,pred_max_index[0]):
    print(a+" "+rev_tag_to_id[int(t)])

what PRON
are AUX
the DET
coach PROPN
flights NOUN
between ADP
dallas PROPN
