In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import networkx as nx
from scipy import sparse

import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv, GATConv, GINConv, global_max_pool, GlobalAttention, GatedGraphConv
from torch_geometric.data import Data, DataLoader
from torch_geometric.utils import softmax
from torch_geometric.utils.convert import from_scipy_sparse_matrix

from pyscf import gto, scf, tools, ao2mo


import model
import train
from model import SecondNet, SimpleNet
from preprocess import build_graph, build_qm7
from train import train, test
from hf import get_data

Numpy 1.16 has memory leak bug  https://github.com/numpy/numpy/issues/13808
It is recommended to downgrade to numpy 1.15 or older


In [2]:
mols = build_qm7('sto-3g')
#Omit first molecule, outlier geometry
mols = mols[1:80]

In [3]:
#TODO: Encode number of electrons explicitly
#TODO: Encode HF features?
#TODO: Encode the "flavor" of the orbital basis as features as well

#TODO: indicate which orbital is first and second in the pair vertices?  This breaks the symmetry,
#but we might want this anyway if we want to particularly understand one of the orbitals in the pair
#TODO: indicators should be separate features not integer values, stop being lazy
#TODO: ACTUALLY USE GCN MODEL
#TODO: Fix edge features between single and double, currently those are all zero and graph is disconnected!!!

In [4]:
#M: Number of orbitals
#N: Number of electrons
#F: feature vector length

#A is potential matrix: M x M
#U is coulumb 4-tensor: M x M x M x M
#X is additional orbital feature matrix: M x F_1
#Y is additional pairwise orbital feature matrix: M x M x F_2

#E is ground state energy

dataset = []

for mol in mols:
    A, U, X, Y, E = get_data(mol, "AO", predict_correlation = False)
    
    ####COMPLETE HACK
#     E /= 10.
#     np.fill_diagonal(Y[:,:,0], np.diagonal(Y[:,:,0]) / 10.)
#     np.fill_diagonal(Y[:,:,1], np.diagonal(Y[:,:,1]) / 10.)
    
    ####
    
    M = A.shape[0]
    X = np.zeros((M, 1)) #Currently no orbital features
                
    x, edge_index, edge_attr = build_graph(A, U, X, Y)
    
    
#     print("True energy:\t\t {}".format(E))
#     print("Energy via Trace:\t {}".format(np.sum(Y[:,:,0] * Y[:,:,1])))
#     print("Energy via features:\t {}"
#           .format(torch.sum(x[:,1] * x[:,2] * (2 - x[:,5])).item()))
#     print()

        
    data = Data(x = x, edge_index = edge_index, edge_attr = edge_attr, y = E)
    dataset.append(data)

  with h5py.File(chkfile) as fh5:


In [5]:
import random
random.shuffle(dataset)

split = int(0.8 * len(dataset))
train_loader = DataLoader(dataset[:split], batch_size = 6)
test_loader = DataLoader(dataset[split:], batch_size = 6)

In [6]:
import importlib
importlib.reload(model)
from model import SecondNet, SimpleNet


In [7]:
vertex_dim = dataset[0].x.shape[1]
edge_dim = dataset[0].edge_attr.shape[1]
hidden_dim = 20

train_criterion = nn.MSELoss()
test_criterion = nn.L1Loss()


np.set_printoptions(precision=3, suppress=True)

In [8]:
# net = SecondNet(vertex_dim, edge_dim, hidden_dim).double()
net = SimpleNet(vertex_dim, edge_dim, hidden_dim, p = 0.0).double()


losses = train(net, train_loader, lr = 0.002, iterations = 2000, criterion = train_criterion, verbose = True)
print(losses[::10])

loss = test(net, test_loader, test_criterion)
print(loss)


