In [None]:
import os
import easydict
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from tqdm import tqdm
from maml.utils import load_dataset, load_model, update_parameters

%load_ext autoreload
%autoreload 2

# log graph

In [None]:
def plot_logs(dataset, model, shot1_filename, shot5_filename, mode='valid'):
    fig, axes = plt.subplots(1, 2, sharey=False, figsize=(16,8))
    error = mode + '_error'
    accuracy = mode + '_accuracy'
    colors = ['green', 'blue', 'red']
    
    error_logs_list = []
    accuracy_logs_list = []
    
    for filename in shot1_filename:
        file_logs = pd.read_csv(filename)
        error_logs_list.append(np.array(file_logs[error]))
        accuracy_logs_list.append(np.array(file_logs[accuracy]))
        
    for idx, (filename, error_logs) in enumerate(zip(shot1_filename, error_logs_list)):
        axes[0].plot(error_logs.nonzero()[0], error_logs[error_logs.nonzero()[0]], label=filename.split("/")[7], color=colors[idx], linestyle='--')
    axes[0].set_ylim([0.0, 3.0])
    axes[0].legend()

    for idx, (filename, accuracy_logs) in enumerate(zip(shot1_filename, accuracy_logs_list)):
        axes[1].plot(accuracy_logs.nonzero()[0], accuracy_logs[accuracy_logs.nonzero()[0]], label=filename.split("/")[7], color=colors[idx], linestyle='--')
    axes[1].set_ylim([0.2, 0.75])
    if dataset == 'fc100':
        axes[1].set_ylim([0.2, 0.55])
    axes[1].legend()

    
    error_logs_list = []
    accuracy_logs_list = []
    
    for filename in shot5_filename:
        file_logs = pd.read_csv(filename)
        error_logs_list.append(np.array(file_logs[error]))
        accuracy_logs_list.append(np.array(file_logs[accuracy]))
        
    for idx, (filename, error_logs) in enumerate(zip(shot5_filename, error_logs_list)):
        axes[0].plot(error_logs.nonzero()[0], error_logs[error_logs.nonzero()[0]], label=filename.split("/")[7], color=colors[idx])
    axes[0].set_ylim([0.0, 3.0])
    axes[0].legend()

    for idx, (filename, accuracy_logs) in enumerate(zip(shot5_filename, accuracy_logs_list)):
        axes[1].plot(accuracy_logs.nonzero()[0], accuracy_logs[accuracy_logs.nonzero()[0]], label=filename.split("/")[7], color=colors[idx])
    axes[1].set_ylim([0.2, 0.8])
    if dataset == 'fc100':
        axes[1].set_ylim([0.2, 0.55])
    axes[1].legend()
    
    plt.show()
    plt.close()

In [None]:
dataset = 'miniimagenet' # miniimagenet, tieredimagenet, cifar_fs, fc100
model = 'smallconv' # smallconv, largeconv, resnet

shot1_path = '/home/osilab7/hdd/jhoon_maml_backup/exp1/1shot_results'
shot1_file_list = os.listdir(shot1_path)
shot1_file_list = sorted([f for f in shot1_file_list if dataset in f and model in f])

shot5_path = '/home/osilab7/hdd/jhoon_maml_backup/exp1/5shot_results'
shot5_file_list = os.listdir(shot5_path)
shot5_file_list = sorted([f for f in shot5_file_list if dataset in f and model in f])

shot1_filename = ['{}/{}/logs/logs.csv'.format(shot1_path, f) for f in shot1_file_list]
shot5_filename = ['{}/{}/logs/logs.csv'.format(shot5_path, f) for f in shot5_file_list]

plot_logs(dataset, model, shot1_filename, shot5_filename, mode='valid')

# 1000 episodes test

In [None]:
def make_sample_task(dataset):
    sample_task = dataset.sample_task()
    for idx, (image, label) in enumerate(sample_task['train']):
        if idx == 0:
            s_images = image.unsqueeze(0)
            s_labels = [label]
            s_real_labels = [sample_task['train'].index[label]]
        else:
            s_images = torch.cat([s_images, image.unsqueeze(0)], dim=0)
            s_labels.append(label)
            s_real_labels.append(sample_task['train'].index[label])
    
    for idx, (image, label) in enumerate(sample_task['test']):
        if idx == 0:
            q_images = image.unsqueeze(0)
            q_labels = [label]
            q_real_labels = [sample_task['test'].index[label]]
        else:
            q_images = torch.cat([q_images, image.unsqueeze(0)], dim=0)
            q_labels.append(label)
            q_real_labels.append(sample_task['test'].index[label])
    
    s_labels = torch.tensor(s_labels).type(torch.LongTensor)
    s_real_labels = torch.tensor(s_real_labels).type(torch.LongTensor)
    q_labels = torch.tensor(q_labels).type(torch.LongTensor)
    q_real_labels = torch.tensor(q_real_labels).type(torch.LongTensor)
    return [s_images, s_labels, s_real_labels, q_images, q_labels, q_real_labels]

