In [1]:
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils, models
import PIL.Image as Image
import torch.nn as nn
import torchvision
import torch.optim as optim
import sys

sys.path.insert(0, '../src')
from bird_dataset import *
from XAI_birds_dataloader import *
from tqdm import tqdm
from models.multi_task_model import *
from XAI_birds_dataloader import *
from XAI_BirdAttribute_dataloader import *

In [2]:
bd = BirdDataset()

In [3]:
trans = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])
train_bird_dataset = XAI_Birds_Dataset(bd, transform=trans, train=True)
val_bird_dataset = XAI_Birds_Dataset(bd, transform=trans, train=False, val=True)

In [16]:
attr_list = [bd.images[i]['attributes'] for i in bd.images]

In [21]:
attr_list = [bd.images[i]['attributes'] for i in bd.images]
attr_filt_list = [[attr for attr in attrs if 'has_bill_shape' in attr or 'wing_color' in attr] for attrs in attr_list]
filt_df = pd.DataFrame(attr_filt_list).apply(lambda x: np.array([i for i in x if i is not None]), axis=1)

In [80]:
filt_df = pd.DataFrame(attr_filt_list).apply(lambda x: np.array([i for i in x if i is not None]), axis=1)

In [86]:
(pd.DataFrame(filt_df.values)[0])

