### An encoding class for easy look up of the vocabularies.

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from collections import Counter

In [None]:
"""
General purpose data encoding classes
"""

class Vocabulary:
    
    def __init__(self, symbols = None):
        
        #dictionary to map the vocabulary with a id to build matrix
        #add "UNK" for unknown word to the initial mapping  
        self.word2idx = dict()
        self.idx2word = []

        if symbols:
            for sym in symbols:
                self.update(sym)

    def update(self, tok):
        
        #takes as input a symbol and build the mapping if it doesnt exist
        if tok not in self.word2idx:
            self.word2idx[tok] = len(self.idx2word)
            self.idx2word.append(tok)

    def lookup(self, tok, update = False):
        
        #find tok id given the string, if the tok does not exist return the idx of "UNK"
        if tok not in self.word2idx:
            if update:
                self.update(tok)
            return self.word2idx["[UNK]"]
            
        return self.word2idx[tok]
    def rev_lookup(self, idx):
        
        #find the tok string given the id
        return self.idx2word[idx]
    
    def __getitem__(self, symbol):
        
        #if the symbol does not exist we see it as unk
        return self.lookup(symbol)
    
    def __len__(self):
        
        return len(self.idx2word)
    

### Load the data
Conll reader reads the conllu file and Data loader gonna load the data and feed the encoded data to the model(either for training or inference).

In [None]:
"""
Functions for reading and writing UD CONLL data
"""
CONLL_FIELDS = ["token", "pos", "features", "deprel"]
MWE_TAGS     = ["outside", "mwehead", "component"]

def readfile(filename):

    istream       = open(filename)
    toks          = []
    deprels       = []
    mwe_tags      = []
    tok_vocab     = Vocabulary(["<bos>", "<eos>", "<unk>"])
    deprel_vocab  = Vocabulary(["[S]"])   #artificial deprel for artificial toks
    mwe_vocab     = Vocabulary(["outside", "mwehead", "component"])
    
    for line in istream:
        line = line.strip()
        if line and line[0] != "#":
           
            tokidx, token, lemma, upos, pos, features, headidx, deprel, extended, _ = line.split()
            if tokidx == "1":
                #beginning of sentence
                toks.append("<bos>")
                mwe_tags.append("outside")
                deprels.append("[S]")

            toks.append(token)
            tok_vocab.update(token)
            
            #make simple mwe tags
            if features.startswith("mwehead"):
                mwe_tags.append("mwehead")
            elif features.startswith("component"):
                mwe_tags.append("component")
            else:
                mwe_tags.append("outside")
                
            #extract deprels information
            deprels.append(deprel)
            deprel_vocab.update(deprel)

        elif len(line) == 0 and toks:
            #beginning of sentence
            toks.append("<eos>")
            mwe_tags.append("outside")
            deprels.append("[S]")
    istream.close()

    return toks, mwe_tags, deprels, tok_vocab, deprel_vocab, mwe_vocab
# [{"token1": "token", "multiword": "mwe", "mwe lemma": "mwe lemma"}, {"token2": "token", "multiword": "mwe"}, {"token3": "token", "multiword": "mwe"}]




In [None]:
toks, mwe_tags, deprels, tok_vocab, deprel_vocab, mwe_vocab = readfile("test.conllu")


# 新段落

In [None]:
print(len(tok_vocab))
print(len(deprel_vocab))
print(len(mwe_vocab))

5924
27
3


