In [None]:
import torch 
import random
from random import randint
from torch.nn import functional as F


__all__ = ['distribution_matching']

def compute_metric_class_similarity_score(candidate_feature, buffer_features ):
    """
    candidate_feature: 128 dimensional
    buffer_features: Nx128 all features from the targeted classes
    """
#     import pdb 
#     pdb.set_trace()
    
    metric = F.cosine_similarity(
        candidate_feature[None,:,None].repeat(buffer_features.shape[0],1,1).type(torch.float32), 
        buffer_features[:,:,None].type(torch.float32),
        dim=1, eps=1e-6) 
#     print(metric)
#     import pdb 
#     pdb.set_trace()
    return metric.min()
    
def compute_image_similiarity_score(candidate_image_features, buffer_features_all):
    """
    candidate_feature: Cx128 dimensional
    buffer_features: NBxCx128  
    """
    image_score = 0
    valid_features_can = candidate_image_features.sum(dim=1) != 0 
    
    # buffer_features_of_interrest = buffer_features[NB,valid_features_can] # we only have to look at a subset of features
    indices_valid_features_can = torch.where(valid_features_can != False)
    print("TRES", indices_valid_features_can)
    for cla in indices_valid_features_can[0].tolist():
        features = buffer_features_all[:,cla,:]
        valid_buffer_elements_for_this_cla = features.sum(dim=1) != 0
        if valid_buffer_elements_for_this_cla.sum() == 0:
            # this class does not exist in the buffer so far
            image_score += 0
        else:
            valid_cla_buffer_features = features[valid_buffer_elements_for_this_cla]
            score = compute_metric_class_similarity_score(candidate_image_features[cla], 
                                                          valid_cla_buffer_features)
            image_score += score
    
    #similarity lower is better
    
    #normalize the image score to not favour images with less classes
    image_score /= indices_valid_features_can.sum()
    
    return image_score
    

def interclass_dissimilarity(latent_features, label_dist, K_return=50, iterations= 1000, early_stopping = 0.00001):
    """ returns k indexe from globale_indices such that the returned indexed match the total feature distribution the best

    Parameters
    ----------
    latent_features: torch.tensor NRxCx128 
    label_dist : torch.tensor NRxC
    K_return : int, optional
            number of returned indices, by default 50
    """
    label_probs = label_dist.sum(dim=0) / label_dist.sum(dim=[0,1])
    allowed_selection = torch.ceil( label_probs * K_return )
    indices = torch.where( allowed_selection != 0)
    res = torch.argsort(allowed_selection, dim=-1, descending=False)
    appearing_classes = allowed_selection[res] != 0
    
    rising_indices_class_count = res[appearing_classes] 
    rising_value_class_count = allowed_selection[rising_indices_class_count]
    
    print(rising_value_class_count)
    print(rising_indices_class_count)
    aggregator_array = torch.zeros(allowed_selection.shape[0] )
    
    init_selection = torch.randperm( latent_features.shape[0] )[:K_return]
    buffer_features_all = latent_features[init_selection,:,:]
    
    
    
#     for nyu_id, nr_to_select in zip(rising_indices_class_count.list(),rising_value_class_count.list())    
      
    for i in range(10):
        swap_candidate = torch.randint(0,K_return, (1,))
        while True:
            replacement_candidate = torch.randint(0,latent_features.shape[0],(1,))[0]
            if replacement_candidate not in init_selection:
                break
        
        init_selection
        print(swap_candidate, replacement_candidate)
        compute_image_similiarity_score( latent_features[replacement_candidate], buffer_features_all)
        
        
    # print(allowed_selection, indices)
    return True, True

#     return selected, old_metric
selected, metric = interclass_dissimilarity(latent_features, label_dist)


In [None]:
# Questions we have to answer:
"""
Should we use the pixelwise information -> We kind of do this already when we check how we sample it
-> Can we utlize the network to check if the selected sample is a good one. 
-> Could we just maximize the intra class dissimilarity 


"""



# def test():
if True:
    from PIL import Image
    import os, sys
    
    import time
    
    os.chdir('/home/jonfrey/ASL')
    sys.path.append('/home/jonfrey/ASL')
    sys.path.append('/home/jonfrey/ASL/src')
    from visu import Visualizer
    
    vis = Visualizer('/home/jonfrey/tmp', logger=None, epoch=0, store=True, num_classes=41)

    label_dist = torch.load( '/media/scratch1/jonfrey/models/master_thesis/dev/uncertainty_integration3/labels_tensor_0.pt')
    globale_indices = torch.load( '/media/scratch1/jonfrey/models/master_thesis/dev/uncertainty_integration3/indices_tensor_0.pt')
    latent_features = torch.load( '/media/scratch1/jonfrey/models/master_thesis/dev/uncertainty_integration3/latent_feature_tensor_0.pt')
    
    st = time.time()
    selected, metric = interclass_dissimilarity(latent_features, label_dist)
    
    
    
    
    
    selected_globale_indices = globale_indices[selected]
    t =  time.time()-st
    print(selected.sum(), metric, 'Total time',t)                 
    # features = features_l.clone()
    # globale_indices = globale_indices_l.clone()
    # K_return=50
    # print(new_metric)
    
    res = vis.plot_bar(features.sum(dim=0), x_label='Label', y_label='Count',
            sort=False, reverse=True, title= 'All', 
            tag=f'Pixelwise_Class_Count_Task', method='left',jupyter=True)
    display(Image.fromarray(res))


    res = vis.plot_bar(features[selected,:].sum(dim=0), x_label='Label', y_label='Count',
            sort=False, reverse=True, title= 'Selected',
            tag=f'Pixelwise_Class_Count_Task', method='left',jupyter=True)

    
    display(Image.fromarray(res))

# test()