In [1]:
import pickle
from itertools import groupby
import os

from tqdm.notebook import tqdm
from skimage import io

from parse_dataset import CUB
from utils import distance_between_samples

## CUB

In [2]:
def beautify_attributes(attribute_list):
    attribute_list = [att.split('::') for att in attribute_list]
    primary_colors = []
    color_attributes = []
    shape_attributes = []
    pattern_attributes = []
    other_attributes = []
    for property, value in attribute_list:
        value = value.replace('-wings', '')
        value = value.replace('_tail', '')
        if property.endswith('primary_color'):
            primary_colors.append(value)
        elif property.endswith('color'):
            color_attributes.append([' '.join(property.split('_')[1:-1]),
                                     ' '.join(value.split('_'))])
        elif property != 'has_shape' and property.endswith('shape'):
            shape_attributes.append([' '.join(property.split('_')[1:-1]),
                                     ' '.join(value.split('_'))])
        elif property.endswith('pattern'):
            pattern_attributes.append([' '.join(property.split('_')[1:-1]),
                                       ' '.join(value.split('_'))])
        else:
            other_attributes.append([' '.join(property.split('_')[1:]),
                                     value])
    beautified_attributes = []

    if primary_colors:
        s = 'The bird is primarily {}{}.'
        arg1 = ', '.join(primary_colors[:-1])
        arg2 = (primary_colors[0] if len(primary_colors) == 1
                else ' and {}'.format(primary_colors[-1]))
        beautified_attributes.append(s.format(arg1, arg2))
    
    for k, g in groupby(sorted(color_attributes, key=lambda x: x[1]), key=lambda x: x[1]):
        parts = [x[0] + ('s' if x[0] in ['wing', 'eye', 'leg'] else '')
                 for x in g]
        if 'primary' in parts:
            beautified_attributes.append('The bird is primarily {}.'.format(k))
            parts = [x for x in parts if x != 'primary']
        if not parts:
            break
        s = 'The bird has {}{} {}{}.'
        arg1 = ''
        if (not parts[0] in ['wings', 'eyes', 'legs', 'underparts', 'upperparts']):
            if k[0] in ('a', 'e', 'i', 'o', 'u'):
                arg1 = 'an '
            else:
                arg1 = 'a '
        arg2 = k
        arg3 = ', '.join(parts[:-1])
        arg4 = parts[0] if len(parts) == 1 else ' and {}'.format(parts[-1])
        beautified_attributes.append(s.format(arg1, arg2, arg3, arg4))

    for k, g in groupby(sorted(pattern_attributes, key=lambda x: x[1]), key=lambda x: x[1]):
        parts = [x[0] + ('s' if x[0] == 'wing' else '')
                 for x in g]
        s = 'The bird has {} {} pattern on its {}{}.'
        arg1 = 'a'
        if k[0] in ('a', 'e', 'i', 'o'):
            arg1 = 'an'
        arg2 = k
        arg3 = ', '.join(parts[:-1])
        arg4 = parts[0] if len(parts) == 1 else ' and {}'.format(parts[-1])
        beautified_attributes.append(s.format(arg1, arg2, arg3, arg4))

    for property, value in other_attributes:
        if property == 'shape':
            if value == 'chicken-like-marsh':
                beautified_attributes.append(
                    'The bird is shaped like a marsh chicken.'
                )
            else:
                value = value[:-5]
                if value in ['long-legged', 'perching', 'tree-clinging', 'upright-perching_water']:
                    value = value.replace('_', ' ')
                    beautified_attributes.append(
                        'The bird is shaped like a {} bird.'.format(value)
                    )
                else:
                    beautified_attributes.append(
                        'The bird is shaped like a{} {}.'.format(
                            'n' if value[0] in ('a', 'e', 'i', 'o', 'u') else '',
                            value)
                    )
        elif property == 'bill length':
            if value == 'about_the_same_as_head':
                beautified_attributes.append(
                    'The bird\'s bill is about as long as its head.'
                )
            else:
                beautified_attributes.append(
                    'The bird\'s bill is {} its head.'.format(' '.join(value.split('_')[:-1]))
                )
        elif property == 'size':
            value = value.split('_')
            beautified_attributes.append(
                'The bird has a {} size {}.'.format(
                    ' '.join(value[:-4]),
                    ' '.join(value[-4:])
                )
            )
        else:
            print(property)
            raise

    beautified_shape_attributes = []
    for property, value in shape_attributes:
        s = '{}{} {}-shape'
        arg1 = 'a '
        if value[0] in ('a', 'e', 'i', 'o', 'u'):
            arg1 = 'an '
        arg2 = value
        arg3 = property
        beautified_shape_attributes.append(s.format(arg1, arg2, arg3))
    if beautified_shape_attributes:
        arg1 = ', '.join(beautified_shape_attributes[:-1])
        arg2 = (beautified_shape_attributes[0] if len(beautified_shape_attributes) == 1
                else ' and {}'.format(beautified_shape_attributes[-1]))
        beautified_attributes.append(
            'The bird has {}{}.'.format(arg1, arg2)
        )
    
    return beautified_attributes

