In [None]:
import os
import copy
import easydict
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F

import matplotlib
from tqdm import tqdm
from collections import OrderedDict
from maml.utils import load_dataset, load_model, update_parameters

In [None]:
matplotlib.rcParams.update({'font.size': 20})

---

# 1000 episodes test

In [None]:
def make_sample_task(dataset):
    sample_task = dataset.sample_task()
    for idx, (image, label) in enumerate(sample_task['train']):
        if idx == 0:
            s_images = image.unsqueeze(0)
            s_labels = [label]
            s_real_labels = [sample_task['train'].index[label]]
        else:
            s_images = torch.cat([s_images, image.unsqueeze(0)], dim=0)
            s_labels.append(label)
            s_real_labels.append(sample_task['train'].index[label])
    
    for idx, (image, label) in enumerate(sample_task['test']):
        if idx == 0:
            q_images = image.unsqueeze(0)
            q_labels = [label]
            q_real_labels = [sample_task['test'].index[label]]
        else:
            q_images = torch.cat([q_images, image.unsqueeze(0)], dim=0)
            q_labels.append(label)
            q_real_labels.append(sample_task['test'].index[label])
    
    s_labels = torch.tensor(s_labels).type(torch.LongTensor)
    s_real_labels = torch.tensor(s_real_labels).type(torch.LongTensor)
    q_labels = torch.tensor(q_labels).type(torch.LongTensor)
    q_real_labels = torch.tensor(q_real_labels).type(torch.LongTensor)
    return [s_images, s_labels, s_real_labels, q_images, q_labels, q_real_labels]

def isfloat(value):
    try:
        float(value)
        return True
    except ValueError:
        return False

def get_arguments(path, dataset, save_name):
    filename = '{}/{}_{}/logs/arguments.txt'.format(path, dataset, save_name)

    args = easydict.EasyDict()
    with open(filename) as f:
        for line in f:
            key, val = line.split(": ")
            if '\n' in val:
                val = val[:-1]
            if isfloat(val):
                if val.isdigit():
                    val = int(val)
                else:
                    val = float(val)
            if val == 'True' or val == 'False':
                val = val == 'True'
            args[key] = val
    return args

def print_accuracy(args, test_dataset, sample_tasks, iteration, inner_update_num, NIL_testing=False):
    device = torch.device(args.device)
    sample_number = len(sample_tasks)
    
    index = ['task{}'.format(str(i+1)) for i in range(sample_number)]
    columns = []
    columns += ['Accuracy on support set (before adaptation)', 'Accuracy on query set (before adaptation)']
    for i in range(inner_update_num):
        columns += ['Accuracy on support set (after {} adaptation(s))'.format(i+1),
                    'Accuracy on query set (after {} adaptation(s))'.format(i+1)]
    
    if NIL_testing:
        filename_pd = '{}/{}/logs/{}_nil_results_{}.csv'.format(args.output_folder, args.save_dir, test_dataset, iteration)
    else:
        filename_pd = '{}/{}/logs/{}_results_{}.csv'.format(args.output_folder, args.save_dir, test_dataset, iteration)
    test_pd = pd.DataFrame(np.zeros([sample_number, len(columns)]), index=index, columns=columns)
    
    model = load_model(args)
    checkpoint = '{}/{}/models/best_val_acc_model.pt'.format(args.output_folder, args.save_dir)    
    checkpoint = torch.load(checkpoint, map_location=device)
    
    step_size = OrderedDict()
    
    if args.model == '4conv_sep':
        for name, _ in model.named_parameters():
            if 'classifier' in name:
                step_size[name] = args.classifier_step_size
            else:
                if 'conv'+args.save_name[-1] in name:
                    step_size[name] = args.extractor_step_size
                else:
                    step_size[name] = 0.0
    else:
        for name, _ in model.named_parameters():
            if 'classifier' in name:
                step_size[name] = args.classifier_step_size
            else:
                step_size[name] = args.extractor_step_size
    
    for idx in tqdm(range(sample_number)):
        task_log = []
               
        model.load_state_dict(checkpoint, strict=True)
        model.to(device)

        support_input = sample_tasks[idx][0].to(device)
        support_target = sample_tasks[idx][1].to(device)
        support_real_target = sample_tasks[idx][2]
        query_input = sample_tasks[idx][3].to(device)
        query_target = sample_tasks[idx][4].to(device)
        query_real_target = sample_tasks[idx][5]
        
        model.train()
        
        # before adaptation
        support_features, support_logit = model(support_input)
        _, support_pred_target = torch.max(support_logit, dim=1)
                
        query_features, query_logit = model(query_input)
        _, query_pred_target = torch.max(query_logit, dim=1)
        
        if NIL_testing:
            cos = nn.CosineSimilarity()
            support_features_mean = torch.zeros([args.num_ways, support_features.shape[1]]).to(device)
            support_target_mean = torch.zeros([args.num_ways]).to(device)
            for label in range(args.num_ways):
                support_features_mean[label] = torch.mean(support_features[torch.where(support_target==label)], dim=0)
                support_target_mean[label] = label

            distance = torch.zeros([len(query_features), len(support_features_mean)])
            for i, query_feature in enumerate(query_features):
                distance[i] = cos(torch.cat([query_feature.unsqueeze(0)]*len(support_features_mean)), support_features_mean)
            top_similar_idx = torch.argmax(distance, dim=1)
            
            query_pred_target = support_target_mean[top_similar_idx]
        
        task_log.append((sum(support_target==support_pred_target)/float(len(support_target))).item())
        task_log.append((sum(query_target==query_pred_target)/float(len(query_target))).item())
        
        # after adaptation
        params = None
        for _ in range(inner_update_num):            
            support_features, support_logit = model(support_input, params=params)
            inner_loss = F.cross_entropy(support_logit, support_target)

            model.zero_grad()
            params = update_parameters(model=model,
                                       loss=inner_loss,
                                       params=params,
                                       step_size=step_size,
                                       first_order=args.first_order)
            model.load_state_dict(params, strict=True)

            support_features, support_logit = model(support_input)
            _, support_pred_target = torch.max(support_logit, dim=1)  

            query_features, query_logit = model(query_input)
            _, query_pred_target = torch.max(query_logit, dim=1)
        
            if NIL_testing:            
                cos = nn.CosineSimilarity()
                support_features_mean = torch.zeros([args.num_ways, support_features.shape[1]]).to(device)
                support_target_mean = torch.zeros([args.num_ways]).to(device)
                for label in range(args.num_ways):
                    support_features_mean[label] = torch.mean(support_features[torch.where(support_target==label)], dim=0)
                    support_target_mean[label] = label

                distance = torch.zeros([len(query_features), len(support_features_mean)])
                for i, query_feature in enumerate(query_features):
                    distance[i] = cos(torch.cat([query_feature.unsqueeze(0)]*len(support_features_mean)), support_features_mean)
                top_similar_idx = torch.argmax(distance, dim=1)

                query_pred_target = support_target_mean[top_similar_idx]
            
            task_log.append((sum(support_target==support_pred_target)/float(len(support_target))).item())
            task_log.append((sum(query_target==query_pred_target)/float(len(query_target))).item())

        test_pd.iloc[idx] = task_log
    test_pd.loc[sample_number+1], test_pd.loc[sample_number+2] = test_pd.mean(axis=0), test_pd.std(axis=0)
    test_pd.index = list(test_pd.index[:sample_number]) + ['mean', 'std']
    test_pd.to_csv(filename_pd)

