In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import csv
from PIL import Image
import matplotlib as mpl
from tqdm import tqdm
from sklearn.manifold import TSNE
import umap 
from sklearn.metrics.pairwise import cosine_distances

# Visualizing the Disregarding classes

### Load data

In [None]:
def access_data(letter,shot):
    feat = torch.load('features'+letter+str(shot),map_location=torch.device('cpu'))
    classifier= torch.load('classifier'+letter,map_location=torch.device('cpu'))
    accuracy = torch.load('complete_class_accuracy'+letter+str(shot)+'shots',map_location=torch.device('cpu'))
    idx = torch.load('complete_class_accuracy'+letter+'idx'+str(shot)+'shots',map_location=torch.device('cpu'))
    return feat,classifier,accuracy,idx

In [None]:
shot=5
letter='A'
feat,classifier,acc,idx = access_data(letter,shot)
print(acc.shape)
print(feat.shape)
print(classifier.shape)
print(idx.shape)

In [None]:
shot=5
letter='B'
featB,classifierB,accB,idxB = access_data(letter,shot)
print(accB.shape)
print(featB.shape)
print(classifierB.shape)
print(idxB.shape)

In [None]:
base_mean = feat[:64].mean(-2)
base_meanB = featB[:64].mean(-2)
print(base_mean.shape)
base_random = torch.rand((64,640))
print(base_random.shape)

In [None]:
def proj_class(i,test_features,letter='A',random=False):
    if random:
        w=base_random[i]
    else:
        if letter=='A':
            #one projection per 64 clesses on miniimagenet
            w=base_mean[i]    #select weights of the i-th class
        else:
            w=base_meanB[i] 
    proj = torch.matmul(test_features,w)/ torch.norm(w)**2    #get coef of projection and normalize
    try:
        projection_ortho = proj.unsqueeze(-1).repeat(1,640)
    except:
        projection_ortho = proj.unsqueeze(-1).repeat(1,1,640)
    projection_ortho = projection_ortho * w                   #vector of projection along w 
    projection_ortho = test_features - projection_ortho       #projection on the orthogonal space of w
    return projection_ortho

In [None]:
filenametrain = '/home/r21lafar/Documents/dataset/miniimagenetimages/train.csv'
filenametest = '/home/r21lafar/Documents/dataset/miniimagenetimages/test.csv'
directory = '/home/r21lafar/Documents/dataset/miniimagenetimages/images/'
def opencsv(filename):
    file = open(filename)
    csvreader = csv.reader(file)
    header = []
    header = next(csvreader)
    print(header)
    rowstrain = []
    rows = []
    for row in csvreader:
        rows.append(row)
    return rows
test = opencsv(filenametest)
train = opencsv(filenametrain)
def openimg(cl,title):
    if cl<64:
        src=train
    if cl>=80:
        src=test
        cl-=80
    if type(cl)==int:
        plt.figure(figsize=(5,5))
        idx=int((cl+0.5)*600)+np.random.randint(-100,100)
        filename=src[idx][0]
        im = Image.open(directory +filename)
        plt.title(title)
        plt.imshow(np.array(im))

In [None]:
def distance_from_base(proj,run,plot=False,letter='A'):
    if letter=='A':
        fs_run = feat[acc[0,0,run].long()]
    else:
        fs_run = featB[acc[0,0,run].long()]
    if proj==-1 and run ==-1:
        if letter=='A':
            proto_fs = feat[-20:].mean(1)
        else:
            proto_fs = featB[-20:].mean(1)
    else:
        fs_run = torch.gather(fs_run,dim=1,index=idx[0,run].unsqueeze(-1).repeat(1,1,640).long()) 
        proto_fs = fs_run[:,:shot].mean(1)
    if proj!=0:
        proto_fs=proj_class(proj-1,proto_fs,letter=letter)
    if letter=='A': 
        D = torch.cdist(proto_fs,base_mean)
    else:
        D = torch.cdist(proto_fs,base_meanB)
    if plot:
        plt.figure()
        plt.imshow(D.detach().numpy(),aspect='auto')
        plt.colorbar()
        plt.title('distance between FS class mean and base class '+letter+' mean \n (whole base dataset) projection ' +str(proj) + ' (0 is no projection)')
        plt.xlabel('64 base class mean')
        plt.ylabel('FS prototype of class')
    return D

## Create FS scenarii or runs 
### 2 ways

