In [2]:
import torch
import os
import pickle
import glob
from PIL import Image
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt 
import cv2
from scipy.spatial.distance import cdist
from sklearn.cluster import KMeans
%matplotlib inline

In [2]:
def init_protonet():
    '''
    Initialize the ProtoNet
    '''
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    model = ProtoNet().to(device)
    return model
def euclidean_dist(x, y):
    '''
    Compute euclidean distance between two tensors
    '''
    # x: N x D
    # y: M x D

    return torch.pow(x-y,2).sum().detach().cpu().numpy().tolist()

In [5]:
def get_classes_paths(paths):
    classes_paths = {}
    for path in paths:
        cls = int(os.path.basename(path.split('.')[0]))-1
        if cls not in classes_paths:
            classes_paths[cls] = []
        classes_paths[cls].append(path)
    return classes_paths
def load_dict(filename_):
    with open(filename_, 'rb') as f:
        ret_di = pickle.load(f)
    return ret_di
def embed_images(paths,model):
    images_tensors = []
    for path in paths:
        x = torch.load(path)
        images_tensors.append(x)
    
    images_tensors_tensor = torch.stack(images_tensors)
    images_tensors_tensor = images_tensors_tensor.cuda()
#     print(images_tensors_tensor.shape)
    emb_vectors = model(images_tensors_tensor)
    return emb_vectors.cpu().detach()
def get_distances_embedding(emb_vectors,ref):
    vec1 = ref[0]
    vec2 = ref[1]
    ds1 = []
    ds2 = []
    for vec in emb_vectors:
        d1 = euclidean_dist(vec,vec1)
        d2 = euclidean_dist(vec,vec2)
        ds1.append(d1)
        ds2.append(d2)
    return np.array(ds1),np.array(ds2)

In [4]:
def random_image_display(files_paths):
    rand_idx = np.random.randint(len(files_paths)-1)
    img = cv2.imread(files_paths[rand_idx])
    print(img.shape)
    plt.imshow(img)
def convert_paths_to_jpg(paths):
    return [os.path.splitext(p)[0]+'.jpg' for p in paths]
def convert_paths_to_jpg(paths):
    return [os.path.splitext(p)[0]+'.jpg' for p in paths]

In [31]:
def get_files(filepath,expression='*.json'):
    '''
    Walks over a directory and its children to get all children json files pathes
    Arguments:
    file_path: string that specifies the path to the data parent directory 
    Returns:
    all_files: List of all the filepaths of the matching expression files included in the directory
    '''
    all_files = []
    for root, dirs, files in os.walk(filepath):
        files = glob.glob(os.path.join(root,expression))
        for f in files :
            all_files.append(os.path.abspath(f))
    return all_files
def save_dict(di_, filename_):
    with open(filename_, 'wb') as f:
        pickle.dump(di_, f)

In [6]:
cd ..\..


E:\CVprojects\Butterflies


In [8]:
from src.models.EncoderProtoNet import EncoderProtoNet

In [11]:
model = EncoderProtoNet(proto_x_dim=128)
weights_path = os.path.join('checkpoints','best_model_95val_82tr.pth')
model.load_state_dict(torch.load(weights_path))
model = model.to('cuda')
model.eval()

