<a href="https://colab.research.google.com/github/Ken89MathCompSci/GLC-MATNilm/blob/kengoh-learnable-positional-encoding-only/21-September-2025-Rerunning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
pwd

'/content'

In [2]:
cd GLC-MATNilm/

/content/GLC-MATNilm


In [3]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [4]:
%cd /content/GLC-MATNilm

/content/GLC-MATNilm


In [5]:
import os
os.makedirs('/content/GLC-MATNilm/history_model/colab_train/s0', exist_ok=True)

In [6]:
import os
os.makedirs('/content/GLC-MATNilm/log/colab_train', exist_ok=True)

In [7]:
!python main.py --dataAug --subName colab_train --output_dir "/content/drive/My Drive/GLC-MATNilm_results"

usage: main.py [-h] [--batch BATCH] [--lr LR] [--dropout DROPOUT]
               [--hidden HIDDEN] [--logname LOGNAME] [--subName SUBNAME]
               [--inputLength INPUTLENGTH] [--outputLength OUTPUTLENGTH]
               [--debug] [--dataAug] [--prob0 PROB0] [--prob1 PROB1]
               [--prob2 PROB2] [--prob3 PROB3] [--resume]
               [--checkpoint CHECKPOINT]
main.py: error: unrecognized arguments: --output_dir /content/drive/My Drive/GLC-MATNilm_results


In [None]:
import copy
import os
import utils
import argparse
import joblib
from tqdm import tqdm
import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from custom_types import Basic, TrainConfig
from modules import MATconv as MAT
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import sys # Import the sys module


def get_args():
    parser = argparse.ArgumentParser()

    parser.add_argument("--batch", type=int, default=32, help="batch size")
    parser.add_argument("--lr", type=float, default=0.001, help="learning rate")
    parser.add_argument("--dropout", type=float, default=0.1, help="dropout")
    parser.add_argument("--hidden", type=int, default=32, help="encoder decoder hidden size")
    parser.add_argument("--logname", action="store", default='root', help="name for log")
    parser.add_argument("--subName", action="store", type=str, default='test', help="name of the directory of current run")
    parser.add_argument("--inputLength", type=int, default=864, help="input length for the model")
    parser.add_argument("--outputLength", type=int, default=864, help="output length for the model")
    parser.add_argument("--debug", action="store_true", help="debug mode")
    parser.add_argument("--dataAug", action="store_true", help="data augmentation mode")
    parser.add_argument("--prob0", type=float, default=0.3, help="augment probability for Dishwasher")
    parser.add_argument("--prob1", type=float, default=0.6, help="weight")
    parser.add_argument("--prob2", type=float, default=0.3, help="weight")
    parser.add_argument("--prob3", type=float, default=0.3, help="weight")
    parser.add_argument("--resume", action="store_true", help="resume training from checkpoint")
    parser.add_argument("--checkpoint", type=str, default="All_best_onoff.ckpt", help="checkpoint file name to resume from")
    parser.add_argument("--output_dir", type=str, default=".", help="directory to save output results") # Added output_dir argument

    # Workaround for Colab's kernel launcher adding extra arguments
    if '__file__' not in globals():
        sys.argv = [sys.argv[0]]

    return parser.parse_args()


def train(t_net, train_Dataloader, vali_Dataloader, config, criterion, modelDir, epo=200):
    iter_loss = []
    vali_loss = []
    early_stopping_all = utils.EarlyStopping(logger, patience=30, verbose=True)

    if config.dataAug:
        sigClass = utils.sigGen(config)

    path_all = os.path.join(modelDir, "All_best_onoff.ckpt")

    for e_i in range(epo):

        logger.info(f"# of epoches: {e_i}")
        for t_i, (_, _, X_scaled, Y_scaled, Y_of) in enumerate(tqdm(train_Dataloader)):
            if config.dataAug:
                X_scaled, Y_scaled, Y_of = utils.dataAug(X_scaled.clone(), Y_scaled.clone(), Y_of.clone(), sigClass, config)

            t_net.model_opt.zero_grad(set_to_none=True)

            X_scaled = X_scaled.type(torch.FloatTensor).to(device, non_blocking=True)
            Y_scaled = Y_scaled.type(torch.FloatTensor).to(device, non_blocking=True)
            Y_of = Y_of.type(torch.FloatTensor).to(device, non_blocking=True)

            y_pred_dish_r, y_pred_dish_c = t_net.model(X_scaled)

            loss_r = criterion[0](y_pred_dish_r,Y_scaled)
            loss_c = criterion[1](y_pred_dish_c, Y_of)

            loss=loss_r+loss_c
            loss.backward()

            t_net.model_opt.step()
            iter_loss.append(loss.item())

        epoch_losses = np.average(iter_loss)

        logger.info(f"Validation: ")
        maeScore, y_vali_ori, y_vali_pred_d_update, _, _, _ = utils.evaluateResult(net, config, vali_Dataloader, logger)
        val_loss = criterion[0](y_vali_ori, y_vali_pred_d_update)
        logger.info(f"Epoch {e_i:d}, train loss: {epoch_losses:3.3f}, val loss: {val_loss:3.3f}.")
        vali_loss.append(val_loss)

        if e_i % 10 == 0:
            checkpointName = os.path.join(modelDir, "checkpoint_" + str(e_i) + '.ckpt')
            utils.saveModel(logger, net, checkpointName)

        logger.info(f"Early stopping overall: ")
        early_stopping_all(np.mean(maeScore), net, path_all)
        if early_stopping_all.early_stop:
            print("Early stopping")
            break

    net_all = copy.deepcopy(net)
    checkpoint_all = torch.load(path_all, map_location=device)
    utils.loadModel(logger, net_all, checkpoint_all)
    net_all.model.eval()

    return net_all

