In [None]:
from itertools import product
from copy import copy
import random
from itertools import groupby
import os
import pickle
import json

import numpy as np
from tqdm.notebook import tqdm, trange
import matplotlib.pyplot as plt

In [None]:
class ASD:
    
    def __init__(self, sets=[]):
        set_of_sets = []
        for s1 in sets:
            if not any(s2 > s1 for s2 in sets):
                set_of_sets.append(s1)
        self.sets = frozenset(frozenset(s) for s in set_of_sets)

    def __le__(self, other):
        return all(any(s1 <= s2 for s2 in other.sets) for s1 in self.sets)

    def __hash__(self):
        return hash(self.sets)

    def __and__(self, other):
        return ASD(sets=[s1 & s2 for s1, s2 in product(self.sets, other.sets)])

    def __str__(self):
        return str(self.sets)
        
    def __repr__(self):
        return str(self)

    def __len__(self):
        return len(self.sets)

In [None]:
def similarity(r1, r2):
    r1_sim = sum(max(len(s1 & s2)/len(s1 | s2)
                     for s2 in r2.sets)
                 for s1 in r1.sets)/len(r1)
    r2_sim = sum(max(len(s1 & s2)/len(s1 | s2)
                     for s1 in r1.sets)
                 for s2 in r2.sets)/len(r2)
    return (r1_sim + r2_sim)/2

In [None]:
def compute_ccds(pos_ids, pos_asds, neg_ids, neg_asds):
    all_ids = copy(pos_ids)
    all_ids.extend(neg_ids)
    all_asds = copy(pos_asds)
    all_asds.extend(neg_asds)
    id_to_asd = dict(zip(all_ids, all_asds))
    asd_to_id = dict(zip(all_asds, all_ids))
    neg_ids = set(neg_ids)

    ccds = []
    
    for pos_id, pos_asd in tqdm(list(zip(pos_ids, pos_asds))):
        description = pos_asd
        cluster = {pos_id}

        asds_by_sim = sorted([(asd_prime, similarity(description, asd_prime))
                              for asd_prime in pos_asds],
                             key=lambda x: x[1])

        while asds_by_sim:
            ncd = description & asds_by_sim.pop()[0]
            new_cluster = {asd_to_id[asd_prime]
                           for asd_prime in all_asds if ncd <= asd_prime}
            ncd_is_dirty = len(new_cluster & neg_ids)
            if not ncd_is_dirty:
                description = ncd
                cluster = new_cluster
                asds_by_sim = sorted([(asd_prime, similarity(description, asd_prime))
                                      for asd_prime, _ in asds_by_sim],
                                     key=lambda x: x[1])

        ccds.append({'description': description, 'cluster': cluster})

    return ccds

In [None]:
def filter_ccds(ccds):
    all_datapoints = set().union(*(ccd['cluster'] for ccd in ccds))
    
    sorted_ccds = sorted(ccds, key=lambda x: len(x['cluster']), reverse=True)
    
    selected_ccds = []
    covered_datapoints = set()
    remaining_ccds = ccds
    
    while not all_datapoints <= covered_datapoints and remaining_ccds:
        remaining_ccds = [ccd for ccd in sorted_ccds if not ccd['cluster'] <= covered_datapoints]
        selected_ccd = max(remaining_ccds,
                           key=lambda ccd: len(ccd['cluster'] - covered_datapoints))
        
        selected_ccds.append(selected_ccd)
        covered_datapoints.update(selected_ccd['cluster'])
       
    print("Initial number of ccds: " + str(len(ccds)))
    print("New number of ccds: " + str(len(selected_ccds)))
    
    max_cluster_size = max(len(ccd['cluster']) for ccd in selected_ccds)
    min_cluster_size = min(len(ccd['cluster']) for ccd in selected_ccds)
    
    print("Most datapoints in a cluster: " + str(max_cluster_size))
    print("Least datapoints in a cluster: " + str(min_cluster_size))
    
    return selected_ccds

## CUB

In [None]:
cub_dir = '../CUB/CUB_200_2011/'
attributes_dir = os.path.join(cub_dir, 'CUB_200_2011', 'attributes')

with open(os.path.join(attributes_dir, 'image_attribute_labels.txt'), 'r') as fp:
    image_attribute_labels_strings = list(fp)
    def attributes_from_str(s):
        attributes = s.strip().split(' ')
        return (int(attributes[0]), int(attributes[1]), attributes[2] == '1')
    image_attribute_labels = [attributes_from_str(s)
                              for s in image_attribute_labels_strings]