EncoderProtoNet(
  (encoder): SubResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True

In [12]:
files_paths = get_files('Data','*sub.pt')

In [13]:
len(files_paths)

25279

In [14]:
classes_to_paths = get_classes_paths(files_paths)

In [15]:
torch.cuda.empty_cache()

In [16]:
max_clust_num = 3
clusters_paths = {i:[] for i in range(max_clust_num)}
torch.backends.cudnn.enabled = False
for i in  tqdm(range(200)):
    paths = classes_to_paths[i][:100]
    embs = embed_images(paths,model)
    torch.cuda.empty_cache()
    kmeans = KMeans(n_clusters=max_clust_num, random_state=0).fit(embs)
    for i,m in enumerate(kmeans.labels_):
        clusters_paths[m].append(paths[i])
    

100%|████████████████████████████████████████████████████████████████████████████████| 200/200 [11:03<00:00,  3.32s/it]


In [17]:
len(paths)

64

In [18]:
ks = []
for j in range(max_clust_num):
    ks.append(get_classes_paths(clusters_paths[j]))

In [19]:
max_images_per_class = 100
new_ks = ks
np.random.seed(0)
for i in range(200):
    ls = []
    cur_class_clust_paths = {}
    cur_max_images_per_class = max_images_per_class
    cur_max_clust_num = max_clust_num
    
    # get the lengths of the clusters examples
    for j in range(max_clust_num):
        cur_class_clust_paths[j] = ks[j][i]
        ls.append(len(ks[j][i]))
    # loop on the clusters from the least samples number to the most    
    indeces = np.argsort(ls)
    for idx in indeces:
        cur_required_samples_per_clust = int(np.floor(cur_max_images_per_class/cur_max_clust_num))
        num_samples = min(cur_required_samples_per_clust,ls[idx])
        new_ks[idx][i] = ks[idx][i][:num_samples]
        
        cur_max_images_per_class = cur_max_images_per_class-num_samples
        cur_max_clust_num = cur_max_clust_num-1
        
        

In [20]:
val_rat = 0.25
test_rat = 0.25
max_samples_per_class = 90
train_paths = []
val_paths = []
test_paths = []
slow_test_paths = []
for i in range(max_clust_num):
    classes_samples = new_ks[i]
    for key in classes_samples:
        cur_paths = np.array(classes_samples[key])
        l = len(cur_paths)
        l2 = min(l,max_samples_per_class)
        num_val = int(l2*val_rat)
        num_test = int(l2*val_rat)

        all_indeces = np.arange(l)
        np.random.shuffle(all_indeces)

        if len(all_indeces)>max_samples_per_class:
            indeces_slow_test = all_indeces[max_samples_per_class:]
            all_indeces_down = all_indeces[:max_samples_per_class]
        else:
            indeces_slow_test = []
            all_indeces_down = all_indeces

        val_indeces = all_indeces_down[:num_val]
        test_indeces = all_indeces_down[num_val:num_val+num_test]
        train_indeces = all_indeces_down[num_val+num_test:]

        train_paths.extend(cur_paths[train_indeces])
        val_paths.extend(cur_paths[val_indeces])
        test_paths.extend(cur_paths[test_indeces])
        slow_test_paths.extend(cur_paths[test_indeces])
        slow_test_paths.extend(cur_paths[indeces_slow_test])

In [21]:
base = 'E:\\CVprojects\\Butterflies\Data\\images_small\\'
designed_splits = {
    'train':{
    base+'110.Chitoria_ulupi':[1,2,3,4,7,10,13,17,20,27,30,31,36],
        base+'113.Sasakia_charonda':[1,4,7,10,13,16,19,20,22,25,28,31,33,34,37,39,42,44,48,53,61,63,68,70,72,76,79,81],
        base+'133.Damora_sagana':[2,4,1,10,16,17,19,21,24,27,30,33,38,42,46,49,50,52,54,55],
        base+'006.Graphium_agamemnon':[1,2,5,8,9,10,12,18,22,26,28,31,34,37,40,42,43,45,49,55,58,65],
        base+'175.Rapala_nissa':[1,2,3,4,9,10,11,12,14,17,19,23,24,27,32,36,38,39,40,41,42,45,50,52,54,57,64,65,66,67],
        base+'014.Meandrusa_sciron':[1,2,3,5,6,10,11,12,13,14,15,18,19,21,24,25,26,29,34,35,36,39,43],
        base+'118.Euthalia_niepelti':[1,2,3,5,6,8,11,13,16,18,22,24,27,28,29,30,32],
        base+'142.Doleschallia_bisaltide':[1,4,6,8,9,10,11,15,17,19,21,24,26,29,31,33,35,45,47,48,51,54,57],
        base+'155.Libythea_myrrha':[1,2,5,7,10,11,14,15,18,19,20,21,25,27,31,32,37,39,40,42,46,49,50,53,58,59,62,63,66,69,73,76,79,80],
        base+'174.Mahathala_ameria':[1,2,3,4,5,6,12,14,19,23,24,25,29,31,32,34,37,38,40,43,44,45,46,47],
        base+'179.Zizeeria_maha':[1,2,3,4,6,7,11,14,17,18,21,25,26,28,29,30,34,37,38,39,46,51,52,53,54,59,60,61,64,66,69,72,73,74,75,76,78,79,80,84],
        base+'166.Arhopala_rama':[1,2,4,6,7,10,11,15,16,18,22,25,26,27,28,29,32,36,37,38,41,44]
        
    },
    'val':{
    base+'110.Chitoria_ulupi':[5,8,11,12,18,21,25,28,32],
        base+'113.Sasakia_charonda':[2,5,8,11,14,17,21,23,26,29,32,35,38,40,43,46,49,55,56,59,62,64,73,75,80,83,84],
        base+'133.Damora_sagana':[3,5,6,11,18,20,22,28,32,34,37,39,41,34,44,51,57,58,59],
        base+'006.Graphium_agamemnon':[3,6,11,14,16,19,21,25,29,33,35,38,41,44,46,52,53,54,59,61],
        base+'175.Rapala_nissa':[5,7,13,15,22,25,28,30,31,34,37,46,47,51,53,55,60,68,70],
        base+'014.Meandrusa_sciron':[7,16,20,22,27,33,37,42,44,45],
        base+'118.Euthalia_niepelti':[4,7,9,10,14,19,23,25,31],
        base+'142.Doleschallia_bisaltide':[2,5,14,18,20,25,27,30,34,37,39,41,43,46,52,55,58,60],
        base+'155.Libythea_myrrha':[3,6,9,12,16,22,26,30,33,36,38,44,47,48,52,55,70,71,72,74,77,81],
        base+'174.Mahathala_ameria':[7,8,10,20,22,26,28,33,35,41,48,50,],
        base+'179.Zizeeria_maha':[5,9,10,15,16,19,22,27,33,35,41,43,47,50,55,57,62,67,77,81,83],
        base+'166.Arhopala_rama':[3,12,17,1,21,24,30,33,34,39,42,45]
    },
    'test':{
    base+'110.Chitoria_ulupi':[6,9,15,19,22,23,24,29,35],
        base+'113.Sasakia_charonda':[3,6,9,12,15,18,24,27,30,36,41,47,50,51,54,57,60,65,69,71,74,77,82,85],
        base+'133.Damora_sagana':[7,12,14,23,25,29,31,35,36,40,45,48,53,56],
        base+'006.Graphium_agamemnon':[4,7,13,15,17,20,23,24,30,32,36,39,47,48,50,51,56,62,63,64,70],
        base+'175.Rapala_nissa':[6,8,16,18,20,26,29,33,35,43,44,49,56,58,59,61,62,69,71],
        base+'014.Meandrusa_sciron':[9,17,23,28,30,32,40,46,47,48],
        base+'118.Euthalia_niepelti':[12,15,17,20,21,26],
        base+'142.Doleschallia_bisaltide':[3,7,16,22,23,28,32,36,38,40,42,44,49,50,53,65,59,61],
        base+'155.Libythea_myrrha':[4,8,13,17,23,28,34,35,41,43,45,51,60,61,64,65,67,68,69,75,78],
        base+'174.Mahathala_ameria':[9,11,13,21,27,30,36,42,49,51,54,56,57],
        base+'179.Zizeeria_maha':[8,12,13,20,23,24,31,32,36,42,44,45,48,49,56,58,63,65,68,71,82],
        base+'166.Arhopala_rama':[13,14,19,20,23,31,35,40,43]
        
    }
}

In [22]:
def format_number(num,suffix = ''):
    s = str(num)
    
    l = len(s)
    while l <3:
        s = '0'+s
        l+=1
    return s+suffix

format_number(122)

'122'

In [23]:
train_paths_designed = [os.path.join(path,format_number(num,'sub.pt'))for path in designed_splits['train'] for num in designed_splits['train'][path]]
val_paths_designed = [os.path.join(path,format_number(num,'sub.pt'))for path in designed_splits['val'] for num in designed_splits['train'][path]]
test_paths_designed = [os.path.join(path,format_number(num,'sub.pt'))for path in designed_splits['test'] for num in designed_splits['train'][path]]


In [24]:
def get_class(path):
    return int(os.path.basename(path.split('.')[0]))-1

In [25]:
design_classes = [get_class(key)for key in designed_splits['train']]
design_classes

[109, 112, 132, 5, 174, 13, 117, 141, 154, 173, 178, 165]

In [26]:
train_paths = [p for p in train_paths if get_class(p) not in design_classes]
val_paths = [p for p in val_paths if get_class(p) not in design_classes]
test_paths = [p for p in test_paths if get_class(p) not in design_classes]


In [27]:
train_paths+=train_paths_designed
val_paths+=val_paths_designed
test_paths+=test_paths_designed

In [28]:
train_paths = [os.path.relpath(p,'.') for p in train_paths ]
val_paths = [os.path.relpath(p,'.') for p in val_paths ]
test_paths = [os.path.relpath(p,'.') for p in test_paths ]

In [29]:
train_paths[0]

'Data\\images_small\\001.Atrophaneura_horishanus\\015sub.pt'

In [30]:
split_dict = {
    'train':train_paths,
    'val':val_paths,
    'test':test_paths,
}
delim = os.sep
split_dict_name = 'configs'+delim+'splits'+delim+"split_dict_hybrid_clust.pkl"
save_dict(split_dict,split_dict_name)


In [40]:
# cd ..\..

In [41]:
# delim = os.sep
# d = load_dict('configs'+delim+'splits'+delim+"split_dict_hybrid_clust.pkl")
# l = len('sub.pt')
# len(d['train'])+len(d['val'])+len(d['test'])
# for key in d:
#     d[key] = [x[:-l]+'.jpg'for x in d[key]]
# save_dict(d,'configs'+delim+'splits'+delim+"split_dict_hybrid_clust.pkl")