In [8]:
import numpy as np
import torch
import torch_geometric as pyg
import graphPINN
import math
import logging
from time import time
from tqdm.notebook import tqdm
from scipy.io import savemat

folder = "C:\\Users\\nhmathew\\Documents\\code\\run-2023-02-23\\"

In [9]:
def pretty_size(n,pow=0,b=1024,u='B',pre=['']+[p+'i'for p in'KMGTPEZY']):
    pow,n=min(int(math.log(max(n*b**pow,1),b)),len(pre)-1),n*b**pow
    return "%%.%if %%s%%s"%abs(pow%(-pow-1))%(n/b**float(pow),pre[pow],u)

logging.basicConfig(filename=f'{folder}run.log',format='%(asctime)s - %(message)s', filemode='a+', level=logging.INFO)
def logfn(message, tq=True):
    logging.info(message)
    if tq:
        tqdm.write(message)
    else:
        print(message)

In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [11]:
logfn(f"{pretty_size(torch.cuda.get_device_properties(0).total_memory)} of {'cuda' if torch.cuda.is_available() else 'cpu'} memory")

24.0 GiB of cuda memory


In [12]:
batch_size = 1
k = 100

dataset = graphPINN.data.MHSDataset(f'D:\\nats ML stuff\\data_k={k}',k=k)


trainset, validset, testset = torch.utils.data.random_split(dataset,[0.8, 0.1, 0.1],generator=torch.Generator().manual_seed(42))
# trainset, validset, testset = torch.utils.data.random_split(dataset,[0.005, 0.001, 0.994],generator=torch.Generator().manual_seed(42))

Processing...
100%|██████████████████████████████████████████████████████████████████████████████| 644/644 [1:45:44<00:00,  9.85s/it]
Done!


In [13]:
design = [18,9,6,6,3]
kernel = graphPINN.KernelNN(design, torch.nn.ReLU)
convgraph = graphPINN.ConvGraph(kernel).to(device)

design = [12,6,6,3]
propkernel = graphPINN.KernelNN(design, torch.nn.ReLU)
propgraph = graphPINN.BDPropGraph(propkernel).to(device)


In [14]:
def train(propgraph, convgraph, epochs = 1, lbfgs = False):
    
    if lbfgs:
        def closure():
            # necessary for lbfgs
            optimizer.zero_grad()
            output = convgraph.forward(kdtree,iter=10)
            loss = graphPINN.MHS.loss(output, true)
            loss.backward()
            return loss
        optimizer = torch.optim.LBFGS(convgraph.kernel.parameters())
    else:
        optimizer = torch.optim.Adam(convgraph.kernel.parameters())
    
    training_loss   = torch.zeros(4,epochs)
    validation_loss = torch.zeros(4,epochs)
    for epoch in range(epochs):
        trainLoader = pyg.loader.DataLoader(trainset, batch_size=batch_size,shuffle=False)
        validLoader = pyg.loader.DataLoader(validset, batch_size=batch_size,shuffle=False)
        
        convgraph.train(True)
        
        running_loss = 0
        running_vec = 0
        running_mhs = 0
        running_div = 0
        iter = 0
        skipped = 0
        start_time = time()
        
        for data in tqdm(trainLoader):
            data['bd','propagates','in'].edge_index, data['bd','propagates','in'].edge_attr = \
                        pyg.utils.dense_to_sparse(
                                torch.ones(data['bd'].x.shape[0],data['in'].x.shape[0])
                        )
            
            data = data.to_homogeneous()
            data.to(device)
            
            optimizer.zero_grad()
            
            true = [data.y[:,0:3],data.x[:,3],data.x[:,4],data.x[:,5]]
            
            pred = graphPINN.FullModel(data, propgraph, convgraph)
            
            loss, vec_diff, mhs_diff, div_diff = graphPINN.MHS.loss(pred,true, logfn=None)
            
            iter += 1
            
            if loss.item() < 1e4:
                loss.backward()
                if lbfgs:
                    optimizer.step(closure)
                else:
                    optimizer.step()

                running_loss += loss.item()
                running_vec += vec_diff
                running_mhs += mhs_diff
                running_div += div_diff