In [None]:
sample_number = 1000
inner_update_num = 1
dataset = [# ('miniimagenet', 'miniimagenet'),
           # ('miniimagenet', 'tieredimagenet'),
           # ('miniimagenet', 'cars'),
           ('cars', 'cars'),
           # ('cars', 'miniimagenet'),
           # ('cars', 'cub')
          ]

model = '4conv'
path = './output_head_abal'

# model = '4conv'
# path = './output_head_abal'

for train_dataset, test_dataset in dataset:
    for iteration in [1, 2, 3, 4, 5]:
        for num_shots in [5]: # 1,5
            dataset_args = easydict.EasyDict({'folder': '/home/osilab7/hdd/ml_dataset/',
                                              'dataset': test_dataset,
                                              'num_ways': 5,
                                              'num_shots': num_shots,
                                              'download': True})

            sample_tasks = [make_sample_task(load_dataset(dataset_args, 'meta_test')) for _ in tqdm(range(sample_number))]
                        
            for algorithm in ['MAML_s', 'MAML_m', 'MAML_ll']:
                save_name = '{}shot_{}_{}'.format(num_shots, model, algorithm)
                args = get_arguments(path, train_dataset, save_name)
                args.device = 'cuda:0'
                
                print_accuracy(args, test_dataset, sample_tasks, iteration=iteration, inner_update_num=inner_update_num, NIL_testing=False)
                # print_accuracy(args, test_dataset, sample_tasks, iteration=iteration, inner_update_num=inner_update_num, NIL_testing=True)

In [None]:
dataset = [('miniimagenet', 'cars')]

model = '4conv'
path = './output'

for train_dataset, test_dataset in dataset:
    for num_shots in [1, 5]:
        print ('train: {}, test: {}, shot: {}'.format(train_dataset, test_dataset, num_shots))
        for algorithm in ['MAML', 'ANIL', 'BOIL']:
            save_name = '{}shot_{}_{}'.format(num_shots, model, algorithm)
            args = get_arguments(path, train_dataset, save_name)
            
            log_folder = '{}/{}/logs'.format(args.output_folder, args.save_dir)
            test_files = [d for d in os.listdir(log_folder) if test_dataset in d]
#             test_files = sorted([d for d in test_files if 'nil' not in d])
            test_files = sorted([d for d in test_files if 'nil' in d])
    
            test_results = []
            for f in test_files:
                test_results.append(list(pd.read_csv(os.path.join(log_folder, f)).iloc[-2])[-1])
            test_results = 100*np.array(test_results)
            print ('alg: {}, {:.2f} ± {:.2f}'.format(algorithm,
                                                     np.mean(test_results),
                                                     np.std(test_results)))

---

# feature space, logit space (Cosine Similarity)

In [None]:
def get_features_logits(args, sample_task, pos):
    device = torch.device(args.device)
    
    model = load_model(args)
    checkpoint = '{}/{}/models/best_val_acc_model.pt'.format(args.output_folder, args.save_dir)    
    checkpoint = torch.load(checkpoint, map_location=device)
    
    step_size = OrderedDict()
    
    if args.model == '4conv_sep':
        for name, _ in model.named_parameters():
            if 'classifier' in name:
                step_size[name] = args.classifier_step_size
            else:
                if 'conv'+args.save_name[-1] in name:
                    step_size[name] = args.extractor_step_size
                else:
                    step_size[name] = 0.0
    else:
        for name, _ in model.named_parameters():
            if 'classifier' in name:
                step_size[name] = args.classifier_step_size
            else:
                step_size[name] = args.extractor_step_size
    
    model.load_state_dict(checkpoint, strict=True)
    model.to(device)
    model.train()
    
    support_input = sample_task[0].to(device)
    support_target = sample_task[1].to(device)
    support_real_target = sample_task[2]
    query_input = sample_task[3].to(device)
    query_target = sample_task[4].to(device)
    query_real_target = sample_task[5]

    # before adaptation
    before_support_features, before_support_logits = model(support_input)
    if args.model == '4conv':
        if pos == 1:
            before_query_features = model.features[0](query_input)
        elif pos == 2:
            before_query_features = model.features[1](model.features[0](query_input))
        elif pos == 3:
            before_query_features = model.features[2](model.features[1](model.features[0](query_input)))
        elif pos == 4:
            before_query_features = model.features[3](model.features[2](model.features[1](model.features[0](query_input))))
        elif pos == 5:
            before_query_features = model.classifier(model.features[3](model.features[2](model.features[1](model.features[0](query_input)))).view(75, -1))
    elif args.model == '4conv_sep':
        if pos == 1:
            before_query_features = model.conv1(query_input)
        elif pos == 2:
            before_query_features = model.conv2(model.conv1(query_input))
        elif pos == 3:
            before_query_features = model.conv3(model.conv2(model.conv1(query_input)))
        elif pos == 4:
            before_query_features = model.conv4(model.conv3(model.conv2(model.conv1(query_input))))
        elif pos == 5:
            before_query_features = model.classifier(model.conv4(model.conv3(model.conv2(model.conv1(query_input)))).view(75, -1))
