# MIT dataset

In [78]:
# import some common libraries
import numpy as np
import torch
import torch.nn as nn
import os

import matplotlib.pyplot as plt
from train import train_with_config, train
from eval_model import eval_model
from model import CompoResnet
from symnet.utils import dataset
if torch.cuda.is_available():  
  dev = "cuda:0" 
else:  
  dev = "cpu" 
  
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [79]:
# train [img, attr_id, obj_id, pair_id, img_feature, img, attr_id, obj_id, pair_id, img_feature, aff_mask]
# test [img, attr_id, obj_id, pair_id, img_feature, aff_mask]

train_dataloader = dataset.get_dataloader('MITg', 'train', batchsize=64, with_image=True, shuffle=True)
test_dataloader = dataset.get_dataloader('MITg', 'test', batchsize=64, with_image=True)

53753 activations loaded
natural split train
#images = 30338
53753 activations loaded
natural split test
#images = 12995


In [3]:
num_mlp_layers = 1
resnet_name = 'resnet18'
compoResnet = CompoResnet(resnet_name, num_mlp_layers).to(dev)

obj_loss_history = [[],[]]
attr_loss_history = [[],[]]
optimizer = torch.optim.AdamW(compoResnet.parameters(), lr=0.0015)
criterion = nn.CrossEntropyLoss()
curr_epoch = 0

model_dir = './models/'
load_model_name = None # 'model_35.pt'
model_path = None if not load_model_name else os.path.join(model_dir, load_model_name)

if model_path:
  #checkpoint = torch.load(model_path), map_location=torch.device('cpu'))
  checkpoint = torch.load(model_path)
  compoResnet.load_state_dict(checkpoint['model_state_dict'])
  optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  obj_loss_history = checkpoint['obj_loss']
  attr_loss_history = checkpoint['attr_loss']

Using cache found in /home/ubuntu/.cache/torch/hub/pytorch_vision_v0.9.0


In [80]:
num_epochs = 15
batch_size = 64
curr_epoch = 0
model_name = 'model'
train(compoResnet, optimizer, criterion, num_epochs, obj_loss_history, attr_loss_history, batch_size, train_dataloader, 
      test_dataloader=test_dataloader, curr_epoch=curr_epoch, model_name=model_name, model_dir=model_dir)

 21%|██        | 100/475 [00:43<02:34,  2.42it/s, Train: epoch 0/15]

[1,   100] obj_loss: 1.672, attr_loss: 1.881


 42%|████▏     | 200/475 [01:24<01:51,  2.47it/s, Train: epoch 0/15]

[1,   200] obj_loss: 1.682, attr_loss: 1.908


 63%|██████▎   | 300/475 [02:05<01:08,  2.54it/s, Train: epoch 0/15]

[1,   300] obj_loss: 1.694, attr_loss: 1.917


 84%|████████▍ | 400/475 [02:44<00:28,  2.62it/s, Train: epoch 0/15]

[1,   400] obj_loss: 1.712, attr_loss: 1.920


100%|██████████| 475/475 [03:12<00:00,  2.47it/s, Train: epoch 0/15]
100%|██████████| 204/204 [00:39<00:00,  5.15it/s]


[1] obj_val_loss: 3.371, attr_val_loss: 4.497
Finished training.


 21%|██        | 100/475 [00:36<02:14,  2.79it/s, Train: epoch 1/15]

[2,   100] obj_loss: 1.648, attr_loss: 1.869


 42%|████▏     | 200/475 [01:12<01:38,  2.78it/s, Train: epoch 1/15]

[2,   200] obj_loss: 1.665, attr_loss: 1.880


 63%|██████▎   | 300/475 [01:49<01:04,  2.73it/s, Train: epoch 1/15]

[2,   300] obj_loss: 1.658, attr_loss: 1.883


 84%|████████▍ | 400/475 [02:25<00:27,  2.78it/s, Train: epoch 1/15]

[2,   400] obj_loss: 1.668, attr_loss: 1.882


