# MIT dataset

In [None]:
# import some common libraries
import numpy as np
import os, json, cv2, random, io
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import random_split

from PIL import Image
import tqdm
import matplotlib.pyplot as plt

from functools import partial
import numpy as np
import torchvision
import torchvision.transforms as transforms
import ray
from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler

if torch.cuda.is_available():  
  dev = "cuda:0" 
else:  
  dev = "cpu" 

In [None]:
from symnet.utils import dataset

In [None]:
# train [img, attr_id, obj_id, pair_id, img_feature, img, attr_id, obj_id, pair_id, img_feature, aff_mask]
# test [img, attr_id, obj_id, pair_id, img_feature, aff_mask]

train_dataloader = dataset.get_dataloader('MIT', 'train', batchsize=64, with_image=True, shuffle=True)
test_dataloader = dataset.get_dataloader('MIT', 'test', batchsize=64, with_image=True)

obj_class = len(train_dataloader.dataset.obj2idx.keys())
attr_class = len(train_dataloader.dataset.attr2idx.keys())

print(f"obj_class: {obj_class}, attr_class: {attr_class}")

# ResNet

In [None]:
class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()
        
    def forward(self, x):
        return x

class MLP(nn.Module):
  def __init__(self, in_features, out_features):
    super(MLP, self).__init__()

    self.mlp = nn.Sequential(
        nn.Linear(in_features, in_features),
        nn.BatchNorm1d(in_features),
        nn.ReLU(),
        nn.Dropout(),
        nn.Linear(in_features, out_features))
    
  def forward(self, x):
    return self.mlp(x)


