#### Reference: https://pytorch-geometric.readthedocs.io/en/latest/notes/introduction.html#learning-methods-on-graphs

In [None]:
from torch_geometric.data import Data

In [None]:
import json
from collections import Counter
import torch
from tqdm import tqdm

In [None]:
def read_json_to_list(pth):
    res = []
    with open(pth, "r") as fin:
        for line in fin:
            res.append(json.loads(line.strip()))
    return res       

In [None]:
hrgs = read_json_to_list("/usr0/home/amadaan/data/audio/LJSpeech-1.1/TTS/hrg.jsonl")

In [None]:
len(hrgs)

## Make vocab, init random embeddings

In [None]:
def get_tokens(hrgs):
    tokens = []
    for hrg in hrgs:
        for word_rep in hrg["hrg"]:
            tokens.extend(get_tokens_from_word_rep(word_rep))
    return tokens

def get_tokens_from_word_rep(word_rep):
    tokens = []
    tokens.append(word_rep["word"])
    for daughter in word_rep["daughters"]:
        tokens.append(daughter["syll"])
    return tokens

In [None]:
def make_vocab(hrgs):
    tokens = Counter(list(get_tokens(hrgs)))
    
    tokens = [w[0] for w in tokens.items() if w[1] > 1]
    tokens.extend([str(i) for i in range(20)])  # position
    tokens.extend(["<W>", "<SYLL>", "<UNK>"])
    tok2id = {w:i for i, w in enumerate(tokens)}
    id2tok = {i:w for w, i in tok2id.items()}
    
    return tok2id, id2tok
    

In [None]:
tok2id, id2tok = make_vocab(hrgs)

In [None]:
def get_tok2id(tok):
    if tok in tok2id:
        return tok2id[tok]
    return tok2id["<UNK>"]

In [None]:
n_embed = 64

In [None]:
embeddings = torch.rand(len(tok2id), n_embed)

### Convert HRGs to PyTorchGeom Objects

In [None]:
def hrg_to_graph(hrg):
    """
    Converts the HRG to graph,
    
    Returns:
        Edge index: (num_edges, 2)
        Node features: (num_nodes, feature_dim)
    """
    words, sylls = [], []
    node_idx = {}
    edges = []
    node_features = []
    for i, word_rep in enumerate(hrg["hrg"]):
        word_node = f"{word_rep['word']}-{i}"
        word_node_id = get_tok2id(word_rep['word'])
        
        node_idx[word_node] = len(node_idx)
        node_features.append(embeddings[word_node_id, :])
        for j, syll in enumerate(word_rep["daughters"]):
            syll_node = f"{syll['syll']}-{i}-{j}"
            syll_node_id = get_tok2id(syll['syll'])
            
            node_idx[syll_node] = len(node_idx)
            
            node_features.append(embeddings[syll_node_id, :])
            edges.append([node_idx[word_node], node_idx[syll_node]])
    return torch.tensor(edges, dtype=torch.long), torch.stack(node_features).float()

In [None]:
py_geom_graphs = []
for hrg in tqdm(hrgs, total=len(hrgs)):
    edge_index, node_features = hrg_to_graph(hrgs[0])
    data = Data(x=node_features, edge_index=edge_index.t().contiguous(), y=node_features.shape[0])
    py_geom_graphs.append(data)
    

In [None]:
from torch_geometric.data import DataLoader

In [None]:
loader = DataLoader(py_geom_graphs, batch_size=32, shuffle=True)


### Sample GCN

In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

In [None]:
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = GCNConv(n_embed, 16)
        self.conv2 = GCNConv(16, 200)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)

        return F.log_softmax(x, dim=1)

In [None]:
conv1 = GCNConv(n_embed, 16).cuda()

In [None]:
x, edge_index = batch.x, batch.edge_index

In [None]:
x

In [None]:
x.shape

In [None]:
x = conv1(x, edge_index)

In [None]:
x

In [None]:
x.shape