In [3]:
def read_cropped_sample(dataset, image_id):
    image_name = dataset.image_id_to_image_name[image_id]
    image_path = os.path.join(cropped_dataset_path, image_name)
    image_path = image_path.replace(".jpg", ".JPEG")
    image = io.imread(image_path)
    return image

In [4]:
def find_prototype(dataset, ccd, cover_set):
    ccd = list(ccd)[0]
    costs = []
    for image_id in cover_set:
        image_annotations = set(dataset.image_id_to_annotations[image_id])
        cost, edits = distance_between_samples(ccd, image_annotations, True, False)        
        costs.append([image_id, cost, edits])
    return costs

In [5]:
def find_costs(ccds):
    all_costs = []
    for ccd in ccds:
        cover_set = [image_id for image_id in dataset.image_id_to_annotations
                     if dataset.image_id_to_image_name[image_id] in ccd['cluster']]
        costs = find_prototype(dataset, ccd['description'], cover_set)
        costs.sort(key=lambda x: x[1])
        all_costs += [costs[0:3]]
    return all_costs

In [6]:
cub_dir = '../CUB/CUB_200_2011/CUB_200_2011/'
cropped_dataset_path = "CUB_200_2011/datasets/cub200_cropped/train_cropped"
ccds_path = os.path.join('results', 'CUB_ccds.pickle')

dataset = CUB(cub_dir)

In [7]:
with open(ccds_path, 'rb') as fp:
    ccds = pickle.load(fp)

In [8]:
prototypical_instances = {}

for class_pair in tqdm(ccds):
    prototypical_instances[class_pair] = {}
    for cl in class_pair:
        cl_ccds = ccds[class_pair][cl]
        all_costs = find_costs(cl_ccds)
        filtered_costs = []
        for costs in all_costs:
            fc = []
            for cost in costs:
                try:
                    read_cropped_sample(dataset, cost[0])
                    fc.append(cost)
                except Exception as e:
                    continue
            if fc:
                filtered_costs.append(fc)
            else:
                filtered_costs.append(None)
        cl_ccds = [[beautify_attributes(list(s)) for s in ccd['description']]
                    for ccd in cl_ccds]
        prototypical_instances[class_pair][cl] = [
            (ccd[0], [(dataset.image_id_to_image_name[cost[0]], *cost) for cost in costs])
            for ccd, costs in zip(cl_ccds, filtered_costs)
            if costs is not None
        ]

  0%|          | 0/4 [00:00<?, ?it/s]

In [45]:
with open(os.path.join('results', 'CUB_prototypes.pickle'), 'wb') as fp:
    pickle.dump(prototypical_instances, fp)