#     elif args.model == 'resnet':
#         if pos == 1:
#             before_query_features = model.layer1(query_input)
#         elif pos == 2:
#             before_query_features = model.layer2(model.layer1(query_input))
#         elif pos == 3:
#             before_query_features = model.layer3(model.layer2(model.layer1(query_input)))
#         elif pos == 4:
#             before_query_features = model.layer4(model.layer3(model.layer2(model.layer1(query_input))))
#         elif pos == 5:
#             before_query_features = model.classifier(F.avg_pool2d(model.layer4(model.layer3(model.layer2(model.layer1(query_input)))), 5).view(75, -1))
#     elif args.model == 'resnet':
#         before_block4 = model.layer3(model.layer2(model.layer1(query_input)))
#         if pos == 1:
#             before_query_features = model.layer4[0].conv1(before_block4)
#         elif pos == 2:
#             before_query_features = model.layer4[0].bn1(model.layer4[0].conv1(before_block4))
#         elif pos == 3:
#             before_query_features = model.layer4[0].relu1(model.layer4[0].bn1(model.layer4[0].conv1(before_block4)))
#         elif pos == 4:
#             before_query_features = model.layer4[0].conv2(model.layer4[0].relu1(model.layer4[0].bn1(model.layer4[0].conv1(before_block4))))  
#         elif pos == 5:
#             before_query_features = model.layer4[0].bn2(model.layer4[0].conv2(model.layer4[0].relu1(model.layer4[0].bn1(model.layer4[0].conv1(before_block4)))))
#         elif pos == 6:
#             before_query_features = model.layer4[0].relu2(model.layer4[0].bn2(model.layer4[0].conv2(model.layer4[0].relu1(model.layer4[0].bn1(model.layer4[0].conv1(before_block4))))))
#         elif pos == 7:
#             before_query_features = model.layer4[0].conv3(model.layer4[0].relu2(model.layer4[0].bn2(model.layer4[0].conv2(model.layer4[0].relu1(model.layer4[0].bn1(model.layer4[0].conv1(before_block4)))))))
#         elif pos == 8:
#             before_query_features = model.layer4[0].bn3(model.layer4[0].conv3(model.layer4[0].relu2(model.layer4[0].bn2(model.layer4[0].conv2(model.layer4[0].relu1(model.layer4[0].bn1(model.layer4[0].conv1(before_block4))))))))
#         elif pos == 9:
#             before_query_features = model.layer4[0].bn3(model.layer4[0].conv3(model.layer4[0].relu2(model.layer4[0].bn2(model.layer4[0].conv2(model.layer4[0].relu1(model.layer4[0].bn1(model.layer4[0].conv1(before_block4)))))))) + \
#                                    model.layer4[0].downsample(before_block4)
#         elif pos == 10:
#             before_query_features = model.layer4[0].relu3(model.layer4[0].bn3(model.layer4[0].conv3(model.layer4[0].relu2(model.layer4[0].bn2(model.layer4[0].conv2(model.layer4[0].relu1(model.layer4[0].bn1(model.layer4[0].conv1(before_block4)))))))) + \
#                                                          model.layer4[0].downsample(before_block4))
#         elif pos == 11:
#             before_query_features = model.layer4[0].maxpool(model.layer4[0].relu3(model.layer4[0].bn3(model.layer4[0].conv3(model.layer4[0].relu2(model.layer4[0].bn2(model.layer4[0].conv2(model.layer4[0].relu1(model.layer4[0].bn1(model.layer4[0].conv1(before_block4)))))))) + \
#                                                            model.layer4[0].downsample(before_block4)))

    elif args.model == 'resnet': # _b style
        before_block4 = model.layer3(model.layer2(model.layer1(query_input)))
        if pos == 1:
            before_query_features = model.layer4[0].conv1(before_block4)
        elif pos == 2:
            before_query_features = model.layer4[0].bn1(model.layer4[0].conv1(before_block4))
        elif pos == 3:
            before_query_features = model.layer4[0].relu1(model.layer4[0].bn1(model.layer4[0].conv1(before_block4)))
        elif pos == 4:
            before_query_features = model.layer4[0].conv2(model.layer4[0].relu1(model.layer4[0].bn1(model.layer4[0].conv1(before_block4))))  
        elif pos == 5:
            before_query_features = model.layer4[0].bn2(model.layer4[0].conv2(model.layer4[0].relu1(model.layer4[0].bn1(model.layer4[0].conv1(before_block4)))))
        elif pos == 6:
            before_query_features = model.layer4[0].relu2(model.layer4[0].bn2(model.layer4[0].conv2(model.layer4[0].relu1(model.layer4[0].bn1(model.layer4[0].conv1(before_block4))))))
        elif pos == 7:
            before_query_features = model.layer4[0].conv3(model.layer4[0].relu2(model.layer4[0].bn2(model.layer4[0].conv2(model.layer4[0].relu1(model.layer4[0].bn1(model.layer4[0].conv1(before_block4)))))))
        elif pos == 8:
            before_query_features = model.layer4[0].bn3(model.layer4[0].conv3(model.layer4[0].relu2(model.layer4[0].bn2(model.layer4[0].conv2(model.layer4[0].relu1(model.layer4[0].bn1(model.layer4[0].conv1(before_block4))))))))
        elif pos == 9:
            before_query_features = model.layer4[0].relu3(model.layer4[0].bn3(model.layer4[0].conv3(model.layer4[0].relu2(model.layer4[0].bn2(model.layer4[0].conv2(model.layer4[0].relu1(model.layer4[0].bn1(model.layer4[0].conv1(before_block4)))))))))
        elif pos == 10:
            before_query_features = model.layer4[0].maxpool(model.layer4[0].relu3(model.layer4[0].bn3(model.layer4[0].conv3(model.layer4[0].relu2(model.layer4[0].bn2(model.layer4[0].conv2(model.layer4[0].relu1(model.layer4[0].bn1(model.layer4[0].conv1(before_block4))))))))))
    
    
    before_query_features = before_query_features.view(75, -1)
    
    # after adaptation
    inner_loss = F.cross_entropy(before_support_logits, support_target)
    
    model.zero_grad()
    params = update_parameters(model=model,
                               loss=inner_loss,
                               params=None,
                               step_size=step_size,
                               first_order=args.first_order)
    model.load_state_dict(params, strict=True)
    
    if args.model == '4conv':
        if pos == 1:
            after_query_features = model.features[0](query_input)
        elif pos == 2:
            after_query_features = model.features[1](model.features[0](query_input))
        elif pos == 3:
            after_query_features = model.features[2](model.features[1](model.features[0](query_input)))
        elif pos == 4:
            after_query_features = model.features[3](model.features[2](model.features[1](model.features[0](query_input))))
        elif pos == 5:
            after_query_features = model.classifier(model.features[3](model.features[2](model.features[1](model.features[0](query_input)))).view(75, -1))
    elif args.model == '4conv_sep':
        if pos == 1:
            after_query_features = model.conv1(query_input)
        elif pos == 2:
            after_query_features = model.conv2(model.conv1(query_input))
        elif pos == 3:
            after_query_features = model.conv3(model.conv2(model.conv1(query_input)))
        elif pos == 4:
            after_query_features = model.conv4(model.conv3(model.conv2(model.conv1(query_input))))
        elif pos == 5:
            after_query_features = model.classifier(model.conv4(model.conv3(model.conv2(model.conv1(query_input)))).view(75, -1))