In [None]:
n_runs, batch_few_shot_runs = 500,10
n_ways=5
def ncm(train_features, features, run_classes, run_indices, n_shots,i_proj):
    global winners
    with torch.no_grad():
        dim = features.shape[2]
        targets = torch.arange(n_ways).unsqueeze(1).unsqueeze(0)
        #features = preprocess(train_features, features)
        scores = []
        score=0
        for batch_idx in range(n_runs // batch_few_shot_runs):
            runs = generate_runs(features, run_classes, run_indices, batch_idx)
            means = torch.mean(runs[:,:,:n_shots], dim = 2)
            distances = torch.norm(runs[:,:,n_shots:].reshape(batch_few_shot_runs, n_ways, 1, -1, dim) - means.reshape(batch_few_shot_runs, 1, n_ways, 1, dim), dim = 4, p = 2)
            winners = torch.min(distances, dim = 2)[1]
            accuracy = (winners == targets)
            if batch_idx==0:
                full_accuracy=accuracy
                full_mean=means
            else:
                full_accuracy=torch.cat((full_accuracy,accuracy),dim=0)
                full_mean=torch.cat((full_mean,means),dim=0)
        return full_accuracy,full_mean

    
def generate_runs(data, run_classes, run_indices, batch_idx):
    n_runs, n_ways, n_samples = run_classes.shape[0], run_classes.shape[1], run_indices.shape[2]
    run_classes = run_classes[batch_idx * batch_few_shot_runs : (batch_idx + 1) * batch_few_shot_runs]
    run_indices = run_indices[batch_idx * batch_few_shot_runs : (batch_idx + 1) * batch_few_shot_runs]
    run_classes = run_classes.unsqueeze(2).unsqueeze(3).repeat(1,1,data.shape[1], data.shape[2])
    run_indices = run_indices.unsqueeze(3).repeat(1, 1, 1, data.shape[2])
    datas = data.unsqueeze(0).repeat(batch_few_shot_runs, 1, 1, 1)
    cclasses = torch.gather(datas, 1, run_classes.to(torch.int64))
    res = torch.gather(cclasses, 2, run_indices)
    return res

def define_runs(n_ways, n_shots, n_queries, num_classes, elements_per_class):
    shuffle_classes = torch.LongTensor(np.arange(num_classes))
    run_classes = torch.LongTensor(n_runs, n_ways)
    run_indices = torch.LongTensor(n_runs, n_ways, n_shots + n_queries)
    for i in range(n_runs):
        run_classes[i] = torch.randperm(num_classes)[:n_ways]
        for j in range(n_ways):
            run_indices[i,j] = torch.randperm(elements_per_class[run_classes[i, j]])[:n_shots + n_queries]
    return run_classes, run_indices

In [None]:
run_classes, run_indices = define_runs(n_ways, 5, 500,20, [600 for i in range(20)])

In [None]:
for i in tqdm(range(64)):
    if i!=0:
        feature=proj_class(i-1,feat,'A',random=True)   #RANDOM WAY NOW
        #featureB=proj_class(i-1,featB,'B',random=True)
    else:
        feature =feat
        #featureB =featB
    A,meanA = ncm(feature[:64], feature[-20:], run_classes, run_indices, 5,0)
    #B,meanB = ncm(featureB[:64], featureB[-20:],run_classes, run_indices, 5,0)
    if i==0:
        fullA = A.unsqueeze(0)
        #fullB = B.unsqueeze(0)
        fullmeanA = meanA.unsqueeze(0)
        #fullmeanB = meanB.unsqueeze(0)
    else:
        fullA = torch.cat((fullA, A.unsqueeze(0)) ,dim = 0)
        #fullB = torch.cat((fullB, B.unsqueeze(0)) ,dim = 0)
        fullmeanA = torch.cat((fullmeanA, meanA.unsqueeze(0)) ,dim = 0)
        #fullmeanB = torch.cat((fullmeanB, meanB.unsqueeze(0)) ,dim = 0)

In [None]:
def what_proj(run):
    return fullA[:,run].float().mean(-1).mean(-1).argsort()-1

In [None]:
%matplotlib inline
run = 0
mk_size=4
plt.figure()
plt.plot(fullA[:,run].float().mean(-1).mean(-1),'.')
plt.hlines(y=fullA[0,run].float().mean(),xmin = 0 ,xmax =64
           ,label='baseline no proj')
plt.xlabel('projection (0 is no projection)')
plt.ylabel('accuracy')



In [None]:
baseline = fullA[0].float().mean()
projected = fullA[1:].float().mean()
projected-baseline

In [None]:
fullA.shape

In [None]:
baseline = fullA[0].float().mean(-1).mean(-1)
best_acc = fullA[1:].float().mean(-1).mean(-1).max(dim = 0)
best_boost = best_acc[0] - baseline

In [None]:
plt.hist(best_boost.detach().numpy(),bins=20)
plt.xlabel('best boost')
plt.ylabel('frequency')
plt.title('64 random vectors 500 runs')

In [None]:
plt.plot(torch.norm(base_mean, dim= 1),'.')
plt.hlines(y=torch.norm(base_mean, dim= 1).mean(),xmin=0,xmax=64)

In [None]:
plt.imshow(cosine_distances(base_mean,base_mean))
plt.colorbar()

In [None]:
plt.imshow(cosine_distances(base_mean,base_random))
plt.colorbar()

In [None]:
plt.imshow(cosine_distances(base_random,base_random))
plt.colorbar()

# Analysis shot by shot

In [None]:
%matplotlib inline
run = 0
mk_size=4
plt.figure()
plt.plot(fullA[:,run].float().mean(-1).mean(-1),'.')
plt.hlines(y=fullA[0,run].float().mean(),xmin = 0 ,xmax =64
           ,label='baseline no proj')
plt.xlabel('projection (0 is no projection)')
plt.ylabel('accuracy')



In [None]:
run_indices.shape

In [None]:
run_classes.shape

In [None]:
run = 0
featb1 = generate_runs(feat, run_classes, run_indices, 0)
feature = featb1[run,:5,:5].reshape(-1,640)
plt.figure()
plt.imshow(cosine_distances(feature, base_random))
plt.colorbar()
plt.figure()

plt.figure()
plt.plot(fullA[:,run].float().mean(-1).mean(-1),'.')
plt.hlines(y=fullA[0,run].float().mean(),xmin = 0 ,xmax =64
           ,label='baseline no proj')
plt.xlabel('projection (0 is no projection)')
plt.ylabel('accuracy')

In [None]:
run = 4
featb1 = generate_runs(feat, run_classes, run_indices, 0)
feature = featb1[run,:5,:5].reshape(-1,640)

cs = cosine_distances(feature, base_random).sum(0)
plt.figure()

plt.plot(cs,fullA[:,run].float().mean(-1).mean(-1),'.')

plt.xlabel('mean cosine distance between feature and vector')
plt.ylabel('accuracy')