class MLP2(nn.Module):
  def __init__(self, in_features, out_features):
    super(MLP2, self).__init__()

    self.mlp = nn.Sequential(
        nn.Linear(in_features, in_features//2),
        nn.BatchNorm1d(in_features//2),
        nn.ReLU(),
        nn.Dropout(),
        nn.Linear(in_features//2, in_features//4),
        nn.BatchNorm1d(in_features//4),
        nn.ReLU(),
        nn.Dropout(),
        nn.Linear(in_features//4, out_features))
    
  def forward(self, x):
    return self.mlp(x)


class MLP3(nn.Module):
  def __init__(self, in_features, out_features):
    super(MLP3, self).__init__()

    self.mlp = nn.Sequential(
        nn.Linear(in_features, 1400),
        nn.BatchNorm1d(1400),
        nn.ReLU(),
        nn.Dropout(),
        nn.Linear(1400, 800),
        nn.BatchNorm1d(800),
        nn.ReLU(),
        nn.Dropout(),
        nn.Linear(800, 400),
        nn.BatchNorm1d(400),
        nn.ReLU(),
        nn.Dropout(),
        nn.Linear(400, out_features))
    
  def forward(self, x):
    return self.mlp(x)


class HalvingMLP(nn.Module):
  def __init__(self, in_features, out_features, num_layers=None):
      super(HalvingMLP, self).__init__()
      layers = []
      for i in range(num_layers):
        layer = nn.Sequential(
          nn.Linear(in_features, in_features//2),
          nn.BatchNorm1d(in_features//2),
          nn.ReLU(),
          nn.Dropout())
        layers.append(layer)
        in_features //= 2
      layers.append(nn.Linear(in_features, out_features))
      self.mlp = nn.Sequential(*layers)
    
  def forward(self, x):
    return self.mlp(x)

def frozen(model):
    for param in model.parameters():
        param.requires_grad = False
    return model

class CompoResnet(nn.Module):
  def __init__(self, resnet, obj_class, attr_class, MLP=MLP):
    super(CompoResnet, self).__init__()
    in_features = resnet.fc.in_features # 2048 for resnet101
    resnet.fc = Identity()
    self.resnet = resnet
    self.obj_fc = MLP(in_features, obj_class)
    self.attr_fc = MLP(in_features, attr_class)

  def forward(self, x):
    img_features = self.resnet(x)
    obj_pred = self.obj_fc(img_features)
    attr_pred = self.attr_fc(img_features)
    return obj_pred, attr_pred

In [None]:
def train_with_config(config, checkpoint_dir=None, num_epochs=1):
  lr = config['lr']
  resnet_name = config['resnet']
  num_mlp_layers = config['num_mlp_layers']
  mlp = partial(HalvingMLP, num_layers=num_mlp_layers)
  batch_size = 64

  resnet = frozen(torch.hub.load('pytorch/vision:v0.9.0', resnet_name, pretrained=True))
  compoResnet = CompoResnet(resnet, obj_class, attr_class, mlp).to(dev)
  obj_loss_history = [[],[]]
  attr_loss_history = [[],[]]
  optimizer = optim.Adam(compoResnet.parameters(), lr=lr)
  criterion = nn.CrossEntropyLoss()

  if checkpoint_dir:
    model_state, optimizer_state, obj_loss_history, attr_loss_history = torch.load(
        os.path.join(checkpoint_dir, "checkpoint"))
    compoResnet.load_state_dict(model_state)
    optimizer.load_state_dict(optimizer_state)


  train(compoResnet, optimizer, criterion, num_epochs, obj_loss_history, attr_loss_history, batch_size, use_tune=True)

def train(net, optimizer, criterion, num_epochs, obj_loss_history, attr_loss_history, batch_size, curr_epoch=0, use_tune=False, model_dir=None):
  dset = dataset.get_dataloader('MIT', 'train', with_image=True).dataset
  test_abs = int(len(dset) * 0.8)
  train_subset, val_subset = random_split(
        dset, [test_abs, len(dset) - test_abs])
  train_dataloader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
  val_dataloader = DataLoader(val_subset, batch_size=batch_size, shuffle=True)

  for epoch in range(curr_epoch, curr_epoch+num_epochs):
    epoch_steps = 0
    obj_running_loss = 0.0
    attr_running_loss = 0.0
    net.train()
    for i, batch in tqdm.tqdm(
        enumerate(train_dataloader),
        total=len(train_dataloader),
        disable=use_tune,
        position=0,
        leave=True,
        postfix='Train: epoch %d/%d'%(epoch, num_epochs)):
      optimizer.zero_grad()
      img, attr_id, obj_id = batch[:3]
      if len(img) == 1:
        # Batchnorm doesn't accept batch with size 1
        continue
      obj_pred, attr_pred = net(img.to(dev))
      obj_loss = criterion(obj_pred, obj_id.to(dev))
      attr_loss = criterion(attr_pred, attr_id.to(dev))
      loss = obj_loss + attr_loss
      loss.backward()
      optimizer.step()

      obj_running_loss += obj_loss.item()
      attr_running_loss += attr_loss.item()
      epoch_steps += 1
      if i % 100 == 99:
          print("[%d, %5d] obj_loss: %.3f, attr_loss: %.3f" % (epoch+1, i + 1,
                                          obj_running_loss / epoch_steps, attr_running_loss / epoch_steps))
          obj_loss_history[0].append(obj_running_loss/epoch_steps)
          attr_loss_history[0].append(attr_running_loss/epoch_steps)
          running_loss = 0.0

    # Validation loss
    obj_val_loss = 0.0
    attr_val_loss = 0.0
    val_steps = 0
    
    net.eval()
    for i, batch in tqdm.tqdm(
          enumerate(val_dataloader),
          total=len(val_dataloader),
          disable=use_tune,
          position=0,
          leave=True):
        with torch.no_grad():
            img, attr_id, obj_id = batch[:3]
            obj_pred, attr_pred = net(img.to(dev))
            obj_loss = criterion(obj_pred, obj_id.to(dev))
            attr_loss = criterion(attr_pred, attr_id.to(dev))
            obj_val_loss += obj_loss.cpu().numpy()
            attr_val_loss += attr_loss.cpu().numpy()
            val_steps += 1
    
    obj_val_loss /= val_steps
    attr_val_loss /= val_steps
    print("[%d] obj_val_loss: %.3f, attr_val_loss: %.3f" % (epoch+1, obj_avl_loss, attr_val_loss ))
    obj_loss_history[1].append(obj_avl_loss)
    attr_loss_history[1].append(attr_val_loss)
        
    if use_tune:
      with tune.checkpoint_dir(epoch) as checkpoint_dir: 
          path = os.path.join(checkpoint_dir, "checkpoint")
          torch.save({
                      'model_state_dict': net.state_dict(),
                      'optimizer_state_dict': optimizer.state_dict(),
                      'obj_loss': obj_loss_history,
                      'attr_loss': attr_loss_history,
                      }, path)
      acc = calc_acc(net, val_dataloader, use_tune)
      tune.report(loss=(obj_val_loss+attr_val_loss), accuracy=acc)
      print("accuracy: ", acc)
    else:
      if model_dir:
        model_path = os.path.join(model_dir, f"model_{curr_epoch+epoch}.pt")
        torch.save({
                      'model_state_dict': net.state_dict(),
                      'optimizer_state_dict': optimizer.state_dict(),
                      'obj_loss': obj_loss_history,
                      'attr_loss': attr_loss_history,
                      }, model_path)
    print("Finished training.")

In [None]:
def calc_acc(model, test_dataloader, use_tune=False):  
  def match(labels, preds):
    preds = torch.argmax(preds, axis=-1)
    return torch.sum(preds == labels)

  def compoMatch(obj_labels, obj_preds, attr_labels, attr_preds):
    obj_preds = torch.argmax(obj_preds, axis=-1)
    attr_preds = torch.argmax(attr_preds, axis=-1)
    comp_match = (obj_labels == obj_preds) * (attr_labels == attr_preds)
    return torch.sum(comp_match)

  obj_match, attr_match, comp_match = 0, 0, 0
  with torch.no_grad():
    model.eval()
    for i, batch in tqdm.tqdm(
        enumerate(test_dataloader),
        total=len(test_dataloader),
        disable=use_tune,
        position=0,
        leave=True):
      img, attr_id, obj_id = batch[:3]
      obj_preds, attr_preds = model(img.to(dev))
      obj_preds, attr_preds = obj_preds.to('cpu'), attr_preds.to('cpu')
      obj_match += match(obj_id, obj_preds)
      attr_match += match(attr_id, attr_preds)
      comp_match += compoMatch(obj_id, obj_preds, attr_id, attr_preds)
  return np.array([obj_match, attr_match, comp_match]) / len(test_dataloader.dataset)

In [None]:
resnet = frozen(torch.hub.load('pytorch/vision:v0.9.0', 'resnet101', pretrained=True))
mlp = partial(HalvingMLP, num_layers=2)
compoResnet = CompoResnet(resnet, obj_class, attr_class, mlp).to(dev)

obj_loss_history = [[],[]]
attr_loss_history = [[],[]]
optimizer = torch.optim.Adam(compoResnet.parameters())
criterion = nn.CrossEntropyLoss()
curr_epoch = 0

model_dir = './models/'
model_name = None
model_path = None if not model_name else os.path.join(model_dir, model_name)

if model_path:
  #checkpoint = torch.load(model_path), map_location=torch.device('cpu'))
  checkpoint = torch.load(model_path)
  compoResnet.load_state_dict(checkpoint['model_state_dict'])
  optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  curr_epoch = checkpoint['epoch']
  obj_loss_history = checkpoint['obj_loss']
  attr_loss_history = checkpoint['attr_loss']

In [None]:
train(compoResnet, optimizer, criterion, 1, obj_loss_history, attr_loss_history, 10, curr_epoch=0, use_tune=False, model_dir=model_dir)

In [None]:
config = {
    "lr": tune.loguniform(1e-4, 1e-1),
    "resnet": tune.choice(['resnet18', 'resnet50', 'resnet101']),
    "num_mlp_layers": tune.choice([1,2,4,6]),
}

In [None]:
num_samples = 12
num_epochs = 6
scheduler = ASHAScheduler(
        metric="loss",
        mode="min",
        max_t=num_epochs,
        grace_period=1,
        reduction_factor=2)
reporter = CLIReporter(metric_columns=["loss", "accuracy", "training_iteration"])
result = tune.run(
    partial(train_with_config, num_epochs=num_epochs),
    resources_per_trial={"cpu": 1, "gpu": 0.32},
    config=config,
    num_samples=num_samples,
    scheduler=scheduler,
    progress_reporter=reporter)

In [None]:
best_trial = result.get_best_trial("loss", "min", "last")
print("Best trial config: {}".format(best_trial.config))
print("Best trial final validation loss: {}".format(
    best_trial.last_result["loss"]))
print("Best trial final validation accuracy: {}".format(
    best_trial.last_result["accuracy"]))

resnet = frozen(torch.hub.load('pytorch/vision:v0.9.0', best_trial.config["resnet"], pretrained=True))
best_mlp = partial(HalvingMLP, num_layers=best_trial.config["num_mlp_layers"])
best_trained_model = CompoResnet(resnet, obj_class, attr_class, best_mlp).to(dev)

best_checkpoint_dir = best_trial.checkpoint.value
model_state = torch.load(os.path.join(
    best_checkpoint_dir, "checkpoint"))['model_state_dict']
best_trained_model.load_state_dict(model_state)

test_acc = calc_acc(best_trained_model, test_dataloader)
print("\nBest trial test set accuracy: {}".format(test_acc))

Matches:

[0.30456985, 0.15528112, 0.02720025] : MLP2, 30 Epochs, Adam