#     elif args.model == 'resnet':
#         if pos == 1:
#             after_query_features = model.layer1(query_input)
#         elif pos == 2:
#             after_query_features = model.layer2(model.layer1(query_input))
#         elif pos == 3:
#             after_query_features = model.layer3(model.layer2(model.layer1(query_input)))
#         elif pos == 4:
#             after_query_features = model.layer4(model.layer3(model.layer2(model.layer1(query_input))))
#         elif pos == 5:
#             after_query_features = model.classifier(F.avg_pool2d(model.layer4(model.layer3(model.layer2(model.layer1(query_input)))), 5).view(75, -1))
#     elif args.model == 'resnet':
#         before_block4 = model.layer3(model.layer2(model.layer1(query_input)))
#         if pos == 1:
#             after_query_features = model.layer4[0].conv1(before_block4)
#         elif pos == 2:
#             after_query_features = model.layer4[0].bn1(model.layer4[0].conv1(before_block4))
#         elif pos == 3:
#             after_query_features = model.layer4[0].relu1(model.layer4[0].bn1(model.layer4[0].conv1(before_block4)))
#         elif pos == 4:
#             after_query_features = model.layer4[0].conv2(model.layer4[0].relu1(model.layer4[0].bn1(model.layer4[0].conv1(before_block4))))  
#         elif pos == 5:
#             after_query_features = model.layer4[0].bn2(model.layer4[0].conv2(model.layer4[0].relu1(model.layer4[0].bn1(model.layer4[0].conv1(before_block4)))))
#         elif pos == 6:
#             after_query_features = model.layer4[0].relu2(model.layer4[0].bn2(model.layer4[0].conv2(model.layer4[0].relu1(model.layer4[0].bn1(model.layer4[0].conv1(before_block4))))))
#         elif pos == 7:
#             after_query_features = model.layer4[0].conv3(model.layer4[0].relu2(model.layer4[0].bn2(model.layer4[0].conv2(model.layer4[0].relu1(model.layer4[0].bn1(model.layer4[0].conv1(before_block4)))))))
#         elif pos == 8:
#             after_query_features = model.layer4[0].bn3(model.layer4[0].conv3(model.layer4[0].relu2(model.layer4[0].bn2(model.layer4[0].conv2(model.layer4[0].relu1(model.layer4[0].bn1(model.layer4[0].conv1(before_block4))))))))
#         elif pos == 9:
#             after_query_features = model.layer4[0].bn3(model.layer4[0].conv3(model.layer4[0].relu2(model.layer4[0].bn2(model.layer4[0].conv2(model.layer4[0].relu1(model.layer4[0].bn1(model.layer4[0].conv1(before_block4)))))))) + \
#                                    model.layer4[0].downsample(before_block4)
#         elif pos == 10:
#             after_query_features = model.layer4[0].relu3(model.layer4[0].bn3(model.layer4[0].conv3(model.layer4[0].relu2(model.layer4[0].bn2(model.layer4[0].conv2(model.layer4[0].relu1(model.layer4[0].bn1(model.layer4[0].conv1(before_block4)))))))) + \
#                                                          model.layer4[0].downsample(before_block4))
#         elif pos == 11:
#             after_query_features = model.layer4[0].maxpool(model.layer4[0].relu3(model.layer4[0].bn3(model.layer4[0].conv3(model.layer4[0].relu2(model.layer4[0].bn2(model.layer4[0].conv2(model.layer4[0].relu1(model.layer4[0].bn1(model.layer4[0].conv1(before_block4)))))))) + \
#                                                            model.layer4[0].downsample(before_block4)))
            
    elif args.model == 'resnet': # _b style
        before_block4 = model.layer3(model.layer2(model.layer1(query_input)))
        if pos == 1:
            after_query_features = model.layer4[0].conv1(before_block4)
        elif pos == 2:
            after_query_features = model.layer4[0].bn1(model.layer4[0].conv1(before_block4))
        elif pos == 3:
            after_query_features = model.layer4[0].relu1(model.layer4[0].bn1(model.layer4[0].conv1(before_block4)))
        elif pos == 4:
            after_query_features = model.layer4[0].conv2(model.layer4[0].relu1(model.layer4[0].bn1(model.layer4[0].conv1(before_block4))))  
        elif pos == 5:
            after_query_features = model.layer4[0].bn2(model.layer4[0].conv2(model.layer4[0].relu1(model.layer4[0].bn1(model.layer4[0].conv1(before_block4)))))
        elif pos == 6:
            after_query_features = model.layer4[0].relu2(model.layer4[0].bn2(model.layer4[0].conv2(model.layer4[0].relu1(model.layer4[0].bn1(model.layer4[0].conv1(before_block4))))))
        elif pos == 7:
            after_query_features = model.layer4[0].conv3(model.layer4[0].relu2(model.layer4[0].bn2(model.layer4[0].conv2(model.layer4[0].relu1(model.layer4[0].bn1(model.layer4[0].conv1(before_block4)))))))
        elif pos == 8:
            after_query_features = model.layer4[0].bn3(model.layer4[0].conv3(model.layer4[0].relu2(model.layer4[0].bn2(model.layer4[0].conv2(model.layer4[0].relu1(model.layer4[0].bn1(model.layer4[0].conv1(before_block4))))))))
        elif pos == 9:
            after_query_features = model.layer4[0].relu3(model.layer4[0].bn3(model.layer4[0].conv3(model.layer4[0].relu2(model.layer4[0].bn2(model.layer4[0].conv2(model.layer4[0].relu1(model.layer4[0].bn1(model.layer4[0].conv1(before_block4)))))))))
        elif pos == 10:
            after_query_features = model.layer4[0].maxpool(model.layer4[0].relu3(model.layer4[0].bn3(model.layer4[0].conv3(model.layer4[0].relu2(model.layer4[0].bn2(model.layer4[0].conv2(model.layer4[0].relu1(model.layer4[0].bn1(model.layer4[0].conv1(before_block4))))))))))
            
    
    after_query_features = after_query_features.view(75, -1)
    
    return (before_query_features.unsqueeze(0).detach().cpu(), after_query_features.unsqueeze(0).detach().cpu())

def get_similarity(outputs):
    distance = torch.zeros([len(outputs), len(outputs)])
    cos = nn.CosineSimilarity()
    for i in range(len(outputs)):
        distance[i] = cos(outputs, outputs[i].unsqueeze(0))
    return distance

