<a href="https://colab.research.google.com/github/Chrisa142857/geom_tokenizer/blob/master/geom_token_nodelevel_exp.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! pip install transformers torch_geometric

Collecting transformers
  Downloading transformers-4.32.1-py3-none-any.whl (7.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.5/7.5 MB[0m [31m20.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch_geometric
  Downloading torch_geometric-2.3.1.tar.gz (661 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m661.6/661.6 kB[0m [31m31.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting huggingface-hub<1.0,>=0.15.1 (from transformers)
  Downloading huggingface_hub-0.16.4-py3-none-any.whl (268 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m268.8/268.8 kB[0m [31m29.9 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers)
  Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[2K 

In [None]:
from torch_geometric.data import Data
import torch
import torch.optim as optim
from tqdm import trange, tqdm
import transformers
import random
from datetime import datetime
from itertools import combinations


def geom_tokenizer(node_feat: torch.Tensor, edge_index: torch.Tensor, N: int, dim: int=3):
    nids = torch.arange(len(node_feat))
    geom_tokens = []
    token_count = []
    dis_sorts = []
    view_embeds = []
    node_embeds = []
    for ni in tqdm(nids, desc='Prepare tokens...'):
        distances1 = ((node_feat[ni] - node_feat[ni+1:]) **2 ).sum(1)
        if len(distances1) > 0:
            mind = distances1.min()
            maxd = distances1.max()
        distances2 = ((node_feat[ni] - node_feat[:ni]) **2 ).sum(1)
        if len(distances2) > 0:
            mind = min(mind, distances2.min())
            maxd = max(maxd, distances2.max())

        distances = torch.cat([distances1, torch.FloatTensor([maxd+1]), distances2]) # X
        ## spatially close nodes are neighborhood
        dis_sort = distances.argsort()
        dis_sorts.append(dis_sort)
        nei_nid = dis_sort[:N]
        ## connected nodes are neighborhood
        connected_node = edge_index[1, edge_index[0] == ni]
        nei_conn = (connected_node==nei_nid[..., None])
        connected_node = connected_node[~(nei_conn.any(0))]
        connected_node = connected_node[distances[connected_node].argsort()] # also sort connected nodes
        nei_nid = torch.cat([nei_nid, connected_node]) # concat in the sorted rank
        ## geom level 3, triangle, it has dim-1 = 2 view tokens
        ## view token, max = N, the neighbor num
        # view_id = torch.stack(torch.meshgrid(torch.arange(len(nei_nid)), torch.arange(len(nei_nid))), -1) # N x N x 2
        # indices = torch.triu_indices(len(nei_nid), len(nei_nid), offset=1) # M = N x (N-1) / 2
        # view_id = view_id[indices[0], indices[1]].T # 2 x M
        view_id = torch.LongTensor(list(combinations(torch.arange(len(nei_nid)), dim-1))).T # dim-1 x M
        # ## view nodes are sorted by distance
        view_node = [nei_nid[view_id[di]] for di in range(dim-1)] # dim-1 x M, each is a node id
        ## pos token, max = node num
        # pos_token = torch.LongTensor([ni for _ in range(len(view_node[0]))]) # M
        ## geom token, max = 2**3 = 8
        nei_pair = torch.LongTensor(list(combinations(torch.arange(dim-1), 2))) # dim-1 * (dim-2) / 2 x 2, if neighbors connected
        geom_token = torch.stack([torch.zeros_like(view_node[0]) for _ in range(dim-1+len(nei_pair))]) # dim x M
        where_edges = torch.cat([nei_conn.any(1), torch.ones_like(connected_node, dtype=bool)])
        # view_id[0] < view_id[1]
        for di in range(dim-1):
            where_edge = where_edges[view_id[di]]
            geom_token[di, where_edge] = 1
        ## if view points are connected
        for di in range(len(nei_pair)):
            where_edge = []
            views = torch.stack([view_node[nei_pair[di, 0]], view_node[nei_pair[di, 1]]], -1)
            for view in views: # for 2 in M x 2
                where_edge.append((view == edge_index.T).any())
            where_edge = torch.where(torch.stack(where_edge))[0]
            geom_token[dim-1+di, where_edge] = 1
        ## convert token to 10-base number
        geom_token = bin2dec(geom_token) # M
        geom_tokens.append(geom_token)
        token_count.append(len(view_node[0]))
        ## embed of geom direction from view points to cur node
        view_embed = torch.stack([node_feat[view_node[di]] - node_feat[ni] for di in range(dim-1)]).sum(0) # M x C
        view_embeds.append(view_embed)
        node_embed = torch.stack([node_feat[ni] for _ in range(len(view_node[0]))]) # M x C
        node_embeds.append(node_embed)

    geom_tokens = torch.cat(geom_tokens)
    dis_sorts = torch.stack(dis_sorts)
    token_count = torch.FloatTensor(token_count)
    view_dirs = torch.cat(view_embeds) #
    node_embeds = torch.cat(node_embeds) #
    # return pos_tokens, geom_tokens, view_tokens, node_embeds, token_count, dis_sorts
    return geom_tokens, view_dirs, node_embeds, token_count, dis_sorts

def geom_tokenizer_onenode(ni: int, node_feat: torch.Tensor, edge_index: torch.Tensor, N: int, dim: int=3):
    distances1 = ((node_feat[ni] - node_feat[ni+1:]) **2 ).sum(1)
    if len(distances1) > 0:
        mind = distances1.min()
        maxd = distances1.max()
    distances2 = ((node_feat[ni] - node_feat[:ni]) **2 ).sum(1)
    if len(distances2) > 0:
        mind = min(mind, distances2.min())
        maxd = max(maxd, distances2.max())

    distances = torch.cat([distances1, torch.FloatTensor([maxd+1]), distances2]) # X
    ## spatially close nodes are neighborhood
    dis_sort = distances.argsort()
    # dis_sorts.append(dis_sort)
    nei_nid = dis_sort[:N]
    ## connected nodes are neighborhood
    connected_node = edge_index[1, edge_index[0] == ni]
    nei_conn = (connected_node==nei_nid[..., None])
    connected_node = connected_node[~(nei_conn.any(0))]
    connected_node = connected_node[distances[connected_node].argsort()[:N]] # also sort connected nodes
    nei_nid = torch.cat([nei_nid, connected_node]) # concat in the sorted rank
    ## view token, max = N, the neighbor num
    view_id = torch.LongTensor(list(combinations(torch.arange(len(nei_nid)), dim-1))).T # dim-1 x M
    # ## view nodes are sorted by distance
    view_node = [nei_nid[view_id[di]] for di in range(dim-1)] # dim-1 x M, each is a node id
    ## geom token, max = 2**3 = 8
    nei_pair = torch.LongTensor(list(combinations(torch.arange(dim-1), 2))) # dim-1 * (dim-2) / 2 x 2, if neighbors connected
    geom_token = torch.stack([torch.zeros_like(view_node[0]) for _ in range(dim-1+len(nei_pair))]) # dim x M
    where_edges = torch.cat([nei_conn.any(1), torch.ones_like(connected_node, dtype=bool)])
    # view_id[0] < view_id[1]
    for di in range(dim-1):
        where_edge = where_edges[view_id[di]]
        geom_token[di, where_edge] = 1
    ## if view points are connected
    for di in range(len(nei_pair)):
        where_edge = []
        views = torch.stack([view_node[nei_pair[di, 0]], view_node[nei_pair[di, 1]]], -1) # pair of views
        for view in views: # for 2 in M x 2
            where_edge.append((view == edge_index.T).any())
        where_edge = torch.where(torch.stack(where_edge))[0]
        geom_token[dim-1+di, where_edge] = 1
    ## convert token to 10-base number
    geom_token = bin2dec(geom_token) # M
    token_count = torch.FloatTensor([len(view_node[0])])
    ## embed of geom direction from view points to cur node
    view_dir = torch.stack([node_feat[view_node[di]] - node_feat[ni] for di in range(dim-1)]).sum(0) # M x C
    node_embed = torch.stack([node_feat[ni] for _ in range(len(view_node[0]))]) # M x C

    return geom_token, view_dir, node_embed, token_count, dis_sort

def bin2dec(b):
    bits, batch = b.shape
    mask = 2 ** torch.arange(bits - 1, -1, -1).to(b.device, b.dtype)
    mask = torch.stack([mask for _ in range(batch)], 1)
    return torch.sum(mask * b, 0)

def token_zeropad(tokens, token_count, seq_len, datay, *args):
    '''
    tokens: M
    M = token number
    # n = token type number
    '''
    batches = []
    labels = []
    masks = [] # mask for zero padding
    token_cumsum = torch.cumsum(token_count, 0)
    token_cumsum = token_cumsum.long()
    for i in range(len(token_cumsum)):
        prev = 0 if i == 0 else token_cumsum[i-1]
        seq = tokens[prev:token_cumsum[i].item()]
        seq = seq[:seq_len]
        mask = torch.ones(seq_len, dtype=bool, device=seq.device)
        if len(seq) < seq_len:
            if len(seq.shape) > 1:
              one = torch.cat([seq, torch.zeros((seq_len-len(seq), seq.shape[1]), device=seq.device, dtype=seq.dtype)])
            else:
              one = torch.cat([seq, torch.zeros(seq_len-len(seq), device=seq.device, dtype=seq.dtype)])
            mask[len(seq):] = False
        else:
            one = seq
        labels.append(datay[i])
        batches.append(one)
        masks.append(mask)
    batches = torch.stack(batches)
    masks = torch.stack(masks)
    labels = torch.LongTensor(labels)
    return batches, masks, labels

def token_neighborpad(tokens, token_count, seq_len, datay, dis_sorts):
    '''
    tokens: M
    M = token number
    # n = token type number
    '''
    batches = []
    labels = []
    masks = [] # mask for class token
    token_count = token_count.long()
    token_cumsum = torch.cumsum(token_count, 0)
    for i in range(len(token_count)):
        prev = 0 if i == 0 else token_count[i-1]
        seq = tokens[prev:prev+token_count[i]]
        mask = torch.ones(seq_len, dtype=bool, device=seq.device)
        if len(seq) < seq_len:
            pid = 0
            pcumsum = torch.cumsum(token_count[dis_sorts[i]], 0)
            while pcumsum[pid] < seq_len-len(seq): pid += 1
            pad_nid = dis_sorts[i, :pid]
            for ni in pad_nid:
                prev = token_cumsum[ni-1] if ni != 0 else 0
                assert token_cumsum[ni] == prev + token_count[ni]
                seq = torch.cat([seq, tokens[prev:token_cumsum[ni]]])
        if len(seq) < seq_len:
            if len(seq.shape) > 1:
              seq = torch.cat([seq, torch.zeros((seq_len-len(seq), seq.shape[1]), device=seq.device, dtype=seq.dtype)])
            else:
              seq = torch.cat([seq, torch.zeros(seq_len-len(seq), device=seq.device, dtype=seq.dtype)])
            mask[len(seq):] = False
        one = seq[:seq_len]
        labels.append(datay[i])
        batches.append(one)
        masks.append(mask)
    batches = torch.stack(batches)
    masks = torch.stack(masks)
    labels = torch.LongTensor(labels)
    return batches, masks, labels

class ToyModel(torch.nn.Module):
    def __init__(self, node_num, node_channel, geom_dim, cls_num, nhead=8) -> None:
        '''
        Toy transformer for node classification
        '''
        super().__init__()
        ## Embed all tokens
        # self.encoder = transformers.BertModel.from_pretrained('bert-base-uncased')
        self.encoder = transformers.BertModel(transformers.BertConfig())
        # self.encoder.config.output_attentions = True
        hdim = self.encoder.config.hidden_size
        tokens_num = 1
        # token_embed1 = torch.nn.Embedding(node_num, hdim//tokens_num)
        token_embed2 = torch.nn.Embedding(2**geom_dim, hdim//tokens_num)
        token_embed3 = torch.nn.Linear(node_channel, hdim//tokens_num)
        # token_embed3 = torch.nn.Embedding(node_num**2, hdim//tokens_num)
        # token_embed4 = torch.nn.Embedding(node_num, hdim//4)
        self.token_embeds = torch.nn.ModuleList([token_embed2, token_embed3]) #, token_embed4
        self.node_embed =  torch.nn.Linear(node_channel, hdim) #, token_embed4
        ## Transformer Encoder
        # encoder_layer = torch.nn.TransformerEncoderLayer(d_model=hdim, nhead=nhead)
        # self.encoder = torch.nn.TransformerEncoder(encoder_layer, num_layers=6)
        self.classifier = torch.nn.Linear(hdim, cls_num)

    def forward(self, inputs, masks=None):
        # x, pos_tokens, geom_tokens, view_tokens = inputs
        ## geom_token + view_token
        embeds = []
        for f, token in zip(self.token_embeds, inputs[1:]):
           embeds.append(f(token))
        embeds = torch.stack(embeds, 0).sum(0)
        ## geom_token + view_token + node feat
        # embeds = embeds + self.node_embed(inputs[0])
        ## node feat
        # embeds = self.node_embed(inputs[0])
        outputs = self.encoder(inputs_embeds=embeds, attention_mask=masks)
        ## last_hidden_state, pooler_output, attentions = outputs
        out = self.classifier(outputs[1])
        return out

def toy_trainval(batches_list, data_idx, train=True, use_mask=True):
    if train:
        model.train()
    else:
        model.eval()
    losses = []
    preds = []
    idx_shuffle = list(range(0, len(data_idx), batch_size))
    id_list = []
    random.shuffle(idx_shuffle)
    for bi, i in enumerate(idx_shuffle):
        idx = data_idx[i:i+batch_size]
        id_list.append(idx)
        batch = [batches[idx].to(device) for batches in batches_list]
        mask = masks[idx].to(device)
        label = labels[idx].to(device)
        if train:
            if use_mask:
              out = model(batch, mask)
            else:
              out = model(batch)
            optimizer.zero_grad()
            loss = loss_fn(out, label)
            loss.backward()
            optimizer.step()
        else:
            with torch.no_grad():
                if use_mask:
                  out = model(batch, mask)
                else:
                  out = model(batch)
            loss = loss_fn(out, label)
        pred = out.max(1)[1].detach().cpu()
        preds.append(pred)
        losses.append(loss.detach().cpu())
    preds = torch.cat(preds)
    losses = torch.stack(losses)
    id_list = torch.cat(id_list)
    acc = preds.eq(labels[id_list]).sum().item() / len(data_idx)
    return losses.mean().item(), acc


In [None]:
from torch.utils.data import Dataset, DataLoader

def binary(x, bits):
    mask = 2**torch.arange(bits-1,-1,-1).to(x.device, x.dtype)
    return x.unsqueeze(-1).bitwise_and(mask).ne(0).byte()

class DataBatchSet(Dataset):

    def __init__(self, node_feat, edge_index, label, node_idx=None, mask=None, node_feat2bin=False, N=10, geom_dim=3, seq_len=512) -> None:
        self.node_feat = node_feat
        self.edge_index = edge_index
        self.label = label
        assert len(node_feat) == len(label)
        if isinstance(node_feat, list):
            ## if task is graph level, sequence will include all nodes of a graph
            self.graph_level = True
            self.node_idx = []
            for gi in range(len(node_feat)):
                for ni in range(len(node_feat[gi])):
                    self.node_idx.append([gi, ni])
            self.node_idx = torch.LongTensor(self.node_idx)
        else:
            ## if task is node level, sequence will include tokens of one node
            self.graph_level = False
            if node_idx is not None:
                self.node_idx = node_idx
            elif mask is not None:
                self.node_idx = torch.where(mask)[0]
            else:
                self.node_idx = torch.arange(len(node_feat))
        self.N = N
        self.dim = geom_dim
        self.seq_len = seq_len
        self.node_feat2bin = node_feat2bin
        if node_feat2bin:
            self.node_feat_ch = len(bin(max([f.max() for f in node_feat]))) - 2

    def __getitem__(self, i):
        if self.graph_level:
            gi, ni = self.node_idx[i]
            node_feat = self.node_feat[gi]
            if self.node_feat2bin:
                node_feat = binary(node_feat, self.node_feat_ch)
            edge_index = self.edge_index[gi]
            datay = self.label[gi]
        else:
            ni = self.node_idx[i]
            node_feat = self.node_feat
            edge_index = self.edge_index
            datay = self.label[ni:ni+1]

        geom_tokens, view_dirs, node_embeds, token_count, distance_sorts = geom_tokenizer_onenode(ni, node_feat, edge_index, self.N, self.dim)
        geom_tokens, masks, labels = token_zeropad(geom_tokens, token_count, self.seq_len, datay, distance_sorts)
        # geom_batches_d4, masks_d4, _ = token_padder(geom_tokens_d4, token_d4_count, seq_len, data.y, distance_sorts)
        view_dirs, _, _ = token_zeropad(view_dirs, token_count, self.seq_len, datay, distance_sorts)
        # view_batches_d4, _, _ = token_padder(view_dirs_d4, token_d4_count, seq_len, data.y, distance_sorts)
        node_embeds, _, _ = token_zeropad(node_embeds, token_count, self.seq_len, datay, distance_sorts)
        return node_embeds[0], geom_tokens[0], view_dirs[0], masks[0], labels[0], gi if self.graph_level else ni

    def __len__(self):
        return len(self.node_idx)

In [None]:
from torch_geometric.datasets import WebKB, WikipediaNetwork, Actor, ZINC, AQSOL, WikiCS, GNNBenchmarkDataset, Planetoid
import torch
import numpy as np

def get_data_pyg(name, split=0):
  path = '../data/' +name
  if name in ['chameleon','squirrel']:
    dataset = WikipediaNetwork(root=path, name=name)
  if name in ['cornell', 'texas', 'wisconsin']:
    dataset = WebKB(path ,name=name)
  if name == 'film':
    dataset = Actor(root=path)
  if name == 'zinc':
    dataset = ZINC(root=path)
  if name in ['pubmed', 'cora', 'citeseer']:
    dataset = Planetoid(root=path, name=name, split='geom-gcn')

  if name in ['pubmed', 'cora', 'citeseer']:
    data = dataset
    data.train_mask = data.train_mask[:, split]
    data.val_mask = data.val_mask[:, split]
    data.test_mask = data.test_mask[:, split]
  else:
    data = dataset[0]
    if name in ['chameleon', 'squirrel']:
      splits_file = np.load(f'{path}/{name}/geom_gcn/raw/{name}_split_0.6_0.2_{split}.npz')
    if name in ['cornell', 'texas', 'wisconsin']:
      splits_file = np.load(f'{path}/{name}/raw/{name}_split_0.6_0.2_{split}.npz')
    if name == 'film':
      splits_file = np.load(f'{path}/raw/{name}_split_0.6_0.2_{split}.npz')
    if name in ['Cora', 'Citeseer', 'Pubmed']:
        splits_file = np.load(f'{path}/{name}/raw/{name}_split_0.6_0.2_{split}.npz')
    train_mask = splits_file['train_mask']
    val_mask = splits_file['val_mask']
    test_mask = splits_file['test_mask']

    data.train_mask = torch.tensor(train_mask, dtype=torch.bool)
    data.val_mask = torch.tensor(val_mask, dtype=torch.bool)
    data.test_mask = torch.tensor(test_mask, dtype=torch.bool)

  return data


In [None]:

def trainval(loader, train=True, use_mask=True):
    if train:
        model.train()
        loop_type = 'train'
    else:
        model.eval()
        loop_type = 'val/test'
    losses = []
    preds = []
    labels = []
    for data in tqdm(loader, desc=f'Epoch [{e+1}\t/{epoch}]'):
        pad_mask_d3, label, did = data[3:]
        batch = [input.to(device) for input in data[:3]]
        mask = pad_mask_d3.to(device)
        label = label.to(device)
        if train:
            if use_mask:
              out = model(batch, mask)
            else:
              out = model(batch)
            optimizer.zero_grad()
            loss = loss_fn(out, label)
            loss.backward()
            optimizer.step()
        else:
            with torch.no_grad():
                if use_mask:
                  out = model(batch, mask)
                else:
                  out = model(batch)
            loss = loss_fn(out, label)
        pred = out.max(1)[1].detach().cpu()
        preds.append(pred)
        labels.append(label.detach().cpu())
        losses.append(loss.detach().cpu())
    preds = torch.cat(preds)
    preds = torch.cat(labels)
    losses = torch.stack(losses)
    id_list = torch.cat(id_list)
    acc = preds.eq(labels).sum().item() / len(labels)
    return losses.mean().item(), acc

torch.manual_seed(142857)
device = 'cuda:0'
seq_len = 512
batch_size = 16
epoch = 100
lr = 1e-6
use_mask = True

for i in range(10):
    data = get_data_pyg('pubmed', split=i)
    train_set = DataBatchSet(data.x, data.edge_index, data.y, mask=data.train_mask)
    trainloader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    val_set = DataBatchSet(data.x, data.edge_index, data.y, mask=data.val_mask)
    valloader = DataLoader(val_set, batch_size=batch_size, shuffle=False)
    test_set = DataBatchSet(data.x, data.edge_index, data.y, mask=data.test_mask)
    testloader = DataLoader(test_set, batch_size=batch_size, shuffle=False)
    loss_fn = torch.nn.CrossEntropyLoss()
    model = ToyModel(0, data.x.shape[1], 3, data.y.max().item()+1).to(device)
    optimizer = optim.Adam(model.parameters(),lr=lr) # 1e-4,weight_decay=0.01
    print(f'[start train val on split {i}]')
    for e in range(epoch):
        train_loss, train_acc = trainval(trainloader, train=True, use_mask=True)
        val_loss, val_acc = trainval(valloader, train=False, use_mask=use_mask)
        test_loss, test_acc = trainval(testloader, train=False, use_mask=use_mask)
        log = f'Epoch [{e+1}\t/{epoch}] Train Loss: {train_loss:.03f} \t Train Acc: {train_acc:.06f} \t Val Loss: {val_loss:.03f} \t Val Acc: {val_acc:.06f} \t Test Loss: {test_loss:.03f} \t Test Acc: {test_acc:.06f}'
        print(datetime.now(), log)


Epoch [1	/100]:  53%|█████▎    | 314/592 [12:48<11:11,  2.41s/it]

In [None]:
# torch.manual_seed(142857)
# device = 'cuda:0'
# seq_len = 512
# batch_size = 16
# epoch = 100
# lr = 1e-6
# token_padder = token_zeropad
# # token_padder = token_neighborpad
# use_mask = True
# data = get_data_pyg('pubmed', split=0)
# geom_tokens, view_dirs, node_embeds, token_count, distance_sorts = geom_tokenizer(data.x, data.edge_index, 10, dim=3)
# # geom_tokens_d4, view_dirs, node_embeds, token_count, distance_sorts = geom_tokenizer(data.x, data.edge_index, 10, dim=4)
# geom_batches, masks, labels = token_padder(geom_tokens, token_count, seq_len, data.y, distance_sorts)
# view_batches, _, _ = token_padder(view_dirs, token_count, seq_len, data.y, distance_sorts)
# x_batches, _, _ = token_padder(node_embeds, token_count, seq_len, data.y, distance_sorts)
# for i in range(10):
#     data = get_data_pyg('pubmed', split=i)
#     train_idx = torch.where(data.train_mask)[0]
#     val_idx = torch.where(data.val_mask)[0]
#     test_idx = torch.where(data.test_mask)[0]
#     loss_fn = torch.nn.CrossEntropyLoss()
#     model = ToyModel(len(data.x), data.x.shape[1], 3, data.y.max().item()+1).to(device)
#     optimizer = optim.Adam(model.parameters(),lr=lr) # 1e-4,weight_decay=0.01
#     # optimizer = optim.SGD(model.parameters(),lr=lr) # 1e-3
#     inputs = [x_batches, geom_batches, view_batches]
#     # inputs = [data.x.unsqueeze(1)]
#     print([i.shape for i in inputs])
#     for e in range(epoch):
#         train_loss, train_acc = toy_trainval(inputs, train_idx, train=True, use_mask=use_mask)
#         val_loss, val_acc = toy_trainval(inputs, val_idx, train=False, use_mask=use_mask)
#         test_loss, test_acc = toy_trainval(inputs, test_idx, train=False, use_mask=use_mask)
#         log = f'Epoch [{e+1}\t/{epoch}] Train Loss: {train_loss:.03f} \t Train Acc: {train_acc:.06f} \t Val Loss: {val_loss:.03f} \t Val Acc: {val_acc:.06f} \t Test Loss: {test_loss:.03f} \t Test Acc: {test_acc:.06f}'
#         print(datetime.now(), log)