In [1]:
# Network
import torch
import torch.nn.functional as F

from torch_geometric.nn import GCNConv, ChebConv
from torch_geometric.nn import global_max_pool

# Data
from torch.utils.data.sampler import SubsetRandomSampler

from torch_geometric.data import Batch, Data, Dataset, DataLoader

# General
import numpy as np

# Util
import os.path as osp
import h5py
import pickle

import time


# Config

In [2]:
# Port from https://github.com/tkarras/progressive_growing_of_gans/blob/master/config.py
class EasyDict(dict):
    def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
    def __getattr__(self, name): return self[name]
    def __setattr__(self, name, value): self[name] = value
    def __delattr__(self, name): del self[name]

In [3]:
config = EasyDict()

config.model_name = "gcn_kipf"

config.data_path = "/home/jwalker/NeutronGNN/root_utils/data_1file.h5"
config.train_indices_file = "train_indicies.txt"
config.val_indices_file = "validation_indicies.txt"
config.test_indices_file = "test_indicies.txt"

config.dump_path = "dump"

config.num_data_workers = 0 # Sometime crashes if we do multiprocessing
config.device = 'gpu'
config.gpu_list = [0]

config.batch_size = 1 #32 # 256
config.lr=0.01
config.weight_decay=5e-4

config.epochs = 10
config.report_interval = 10 # 100
config.num_val_batches  = 32
config.valid_interval   = 100 # 10000

### Generate fake indicies

In [4]:
#f=h5py.File(config.data_path,'r')
#print( f["event_data"][0] )
#print( f["nhits"][0] )
#print( f["event_data"][0, :f["nhits"][0], :] )
#length = f["event_ids"].len()
#splits = 10
#random_shuffle = np.random.permutation(length)
#validation_indicies = random_shuffle[:length//splits]
#test_indicies = random_shuffle[length//splits:2*length//splits]
#train_indicies = random_shuffle[2*length//splits:]
#
#with open("train_indicies.txt", 'w') as f:
#    f.writelines(["{}\n".format(i) for i in train_indicies])
#    
#with open("validation_indicies.txt", 'w') as f:
#    f.writelines(["{}\n".format(i) for i in validation_indicies])
#
#with open("test_indicies.txt", 'w') as f:
#    f.writelines(["{}\n".format(i) for i in test_indicies])

# Data Loader

In [5]:
class WCH5Dataset(Dataset):
    """
    Dataset storing image-like data from Water Cherenkov detector
    memory-maps the detector data from hdf5 file
    The detector data must be uncompresses and unchunked
    labels are loaded into memory outright
    No other data is currently loaded 
    """

    # Override the default implementation
    def _download(self):
        pass
    
    def _process(self):
        pass
    
    
    def __init__(self, path, train_indices_file, val_indices_file, test_indices_file, nodes=15808,
                 transform=None, pre_transform=None, pre_filter=None, 
                 use_node_attr=False, use_edge_attr=False, cleaned=False):

        super(WCH5Dataset, self).__init__("", transform, pre_transform,
                                        pre_filter)
        
        f=h5py.File(path,'r')
        hdf5_event_data = f["event_data"]
        hdf5_labels=f["labels"]
        hdf5_nhits=f["nhits"]

        assert hdf5_event_data.shape[0] == hdf5_labels.shape[0]

        event_data_shape = hdf5_event_data.shape
        event_data_offset = hdf5_event_data.id.get_offset()
        event_data_dtype = hdf5_event_data.dtype

        #this creates a memory map - i.e. events are not loaded in memory here
        #only on get_item
        self.event_data = np.memmap(path, mode='r', shape=event_data_shape, 
                                    offset=event_data_offset, dtype=event_data_dtype)
        
        #this will fit easily in memory even for huge datasets
        self.labels = np.array(hdf5_labels)
        self.nhits = np.array(hdf5_nhits)
        self.nodes = nodes
        #self.load_edges(edge_index_pickle)
        
        self.transform=transform
        
        #the section below handles the subset
        #(for reduced dataset training tests)
        #as well as shuffling and train/test/validation splits
            
        self.train_indices = self.load_indicies(train_indices_file)
        self.val_indices = self.load_indicies(val_indices_file)
        self.test_indices = self.load_indicies(test_indices_file)
    
    def load_indicies(self, indicies_file):
        with open(indicies_file, 'r') as f:
            lines = f.readlines()
        # indicies = [int(l.strip()) for l in lines if not l.isspace()]
        indicies = [int(l.strip()) for l in lines]
        return indicies
    
    def load_edges(self, nhits):
        edge_index = torch.ones([nhits, nhits], dtype=torch.int64)
        #print(edge_index)
        self.edge_index=edge_index.to_sparse()._indices()
        #print(self.edge_index)
        #print("edge_index shape", self.edge_index.shape)
    
    def get(self, idx):
        nhits = self.nhits[idx]
        x = torch.from_numpy(self.event_data[idx, :nhits, :])
        #print( x )
        #print("x shape", x.shape)
        y = torch.tensor([self.labels[idx]], dtype=torch.int64)
        self.load_edges(nhits)

        return Data(x=x, y=y, edge_index=self.edge_index)

    def __len__(self):
        return self.labels.shape[0]

