In [1]:
import numpy as np
import torch
import torch_geometric as pyg
import graphPINN
import math

from time import time
from tqdm.notebook import tqdm
from scipy.io import savemat, loadmat
import os

# os.environ['MKL_THREADING_LAYER'] = 'GNU' # fixes a weird intel multiprocessing error with numpy

In [2]:
folder = "C:\\Users\\NASA\\Documents\\ML_checkpoints\\2023-04-27\\"
if not os.path.exists(folder):
    os.makedirs(folder)

In [3]:
logfn = graphPINN.debug.Logfn(folder)

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
for j in range(torch.cuda.device_count()):
    logfn(f"{j}: {graphPINN.debug.pretty_size(torch.cuda.get_device_properties(j).total_memory)} of {'cuda' if torch.cuda.is_available() else 'cpu'} memory")

0: 48.0 GiB of cuda memory
1: 48.0 GiB of cuda memory


In [4]:
k = 100
ddp = True

dataset = graphPINN.data.MHSDataset(f'D:\\scattered_data_k={k}',k=k)
propdesign = [12,6,3]
# propdesign = [12,3]
convdesign = [18,9,6,3]
# convdesign = [18,3]

propkernel = graphPINN.KernelNN(propdesign, torch.nn.ReLU)
propgraph = graphPINN.BDPropGraph(propkernel)
convkernel = graphPINN.KernelNN(convdesign, torch.nn.ReLU)
convgraph = graphPINN.ConvGraph(convkernel)
model = graphPINN.FullModel(propgraph, convgraph)

trainset, validset, testset = torch.utils.data.random_split(dataset,[0.8, 0.1, 0.1],generator=torch.Generator().manual_seed(314))
# trainset, validset, testset = torch.utils.data.random_split(dataset,[0.01, 0.005, 0.985],generator=torch.Generator().manual_seed(314))

In [None]:
key = 0
lossdict = {'index_array':[[0,1,2],[3,4,5],[0,1,2],[3,4,5],-1]}

for index in lossdict['index_array']:
    print(f'key {key} - index {index}')
    training_loss, validation_loss, state_dict = graphPINN.learn.train(
                model, trainset, validset,
                epochs=5, logfn=logfn, checkpointfile=folder, use_ddp = ddp)
    model.load_state_dict(state_dict)
    lossdict[f'train{key}'] = (sum(training_loss)/len(training_loss)).cpu().numpy()
    lossdict[f'valid{key}'] = (sum(validation_loss)/len(validation_loss)).cpu().numpy()
    logfn(f'training loss:\n{lossdict[f"train{key}"]}')
    logfn(f'validation loss:\n{lossdict[f"valid{key}"]}')
    savemat(f'{folder}loss_key-{key}_params-{math.prod(convdesign)+math.prod(propdesign)}.mat', lossdict)
    
    key = key + 1

lossdict['index_array'] = str(lossdict['index_array'])
torch.save(model, f'{folder}model_trainsize-{len(trainset)}_k-{k}_params-{math.prod(convdesign)+math.prod(propdesign)}.pt')
savemat(f'{folder}loss_{epochs}_trainsize-{len(trainset)}_k-{k}_params-{math.prod(convdesign)+math.prod(propdesign)}.mat', lossdict)

key 0 - index [0, 1, 2]
Starting on rank 0
Starting on rank 1
  [0] iter 1/3, loss 7.329127311706543
  [1] iter 1/3, loss 5.297978401184082
  [0] iter 2/3, loss 5.2879533767700195
  [1] iter 2/3, loss 6.113253593444824
  [0] iter 3/3, loss 7.863614559173584
[0] Epoch 1 completed. Loss: 6.826898574829102; Total time: 782.5795669555664
[0] running vec: 1.3391778469085693, running mhs: 5.487493515014648, running div: 0.00022701657144352794
  [1] iter 3/3, loss 7.451117515563965
[1] Epoch 1 completed. Loss: 6.287449836730957; Total time: 785.504088640213
[1] running vec: 1.2544631958007812, running mhs: 5.032764911651611, running div: 0.00022169403382577002
  [0] iter 1/1, loss 6.318548202514648
[0] Validation loss: 6.318548202514648; validation time: 260.70579767227173
[0] running vec: 1.3763699531555176, running mhs: 4.941967487335205, running div: 0.00021071200899314135
  [1] iter 1/1, loss 5.552245140075684
[1] Validation loss: 5.552245140075684; validation time: 260.92543387413025
[1]

In [None]:
print((sum(training_loss)/len(training_loss)).cpu().numpy())
print((sum(validation_loss)/len(validation_loss)).cpu().numpy())

In [13]:
model.load_state_dict(torch.load("C:\\Users\\NASA\\Documents\\ML_checkpoints\\2023-04-27\\2023-04-27_1682625968.9837143epoch-10.pt"))

<All keys matched successfully>

In [14]:
torch.save(model.module,"C:\\Users\\NASA\\Documents\\ML_checkpoints\\2023-04-27\\2023-04-27_1682625968.9837143model-10.pt")