100%|██████████| 475/475 [02:52<00:00,  2.76it/s, Train: epoch 1/15]
100%|██████████| 204/204 [00:41<00:00,  4.92it/s]


[2] obj_val_loss: 3.418, attr_val_loss: 4.408
Finished training.


 21%|██        | 100/475 [00:36<02:15,  2.76it/s, Train: epoch 2/15]

[3,   100] obj_loss: 1.609, attr_loss: 1.832


 42%|████▏     | 200/475 [01:13<01:38,  2.79it/s, Train: epoch 2/15]

[3,   200] obj_loss: 1.598, attr_loss: 1.831


 63%|██████▎   | 300/475 [01:49<01:03,  2.77it/s, Train: epoch 2/15]

[3,   300] obj_loss: 1.619, attr_loss: 1.849


 84%|████████▍ | 400/475 [02:26<00:27,  2.77it/s, Train: epoch 2/15]

[3,   400] obj_loss: 1.639, attr_loss: 1.871


100%|██████████| 475/475 [02:53<00:00,  2.74it/s, Train: epoch 2/15]
100%|██████████| 204/204 [00:41<00:00,  4.95it/s]


[3] obj_val_loss: 3.385, attr_val_loss: 4.433
Finished training.


 21%|██        | 100/475 [00:36<02:14,  2.79it/s, Train: epoch 3/15]

[4,   100] obj_loss: 1.560, attr_loss: 1.776


 42%|████▏     | 200/475 [01:12<01:39,  2.75it/s, Train: epoch 3/15]

[4,   200] obj_loss: 1.551, attr_loss: 1.777


 63%|██████▎   | 300/475 [01:49<01:03,  2.77it/s, Train: epoch 3/15]

[4,   300] obj_loss: 1.583, attr_loss: 1.807


 84%|████████▍ | 400/475 [02:25<00:27,  2.78it/s, Train: epoch 3/15]

[4,   400] obj_loss: 1.600, attr_loss: 1.823


100%|██████████| 475/475 [02:52<00:00,  2.75it/s, Train: epoch 3/15]
100%|██████████| 204/204 [00:41<00:00,  4.93it/s]


[4] obj_val_loss: 3.401, attr_val_loss: 4.468
Finished training.


 21%|██        | 100/475 [00:36<02:14,  2.80it/s, Train: epoch 4/15]

[5,   100] obj_loss: 1.541, attr_loss: 1.764


 42%|████▏     | 200/475 [01:13<01:38,  2.79it/s, Train: epoch 4/15]

[5,   200] obj_loss: 1.564, attr_loss: 1.793


 63%|██████▎   | 300/475 [01:50<01:02,  2.80it/s, Train: epoch 4/15]

[5,   300] obj_loss: 1.581, attr_loss: 1.810


 84%|████████▍ | 400/475 [02:27<00:26,  2.79it/s, Train: epoch 4/15]

[5,   400] obj_loss: 1.601, attr_loss: 1.834


100%|██████████| 475/475 [02:54<00:00,  2.73it/s, Train: epoch 4/15]
100%|██████████| 204/204 [00:41<00:00,  4.91it/s]


[5] obj_val_loss: 3.425, attr_val_loss: 4.512
Finished training.


 21%|██        | 100/475 [00:36<02:14,  2.79it/s, Train: epoch 5/15]

[6,   100] obj_loss: 1.542, attr_loss: 1.789


 42%|████▏     | 200/475 [01:13<01:39,  2.76it/s, Train: epoch 5/15]

[6,   200] obj_loss: 1.546, attr_loss: 1.789


 63%|██████▎   | 300/475 [01:49<01:03,  2.77it/s, Train: epoch 5/15]

[6,   300] obj_loss: 1.562, attr_loss: 1.799


 84%|████████▍ | 400/475 [02:25<00:26,  2.80it/s, Train: epoch 5/15]

[6,   400] obj_loss: 1.571, attr_loss: 1.797


100%|██████████| 475/475 [02:52<00:00,  2.75it/s, Train: epoch 5/15]
100%|██████████| 204/204 [00:41<00:00,  4.92it/s]