#                 logfn(f'epoch {epoch+1} iter {iter}/{len(trainLoader)}, loss {loss.item()}')
            else:
#                 logfn(f'loss skipped...')
                skipped += 1
            
        training_loss[:,epoch] = torch.tensor((running_vec/(len(trainLoader)-skipped),
                                               running_mhs/(len(trainLoader)-skipped),
                                               running_div/(len(trainLoader)-skipped),
                                              running_loss/(len(trainLoader)-skipped)
                                 ))
        
        logfn(f'Epoch {epoch+1} completed. Loss: {training_loss[3,epoch]}; Total skipped: {skipped}; Total time: {time()-start_time}', tq=False)
        logfn(f'running vec: {training_loss[0,epoch]}, running mhs: {training_loss[1,epoch]}, running div: {training_loss[2,epoch]}')
        
        convgraph.train(False)
        torch.save(convgraph.state_dict(), f'{folder}epoch-{epoch+1}_trainsize-{len(trainset)}_k-{k}.pt')
        
        start_time = time()
        valid_skipped = 0
        valid_vec = 0
        valid_mhs = 0
        valid_div = 0
        running_valid = 0
        for data in tqdm(validLoader):
            data['bd','propagates','in'].edge_index, data['bd','propagates','in'].edge_attr = \
                        pyg.utils.dense_to_sparse(
                                torch.ones(data['bd'].x.shape[0],data['in'].x.shape[0])
                        )
            
            data = data.to_homogeneous()
            data.to(device)
            
            true = [data.y[:,0:3],data.x[:,3],data.x[:,4],data.x[:,5]]
            pred = graphPINN.FullModel(data, propgraph, convgraph)
            
            loss, vec, mhs, div = graphPINN.MHS.loss(pred,true, logfn=None)
            
            if loss.item() < 1e4:
                running_valid += loss.item()
                valid_vec += vec
                valid_mhs += mhs
                valid_div += div
            else:
#                 logfn('  skipped')
                valid_skipped += 1
        validation_loss[:,epoch] = torch.tensor((valid_vec/(len(validLoader)-valid_skipped),
                                                 valid_mhs/(len(validLoader)-valid_skipped),
                                                 valid_div/(len(validLoader)-valid_skipped),
                                             running_valid/(len(validLoader)-valid_skipped)
                                   ))
        logfn(f'Validation loss: {validation_loss[3,epoch]}, total skipped: {valid_skipped}, validation time: {time()-start_time}', tq=False)
        logfn(f'running vec: {validation_loss[0,epoch]}, running mhs: {validation_loss[1,epoch]}, running div: {validation_loss[2,epoch]}')
        
            
    return running_loss, training_loss, validation_loss

epochs = 5
loss, training_loss, validation_loss = train(propgraph, convgraph, epochs=epochs)
logfn(f'validation loss: {validation_loss[3,:]}', tq=False)
logfn(f'training loss: {training_loss[3,:]}', tq=False)

  0%|          | 0/3092 [00:00<?, ?it/s]

Epoch 1 completed. Loss: 36.4398078918457; Total skipped: 264; Total time: 23545.09375858307
running vec: 28.127338409423828, running mhs: 0.08386331796646118, running div: 8.228604316711426


  0%|          | 0/386 [00:00<?, ?it/s]

Validation loss: 50.54824447631836, total skipped: 31, validation time: 2915.774275779724
running vec: 37.88922882080078, running mhs: 0.037443857640028, running div: 12.621569633483887


  0%|          | 0/3092 [00:00<?, ?it/s]

Epoch 2 completed. Loss: 38.758628845214844; Total skipped: 264; Total time: 23826.29337835312
running vec: 28.12733268737793, running mhs: 0.047043636441230774, running div: 10.584253311157227


  0%|          | 0/386 [00:00<?, ?it/s]

Validation loss: 57.04352569580078, total skipped: 31, validation time: 2912.829377889633
running vec: 37.88557052612305, running mhs: 0.03768550977110863, running div: 19.120267868041992


  0%|          | 0/3092 [00:00<?, ?it/s]

