In [None]:
import conllu
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

In [None]:
def get_data(data_file, vocab_index, pos_tag_index):
    TokenLists = conllu.parse_incr(open(data_file, "r", encoding="utf-8"))
    Sentences = []
    Tag_Sequences = []
    for TokenList in TokenLists:
        Sentence = []
        tags = []
        for token in TokenList:
            #print(token["form"], token["upos"])
            Sentence.append(vocab_index[token["form"]])
            tags.append(pos_tag_index[token["upos"]])
        Sentences.append(Sentence)
        Tag_Sequences.append(tags)
    return Sentences, Tag_Sequences


def get_vocab_index(data_file):
    vocab_index = {}
    TokenLists = conllu.parse_incr(open(data_file, "r", encoding="utf-8"))
    for TokenList in TokenLists:
        for token in TokenList:
            if token["form"] not in vocab_index:
                vocab_index[token["form"]] = len(vocab_index)
    return vocab_index


def custom_collate(batch):
    Sentences = [sample[0] for sample in batch]
    PosTags = [sample[1] for sample in batch]

    Sentences = pad_sequence(Sentences, batch_first=True)
    PosTags = pad_sequence(PosTags, batch_first=True)
    return Sentences, PosTags


class PosTagDataset(Dataset):
    def __init__(self, data_file):
        self.vocab_index = get_vocab_index(data_file)
        self.pos_tag_index = { "ADJ": 0, "ADP": 1, "ADV": 2, "AUX": 3, "CCONJ": 4, "DET": 5, "INTJ": 6, "NOUN": 7, "NUM": 8, "PART": 9, "PRON": 10, "PROPN": 11, "PUNCT": 12, "SCONJ": 13, "SYM": 14, "VERB": 15, "X": 16}
        self.Sentences, self.Tag_Sequences = get_data(data_file, vocab_index, pos_tag_index)

    def __len__(self):
        return len(self.Sentences)
    
    def __getitem__(self, idx):
        return torch.LongTensor(self.Sentences[idx]), torch.LongTensor(self.Tag_Sequences[idx])
    
    

In [None]:
data_file = "./UD_English-Atis/en_atis-ud-dev.conllu"

vocab_index = get_vocab_index(data_file)
pos_tag_index = { "ADJ": 0, "ADP": 1, "ADV": 2, "AUX": 3, "CCONJ": 4, "DET": 5, "INTJ": 6, "NOUN": 7, "NUM": 8, "PART": 9, "PRON": 10, "PROPN": 11, "PUNCT": 12, "SCONJ": 13, "SYM": 14, "VERB": 15, "X": 16}

Sentences, Tag_Sequences = get_data(data_file, vocab_index, pos_tag_index)


for i in range(len(Sentences)):
    for j in range(len(Sentences[i])):
        print(Sentences[i][j], Tag_Sequences[i][j])

In [None]:
dataset = PosTagDataset(data_file)
print(dataset[0])

train_dataloader = DataLoader(dataset, batch_size=32, shuffle=False, collate_fn=custom_collate)
for batch in train_dataloader:
    print(batch)