def isfloat(value):
    try:
        float(value)
        return True
    except ValueError:
        return False

def get_arguments(path, dataset, save_name):
    filename = '{}/{}_{}/logs/arguments.txt'.format(path, dataset, save_name)

    args = easydict.EasyDict()
    with open(filename) as f:
        for line in f:
            key, val = line.split(": ")
            if '\n' in val:
                val = val[:-1]
            if isfloat(val):
                if val.isdigit():
                    val = int(val)
                else:
                    val = float(val)
            if val == 'True' or val == 'False':
                val = val == 'True'
            args[key] = val
    return args

def print_accuracy(args, sample_tasks, meta_mode, inner_update_number, criterion=None):
    device = torch.device(args.device)
    sample_number = len(sample_tasks)
    
    index = ['task{}'.format(str(i+1)) for i in range(sample_number)]
    columns = ['Accuracy on support set (before adaptation)']
    for i in range(inner_update_number):
        columns.append('Accuracy on support set (after {} adaptation)'.format(i+1))
        columns.append('Accuracy on query set (after {} adaptation)'.format(i+1))
    filename_pd = '{}/{}_{}/logs/{}_results.csv'.format(args.output_folder, args.dataset, args.save_name, meta_mode)
    test_pd = pd.DataFrame(np.zeros([sample_number, len(columns)]), index=index, columns=columns)
    
    model = load_model(args)
    
    filename = '{}/{}_{}/logs/logs.csv'.format(args.output_folder, args.dataset, args.save_name)
    logs = pd.read_csv(filename)
    
    if criterion == 'error':
        valid_logs = list(logs[logs['valid_error']!=0]['valid_error'])
        best_valid_epochs = (valid_logs.index(min(valid_logs))+1)*args.train_batches
    elif criterion == 'accuracy':
        valid_logs = list(logs[logs['valid_accuracy']!=0]['valid_accuracy'])
        best_valid_epochs = (valid_logs.index(max(valid_logs))+1)*args.train_batches
    
    checkpoint = '{}/{}_{}/models/epochs_{}.pt'.format(args.output_folder, args.dataset, args.save_name, best_valid_epochs)
    checkpoint = torch.load(checkpoint, map_location=device)
    
    for idx in tqdm(range(sample_number)):
        task_log = []
               
        model.load_state_dict(checkpoint, strict=True)
        
        cos = torch.nn.CosineSimilarity(dim=1)
        distances = torch.zeros([5, 5])
        for i in range(len(model.classifier.weight.data)):
            distances[i] = cos(torch.cat([model.classifier.weight.data[i].unsqueeze(0)]*5, dim=0), model.classifier.weight.data)
            distances[i,i] = 0.
            
        print (torch.sum(distances)/20.)
        model.to(device)

        support_input = sample_tasks[idx][0].to(device)
        support_target = sample_tasks[idx][1].to(device)
        support_real_target = sample_tasks[idx][2]
        query_input = sample_tasks[idx][3].to(device)
        query_target = sample_tasks[idx][4].to(device)
        query_real_target = sample_tasks[idx][5]

        model.train()
        support_features, support_logit = model(support_input)
        _, support_pred_target = torch.max(support_logit, dim=1)
        task_log.append((sum(support_target==support_pred_target)/float(len(support_target))).item())
        
        for number in range(inner_update_number):
            inner_loss = F.cross_entropy(support_logit, support_target)
            model.zero_grad()
            
            params = update_parameters(model, inner_loss, extractor_step_size=args.extractor_step_size, classifier_step_size=args.classifier_step_size, first_order=args.first_order)
            model.load_state_dict(params, strict=False)
            
            support_features, support_logit = model(support_input)
            _, support_pred_target = torch.max(support_logit, dim=1)
            
            query_features, query_logit = model(query_input)
            _, query_pred_target = torch.max(query_logit, dim=1)
                        
            task_log.append((sum(support_target==support_pred_target)/float(len(support_target))).item())
            task_log.append((sum(query_target==query_pred_target)/float(len(query_target))).item())
               
        test_pd.iloc[idx] = task_log
    test_pd.loc[sample_number+1], test_pd.loc[sample_number+2] = test_pd.mean(axis=0), test_pd.std(axis=0)
    test_pd.index = list(test_pd.index[:sample_number]) + ['mean', 'std']
    test_pd.to_csv(filename_pd)
    
