In [None]:
from torch_geometric.data import Data
import json
from collections import Counter
import torch
from tqdm import tqdm

In [None]:
def read_json_to_list(pth):
    res = []
    with open(pth, "r") as fin:
        for line in fin:
            res.append(json.loads(line.strip().split("|")[1]))
    return res

In [None]:
data = read_json_to_list("/usr0/home/amadaan/data/audio/LJSpeech-1.1/metadata_hrg_3.csv")

In [None]:
def get_tokens_from_hrg(hrgs):
    def _get_tokens_from_word_rep(word_rep):
        tokens = []
        tokens.append(word_rep["word"])
        for daughter in word_rep["daughters"]:
            syllnode = ""
            for syll in daughter:
                tokens.append(syll["syll"])
                syllnode += tokens[-1]
            tokens.append(syllnode)
        return tokens
    tokens = []
    for hrg in hrgs:
        #  print(hrg.keys())
        for word_rep in hrg:
            tokens.extend(_get_tokens_from_word_rep(word_rep))
    return tokens

In [None]:
def init_vocab(hrgs):
    tokens = Counter(list(get_tokens_from_hrg(hrgs)))
    tokens = [w[0] for w in tokens.items() if w[1] > 1]
    tokens.extend([str(i) for i in range(20)])  # position
    tokens.extend(["<W>", "<SYLL>", "<UNK>"])
    tok2id = {w: i for i, w in enumerate(tokens)}
    id2tok = {i: w for w, i in tok2id.items()}
    n_vocab = len(tok2id)
    return tok2id, id2tok, n_vocab


def get_tok2id(tok):
    if tok in tok2id:
        return tok2id[tok]
    return tok2id["<UNK>"]


In [None]:
tok2id, id2tok, n_vocab = init_vocab(data)

In [None]:
def hrg_to_graph(hrg):
    """
    Converts the HRG to graph,

    NOTE: idx -> index, a way to identify each node in the graph
        ids -> id for a token returned by the vocab.
        Idxs are primarily used for specifying the connectivity of the graph
    Returns:
        Edge index: (num_edges, 2)
        Node features: (num_nodes, feature_dim)
    """
    words, sylls = [], []
    node_idx = {}
    node_ids = []
    x = []

    edges = []

    syll_node_idxs = []
    for i, word_rep in enumerate(hrg):
        word_node = f"{word_rep['word']}-{i}"
        word_node_id = get_tok2id(word_rep['word'])
        node_idx[word_node] = len(node_idx)
        x.append(word_node_id)

        for j, daughter in enumerate(word_rep["daughters"]):
            # make syll node
            syll_parent_node = ""
            for syll in daughter:
                syll_parent_node += syll["syll"]
            syll_parent_node_id = get_tok2id(syll_parent_node)
            x.append(syll_parent_node_id)
            syll_parent_node = f"{syll_parent_node}-{i}-{j}"
            node_idx[syll_parent_node] = len(node_idx)
            edges.append([node_idx[word_node], node_idx[syll_parent_node]])
            # now prepare phone nodes
            for k, syll in enumerate(daughter):
                
                syll_node = f"{syll['syll']}-{i}-{j}-{k}"
                syll_node_id = get_tok2id(syll['syll'])
                node_idx[syll_node] = len(node_idx)

                x.append(syll_node_id)
                syll_node_idxs.append(node_idx[syll_node])

                edges.append([node_idx[syll_parent_node], node_idx[syll_node]])

    return Data(x=torch.tensor(x, dtype=torch.long), edge_index=torch.tensor(edges, dtype=torch.long).contiguous().t(),
                syll_nodes=torch.tensor(syll_node_idxs, dtype=torch.long))


In [None]:
hrg_to_graph(data[0])

In [None]:
hrg_to_graph(data[3][:2])

In [None]:
d = hrg_to_graph(data[3][:2]); d

In [None]:
d.edge_index

In [None]:
d.x