timestep: 0, loss: 544.3663398203461
timestep: 1, loss: 438.9710992028379
timestep: 2, loss: 398.34177436297233
timestep: 3, loss: 397.36148244007467
timestep: 4, loss: 383.12508624350534
timestep: 5, loss: 375.5904284393231
timestep: 6, loss: 367.5097823907338
timestep: 7, loss: 357.441342985659
timestep: 8, loss: 346.87670452701707
timestep: 9, loss: 335.53939396223194
timestep: 10, loss: 323.19237847023277
timestep: 11, loss: 310.3123538056834
timestep: 12, loss: 296.68559388055314
timestep: 13, loss: 282.76069046828917
timestep: 14, loss: 269.0620824999703
timestep: 15, loss: 255.9572529884762
timestep: 16, loss: 242.7395061053752
timestep: 17, loss: 229.33953419486932
timestep: 18, loss: 215.89603858735316
timestep: 19, loss: 202.24500105520568
timestep: 20, loss: 188.44672205512833
timestep: 21, loss: 174.5515461769024
timestep: 22, loss: 160.72589605317916
timestep: 23, loss: 147.03355199969374
timestep: 24, loss: 133.54198215293928
timestep: 25, loss: 120.39399394050776
timeste

timestep: 211, loss: 3.4295416328618895
timestep: 212, loss: 3.425114228219573
timestep: 213, loss: 3.4156195777662206
timestep: 214, loss: 3.4095071319604027
timestep: 215, loss: 3.3984357770892397
timestep: 216, loss: 3.3921780706592513
timestep: 217, loss: 3.3810928156711637
timestep: 218, loss: 3.3820990181055595
timestep: 219, loss: 3.3881337083841796
timestep: 220, loss: 3.3796175946916964
timestep: 221, loss: 3.367845414320669
timestep: 222, loss: 3.3639562049036793
timestep: 223, loss: 3.3571539702108426
timestep: 224, loss: 3.35638481857158
timestep: 225, loss: 3.354843413222828
timestep: 226, loss: 3.346843026694712
timestep: 227, loss: 3.3325402312359156
timestep: 228, loss: 3.3222545999798183
timestep: 229, loss: 3.3262286265977097
timestep: 230, loss: 3.3285153461029977
timestep: 231, loss: 3.3122055440481772
timestep: 232, loss: 3.3019264648395463
timestep: 233, loss: 3.2899189467027323
timestep: 234, loss: 3.279449690255843
timestep: 235, loss: 3.2712527491359107
timeste

timestep: 419, loss: 2.3883590129460868
timestep: 420, loss: 2.365374553808783
timestep: 421, loss: 2.3481122685700857
timestep: 422, loss: 2.3458890867973192
timestep: 423, loss: 2.3422788306697253
timestep: 424, loss: 2.334234084040452
timestep: 425, loss: 2.326072035895599
timestep: 426, loss: 2.3175941905234296
timestep: 427, loss: 2.3054513727960417
timestep: 428, loss: 2.3049565309003133
timestep: 429, loss: 2.3049980670678347
timestep: 430, loss: 2.3014326626162473
timestep: 431, loss: 2.2949616454260844
timestep: 432, loss: 2.28730545860746
timestep: 433, loss: 2.2743349350540787
timestep: 434, loss: 2.2586836840660953
timestep: 435, loss: 2.2514648764046314
timestep: 436, loss: 2.2543591636011295
timestep: 437, loss: 2.2530551394977754
timestep: 438, loss: 2.2521231939841595
timestep: 439, loss: 2.250123012530189
timestep: 440, loss: 2.2463159323247344
timestep: 441, loss: 2.2391597191144643
timestep: 442, loss: 2.2432041370690574
timestep: 443, loss: 2.2460755303451716
timest

timestep: 626, loss: 1.9093687939193664
timestep: 627, loss: 1.8976024029363747
timestep: 628, loss: 1.8802138613236223
timestep: 629, loss: 1.8596455977570405
timestep: 630, loss: 1.8367326075531694
timestep: 631, loss: 1.7964528936421984
timestep: 632, loss: 1.7537138957565253
timestep: 633, loss: 1.6987890972995698
timestep: 634, loss: 1.6319603578686526
timestep: 635, loss: 1.5709420859494438
timestep: 636, loss: 1.4999501135817135
timestep: 637, loss: 1.4423083378890404
timestep: 638, loss: 1.4040092401293638
timestep: 639, loss: 1.3917035341724988
timestep: 640, loss: 1.3909548271319057
timestep: 641, loss: 1.404255100945402
timestep: 642, loss: 1.434892335519172
timestep: 643, loss: 1.4727687570093548
timestep: 644, loss: 1.5283399147894756
timestep: 645, loss: 1.5882407626559647
timestep: 646, loss: 1.6512585702195386
timestep: 647, loss: 1.7183594870905727
timestep: 648, loss: 1.7777109137610603
timestep: 649, loss: 1.830453665585572
timestep: 650, loss: 1.872860957095867
time

