In [1]:
from typing import Iterator, List, Dict
import torch
import torch.optim as optim
import torch.nn as nn
import numpy as np
from allennlp.data import Instance
from allennlp.data.fields import TextField, LabelField
from allennlp.data.dataset_readers import DatasetReader
from allennlp.common.file_utils import cached_path
from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
from allennlp.data.tokenizers.token import Token
from allennlp.data.vocabulary import Vocabulary
from allennlp.models import Model
from allennlp.modules.text_field_embedders import TextFieldEmbedder, BasicTextFieldEmbedder
from allennlp.modules.token_embedders import Embedding
from allennlp.modules.seq2vec_encoders import Seq2VecEncoder, PytorchSeq2VecWrapper
from allennlp.nn.util import get_text_field_mask, sequence_cross_entropy_with_logits
from allennlp.training.metrics import CategoricalAccuracy
from allennlp.data.iterators import BucketIterator
from allennlp.training.trainer import Trainer
from allennlp.predictors import SentenceTaggerPredictor

torch.manual_seed(1)

class LinguoDatasetReader(DatasetReader):
    """Dataset reader for preprocessed sentences (tokens separated by spaces) """
    GRAMMATICALITY_labels = ["ungrammatical","grammatical"]
    UG_TYPE_labels = ["WS","VA","AA","RV","G"]
    
    def __init__(self, token_indexers: Dict[str, TokenIndexer] = None) -> None:
        super().__init__(lazy=False)
        self.token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()}
    
    def text_to_instance(self, tokens:List[Token], glabel:int=None, ugType:str=None ):
        sentence_field = TextField(tokens,self.token_indexers)
        fields = {"sentence":sentence_field}
        if glabel:
            glabel_field = LabelField(label=glabel,label_namespace = "grammaticality_labels")
            fields["g_label"] = glabel_field
        if ugType:
            ugType_field = LabelField(label=ugType, label_namespace = "ugtype_labels")
            fields["ug_type"] = ugType_field
        return Instance(fields)
    def _read(self, file_path:str, label:str=None, ugType:str=None) -> Iterator[Instance]:
        with open(file_path) as infile:
            for line in infile:
                elements = line.strip().split()
                label = self.GRAMMATICALITY_labels[int(elements[0])]
                if label == self.GRAMMATICALITY_labels[0]:
                    ugType = elements[1]
                else :
                    ugType = "G"
                sentence = elements [2:]
                yield self.text_to_instance([Token(word) for word in sentence],label,ugType)
                
class AllenLinguo(Model):
    
    def __init__(self,word_embeddings : TextFieldEmbedder,
                encoder : Seq2VecEncoder,
                vocab: Vocabulary) -> None:
        super().__init__(vocab)
        self.word_embeddings = word_embeddings
        self.encoder = encoder
        
        self.hidden2decision = torch.nn.Linear(in_features=encoder.get_output_dim(),
                                              out_features=vocab.get_vocab_size("grammaticality_labels"))
        self.loss_function = nn.CrossEntropyLoss()
        self.accuracy = CategoricalAccuracy()
        
    def forward(self,
               sentence: Dict[str, torch.Tensor],
               g_label: torch.Tensor = None,
               ug_type: torch.Tensor = None) -> torch.Tensor:
        
        mask = get_text_field_mask(sentence)
        
        embeddings = self.word_embeddings(sentence)
        
        encoder_out = self.encoder(embeddings, mask)
        
        tag_logits = self.hidden2decision(encoder_out)
        
        output = {"tag_logits": tag_logits}
        
        if g_label is not None:
            self.accuracy(tag_logits, g_label)
            #print(tag_logits)
            output["loss"] = self.loss_function(tag_logits, g_label)
        
        return output
    
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {"accuracy": self.accuracy.get_metric(reset)}    
            
        
                

In [2]:
training_fn = "/Users/pablo/Dropbox/workspace/darth_linguo/Data/toy_corpus/toy_training-GvsWS"
testing_fn = "/Users/pablo/Dropbox/workspace/darth_linguo/Data/toy_corpus/toy_testing-GvsWS"

reader = LinguoDatasetReader()

train_dataset = reader.read(training_fn)
validation_dataset = reader.read(testing_fn)

