In [1]:
import sys

import metal
import os
# Import other dependencies
import torch
import torch.nn as nn
import torch.nn.functional as F
os.environ['METALHOME'] = '/dfs/scratch1/saelig/slicing/metal/'
# Set random seed for notebook
SEED = 123

In [2]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

# Load Data


In [3]:
from skimage import io, transform
import torchvision.transforms as transforms
import numpy as np

opj = os.path.join
HOME_DIR = '/dfs/scratch1/saelig/slicing/'
DATASET_DIR = opj(HOME_DIR,'CUB_200_2011')
IMAGES_DIR = opj(DATASET_DIR, 'images')
TENSORS_DIR = opj(HOME_DIR, 'birds_data')
MODELS_DIR = opj(HOME_DIR, 'birds_models')

#Size of eac
#image_list = np.loadtxt(os.path.join(DATASET_DIR, 'images.txt'), dtype=str)
#train_test_split = np.loadtxt(os.path.join(DATASET_DIR, 'train_test_split.txt'), dtype=int)
#labels = np.loadtxt(os.path.join(DATASET_DIR, 'image_class_labels.txt'), dtype=int)


train_image_ids = torch.load(opj(TENSORS_DIR,'train_image_ids.pt'))
valid_image_ids = torch.load(opj(TENSORS_DIR,'valid_image_ids.pt'))
test_image_ids = torch.load(opj(TENSORS_DIR,'test_image_ids.pt'))
X_train = torch.load(opj(TENSORS_DIR,'X_train.pt'))
X_valid = torch.load(opj(TENSORS_DIR,'X_valid.pt'))
X_test = torch.load(opj(TENSORS_DIR,'X_test.pt'))
Y_train = torch.load(opj(TENSORS_DIR,'Y_train.pt'))
Y_valid = torch.load(opj(TENSORS_DIR,'Y_valid.pt'))
Y_test = torch.load(opj(TENSORS_DIR,'Y_test.pt'))



Let's create the payloads. First we need to put the attribute information into an easy to deal with data structure.

In [4]:
attrs_array = np.loadtxt(os.path.join(DATASET_DIR, 'attributes/image_attribute_labels.txt'), usecols=(0,1,2), dtype=int)

Let's create a dictionary to make it easier to figure out which samples have which attributes.

In [5]:
NUM_ATTRIBUTES = 312

#format: <image_id>,  <attribute_id>,  <is_present>

attrs_dict = {} #dict mapping attribute id to a set of image_ids that have that attribute

# for attr in range(1, NUM_ATTRIBUTES + 1):
#     temp = attrs_array[(attrs_array[:, 1] == attr) & (attrs_array[:,2] == 1)]
#     print(temp)
#     break

for (image_id, attr_id, is_present) in attrs_array:
    if is_present == 1:
        if attr_id in attrs_dict:
            attrs_dict[attr_id].add(image_id)
        else:
            attrs_dict[attr_id] = {image_id}

Create payload abstraction for slices based on the binary attributes.

In [6]:
from metal.mmtl.payload import Payload
from metal.mmtl.data import MmtlDataLoader, MmtlDataset
from pprint import pprint

payloads = []
splits = ["train", "valid", "test"]
X_splits = X_train, X_valid, X_test
Y_splits = Y_train, Y_valid, Y_test

task_name = 'BirdClassificationTask'
labels_to_tasks = {"labelset_gold": task_name}
slice_sizes = {}
slice_fns = {}
for i in range(3):
    payload_name = f"Payload{i}_{splits[i]}"
    X_dict = {'data': X_splits[i]}
    Y_dict = {'labelset_gold': Y_splits[i]}
    
    if splits[i] == 'train':
        image_ids = train_image_ids
    elif splits[i] == 'valid':
        image_ids = valid_image_ids
    else:
        image_ids = test_image_ids
    for attr_id in range(1, NUM_ATTRIBUTES + 1):
        f = lambda x: 1 if x in attrs_dict[attr_id] else 0
        def slice_fn(x):
            in_attr = map(f, x)
            m = np.array(in_attr)
            return np.reshape(m, (m.shape[0], 1))
        slice_fns[str(attr_id)] = slice_fn
        mask = list(map(f, image_ids.tolist()))
        if splits[i] == 'train':
            slice_sizes[attr_id] = sum(mask) / len(mask) * 100.0
            
        mask = torch.tensor(mask)
        slice_labelset_name = f"labelset:{attr_id}:pred"
        slice_task_name = f"{task_name}:{attr_id}:pred"
        Y_dict[slice_labelset_name] = mask * Y_splits[i]
        labels_to_tasks[slice_labelset_name] = task_name
        
        mask[mask == 0] = 2 #to follow Metal convention
        slice_labelset_name = f"labelset:{attr_id}:ind"
        slice_task_name = f"{task_name}:{attr_id}:ind"
        Y_dict[slice_labelset_name] = mask 
        labels_to_tasks[slice_labelset_name] = None
        
    
    dataset = MmtlDataset(X_dict, Y_dict)
    data_loader = MmtlDataLoader(dataset, batch_size=32)
    payload = Payload(payload_name, data_loader, labels_to_tasks, splits[i])
    payloads.append(payload)