In [6]:
def get_loaders(path, train_indices_file, val_indices_file, test_indices_file, batch_size, workers):
    
    dataset = WCH5Dataset(path, train_indices_file, val_indices_file, test_indices_file)
                          
    train_loader=DataLoader(dataset, batch_size=batch_size, num_workers=workers,
                            pin_memory=True, sampler=SubsetRandomSampler(dataset.train_indices))

    val_loader=DataLoader(dataset, batch_size=batch_size, num_workers=workers,
                            pin_memory=True, sampler=SubsetRandomSampler(dataset.val_indices))

    test_loader=DataLoader(dataset, batch_size=batch_size, num_workers=workers,
                            pin_memory=True, sampler=SubsetRandomSampler(dataset.test_indices))
    
    return train_loader, val_loader, test_loader

# Network stuff

In [7]:
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = GCNConv(8, 16, cached=False)
        self.conv2 = GCNConv(16, 16, cached=False)
        self.conv3 = GCNConv(16, 5, cached=False)

    def forward(self, batch):
        x, edge_index, batch_index = batch.x, batch.edge_index, batch.batch
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, training=self.training)
        x = F.relu(self.conv2(x, edge_index))
        x = F.dropout(x, training=self.training)
        x = self.conv3(x, edge_index)
        x = global_max_pool(x, batch_index)
        return F.log_softmax(x, dim=1)

# Engine

In [8]:
# Python standard imports
from abc import ABC, abstractmethod
from time import strftime
from os import stat, mkdir
from math import floor, ceil

# PyTorch imports
from torch import device, load, save
from torch.nn import DataParallel
from torch.cuda import is_available

In [9]:
class Engine(ABC):

    def __init__(self, model, config):
        super().__init__()

        # Engine attributes
        self.model=model
        self.config=config

        # Determine the device to be used for model training and inference
        if (config.device == 'gpu') and config.gpu_list:
            print("Requesting GPUs. GPU list : " + str(config.gpu_list))
            self.devids=["cuda:{0}".format(x) for x in config.gpu_list]
            print("Main GPU : " + self.devids[0])

            if is_available():
                self.device=device(self.devids[0])
                if len(self.devids) > 1:
                    print("Using DataParallel on these devices: {}".format(self.devids))
                    self.model=DataParallel(self.model, device_ids=config.gpu_list, dim=0)
                print("CUDA is available")
            else:
                self.device=device("cpu")
                print("CUDA is not available")
        else:
            print("Unable to use GPU")
            self.device=device("cpu")

        # Send the model to the selected device
        self.model = DataParallel(self.model) if len(self.devids) > 1 else self.model # Changed
        self.model.to(self.device)

        # Setup the parameters tp save given the model type
        if type(self.model) == DataParallel:
            self.model_accs=self.model.module
        else:
            self.model_accs=self.model
  
        # Create the dataset object
        out = get_loaders(config.data_path, 
                      config.train_indices_file, config.val_indices_file, config.test_indices_file,   # Changed
                      config.batch_size, config.num_data_workers)

        self.train_loader, self.val_loader, self.test_loader = out
        # Define the variant dependent attributes
        self.criterion=None

        # Create the directory for saving the log and dump files
        self.dirpath=config.dump_path + strftime("%Y%m%d_%H%M%S") + "/"
        try:
            stat(self.dirpath)
        except:
            print("Creating a directory for run dump at : {}".format(self.dirpath))
            mkdir(self.dirpath)

        # Logging attributes
        self.train_log=CSVData(self.dirpath + "log_train.csv")
        self.val_log=CSVData(self.dirpath + "log_val.csv")

#         # Save a copy of the config in the dump path
#         save_config(self.config, self.dirpath + "config_file.ini")   # changed

    @abstractmethod
    def forward(self, data, mode):
        """Forward pass using self.data as input."""
        raise NotImplementedError

    def backward(self, predict, expected):
        """Backward pass using the loss computed for a mini-batch."""
        self.optimizer.zero_grad()  # Reset gradient accumulation
        loss = self.criterion(predict, expected)
        loss.backward()# Propagate the loss backwards
        self.optimizer.step()       # Update the optimizer parameters         
        
        return loss

    @abstractmethod
    def train(self):
        """Training loop over the entire dataset for a given number of epochs."""
        raise NotImplementedError

    def save_state(self, mode="latest"):
        """Save the model parameters in a file.
        
        Args :
        mode -- one of "latest", "best" to differentiate
                the latest model from the model with the
                lowest loss on the validation subset (default "latest")
        """
        path=self.dirpath + self.config.model_name + "_" + mode + ".pth"

        # Extract modules from the model dict and add to start_dict 
        modules=list(self.model_accs._modules.keys())
        state_dict={module: getattr(self.model_accs, module).state_dict() for module in modules}

        # Save the model parameter dict
        save(state_dict, path)

    def load_state(self, path):
        """Load the model parameters from a file.
        
        Args :
        path -- absolute path to the .pth file containing the dictionary
        with the model parameters to load from
        """
        # Open a file in read-binary mode
        with open(path, 'rb') as f:

            # Interpret the file using torch.load()
            checkpoint=load(f, map_location=self.device)

            print("Loading weights from file : {0}".format(path))

            local_module_keys=list(self.model_accs._modules.keys())
            for module in checkpoint.keys():
                if module in local_module_keys:
                    print("Loading weights for module = ", module)
                    getattr(self.model_accs, module).load_state_dict(checkpoint[module])
                    
        self.model.to(self.device)