0                         [has_bill_shape::hooked_seabird]
1                                   [has_wing_color::grey]
2        [has_bill_shape::hooked_seabird, has_wing_colo...
3        [has_bill_shape::hooked_seabird, has_wing_colo...
4        [has_bill_shape::hooked_seabird, has_wing_colo...
                               ...                        
11783       [has_bill_shape::dagger, has_wing_color::grey]
11784    [has_bill_shape::all-purpose, has_wing_color::...
11785    [has_bill_shape::all-purpose, has_wing_color::...
11786    [has_bill_shape::all-purpose, has_wing_color::...
11787    [has_bill_shape::all-purpose, has_wing_color::...
Name: 0, Length: 11788, dtype: object

In [69]:
import torch
import os
import PIL.Image as Image
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils, models
import sys

sys.path.insert(0, '../src')
from bird_dataset import *
from XAI_birds_dataloader import *

class Bird_Attribute_Loader(XAI_Birds_Dataset):
    '''
    Loads in x amount of attributes in bd.attributes
    self.attrs:str or list -- loads in x attributes into a list of labels for Pytorch
    '''
    def __init__(self, bd:BirdDataset, attrs, subset=True, transform=None, train=True, val=False, random_seed=42):
        XAI_Birds_Dataset.__init__(self, bd, subset=subset, transform=transform, train=train, val=val, random_seed=random_seed)
#         print(f'num_images: {len(self.images)}')
        self.attrs = attrs
        self.class_dict = self._set_classes_attributes()
        self.images, self.attr_indices = self._filter_images_by_attributes()

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        img_path = os.path.join(self.bd.img_dir, self.images[idx]['filepath'])
        image = Image.open(img_path)
        if isinstance(self.attrs, str):
            attr = self.images[idx]['attributes'][self.attr_indices[idx]]
            label = self.class_dict[attr]
            sample = {'image': image, 'label':label}
        elif isinstance(self.attrs, list):
            attrs = sorted([self.images[idx]['attributes'][i] for i in self.attr_indices[idx]])
#             print(attrs)
#             print("ATTRIBUTES TO INDEX:",attrs)
            labels = [self.class_dict[attr.split('::')[0]][attr] for attr in attrs]
            sample = {'image': image, 'labels':labels}
        if self.transform:
            sample['image'] = self.transform(sample['image'])
        return sample
    
    def _set_classes_attributes(self):
        pd_attr = pd.Series(self.bd.attributes)
        if isinstance(self.attrs, str):
            attrs_dict = pd_attr[pd_attr.str.contains(self.attrs)].to_dict()
            class_dict = dict(zip(attrs_dict.values(), range(len(attrs_dict))))
        elif isinstance(self.attrs, list):
            attrs_dict = dict()
            for attribute in self.attrs:
                attr_dict = pd_attr[pd_attr.str.contains(attribute)].to_dict()
                attrs_dict[attribute] = dict(zip(attr_dict.values(), range(len(attr_dict))))
#             print(f'ATTRS DICT: {attrs_dict}')
#             class_dict = dict(zip(attrs_dict.values(), range(len(attrs_dict))))
        return attrs_dict
    
    def _filter_images_by_attributes(self):
        filt_images = []
        attr_indices = []
        for i, img in enumerate(self.images):
            check=0
            attr_index = []
            attrs = []
            
            for idx, attr in enumerate(img['attributes']):
                if isinstance(self.attrs, str):
                    if self.attrs in attr:
                        filt_images.append(img)
                        attr_indices.append(idx)
                        break
                elif isinstance(self.attrs, list):
                    for attribute in self.attrs:
                        if attribute in attr:
#                             print((idx,attr))
                            
                            check+=1
#                             print(attribute)
                            attr_index.append(idx)
                            attrs.append(attribute)
                else: raise(ValueError, "self.attrs must be a string or a list of strings")
#             print(i, check)
            if check>=len(self.attrs): # only append to images/indices if all attributes in self.attrs are in the images attributes
                unique_attrs = [attr.split('::')[0] for attr in attrs]
#                 print(unique_attrs)
                if len(set(unique_attrs)) >= len(self.attrs):
#                     if img not in filt_images:
                    filt_images.append(img)
#                     print(list(attr_index))
    #                     else: pass
    #                             print('img already herre')
    #                         print('wowie')
                    attr_indices.append((list(attr_index)))
#                 else: print("Attributes aren't unique:",attr)
        return filt_images, attr_indices

In [75]:
bd.attributes

{1: 'has_bill_shape::curved_(up_or_down)',
 2: 'has_bill_shape::dagger',
 3: 'has_bill_shape::hooked',
 4: 'has_bill_shape::needle',
 5: 'has_bill_shape::hooked_seabird',
 6: 'has_bill_shape::spatulate',
 7: 'has_bill_shape::all-purpose',
 8: 'has_bill_shape::cone',
 9: 'has_bill_shape::specialized',
 10: 'has_wing_color::blue',
 11: 'has_wing_color::brown',
 12: 'has_wing_color::iridescent',
 13: 'has_wing_color::purple',
 14: 'has_wing_color::rufous',
 15: 'has_wing_color::grey',
 16: 'has_wing_color::yellow',
 17: 'has_wing_color::olive',
 18: 'has_wing_color::green',
 19: 'has_wing_color::pink',
 20: 'has_wing_color::orange',
 21: 'has_wing_color::black',
 22: 'has_wing_color::white',
 23: 'has_wing_color::red',
 24: 'has_wing_color::buff',
 25: 'has_upperparts_color::blue',
 26: 'has_upperparts_color::brown',
 27: 'has_upperparts_color::iridescent',
 28: 'has_upperparts_color::purple',
 29: 'has_upperparts_color::rufous',
 30: 'has_upperparts_color::grey',
 31: 'has_upperparts_col

In [70]:
trans = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
]) 
train_bird_dataset = Bird_Attribute_Loader(bd, attrs=['has_bill_shape', 'has_wing_color'], transform=trans, train=True, val=False)
val_bird_dataset = Bird_Attribute_Loader(bd, attrs=['has_bill_shape', 'has_wing_color'], transform=trans, train=False, val=True)

In [71]:
len(train_bird_dataset)

1338

In [72]:
len(val_bird_dataset)

339

In [73]:
train_bird_dataset[0]

{'image': tensor([[[0.6353, 0.7686, 0.8392,  ..., 0.4824, 0.4510, 0.4510],
          [0.7137, 0.8431, 0.9098,  ..., 0.4000, 0.3922, 0.4118],
          [0.7765, 0.9020, 0.9529,  ..., 0.3373, 0.3529, 0.3922],
          ...,
          [0.4118, 0.4275, 0.4549,  ..., 0.3451, 0.3373, 0.3373],
          [0.4431, 0.4745, 0.4706,  ..., 0.3608, 0.3490, 0.3765],
          [0.4863, 0.5333, 0.5255,  ..., 0.3647, 0.3529, 0.4000]],
 
         [[0.7137, 0.8471, 0.8863,  ..., 0.4941, 0.4627, 0.4667],
          [0.8000, 0.8980, 0.9294,  ..., 0.4431, 0.4275, 0.4353],
          [0.8431, 0.9333, 0.9686,  ..., 0.3882, 0.4039, 0.4235],
          ...,
          [0.4431, 0.4549, 0.4667,  ..., 0.3451, 0.3412, 0.3412],
          [0.4784, 0.4863, 0.4784,  ..., 0.3647, 0.3608, 0.3804],
          [0.5216, 0.5451, 0.5176,  ..., 0.3765, 0.3765, 0.4118]],
 
         [[0.5333, 0.7059, 0.8431,  ..., 0.4549, 0.4235, 0.4235],
          [0.6431, 0.8314, 0.9373,  ..., 0.3686, 0.3569, 0.3725],
          [0.7451, 0.9059, 0.97

In [38]:
filt_df[filt_df.apply(lambda x: len(x)) > 2]

2        [has_bill_shape::hooked_seabird, has_wing_colo...
4        [has_bill_shape::hooked_seabird, has_wing_colo...
5        [has_bill_shape::hooked_seabird, has_wing_colo...
7        [has_bill_shape::spatulate, has_wing_color::br...
12       [has_bill_shape::spatulate, has_wing_color::bl...
                               ...                        
11782    [has_bill_shape::all-purpose, has_wing_color::...
11784    [has_bill_shape::all-purpose, has_wing_color::...
11785    [has_bill_shape::all-purpose, has_wing_color::...
11786    [has_bill_shape::all-purpose, has_wing_color::...
11787    [has_bill_shape::all-purpose, has_wing_color::...
Length: 6614, dtype: object

In [18]:
# pd.DataFrame([bd.images[i]['attributes'] for i in bd.images]).apply(lambda x: pd.Series([j.split('::')[0] for j in x]), axis=1)