## Parameter Definition

In [14]:
# N-way, K-shot few-Shot learning parameters
ways=10             # Number of classes in a task
shots=5             # Number of training examples per class (same for testing)
# Meta-learning parameters
meta_lr=0.001       # Outer loop learning rate
fast_lr=0.1         # Inner loop learning rate
adapt_steps=5       # Number of inner loop update steps
meta_batch_size=32  # Number of tasks sampled per batch
iterations=5        # Number of outer loop iterations
# Cuda and random seed settings
cuda=True
seed=42
# Dataset parameters (different domain means different working condition)
train_domain=1      # For CWRU dataset: 0: 1797, 1: 1772, 2: 1750, 3: 1730
valid_domain=2     
test_domain=3
# Path to the data directory
data_dir_path='./data'

# test_step = 50
test_batch_size = 128
display_step = 1
expriment_name = 'MAML_CWRU_10w_5s'

In [15]:
from datasets.cwru import CWRU
from utils import (setlogger, fast_adapt)

import os
import torch
import logging
import random
import numpy as np
import learn2learn as l2l
import matplotlib.pyplot as plt

from torch import nn
from learn2learn.data.transforms import (
    FusedNWaysKShots,
    LoadData,
    RemapLabels,
    ConsecutiveLabels,
)

## Device Checking and Random Seed Setting

In [16]:
# Set the Random Seed
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

# Set training device, using GPU if available
if cuda and torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    device_count = torch.cuda.device_count()
    device = torch.device('cuda')
    logging.info('Training MAML with {} GPU(s).'.format(device_count))
else:
    device = torch.device('cpu')
    logging.info('Training MAML with CPU.')

# set the logger
if not os.path.exists("./logs"):
    os.makedirs("./logs")
setlogger(os.path.join("./logs", expriment_name + '.log'))

## Dataset and Meta-data Creation

In [17]:
# Create Datasets
train_dataset = CWRU(train_domain,
                        data_dir_path)

valid_dataset = CWRU(valid_domain,
                        data_dir_path)

test_dataset = CWRU(test_domain,
                    data_dir_path)

# Create Meta-Datasets
train_dataset = l2l.data.MetaDataset(train_dataset)
valid_dataset = l2l.data.MetaDataset(valid_dataset)
test_dataset = l2l.data.MetaDataset(test_dataset)

# Create Meta-Tasks
train_transforms = [
    FusedNWaysKShots(train_dataset, n=ways, k=2*shots),
    LoadData(train_dataset),
    RemapLabels(train_dataset),
    ConsecutiveLabels(train_dataset),
]
train_tasks = l2l.data.Taskset(
    train_dataset,
    task_transforms=train_transforms,
    num_tasks=400,
)

valid_transforms = [
    FusedNWaysKShots(valid_dataset, n=ways, k=2*shots),
    LoadData(valid_dataset),
    ConsecutiveLabels(valid_dataset),
    RemapLabels(valid_dataset),
]
valid_tasks = l2l.data.Taskset(
    valid_dataset,
    task_transforms=valid_transforms,
    num_tasks=100,
)

test_transforms = [
    FusedNWaysKShots(test_dataset, n=ways, k=2*shots),
    LoadData(test_dataset),
    RemapLabels(test_dataset),
    ConsecutiveLabels(test_dataset),
]
test_tasks = l2l.data.Taskset(
    test_dataset,
    task_transforms=test_transforms,
    num_tasks=100,
)

## Training Model Creation

In [18]:
# Create Model
model = l2l.vision.models.CNN4(output_size=10)
model.to(device)
maml = l2l.algorithms.MAML(model, lr=fast_lr, first_order=False)
opt = torch.optim.Adam(model.parameters(), meta_lr)
loss = nn.CrossEntropyLoss(reduction='mean')

## Tracking Acc and Loss

In [19]:
train_acc_list = []
valid_acc_list = []
train_err_list = []
valid_err_list = []

test_acc_list = []
test_err_list = []

