In [None]:
import os
import easydict
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from tqdm import tqdm
from maml.utils import load_dataset, load_model, update_parameters

%load_ext autoreload
%autoreload 2

# log graph

In [None]:
def plot_logs(log_type, logs, ax):
    if log_type == 'error':
        train_log = np.array(logs['train_error'])
        valid_log = np.array(logs['valid_error'])
        test_log = np.array(logs['test_error'])
        ymin = 0.0
        ymax = 3.0
        ax.set_title('error graph, test error: {}'.format(round(test_log[-1],4)))
    elif log_type == 'accuracy':
        train_log = np.array(logs['train_accuracy'])
        valid_log = np.array(logs['valid_accuracy'])
        test_log = np.array(logs['test_accuracy'])
        ymin = 0.0
        ymax = 1.0
        ax.set_title('accruacy graph, test accruacy: {}'.format(round(test_log[-1],4)))
    
    ax.plot(train_log.nonzero()[0], train_log[train_log.nonzero()[0]], label='train')
    ax.plot(valid_log.nonzero()[0], valid_log[valid_log.nonzero()[0]], label='valid')
    ax.scatter(test_log.nonzero()[0], test_log[test_log.nonzero()[0]], s=100, color='red', label='test')
    
    ax.set_ylim([ymin, ymax])
    ax.legend()

def plot_diff_logs(filename_list, mode='valid'):
    fig, axes = plt.subplots(1, 2, sharey=False, figsize=(16,8))
    
    error_logs_list = []
    accuracy_logs_list = []
    
    error = mode + '_error'
    accuracy = mode + '_accuracy'
    
    for filename in filename_list:
        file_logs = pd.read_csv(filename)
        error_logs_list.append(np.array(file_logs[error]))
        accuracy_logs_list.append(np.array(file_logs[accuracy]))
    
    for filename, error_logs in list(zip(filename_list, error_logs_list)):
        axes[0].plot(error_logs.nonzero()[0], error_logs[error_logs.nonzero()[0]], label="_".join(filename.split("/")[2].split("_")[1:]))
    axes[0].set_ylim([0.0, 3.0])
    axes[0].legend()

    for filename, accuracy_logs in list(zip(filename_list, accuracy_logs_list)):
        axes[1].plot(accuracy_logs.nonzero()[0], accuracy_logs[accuracy_logs.nonzero()[0]], label="_".join(filename.split("/")[2].split("_")[1:]))
    axes[1].set_ylim([0.0, 1.0])
    axes[1].legend()
        
    plt.show()
    plt.close()

In [None]:
path = './output/'
file_list = os.listdir(path)
file_list = sorted([f for f in file_list if 'miniimagenet' in f])
print (file_list)

In [None]:
file = [f for f in file_list if 'both_inner' in f]
filename = './output/{}/logs/logs.csv'.format(file[0])
logs = pd.read_csv(filename)

fig, axes = plt.subplots(1, 2, sharey=False, figsize=(16, 8))

plot_logs(log_type='error', logs=logs, ax=axes[0])
plot_logs(log_type='accuracy', logs=logs, ax=axes[1])

plt.suptitle(filename.split("/")[2])
plt.show()
plt.close()

In [None]:
file1 = [f for f in file_list if 'both_inner' in f]
file2 = [f for f in file_list if 'classifier_inner' in f]
file3 = [f for f in file_list if 'extractor_inner' in f]

filename1 = './output/{}/logs/logs.csv'.format(file1[0])
filename2 = './output/{}/logs/logs.csv'.format(file2[0])
filename3 = './output/{}/logs/logs.csv'.format(file3[0])

filename_list = [filename1, filename2, filename3]
plot_diff_logs(filename_list, mode='valid')

# Overfitting test
### Overfitting 1: meta train <-> meta test
### Overfitting 2: support <-> query in meta test

In [None]:
def make_sample_task(dataset):
    sample_task = dataset.sample_task()
    for idx, (image, label) in enumerate(sample_task['train']):
        if idx == 0:
            s_images = image.unsqueeze(0)
            s_labels = [label]
            s_real_labels = [sample_task['train'].index[label]]
        else:
            s_images = torch.cat([s_images, image.unsqueeze(0)], dim=0)
            s_labels.append(label)
            s_real_labels.append(sample_task['train'].index[label])
    
    for idx, (image, label) in enumerate(sample_task['test']):
        if idx == 0:
            q_images = image.unsqueeze(0)
            q_labels = [label]
            q_real_labels = [sample_task['test'].index[label]]
        else:
            q_images = torch.cat([q_images, image.unsqueeze(0)], dim=0)
            q_labels.append(label)
            q_real_labels.append(sample_task['test'].index[label])
    
    s_labels = torch.tensor(s_labels).type(torch.LongTensor)
    s_real_labels = torch.tensor(s_real_labels).type(torch.LongTensor)
    q_labels = torch.tensor(q_labels).type(torch.LongTensor)
    q_real_labels = torch.tensor(q_real_labels).type(torch.LongTensor)
    return [s_images, s_labels, s_real_labels, q_images, q_labels, q_real_labels]