def get_mean(similarity_matrices):
    num_images = 15
    
    similarity_matrices[range(5*num_images), range(5*num_images)] = 0
    same_class_distance = torch.zeros([5*num_images, 5*num_images])
    same_class_distance[0*num_images:1*num_images, 0*num_images:1*num_images] = similarity_matrices[0*num_images:1*num_images, 0*num_images:1*num_images]
    same_class_distance[1*num_images:2*num_images, 1*num_images:2*num_images] = similarity_matrices[1*num_images:2*num_images, 1*num_images:2*num_images]
    same_class_distance[2*num_images:3*num_images, 2*num_images:3*num_images] = similarity_matrices[2*num_images:3*num_images, 2*num_images:3*num_images]
    same_class_distance[3*num_images:4*num_images, 3*num_images:4*num_images] = similarity_matrices[3*num_images:4*num_images, 3*num_images:4*num_images]
    same_class_distance[4*num_images:5*num_images, 4*num_images:5*num_images] = similarity_matrices[4*num_images:5*num_images, 4*num_images:5*num_images]

    different_class_distance = similarity_matrices - same_class_distance

    same_class_distance_list = same_class_distance[same_class_distance.nonzero()[:,0], same_class_distance.nonzero()[:,1]]
    different_class_distance_list = different_class_distance[different_class_distance.nonzero()[:,0], different_class_distance.nonzero()[:,1]]
                
    return (torch.mean(same_class_distance_list).item(),
            torch.std(same_class_distance_list).item(),
            torch.mean(different_class_distance_list).item(),
            torch.std(different_class_distance_list).item())

In [None]:
dataset = 'miniimagenet'
num_shots = 5
dataset_args = easydict.EasyDict({'folder': '/home/osilab7/hdd/ml_dataset/',
                                  'dataset': dataset,
                                  'num_ways': 5,
                                  'num_shots': num_shots,
                                  'download': True})

In [None]:
matplotlib.rcParams.update({'font.size': 12})

In [None]:
# model = '4conv'
# path = './output'
# algorithms = ['MAML', 'ANIL', 'BOIL']
model = 'resnet'
path = './output_resnet'
algorithms = ['MAML_a', 'MAML_b', 'ANIL_a', 'ANIL_b', 'BOIL_a', 'BOIL_b']
# algorithms = ['MAML_a', 'ANIL_a', 'BOIL_a']
algorithms = ['MAML_b', 'ANIL_b', 'BOIL_b']

sample_task = make_sample_task(load_dataset(dataset_args, 'meta_train'))

for algorithm in algorithms:
    save_name = '{}shot_{}_{}'.format(num_shots, model, algorithm)
    args = get_arguments(path, dataset, save_name)
    args.device = 'cuda:0'
    
    fig, axes = plt.subplots(1, 2, sharey=True, figsize=(20, 3)) # (10, 3), (8, 3)
    
    axes[0].set_title('Before adaptation')
    axes[0].set_ylim([0.0-0.05, 1.0+0.05])
    axes[0].tick_params(axis='both', which='major')
    axes[0].grid(True)

    axes[1].set_title('After adaptation')
    axes[1].set_ylim([0.0-0.05, 1.0+0.05])
    axes[1].tick_params(axis='both', which='major')
    axes[1].grid(True)
    
    if model == '4conv':
        xrange = ['conv1', 'conv2', 'conv3', 'conv4']
        pos_list = [1,2,3,4]
#     elif model == 'resnet':
#         xrange = ['block1', 'block2', 'block3', 'block4']
#         pos_list = [1,2,3,4]
#     elif model == 'resnet':
#         xrange = ['conv1', 'bn1', 'relu1', 'conv2', 'bn2', 'relu2', 'conv3', 'bn3', 'residual', 'relu3', 'maxpool']
#         pos_list = [1,2,3,4,5,6,7,8,9,10,11]
    elif model == 'resnet':
        xrange = ['conv1', 'bn1', 'relu1', 'conv2', 'bn2', 'relu2', 'conv3', 'bn3', 'relu3', 'maxpool']
        pos_list = [1,2,3,4,5,6,7,8,9,10]
    
    before_same_list = []
    before_same_std_list = []
    before_different_list = []
    before_different_std_list = []
    
    after_same_list = []
    after_same_std_list = []
    after_different_list = []
    after_different_std_list = []

    for pos in pos_list:    
        before_f, after_f = get_features_logits(args, sample_task, pos=pos)
                
        before = get_similarity(outputs=before_f.squeeze())
        after = get_similarity(outputs=after_f.squeeze())
                
        before_same_class, before_same_class_std, before_different_class, before_different_class_std= get_mean(before)
        after_same_class, after_same_class_std, after_different_class, after_different_class_std = get_mean(after)

        before_same_list.append(before_same_class)
        before_same_std_list.append(before_same_class_std)
        before_different_list.append(before_different_class)
        before_different_std_list.append(before_different_class_std)
        
        after_same_list.append(after_same_class)
        after_same_std_list.append(after_same_class_std)
        after_different_list.append(after_different_class)
        after_different_std_list.append(after_different_class_std)
    
    axes[0].errorbar(xrange, before_different_list, yerr=before_different_std_list, fmt='-o', capsize=6, capthick=2)
    axes[0].errorbar(xrange, before_same_list, yerr=before_same_std_list, fmt='-o', capsize=6, capthick=2)
    
    axes[1].errorbar(xrange, after_different_list, yerr=after_different_std_list, fmt='-o', capsize=6, capthick=2)
    axes[1].errorbar(xrange, after_same_list, yerr=after_same_std_list, fmt='-o', capsize=6, capthick=2)
    
#     plt.show()
    plt.subplots_adjust(wspace=0.2)
    plt.savefig('./src/{}_cosine_in_block4_miniimagenet.pdf'.format(algorithm), bbox_inches='tight', format='pdf')
    plt.close()

---
# Gradient norm when adaptation

In [None]:
xrange = ['conv1_w', 'conv2_w', 'conv3_w', 'conv4_w', 'head_w']
maml_values1 = np.array([0.1327, 0.0942, 0.1181, 0.1028, 3.3497])
maml_values2 = np.array([0.1706, 0.1169, 0.1470, 0.1172, 3.2524])
maml_values3 = np.array([0.1356, 0.1059, 0.1390, 0.1231, 3.2256])
maml_values4 = np.array([0.1641, 0.1065, 0.1270, 0.1115, 3.1296])
maml_values5 = np.array([0.1375, 0.1045, 0.1297, 0.1140, 3.4852])
maml_values = np.mean([maml_values1, maml_values2, maml_values3, maml_values4, maml_values5], axis=0)
maml_values_std = np.std([maml_values1, maml_values2, maml_values3, maml_values4, maml_values5], axis=0)