Let's load our baseline model

In [15]:
model = torch.load(opj(MODELS_DIR,'resnet18_lr_1e-3_patience10_shuffled_fc_separated.pt')) #achieves 17% accuracy on test set

In [16]:
accs_per_slice = model.score(payloads[1], metrics=[]) #score on validation set
print(accs_per_slice)
del accs_per_slice['BirdClassificationTask/Payload1_valid/labelset_gold/accuracy']

{'BirdClassificationTask/Payload1_valid/labelset_gold/accuracy': 0.2268557130942452, 'BirdClassificationTask/Payload1_valid/labelset:1:pred/accuracy': 0.25925925925925924, 'BirdClassificationTask/Payload1_valid/labelset:2:pred/accuracy': 0.30344827586206896, 'BirdClassificationTask/Payload1_valid/labelset:3:pred/accuracy': 0.3157894736842105, 'BirdClassificationTask/Payload1_valid/labelset:4:pred/accuracy': 0.16, 'BirdClassificationTask/Payload1_valid/labelset:5:pred/accuracy': 0.16901408450704225, 'BirdClassificationTask/Payload1_valid/labelset:6:pred/accuracy': 0.2222222222222222, 'BirdClassificationTask/Payload1_valid/labelset:7:pred/accuracy': 0.20512820512820512, 'BirdClassificationTask/Payload1_valid/labelset:8:pred/accuracy': 0.2314540059347181, 'BirdClassificationTask/Payload1_valid/labelset:9:pred/accuracy': 0.3103448275862069, 'BirdClassificationTask/Payload1_valid/labelset:10:pred/accuracy': 0.2459016393442623, 'BirdClassificationTask/Payload1_valid/labelset:11:pred/accuracy

In [17]:

accs_per_slice_list = list(accs_per_slice.items())
accs_per_slice_list = list(map(lambda p: (p[0].split(':')[1], p[1]), accs_per_slice_list))
s = sorted(accs_per_slice_list, key=lambda x: x[1])
import csv
with open('slices.csv', 'w', newline='') as csvfile:
    w = csv.writer(csvfile, delimiter=',')
    w.writerow(['Attribute ID', 'Accuracy','Size (%)'])
    for row in s:
        attr_id, acc = row
        w.writerow([attr_id, acc, slice_sizes[int(attr_id)]])



## Slicing Models

In [18]:
print(Y_dict)

{'labelset_gold': tensor([  1,   1,   1,  ..., 200, 200, 200]), 'labelset:1:pred': tensor([0, 0, 0,  ..., 0, 0, 0]), 'labelset:1:ind': tensor([2, 2, 2,  ..., 2, 2, 2]), 'labelset:2:pred': tensor([0, 0, 0,  ..., 0, 0, 0]), 'labelset:2:ind': tensor([2, 2, 2,  ..., 2, 2, 2]), 'labelset:3:pred': tensor([0, 0, 0,  ..., 0, 0, 0]), 'labelset:3:ind': tensor([2, 2, 2,  ..., 2, 2, 2]), 'labelset:4:pred': tensor([0, 0, 0,  ..., 0, 0, 0]), 'labelset:4:ind': tensor([2, 2, 2,  ..., 2, 2, 2]), 'labelset:5:pred': tensor([1, 1, 1,  ..., 0, 0, 0]), 'labelset:5:ind': tensor([1, 1, 1,  ..., 2, 2, 2]), 'labelset:6:pred': tensor([0, 0, 0,  ..., 0, 0, 0]), 'labelset:6:ind': tensor([2, 2, 2,  ..., 2, 2, 2]), 'labelset:7:pred': tensor([  0,   0,   0,  ..., 200, 200, 200]), 'labelset:7:ind': tensor([2, 2, 2,  ..., 1, 1, 1]), 'labelset:8:pred': tensor([0, 0, 0,  ..., 0, 0, 0]), 'labelset:8:ind': tensor([2, 2, 2,  ..., 2, 2, 2]), 'labelset:9:pred': tensor([0, 0, 0,  ..., 0, 0, 0]), 'labelset:9:ind': tensor([2, 2,

In [31]:
from metal.mmtl.metal_model import MetalModel
from metal.mmtl.slicing.slice_model import SliceModel, SliceRepModel
from metal.mmtl.slicing.moe_model import MoEModel
# all models to test to test
# all_slice_funcs = {"slice_1": slice_1, "slice_2": slice_2, "BASE": identity_fn}
slice_train_attrs = ['23', '24']
slice_train_funcs = {attr : slice_fns[attr] for attr in slice_train_attrs}
identity_fn = lambda x: np.ones(x.shape[0], dtype=np.bool)
slice_train_funcs['BASE'] = identity_fn
slice_loss_multiplier = 1.0 / (2*len(slice_train_funcs))
slice_weights = {attr:slice_loss_multiplier for attr in slice_train_attrs}
slice_weights_w_base = dict(slice_weights)
slice_weights_w_base['BASE'] = slice_loss_multiplier
model_configs = {
#     'soft_param': {
#         'slice_funcs': slice_train_funcs,
#         'create_ind': True,
#         'model_class': SliceModel,
#         'slice_weights' : slice_weights_w_base
# #         'slice_weights': {
# #             'BASE': slice_loss_multiplier,
# #             'slice_1': slice_loss_multiplier, 'slice_2': slice_loss_multiplier
# #         }
#     },   
    'soft_param_rep': {
        'slice_funcs': slice_train_funcs,
        'create_ind': True,
        'create_preds': False,
        'model_class': SliceRepModel,
        'slice_weights' : slice_weights_w_base,
#         'slice_weights': {
#             'BASE': slice_loss_multiplier,
#             'slice_1': slice_loss_multiplier, 'slice_2': slice_loss_multiplier
#         },
        'h_dim': 2
    },
    'hard_param': {
        'slice_funcs': slice_train_funcs,
        'create_ind': False,
        'model_class': MetalModel,
        'slice_weights' : slice_weights
        #'slice_weights': {'slice_1': slice_loss_multiplier, 'slice_2': slice_loss_multiplier}
    },
#     'manual_reweighting': {
#         'slice_funcs': slice_train_funcs,
#         'create_ind': False,
#         'slice_weights': {"slice_1": 50}, # 10x weight of slice_2, everything else default
#         'model_class': MetalModel        
#     },
#     'moe': {
#         'slice_funcs': {},
#         'create_ind': False,
#         'model_class': MoEModel
#     },
#     'naive': {
#         'slice_funcs': {},
#         'model_class': MetalModel   
#     }
}

In [34]:
from metal.mmtl.slicing.tasks import MultiClassificationTask
trained_models = {}
for model_name, config in model_configs.items():
    pretrained_input_module = model.input_modules['BirdClassificationTask'].module.module
    pretrained_head_module = model.head_modules['BirdClassificationTask'].module.module
    task0 = MultiClassificationTask(name=task_name, \
                                    input_module=pretrained_input_module, head_module=pretrained_head_module)
    tasks = [task0]
    print(f"{'='*10}Initializing + Training {model_name}{'='*10}")
    slice_funcs = config['slice_funcs']
    model_class = config['model_class']
    slice_weights = config.get("slice_weights", {})
    create_ind = config.get("create_ind", True)
    create_preds = config.get("create_preds", True)
    h_dim = config.get("h_dim", None)
    # just the one task
    #get payloads from above
        
    if model_name == 'moe':
        # train for same total num epochs
        expert_train_kwargs = copy.deepcopy(train_kwargs)
        expert_train_kwargs['n_epochs'] = int(train_kwargs['n_epochs'] / (len(all_slice_funcs) + 1))
        experts = train_slice_experts(
            uid_lists, Xs, Ys, MetalModel, all_slice_funcs, **expert_train_kwargs
        )
        model = model_class(tasks, experts, verbose=False, seed=seed)
        trainer = MultitaskTrainer(seed=seed)
        metrics_dict = trainer.train_model(model, payloads, **expert_train_kwargs)
    else:
        model = model_class(tasks, h_dim=h_dim, verbose=True)
        trainer = MultitaskTrainer()
        metrics_dict = trainer.train_model(model, payloads, **train_kwargs)
    print(metrics_dict) 
    trained_models[model_name] = model



ValueError: There must be a `slice_task` designated to operate on the entire labelset with name 'BirdClassificationTask:BASE'.