Skip to content

Commit

Permalink
Merge pull request #113 from GilesStrong/feat_add_geo_data
Browse files Browse the repository at this point in the history
Feat add geo data
  • Loading branch information
GilesStrong committed Apr 13, 2023
2 parents c040dce + d7dd1a0 commit 92301ba
Show file tree
Hide file tree
Showing 5 changed files with 224 additions and 19 deletions.
3 changes: 1 addition & 2 deletions .gitignore
Expand Up @@ -4,11 +4,10 @@ dev/
**/.DS_Store
*~


#Examples
examples/*/*
**/train_weights
**/data
examples/**/data
**/weights

# Byte-compiled / optimized / DLL files
Expand Down
1 change: 1 addition & 0 deletions CHANGES.md
Expand Up @@ -20,6 +20,7 @@
- 'hard_identity' function to replace lambda x: x when required
- `fold2foldfile`, `df2foldfile`, and `add_meta_data` can now deal with targets in the form of multi dimensional tensors, and convert them to sparse COO format
- `df2foldfile` now has the option to not shuffle data into folds and instead split it into contiguous folds
- Limited handling of PyTorch Geometric data: `TorchGeometricFoldYielder`, `TorchGeometricBatchYielder`, `TorchGeometricEvalMetric`

## Removals

Expand Down
76 changes: 68 additions & 8 deletions lumin/nn/data/batch_yielder.py
@@ -1,24 +1,21 @@
from __future__ import annotations
import numpy as np
import pandas as pd
from typing import List, Optional, Union, Tuple
from typing import List, Optional, Union, Tuple, Any, Dict
import math

from ...utils.misc import to_device

from torch import Tensor

__all__ = ['BatchYielder']


'''
Todo
- Improve this/change to dataloader
'''
__all__ = ['BatchYielder', 'TorchGeometricBatchYielder']


class BatchYielder:
r'''
Yields minibatches to model during training. Iteration provides one minibatch as tuple of tensors of inputs, targets, and weights.
TODO: Improve this/change to dataloader
Arguments:
inputs: input array for (sub-)epoch
Expand Down Expand Up @@ -89,9 +86,72 @@ def __iter__(self) -> List[Tensor]:
def __len__(self): return len(self.inputs)//self.bs if self.drop_last else math.ceil(len(self.inputs)/self.bs)

def get_inputs(self, on_device:bool=False) -> Union[Tensor, Tuple[Tensor,Tensor]]:
r'''
Returns all data.
Arguments:
on_device: whether to place tensor on device
Returns:
tuple of inputs, targets, and weights as tensors on device
'''

if on_device:
if self.matrix_inputs is None: return to_device(Tensor(self.inputs))
else: return (to_device(Tensor(self.inputs)), to_device(Tensor(self.matrix_inputs)))
else:
if self.matrix_inputs is None: return self.inputs
else: return (self.inputs, self.matrix_inputs)


class TorchGeometricBatchYielder(BatchYielder):
r'''
:class:`~lumin.nn.data.batch_yielder.BatchYielder` for PyTorch Geometric data. kwargs for compatibility only.
Arguments:
inputs: PyTorch Geometric Dataset containing inputs, weights, and targets
bs: batchsize, number of data to include per minibatch
shuffle: whether to shuffle the data at the beginning of an iteration
exclude_keys: data keys to exclude from inputs
'''

from torch_geometric.data import Dataset

def __init__(self, inputs: Dataset, bs:int, shuffle:bool=True, exclude_keys:Optional[List[str]]=None, use_weights:bool=True, **kwargs:Any):
from torch_geometric.loader import DataLoader
self.loader = DataLoader(inputs, batch_size=bs, shuffle=shuffle, exclude_keys=exclude_keys)
self.use_weights = use_weights

def __iter__(self) -> Tuple[Dict[str, Tensor], Dict[str, Tensor], Optional[Dict[str, Tensor]]]:
r'''
Iterate through data in batches.
Returns:
tuple of batches of inputs, targets, and weights as dictionaries of tensors on device
'''

for batch in self.loader:
batch = to_device(batch)
x = {k: batch[k] for k in batch.keys if k not in ['y', 'ptr']}
y = {'y': batch.y, 'batch': batch.batch}
w = {'weight': batch.weight, 'batch': batch.batch} if 'weight' in batch.keys and self.use_weights else None
yield x, y, w

def __len__(self): return len(self.loader)

def get_inputs(self, on_device:bool=False) -> Union[Tensor, Tuple[Tensor,Tensor]]:
r'''
Returns all data.
Arguments:
on_device: whether to place tensor on device
Returns:
tuple of inputs, targets, and weights as dictionaries of tensors on device
'''

if on_device:
x = {k: to_device(self.loader.dataset[k]) for k in self.loader.dataset.keys if k not in ['y', 'ptr']}
else:
x = {k: self.loader.dataset[k] for k in self.loader.dataset.keys if k not in ['y', 'ptr']}
return x
111 changes: 103 additions & 8 deletions lumin/nn/data/fold_yielder.py
@@ -1,3 +1,4 @@
from __future__ import annotations
import numpy as np
import pandas as pd
import h5py
Expand All @@ -10,10 +11,11 @@
from fastcore.all import is_listy
from importlib import import_module
from sklearn.pipeline import Pipeline
from sklearn.model_selection import KFold

from .batch_yielder import BatchYielder

__all__ = ['FoldYielder', 'HEPAugFoldYielder']
__all__ = ['FoldYielder', 'HEPAugFoldYielder', 'TorchGeometricFoldYielder']


class FoldYielder:
Expand Down Expand Up @@ -327,8 +329,8 @@ def get_data(self, n_folds:Optional[int]=None, fold_idx:Optional[int]=None) -> D
Inputs are passed through np.nan_to_num to deal with nans and infs.
Arguments:
n_folds: number of folds to get data from. Default all folds. Not compatable with fold_idx
fold_idx: Only load group from a single, specified fold. Not compatable with n_folds
n_folds: number of folds to get data from. Default all folds. Not compatible with fold_idx
fold_idx: Only load group from a single, specified fold. Not compatible with n_folds
Returns:
tuple of inputs, targets, and weights as Numpy arrays
Expand All @@ -342,19 +344,19 @@ def get_df(self, pred_name:str='pred', targ_name:str='targets', wgt_name:str='we
inc_inputs:bool=False, inc_ignore:bool=False, deprocess:bool=False, verbose:bool=True, suppress_warn:bool=False,
nan_to_num:bool=False, inc_matrix:bool=False) -> pd.DataFrame:
r'''
Get a Pandas DataFrameof the data in the foldfile. Will add columns for inputs (if requested), targets, weights, and predictions (if present)
Get a Pandas DataFrame of the data in the foldfile. Will add columns for inputs (if requested), targets, weights, and predictions (if present)
Arguments:
pred_name: name of prediction group
targ_name: name of target group
wgt_name: name of weight group
n_folds: number of folds to get data from. Default all folds. Not compatable with fold_idx
fold_idx: Only load group from a single, specified fold. Not compatable with n_folds
n_folds: number of folds to get data from. Default all folds. Not compatible with fold_idx
fold_idx: Only load group from a single, specified fold. Not compatible with n_folds
inc_inputs: whether to include input data
inc_ignore: whether to include ignored features
deprocess: whether to deprocess inputs and targets if pipelines have been
verbose: whether to print the number of datapoints loaded
suppress_warn: whether to supress the warning about missing columns
suppress_warn: whether to suppress the warning about missing columns
nan_to_num: whether to pass input data through `np.nan_to_num`
inc_matrix: whether to include flattened matrix data in output, if present
Expand Down Expand Up @@ -528,7 +530,7 @@ def _reflect(self, df:pd.DataFrame, vectors:List[str]) -> None:

def get_fold(self, idx:int) -> Dict[str,np.ndarray]:
r'''
Get data for single fold applying random train-time data augmentaion. Data consists of dictionary of inputs, targets, and weights.
Get data for single fold applying random train-time data augmentation. Data consists of dictionary of inputs, targets, and weights.
Accounts for ignored features.
Inputs, except for matrix data, are passed through np.nan_to_num to deal with nans and infs.
Expand Down Expand Up @@ -638,3 +640,96 @@ def get_test_fold(self, idx:int, aug_idx:int) -> Dict[str, np.ndarray]:
targets = targets[self.targ_feats]
data['targets'] = np.nan_to_num(targets.values)
return self._append_matrix(data, idx) if self.has_matrix and self.yield_matrix else data


class TorchGeometricFoldYielder(FoldYielder):
r'''
Interface class for accessing data from PyTorch Geometric datasets.
Dataset will be split into sub-folds; either provide a value for the `fold_indices` argument with your own split as a list of lists of indices,
or specify the number of folds for a random split (`n_folds`)
..warning::
Much functionality has yet to be implemented for this class
Arguments:
dataset: PyTorch Geometric Dataset containing inputs, weights, and targets
n_folds: number of folds in which to randomly split the dataset. Must provide either this or `fold_indices`
fold_indices: list of lists of indices; each list of indices is a fold. Must provide either this or `n_folds`
batch_yielder_type: Class of :class:`~lumin.nn.data.batch_yielder.BatchYielder` to instantiate to yield inputs
'''

from torch_geometric.data import Dataset
from .batch_yielder import TorchGeometricBatchYielder

def __init__(self, dataset:Dataset, n_folds:Optional[int], fold_indices:Optional[List[List[int]]]=None, batch_yielder_type:Type[BatchYielder]=TorchGeometricBatchYielder):
self.dataset = dataset
self.batch_yielder_type = batch_yielder_type
self._set_folds(n_folds, fold_indices)

self.cont_feats,self.cat_feats,self.input_pipe,self.output_pipe = [],[],None,None
self.yield_matrix,self.matrix_pipe = True,None
self.augmented,self.aug_mult,self.train_time_aug,self.test_time_aug = False,0,False,False
self.input_feats = self.cont_feats + self.cat_feats
self.orig_cont_feats,self.orig_cat_feat,self._ignore_feats = self.cont_feats,self.cat_feats,[]

def __repr__(self) -> str: return f'FoldYielder with {self.n_folds} folds'

def __len__(self) -> int: return self.n_folds

def __getitem__(self, idx:int) -> Dataset: return self.get_fold(idx)

def __iter__(self) -> Dataset:
for i in range(self.n_folds): yield self.get_fold(i)

def _set_folds(self, n_folds:Optional[int], fold_indices:Optional[List[List[int]]]=None) -> None:
if fold_indices is None:
kf = KFold(n_splits=n_folds, shuffle=True)
fold_indices = [f[1] for f in kf.split(X=np.arange(len(self.dataset)))]
self.n_folds = n_folds
else:
self.n_folds = len(fold_indices)

self.fold_indices = fold_indices
self.fld_szs = {i:len(f) for i,f in enumerate(self.fold_indices)}

def columns(self) -> List[str]:
raise NotImplementedError()

def add_ignore(self, feats:Union[str,List[str]]) -> None:
raise NotImplementedError()

def _set_foldfile(self, foldfile:Union[str,Path,h5py.File]) -> None:
raise NotImplementedError()

def _append_matrix(self, data, idx) -> Dict[str,np.ndarray]:
raise NotImplementedError()

def close(self) -> None:
pass

def get_fold(self, idx:int) -> Dict[str,np.ndarray]:
r'''
Get data for single fold. Data consists of a slice of a PyTorch Geometric Dataset.
Arguments:
idx: fold index to load
Returns:
PyTorch Geometric Dataset slice
'''

return {'inputs':self.dataset[self.fold_indices[idx]]}

def get_column(self, column:str, n_folds:Optional[int]=None, fold_idx:Optional[int]=None, add_newaxis:bool=False) -> Union[np.ndarray, None]:
raise NotImplementedError()

def get_data(self, n_folds:Optional[int]=None, fold_idx:Optional[int]=None) -> Dict[str,np.ndarray]:
raise NotImplementedError()

def get_df(self, pred_name:str='pred', targ_name:str='targets', wgt_name:str='weights', n_folds:Optional[int]=None, fold_idx:Optional[int]=None,
inc_inputs:bool=False, inc_ignore:bool=False, deprocess:bool=False, verbose:bool=True, suppress_warn:bool=False,
nan_to_num:bool=False, inc_matrix:bool=False) -> pd.DataFrame:
raise NotImplementedError()

def save_fold_pred(self, pred:np.ndarray, fold_idx:int, pred_name:str='pred') -> None:
raise NotImplementedError()
52 changes: 51 additions & 1 deletion lumin/nn/metrics/eval_metric.py
Expand Up @@ -11,7 +11,7 @@
from ..callbacks.callback import Callback
from ...utils.misc import to_np

__all__ = ['EvalMetric']
__all__ = ['EvalMetric', 'TorchGeometricEvalMetric']


class EvalMetric(Callback, metaclass=ABCMeta):
Expand Down Expand Up @@ -167,3 +167,53 @@ def get_df(self) -> pd.DataFrame:
else:
df['pred'] = self.preds.squeeze()
return df


class TorchGeometricEvalMetric(EvalMetric):
r'''
Abstract class for evaluating performance of a model using some metric and PyTorch Geometric data
Arguments:
name: optional name for metric, otherwise will be inferred from class
lower_metric_better: whether a lower metric value should be treated as representing better perofrmance
main_metric: whether this metic should be treated as the primary metric for SaveBest and EarlyStopping
Will automatically set the first EvalMetric to be main if multiple primary metrics are submitted
'''

def on_epoch_begin(self) -> None:
r'''
Resets prediction tracking
'''

self.preds, self.targets, self.batches, self.weights, self.metric = [],[],[],[],None
self.batch_cnt = 0

def on_forwards_end(self) -> None:
r'''
Save predictions from batch
'''

if self.model.fit_params.state == 'valid':
self.preds.append(self.model.fit_params.y_pred.cpu().detach())
self.targets.append(self.model.fit_params.y['y'].cpu().detach())
self.batches.append(self.model.fit_params.y['batch'].cpu().detach()+self.batch_cnt)
self.batch_cnt = self.batches[-1].max()
if self.model.fit_params.w is not None:
self.weights.append(self.model.fit_params.w.cpu().detach())

def on_epoch_end(self) -> None:
r'''
Compute metric using saved predictions
'''

if self.model.fit_params.state != 'valid': return
self.preds = torch.cat(self.preds, dim=0)
if 'multiclass' in self.model.objective: self.preds = torch.exp(self.preds)
self.targets = torch.cat(self.targets, dim=0)
self.batches = torch.cat(self.batches, dim=0)
self.weights = torch.cat(self.weights, dim=0) if len(self.weights) > 0 else None
self.metric = self.evaluate()
del self.preds
del self.targets
del self.batches
del self.weights

0 comments on commit 92301ba

Please sign in to comment.