In [1]:
import os
import easydict
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from tqdm import tqdm

from torchmeta.utils.data import BatchMetaDataLoader
from maml.utils import load_dataset, load_model, update_parameters, get_accuracy, get_graph_regularizer

In [28]:
args = easydict.EasyDict({'folder': './dataset',
                          'dataset': 'miniimagenet',
                          'device': 'cuda:1',
                          'download': True,
                          'num_shots': 5,
                          'num_ways': 5,
                          'meta_lr': 1e-3,
                          'first_order': False,
                          'step_size': 0.7,
                          'hidden_size': 64,
                          'output_folder': './output/',
                          'save_name': None,
                          'batch_size': 4,
                          'batch_iter': 1200,
                          'train_batches': 50,
                          'valid_batches': 25,
                          'test_batches': 2500,
                          'num_workers': 1,
                          'graph_gamma': 5.0,
                          'graph_beta': 1e-5,
                          'graph_regularizer': False,
                          'fc_regularizer': False,
                          'task_embedding_method': None,
                          'edge_generation_method': None,
                          'best_valid_error_test': False,
                          'best_valid_accuracy_test': False})

In [29]:
args.task_embedding_method = 'gcn'
args.edge_generation_method = 'max_normalization'
args.save_name = 'te_gcn_maxnorm_l1_output_normalization'
args.best_valid_error_test = True

In [30]:
model = load_model(args)
filename = './output/miniimagenet_{}/logs/logs.csv'.format(args.save_name)
logs = pd.read_csv(filename)

if args.best_valid_error_test:
    valid_logs = list(logs[logs['valid_error']!=0]['valid_error'])
    best_valid_epochs = (valid_logs.index(min(valid_logs))+1)*50
else:
    valid_logs = list(logs[logs['valid_accuracy']!=0]['valid_accuracy'])
    best_valid_epochs = (valid_logs.index(max(valid_logs))+1)*50

best_valid_model = torch.load('./output/miniimagenet_{}/models/epochs_{}.pt'.format(args.save_name, best_valid_epochs))

mode = 'meta_test'
dataset = load_dataset(args, mode)
dataloader = BatchMetaDataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)

In [31]:
sample_number = 100
sample_tasks = []
for _ in range(sample_number):
    sample_task = dataset.sample_task()
    for idx, (image, label) in enumerate(sample_task['train']): # support set in meta_test
        if idx == 0:
            s_images = image.unsqueeze(0)
            s_labels = [label]
            s_real_labels = [sample_task['train'].index[label]]
        else:
            s_images = torch.cat([s_images, image.unsqueeze(0)], dim=0)
            s_labels.append(label)
            s_real_labels.append(sample_task['train'].index[label])
    
    for idx, (image, label) in enumerate(sample_task['test']): # query set in meta_test
        if idx == 0:
            q_images = image.unsqueeze(0)
            q_labels = [label]
            q_real_labels = [sample_task['test'].index[label]]
        else:
            q_images = torch.cat([q_images, image.unsqueeze(0)], dim=0)
            q_labels.append(label)
            q_real_labels.append(sample_task['test'].index[label])
    
    s_labels = torch.tensor(s_labels).type(torch.LongTensor)
    s_real_labels = torch.tensor(s_real_labels).type(torch.LongTensor)
    q_labels = torch.tensor(q_labels).type(torch.LongTensor)
    q_real_labels = torch.tensor(q_real_labels).type(torch.LongTensor)
    sample_tasks.append((s_images, s_labels, s_real_labels, q_images, q_labels, q_real_labels))
print (sample_tasks[0][0].shape, sample_tasks[0][1].shape, sample_tasks[0][2].shape)
print (sample_tasks[0][3].shape, sample_tasks[0][4].shape, sample_tasks[0][5].shape)

torch.Size([25, 3, 84, 84]) torch.Size([25]) torch.Size([25])
torch.Size([75, 3, 84, 84]) torch.Size([75]) torch.Size([75])


In [32]:
before_support_accuracy_list = []
after_support_accuracy_list = []
after_query_accuracy_list = []

for idx in range(sample_number):
    model.load_state_dict(best_valid_model)
    model.to(args.device)
    # model.classifier.weight.data.fill_(1.)
    # model.classifier.bias.data.fill_(1.)

    support_input = sample_tasks[idx][0].to(args.device)
    support_target = sample_tasks[idx][1].to(args.device)
    support_real_target = sample_tasks[idx][2]
    query_input = sample_tasks[idx][3].to(args.device)
    query_target = sample_tasks[idx][4].to(args.device)
    query_real_target = sample_tasks[idx][5]

    model.train()
    support_features, support_logit = model(support_input)
    _, support_pred_target = torch.max(support_logit, dim=1)
    before_support_accuracy_list.append(sum(support_target==support_pred_target)/float(len(support_target)))
    # print ('before support accuracy: {}'.format(sum(support_target==support_pred_target)/float(len(support_target))))

    inner_loss = F.cross_entropy(support_logit, support_target)
    model.zero_grad()
    params = update_parameters(model, inner_loss, step_size=args.step_size, first_order=args.first_order)

    support_features, support_logit = model(support_input, params=params) # inner loss를 통해 적어도 1번이라도 업데이트 되었을 때 (params=params 들어갔을 떄), 제대로된 task_embedding이 뽑힘. 근데 뽑는다 한들 어떻게 합치지;;
    _, support_pred_target = torch.max(support_logit, dim=1)
    after_support_accuracy_list.append(sum(support_target==support_pred_target)/float(len(support_target)))
    # print ('after support accuracy: {}'.format(sum(support_target==support_pred_target)/float(len(support_target))))

    query_features, query_logit = model(query_input, params=params)
    _, query_pred_target = torch.max(query_logit, dim=1)
    after_query_accuracy_list.append(sum(query_target==query_pred_target)/float(len(query_target)))
    # print ('query accuracy: {}'.format(sum(query_target==query_pred_target)/float(len(query_target))))
    
before_support_accuracy_list = [v.cpu().item() for v in before_support_accuracy_list]
after_support_accuracy_list = [v.cpu().item() for v in after_support_accuracy_list]
after_query_accuracy_list = [v.cpu().item() for v in after_query_accuracy_list]

print (np.mean(before_support_accuracy_list), np.mean(after_support_accuracy_list), np.mean(after_query_accuracy_list))

0.19919999405741692 0.9775999885797501 0.6581333476305008


In [None]:
# pca (일단 임의로 features의 mean을 task embedding 이라고 함)
embedded_tasks = []
for idx in enumerate(sample_tasks):
    print (sample_task[2])
    for (image, label, real_label) in sample_task.:
        features, out = model(image)
        embedded_tasks.append(torch.mean(features, dim=0))