In [1]:
# Network
import torch
from torch.nn import Linear
import torch.nn.functional as F

from torch_geometric.nn import GCNConv, ChebConv
from torch_geometric.nn import global_max_pool

# Data
from torch.utils.data.sampler import SubsetRandomSampler, Sampler

from torch_geometric.data import Batch, Data, Dataset, DataLoader

# General
import numpy as np

# Util
import os.path as osp
import h5py
import pickle

import time

### Generate fake indicies

In [2]:
# length = 2937
# splits = 10
# random_shuffle = np.random.permutation(length)
# validation_indicies = random_shuffle[:length//splits]
# test_indicies = random_shuffle[length//splits:2*length//splits]
# train_indicies = random_shuffle[2*length//splits:]

# with open("train_indicies.txt", 'w') as f:
#     f.writelines(["{}\n".format(i) for i in train_indicies])
    
# with open("validation_indicies.txt", 'w') as f:
#     f.writelines(["{}\n".format(i) for i in validation_indicies])

# with open("test_indicies.txt", 'w') as f:
#     f.writelines(["{}\n".format(i) for i in test_indicies])

# Config

In [3]:
# Port from https://github.com/tkarras/progressive_growing_of_gans/blob/master/config.py
class EasyDict(dict):
    def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
    def __getattr__(self, name): return self[name]
    def __setattr__(self, name, value): self[name] = value
    def __delattr__(self, name): del self[name]

In [4]:
config = EasyDict()

config.model_name = "gcn_kipf"

config.data_path = "/app/test_data/IWCDmPMT_4pi_fulltank_test_graphnet.h5"
config.train_indices_file = "/app/test_data/IWCDmPMT_4pi_fulltank_test_splits/train.txt"
config.val_indices_file = "/app/test_data/IWCDmPMT_4pi_fulltank_test_splits/val.txt"
config.test_indices_file = "/app/test_data/IWCDmPMT_4pi_fulltank_test_splits/test.txt"
config.edge_index_pickle = "/app/GraphNets/metadata/edges_dict.pkl"

config.dump_path = "/app/GraphNets/dump/gcn"

config.num_data_workers = 0 # Sometime crashes if we do multiprocessing
config.device = 'gpu'
config.gpu_list = [0]

config.batch_size = 32 # 256
config.validate_batch_size = 32
config.lr=0.01
config.weight_decay=5e-4

config.epochs = 1
config.report_interval = 10 # 100
config.num_val_batches  = 32
config.valid_interval   = 100 # 10000

# Data Loader

In [5]:
class WCH5Dataset(Dataset):
    """
    Dataset storing image-like data from Water Cherenkov detector
    memory-maps the detector data from hdf5 file
    The detector data must be uncompresses and unchunked
    labels are loaded into memory outright
    No other data is currently loaded 
    """

    # Override the default implementation
    def _download(self):
        pass
    
    def _process(self):
        pass
    
    
    def __init__(self, path, train_indices_file, val_indices_file, test_indices_file, 
                 edge_index_pickle, nodes=15808,
                 transform=None, pre_transform=None, pre_filter=None, 
                 use_node_attr=False, use_edge_attr=False, cleaned=False):

        super(WCH5Dataset, self).__init__("", transform, pre_transform,
                                        pre_filter)
        
        f=h5py.File(path,'r')
        hdf5_event_data = f["event_data"]
        hdf5_labels=f["labels"]

        assert hdf5_event_data.shape[0] == hdf5_labels.shape[0]

        event_data_shape = hdf5_event_data.shape
        event_data_offset = hdf5_event_data.id.get_offset()
        event_data_dtype = hdf5_event_data.dtype

        #this creates a memory map - i.e. events are not loaded in memory here
        #only on get_item
        self.event_data = np.memmap(path, mode='r', shape=event_data_shape, 
                                    offset=event_data_offset, dtype=event_data_dtype)
        
        #this will fit easily in memory even for huge datasets
        self.labels = np.array(hdf5_labels)
        self.nodes = nodes
        self.load_edges(edge_index_pickle)
        
        self.transform=transform
        
        #the section below handles the subset
        #(for reduced dataset training tests)
        #as well as shuffling and train/test/validation splits
            
        self.train_indices = self.load_indicies(train_indices_file)
        self.val_indices = self.load_indicies(val_indices_file)
        self.test_indices = self.load_indicies(test_indices_file)
    
    def load_indicies(self, indicies_file):
        with open(indicies_file, 'r') as f:
            lines = f.readlines()
        # indicies = [int(l.strip()) for l in lines if not l.isspace()]
        indicies = [int(l.strip()) for l in lines]
        return indicies
    
    def load_edges(self, edge_index_pickle):
        edge_index = torch.zeros([self.nodes, self.nodes], dtype=torch.int64)

        with open(edge_index_pickle, 'rb') as f:
            edges = pickle.load(f)

            for k,vs in edges.items():
                for v in vs:
                    edge_index[k,v] = 1

        self.edge_index=edge_index.to_sparse()._indices()
    
    def get(self, idx):
        x = torch.from_numpy(self.event_data[idx])
        y = torch.tensor([self.labels[idx]], dtype=torch.int64)

        return Data(x=x, y=y, edge_index=self.edge_index)

    def __len__(self):
        return self.labels.shape[0]