[6] obj_val_loss: 3.514, attr_val_loss: 4.512
Finished training.


 21%|██        | 100/475 [00:36<02:13,  2.80it/s, Train: epoch 6/15]

[7,   100] obj_loss: 1.520, attr_loss: 1.759


 42%|████▏     | 200/475 [01:12<01:38,  2.79it/s, Train: epoch 6/15]

[7,   200] obj_loss: 1.546, attr_loss: 1.787


 63%|██████▎   | 300/475 [01:48<01:02,  2.78it/s, Train: epoch 6/15]

[7,   300] obj_loss: 1.549, attr_loss: 1.795


 84%|████████▍ | 400/475 [02:25<00:26,  2.79it/s, Train: epoch 6/15]

[7,   400] obj_loss: 1.565, attr_loss: 1.803


100%|██████████| 475/475 [02:52<00:00,  2.75it/s, Train: epoch 6/15]
100%|██████████| 204/204 [00:41<00:00,  4.91it/s]


[7] obj_val_loss: 3.474, attr_val_loss: 4.530
Finished training.


 21%|██        | 100/475 [00:36<02:15,  2.76it/s, Train: epoch 7/15]

[8,   100] obj_loss: 1.524, attr_loss: 1.732


 42%|████▏     | 200/475 [01:13<01:39,  2.77it/s, Train: epoch 7/15]

[8,   200] obj_loss: 1.532, attr_loss: 1.741


 63%|██████▎   | 300/475 [01:51<01:03,  2.78it/s, Train: epoch 7/15]

[8,   300] obj_loss: 1.539, attr_loss: 1.764


 84%|████████▍ | 400/475 [02:27<00:27,  2.77it/s, Train: epoch 7/15]

[8,   400] obj_loss: 1.536, attr_loss: 1.765


100%|██████████| 475/475 [02:54<00:00,  2.72it/s, Train: epoch 7/15]
100%|██████████| 204/204 [00:41<00:00,  4.90it/s]


[8] obj_val_loss: 3.515, attr_val_loss: 4.565
Finished training.


 21%|██        | 100/475 [00:36<02:14,  2.79it/s, Train: epoch 8/15]

[9,   100] obj_loss: 1.486, attr_loss: 1.686


 42%|████▏     | 200/475 [01:13<01:50,  2.48it/s, Train: epoch 8/15]

[9,   200] obj_loss: 1.513, attr_loss: 1.740


 63%|██████▎   | 300/475 [01:50<01:02,  2.78it/s, Train: epoch 8/15]

[9,   300] obj_loss: 1.523, attr_loss: 1.752


 84%|████████▍ | 400/475 [02:26<00:26,  2.78it/s, Train: epoch 8/15]

[9,   400] obj_loss: 1.536, attr_loss: 1.761


100%|██████████| 475/475 [02:53<00:00,  2.74it/s, Train: epoch 8/15]
100%|██████████| 204/204 [00:40<00:00,  4.98it/s]


[9] obj_val_loss: 3.496, attr_val_loss: 4.602
Finished training.


 21%|██        | 100/475 [00:38<02:15,  2.78it/s, Train: epoch 9/15]

[10,   100] obj_loss: 1.493, attr_loss: 1.729


 42%|████▏     | 200/475 [01:14<01:38,  2.78it/s, Train: epoch 9/15]

[10,   200] obj_loss: 1.508, attr_loss: 1.739


 63%|██████▎   | 300/475 [01:51<01:02,  2.79it/s, Train: epoch 9/15]

[10,   300] obj_loss: 1.509, attr_loss: 1.746


 84%|████████▍ | 400/475 [02:28<00:27,  2.77it/s, Train: epoch 9/15]

[10,   400] obj_loss: 1.517, attr_loss: 1.755


100%|██████████| 475/475 [02:55<00:00,  2.71it/s, Train: epoch 9/15]
100%|██████████| 204/204 [00:40<00:00,  5.04it/s]