In [None]:
sample_number = 1000
dataset_args = easydict.EasyDict({'folder': './dataset',
                                  'dataset': 'miniimagenet',
                                  'num_ways': 5,
                                  'num_shots': 5,
                                  'download': False})
train_tasks = [make_sample_task(load_dataset(dataset_args, 'meta_train')) for _ in tqdm(range(sample_number))]
test_tasks = [make_sample_task(load_dataset(dataset_args, 'meta_test')) for _ in tqdm(range(sample_number))]

In [None]:
def isfloat(value):
    try:
        float(value)
        return True
    except ValueError:
        return False

def get_arguments(dataset, save_name):
    filename = './output/'+'{}_{}/logs/arguments.txt'.format(dataset, save_name)

    args = easydict.EasyDict()
    with open(filename) as f:
        for line in f:
            key, val = line.split(": ")
            if '\n' in val:
                val = val[:-1]
            if isfloat(val):
                if val.isdigit():
                    val = int(val)
                else:
                    val = float(val)
            if val == 'True' or val == 'False':
                val = val == 'True'
            args[key] = val
    return args

def print_accuracy(args, sample_tasks, meta_mode, inner_update_number, criterion=None):
    device = torch.device(args.device)
    sample_number = len(sample_tasks)
    
    index = ['task{}'.format(str(i+1)) for i in range(sample_number)]
    columns = ['Accuracy on support set (before adaptation)']
    for i in range(inner_update_number):
        columns.append('Accuracy on support set (after {} adaptation)'.format(i+1))
        columns.append('Accuracy on query set (after {} adaptation)'.format(i+1))
    filename_pd = './output/miniimagenet_{}/logs/{}_results.csv'.format(args.save_name, meta_mode)
    test_pd = pd.DataFrame(np.zeros([sample_number, len(columns)]), index=index, columns=columns)
    
    model = load_model(args)
    
    filename = './output/miniimagenet_{}/logs/logs.csv'.format(args.save_name)
    logs = pd.read_csv(filename)
    
    if criterion == 'error':
        valid_logs = list(logs[logs['valid_error']!=0]['valid_error'])
        best_valid_epochs = (valid_logs.index(min(valid_logs))+1)*50
    elif criterion == 'accuracy':
        valid_logs = list(logs[logs['valid_accuracy']!=0]['valid_accuracy'])
        best_valid_epochs = (valid_logs.index(max(valid_logs))+1)*50

    checkpoint = args.output_folder + '{}_{}/'.format(args.dataset, args.save_name) + 'models/epochs_{}.pt'.format(best_valid_epochs)
    checkpoint = torch.load(checkpoint, map_location=device)
    
    for idx in tqdm(range(sample_number)):
        task_log = []
               
        model.load_state_dict(checkpoint, strict=True)
        model.to(device)

        support_input = sample_tasks[idx][0].to(device)
        support_target = sample_tasks[idx][1].to(device)
        support_real_target = sample_tasks[idx][2]
        query_input = sample_tasks[idx][3].to(device)
        query_target = sample_tasks[idx][4].to(device)
        query_real_target = sample_tasks[idx][5]

        model.train()
        support_features, support_logit = model(support_input)
        _, support_pred_target = torch.max(support_logit, dim=1)
        task_log.append((sum(support_target==support_pred_target)/float(len(support_target))).item())
        
        for number in range(inner_update_number):
            inner_loss = F.cross_entropy(support_logit, support_target)
            model.zero_grad()
            
            params = update_parameters(model, inner_loss, extractor_step_size=args.extractor_step_size, classifier_step_size=args.classifier_step_size, first_order=args.first_order)
            model.load_state_dict(params, strict=True)
            
            support_features, support_logit = model(support_input)
            _, support_pred_target = torch.max(support_logit, dim=1)
            
            query_features, query_logit = model(query_input)
            _, query_pred_target = torch.max(query_logit, dim=1)
                        
            task_log.append((sum(support_target==support_pred_target)/float(len(support_target))).item())
            task_log.append((sum(query_target==query_pred_target)/float(len(query_target))).item())
            
        test_pd.iloc[idx] = task_log
    test_pd.loc[sample_number+1], test_pd.loc[sample_number+2] = test_pd.mean(axis=0), test_pd.std(axis=0)
    test_pd.index = list(test_pd.index[:sample_number]) + ['mean', 'std']
    test_pd.to_csv(filename_pd)

In [None]:
dataset = dataset_args.dataset
save_name = 'both_inner'

args = get_arguments(dataset, save_name)
inner_update_number = 1
criterion = 'error' # or 'accuracy'

print_accuracy(args, train_tasks, meta_mode='meta_train', inner_update_number=inner_update_number, criterion=criterion)
print_accuracy(args, test_tasks, meta_mode='meta_test', inner_update_number=inner_update_number, criterion=criterion)

---