In [94]:
import sys

import metal
import os
# Import other dependencies
import torch
import torch.nn as nn
import torch.nn.functional as F
os.environ['METALHOME'] = '/dfs/scratch1/saelig/slicing/metal/'
# Set random seed for notebook
SEED = 123

In [95]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Load Data

Here, the train/test split was defined in the dataset. We then split the train set into a train/valid (see next cell)

In [96]:
from skimage import io, transform
import torchvision.transforms as transforms
import numpy as np
DATASET_DIR = '/dfs/scratch1/saelig/slicing/CUB_200_2011/'
IMAGES_DIR = os.path.join(DATASET_DIR, 'images')

#Size of eac
image_list = np.loadtxt(os.path.join(DATASET_DIR, 'images.txt'), dtype=str)
train_test_split = np.loadtxt(os.path.join(DATASET_DIR, 'train_test_split.txt'), dtype=int)
labels = np.loadtxt(os.path.join(DATASET_DIR, 'image_class_labels.txt'), dtype=int)

X = []
Y = []
X_test = []
Y_test = []

#image size (332, 500, 3)
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
tt = transforms.ToTensor()

for image_id, image_file in image_list:
    image_id = int(image_id)
    image_data = io.imread(os.path.join(IMAGES_DIR, image_file))
    image_data = transform.resize(image_data, (224,224,3)) #resize all images to 224x224
    image_data = normalize(tt(image_data).type(torch.float32)) #make channel dim first
    label = labels[image_id - 1][1] #Keep 1 the first class since 0 is used for abstain
    if train_test_split[image_id - 1][1] == 1: #put in train
        X.append(image_data)
        Y.append(label)
    else: #put in test
        X_test.append(image_data)
        Y_test.append(label)

X_train = torch.stack(X)
Y_train = np.array(Y)
X_test = torch.stack(X_test)
Y_test = np.array(Y_test)

Now let's convert all the data to tensors, and create a validation set.

In [97]:
from metal.utils import split_data

(X_train, X_valid), (Y_train, Y_valid) = split_data(X_train, Y_train, splits=[0.8,0.2], seed=SEED)

# X_train, X_valid, X_test = torch.tensor(X_train), torch.tensor(X_valid), torch.tensor(X_test)
Y_train, Y_valid, Y_test = torch.tensor(Y_train), torch.tensor(Y_valid), torch.tensor(Y_test)

# X_train = X_train.permute(0,3,1,2)
# X_valid = X_valid.permute(0,3,1,2)
# X_test = X_test.permute(0,3,1,2)

Create a task. Use a resnet50 for now as our input model. Since the resnet already includes the fully connected layer, we don't specify a `head_module`, which defaults to the identity.

In [112]:
#from torchvision.models.resnet import *
from metal.mmtl.slicing.tasks import MultiClassificationTask
from metal.mmtl.metal_model import MetalModel 
from resnet import *

resnet_model = resnet18(num_classes=200).float().cuda()

task0 = MultiClassificationTask(
    name='BirdClassificationTask', 
    input_module=resnet_model, 
)
tasks = [task0]
model = MetalModel(tasks, verbose=False)

Create payload abstraction for our train/valid/test sets.

In [143]:
from metal.mmtl.payload import Payload
from pprint import pprint

payloads = []
splits = ["train", "valid", "test"]
X_splits = X_train, X_valid, X_test
Y_splits = Y_train, Y_valid, Y_test

for i in range(3):
    payload_name = f"Payload{i}_{splits[i]}"
    task_name = task0.name
    #print(X_splits[i].shape)
    payload = Payload.from_tensors(payload_name, {'data': X_splits[i]}, {'labels' : Y_splits[i]}, task_name, splits[i], batch_size=32)
    #payload = Payload.from_tensors(payload_name, X_splits[i], Y_splits[i], task_name, splits[i], batch_size=32)
    payloads.append(payload)

pprint(payloads)
print(payloads[0].data_loader)

[Payload(Payload0_train: labels_to_tasks=[{'labels': 'BirdClassificationTask'}], split=train),
 Payload(Payload1_valid: labels_to_tasks=[{'labels': 'BirdClassificationTask'}], split=valid),
 Payload(Payload2_test: labels_to_tasks=[{'labels': 'BirdClassificationTask'}], split=test)]
<metal.mmtl.data.MmtlDataLoader object at 0x7f4c0e575cc0>


In [128]:
from metal.mmtl.trainer import MultitaskTrainer
trainer = MultitaskTrainer()
resnet_model = resnet18(num_classes=200).float().cuda()

task0 = MultiClassificationTask(
    name='BirdClassificationTask', 
    input_module=resnet_model, 
)
tasks = [task0]
model = MetalModel(tasks, verbose=False)