anil_values1 = np.array([0., 0., 0., 0., 3.4998])
anil_values2 = np.array([0., 0., 0., 0., 3.4090])
anil_values3 = np.array([0., 0., 0., 0., 3.3367])
anil_values4 = np.array([0., 0., 0., 0., 3.1945])
anil_values5 = np.array([0., 0., 0., 0., 3.5396])
anil_values = np.mean([anil_values1, anil_values2, anil_values3, anil_values4, anil_values5], axis=0)
anil_values_std = np.std([anil_values1, anil_values2, anil_values3, anil_values4, anil_values5], axis=0)

boil_values1 = np.array([0.3532, 0.3685, 0.5297, 4.4506, 0.])
boil_values2 = np.array([0.5029, 0.4703, 0.6391, 5.0119, 0.])
boil_values3 = np.array([0.3648, 0.3765, 0.5192, 4.4990, 0.])
boil_values4 = np.array([0.3271, 0.3796, 0.5237, 4.2251, 0.])
boil_values5 = np.array([0.4050, 0.4030, 0.5288, 4.3817, 0.])
boil_values = np.mean([boil_values1, boil_values2, boil_values3, boil_values4, boil_values5], axis=0)
boil_values_std = np.std([boil_values1, boil_values2, boil_values3, boil_values4, boil_values5], axis=0)

In [None]:
plt.figure(figsize=(12,4))
plt.errorbar(xrange, maml_values, yerr=maml_values_std, fmt='-o', capsize=6, capthick=2, label='MAML', color='blue')
plt.errorbar(xrange, anil_values, yerr=anil_values_std, fmt='-o', capsize=6, capthick=2, label='ANIL', color='green')
plt.errorbar(xrange, boil_values, yerr=boil_values_std, fmt='-o', capsize=6, capthick=2, label='BOIL', color='red')

plt.legend()
# plt.show()
plt.savefig('./src/grad_norm.pdf'.format(model), bbox_inches='tight', format='pdf')
plt.close()

---

# CCA/CKA

In [None]:
def gram_linear(x):
    """Compute Gram (kernel) matrix for a linear kernel.

    Args:
    x: A num_examples x num_features matrix of features.

    Returns:
    A num_examples x num_examples Gram matrix of examples.
    """
    return x.dot(x.T)


def gram_rbf(x, threshold=1.0):
    """Compute Gram (kernel) matrix for an RBF kernel.

    Args:
    x: A num_examples x num_features matrix of features.
    threshold: Fraction of median Euclidean distance to use as RBF kernel
      bandwidth. (This is the heuristic we use in the paper. There are other
      possible ways to set the bandwidth; we didn't try them.)

    Returns:
    A num_examples x num_examples Gram matrix of examples.
    """
    dot_products = x.dot(x.T)
    sq_norms = np.diag(dot_products)
    sq_distances = -2 * dot_products + sq_norms[:, None] + sq_norms[None, :]
    sq_median_distance = np.median(sq_distances)
    return np.exp(-sq_distances / (2 * threshold ** 2 * sq_median_distance))


def center_gram(gram, unbiased=False):
    """Center a symmetric Gram matrix.

    This is equvialent to centering the (possibly infinite-dimensional) features
    induced by the kernel before computing the Gram matrix.

    Args:
    gram: A num_examples x num_examples symmetric matrix.
    unbiased: Whether to adjust the Gram matrix in order to compute an unbiased
      estimate of HSIC. Note that this estimator may be negative.

    Returns:
    A symmetric matrix with centered columns and rows.
    """
    if not np.allclose(gram, gram.T):
        raise ValueError('Input must be a symmetric matrix.')
    gram = gram.copy()

    if unbiased:
        # This formulation of the U-statistic, from Szekely, G. J., & Rizzo, M.
        # L. (2014). Partial distance correlation with methods for dissimilarities.
        # The Annals of Statistics, 42(6), 2382-2412, seems to be more numerically
        # stable than the alternative from Song et al. (2007).
        n = gram.shape[0]
        np.fill_diagonal(gram, 0)
        means = np.sum(gram, 0, dtype=np.float64) / (n - 2)
        means -= np.sum(means) / (2 * (n - 1))
        gram -= means[:, None]
        gram -= means[None, :]
        np.fill_diagonal(gram, 0)
    else:
        means = np.mean(gram, 0, dtype=np.float64)
        means -= np.mean(means) / 2
        gram -= means[:, None]
        gram -= means[None, :]

    return gram

def cka(gram_x, gram_y, debiased=False):
    """Compute CKA.

    Args:
    gram_x: A num_examples x num_examples Gram matrix.
    gram_y: A num_examples x num_examples Gram matrix.
    debiased: Use unbiased estimator of HSIC. CKA may still be biased.

    Returns:
    The value of CKA between X and Y.
    """
    gram_x = center_gram(gram_x, unbiased=debiased)
    gram_y = center_gram(gram_y, unbiased=debiased)

    # Note: To obtain HSIC, this should be divided by (n-1)**2 (biased variant) or
    # n*(n-3) (unbiased variant), but this cancels for CKA.
    scaled_hsic = gram_x.ravel().dot(gram_y.ravel())

    normalization_x = np.linalg.norm(gram_x)
    normalization_y = np.linalg.norm(gram_y)
    return scaled_hsic / (normalization_x * normalization_y)


def _debiased_dot_product_similarity_helper(xty, sum_squared_rows_x, sum_squared_rows_y, squared_norm_x, squared_norm_y, n):
    """Helper for computing debiased dot product similarity (i.e. linear HSIC)."""
    # This formula can be derived by manipulating the unbiased estimator from
    # Song et al. (2007).
    return (
      xty - n / (n - 2.) * sum_squared_rows_x.dot(sum_squared_rows_y)
      + squared_norm_x * squared_norm_y / ((n - 1) * (n - 2)))


