# Fine-Tuning the CNN Filter

In [None]:
import numpy

print(numpy.__version__)
import networkx

print(networkx.__version__)

from nilmtk.api import API
import warnings

warnings.filterwarnings("ignore")
from nilmtk.disaggregate import GaterCNN
import nilmtk.utils as utils

import torch

USE_GPU = True
device = torch.device("cuda:0" if USE_GPU and torch.cuda.is_available() else "cpu")
print(torch.__version__, device)

## Data Selection

We begin by specifying the fine-tuning data and corresponding file paths.

The fine-tuning process is performed per appliance per house, i.e., one appliance in one house at a time.

The fine-tuning data can be selected with flexible duration and time range, provided that the following conditions are met:
1. The fine-tuning dataset includes at least one active usage event of the target appliance.
2. There is no overlap between the fine-tuning data and the testing data.

In [None]:
DATASET_NAME = 'redd'
HOUSE = 3

# Whether to fine-tune the CNN filter.
TUNE = True
APPLIANCE = "dish washer"

# Fine-tunes with one day data or seven days data
days = 1

MODEL_NOTE = f"{days}days_ft_{HOUSE}"

if HOUSE == 1:
    TRAIN = {
        # 1: {  # data for 7-days
        #     'start_time': '2011-04-19',
        #     'end_time': '2011-04-26'
        # },
        1: {  # data for one-day
            'start_time': '2011-04-23',
            'end_time': '2011-04-24'
        },
    }
    TEST = {
        1: {
            'start_time': '2011-04-24',
            'end_time': '2011-05-26'
        },
    }
elif HOUSE == 2:
    TRAIN = {
        # 2: {  # data for 7-days
        #     'start_time': '2011-04-18',
        #     'end_time': '2011-04-25'
        # },
        2: {  # one-day data for D.W. and M.V.
            'start_time': '2011-04-18',
            'end_time': '2011-04-19'
        },
        # 2: {  # one-day data for W.M.
        #     'start_time': '2011-04-23',
        #     'end_time': '2011-04-24'
        # },
    }
    TEST = {
        2: {
            'start_time': '2011-04-25',
            'end_time': '2011-05-22'
        },
    }
elif HOUSE == 3:
    TRAIN = {
        # 3: {  # data for 7-days
        #     'start_time': '2011-04-17',
        #     'end_time': '2011-04-24'
        # },
        3: {  # one-day data for D.W. and M.V.
            'start_time': '2011-04-17',
            'end_time': '2011-04-18'
        },
        # 3: {  # one-day data for W.M.
        #     'start_time': '2011-04-19',
        #     'end_time': '2011-04-20'
        # },
    }
    TEST = {
        3: {
            'start_time': '2011-04-24',
            'end_time': '2011-05-30'
        },
    }
else:
    raise NameError

## Fine-Tune

Then, we perform the fine-tuning process.

In [None]:
e = {
    # Specify power type, sample rate and disaggregated appliance
    'power': {
        'mains_train': ['apparent'],
        'mains_transfer': ['apparent'],
        'mains_test': ['apparent'],
        'appliance': ['active'],
    },
    'sample_rate': 6,
    'appliances': [APPLIANCE],
    # 'appliances': ['fridge'],
    # Universally no pre-training
    'pre_trained': False,
    "app_meta": utils.GENERAL_APP_META,
    # Specify algorithm hyper-parameters
    'save_note': f'ft-{HOUSE}' if TUNE else f'ft-{HOUSE}-no',
    'methods': {
        "GaterCNN": GaterCNN(
            {
                'n_epochs': 3,
                'batch_size': 16,
                'sequence_length': 720,
                'appliance_length': 720,
                # In fine-tuning mode, set 'test_only' to True to avoid the model from
                # being trained by the source dataset again.
                'test_only': True,
                'fine_tune': TUNE,
                'note': MODEL_NOTE,
                'load_from': 'ukdale',
                'patience': 3
            }
        )
    },
    # Specify train and test data
    'train': {
        'datasets': {
            'redd': {
                'path': '../mnt/redd.h5',
                'buildings': TRAIN
            },
        }
    },
    'transfer': {
        'datasets': {
            'redd': {
                'path': '../mnt/redd.h5',
                'buildings': TRAIN
            },
        },
    },
    'test': {
        'datasets': {
            'redd': {
                'path': '../mnt/redd.h5',
                'buildings': TEST
            },
        },
        # Specify evaluation metrics
        'metrics': ['f1score']
    }
}

API(e)