# scores = trainer.train_model(
#     model, 
#     payloads, 
#     n_epochs=30, 
#     log_every=2,
#     lr=0.001,
#     progress_bar=False,
# #     lr_scheduler='reduce_on_plateau',
# #     patience = 3,
#     checkpoint_every = 2,
#     checkpoint_metric='BirdClassificationTask/Payload1_valid/labels/accuracy',
#     checkpoint_metric_mode='max',
# )
scores = trainer.train_model(
    model, 
    payloads, 
    n_epochs=30, 
    log_every=2,
    lr=0.001,
    progress_bar=False,
    lr_scheduler='reduce_on_plateau',
    patience = 10,
    checkpoint_every = 2,
    checkpoint_metric='BirdClassificationTask/Payload1_valid/labels/accuracy',
    checkpoint_metric_mode='max',
)

CONFIG:  {'verbose': True, 'seed': 531396, 'commit_hash': None, 'ami': None, 'progress_bar': False, 'n_epochs': 30, 'l2': 0.0, 'grad_clip': 1.0, 'optimizer_config': {'optimizer': 'adam', 'optimizer_common': {'lr': 0.001}, 'sgd_config': {'momentum': 0.9}, 'adam_config': {'betas': (0.9, 0.999)}, 'rmsprop_config': {}}, 'lr_scheduler': 'reduce_on_plateau', 'lr_scheduler_config': {'warmup_steps': 0.0, 'warmup_unit': 'epochs', 'min_lr': 1e-06, 'exponential_config': {'gamma': 0.999}, 'plateau_config': {'factor': 0.5, 'patience': 10, 'threshold': 0.0001}}, 'metrics_config': {'task_metrics': [], 'trainer_metrics': ['model/valid/all/loss'], 'aggregate_metric_fns': [], 'max_valid_examples': 0, 'valid_split': 'valid', 'test_split': 'test'}, 'task_scheduler': 'proportional', 'logger': True, 'logger_config': {'log_unit': 'epochs', 'log_every': 2, 'score_every': -1.0, 'log_lr': True}, 'writer': None, 'writer_config': {'log_dir': '/dfs/scratch1/saelig/slicing/metal//logs', 'run_dir': None, 'run_name':

Exception ignored in: <generator object tqdm_notebook.__iter__ at 0x7f4c0e587e08>
Traceback (most recent call last):
  File "/dfs/scratch0/saelig/miniconda3/envs/slicing/lib/python3.6/site-packages/tqdm/_tqdm_notebook.py", line 226, in __iter__
    self.sp(bar_style='danger')
AttributeError: 'tqdm_notebook' object has no attribute 'sp'


KeyboardInterrupt: 

Now let's see where the model struggles to make accurate predictions by sweeping over the binary attributes. We can then use these to idenitfy potentially useful slices.

In [161]:
NUM_ATTRIBUTES = 312
attributes_array = np.loadtxt(os.path.join(DATASET_DIR, 'attributes/image_attribute_labels.txt'), usecols=(0,1,2))
model = torch.load('resnet18_lr_1e-3_patience10.pt')

predictions = torch.tensor(model.predict(payloads[2], task_name='BirdClassificationTask'))
print((predictions == (Y_test)).sum())
incorrect_predictions = (predictions != Y_test).nonzero().flatten().tolist() #get indices of incorrect predictions
print(len(incorrect_predictions))    


tensor(1012)
4782


In [166]:
test_ids = train_test_split[train_test_split[:,1] == 0][:,0]
counter = [0] * NUM_ATTRIBUTES
for id in incorrect_predictions:
    image_id = id + 1
    attributes_for_image = attributes_array[attributes_array[:, 0] == image_id][:,1:]
    for i in range(len(attributes_for_image)):
        if attributes_for_image[i,1] == 1: #attribute is present
            counter[i] += 1


l = list(map(lambda x: (x[0] + 1, x[1]), enumerate(counter)))
print('(attribute id, num misclassifed images)')
print(sorted(l, key=lambda x: x[1], reverse=True))

(attribute id, num misclassifed images)
[(146, 3818), (55, 2664), (245, 2653), (219, 2244), (152, 2186), (290, 2180), (21, 2146), (150, 2128), (36, 2004), (260, 1966), (237, 1875), (241, 1828), (236, 1725), (179, 1703), (305, 1688), (164, 1686), (52, 1673), (102, 1654), (210, 1607), (118, 1586), (133, 1533), (70, 1505), (309, 1495), (7, 1479), (312, 1470), (91, 1452), (221, 1429), (194, 1396), (15, 1388), (261, 1382), (254, 1368), (30, 1364), (275, 1361), (132, 1312), (22, 1238), (213, 1237), (269, 1211), (284, 1194), (76, 1191), (117, 1152), (64, 1132), (37, 1123), (195, 1104), (11, 1070), (8, 1058), (244, 1049), (51, 1021), (250, 1016), (240, 984), (188, 981), (299, 971), (58, 956), (26, 952), (158, 928), (214, 884), (78, 881), (311, 875), (180, 862), (173, 860), (209, 851), (111, 833), (45, 827), (85, 822), (60, 790), (222, 780), (126, 770), (248, 751), (71, 749), (165, 747), (306, 734), (101, 731), (203, 712), (2, 684), (263, 683), (184, 679), (295, 677), (24, 666), (92, 655), (154