In [2]:
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils, models
import PIL.Image as Image
import torch.nn as nn
import torch.optim as optim
import sys

sys.path.insert(0, '../src')
from bird_dataset import *
from XAI_birds_dataloader import *
from tqdm import tqdm

In [105]:
# #hide
# from fastai.vision.all import *
# from fastai.text.all import *
# from fastai.collab import *
# from fastai.tabular.all import *

In [3]:
bd = BirdDataset()

In [82]:
class Bird_Attribute_Loader(XAI_Birds_Dataset):
    '''
    Can be combined with Bill Shape Class into one general class with attribute self.attr = 'has_wing_color'
    '''
    def __init__(self, bd:BirdDataset, attrs, subset=True, transform=None, train=True, val=False, random_seed=42):
        XAI_Birds_Dataset.__init__(self, bd, subset=subset, transform=transform, train=train, val=val, random_seed=random_seed)
        print(f'num_images: {len(self.images)}')
        self.attrs = attrs
        self.class_dict = self._set_classes_attributes()
        self.images, self.attr_indices = self._filter_images_by_attributes()

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        img_path = os.path.join(self.bd.img_dir, self.images[idx]['filepath'])
        image = Image.open(img_path)
        if isinstance(self.attrs, str):
            attr = self.images[idx]['attributes'][self.attr_indices[idx]]
            label = self.class_dict[attr]
            sample = {'image': image, 'label':label}
        elif isinstance(self.attrs, list):
            attrs = [self.images[idx]['attributes'][i] for i in self.attr_indices[idx]]
            print(attrs)
            labels = [self.class_dict[attr] for attr in attrs]
            sample = {'image': image, 'labels':labels}
        if self.transform:
            sample['image'] = self.transform(sample['image'])
        return sample
    
    def _set_classes_attributes(self):
        pd_attr = pd.Series(self.bd.attributes)
        if isinstance(self.attrs, str):
            attrs_dict = pd_attr[pd_attr.str.contains(self.attrs)].to_dict()
            class_dict = dict(zip(attrs_dict.values(), range(len(attrs_dict))))
        elif isinstance(self.attrs, list):
            attrs_dict = dict()
            for attribute in self.attrs:
                attr_dict = pd_attr[pd_attr.str.contains(attribute)].to_dict()
                attrs_dict.update(attr_dict)
            class_dict = dict(zip(attrs_dict.values(), range(len(attrs_dict))))
        return class_dict
    
    def _filter_images_by_attributes(self):
        filt_images = []
        attr_indices = []
        for img in self.images:
            check=0
            attr_index = []
            for idx, attr in enumerate(img['attributes']):
                if isinstance(self.attrs, str):
                    if self.attrs in attr:
                        filt_images.append(img)
                        attr_indices.append(idx)
                        break
                elif isinstance(self.attrs, list):
                    for attribute in self.attrs:
                        if attribute in attr:
                            check+=1
#                             print(attribute)
                            attr_index.append(idx)
                    if check==len(self.attrs): # only append to images/indices if all attributes in self.attrs are in the images attributes
                        if img not in filt_images:
                            filt_images.append(img)
                        else: pass
#                             print('img already herre')
#                         print('wowie')
                        attr_indices.append(sorted(list(set(attr_index))))
                else: raise(ValueError, "self.attrs must be a string or a list of strings")
        return filt_images, attr_indices

In [106]:
vgg16 = models.vgg16_bn(pretrained=True)


In [84]:
# vgg16

In [85]:
# bd.images

In [86]:
# [bd.images[i]['attributes'] for i in list(bd.images.keys()) if 'has_bill_shape' in bd.images[i]['attributes']]

In [107]:
trans = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])
train_bird_dataset = Bird_Attribute_Loader(bd, attrs=['has_wing_color','has_bill_shape'], transform=trans, train=True)
val_bird_dataset = Bird_Attribute_Loader(bd, attrs=['has_wing_color','has_bill_shape'], transform=trans, train=False, val=True)

num_images: 1466
num_images: 367


In [74]:
# len(train_bird_dataset)
# len(val_bird_dataset)

In [108]:
# for sample in train_bird_dataset:
#     print(sample)
#     break

In [110]:
class MultiTaskModel(nn.Module):
    """
    Creates a MTL model with the encoder from "arch" and with dropout multiplier ps.
    """
    def __init__(self, model,ps=0.5):
        super(MultiTaskModel,self).__init__()
        
#         num_feats = model.classifier[6].in_features
#         features = list(model.classifier.children())[:-1]
#         features.extend([nn.Linear(num_feats, len(train_bird_dataset.class_dict))])
#         vgg16.classifier = nn.Sequential(*features) # Replace the model classifier
        
        self.encoder = model        #fastai function that creates an encoder given an architecture
        
        self.fc1 = nn.Linear(1000, 9)    #fastai function that creates a head
        self.fc2 = nn.Linear(1000, 15)

    def forward(self,x):

        x = nn.ReLU(self.encoder(x))
        
        bill_shape = self.fc1(x)
        wing_color = self.fc2(x)

        return bill_shape, wing_color

In [111]:
class MultiTaskLossWrapper(nn.Module):
    def __init__(self):
        super(MultiTaskLossWrapper, self).__init__()
#         self.task_num = task_num
#         self.log_vars = nn.Parameter(torch.zeros((task_num)))

    def forward(self, preds, bill_shape, wing_color):


        loss0 = nn.CrossEntropyLoss(preds[0], bill_shape)
        loss1 = nn.CrossEntropyLoss(preds[1], wing_color)

#         precision0 = torch.exp(-self.log_vars[0])
#         loss0 = precision0*loss0 + self.log_vars[0]

#         precision1 = torch.exp(-self.log_vars[1])
#         loss1 = precision1*loss1 + self.log_vars[1]

#         precision2 = torch.exp(-self.log_vars[2])
#         loss2 = precision2*loss2 + self.log_vars[2]
        
        return loss0+loss1

In [112]:
model = MultiTaskModel(vgg16)

In [113]:
model

MultiTaskModel(
  (encoder): VGG(
    (features): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
      (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (9): ReLU(inplace=True)
      (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (12): ReLU(inplace=True)
      (13): MaxPool2d(kernel_size=2, stride=2, paddin

In [None]:
        loss2 = crossEntropy(preds[2],ethnicity)