In [6]:
def get_loaders(path, train_indices_file, val_indices_file, test_indices_file, edges_dict_pickle, batch_size, workers):
    
    dataset = WCH5Dataset(path, train_indices_file, val_indices_file, test_indices_file, edges_dict_pickle)
                          
    train_loader=DataLoader(dataset, batch_size=batch_size, num_workers=workers,
                            pin_memory=True, sampler=SubsetRandomSampler(dataset.train_indices))

    val_loader=DataLoader(dataset, batch_size=batch_size, num_workers=workers,
                            pin_memory=True, sampler=SubsetRandomSampler(dataset.val_indices))

    return train_loader, val_loader, dataset

In [7]:
class SubsetSequentialSampler(Sampler):
    r"""Samples elements randomly from a given list of indices, without replacement.

    Arguments:
        indices (sequence): a sequence of indices
    """

    def __init__(self, indices):
        self.indices = indices

    def __iter__(self):
#         return (i for i in self.indices)
        return iter(self.indices)

    def __len__(self):
        return len(self.indices)

# Network stuff

In [8]:
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = GCNConv(2, 8, cached=False)
        self.conv2 = GCNConv(8, 32, cached=False)
        self.conv3 = GCNConv(32, 128, cached=False)
        self.linear = Linear(128, 3)

    def forward(self, batch):
        x, edge_index, batch_index = batch.x, batch.edge_index, batch.batch
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, training=self.training)
        x = F.relu(self.conv2(x, edge_index))
        x = F.dropout(x, training=self.training)
        x = self.conv3(x, edge_index)
        x = global_max_pool(x, batch_index)
        x = self.linear(x)
        return F.log_softmax(x, dim=1)

# Engine

In [9]:
# Python standard imports
from abc import ABC, abstractmethod
from time import strftime
from os import stat, mkdir
from math import floor, ceil

# PyTorch imports
from torch import device, load, save
from torch.nn import DataParallel
from torch.cuda import is_available