if __name__ == '__main__':
    args = get_args()
    utils.mkdir("log/" + args.subName)
    logger = utils.setup_log(args.subName, args.logname)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    logger.info(f"Using computation device: {device}")
    logger.info(args)
    if args.debug:
        epo = 2
    else:
        epo = 200

    # splitLoss = False
    # trainFull = True

    # Dataloder
    logger.info(f"loading data")
    train_data, val_data, test_data = utils.data_loader(args)

    logger.info(f"loading data finished")

    config_dict = {
        "input_size": 1,
        "batch_size": args.batch,
        "hidden": args.hidden,
        "lr": args.lr,
        "dropout": args.dropout,
        "logname": args.logname,
        "outputLength": args.outputLength,
        "inputLength" : args.inputLength,
        "subName": args.subName,
        "dataAug": args.dataAug,
        "prob0": args.prob0,
        "prob1": args.prob1,
        "prob2": args.prob2,
        "prob3": args.prob3,
        "output_dir": args.output_dir # Added to config
    }

    config = TrainConfig.from_dict(config_dict)
    modelDir = utils.mkdirectory(config.subName, saveModel=True) # This needs to be updated to use output_dir
    joblib.dump(config, os.path.join(modelDir, "config.pkl"))


    logger.info(f"Training size: {train_data.cumulative_sizes[-1]:d}.")

    index = np.arange(0,train_data.cumulative_sizes[-1])
    train_subsampler = torch.utils.data.SubsetRandomSampler(index)
    train_Dataloader = DataLoader(
        train_data,
        batch_size=config.batch_size,
        sampler=train_subsampler,
        num_workers=1,
        pin_memory=True)

    sampler = utils.testSampler(val_data.cumulative_sizes[-1], config.outputLength)
    sampler_test = utils.testSampler(test_data.cumulative_sizes[-1], config.outputLength)

    vali_Dataloader = DataLoader(
        val_data,
        batch_size=config.batch_size,
        sampler=sampler,
        num_workers=1,
        pin_memory=True)

    test_Dataloader = DataLoader(
        test_data,
        batch_size=config.batch_size,
        sampler=sampler_test,
        num_workers=1,
        pin_memory=True)

    logger.info("Initialize model")
    model = MAT(config).to(device)
    logger.info("Model MAT")

    optim = optim.Adam(params=[p for p in model.parameters() if p.requires_grad], lr=config.lr)
    net = Basic(model, optim)

    # Resume from checkpoint if specified
    if args.resume:
        checkpoint_path = os.path.join(modelDir, args.checkpoint)
        if os.path.exists(checkpoint_path):
            logger.info(f"Resuming training from checkpoint: {checkpoint_path}")
            checkpoint = torch.load(checkpoint_path, map_location=device)
            net = utils.loadModel(logger, net, checkpoint)
        else:
            logger.error(f"Checkpoint file not found: {checkpoint_path}")
            logger.info("Starting training from scratch")

    criterion_r = nn.MSELoss()
    criterion_c = nn.BCELoss()
    criterion = [criterion_r, criterion_c]

    logger.info("Training start")
    net_all = train(net, train_Dataloader, vali_Dataloader, config, criterion, modelDir, epo=epo)
    logger.info("Training end")

    logger.info("validation start")
    utils.evaluateResult(net_all, config, vali_Dataloader, logger)
    logger.info("test start")
    utils.evaluateResult(net_all, config, test_Dataloader, logger)

INFO:root:Using computation device: cuda:0
2025-09-20 18:16:35,469 - root - INFO - Using computation device: cuda:0
INFO:root:Namespace(batch=32, lr=0.001, dropout=0.1, hidden=32, logname='root', subName='test', inputLength=864, outputLength=864, debug=False, dataAug=False, prob0=0.3, prob1=0.6, prob2=0.3, prob3=0.3, resume=False, checkpoint='All_best_onoff.ckpt', output_dir='.')
2025-09-20 18:16:35,470 - root - INFO - Namespace(batch=32, lr=0.001, dropout=0.1, hidden=32, logname='root', subName='test', inputLength=864, outputLength=864, debug=False, dataAug=False, prob0=0.3, prob1=0.6, prob2=0.3, prob3=0.3, resume=False, checkpoint='All_best_onoff.ckpt', output_dir='.')
INFO:root:loading data
2025-09-20 18:16:35,472 - root - INFO - loading data
INFO:root:loading data finished
2025-09-20 18:16:35,520 - root - INFO - loading data finished
INFO:root:Training size: 27937.
2025-09-20 18:16:35,529 - root - INFO - Training size: 27937.
INFO:root:Initialize model
2025-09-20 18:16:35,531 - roo