# Reference: https://github.com/ShuoYang-1998/Few_Shot_Distribution_Calibration 

# ===================== Imports =====================

In [None]:
import time
start_time = time.time()
import pickle
import numpy as np
import torch
from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import KNeighborsClassifier
from tqdm import tqdm
import os
use_gpu = torch.cuda.is_available()
import matplotlib.pyplot as plt

# ============ Distribution Calibration Code ==========

In [None]:
def distribution_calibration(query, base_means, base_cov, k,alpha=0.21):
    dist = []
    for i in range(len(base_means)):
        dist.append(np.linalg.norm(query-base_means[i]))
    index = np.argpartition(dist, k)[:k]
    #print(index)
    mean = np.concatenate([np.array(base_means)[index], query[np.newaxis, :]])
    calibrated_mean = np.mean(mean, axis=0)
    calibrated_cov = np.mean(np.array(base_cov)[index], axis=0)+alpha
    return calibrated_mean, calibrated_cov

# ========== Setup data, nway-nshot task ==========

In [None]:
# ---- data loading
dataset = 'miniImagenet'
n_shot = 1 # 5
n_ways = 5
n_queries = 15
n_runs = 10000
n_lsamples = n_ways * n_shot
n_usamples = n_ways * n_queries
n_samples = n_lsamples + n_usamples

In [None]:
# ========================================================
#   Usefull paths
_datasetFeaturesFiles = {"miniImagenet": "./dino_features_data/dino_test_miniimagenet.p",}
_cacheDir = "cache"
_maxRuns = 10000
_min_examples = -1.

# ========================================================
#   Module internal functions and variables

_randStates = None
_rsCfg = None


def _load_pickle(file):
    with open(file, 'rb') as f:
        data = pickle.load(f)
        labels = [np.full(shape=len(data[key]), fill_value=key)
                  for key in data]
        data = [features for key in data for features in data[key]]
        dataset = dict()
        dataset['data'] = torch.FloatTensor(np.stack(data, axis=0))
        dataset['labels'] = torch.LongTensor(np.concatenate(labels))
        return dataset


# =========================================================
#    Callable variables and functions from outside the module

data = None
labels = None
dsName = None


def loadDataSet(dsname):
    if dsname not in _datasetFeaturesFiles:
        raise NameError('Unknwown dataset: {}'.format(dsname))

    global dsName, data, labels, _randStates, _rsCfg, _min_examples
    dsName = dsname
    _randStates = None
    _rsCfg = None

    # Loading data from files on computer
    # home = expanduser("~")
    dataset = _load_pickle(_datasetFeaturesFiles[dsname])

    # Computing the number of items per class in the dataset
    _min_examples = dataset["labels"].shape[0]
    for i in range(dataset["labels"].shape[0]):
        if torch.where(dataset["labels"] == dataset["labels"][i])[0].shape[0] > 0:
            _min_examples = min(_min_examples, torch.where(
                dataset["labels"] == dataset["labels"][i])[0].shape[0])
    print("Guaranteed number of items per class: {:d}\n".format(_min_examples))

    # Generating data tensors
    data = torch.zeros((0, _min_examples, dataset["data"].shape[1]))
    labels = dataset["labels"].clone()
    while labels.shape[0] > 0:
        indices = torch.where(dataset["labels"] == labels[0])[0]
        data = torch.cat([data, dataset["data"][indices, :]
                          [:_min_examples].view(1, _min_examples, -1)], dim=0)
        indices = torch.where(labels != labels[0])[0]
        labels = labels[indices]
    print("Total of {:d} classes, {:d} elements each, with dimension {:d}\n".format(
        data.shape[0], data.shape[1], data.shape[2]))


def GenerateRun(iRun, cfg, regenRState=False, generate=True):
    global _randStates, data, _min_examples
    if not regenRState:
        np.random.set_state(_randStates[iRun])

    classes = np.random.permutation(np.arange(data.shape[0]))[:cfg["ways"]]
    shuffle_indices = np.arange(_min_examples)
    dataset = None
    actual_labels = None
    if generate:
        dataset = torch.zeros((cfg['ways'], cfg['shot']+cfg['queries'], data.shape[2]))
        actual_labels = torch.zeros((cfg['ways']))
    for i in range(cfg['ways']):
        shuffle_indices = np.random.permutation(shuffle_indices)
        if generate:
            dataset[i] = data[classes[i], shuffle_indices,:][:cfg['shot']+cfg['queries']]
            actual_labels[i] = classes[i]
    return dataset, actual_labels


def ClassesInRun(iRun, cfg):
    global _randStates, data
    np.random.set_state(_randStates[iRun])

    classes = np.random.permutation(np.arange(data.shape[0]))[:cfg["ways"]]
    return classes