def get_similarity(args, sample_tasks, meta_mode, inner_update_number, checkpoint_epoch):
    device = torch.device(args.device)
    sample_number = len(sample_tasks)
    
    model = load_model(args)
    
    checkpoint = '{}/{}_{}/models/epochs_{}.pt'.format(args.output_folder, args.dataset, args.save_name, checkpoint_epoch)
    checkpoint = torch.load(checkpoint, map_location=device)
    
    for idx in range(sample_number):      
        model.load_state_dict(checkpoint, strict=True)
        
        cos = torch.nn.CosineSimilarity(dim=1)
        distances = torch.zeros([5, 5])
        for i in range(len(model.classifier.weight.data)):
            distances[i] = cos(torch.cat([model.classifier.weight.data[i].unsqueeze(0)]*5, dim=0), model.classifier.weight.data)
            distances[i,i] = 0.
        before_fc_similarity = torch.sum(distances)/(len(distances)*len(distances)-len(distances))
        
        model.to(device)

        support_input = sample_tasks[idx][0].to(device)
        support_target = sample_tasks[idx][1].to(device)
        support_real_target = sample_tasks[idx][2]
        query_input = sample_tasks[idx][3].to(device)
        query_target = sample_tasks[idx][4].to(device)
        query_real_target = sample_tasks[idx][5]

        model.train()
        support_features, support_logit = model(support_input)
        _, support_pred_target = torch.max(support_logit, dim=1)
        
        query_features, query_logit = model(query_input)
        _, query_pred_target = torch.max(query_logit, dim=1)
        
        distances = torch.cdist(support_features, support_features)
        before_support_features_similarity = torch.sum(distances)/(len(distances)*len(distances)-len(distances))
        
        distances = torch.cdist(support_logit, support_logit)
        before_support_logits_similarity = torch.sum(distances)/(len(distances)*len(distances)-len(distances))
        
        distances = torch.cdist(query_features, query_features)
        before_query_features_similarity = torch.sum(distances)/(len(distances)*len(distances)-len(distances))
        
        distances = torch.cdist(query_logit, query_logit)
        before_query_logits_similarity = torch.sum(distances)/(len(distances)*len(distances)-len(distances))

        for number in range(inner_update_number):
            inner_loss = F.cross_entropy(support_logit, support_target)
            model.zero_grad()
            
            params = update_parameters(model, inner_loss, extractor_step_size=args.extractor_step_size, classifier_step_size=args.classifier_step_size, first_order=args.first_order)
            model.load_state_dict(params, strict=False)
            
            support_features, support_logit = model(support_input)
            _, support_pred_target = torch.max(support_logit, dim=1)
            
            query_features, query_logit = model(query_input)
            _, query_pred_target = torch.max(query_logit, dim=1)
        
        distances = torch.zeros([5, 5])
        for i in range(len(model.classifier.weight.data)):
            distances[i] = cos(torch.cat([model.classifier.weight.data[i].unsqueeze(0)]*5, dim=0), model.classifier.weight.data)
            distances[i,i] = 0.
        after_fc_similarity = torch.sum(distances)/(len(distances)*len(distances)-len(distances))
        
        distances = torch.cdist(support_features, support_features)
        after_support_features_similarity = torch.sum(distances)/(len(distances)*len(distances)-len(distances))
        
        distances = torch.cdist(support_logit, support_logit)
        after_support_logits_similarity = torch.sum(distances)/(len(distances)*len(distances)-len(distances))
        
        distances = torch.cdist(query_features, query_features)
        after_query_features_similarity = torch.sum(distances)/(len(distances)*len(distances)-len(distances))
        
        distances = torch.cdist(query_logit, query_logit)
        after_query_logits_similarity = torch.sum(distances)/(len(distances)*len(distances)-len(distances))
        
        return (before_fc_similarity.item(), after_fc_similarity.item(),
                before_support_features_similarity.item(), after_support_features_similarity.item(),
                before_support_logits_similarity.item(), after_support_logits_similarity.item(),
                before_query_features_similarity.item(), after_query_features_similarity.item(),
                before_query_logits_similarity.item(), after_query_logits_similarity.item())