timestep: 832, loss: 0.36480915903728134
timestep: 833, loss: 0.32222634103448455
timestep: 834, loss: 0.3151563327395382
timestep: 835, loss: 0.3215556555258851
timestep: 836, loss: 0.33430220921997905
timestep: 837, loss: 0.3582420912686205
timestep: 838, loss: 0.3940167242817916
timestep: 839, loss: 0.44803697195154074
timestep: 840, loss: 0.5203029732865256
timestep: 841, loss: 0.6152502872840445
timestep: 842, loss: 0.734514673274861
timestep: 843, loss: 0.8840016435007261
timestep: 844, loss: 1.0766011363745358
timestep: 845, loss: 1.3243882540301355
timestep: 846, loss: 1.6395765122696473
timestep: 847, loss: 2.0266477639424787
timestep: 848, loss: 2.502990440722496
timestep: 849, loss: 3.0779209214471575
timestep: 850, loss: 3.6535239975143585
timestep: 851, loss: 4.108087522364664
timestep: 852, loss: 4.232371917528979
timestep: 853, loss: 3.8666878049377615
timestep: 854, loss: 3.018257904370089
timestep: 855, loss: 2.006587918319728
timestep: 856, loss: 1.1593247224994137
ti

timestep: 1032, loss: 0.08824367008154756
timestep: 1033, loss: 0.08438263273375837
timestep: 1034, loss: 0.08073169615737387
timestep: 1035, loss: 0.07726833843250724
timestep: 1036, loss: 0.07460368691090058
timestep: 1037, loss: 0.07177422940921398
timestep: 1038, loss: 0.06835789279544588
timestep: 1039, loss: 0.06593128977102115
timestep: 1040, loss: 0.06358677722927246
timestep: 1041, loss: 0.061478425539646046
timestep: 1042, loss: 0.05957468013273545
timestep: 1043, loss: 0.05782845844139053
timestep: 1044, loss: 0.05622406731634539
timestep: 1045, loss: 0.05478390664509725
timestep: 1046, loss: 0.05310945034465009
timestep: 1047, loss: 0.05178454265900283
timestep: 1048, loss: 0.05029575006721649
timestep: 1049, loss: 0.049613344575967555
timestep: 1050, loss: 0.04752193622154841
timestep: 1051, loss: 0.04928766985912004
timestep: 1052, loss: 0.04877866084766578
timestep: 1053, loss: 0.07177788920989084
timestep: 1054, loss: 0.14173530148438518
timestep: 1055, loss: 0.54041866

timestep: 1228, loss: 0.08045913075096986
timestep: 1229, loss: 0.07036071679643062
timestep: 1230, loss: 0.3388963001030165
timestep: 1231, loss: 0.6676612740500084
timestep: 1232, loss: 1.8284863038044377
timestep: 1233, loss: 0.8686945426857423
timestep: 1234, loss: 0.09867739391485385
timestep: 1235, loss: 0.15403134881341338
timestep: 1236, loss: 0.034305595217246763
timestep: 1237, loss: 0.05341003430140721
timestep: 1238, loss: 0.05890100349878209
timestep: 1239, loss: 0.031193887130542542
timestep: 1240, loss: 0.04495992633355244
timestep: 1241, loss: 0.03592628769796157
timestep: 1242, loss: 0.039170292021012215
timestep: 1243, loss: 0.03700236405320625
timestep: 1244, loss: 0.03642616165430158
timestep: 1245, loss: 0.03442866089585506
timestep: 1246, loss: 0.03723964906738221
timestep: 1247, loss: 0.029456975386890127
timestep: 1248, loss: 0.04192619133844267
timestep: 1249, loss: 0.026808409061867454
timestep: 1250, loss: 0.08166200060238361
timestep: 1251, loss: 0.106712807