In [20]:
for iteration in range(1, iterations+1):
    opt.zero_grad()
    meta_train_err_sum = 0.0
    meta_train_acc_sum = 0.0
    meta_valid_err_sum = 0.0
    meta_valid_acc_sum = 0.0

    for task in range(meta_batch_size):
        # Compute meta-training loss
        learner = maml.clone()
        batch = train_tasks.sample()
        evaluation_error, evaluation_accuracy = fast_adapt(batch,
                                                            learner,
                                                            loss,
                                                            adapt_steps,
                                                            shots,
                                                            ways,
                                                            device)
        evaluation_error.backward()
        meta_train_err_sum += evaluation_error.item()
        meta_train_acc_sum += evaluation_accuracy.item()
        
        # Compute meta-validation loss
        learner = maml.clone()
        batch = valid_tasks.sample()
        evaluation_error, evaluation_accuracy = fast_adapt(batch,
                                                            learner,
                                                            loss,
                                                            adapt_steps,
                                                            shots,
                                                            ways,
                                                            device)
        meta_valid_err_sum += evaluation_error.item()
        meta_valid_acc_sum += evaluation_accuracy.item()

    # Train
    meta_train_acc = meta_train_acc_sum / meta_batch_size
    meta_train_err = meta_train_err_sum / meta_batch_size
    # Valid
    meta_valid_acc = meta_valid_acc_sum / meta_batch_size
    meta_valid_err = meta_valid_err_sum / meta_batch_size

    if meta_train_err <= min(train_err_list) and meta_valid_err <= min(valid_err_list):
        if not os.path.exists("./models"):
            os.makedirs("./models")
        torch.save(model.state_dict(), './models/' + expriment_name + '_best.pth')

    train_acc_list.append(meta_train_acc)
    train_err_list.append(meta_train_err)
    valid_acc_list.append(meta_valid_acc)
    valid_err_list.append(meta_valid_err)

    # Print some metrics
    # print('\n')
    logging.info('Iteration {}:'.format(iteration))
    logging.info('Meta Train Error: {}.'.format(meta_train_err))
    logging.info('Meta Train Accuracy: {}.'.format(meta_train_acc))
    logging.info('Meta Valid Error: {}.'.format(meta_valid_err))
    logging.info('Meta Valid Accuracy: {}.\n'.format(meta_valid_acc))

    # ========================= plot ==========================
    if (iteration % display_step == 0):
        x_ticks = np.arange(1, iterations, 1)
        plt.figure(figsize=(12, 4))
        plt.subplot(121)
        plt.plot(train_acc_list, '-o', label="train acc")
        plt.plot(valid_acc_list, '-o', label="valid acc")
        plt.xticks(x_ticks)
        plt.xlabel('Trainin iteration')
        plt.ylabel('Accuracy')
        plt.title("Accuracy Curve by Iteration")
        plt.legend()
        plt.subplot(122)
        plt.plot(train_err_list, '-o', label="train loss")
        plt.plot(valid_err_list, '-o', label="valid loss")
        plt.xticks(x_ticks)
        plt.xlabel('Trainin iteration')
        plt.ylabel('Loss')
        plt.title("Loss Curve by Iteration")
        plt.legend()
        plt.suptitle("CWRU Bearing Fault Diagnosis {}way-{}shot".format(ways, shots))
        plt.savefig('./results/' + expriment_name + '_{}.png'.format(iteration))
        plt.show()

    # Average the accumulated gradients and optimize
    for p in model.parameters():
        p.grad.data.mul_(1.0 / meta_batch_size)
    opt.step()

KeyboardInterrupt: 

In [None]:
# Compute meta-testing loss
meta_test_error = 0.0
meta_test_accuracy = 0.0
model.load_state_dict(torch.load('./models/' + expriment_name + '_best.pth'))
for task in range(test_batch_size):
    learner = maml.clone()
    batch = test_tasks.sample()
    evaluation_error, evaluation_accuracy = fast_adapt(batch,
                                                        learner,
                                                        loss,
                                                        adapt_steps,
                                                        shots,
                                                        ways,
                                                        device)
    meta_test_error += evaluation_error.item()
    meta_test_accuracy += evaluation_accuracy.item()
# print('\n')
logging.info('Meta Test Error: {}.'.format(meta_test_error / meta_batch_size))
logging.info('Meta Test Accuracy: {}.\n'.format(meta_test_accuracy / meta_batch_size))