In [10]:
# Python standard imports
from sys import stdout
from math import floor, ceil
from time import strftime, localtime

# PyTorch imports
from torch.optim import Adam

# WatChMaL imports
# from training_utils.engine import Engine
# from plot_utils.notebook_utils import CSVData



In [11]:
class CSVData:

    def __init__(self,fout):
        self.name  = fout
        self._fout = None
        self._str  = None
        self._dict = {}

    def record(self, keys, vals):
        for i, key in enumerate(keys):
            self._dict[key] = vals[i]

    def write(self):
        if self._str is None:
            self._fout=open(self.name,'w')
            self._str=''
            for i,key in enumerate(self._dict.keys()):
                if i:
                    self._fout.write(',')
                    self._str += ','
                self._fout.write(key)
                self._str+='{:f}'
            self._fout.write('\n')
            self._str+='\n'

        self._fout.write(self._str.format(*(self._dict.values())))

    def flush(self):
        if self._fout: self._fout.flush()

    def close(self):
        if self._str is not None:
            self._fout.close()

In [12]:
class EngineGraph(Engine):

    def __init__(self, model, config):
        super().__init__(model, config)
        self.criterion=F.nll_loss
        self.optimizer=Adam(self.model_accs.parameters(), lr=config.lr)
        
        self.keys = ['iteration', 'epoch', 'loss', 'acc']

    def forward(self, data, mode="train"):
        """Overrides the forward abstract method in Engine.py.
        
        Args:
        mode -- One of 'train', 'validation' 
        """

        # Set the correct grad_mode given the mode
        if mode == "train":
            self.model.train()
        elif mode in ["validation"]:
            self.model.eval()

        return self.model(data)

    def train(self):
        """Overrides the train method in Engine.py.
        
        Args: None
        """
        
        epochs          = self.config.epochs
        report_interval = self.config.report_interval
        valid_interval  = self.config.valid_interval
        num_val_batches = self.config.num_val_batches

        # Initialize counters
        epoch=0.
        iteration=0

        # Parameter to upadte when saving the best model
        best_loss=1000000.

        val_iter = iter(self.val_loader)
        
        # Global training loop for multiple epochs
        while (floor(epoch) < epochs):

            print('Epoch', floor(epoch),
                  'Starting @', strftime("%Y-%m-%d %H:%M:%S", localtime()))

            # Local training loop for a single epoch
            for data in self.train_loader:
                #print( data )
                data = data.to(self.device)
                #print( data )

                # Update the epoch and iteration
                epoch+=1. / len(self.train_loader)
                iteration += 1
                
                # Do a forward pass using data = self.data
                res=self.forward(data, mode="train")

                # Do a backward pass using loss = self.loss
                loss = self.backward(res, data.y)

                acc = res.argmax(1).eq(data.y).sum().item()/data.y.shape[0]
                
                # Record the metrics for the mini-batch in the log
                self.train_log.record(self.keys, [iteration, epoch, loss, acc])
                self.train_log.write()

                # Print the metrics at given intervals
                if iteration % report_interval == 0:
                    print("... Iteration %d ... Epoch %1.2f ... Loss %1.3f ... Acc %1.3f"
                          % (iteration, epoch, loss, acc))

                # Run validation on given intervals
                if iteration % valid_interval == 0:
                    with torch.no_grad():
                        val_loss=0.
                        val_acc=0.

                        for val_batch in range(num_val_batches):

                            try:
                                data=next(val_iter)
                            except StopIteration:
                                val_iter=iter(self.val_loader)
                                data=next(val_iter)
                            data = data.to(self.device)

                            # Extract the event data from the input data tuple
                            res=self.forward(data, mode="validation")
                            acc = res.argmax(1).eq(data.y).sum().item()/data.y.shape[0]

                            val_loss+=self.criterion(res, data.y)
                            val_acc+=acc

                        val_loss /= num_val_batches
                        val_acc /= num_val_batches

                        # Record the validation stats to the csv
                        self.val_log.record(self.keys, [iteration, epoch, loss, acc])
                        self.val_log.write()

                        # Save the best model
                        if val_loss < best_loss:
                            self.save_state(mode="best")
                            best_loss = val_loss

                        # Save the latest model
                        self.save_state(mode="latest")
                    

        self.val_log.close()
        self.train_log.close()

    def validate(self, subset):
        """Overrides the validate method in Engine.py.
        
        Args:
        subset          -- One of 'train', 'validation', 'test' to select the subset to perform validation on
        """
        # Print start message
        if subset == "train":
            message="Validating model on the train set"
        elif subset == "validation":
            message="Validating model on the validation set"
        elif subset == "test":
            message="Validating model on the test set"
        else:
            print("validate() : arg subset has to be one of train, validation, test")
            return None

        print(message)
        
        # Setup the CSV file for logging the output, path to save the actual and reconstructed events, dataloader iterator
        if subset == "train":
            self.log=CSVData(self.dirpath + "train_validation_log.csv")
            np_event_path=self.dirpath + "/train_valid_iteration_"
            data_iter=self.train_loader
        elif subset == "validation":
            self.log=CSVData(self.dirpath + "valid_validation_log.csv")
            np_event_path=self.dirpath + "/val_valid_iteration_"
            data_iter=self.val_loader
        else:
            self.log=CSVData(self.dirpath + "test_validation_log.csv")
            np_event_path=self.dirpath + "/test_validation_iteration_"
            data_iter=self.test_loader

        save_arr_dict={"events": [], "labels": [], "energies": []}
        
        with torch.no_grad():
            for iteration, data in enumerate(data_iter):
                data = data.to(self.device)

                stdout.write("Iteration : {}, Progress {} \n".format(iteration, iteration/len(data_iter)))
                res=self.forward(data, mode="validation")
                acc = res.argmax(1).eq(data.y).sum().item()/data.y.shape[0]
                loss = self.criterion(res, data.y)

                # Log/Report
                self.log.record(["Iteration", "loss", "acc"], [iteration, loss, acc])
                self.log.write()