[10] obj_val_loss: 3.516, attr_val_loss: 4.556
Finished training.


 21%|██        | 100/475 [00:37<02:15,  2.77it/s, Train: epoch 10/15]

[11,   100] obj_loss: 1.468, attr_loss: 1.732


 42%|████▏     | 200/475 [01:13<01:39,  2.77it/s, Train: epoch 10/15]

[11,   200] obj_loss: 1.490, attr_loss: 1.732


 63%|██████▎   | 300/475 [01:49<01:05,  2.69it/s, Train: epoch 10/15]

[11,   300] obj_loss: 1.489, attr_loss: 1.742


 84%|████████▍ | 400/475 [02:26<00:26,  2.78it/s, Train: epoch 10/15]

[11,   400] obj_loss: 1.495, attr_loss: 1.749


100%|██████████| 475/475 [02:53<00:00,  2.74it/s, Train: epoch 10/15]
100%|██████████| 204/204 [00:40<00:00,  4.98it/s]


[11] obj_val_loss: 3.511, attr_val_loss: 4.620
Finished training.


 21%|██        | 100/475 [00:37<02:17,  2.73it/s, Train: epoch 11/15]

[12,   100] obj_loss: 1.480, attr_loss: 1.718


 42%|████▏     | 200/475 [01:13<01:39,  2.76it/s, Train: epoch 11/15]

[12,   200] obj_loss: 1.494, attr_loss: 1.726


 63%|██████▎   | 300/475 [01:49<01:03,  2.76it/s, Train: epoch 11/15]

[12,   300] obj_loss: 1.501, attr_loss: 1.734


 84%|████████▍ | 400/475 [02:26<00:26,  2.78it/s, Train: epoch 11/15]

[12,   400] obj_loss: 1.500, attr_loss: 1.742


100%|██████████| 475/475 [02:53<00:00,  2.73it/s, Train: epoch 11/15]
100%|██████████| 204/204 [00:41<00:00,  4.91it/s]


[12] obj_val_loss: 3.542, attr_val_loss: 4.597
Finished training.


 21%|██        | 100/475 [00:37<02:14,  2.78it/s, Train: epoch 12/15]

[13,   100] obj_loss: 1.435, attr_loss: 1.720


 42%|████▏     | 200/475 [01:14<01:39,  2.75it/s, Train: epoch 12/15]

[13,   200] obj_loss: 1.448, attr_loss: 1.723


 63%|██████▎   | 300/475 [01:51<01:02,  2.78it/s, Train: epoch 12/15]

[13,   300] obj_loss: 1.473, attr_loss: 1.720


 84%|████████▍ | 400/475 [02:27<00:27,  2.77it/s, Train: epoch 12/15]

[13,   400] obj_loss: 1.489, attr_loss: 1.731


100%|██████████| 475/475 [02:54<00:00,  2.72it/s, Train: epoch 12/15]
100%|██████████| 204/204 [00:41<00:00,  4.92it/s]


[13] obj_val_loss: 3.580, attr_val_loss: 4.663
Finished training.


 21%|██        | 100/475 [00:36<02:15,  2.78it/s, Train: epoch 13/15]

[14,   100] obj_loss: 1.413, attr_loss: 1.672


 42%|████▏     | 200/475 [01:13<01:38,  2.79it/s, Train: epoch 13/15]

[14,   200] obj_loss: 1.432, attr_loss: 1.692


 63%|██████▎   | 300/475 [01:49<01:02,  2.79it/s, Train: epoch 13/15]

[14,   300] obj_loss: 1.454, attr_loss: 1.719


 84%|████████▍ | 400/475 [02:26<00:26,  2.79it/s, Train: epoch 13/15]

[14,   400] obj_loss: 1.472, attr_loss: 1.729


100%|██████████| 475/475 [02:53<00:00,  2.73it/s, Train: epoch 13/15]
100%|██████████| 204/204 [00:41<00:00,  4.97it/s]


[14] obj_val_loss: 3.552, attr_val_loss: 4.649
Finished training.


 21%|██        | 100/475 [00:36<02:13,  2.80it/s, Train: epoch 14/15]