In [None]:
sample_number = 1
dataset = 'miniimagenet' # miniimagenet, tieredimagenet, cifar_fs, fc100
num_shots = 5 # 1, 5
dataset_args = easydict.EasyDict({'folder': '/home/osilab7/hdd/ml_dataset',
                                  'dataset': dataset,
                                  'num_ways': 5,
                                  'num_shots': num_shots,
                                  'download': True})

train_tasks = [make_sample_task(load_dataset(dataset_args, 'meta_train')) for _ in tqdm(range(sample_number))]
test_tasks = [make_sample_task(load_dataset(dataset_args, 'meta_test')) for _ in tqdm(range(sample_number))]

inner_update_number = 1
criterion = 'error' # error, accuracy

path = '/home/osilab7/hdd/jhoon_maml_backup/exp1/1shot_results' if num_shots == 1 else '/home/osilab7/hdd/jhoon_maml_backup/exp1/5shot_results'
for model in ['smallconv']: # smallconv, largeconv, resnet
    for algorithm in ['both']: # both, extractor, classifier
        save_name = '{}shot_{}_{}'.format(num_shots, model, algorithm)
        args = get_arguments(path, dataset, save_name)
        args.folder = '/home/osilab7/hdd/ml_dataset'
        args.device = 'cuda:0'
        args.output_folder = path
        print_accuracy(args, train_tasks, meta_mode='meta_train', inner_update_number=inner_update_number, criterion=criterion)
        print_accuracy(args, test_tasks, meta_mode='meta_test', inner_update_number=inner_update_number, criterion=criterion)

---

# get similarity

In [None]:
sample_number = 1
dataset = 'miniimagenet' # miniimagenet, tieredimagenet, cifar_fs, fc100
num_shots = 5 # 1, 5
dataset_args = easydict.EasyDict({'folder': '/home/osilab7/hdd/ml_dataset',
                                  'dataset': dataset,
                                  'num_ways': 5,
                                  'num_shots': num_shots,
                                  'download': True})

train_tasks = [make_sample_task(load_dataset(dataset_args, 'meta_train')) for _ in tqdm(range(sample_number))]
test_tasks = [make_sample_task(load_dataset(dataset_args, 'meta_test')) for _ in tqdm(range(sample_number))]

inner_update_number = 1
before_fc_similarity_list = []
after_fc_similarity_list = []
before_support_features_similarity_list = []
after_support_features_similarity_list = []
before_support_logits_similarity_list = []
after_support_logits_similarity_list = []
before_query_features_similarity_list = []
after_query_features_similarity_list = []
before_query_logits_similarity_list = []
after_query_logits_similarity_list = []

path = '/home/osilab7/hdd/jhoon_maml_backup/exp1/1shot_results' if num_shots == 1 else '/home/osilab7/hdd/jhoon_maml_backup/exp1/5shot_results'
for model in ['smallconv']: # resnet 추가
    for algorithm in ['both']:
        for checkpoint_epoch in range(100, 60001, 100):
            save_name = '{}shot_{}_{}'.format(num_shots, model, algorithm)
            args = get_arguments(path, dataset, save_name)
            args.folder = '/home/osilab7/hdd/ml_dataset'
            args.device = 'cuda:0'
            args.output_folder = path
            outputs = get_similarity(args, train_tasks, meta_mode='meta_train', inner_update_number=inner_update_number, checkpoint_epoch=checkpoint_epoch)
            before_fc_similarity_list.append(outputs[0]); after_fc_similarity_list.append(outputs[1])
            before_support_features_similarity_list.append(outputs[2]); after_support_features_similarity_list.append(outputs[3])
            before_support_logits_similarity_list.append(outputs[4]); after_support_logits_similarity_list.append(outputs[5])
            before_query_features_similarity_list.append(outputs[6]); after_query_features_similarity_list.append(outputs[7])
            before_query_logits_similarity_list.append(outputs[8]); after_query_logits_similarity_list.append(outputs[9])
            # get_similarity(args, test_tasks, meta_mode='meta_test', inner_update_number=inner_update_number, criterion=criterion)

In [None]:
# before_fc_similarity_list (0)
# after_fc_simiaafter_fc_similarity_list (1)
# before_support_features_similarity_list (2)
# after_support_features_similarity_list (3)
# before_support_logits_similarity_list (4)
# after_support_logits_similarity_list (5)
# before_query_features_similarity_list (6)
# after_query_features_similarity_list (7)
# before_query_logits_similarity_list (8)
# after_query_logits_similarity_list (9)
# extractor의 역할을 살피기 위해 feature space 간의 거리를 살핌
# classifier의 역할을 살피기 위해 feature space와 logit space를 살피고 어떻게 mapping 되는 지를 봄 + fc weight의 similarity