# Testing

### Initiatlization

In [13]:
model = Net()

In [14]:
engine = EngineGraph(model, config)

Requesting GPUs. GPU list : [0]
Main GPU : cuda:0
CUDA is available
Creating a directory for run dump at : dump20200203_101737/


### Train

In [15]:
engine.train()

Epoch 0 Starting @ 2020-02-03 10:17:37
... Iteration 10 ... Epoch 0.01 ... Loss 0.000 ... Acc 1.000
... Iteration 20 ... Epoch 0.02 ... Loss 34.924 ... Acc 0.000
... Iteration 30 ... Epoch 0.02 ... Loss 47.747 ... Acc 0.000
... Iteration 40 ... Epoch 0.03 ... Loss 4.616 ... Acc 0.000
... Iteration 50 ... Epoch 0.04 ... Loss 0.000 ... Acc 1.000
... Iteration 60 ... Epoch 0.05 ... Loss 0.000 ... Acc 1.000
... Iteration 70 ... Epoch 0.05 ... Loss 1.305 ... Acc 0.000
... Iteration 80 ... Epoch 0.06 ... Loss 0.001 ... Acc 1.000
... Iteration 90 ... Epoch 0.07 ... Loss 0.800 ... Acc 0.000
... Iteration 100 ... Epoch 0.08 ... Loss 0.020 ... Acc 1.000
... Iteration 110 ... Epoch 0.08 ... Loss 0.561 ... Acc 1.000
... Iteration 120 ... Epoch 0.09 ... Loss 3.466 ... Acc 0.000
... Iteration 130 ... Epoch 0.10 ... Loss 0.863 ... Acc 0.000
... Iteration 140 ... Epoch 0.11 ... Loss 0.480 ... Acc 1.000
... Iteration 150 ... Epoch 0.11 ... Loss 0.341 ... Acc 1.000
... Iteration 160 ... Epoch 0.12 ... L

... Iteration 1330 ... Epoch 1.02 ... Loss 0.726 ... Acc 0.000
... Iteration 1340 ... Epoch 1.02 ... Loss 0.659 ... Acc 1.000
... Iteration 1350 ... Epoch 1.03 ... Loss 1.150 ... Acc 0.000
... Iteration 1360 ... Epoch 1.04 ... Loss 0.429 ... Acc 1.000
... Iteration 1370 ... Epoch 1.05 ... Loss 0.800 ... Acc 0.000
... Iteration 1380 ... Epoch 1.06 ... Loss 0.687 ... Acc 1.000
... Iteration 1390 ... Epoch 1.06 ... Loss 0.768 ... Acc 0.000
... Iteration 1400 ... Epoch 1.07 ... Loss 0.732 ... Acc 0.000
... Iteration 1410 ... Epoch 1.08 ... Loss 0.696 ... Acc 0.000
... Iteration 1420 ... Epoch 1.09 ... Loss 0.651 ... Acc 1.000
... Iteration 1430 ... Epoch 1.09 ... Loss 0.646 ... Acc 1.000
... Iteration 1440 ... Epoch 1.10 ... Loss 0.678 ... Acc 1.000
... Iteration 1450 ... Epoch 1.11 ... Loss 0.737 ... Acc 0.000
... Iteration 1460 ... Epoch 1.12 ... Loss 0.705 ... Acc 0.000
... Iteration 1470 ... Epoch 1.12 ... Loss 0.977 ... Acc 0.000
... Iteration 1480 ... Epoch 1.13 ... Loss 0.901 ... Ac