In [10]:
class Engine(ABC):

    def __init__(self, model, config):
        super().__init__()

        # Engine attributes
        self.model=model
        self.config=config

        # Determine the device to be used for model training and inference
        if (config.device == 'gpu') and config.gpu_list:
            print("Requesting GPUs. GPU list : " + str(config.gpu_list))
            self.devids=["cuda:{0}".format(x) for x in config.gpu_list]
            print("Main GPU : " + self.devids[0])

            if is_available():
                self.device=device(self.devids[0])
                if len(self.devids) > 1:
                    print("Using DataParallel on these devices: {}".format(self.devids))
                    self.model=DataParallel(self.model, device_ids=config.gpu_list, dim=0)
                print("CUDA is available")
            else:
                self.device=device("cpu")
                print("CUDA is not available")
        else:
            print("Unable to use GPU")
            self.device=device("cpu")

        # Send the model to the selected device
        self.model = DataParallel(self.model) if len(self.devids) > 1 else self.model # Changed
        self.model.to(self.device)

        # Setup the parameters tp save given the model type
        if type(self.model) == DataParallel:
            self.model_accs=self.model.module
        else:
            self.model_accs=self.model
  
        # Create the dataset object
        out = get_loaders(config.data_path, 
                      config.train_indices_file, config.val_indices_file, config.test_indices_file,   # Changed
                      config.edge_index_pickle, config.batch_size, config.num_data_workers)

        self.train_loader, self.val_loader, self.dataset = out
        # Define the variant dependent attributes
        self.criterion=None

        # Create the directory for saving the log and dump files
        self.dirpath=config.dump_path + strftime("%Y%m%d_%H%M%S") + "/"
        try:
            stat(self.dirpath)
        except:
            print("Creating a directory for run dump at : {}".format(self.dirpath))
            mkdir(self.dirpath)

        # Logging attributes
        self.train_log=CSVData(self.dirpath + "log_train.csv")
        self.val_log=CSVData(self.dirpath + "log_val.csv")

#         # Save a copy of the config in the dump path
#         save_config(self.config, self.dirpath + "config_file.ini")   # changed

    @abstractmethod
    def forward(self, data, mode):
        """Forward pass using self.data as input."""
        raise NotImplementedError

    def backward(self, predict, expected):
        """Backward pass using the loss computed for a mini-batch."""
        self.optimizer.zero_grad()  # Reset gradient accumulation
        loss = self.criterion(predict, expected)
        loss.backward()# Propagate the loss backwards
        self.optimizer.step()       # Update the optimizer parameters         
        
        return loss

    @abstractmethod
    def train(self):
        """Training loop over the entire dataset for a given number of epochs."""
        raise NotImplementedError

    def save_state(self, mode="latest"):
        """Save the model parameters in a file.
        
        Args :
        mode -- one of "latest", "best" to differentiate
                the latest model from the model with the
                lowest loss on the validation subset (default "latest")
        """
        path=self.dirpath + self.config.model_name + "_" + mode + ".pth"

        # Extract modules from the model dict and add to start_dict 
        modules=list(self.model_accs._modules.keys())
        state_dict={module: getattr(self.model_accs, module).state_dict() for module in modules}

        # Save the model parameter dict
        save(state_dict, path)

    def load_state(self, path):
        """Load the model parameters from a file.
        
        Args :
        path -- absolute path to the .pth file containing the dictionary
        with the model parameters to load from
        """
        # Open a file in read-binary mode
        with open(path, 'rb') as f:

            # Interpret the file using torch.load()
            checkpoint=load(f, map_location=self.device)

            print("Loading weights from file : {0}".format(path))

            local_module_keys=list(self.model_accs._modules.keys())
            for module in checkpoint.keys():
                if module in local_module_keys:
                    print("Loading weights for module = ", module)
                    getattr(self.model_accs, module).load_state_dict(checkpoint[module])
                    
        self.model.to(self.device)

In [11]:
# Python standard imports
from sys import stdout
from math import floor, ceil
from time import strftime, localtime

# PyTorch imports
from torch.optim import Adam

# WatChMaL imports
# from training_utils.engine import Engine
# from plot_utils.notebook_utils import CSVData



In [12]:
class CSVData:

    def __init__(self,fout):
        self.name  = fout
        self._fout = None
        self._str  = None
        self._dict = {}

    def record(self, keys, vals):
        for i, key in enumerate(keys):
            self._dict[key] = vals[i]

    def write(self):
        if self._str is None:
            self._fout=open(self.name,'w')
            self._str=''
            for i,key in enumerate(self._dict.keys()):
                if i:
                    self._fout.write(',')
                    self._str += ','
                self._fout.write(key)
                self._str+='{:f}'
            self._fout.write('\n')
            self._str+='\n'

        self._fout.write(self._str.format(*(self._dict.values())))
        self.flush()

    def flush(self):
        if self._fout: self._fout.flush()

    def close(self):
        if self._str is not None:
            self._fout.close()

