In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import networkx as nx
from scipy import sparse

import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv, GATConv, GINConv, global_max_pool, GlobalAttention, GatedGraphConv
from torch_geometric.data import Data, DataLoader
from torch_geometric.utils import softmax
from torch_geometric.utils.convert import from_scipy_sparse_matrix


from model import GCNNet
from preprocess import build_graph
from train import train

In [21]:
M = 2
F = 3

#M: total number of orbitals
#N: total number of particles
#F: feature vector length

#A is potential matrix: M x M
#U is coulumb 4-tensor: M x M x M x M
#X is additional orbital feature matrix: M x F
#Y is additional pairwise orbital feature matrix: M x M x F

data_list = []

for i in range(4):
    A = np.random.normal(size = (M, M))
    U = np.random.normal(size = (M, M, M, M))
    X = np.random.normal(size = (M, F))
    Y = np.random.normal(size = (M, M, F))
    
    E = np.random.random()

    W, Z = build_graph(A, U, X, Y)
    Z = torch.from_numpy(Z)
    edge_index, edge_attr = from_scipy_sparse_matrix(sparse.coo_matrix(W))    
    data = Data(x = Z, edge_index = edge_index, edge_attr = edge_attr, y = E)
    data_list.append(data)

loader = DataLoader(data_list, batch_size = 2)

In [28]:
model = GCNNet(F, 10).double()
losses = train(model, loader)
print(losses)

[0.14294144796979588, 0.12728628128438282, 0.04797815326771675, 0.11894373360319394, 0.009504410346480524, 0.11545914688322013, 0.007416488466565711, 0.10699878754387576, 0.013894658364339726, 0.09112305203423568, 0.011109902235926684, 0.07422455294243532, 0.00459504209272669, 0.0629133503769032, 0.0013316045964158634, 0.05917773381060876, 0.001383269855816409, 0.05669847426445927, 0.0014346451362396875, 0.05343794027068823, 0.00041723848653123864, 0.047104826972250415, 0.0005603725679975883, 0.039805082586003834, 0.003676420568259651, 0.033648195216382154, 0.008121894441155129, 0.029365080083012616, 0.009937027822774905, 0.026840146647495712, 0.007840845686176342, 0.025511417041186202, 0.004340372651884093, 0.025200608715962136, 0.0018315569737035887, 0.02464092832261934, 0.0007444225115359909, 0.022791846472097414, 0.00042764429076201584, 0.019645653132787384]