... Iteration 2640 ... Epoch 2.02 ... Loss 0.694 ... Acc 1.000
... Iteration 2650 ... Epoch 2.03 ... Loss 0.670 ... Acc 1.000
... Iteration 2660 ... Epoch 2.03 ... Loss 0.884 ... Acc 0.000
... Iteration 2670 ... Epoch 2.04 ... Loss 0.624 ... Acc 1.000
... Iteration 2680 ... Epoch 2.05 ... Loss 0.778 ... Acc 0.000
... Iteration 2690 ... Epoch 2.06 ... Loss 0.777 ... Acc 0.000
... Iteration 2700 ... Epoch 2.06 ... Loss 0.663 ... Acc 1.000
... Iteration 2710 ... Epoch 2.07 ... Loss 0.929 ... Acc 0.000
... Iteration 2720 ... Epoch 2.08 ... Loss 0.946 ... Acc 0.000
... Iteration 2730 ... Epoch 2.09 ... Loss 0.820 ... Acc 0.000
... Iteration 2740 ... Epoch 2.09 ... Loss 0.808 ... Acc 0.000
... Iteration 2750 ... Epoch 2.10 ... Loss 0.927 ... Acc 0.000
... Iteration 2760 ... Epoch 2.11 ... Loss 0.716 ... Acc 0.000
... Iteration 2770 ... Epoch 2.12 ... Loss 0.670 ... Acc 1.000
... Iteration 2780 ... Epoch 2.13 ... Loss 0.635 ... Acc 1.000
... Iteration 2790 ... Epoch 2.13 ... Loss 0.663 ... Ac

... Iteration 3940 ... Epoch 3.01 ... Loss 0.707 ... Acc 0.000
... Iteration 3950 ... Epoch 3.02 ... Loss 0.710 ... Acc 0.000
... Iteration 3960 ... Epoch 3.03 ... Loss 0.696 ... Acc 1.000
... Iteration 3970 ... Epoch 3.04 ... Loss 0.709 ... Acc 0.000
... Iteration 3980 ... Epoch 3.04 ... Loss 0.681 ... Acc 1.000
... Iteration 3990 ... Epoch 3.05 ... Loss 0.692 ... Acc 1.000
... Iteration 4000 ... Epoch 3.06 ... Loss 0.728 ... Acc 0.000
... Iteration 4010 ... Epoch 3.07 ... Loss 0.674 ... Acc 1.000
... Iteration 4020 ... Epoch 3.07 ... Loss 0.678 ... Acc 1.000
... Iteration 4030 ... Epoch 3.08 ... Loss 0.668 ... Acc 1.000
... Iteration 4040 ... Epoch 3.09 ... Loss 0.726 ... Acc 0.000
... Iteration 4050 ... Epoch 3.10 ... Loss 0.702 ... Acc 0.000
... Iteration 4060 ... Epoch 3.10 ... Loss 0.714 ... Acc 0.000
... Iteration 4070 ... Epoch 3.11 ... Loss 0.740 ... Acc 0.000
... Iteration 4080 ... Epoch 3.12 ... Loss 0.733 ... Acc 0.000
... Iteration 4090 ... Epoch 3.13 ... Loss 0.653 ... Ac

... Iteration 5260 ... Epoch 4.02 ... Loss 0.735 ... Acc 0.000
... Iteration 5270 ... Epoch 4.03 ... Loss 0.648 ... Acc 1.000
... Iteration 5280 ... Epoch 4.04 ... Loss 0.622 ... Acc 1.000
... Iteration 5290 ... Epoch 4.04 ... Loss 0.786 ... Acc 0.000
... Iteration 5300 ... Epoch 4.05 ... Loss 0.778 ... Acc 0.000
... Iteration 5310 ... Epoch 4.06 ... Loss 0.621 ... Acc 1.000
... Iteration 5320 ... Epoch 4.07 ... Loss 0.790 ... Acc 0.000
... Iteration 5330 ... Epoch 4.07 ... Loss 0.611 ... Acc 1.000
... Iteration 5340 ... Epoch 4.08 ... Loss 0.595 ... Acc 1.000
... Iteration 5350 ... Epoch 4.09 ... Loss 0.827 ... Acc 0.000
... Iteration 5360 ... Epoch 4.10 ... Loss 0.588 ... Acc 1.000
... Iteration 5370 ... Epoch 4.11 ... Loss 0.579 ... Acc 1.000
... Iteration 5380 ... Epoch 4.11 ... Loss 0.810 ... Acc 0.000
... Iteration 5390 ... Epoch 4.12 ... Loss 0.786 ... Acc 0.000
... Iteration 5400 ... Epoch 4.13 ... Loss 0.619 ... Acc 1.000
... Iteration 5410 ... Epoch 4.14 ... Loss 0.769 ... Ac

