In [22]:
import conllu
import torch
from torch.utils.data import Dataset

In [23]:
def get_data(data_file, vocab_index, pos_tag_index):
    TokenLists = conllu.parse_incr(open(data_file, "r", encoding="utf-8"))
    Sentences = []
    Tag_Sequences = []
    for TokenList in TokenLists:
        Sentence = []
        tags = []
        for token in TokenList:
            #print(token["form"], token["upos"])
            Sentence.append(vocab_index[token["form"]])
            tags.append(pos_tag_index[token["upos"]])
        Sentences.append(Sentence)
        Tag_Sequences.append(tags)
    return Sentences, Tag_Sequences


def get_vocab_index(data_file):
    vocab_index = {}
    TokenLists = conllu.parse_incr(open(data_file, "r", encoding="utf-8"))
    for TokenList in TokenLists:
        for token in TokenList:
            if token["form"] not in vocab_index:
                vocab_index[token["form"]] = len(vocab_index)
    return vocab_index

In [24]:
data_file = "./UD_English-Atis/en_atis-ud-dev.conllu"

vocab_index = get_vocab_index(data_file)
pos_tag_index = { "ADJ": 0, "ADP": 1, "ADV": 2, "AUX": 3, "CCONJ": 4, "DET": 5, "INTJ": 6, "NOUN": 7, "NUM": 8, "PART": 9, "PRON": 10, "PROPN": 11, "PUNCT": 12, "SCONJ": 13, "SYM": 14, "VERB": 15, "X": 16}

Sentences, Tag_Sequences = get_data(data_file, vocab_index, pos_tag_index)


In [25]:
for i in range(len(Sentences)):
    for j in range(len(Sentences[i])):
        print(Sentences[i][j], Tag_Sequences[i][j])

0 10
1 3
2 15
3 5
4 0
5 7
6 1
7 11
8 1
9 11
10 15
11 7
12 8
13 0
14 4
15 15
16 7
17 0
0 10
18 15
19 5
5 7
6 1
20 11
8 1
21 11
22 1
23 15
24 5
25 2
26 1
27 8
28 7
29 15
30 10
31 7
32 7
33 7
6 1
34 11
8 1
35 11
10 15
36 0
37 7
14 4
15 15
3 5
38 7
39 1
40 10
41 3
42 11
43 11
44 15
45 1
45 1
46 7
29 15
30 10
47 5
33 7
6 1
9 11
8 1
48 11
49 5
50 15
39 1
51 8
52 2
28 7
3 5
38 7
39 1
53 7
40 5
5 7
6 1
48 11
8 1
7 11
23 15
54 0
55 1
3 5
56 7
57 3
0 10
58 15
59 7
60 1
33 7
6 1
61 11
62 11
8 1
63 11
64 11
10 15
65 1
66 7
67 15
3 5
68 11
69 11
33 7
65 1
70 11
71 15
47 5
33 7
10 15
6 1
72 11
8 1
73 11
74 11
65 1
75 11
69 7
76 15
77 1
55 1
78 11
79 11
40 10
80 3
81 7
82 11
0 10
83 15
84 7
45 1
85 7
86 7
48 11
87 11
40 10
88 3
3 5
89 0
33 7
6 1
90 11
91 11
8 1
92 11
93 15
65 1
94 7
16 7
12 8
95 8
0 10
1 3
2 15
8 1
96 15
19 5
5 7
65 1
97 7
98 0
6 1
99 11
100 11
8 1
7 11
55 1
3 5
101 0
102 7
40 10
80 3
103 7
104 7
105 11
14 4
106 2
29 15
30 10
85 7
86 7
22 1
0 10
57 3
107 15
55 1
108 11
101 0
109 7
29

In [26]:
class PosTagDataset(Dataset):
    def __init__(self, data_file):
        self.vocab_index = get_vocab_index(data_file)
        self.pos_tag_index = { "ADJ": 0, "ADP": 1, "ADV": 2, "AUX": 3, "CCONJ": 4, "DET": 5, "INTJ": 6, "NOUN": 7, "NUM": 8, "PART": 9, "PRON": 10, "PROPN": 11, "PUNCT": 12, "SCONJ": 13, "SYM": 14, "VERB": 15, "X": 16}
        self.Sentences, self.Tag_Sequences = get_data(data_file, vocab_index, pos_tag_index)

    def __len__(self):
        return len(self.Sentences)
    
    def __getitem__(self, idx):
        return torch.LongTensor(self.Sentences[idx]), torch.LongTensor(self.Tag_Sequences[idx])

In [27]:
dataset = PosTagDataset(data_file)

In [29]:
print(dataset[0])

(tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17]), tensor([10,  3, 15,  5,  0,  7,  1, 11,  1, 11, 15,  7,  8,  0,  4, 15,  7,  0]))
