In [None]:
import numpy as np
import torch
import torch_geometric as pyg
import graphPINN
import math
import logging
from time import time
from tqdm.notebook import tqdm
from scipy.io import savemat

folder = "C:\\Users\\nhmathew\\Documents\\code\\run-2023-02-21\\"

In [None]:
def pretty_size(n,pow=0,b=1024,u='B',pre=['']+[p+'i'for p in'KMGTPEZY']):
    pow,n=min(int(math.log(max(n*b**pow,1),b)),len(pre)-1),n*b**pow
    return "%%.%if %%s%%s"%abs(pow%(-pow-1))%(n/b**float(pow),pre[pow],u)

logging.basicConfig(filename=f'{folder}run.log',format='%(asctime)s - %(message)s', filemode='a+', level=logging.INFO)
def logfn(message, tq=True):
    logging.info(message)
    if tq:
        tqdm.write(message)
    else:
        print(message)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
logfn(f"{pretty_size(torch.cuda.get_device_properties(0).total_memory)} of {'cuda' if torch.cuda.is_available() else 'cpu'} memory")

In [None]:
batch_size = 1
k = 50

dataset = graphPINN.data.MHSDataset(f'D:\\nats ML stuff\\data_k={k}',k=k)


# trainset, validset, testset = torch.utils.data.random_split(dataset,[0.8, 0.1, 0.1],generator=torch.Generator().manual_seed(42))
trainset, validset, testset = torch.utils.data.random_split(dataset,[0.005, 0.001, 0.994],generator=torch.Generator().manual_seed(42))

In [None]:
design = [18,9,6,6,3]
kernel = graphPINN.KernelNN(design, torch.nn.ReLU)
convgraph = graphPINN.ConvGraph(kernel).to(device)

design = [12,6,6,3]
propkernel = graphPINN.KernelNN(design, torch.nn.ReLU)
propgraph = graphPINN.BDPropGraph(propkernel).to(device)


In [None]:
def train(propgraph, convgraph, epochs = 1, lbfgs = False):
    
    if lbfgs:
        def closure():
            # necessary for lbfgs
            optimizer.zero_grad()
            output = convgraph.forward(kdtree,iter=10)
            loss = graphPINN.MHS.loss(output, true)
            loss.backward()
            return loss
        optimizer = torch.optim.LBFGS(convgraph.kernel.parameters())
    else:
        optimizer = torch.optim.Adam(convgraph.kernel.parameters())
    
    training_loss   = torch.zeros(4,epochs)
    validation_loss = torch.zeros(4,epochs)
    for epoch in range(epochs):
        trainLoader = pyg.loader.DataLoader(trainset, batch_size=batch_size,shuffle=False)
        validLoader = pyg.loader.DataLoader(validset, batch_size=batch_size,shuffle=False)
        
        convgraph.train(True)
        
        running_loss = 0
        running_vec = 0
        running_mhs = 0
        running_div = 0
        iter = 0
        skipped = 0
        start_time = time()
        
        for data in tqdm(trainLoader):
            data['bd','propagates','in'].edge_index, data['bd','propagates','in'].edge_attr = \
                        pyg.utils.dense_to_sparse(
                                torch.ones(data['bd'].x.shape[0],data['in'].x.shape[0])
                        )
            
            data = data.to_homogeneous()
            data.to(device)
            
            optimizer.zero_grad()
            
            true = [data.y[:,0:3],data.x[:,3],data.x[:,4],data.x[:,5]]
            
            pred = graphPINN.FullModel(data, propgraph, convgraph)
            
            loss, vec_diff, mhs_diff, div_diff = graphPINN.MHS.loss(pred,true, logfn=None)
            
            iter += 1
            
            if loss.item() < 1000:
                loss.backward()
                if lbfgs:
                    optimizer.step(closure)
                else:
                    optimizer.step()

                running_loss += loss.item()
                running_vec += vec_diff
                running_mhs += mhs_diff
                running_div += div_diff
#                 logfn(f'epoch {epoch+1} iter {iter}/{len(trainLoader)}, loss {loss.item()}')
            else:
#                 logfn(f'loss skipped...')
                skipped += 1
            
        training_loss[:,epoch] = torch.tensor((running_vec/(len(trainLoader)-skipped),
                                               running_mhs/(len(trainLoader)-skipped),
                                               running_div/(len(trainLoader)-skipped),
                                              running_loss/(len(trainLoader)-skipped)
                                 ))
        
        logfn(f'Epoch {epoch+1} completed. Loss: {training_loss[3,epoch]}; Total skipped: {skipped}; Total time: {time()-start_time}', tq=False)
        logfn(f'running vec: {training_loss[0,epoch]}, running mhs: {training_loss[1,epoch]}, running div: {training_loss[2,epoch]}')
        
        convgraph.train(False)
        torch.save(convgraph.state_dict(), f'{folder}epoch-{epoch+1}_trainsize-{len(trainset)}_k-{k}.pt')
        
        start_time = time()
        valid_skipped = 0
        valid_vec = 0
        valid_mhs = 0
        valid_div = 0
        running_valid = 0
        for data in tqdm(validLoader):
            data['bd','propagates','in'].edge_index, data['bd','propagates','in'].edge_attr = \
                        pyg.utils.dense_to_sparse(
                                torch.ones(data['bd'].x.shape[0],data['in'].x.shape[0])
                        )
            
            data = data.to_homogeneous()
            data.to(device)
            
            true = [data.y[:,0:3],data.x[:,3],data.x[:,4],data.x[:,5]]
            pred = graphPINN.FullModel(data, propgraph, convgraph)
            
            loss, vec, mhs, div = graphPINN.MHS.loss(pred,true, logfn=None)
            
            if loss.item() < 1000:
                running_valid += loss.item()
                valid_vec += vec
                valid_mhs += mhs
                valid_div += div
            else:
#                 logfn('  skipped')
                valid_skipped += 1
        validation_loss[:,epoch] = torch.tensor((valid_vec/(len(validLoader)-valid_skipped),
                                                 valid_mhs/(len(validLoader)-valid_skipped),
                                                 valid_div/(len(validLoader)-valid_skipped),
                                             running_valid/(len(validLoader)-valid_skipped)
                                   ))
        logfn(f'Validation loss: {validation_loss[3,epoch]}, total skipped: {valid_skipped}, validation time: {time()-start_time}', tq=False)
        logfn(f'running vec: {validation_loss[0,epoch]}, running mhs: {validation_loss[1,epoch]}, running div: {validation_loss[2,epoch]}')
        
            
    return running_loss, training_loss, validation_loss

epochs = 5
loss, training_loss, validation_loss = train(propgraph, convgraph, epochs=epochs)
logfn(f'validation loss: {validation_loss[3,:]}', tq=False)
logfn(f'training loss: {training_loss[3,:]}', tq=False)

In [None]:
lossdict = {}
lossdict['training'] = training_loss
lossdict['validation'] = validation_loss
lossdict['ordering'] = ['vector','mhs','div','total']
print(lossdict['training'])
print(lossdict['validation'])

In [None]:
savemat(f'{foler}loss_{epochs}_trainsize-{len(trainset)}_k-{k}.mat', lossdict)

In [None]:
print(f'{pretty_size(torch.cuda.memory_allocated())}/{pretty_size(torch.cuda.max_memory_allocated())} allocated, ' +
      f'{pretty_size(torch.cuda.memory_reserved())}/{pretty_size(torch.cuda.get_device_properties(0).total_memory)} reserved')

In [None]:
# convgraph.load_state_dict(torch.load('epoch-0_trainsize-3092_k-100.pt'))

In [None]:
sample = dataset[0]
sample['bd','propagates','in'].edge_index, sample['bd','propagates','in'].edge_attr = \
            pyg.utils.dense_to_sparse(
                        torch.ones(sample['bd'].x.shape[0],sample['in'].x.shape[0])
            )
print(sample)

In [None]:
homosamp = sample.to_homogeneous()

In [None]:
print(homosamp.x)

In [None]:
print(homosamp.edge_index[:,homosamp.edge_type==1])