... Iteration 6570 ... Epoch 5.02 ... Loss 0.699 ... Acc 0.000
... Iteration 6580 ... Epoch 5.03 ... Loss 0.702 ... Acc 0.000
... Iteration 6590 ... Epoch 5.04 ... Loss 0.662 ... Acc 1.000
... Iteration 6600 ... Epoch 5.05 ... Loss 0.737 ... Acc 0.000
... Iteration 6610 ... Epoch 5.05 ... Loss 0.695 ... Acc 0.000
... Iteration 6620 ... Epoch 5.06 ... Loss 0.643 ... Acc 1.000
... Iteration 6630 ... Epoch 5.07 ... Loss 0.762 ... Acc 0.000
... Iteration 6640 ... Epoch 5.08 ... Loss 0.753 ... Acc 0.000
... Iteration 6650 ... Epoch 5.08 ... Loss 0.760 ... Acc 0.000
... Iteration 6660 ... Epoch 5.09 ... Loss 0.748 ... Acc 0.000
... Iteration 6670 ... Epoch 5.10 ... Loss 0.668 ... Acc 1.000
... Iteration 6680 ... Epoch 5.11 ... Loss 0.684 ... Acc 1.000
... Iteration 6690 ... Epoch 5.11 ... Loss 0.719 ... Acc 0.000
... Iteration 6700 ... Epoch 5.12 ... Loss 0.697 ... Acc 0.000
... Iteration 6710 ... Epoch 5.13 ... Loss 0.697 ... Acc 0.000
... Iteration 6720 ... Epoch 5.14 ... Loss 0.667 ... Ac

... Iteration 7870 ... Epoch 6.02 ... Loss 0.721 ... Acc 0.000
... Iteration 7880 ... Epoch 6.02 ... Loss 0.742 ... Acc 0.000
... Iteration 7890 ... Epoch 6.03 ... Loss 0.754 ... Acc 0.000
... Iteration 7900 ... Epoch 6.04 ... Loss 0.625 ... Acc 1.000
... Iteration 7910 ... Epoch 6.05 ... Loss 0.652 ... Acc 1.000
... Iteration 7920 ... Epoch 6.06 ... Loss 0.709 ... Acc 0.000
... Iteration 7930 ... Epoch 6.06 ... Loss 0.691 ... Acc 1.000
... Iteration 7940 ... Epoch 6.07 ... Loss 0.733 ... Acc 0.000
... Iteration 7950 ... Epoch 6.08 ... Loss 0.747 ... Acc 0.000
... Iteration 7960 ... Epoch 6.09 ... Loss 0.754 ... Acc 0.000
... Iteration 7970 ... Epoch 6.09 ... Loss 0.649 ... Acc 1.000
... Iteration 7980 ... Epoch 6.10 ... Loss 0.742 ... Acc 0.000
... Iteration 7990 ... Epoch 6.11 ... Loss 0.638 ... Acc 1.000
... Iteration 8000 ... Epoch 6.12 ... Loss 0.608 ... Acc 1.000
... Iteration 8010 ... Epoch 6.12 ... Loss 0.593 ... Acc 1.000
... Iteration 8020 ... Epoch 6.13 ... Loss 0.791 ... Ac

... Iteration 9170 ... Epoch 7.01 ... Loss 0.655 ... Acc 1.000
... Iteration 9180 ... Epoch 7.02 ... Loss 0.751 ... Acc 0.000
... Iteration 9190 ... Epoch 7.03 ... Loss 0.647 ... Acc 1.000
... Iteration 9200 ... Epoch 7.03 ... Loss 0.746 ... Acc 0.000
... Iteration 9210 ... Epoch 7.04 ... Loss 0.642 ... Acc 1.000
... Iteration 9220 ... Epoch 7.05 ... Loss 0.743 ... Acc 0.000
... Iteration 9230 ... Epoch 7.06 ... Loss 0.641 ... Acc 1.000
... Iteration 9240 ... Epoch 7.06 ... Loss 0.745 ... Acc 0.000
... Iteration 9250 ... Epoch 7.07 ... Loss 0.729 ... Acc 0.000
... Iteration 9260 ... Epoch 7.08 ... Loss 0.709 ... Acc 0.000
... Iteration 9270 ... Epoch 7.09 ... Loss 0.689 ... Acc 1.000
... Iteration 9280 ... Epoch 7.09 ... Loss 0.672 ... Acc 1.000
... Iteration 9290 ... Epoch 7.10 ... Loss 0.757 ... Acc 0.000
... Iteration 9300 ... Epoch 7.11 ... Loss 0.795 ... Acc 0.000
... Iteration 9310 ... Epoch 7.12 ... Loss 0.803 ... Acc 0.000
... Iteration 9320 ... Epoch 7.13 ... Loss 0.606 ... Ac