def setRandomStates(cfg):
    global _randStates, _maxRuns, _rsCfg
    if _rsCfg == cfg:
        return

    rsFile = os.path.join(_cacheDir, "RandStates_{}_s{}_q{}_w{}".format(
        dsName, cfg['shot'], cfg['queries'], cfg['ways']))
    if not os.path.exists(rsFile):
        print("{} does not exist, regenerating it...".format(rsFile))
        np.random.seed(0)
        _randStates = []
        for iRun in range(_maxRuns):
            _randStates.append(np.random.get_state())
            GenerateRun(iRun, cfg, regenRState=True, generate=False)
        torch.save(_randStates, rsFile)
    else:
        print("reloading random states from file....")
        _randStates = torch.load(rsFile)
    _rsCfg = cfg


def GenerateRunSet(start=None, end=None, cfg=None):
    global dataset, _maxRuns, actual_classes
    if start is None:
        start = 0
    if end is None:
        end = _maxRuns
    if cfg is None:
        cfg = {"shot": 1, "ways": 5, "queries": 15}

    setRandomStates(cfg)
    print("generating task from {} to {}".format(start, end))

    dataset = torch.zeros((end-start, cfg['ways'], cfg['shot']+cfg['queries'], data.shape[2]))
    actual_classes = torch.zeros((end-start, cfg['ways']))
    
    for iRun in range(end-start):
        dataset[iRun], actual_classes[iRun] = GenerateRun(start+iRun, cfg)

    return dataset, actual_classes

In [None]:
# import FSLTask
cfg = {'shot': n_shot, 'ways': n_ways, 'queries': n_queries}

loadDataSet(dataset)

setRandomStates(cfg)

ndatas, actual_classes = GenerateRunSet(end=n_runs, cfg=cfg)
ndatas = ndatas.permute(0, 2, 1, 3).reshape(n_runs, n_samples, -1)
labels = torch.arange(n_ways).view(1, 1, n_ways).expand(n_runs, n_shot + n_queries, n_ways).clone().view(n_runs, n_samples)

# ===== Get the base-class stats for Distribution Calibration =======

In [None]:
# ---- Base class statistics
base_means = []
base_cov = [] 
dataset = 'miniImagenet'
#base_features_path = "../checkpoints/%s/WideResNet28_10_S2M2_R/last/base_features.plk"%dataset
base_features_path = "./dino_features_data/dino_train_miniimagenet.p"
with open(base_features_path, 'rb') as f:
    data = pickle.load(f)
    for key in data.keys():
        feature = np.array(data[key])
        mean = np.mean(feature, axis=0)
        cov = np.cov(feature.T)
        base_means.append(mean)
        base_cov.append(cov)

# ========== Logging ==========

In [None]:
import logging
import sys

logging.basicConfig(
    level=logging.INFO, 
    format='[{%(filename)s:%(lineno)d} %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(filename='logs/experiment_errors.log'),
        logging.StreamHandler(sys.stdout)
    ]
)
logger = logging.getLogger('LOGGER_NAME')

# ========== DC + Logistic Regression ==========

In [None]:
# ---- classification for each task
acc_list = []

print('Start classification for %d tasks...'%(n_runs))

keep_count = 0

for i in tqdm(range(n_runs)):
    keep_count = i
    support_data = ndatas[i][:n_lsamples].numpy()
    support_label = labels[i][:n_lsamples].numpy()
    query_data = ndatas[i][n_lsamples:].numpy()
    query_label = labels[i][n_lsamples:].numpy()
    
    # ---- Tukey's transform
    # beta = 0.5
    #support_data = np.power(support_data[:, ] ,beta)
    #query_data = np.power(query_data[:, ] ,beta)
    
    # ---- distribution calibration and feature sampling
    sampled_data = []
    sampled_label = []
    num_sampled = int(750/n_shot)
    
    for i in range(n_lsamples):
        mean, cov = distribution_calibration(support_data[i], base_means, base_cov, k=2)
        sampled_data.append(np.random.multivariate_normal(mean=mean, cov=cov, size=num_sampled))
        sampled_label.extend([support_label[i]]*num_sampled)

    sampled_data = np.concatenate([sampled_data[:]]).reshape(n_ways * n_shot * num_sampled, -1)
    
    X_aug = np.concatenate([support_data, sampled_data])
    Y_aug = np.concatenate([support_label, sampled_label])
    
    # ---- train classifier
    classifier = LogisticRegression(max_iter=1000).fit(X=X_aug, y=Y_aug)
    #classifier = KNeighborsClassifier(n_neighbors=3).fit(X=X_aug, y=Y_aug)
    predicts = classifier.predict(query_data)
    
    
    acc = np.mean(predicts == query_label)
    acc_list.append(acc)
    
    if keep_count % 100 == 0:
        logger.info('Iteration  %d %f %f' %(keep_count, acc, np.mean(acc_list)))

logger.info('Final: Mini-Imagenet') 
logger.info('%s %d way %d shot  ACC : %f'%(dataset,n_ways,n_shot,float(np.mean(acc_list))))
logger.info('Time needed: %f' %(time.time()-start_time))
print('%s %d way %d shot  ACC : %f'%(dataset,n_ways,n_shot,float(np.mean(acc_list))))