vocab = Vocabulary.from_instances(train_dataset + validation_dataset)


1483it [00:00, 11193.97it/s]
371it [00:00, 18373.70it/s]
100%|██████████| 1854/1854 [00:00<00:00, 50295.84it/s]


In [3]:
print(validation_dataset[1].fields["sentence"])
print(validation_dataset[1].fields["g_label"])

TextField of length 86 with text: 
 		[en, este, parlamento, en, el, marco, en, el, ámbito, de, los, países, de, europa, central, y,
		oriental, ;, un, capítulo, especial, de, su, coordinación, ,, es, decir, ,, la, <unk>, <unk>, los,
		aspectos, <unk>, regionales, y, locales, de, los, fondos, estructurales, y, su, <unk>, en, que, no,
		han, realizado, <unk>, <unk>, ,, la, comisión, <unk>, <unk>, de, <unk>, <unk>, de, nuevo, ,, pues,
		,, necesario, <unk>, esa, flexibilidad, de, los, datos, y, ,, claro, está, ,, sin, embargo, ,,
		<unk>, <unk>, en, la, comisión, ., <eos>]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
LabelField with label: ungrammatical in namespace: 'grammaticality_labels'.'


In [4]:
EMBEDDING_DIM = 32
HIDDEN_DIM = 32

token_embedding = Embedding(num_embeddings=vocab.get_vocab_size('tokens'),
                            embedding_dim=EMBEDDING_DIM)

word_embeddings = BasicTextFieldEmbedder({"tokens": token_embedding})

lstm = PytorchSeq2VecWrapper(torch.nn.LSTM(EMBEDDING_DIM, HIDDEN_DIM, batch_first=True))

model = AllenLinguo(word_embeddings, lstm, vocab)

optimizer = optim.SGD(model.parameters(), lr=0.1)

iterator = BucketIterator(batch_size=2, sorting_keys=[("sentence", "num_tokens")])

iterator.index_with(vocab)

trainer = Trainer(model=model,
                  optimizer=optimizer,
                  iterator=iterator,
                  train_dataset=train_dataset,
                  validation_dataset=validation_dataset,
                  patience=10,
                  num_epochs=25)

trainer.train()

predictor = SentenceTaggerPredictor(model, dataset_reader=reader)

# tag_logits = predictor.predict("La proxima vez no habrá otra opción")['tag_logits']

#tag_ids = np.argmax(tag_logits, axis=-1)

# print([model.vocab.get_token_from_index(i, 'labels') for i in tag_ids])

accuracy: 0.4983, loss: 0.7023 ||: 100%|██████████| 742/742 [00:06<00:00, 119.94it/s]
accuracy: 0.5013, loss: 0.6936 ||: 100%|██████████| 186/186 [00:00<00:00, 360.23it/s]
accuracy: 0.4976, loss: 0.7013 ||: 100%|██████████| 742/742 [00:06<00:00, 122.03it/s]
accuracy: 0.5013, loss: 0.6940 ||: 100%|██████████| 186/186 [00:00<00:00, 387.66it/s]
accuracy: 0.4990, loss: 0.7003 ||: 100%|██████████| 742/742 [00:06<00:00, 122.44it/s]
accuracy: 0.5013, loss: 0.6946 ||: 100%|██████████| 186/186 [00:00<00:00, 377.17it/s]
accuracy: 0.5051, loss: 0.6998 ||: 100%|██████████| 742/742 [00:06<00:00, 121.21it/s]
accuracy: 0.4987, loss: 0.6944 ||: 100%|██████████| 186/186 [00:00<00:00, 383.68it/s]
accuracy: 0.6298, loss: 0.6084 ||: 100%|██████████| 742/742 [00:06<00:00, 134.98it/s]
accuracy: 0.5013, loss: 0.6921 ||: 100%|██████████| 186/186 [00:00<00:00, 383.97it/s]
accuracy: 0.8247, loss: 0.3583 ||: 100%|██████████| 742/742 [00:06<00:00, 118.84it/s]
accuracy: 0.9677, loss: 0.1199 ||: 100%|██████████| 18

In [6]:

tag_logits = predictor.predict("La proxima vez no habrá otra opción")['tag_logits']

tag_ids = np.argmax(tag_logits, axis=-1)

In [7]:
tag_ids

1