In [13]:
class EngineGraph(Engine):

    def __init__(self, model, config):
        super().__init__(model, config)
        self.criterion=F.nll_loss
        self.optimizer=Adam(self.model_accs.parameters(), lr=config.lr)
        
        self.keys = ['iteration', 'epoch', 'loss', 'acc']

    def forward(self, data, mode="train"):
        """Overrides the forward abstract method in Engine.py.
        
        Args:
        mode -- One of 'train', 'validation' 
        """

        # Set the correct grad_mode given the mode
        if mode == "train":
            self.model.train()
        elif mode in ["validation"]:
            self.model.eval()

        return self.model(data)

    def train(self):
        """Overrides the train method in Engine.py.
        
        Args: None
        """
        
        epochs          = self.config.epochs
        report_interval = self.config.report_interval
        valid_interval  = self.config.valid_interval
        num_val_batches = self.config.num_val_batches

        # Initialize counters
        epoch=0.
        iteration=0

        # Parameter to upadte when saving the best model
        best_loss=1000000.

        val_iter = iter(self.val_loader)
        
        # Global training loop for multiple epochs
        while (floor(epoch) < epochs):

            print('Epoch', np.round(epoch).astype(np.int),
                  'Starting @', strftime("%Y-%m-%d %H:%M:%S", localtime()))

            # Local training loop for a single epoch
            for data in self.train_loader:
                data = data.to(self.device)

                # Update the epoch and iteration
                epoch+=1. / len(self.train_loader)
                iteration += 1
                
                # Do a forward pass using data = self.data
                res=self.forward(data, mode="train")

                # Do a backward pass using loss = self.loss
                loss = self.backward(res, data.y)

                acc = res.argmax(1).eq(data.y).sum().item()/data.y.shape[0]
                
                # Record the metrics for the mini-batch in the log
                self.train_log.record(self.keys, [iteration, epoch, loss, acc])
                self.train_log.write()

                # Print the metrics at given intervals
                if iteration % report_interval == 0:
                    print("... Iteration %d ... Epoch %1.2f ... Loss %1.3f ... Acc %1.3f"
                          % (iteration, epoch, loss, acc))

                # Run validation on given intervals
                if iteration % valid_interval == 0:
                    with torch.no_grad():
                        val_loss=0.
                        val_acc=0.

                        for val_batch in range(num_val_batches):

                            try:
                                data=next(val_iter)
                            except StopIteration:
                                val_iter=iter(self.val_loader)
                                data=next(val_iter)
                            data = data.to(self.device)

                            # Extract the event data from the input data tuple
                            res=self.forward(data, mode="validation")
                            acc = res.argmax(1).eq(data.y).sum().item()/data.y.shape[0]

                            val_loss+=self.criterion(res, data.y)
                            val_acc+=acc

                        val_loss /= num_val_batches
                        val_acc /= num_val_batches

                        # Record the validation stats to the csv
                        self.val_log.record(self.keys, [iteration, epoch, loss, acc])
                        self.val_log.write()

                        # Save the best model
                        if val_loss < best_loss:
                            self.save_state(mode="best")
                            best_loss = val_loss

                        # Save the latest model
                        self.save_state(mode="latest")
                    

        self.val_log.close()
        self.train_log.close()

    def validate(self, subset):
        """Overrides the validate method in Engine.py.
        
        Args:
        subset          -- One of 'train', 'validation', 'test' to select the subset to perform validation on
        """
        # Print start message
        if subset == "train":
            message="Validating model on the train set"
        elif subset == "validation":
            message="Validating model on the validation set"
        elif subset == "test":
            message="Validating model on the test set"
        else:
            print("validate() : arg subset has to be one of train, validation, test")
            return None

        print(message)
        
        # Setup the CSV file for logging the output, path to save the actual and reconstructed events, dataloader iterator
        if subset == "train":
            self.log=CSVData(self.dirpath + "train_validation_log.csv")
            validate_indices = self.dataset.train_indices
        elif subset == "validation":
            self.log=CSVData(self.dirpath + "valid_validation_log.csv")
            validate_indices = self.dataset.val_indices
        else:
            self.log=CSVData(self.dirpath + "test_validation_log.csv")
            validate_indices = self.dataset.test_indices

        data_iter = DataLoader(self.dataset, batch_size=self.config.validate_batch_size, 
                               num_workers=self.config.num_data_workers,
                               pin_memory=True, sampler=SubsetSequentialSampler(validate_indices))

        headers = ["index", "label", "pred"]
        for i in range(max(self.dataset.labels)+1):
            headers.append("pred_val{}".format(i))

        avg_loss = 0
        avg_acc = 0
        indices_iter = iter(validate_indices)

        with torch.no_grad():
            for iteration, data in enumerate(data_iter):

                gpu_data = data.to(self.device)

                stdout.write("Iteration : {}, Progress {} \n".format(iteration, iteration/len(data_iter)))
                res=self.forward(gpu_data, mode="validation")

                acc = res.argmax(1).eq(gpu_data.y).sum().item()
                loss = self.criterion(res, gpu_data.y) * data.y.shape[0]
                avg_acc += acc
                avg_loss += loss

                # Log/Report
                for label, pred, preds in zip(data.y.tolist(), res.argmax(1).tolist(), res.exp().tolist()):
                    output = [next(indices_iter), label, pred]
                    for p in preds:
                        output.append(p)
                    self.log.record(headers, output)
                    self.log.write()

        avg_acc/=len(validate_indices)
        avg_loss/=len(validate_indices)

        stdout.write("Overall acc : {}, Overall loss : {}".format(avg_acc, avg_loss))
        self.log.close()