... Iteration 10490 ... Epoch 8.02 ... Loss 0.752 ... Acc 0.000
... Iteration 10500 ... Epoch 8.03 ... Loss 0.631 ... Acc 1.000
... Iteration 10510 ... Epoch 8.04 ... Loss 0.629 ... Acc 1.000
... Iteration 10520 ... Epoch 8.04 ... Loss 0.635 ... Acc 1.000
... Iteration 10530 ... Epoch 8.05 ... Loss 0.742 ... Acc 0.000
... Iteration 10540 ... Epoch 8.06 ... Loss 0.660 ... Acc 1.000
... Iteration 10550 ... Epoch 8.07 ... Loss 0.725 ... Acc 0.000
... Iteration 10560 ... Epoch 8.07 ... Loss 0.672 ... Acc 1.000
... Iteration 10570 ... Epoch 8.08 ... Loss 0.708 ... Acc 0.000
... Iteration 10580 ... Epoch 8.09 ... Loss 0.671 ... Acc 1.000
... Iteration 10590 ... Epoch 8.10 ... Loss 0.646 ... Acc 1.000
... Iteration 10600 ... Epoch 8.10 ... Loss 0.617 ... Acc 1.000
... Iteration 10610 ... Epoch 8.11 ... Loss 0.624 ... Acc 1.000
... Iteration 10620 ... Epoch 8.12 ... Loss 0.769 ... Acc 0.000
... Iteration 10630 ... Epoch 8.13 ... Loss 0.776 ... Acc 0.000
... Iteration 10640 ... Epoch 8.13 ... L

... Iteration 11790 ... Epoch 9.01 ... Loss 0.609 ... Acc 1.000
... Iteration 11800 ... Epoch 9.02 ... Loss 0.571 ... Acc 1.000
... Iteration 11810 ... Epoch 9.03 ... Loss 0.871 ... Acc 0.000
... Iteration 11820 ... Epoch 9.04 ... Loss 0.877 ... Acc 0.000
... Iteration 11830 ... Epoch 9.04 ... Loss 0.536 ... Acc 1.000
... Iteration 11840 ... Epoch 9.05 ... Loss 0.535 ... Acc 1.000
... Iteration 11850 ... Epoch 9.06 ... Loss 0.873 ... Acc 0.000
... Iteration 11860 ... Epoch 9.07 ... Loss 0.825 ... Acc 0.000
... Iteration 11870 ... Epoch 9.07 ... Loss 0.775 ... Acc 0.000
... Iteration 11880 ... Epoch 9.08 ... Loss 0.715 ... Acc 0.000
... Iteration 11890 ... Epoch 9.09 ... Loss 0.726 ... Acc 0.000
... Iteration 11900 ... Epoch 9.10 ... Loss 0.640 ... Acc 1.000
... Iteration 11910 ... Epoch 9.11 ... Loss 0.595 ... Acc 1.000
... Iteration 11920 ... Epoch 9.11 ... Loss 0.831 ... Acc 0.000
... Iteration 11930 ... Epoch 9.12 ... Loss 0.830 ... Acc 0.000
... Iteration 11940 ... Epoch 9.13 ... L

### Test

In [16]:
engine.validate("test")

Validating model on the test set
Iteration : 0, Progress 0.0 
Iteration : 1, Progress 0.006097560975609756 
Iteration : 2, Progress 0.012195121951219513 
Iteration : 3, Progress 0.018292682926829267 
Iteration : 4, Progress 0.024390243902439025 
Iteration : 5, Progress 0.03048780487804878 
Iteration : 6, Progress 0.036585365853658534 
Iteration : 7, Progress 0.042682926829268296 
Iteration : 8, Progress 0.04878048780487805 
Iteration : 9, Progress 0.054878048780487805 
Iteration : 10, Progress 0.06097560975609756 
Iteration : 11, Progress 0.06707317073170732 
Iteration : 12, Progress 0.07317073170731707 
Iteration : 13, Progress 0.07926829268292683 
Iteration : 14, Progress 0.08536585365853659 
Iteration : 15, Progress 0.09146341463414634 
Iteration : 16, Progress 0.0975609756097561 
Iteration : 17, Progress 0.10365853658536585 
Iteration : 18, Progress 0.10975609756097561 
Iteration : 19, Progress 0.11585365853658537 
Iteration : 20, Progress 0.12195121951219512 
Iteration : 21, Progr

### Save and Load

In [17]:
engine.save_state()

In [18]:
model.conv1.weight

