# Example of Using BatLiNet for Inference

In [1]:
import os

os.environ['CUDA_VISIBLE_DEVICES'] = ''

import sys
import torch
import random
import shutil
import pickle
import hashlib
import warnings
import numpy as np

from pathlib import Path
from datetime import datetime

sys.path.append(str(Path.cwd()))
from src.task import Task
from src.builders import MODELS
from src.utils import import_config

In our experiment, there are two methods to train a model.

Method 1: Call the wrapped main function. This involves directly invoking the main function within the encapsulated 'pipeline.py' for both training and evaluation.

Method 2: Pipeline details. This method involves gradually unfolding the contents within the pipeline, such as loading configurations, building the dataset, training, predicting, and so forth.

You are free to choose either of these methods to reproduce the code as per your convenience and requirements.

# Method 1: Calling the wrapped main function


In [2]:
#from scripts.pipeline import main
#for seed in range(8):
#    config_path = "./configs/ablation/diff_branch/batlinet/mix_100.yaml"
#    workspace = "./workspaces/ablation/diff_branch/batlinet/mix_100"
#    # If train is true, the model needs to be trained from scratch, if it is false it will be loaded from checkpoint
#    main(config_path=config_path, workspace=workspace, seed=seed, train=False, evaluate=True, device='cpu')

# Method 2: Pipeline details


## Define helper functions
We first define some helper functions.

In [3]:
# We use these functions to name the dumped files
def hash_string(string):
    sha256_hash = hashlib.sha256()
    sha256_hash.update(string.encode('utf-8'))
    hash_value = sha256_hash.hexdigest()
    truncated_hash = hash_value[:32]
    return truncated_hash


def timestamp(marker: bool = False):
    template = '%Y-%m-%d %H:%M:%S' if marker else '%Y%m%d%H%M%S'
    return datetime.now().strftime(template)

## Set seed
We set the seed of the experiment with the following function. However, some low-level code may still bring in randomness, which may slightly influence the final scores (<10 RMSE, for example).

In [4]:
def set_seed(seed: int):
    print(f'Seed is set to {seed}.')
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

## Load config
We use config files to organize our experiments. We will use the following function to load the config files.

In [5]:
CONFIGS = [
    'model',
    'train_test_split',
    'feature',
    'label',
    'feature_transformation',
    'label_transformation'
]


def load_config(config_path: str, workspace: str) -> dict:
    config_path = Path(config_path)
    configs = import_config(config_path, CONFIGS)

    # Determine the workspace
    if configs['model'].get('workspace') is not None:
        workspace = Path(configs['model'].get('workspace'))
    elif workspace is not None:
        if workspace.strip().lower() == 'none':
            workspace = None
        else:
            workspace = Path(workspace)
    else:
        workspace = Path.cwd() / 'workspaces' / config_path.stem
        warnings.warn(f'Setting workspace to {str(workspace)}. If you '
                       'do not want any information to be stored, '
                       'explicitly call with flag `--workspace none`.')

    if workspace is not None and workspace.exists():
        assert workspace.is_dir(), workspace

    if workspace is not None and not workspace.exists():
        os.makedirs(workspace)

    configs['workspace'] = workspace

    return configs

## Build dataset
As the preprocessing of the datasets are time-consuming, we cache the preprocessed data to save both time and memory (when using parallel computation).

In [6]:
def recursive_dump_string(data):
    if isinstance(data, list):
        return '_'.join([recursive_dump_string(x) for x in data])
    if isinstance(data, dict):
        return '_'.join([
            recursive_dump_string(data[key])
            for key in sorted(data.keys())
        ])
    return str(data)


def build_dataset(configs: dict, device: str):
    strings = []
    fields = ['label', 'feature', 'train_test_split',
              'feature_transformation', 'label_transformation']
    for field in fields:
        strings.append(recursive_dump_string(configs[field]))
    filename = hash_string('+'.join(strings))
    cache_dir = Path('cache')
    if not cache_dir.exists():
        cache_dir.mkdir()
    cache_file = Path(cache_dir / f'battery_cache_{filename}.pkl')

    if cache_file.exists():
        warnings.warn(f'Load datasets from cache {str(cache_file)}.')
        with open(cache_file, 'rb') as f:
            dataset = pickle.load(f)
    else:
        dataset = Task(
            label_annotator=configs['label'],
            feature_extractor=configs['feature'],
            train_test_splitter=configs['train_test_split'],
            feature_transformation=configs['feature_transformation'],
            label_transformation=configs['label_transformation']).build()
        # store cache
        with open(cache_file, 'wb') as f:
            pickle.dump(dataset, f)
    return dataset.to(device)

## Train and evaluate 
The following is the main logic of evaluation. We load in the correct config file and then train or evaluate the model to obtain metrics.

In [7]:
config_path = Path("./configs/ablation/diff_branch/batlinet/mix_100.yaml")
workspace = "./workspaces/ablation/diff_branch/batlinet/mix_100"
configs = load_config(config_path, workspace)
metric = ['RMSE', 'MAE', 'MAPE']
device = 'cpu'
train_from_scratch = False  # Whether we train the model from scratch

In [8]:
# We test 8 seeds
for seed in range(8):
    set_seed(seed)

    # Prepare dataset
    dataset = build_dataset(configs, device).to(device)

    # Model preparation
    configs['model']['seed'] = seed
    model = MODELS.build(configs['model'])
    
    if not train_from_scratch:
        # load model from checkpoint
        # checkpoint = next(Path(workspace).glob(f'*seed_{seed}*.ckpt'))
        paths = list(Path(workspace).glob(f'*seed_{seed}*.ckpt'))
        if len(paths) > 1:
            print(f"Warning: finding multiple paths", paths)
        checkpoint = paths[0]
        model.load_checkpoint(checkpoint, device=device)

    model = model.to(device)

    # Store the current config to workspace
    ts = timestamp()
    if model.workspace is not None:
        shutil.copyfile(config_path, model.workspace / f'config_{ts}.yaml')

    if train_from_scratch:
        # train from scratch
        model.fit(dataset, timestamp=ts)

    # Evaluate
    prediction = model.predict(dataset)
    scores = {
        m: dataset.evaluate(prediction, m) for m in metric
    }

    # Save predictions
    if model.workspace is not None:
        obj = {
            'prediction': prediction,
            'scores': scores,
            'data': dataset.to('cpu'),
            'seed': seed,
        }
        with open(
            model.workspace / f'predictions_seed_{seed}_{ts}.pkl', 'wb'
        ) as f:
            pickle.dump(obj, f)

    # Print metrics
    print(' '.join([
        f'{m}: {s:.2f}' for m, s in scores.items()
    ]), flush=True)

Seed is set to 0.


Reading train data:   0%|          | 0/205 [00:00<?, ?it/s]

Reading test data:   0%|          | 0/137 [00:00<?, ?it/s]

Extracting features:   0%|          | 0/205 [00:00<?, ?it/s]

Extracting features:   0%|          | 0/137 [00:00<?, ?it/s]

IndexError: list index out of range

For the complete reproduction of our experiments, please refer to our [detailed guide](README.md). Also, we visualized all the tables and figures [here](notebooks).