Epoch 3 completed. Loss: 41.97667694091797; Total skipped: 264; Total time: 23830.65084338188
running vec: 28.1208438873291, running mhs: 0.04396802932024002, running div: 13.811864852905273


  0%|          | 0/386 [00:00<?, ?it/s]

Validation loss: 57.475399017333984, total skipped: 31, validation time: 2921.6089940071106
running vec: 37.8637580871582, running mhs: 0.04037060588598251, running div: 19.571266174316406


  0%|          | 0/3092 [00:00<?, ?it/s]

Epoch 4 completed. Loss: 47.15961837768555; Total skipped: 264; Total time: 23853.660502433777
running vec: 28.118637084960938, running mhs: 0.04212522506713867, running div: 18.998855590820312


  0%|          | 0/386 [00:00<?, ?it/s]

Validation loss: 54.35874557495117, total skipped: 31, validation time: 2920.70063662529
running vec: 37.86005401611328, running mhs: 0.03805379569530487, running div: 16.46063804626465


  0%|          | 0/3092 [00:00<?, ?it/s]

Epoch 5 completed. Loss: 39.78526306152344; Total skipped: 264; Total time: 23871.073615074158
running vec: 28.114511489868164, running mhs: 0.03997684270143509, running div: 11.630776405334473


  0%|          | 0/386 [00:00<?, ?it/s]

Validation loss: 47.924072265625, total skipped: 31, validation time: 2920.46866440773
running vec: 37.8585090637207, running mhs: 0.03745351731777191, running div: 10.028109550476074
validation loss: tensor([50.5482, 57.0435, 57.4754, 54.3587, 47.9241])
training loss: tensor([36.4398, 38.7586, 41.9767, 47.1596, 39.7853])


In [15]:
lossdict = {}
lossdict['training'] = training_loss
lossdict['validation'] = validation_loss
lossdict['ordering'] = ['vector','mhs','div','total']
print(lossdict['training'])
print(lossdict['validation'])

tensor([[2.8127e+01, 2.8127e+01, 2.8121e+01, 2.8119e+01, 2.8115e+01],
        [8.3863e-02, 4.7044e-02, 4.3968e-02, 4.2125e-02, 3.9977e-02],
        [8.2286e+00, 1.0584e+01, 1.3812e+01, 1.8999e+01, 1.1631e+01],
        [3.6440e+01, 3.8759e+01, 4.1977e+01, 4.7160e+01, 3.9785e+01]])
tensor([[3.7889e+01, 3.7886e+01, 3.7864e+01, 3.7860e+01, 3.7859e+01],
        [3.7444e-02, 3.7686e-02, 4.0371e-02, 3.8054e-02, 3.7454e-02],
        [1.2622e+01, 1.9120e+01, 1.9571e+01, 1.6461e+01, 1.0028e+01],
        [5.0548e+01, 5.7044e+01, 5.7475e+01, 5.4359e+01, 4.7924e+01]])


In [16]:
savemat(f'{folder}loss_{epochs}_trainsize-{len(trainset)}_k-{k}.mat', lossdict)

In [None]:
print(f'{pretty_size(torch.cuda.memory_allocated())}/{pretty_size(torch.cuda.max_memory_allocated())} allocated, ' +
      f'{pretty_size(torch.cuda.memory_reserved())}/{pretty_size(torch.cuda.get_device_properties(0).total_memory)} reserved')

In [None]:
# convgraph.load_state_dict(torch.load('epoch-0_trainsize-3092_k-100.pt'))

In [None]:
sample = dataset[0]
sample['bd','propagates','in'].edge_index, sample['bd','propagates','in'].edge_attr = \
            pyg.utils.dense_to_sparse(
                        torch.ones(sample['bd'].x.shape[0],sample['in'].x.shape[0])
            )
print(sample)

In [None]:
homosamp = sample.to_homogeneous()

In [None]:
print(homosamp.x)

In [None]:
print(homosamp.edge_index[:,homosamp.edge_type==1])