# Testing

### Initiatlization

In [14]:
model = Net()
engine = EngineGraph(model, config)
# engine.train()
engine.load_state("/app/GraphNets/dump/gcn20191115_064647/gcn_kipf_best.pth")

Requesting GPUs. GPU list : [0]
Main GPU : cuda:0
CUDA is available
Creating a directory for run dump at : /app/GraphNets/dump/gcn20191115_070248/
Loading weights from file : /app/GraphNets/dump/gcn20191115_064647/gcn_kipf_best.pth
Loading weights for module =  conv1
Loading weights for module =  conv2
Loading weights for module =  conv3
Loading weights for module =  linear


# Testing new validation

In [15]:
engine.validate("validation")

Validating model on the validation set
Iteration : 0, Progress 0.0 
Iteration : 1, Progress 0.03571428571428571 
Iteration : 2, Progress 0.07142857142857142 
Iteration : 3, Progress 0.10714285714285714 
Iteration : 4, Progress 0.14285714285714285 
Iteration : 5, Progress 0.17857142857142858 
Iteration : 6, Progress 0.21428571428571427 
Iteration : 7, Progress 0.25 
Iteration : 8, Progress 0.2857142857142857 
Iteration : 9, Progress 0.32142857142857145 
Iteration : 10, Progress 0.35714285714285715 
Iteration : 11, Progress 0.39285714285714285 
Iteration : 12, Progress 0.42857142857142855 
Iteration : 13, Progress 0.4642857142857143 
Iteration : 14, Progress 0.5 
Iteration : 15, Progress 0.5357142857142857 
Iteration : 16, Progress 0.5714285714285714 
Iteration : 17, Progress 0.6071428571428571 
Iteration : 18, Progress 0.6428571428571429 
Iteration : 19, Progress 0.6785714285714286 
Iteration : 20, Progress 0.7142857142857143 
Iteration : 21, Progress 0.75 
Iteration : 22, Progress 0.78