Parameter containing:
tensor([[-0.2022, -0.3764, -0.1560, -0.1571, -0.2977,  0.0726,  0.4064,  0.0673,
         -0.4541, -0.1497,  0.2524, -0.1232, -0.7010, -0.4255, -0.2545, -0.3702],
        [-0.4060, -0.1885, -0.1186, -0.2231, -0.2302, -0.2523, -0.1852, -0.2108,
         -0.2764, -0.0751, -0.2591,  0.6272, -0.2774, -0.3769, -0.2426, -0.2058],
        [-0.0785, -0.4351, -0.1136,  0.4038, -0.2329,  0.3133,  0.0311,  0.1175,
          0.4586, -0.0597, -0.2328,  0.3186, -0.3812, -0.1632, -0.0734,  0.2514],
        [-0.3851,  0.5569,  0.4015,  0.1767,  0.4602,  0.3272, -0.0915,  0.1819,
         -0.3966, -0.2142, -0.0629,  0.2916,  0.1185,  0.3014,  0.2812,  0.0142],
        [-0.4202,  0.0276, -0.0707, -0.1376, -0.1610, -0.2124, -0.3863,  0.1287,
         -0.3163, -0.0547, -0.1826, -0.1627,  0.1016, -0.4186, -0.0992, -0.3123],
        [-0.3600,  0.7437,  0.0240,  0.2668, -0.4308, -0.4151,  0.0286, -0.0015,
          0.4695,  0.1462,  0.4536, -0.5511, -0.5494,  0.4633,  0.3983, -0.0885],


In [19]:
model.conv1.weight = torch.nn.Parameter(torch.Tensor(np.random.rand(8,16)))
model.conv1.weight

Parameter containing:
tensor([[0.6157, 0.2432, 0.8556, 0.3369, 0.3435, 0.7384, 0.5346, 0.1591, 0.8238,
         0.6595, 0.6798, 0.2926, 0.9120, 0.5067, 0.8086, 0.7089],
        [0.8408, 0.2373, 0.4983, 0.8418, 0.7784, 0.7952, 0.4930, 0.4700, 0.7724,
         0.7797, 0.6266, 0.4994, 0.3589, 0.2786, 0.7338, 0.8408],
        [0.5454, 0.5598, 0.3705, 0.9656, 0.0321, 0.6190, 0.9263, 0.2375, 0.3505,
         0.2175, 0.5283, 0.8647, 0.0730, 0.4041, 0.2601, 0.9593],
        [0.9331, 0.7923, 0.7750, 0.0777, 0.7771, 0.7115, 0.1746, 0.7311, 0.8464,
         0.8135, 0.3760, 0.0310, 0.6498, 0.1042, 0.1139, 0.2601],
        [0.9262, 0.9690, 0.4995, 0.4369, 0.4268, 0.2747, 0.8708, 0.1315, 0.2535,
         0.0839, 0.0761, 0.1398, 0.3198, 0.0541, 0.7726, 0.4492],
        [0.5795, 0.4856, 0.6194, 0.8952, 0.9648, 0.6960, 0.3843, 0.3602, 0.8442,
         0.4476, 0.9379, 0.0793, 0.4839, 0.3902, 0.6733, 0.1277],
        [0.4628, 0.1021, 0.2197, 0.0258, 0.9070, 0.3676, 0.2146, 0.7431, 0.8067,
         0.1221

In [20]:
engine.load_state(osp.join(engine.dirpath, "gcn_kipf_latest.pth"))

Loading weights from file : dump20200203_101737/gcn_kipf_latest.pth
Loading weights for module =  conv1
Loading weights for module =  conv2
Loading weights for module =  conv3


In [21]:
model.conv1.weight

Parameter containing:
tensor([[-0.2022, -0.3764, -0.1560, -0.1571, -0.2977,  0.0726,  0.4064,  0.0673,
         -0.4541, -0.1497,  0.2524, -0.1232, -0.7010, -0.4255, -0.2545, -0.3702],
        [-0.4060, -0.1885, -0.1186, -0.2231, -0.2302, -0.2523, -0.1852, -0.2108,
         -0.2764, -0.0751, -0.2591,  0.6272, -0.2774, -0.3769, -0.2426, -0.2058],
        [-0.0785, -0.4351, -0.1136,  0.4038, -0.2329,  0.3133,  0.0311,  0.1175,
          0.4586, -0.0597, -0.2328,  0.3186, -0.3812, -0.1632, -0.0734,  0.2514],
        [-0.3851,  0.5569,  0.4015,  0.1767,  0.4602,  0.3272, -0.0915,  0.1819,
         -0.3966, -0.2142, -0.0629,  0.2916,  0.1185,  0.3014,  0.2812,  0.0142],
        [-0.4202,  0.0276, -0.0707, -0.1376, -0.1610, -0.2124, -0.3863,  0.1287,
         -0.3163, -0.0547, -0.1826, -0.1627,  0.1016, -0.4186, -0.0992, -0.3123],
        [-0.3600,  0.7437,  0.0240,  0.2668, -0.4308, -0.4151,  0.0286, -0.0015,
          0.4695,  0.1462,  0.4536, -0.5511, -0.5494,  0.4633,  0.3983, -0.0885],


### Model works after loading

In [22]:
engine.validate("test")

Validating model on the test set
Iteration : 0, Progress 0.0 
Iteration : 1, Progress 0.006097560975609756 
Iteration : 2, Progress 0.012195121951219513 
Iteration : 3, Progress 0.018292682926829267 
Iteration : 4, Progress 0.024390243902439025 
Iteration : 5, Progress 0.03048780487804878 
Iteration : 6, Progress 0.036585365853658534 
Iteration : 7, Progress 0.042682926829268296 
Iteration : 8, Progress 0.04878048780487805 
Iteration : 9, Progress 0.054878048780487805 
Iteration : 10, Progress 0.06097560975609756 
Iteration : 11, Progress 0.06707317073170732 
Iteration : 12, Progress 0.07317073170731707 
Iteration : 13, Progress 0.07926829268292683 
Iteration : 14, Progress 0.08536585365853659 
Iteration : 15, Progress 0.09146341463414634 
Iteration : 16, Progress 0.0975609756097561 
Iteration : 17, Progress 0.10365853658536585 
Iteration : 18, Progress 0.10975609756097561 
Iteration : 19, Progress 0.11585365853658537 
Iteration : 20, Progress 0.12195121951219512 
Iteration : 21, Progr