## Download Dataset:

In [None]:
!curl --remote-name-all https://lindat.mff.cuni.cz/repository/xmlui/bitstream/handle/11234/1-3424{/ud-treebanks-v2.7.tgz,/ud-documentation-v2.7.tgz,/ud-tools-v2.7.tgz}

In [None]:
!tar -xzf ud-treebanks-v2.7.tgz

## PreProcessing Code

In [None]:
import numpy as np
import tqdm.notebook as tqdm

def write_sentences_to_file(s, output_file):
    f = open(output_file.split(".")[0] + ".txt", "w")
    for sentence in s:
        f.write(sentence + "\n")
    f.close()

def create_matrices(path, punctuation=True):
    filename = path.split("/")
    filename = filename[-1]
    res, s = create_adjacency_matrix(path, punctuation=punctuation)

    if punctuation:
        addendum = "_punct_"
    else:
        addendum = "_nopunct_"
    np.save(filename.split(".")[0]+"{}.npy".format(addendum), res)
    write_sentences_to_file(s, filename.split(".")[0]+"{}.txt".format(addendum))

def get_matrix(samples):
    r = []
    s = []
    for i, sample in enumerate(samples):
        n = len(sample)
        #print("\r{}/{}".format(i+1,n),end="")
        adjacency = np.zeros((n,n))
        sentence = " ".join([x[1] for x in sample])
        s.append(sentence)
        real_id = {search_id: nid for nid, (search_id, *_) in enumerate(sample)}
        #print(real_id)
        found = True
        for j, (id, word, parent_id) in enumerate(sample):
      
            if parent_id!= 0: # Root word
                # get real id for parent
                try:
                    nid = real_id[parent_id]
                except:
                    found = False
                    break
                adjacency[j][nid] = 1.0
        if found:
            r.append(adjacency)
    return r, s

def create_adjacency_matrix(filename, punctuation = True):
    f = open(filename,"r")
    data = f.readlines()
    f.close()
    samples = []
    sample = []
    for line in tqdm.tqdm(data):
        if line[0] == "#":
            continue
        line = line.rstrip()
        if line != "":
            line = line.split("\t")
            if "-" in line[0]:
                continue
            if "." in line[0]:
                continue
            if line[3] == 'PUNCT' and not punctuation:
                continue
            sample.append([int(line[0]),line[1], int(line[6])]) # Get word and parent id
        else:
            samples.append(sample)
            sample = []
        m, s = get_matrix(samples)
    return np.array(m, dtype=object), s

In [None]:
create_matrices("/content/ud-treebanks-v2.7/UD_English-EWT/en_ewt-ud-train.conllu", punctuation=True)
create_matrices("/content/ud-treebanks-v2.7/UD_English-EWT/en_ewt-ud-train.conllu", punctuation=False)
create_matrices("/content/ud-treebanks-v2.7/UD_English-EWT/en_ewt-ud-dev.conllu", punctuation=True)
create_matrices("/content/ud-treebanks-v2.7/UD_English-EWT/en_ewt-ud-dev.conllu", punctuation=False)
create_matrices("/content/ud-treebanks-v2.7/UD_English-EWT/en_ewt-ud-test.conllu", punctuation=True)
create_matrices("/content/ud-treebanks-v2.7/UD_English-EWT/en_ewt-ud-test.conllu", punctuation=False)

HBox(children=(FloatProgress(value=0.0, max=247689.0), HTML(value='')))

## Dataset Class for EWT

In [3]:
import torch 
from torch.utils.data import Dataset
import numpy as np
from os.path import join
from os import listdir
"""
PennDP: a Dataset class for the Penn TreeBank Dependency Parsed Dataset
path: path to treebank_3 folder
corpus_name: one of 'wsj' or 'brown' (optional, default 'wsj')
split: whether we want a 'train', 'val' or 'test' split for the data
"""

class EWT(Dataset):

    def __init__(self, path, split='train', punctuation=True, tokenizer=None):
        super().__init__()
        # We look for all samples in folder
        self.sample_ids = []
        self.sample_sentences = []
        self.sample_matrices = []
        self.sample_tokens = []
        self.tokenizer = tokenizer
        fullpath = path
        matrices = "en_ewt-ud-{}_{}_.npy".format(split, "punct" if punctuation else 'nopunct')
        sentences= "en_ewt-ud-{}_{}_.txt".format(split, "punct" if punctuation else 'nopunct')
        
        # read sentences
        sentence_path = join(path, sentences)
        # read matrices
        matrix_path = join(path, matrices)
        
        sentence_file = open(sentence_path, "r")
        matrix_file = np.load(matrix_path, allow_pickle=True)
        lines = sentence_file.read().splitlines()
        sentence_file.close()
        # if available, tokenize sentences
        if self.tokenizer is not None:
            examples = tokenizer(lines, 
                                add_special_tokens=True,
                                truncation=True)['input_ids']
            examples =[torch.tensor(e, dtype=torch.long) for e in examples]
            self.sample_tokens.extend(examples)

        lines = [line.split(" ") for line in lines]
        
        for line in lines:
            self.sample_sentences.append(line)
        for m in matrix_file:
            self.sample_matrices.append(m)
        

    def __len__(self):
        return len(self.sample_sentences)

    def __getitem__(self, id):
        if self.tokenizer is None:
            return self.sample_sentences[id], self.sample_matrices[id].transpose()
        else:
            return self.sample_sentences[id], self.sample_matrices[id].transpose(), self.sample_tokens[id]
    


In [6]:
ds = EWT("ewt_proc", "train", punctuation=False)

In [None]:
ds[-1]