[15,   100] obj_loss: 1.484, attr_loss: 1.719


 42%|████▏     | 200/475 [01:13<01:38,  2.78it/s, Train: epoch 14/15]

[15,   200] obj_loss: 1.465, attr_loss: 1.698


 63%|██████▎   | 300/475 [01:49<01:02,  2.79it/s, Train: epoch 14/15]

[15,   300] obj_loss: 1.471, attr_loss: 1.706


 84%|████████▍ | 400/475 [02:26<00:27,  2.71it/s, Train: epoch 14/15]

[15,   400] obj_loss: 1.481, attr_loss: 1.715


100%|██████████| 475/475 [02:53<00:00,  2.74it/s, Train: epoch 14/15]
100%|██████████| 204/204 [00:41<00:00,  4.90it/s]


[15] obj_val_loss: 3.628, attr_val_loss: 4.720
Finished training.


In [81]:
obj_acc, attr_acc, report_cw, report_ow = eval_model(compoResnet, test_dataloader, train_dataloader)

print('A:{:.3f}|O:{:.3f}|CwSeen:{:.3f}|CwUnseen:{:.3f}|CwHM:{:.3f}|CwAUC:{:.3f}|OpSeen:{:.3f}|OpUnseen:{:.3f}|OpHM:{:.3f}|OpAUC:{:.3f}|'.format(
  attr_acc, obj_acc, *report_cw, *report_ow))

100%|██████████| 204/204 [01:25<00:00,  2.39it/s]

A:0.227|O:0.338|CwSeen:0.110|CwUnseen:0.152|CwHM:0.077|CwAUC:0.011|OpSeen:0.086|OpUnseen:0.041|OpHM:0.040|OpAUC:0.001|





In [None]:
plt.subplot(2, 2, 1)
plt.plot(obj_loss_history[0])
plt.subplot(2, 2, 2)
plt.plot(obj_loss_history[1])
plt.subplot(2, 2, 3)
plt.plot(attr_loss_history[0])
plt.subplot(2, 2, 4)
plt.plot(attr_loss_history[1])

In [None]:
config = {
    "lr": tune.loguniform(1e-4, 1e-1),
    "resnet": tune.choice(['resnet18', 'resnet50', 'resnet101']),
    "num_mlp_layers": tune.choice([1,2,4,6]),
}

In [None]:
num_samples = 12
num_epochs = 6
scheduler = ASHAScheduler(
        metric="loss",
        mode="min",
        max_t=num_epochs,
        grace_period=1,
        reduction_factor=2)
reporter = CLIReporter(metric_columns=["loss", "accuracy", "training_iteration"])
result = tune.run(
    partial(train_with_config, num_epochs=num_epochs),
    resources_per_trial={"cpu": 1, "gpu": 0.32},
    config=config,
    num_samples=num_samples,
    scheduler=scheduler,
    progress_reporter=reporter)

In [None]:
best_trial = result.get_best_trial("loss", "min", "last")
print("Best trial config: {}".format(best_trial.config))
print("Best trial final validation loss: {}".format(
    best_trial.last_result["loss"]))
print("Best trial final validation accuracy: {}".format(
    best_trial.last_result["accuracy"]))

resnet = frozen(torch.hub.load('pytorch/vision:v0.9.0', best_trial.config["resnet"], pretrained=True))
best_mlp = partial(HalvingMLP, num_layers=best_trial.config["num_mlp_layers"])
best_trained_model = CompoResnet(resnet, obj_class, attr_class, best_mlp).to(dev)

best_checkpoint_dir = best_trial.checkpoint.value
model_state = torch.load(os.path.join(
    best_checkpoint_dir, "checkpoint"))['model_state_dict']
best_trained_model.load_state_dict(model_state)

test_acc = calc_acc(best_trained_model, test_dataloader)
print("\nBest trial test set accuracy: {}".format(test_acc))

Matches:

[0.30456985, 0.15528112, 0.02720025] : MLP2, 30 Epochs, Adam