In [None]:
class MWEDataset (Dataset):

    def __init__(self,datafilename = None, lst_toks = None, context_size = 1):
        """
        take as input either the path to a conllu file or a list of tokens
        we consider context size as the n preceding and n subsequent words in the text as the context for predicting the next word.
        """
        super(MWEDataset, self).__init__()\
        
        if datafilename:
            self.toks, self.mwe_tags, self.deprels, self.tok_vocab, self.deprel_vocab, self.mwe_vocab = readfile(datafilename)
            
        elif lst_toks:
            self.toks = lst_toks
            self.tok_vocab = Vocabulary(lst_toks)
            
        print('token Vocab size',len(self.tok_vocab))
        self.context_size  = context_size
        self.context = ["<unk>"]*self.context_size + self.toks + ["<unk>"]*self.context_size
            
    @staticmethod
    def from_string(text,token_vocabulary):
        """
        When the input is a string of text, we build the dataset from the string instead of a conllu
        """
        text = text.replace("\n", "")
        data = [tok for tok in text.split('\n')]
        dataset = MWEDataset(from_string = data)
        
        return dataset

    def __len__(self):
        return len(self.toks)

    def __getitem__(self,idx):
        """
        return the X as the concatenation of token, and the tokens in the immediate context window
        Y is a real number representing the idx for mwe tags
        """
       
        context  = self.context[idx : idx + self.context_size] + self.context[idx + self.context_size + 1 : idx + 2*self.context_size + 1]
        X        = torch.tensor([self.tok_vocab[self.toks[idx]]] + [self.tok_vocab[tok] for tok in context]) 
        y_true   = self.mwe_vocab[mwe_tags[idx]]
        

        return X, y_true
        
    def as_strings(self,batch_tensor):
        """
        Returns a string representation of a tensor of word indexes
        """
        out = []
        for line in batch_tensor.tolist():
            out.append([self.tok_vocab.rev_lookup(idx) for idx in line] )
        return out

    def get_loader(self, batch_size=1, num_workers=0, word_dropout=0.):

        def mk_batch(selected_items):
            X_toks  = []
            y_t     = []

            for X, y_true in selected_items:
                X_toks.append(X)
                y_t.append(torch.tensor(y_true))

                
            return torch.stack(X_toks),torch.stack(y_t)

        return DataLoader(self, batch_size=batch_size, num_workers=num_workers, collate_fn=mk_batch)


In [None]:
corpuspath = "test.conllu"
dataset = MWEDataset(corpuspath, context_size = 3)

token Vocab size 5924


In [None]:
for X, y_t in dataset.get_loader(batch_size = 10):
    print(X.shape)
    print(dataset.as_strings(X))
    print(y_t)


[1;30;43m流式输出内容被截断，只能显示最后 5000 行内容。[0m
[['cette', 'non', 'mentionnés', 'dans', 'notice', ',', 'veuillez'], ['notice', 'mentionnés', 'dans', 'cette', ',', 'veuillez', 'en'], [',', 'dans', 'cette', 'notice', 'veuillez', 'en', 'informer'], ['veuillez', 'cette', 'notice', ',', 'en', 'informer', 'votre'], ['en', 'notice', ',', 'veuillez', 'informer', 'votre', 'médecin'], ['informer', ',', 'veuillez', 'en', 'votre', 'médecin', '.'], ['votre', 'veuillez', 'en', 'informer', 'médecin', '.', '<eos>'], ['médecin', 'en', 'informer', 'votre', '.', '<eos>', '<bos>'], ['.', 'informer', 'votre', 'médecin', '<eos>', '<bos>', 'Dans'], ['<eos>', 'votre', 'médecin', '.', '<bos>', 'Dans', 'cette']]
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
torch.Size([10, 7])
[['<bos>', 'médecin', '.', '<eos>', 'Dans', 'cette', 'notice'], ['Dans', '.', '<eos>', '<bos>', 'cette', 'notice', ':'], ['cette', '<eos>', '<bos>', 'Dans', 'notice', ':', '<eos>'], ['notice', '<bos>', 'Dans', 'cette', ':', '<eos>', '<bos>'], [':', 'Da

# 模型

In [None]:
import torch.nn as nn

In [None]:
model = nn.Sequential(
    nn.Linear(7, 100), 
    nn.ReLU(),
    nn.Linear(100, 3),
    nn.Softmax(dim=1)
)
loss = nn.NLLLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

In [30]:
from torch.autograd.grad_mode import no_grad
epoch = 500
l = []
for i in range(epoch):
    for x, y in dataset.get_loader(batch_size=10):
        optimizer.zero_grad()
        # y_b = torch.zeros(len(y), 3)
        # for j in range(len(y)):
        #     y_b[j, y[j]] = 1
        # print(x.size())
        y_hat = model(x.float())
        loss_value = loss(y_hat, y)
        with torch.no_grad():
            l.append(loss_value)
        loss_value.backward()
        optimizer.step()

KeyboardInterrupt: ignored