timestep: 1425, loss: 0.29195673382257625
timestep: 1426, loss: 0.306060925411325
timestep: 1427, loss: 0.30370326178566653
timestep: 1428, loss: 0.3214565456466682
timestep: 1429, loss: 0.3269546495763302
timestep: 1430, loss: 0.3284072370831929
timestep: 1431, loss: 0.3433537882784786
timestep: 1432, loss: 0.3479746527776222
timestep: 1433, loss: 0.3431582973812821
timestep: 1434, loss: 0.33223825290892584
timestep: 1435, loss: 0.3152547929198617
timestep: 1436, loss: 0.29028243223952543
timestep: 1437, loss: 0.26725397837928605
timestep: 1438, loss: 0.2347287641867234
timestep: 1439, loss: 0.2045666876901861
timestep: 1440, loss: 0.17930689628134067
timestep: 1441, loss: 0.1573288873646628
timestep: 1442, loss: 0.14309126373618636
timestep: 1443, loss: 0.13475131993543082
timestep: 1444, loss: 0.12562912455407854
timestep: 1445, loss: 0.11693703232579145
timestep: 1446, loss: 0.10952565130294438
timestep: 1447, loss: 0.1023991415157485
timestep: 1448, loss: 0.0981303388705634
timest

timestep: 1622, loss: 0.029294042382312556
timestep: 1623, loss: 0.030573009848203655
timestep: 1624, loss: 0.03073461748135228
timestep: 1625, loss: 0.02937055698094621
timestep: 1626, loss: 0.02782476153167228
timestep: 1627, loss: 0.02770931570149009
timestep: 1628, loss: 0.029236034623001936
timestep: 1629, loss: 0.0308303680985509
timestep: 1630, loss: 0.030843787963262672
timestep: 1631, loss: 0.029508560564838512
timestep: 1632, loss: 0.02817469459028803
timestep: 1633, loss: 0.028011780699901324
timestep: 1634, loss: 0.02896844915650161
timestep: 1635, loss: 0.030219679072036553
timestep: 1636, loss: 0.031057369008957378
timestep: 1637, loss: 0.031627147170409944
timestep: 1638, loss: 0.03193785682737398
timestep: 1639, loss: 0.032046895123988194
timestep: 1640, loss: 0.032307185368239726
timestep: 1641, loss: 0.03271090263160391
timestep: 1642, loss: 0.03298935167342984
timestep: 1643, loss: 0.03277827533116013
timestep: 1644, loss: 0.03242445642421243
timestep: 1645, loss: 0.

timestep: 1816, loss: 0.018463373982809214
timestep: 1817, loss: 0.01878354018580094
timestep: 1818, loss: 0.017204370608443067
timestep: 1819, loss: 0.01744986580291404
timestep: 1820, loss: 0.016994208487676885
timestep: 1821, loss: 0.016785075244538282
timestep: 1822, loss: 0.016344519925506017
timestep: 1823, loss: 0.01588035068182366
timestep: 1824, loss: 0.01545714653204673
timestep: 1825, loss: 0.015076790741028016
timestep: 1826, loss: 0.014348171071763026
timestep: 1827, loss: 0.01415240294813297
timestep: 1828, loss: 0.013696030739120674
timestep: 1829, loss: 0.013698332086259436
timestep: 1830, loss: 0.012862015845651819
timestep: 1831, loss: 0.01357490072870683
timestep: 1832, loss: 0.011798328394667275
timestep: 1833, loss: 0.015152488462725607
timestep: 1834, loss: 0.011071020211049935
timestep: 1835, loss: 0.03226634434646768
timestep: 1836, loss: 0.046862439882478964
timestep: 1837, loss: 0.26102452937025317
timestep: 1838, loss: 0.7874955479415007
timestep: 1839, loss:

In [12]:
for data in test_loader:
    output = net(data)
    loss = test_criterion(output, data.y.double())
    print(output)
    print(data.y.double())
    break

tensor([-110.7541, -148.3292,  -87.5184, -148.3689, -111.8256, -131.5436],
       dtype=torch.float64, grad_fn=<SqueezeBackward1>)
tensor([-110.6895, -148.0932,  -87.4793, -148.3788, -111.5629, -131.4503],
       dtype=torch.float64)