def feature_space_linear_cka(features_x, features_y, debiased=False):
    """Compute CKA with a linear kernel, in feature space.

    This is typically faster than computing the Gram matrix when there are fewer
    features than examples.

    Args:
    features_x: A num_examples x num_features matrix of features.
    features_y: A num_examples x num_features matrix of features.
    debiased: Use unbiased estimator of dot product similarity. CKA may still be
      biased. Note that this estimator may be negative.

    Returns:
    The value of CKA between X and Y.
    """
    features_x = features_x - np.mean(features_x, 0, keepdims=True)
    features_y = features_y - np.mean(features_y, 0, keepdims=True)

    dot_product_similarity = np.linalg.norm(features_x.T.dot(features_y)) ** 2
    normalization_x = np.linalg.norm(features_x.T.dot(features_x))
    normalization_y = np.linalg.norm(features_y.T.dot(features_y))

    if debiased:
        n = features_x.shape[0]
        # Equivalent to np.sum(features_x ** 2, 1) but avoids an intermediate array.
        sum_squared_rows_x = np.einsum('ij,ij->i', features_x, features_x)
        sum_squared_rows_y = np.einsum('ij,ij->i', features_y, features_y)
        squared_norm_x = np.sum(sum_squared_rows_x)
        squared_norm_y = np.sum(sum_squared_rows_y)

        dot_product_similarity = _debiased_dot_product_similarity_helper(
            dot_product_similarity, sum_squared_rows_x, sum_squared_rows_y,
            squared_norm_x, squared_norm_y, n)
        normalization_x = np.sqrt(_debiased_dot_product_similarity_helper(
            normalization_x ** 2, sum_squared_rows_x, sum_squared_rows_x,
            squared_norm_x, squared_norm_x, n))
        normalization_y = np.sqrt(_debiased_dot_product_similarity_helper(
            normalization_y ** 2, sum_squared_rows_y, sum_squared_rows_y,
            squared_norm_y, squared_norm_y, n))

    return dot_product_similarity / (normalization_x * normalization_y)

def cca(features_x, features_y):
    """Compute the mean squared CCA correlation (R^2_{CCA}).

    Args:
    features_x: A num_examples x num_features matrix of features.
    features_y: A num_examples x num_features matrix of features.

    Returns:
    The mean squared CCA correlations between X and Y.
    """
    qx, _ = np.linalg.qr(features_x)  # Or use SVD with full_matrices=False.
    qy, _ = np.linalg.qr(features_y)
    return np.linalg.norm(qx.T.dot(qy)) ** 2 / min(
      features_x.shape[1], features_y.shape[1])

In [None]:
dataset = 'cars'
num_shots = 5
dataset_args = easydict.EasyDict({'folder': '/home/osilab7/hdd/ml_dataset/',
                                  'dataset': dataset,
                                  'num_ways': 5,
                                  'num_shots': num_shots,
                                  'download': True})

model = '4conv'
path = './output'
algorithms = ['MAML', 'ANIL', 'BOIL']

model = '4conv_sep'
path = './output_conv_abal'
algorithms = ['MAML_3', 'BOIL_3']



maml_df = pd.DataFrame(columns=['1','2','3','4','head'])
# anil_df = pd.DataFrame(columns=['1','2','3','4','head'])
boil_df = pd.DataFrame(columns=['1','2','3','4','head'])

# model = 'resnet'
# path = './output_resnet'
# algorithms = ['MAML_a', 'MAML_b', 'ANIL_a', 'ANIL_b', 'BOIL_a', 'BOIL_b']

# maml_a_df = pd.DataFrame(columns=['1','2','3','4','head'])
# maml_b_df = pd.DataFrame(columns=['1','2','3','4','head'])
# anil_a_df = pd.DataFrame(columns=['1','2','3','4','head'])
# anil_b_df = pd.DataFrame(columns=['1','2','3','4','head'])
# boil_a_df = pd.DataFrame(columns=['1','2','3','4','head'])
# boil_b_df = pd.DataFrame(columns=['1','2','3','4','head'])

for i in tqdm(range(1)):
    sample_task = make_sample_task(load_dataset(dataset_args, 'meta_train'))

    for algorithm in algorithms:
        save_name = '{}shot_{}_{}'.format(num_shots, model, algorithm)
        args = get_arguments(path, dataset, save_name)
        args.device = 'cuda:0'

        all_before_f = torch.tensor([])
        all_after_f = torch.tensor([])

        cka_list = []
        cka_std_list = []

        if model == '4conv' or model == '4conv_sep':
            xrange = ['conv1', 'conv2', 'conv3', 'conv4', 'head']
            pos_list = [1,2,3,4,5]
        elif model == 'resnet':
            xrange = ['block1', 'block2', 'block3', 'block4', 'head']
            pos_list = [1,2,3,4,5]

        for pos in pos_list:
            before_f, after_f = get_features_logits(args, sample_task, pos=pos)

            before_f = before_f.squeeze(0).numpy()
            after_f = after_f.squeeze(0).numpy()

            cka = feature_space_linear_cka(before_f, after_f)
            cka_list.append(cka)
        
        if model == '4conv':
            if algorithm == 'MAML':
                maml_df.loc[i] = cka_list
            elif algorithm ==  'ANIL':
                anil_df.loc[i] = cka_list
            elif algorithm ==  'BOIL':
                boil_df.loc[i] = cka_list
        elif model == '4conv_sep':
            if 'MAML' in algorithm:
                maml_df.loc[i] = cka_list
            elif 'BOIL' in algorithm:
                boil_df.loc[i] = cka_list
        elif model == 'resnet':
            if algorithm == 'MAML_a':
                maml_a_df.loc[i] = cka_list
            if algorithm == 'MAML_b':
                maml_b_df.loc[i] = cka_list
            if algorithm == 'ANIL_a':
                anil_a_df.loc[i] = cka_list
            if algorithm == 'ANIL_b':
                anil_b_df.loc[i] = cka_list
            if algorithm == 'BOIL_a':
                boil_a_df.loc[i] = cka_list
            if algorithm == 'BOIL_b':
                boil_b_df.loc[i] = cka_list
                
"""
fig, ax = plt.subplots(1, 1, sharey=True, figsize=(8,6))
if model == '4conv':
    if algorithm == 'MAML':
        ax.plot(xrange, cka_list, marker='o', label='MAML', color='#4F81BD')
        ax.errorbar(xrange, cka_list, yerr=cka_std_list, fmt='-o', capsize=6, capthick=2, label='MAML')
    elif algorithm == 'BOIL':
        ax.plot(xrange, cka_list, marker='D', label='BOIL', color='#C0504D')
elif model == 'resnet':
    if algorithm == 'block_a_extractor':
        ax.plot(xrange, cka_list, marker='o', label='BOIL w/ last skip connection', color='#4F81BD')
    elif algorithm == 'block_b_extractor':
        ax.plot(xrange, cka_list, marker='D', label='BOIL w/o last skip connection', color='#C0504D')
    
plt.legend()
plt.show()
plt.subplots_adjust(wspace=0.2)
# plt.savefig('./src/{}_cka.pdf'.format(model), bbox_inches='tight', format='pdf')
plt.close()
"""