with open(os.path.join(attributes_dir, 'attributes.txt'), 'r') as fp:
    attribute_strings = list(fp)
    attributes_by_id = [s.split(' ', 1) for s in attribute_strings]
    attributes_by_id = {int(k): v.strip() for k, v in attributes_by_id}

with open(os.path.join(cub_dir, 'CUB_200_2011', 'images.txt'), 'r') as fp:
    image_ids_strings = list(fp)
    image_ids = [s.strip().split(' ') for s in image_ids_strings]
    image_id_to_fn = {int(i): s for i, s in image_ids}
    # image_fn_to_id = {s: int(i) for i, s in image_ids}

with open(os.path.join(cub_dir, 'CUB_200_2011', 'image_class_labels.txt'), 'r') as fp:
    image_id_to_class_strings = list(fp)
    image_id_to_class = [s.split(' ') for s in image_id_to_class_strings]
    image_id_to_class = {int(k): int(v) for k, v in image_id_to_class}
    image_ids_by_class = {cl: [x[0] for x in lst]
                          for cl, lst in groupby(sorted(image_id_to_class.items(),
                                                        key=lambda x: x[1]),
                                                 key=lambda x: x[1])}

In [None]:
attribute_labels_by_image = groupby(sorted(image_attribute_labels,
                                           key=lambda x:x[0]),
                                    key=lambda x: x[0])


def cub_asd_from_attributes(attributes):
    set = []
    for _, attribute_id, is_present in attributes:
        if is_present:
            attribute = attributes_by_id[attribute_id]
            set.append(attribute)
    return ASD([set])

id_to_asd = {k: cub_asd_from_attributes(attributes)
             for k, attributes in attribute_labels_by_image}
asd_to_id = {asd: id for id, asd in id_to_asd.items()}

In [None]:
classes = [(6, 7), (23, 25), (47, 180), (60, 64)]
ccds = {}

for class_pair in classes:
    print(class_pair)
    ccds[class_pair] = {}
    pos_cl, neg_cl = class_pair
    
    for _ in range(2):
        pos_ids = image_ids_by_class[pos_cl]
        pos_asds = [id_to_asd[i] for i in pos_ids]
        neg_ids = image_ids_by_class[neg_cl]
        neg_asds = [id_to_asd[i] for i in neg_ids]
        ccds[class_pair][pos_cl] = [
            {
                'description': ccd['description'].sets,
                'cluster': {image_id_to_fn[id] for id in ccd['cluster']}
            }
            for ccd in filter_ccds(
                compute_ccds(pos_ids, pos_asds, neg_ids, neg_asds)
            )
        ]
        pos_cl, neg_cl = neg_cl, pos_cl
        print()

In [None]:
with open(os.path.join('results', 'CUB_ccds.pickle'), 'wb') as fp:
    pickle.dump(ccds, fp)

## CLEVR-Hans3

In [None]:
clevrhans_dir = '../CLEVR-Hans3/'

with open('CLEVR-Hans3_attributes.json', 'r') as fp:
    clevrhans_attributes = json.load(fp)

In [None]:
fn_to_asd = {
    fn: ASD([set(obj) for obj in obj_list])
    for fn, obj_list in clevrhans_attributes.items()
}

all_asds = set(fn_to_asd.values())

In [None]:
ccds = {}

for cl in [0, 1, 2]:
    with open(os.path.join(clevrhans_dir, 'class_{}_positives.txt'.format(cl)), 'r') as fp:
        positives = list(fp)
    pos_ids = [p.strip() for p in positives]
    neg_ids = list(fn_to_asd.keys() - set(pos_ids))
    pos_asds = [fn_to_asd[p] for p in pos_ids]
    neg_asds = all_asds - set(pos_asds)

    ccds[cl] = [
        {
            'description': ccd['description'].sets,
            'cluster': set(ccd['cluster'])
        }
        for ccd in filter_ccds(
            compute_ccds(pos_ids[:80], pos_asds[:80], neg_ids, neg_asds)
        )
    ]

In [None]:
with open(os.path.join('results', 'CLEVR-Hans3_ccds.pickle'), 'wb') as fp:
    pickle.dump(ccds, fp)

In [None]:
ccds