In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import networkx as nx
from scipy import sparse

import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv, GATConv, GINConv, global_max_pool, GlobalAttention, GatedGraphConv
from torch_geometric.data import Data, DataLoader
from torch_geometric.utils import softmax
from torch_geometric.utils.convert import from_scipy_sparse_matrix

from pyscf import gto, scf, tools, ao2mo


import model
import train
from model import SecondNet, SimpleNet
from preprocess import build_graph, build_qm7
from train import train, test
from hf import get_data

In [None]:
mols = build_qm7('sto-3g')
#Omit first molecule, outlier geometry
mols = mols[100:101]

In [None]:
#TODO: Encode number of electrons explicitly
#TODO: Encode HF features?
#TODO: Encode the "flavor" of the orbital basis as features as well

#TODO: indicate which orbital is first and second in the pair vertices?  This breaks the symmetry,
#but we might want this anyway if we want to particularly understand one of the orbitals in the pair
#TODO: indicators should be separate features not integer values, stop being lazy
#TODO: ACTUALLY USE GCN MODEL
#TODO: Fix edge features between single and double, currently those are all zero and graph is disconnected!!!

In [None]:
#M: Number of orbitals
#N: Number of electrons
#F: feature vector length

#A is potential matrix: M x M
#U is coulumb 4-tensor: M x M x M x M
#X is additional orbital feature matrix: M x F_1
#Y is additional pairwise orbital feature matrix: M x M x F_2

#E is ground state energy

dataset = []

for mol in mols:
    A, U, X, Y, E = get_data(mol, "AO", predict_correlation = False)
    
    ####COMPLETE HACK
#     E /= 10.
#     np.fill_diagonal(Y[:,:,0], np.diagonal(Y[:,:,0]) / 10.)
#     np.fill_diagonal(Y[:,:,1], np.diagonal(Y[:,:,1]) / 10.)
    
    ####
    
    M = A.shape[0]
    X = np.zeros((M, 1)) #Currently no orbital features
                
    x, edge_index, edge_attr = build_graph(A, U, X, Y)
    
    
#     print("True energy:\t\t {}".format(E))
#     print("Energy via Trace:\t {}".format(np.sum(Y[:,:,0] * Y[:,:,1])))
#     print("Energy via features:\t {}"
#           .format(torch.sum(x[:,1] * x[:,2] * (2 - x[:,5])).item()))
#     print()

        
    data = Data(x = x, edge_index = edge_index, edge_attr = edge_attr, y = E)
    dataset.append(data)

In [None]:
import random
random.shuffle(dataset)

split = int(0.8 * len(dataset))
train_loader = DataLoader(dataset[:split], batch_size = 1)
test_loader = DataLoader(dataset[split:], batch_size = 1)

In [None]:
for data in dataset:
    print (data.x.shape, data.edge_attr.shape)

In [None]:
for batch in train_loader:
    print(batch.x.shape)
    print(batch.edge_index.shape)
    print(batch.edge_attr.shape)
    break

In [None]:
import importlib
importlib.reload(model)
from model import SecondNet, SimpleNet


In [None]:
vertex_dim = dataset[0].x.shape[1]
edge_dim = dataset[0].edge_attr.shape[1]
hidden_dim = 20

train_criterion = nn.MSELoss()
test_criterion = nn.L1Loss()


np.set_printoptions(precision=3, suppress=True)

In [None]:
# net = SecondNet(vertex_dim, edge_dim, hidden_dim).double()
net = SimpleNet(vertex_dim, edge_dim, hidden_dim, p = 0.0).double()


losses = train(net, train_loader, lr = 0.002, iterations = 2000, criterion = train_criterion, verbose = True)
print(losses[::10])

loss = test(net, test_loader, test_criterion)
print(loss)


In [None]:
for data in test_loader:
    output = net(data)
    loss = test_criterion(output, data.y.double())
    print(output)
    print(data.y.double())
    break