In [1]:
import numpy as np
import torch
import torch_geometric as pyg
import graphPINN
import math
import logging
from time import time
from tqdm.notebook import tqdm

In [2]:
def pretty_size(n,pow=0,b=1024,u='B',pre=['']+[p+'i'for p in'KMGTPEZY']):
    pow,n=min(int(math.log(max(n*b**pow,1),b)),len(pre)-1),n*b**pow
    return "%%.%if %%s%%s"%abs(pow%(-pow-1))%(n/b**float(pow),pre[pow],u)

logging.basicConfig(filename='log.log',format='%(asctime)s - %(message)s', filemode='a+', level=logging.INFO)
def logfn(message, tq=True):
    logging.info(message)
    if tq:
        tqdm.write(message)
    else:
        print(message)

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
logfn(f"{pretty_size(torch.cuda.get_device_properties(0).total_memory)} of {'cuda' if torch.cuda.is_available() else 'cpu'} memory")

24.0 GiB of cuda memory


In [5]:
batch_size = 1
k = 50

dataset = graphPINN.data.MHSDataset(f'D:\\nats ML stuff\\data_k={k}',k=k)


trainset, validset, testset = torch.utils.data.random_split(dataset,[0.8, 0.1, 0.1],generator=torch.Generator().manual_seed(42))
# trainset, validset, testset = torch.utils.data.random_split(dataset,[0.005, 0.001, 0.994],generator=torch.Generator().manual_seed(42))

Processing...
100%|████████████████████████████████████████████████████████████████████████████████| 644/644 [51:56<00:00,  4.84s/it]
Done!


In [6]:
design = [18,9,6,6,3]
kernel = graphPINN.KernelNN(design, torch.nn.ReLU)
convgraph = graphPINN.ConvGraph(kernel).to(device)

design = [9,6,6,3]
propkernel = graphPINN.KernelNN(design, torch.nn.ReLU)
propgraph = graphPINN.BDPropGraph(propkernel).to(device)


In [None]:
def train(convgraph, epochs = 1, lbfgs = False):
    
    if lbfgs:
        def closure():
            # necessary for lbfgs
            optimizer.zero_grad()
            output = convgraph.forward(kdtree,iter=10)
            loss = graphPINN.MHS.loss(output, true)
            loss.backward()
            return loss
        optimizer = torch.optim.LBFGS(convgraph.kernel.parameters())
    else:
        optimizer = torch.optim.Adam(convgraph.kernel.parameters())
    
    validation_loss = [0]*epochs
    for epoch in range(1,epochs):
        trainLoader = pyg.loader.DataLoader(trainset, batch_size=batch_size,shuffle=False)
        validLoader = pyg.loader.DataLoader(validset, batch_size=batch_size,shuffle=False)
        
        convgraph.train(True)
        
        running_loss = 0
        running_vec = 0
        running_mhs = 0
        running_div = 0
        iter = 0
        skipped = 0
        start_time = time()
        
        for data in tqdm(trainLoader):
            data['bd','propagates','in'].edge_index, data['bd','propagates','in'].edge_attr = \
                        pyg.utils.dense_to_sparse(
                                torch.ones(data['bd'].x.shape[0],data['in'].x.shape[0])
                        )
            
            
            data = data.to_homogeneous()
            
            data.to(device)
            
            optimizer.zero_grad()
            
            true = [data.y[:,0:3],data.y[:,3],data.y[:,4],data.y[:,5]]
            
            data.x[:,0:3] = propgraph.forward(data)
            
            pred = convgraph.forward(data, iter=1)
            loss, vec_diff, mhs_diff, div_diff = graphPINN.MHS.loss(pred,true, logfn=None)
            
            iter += 1
            
            if loss.item() < 1000:
                loss.backward()
                if lbfgs:
                    optimizer.step(closure)
                else:
                    optimizer.step()

                running_loss += loss.item()
                running_vec += vec_diff
                running_mhs += mhs_diff
                running_div += div_diff
#                 logfn(f'epoch {epoch} iter {iter}/{len(trainLoader)}, loss {loss.item()}')
            else:
#                 logfn('loss skipped...')
                skipped += 1
            
        logfn(f'Epoch {epoch} completed. Loss: {running_loss / (len(trainLoader)-skipped)}; Total skipped: {skipped}; Total time: {time()-start_time}', tq=False)
        logfn(f'running vec: {running_vec / (len(trainLoader)-skipped)}, running mhs: {running_mhs / (len(trainLoader)-skipped)}, running div: {running_div / (len(trainLoader)-skipped)}')
        
        convgraph.train(False)
        torch.save(convgraph.state_dict(), f'epoch-{epoch}_trainsize-{len(trainset)}_k-{k}.pt')
        
        start_time = time()
        valid_skipped = 0
        valid_vec = 0
        valid_mhs = 0
        valid_div = 0
        running_valid = 0
        for kdtree in tqdm(validLoader):
            kdtree.to(device)
            true = [kdtree.y[:,0:3],kdtree.y[:,3],kdtree.y[:,4],kdtree.y[:,5]]
            pred = convgraph.forward(kdtree, iter=20)
            loss, vec, mhs, div = graphPINN.MHS.loss(pred,true, logfn = None)
            if loss.item() < 1000:
                running_valid += loss.item()
                valid_vec += vec
                valid_mhs += mhs
                valid_div += div
            else:
#                 logfn('  skipped')
                valid_skipped += 1
        validation_loss[epoch] = running_valid/(len(validLoader)-valid_skipped)
        logfn(f'Validation loss: {validation_loss[epoch]}, validation time = {time()-start_time}', tq=False)
        logfn(f'running vec: {valid_vec / (len(validLoader)-valid_skipped)}, running mhs: {valid_mhs / (len(validLoader)-valid_skipped)}, running div: {valid_div / (len(validLoader)-valid_skipped)}')
        
            
    return running_loss, validation_loss

loss, validation_loss = train(convgraph, epochs=5)
logfn(validation_loss, tq=False)

In [None]:
print(f'{pretty_size(torch.cuda.memory_allocated())}/{pretty_size(torch.cuda.max_memory_allocated())} allocated, ' +
      f'{pretty_size(torch.cuda.memory_reserved())}/{pretty_size(torch.cuda.get_device_properties(0).total_memory)} reserved')

In [None]:
def fullValidate(epochs, filePrefix='epoch-', filePostfix='_trainsize-3092_k-50.pt'):
    convgraph.train(False)
    validLoader = pyg.loader.DataLoader(validset, batch_size=batch_size,shuffle=False)
    start_time = time()
    
    validation_loss = [0]*len(epochs)
    for epoch in epochs:
        convgraph.load_state_dict(torch.load(f'{filePrefix}{epoch}{filePostfix}'))
        logfn(f'Loaded epoch {epoch}')
        valid_skipped = 0
        for kdtree in tqdm(validLoader):
            kdtree.to(device)
            true = [kdtree.y[:,0:3],kdtree.y[:,3],kdtree.y[:,4],kdtree.y[:,5]]
            pred = convgraph.forward(kdtree, iter=20)
            loss = graphPINN.MHS.loss(pred,true, logfn = logging.info)
            if loss.item() < 100:
                validation_loss[epoch] += loss.item()
            else:
                logging.info('  skipped')
                valid_skipped += 1
            validation_loss[epoch] = validation_loss[epoch]/(len(validLoader)-valid_skipped)
        logfn(f'Validation loss: {validation_loss[epoch]}, validation time = {time()-start_time}')
    return validation_loss

validloss_full = fullValidate([0,1,2,3,4])

In [None]:
convgraph.load_state_dict(torch.load('epoch-0_trainsize-3092_k-100.pt'))

In [None]:
sample = dataset[0]
sample['bd','propagates','in'].edge_index, sample['bd','propagates','in'].edge_attr = \
            pyg.utils.dense_to_sparse(
                        torch.ones(sample['bd'].x.shape[0],sample['in'].x.shape[0])
            )
print(sample)

In [None]:
homosamp = sample.to_homogeneous()

In [None]:
print(homosamp.x)

In [None]:
print(homosamp.edge_index[:,homosamp.edge_type==1])

In [None]:
edges = pyg.utils.dense_to_sparse(torch.ones(1063,22862))
print(edges)

In [None]:
class BDPropGraph(pyg.nn.MessagePassing):
    def __init__(self, kernel):
        super().__init__(aggr='mean')
        self.kernel = kernel
    def reset_parameters(self):
        self.kernel.reset_parameters()
    def message(self, x_i, x_j, pos_i, pos_j, edge_attr):
        update = self.kernel.forward(torch.cat((pos_i,pos_j,x_j),1))
        
        return update
    def update(self, aggr_out):
        return aggr_out
    
    def forward(self, x_bd, x_int, pos_bd, pos_int):
        M = pos_int.shape[0]
        N = pos_bd.shape[0]
        x = self.propagate([range(M) for j in range(N)], size=(M,N), x=(x_bd,x_int), pos=(pos_bd,pos_int))
        return x