maml_df.to_csv('maml_3_df.csv')
# anil_df.to_csv('anil_df.csv')
boil_df.to_csv('boil_3_df.csv')

# maml_a_df.to_csv('maml_a_df.csv')
# maml_b_df.to_csv('maml_b_df.csv')
# anil_a_df.to_csv('anil_a_df.csv')
# anil_b_df.to_csv('anil_b_df.csv')
# boil_a_df.to_csv('boil_a_df.csv')
# boil_b_df.to_csv('boil_b_df.csv')

In [None]:
maml_acc = [np.mean(maml_df['1']), np.mean(maml_df['2']), np.mean(maml_df['3']), np.mean(maml_df['4']), np.mean(maml_df['head'])]
maml_std = [np.std(maml_df['1']), np.std(maml_df['2']), np.std(maml_df['3']), np.std(maml_df['4']), np.std(maml_df['head'])]
# anil_acc = [np.mean(anil_df['1']), np.mean(anil_df['2']), np.mean(anil_df['3']), np.mean(anil_df['4']), np.mean(anil_df['head'])]
# anil_std = [np.std(anil_df['1']), np.std(anil_df['2']), np.std(anil_df['3']), np.std(anil_df['4']), np.std(anil_df['head'])]
boil_acc = [np.mean(boil_df['1']), np.mean(boil_df['2']), np.mean(boil_df['3']), np.mean(boil_df['4']), np.mean(boil_df['head'])]
boil_std = [np.std(boil_df['1']), np.std(boil_df['2']), np.std(boil_df['3']), np.std(boil_df['4']), np.std(boil_df['head'])]

In [None]:
fig, ax = plt.subplots(1, 1, sharey=True, figsize=(8,6))
xrange = ['conv1', 'conv2', 'conv3', 'conv4', 'head']

ax.set_title('CKA')
ax.set_ylim([0.0-0.05, 1.0+0.05])
ax.tick_params(axis='both', which='major')
ax.grid(True)

ax.errorbar(xrange, maml_acc, yerr=maml_std, fmt='-o', capsize=6, capthick=2, label='MAML', color='blue')
# ax.errorbar(xrange, anil_acc, yerr=anil_std, fmt='-o', capsize=6, capthick=2, label='ANIL', color='green')
ax.errorbar(xrange, boil_acc, yerr=boil_std, fmt='-o', capsize=6, capthick=2, label='BOIL', color='red')
        
plt.legend()
plt.show()
plt.subplots_adjust(wspace=0.2)
# plt.savefig('./src/{}_pos3_cka.pdf'.format(model), bbox_inches='tight', format='pdf')
plt.close()

In [None]:
maml_a_acc = [np.mean(maml_a_df['1']), np.mean(maml_a_df['2']), np.mean(maml_a_df['3']), np.mean(maml_a_df['4']), np.mean(maml_a_df['head'])]
maml_a_std = [np.std(maml_a_df['1']), np.std(maml_a_df['2']), np.std(maml_a_df['3']), np.std(maml_a_df['4']), np.std(maml_a_df['head'])]
maml_b_acc = [np.mean(maml_b_df['1']), np.mean(maml_b_df['2']), np.mean(maml_b_df['3']), np.mean(maml_b_df['4']), np.mean(maml_b_df['head'])]
maml_b_std = [np.std(maml_b_df['1']), np.std(maml_b_df['2']), np.std(maml_b_df['3']), np.std(maml_b_df['4']), np.std(maml_b_df['head'])]
anil_a_acc = [np.mean(anil_a_df['1']), np.mean(anil_a_df['2']), np.mean(anil_a_df['3']), np.mean(anil_a_df['4']), np.mean(anil_a_df['head'])]
anil_a_std = [np.std(anil_a_df['1']), np.std(anil_a_df['2']), np.std(anil_a_df['3']), np.std(anil_a_df['4']), np.std(anil_a_df['head'])]
anil_b_acc = [np.mean(anil_b_df['1']), np.mean(anil_b_df['2']), np.mean(anil_b_df['3']), np.mean(anil_b_df['4']), np.mean(anil_b_df['head'])]
anil_b_std = [np.std(anil_b_df['1']), np.std(anil_b_df['2']), np.std(anil_b_df['3']), np.std(anil_b_df['4']), np.std(anil_b_df['head'])]
boil_a_acc = [np.mean(boil_a_df['1']), np.mean(boil_a_df['2']), np.mean(boil_a_df['3']), np.mean(boil_a_df['4']), np.mean(boil_a_df['head'])]
boil_a_std = [np.std(boil_a_df['1']), np.std(boil_a_df['2']), np.std(boil_a_df['3']), np.std(boil_a_df['4']), np.std(boil_a_df['head'])]
boil_b_acc = [np.mean(boil_b_df['1']), np.mean(boil_b_df['2']), np.mean(boil_b_df['3']), np.mean(boil_b_df['4']), np.mean(boil_b_df['head'])]
boil_b_std = [np.std(boil_b_df['1']), np.std(boil_b_df['2']), np.std(boil_b_df['3']), np.std(boil_b_df['4']), np.std(boil_b_df['head'])]

In [None]:
fig, ax = plt.subplots(1, 1, sharey=True, figsize=(8,6))
xrange = ['block1', 'block2', 'block3', 'block4', 'head']

ax.set_title('CKA')
ax.set_ylim([0.0-0.05, 1.0+0.05])
ax.tick_params(axis='both', which='major')
ax.grid(True)

ax.errorbar(xrange, maml_a_acc, yerr=maml_a_std, fmt='-o', capsize=6, capthick=2, label='MAML w/ LSC', color='blue')
ax.errorbar(xrange, anil_a_acc, yerr=anil_a_std, fmt='-o', capsize=6, capthick=2, label='ANIL w/ LSC', color='green')
ax.errorbar(xrange, boil_a_acc, yerr=boil_a_std, fmt='-o', capsize=6, capthick=2, label='BOIL w/ LSC', color='red')
        
ax.errorbar(xrange, maml_b_acc, yerr=maml_b_std, fmt='--x', capsize=6, capthick=2, label='MAML w/o LSC', color='blue')
ax.errorbar(xrange, anil_b_acc, yerr=anil_b_std, fmt='--x', capsize=6, capthick=2, label='ANIL w/o LSC', color='green')
ax.errorbar(xrange, boil_b_acc, yerr=boil_b_std, fmt='--x', capsize=6, capthick=2, label='BOIL w/o LSC', color='red')

plt.legend()
plt.show()
plt.subplots_adjust(wspace=0.2)
# plt.savefig('./src/{}_cka.pdf'.format(model), bbox_inches='tight', format='pdf')
plt.close()