# Tonks Ensemble Model Training Pipeline

As the fourth (and final) step of this tutorial, we will train an ensemble model using the image and text models we've already trained.

This notebook was run on an AWS p3.2xlarge

In [1]:
%load_ext autoreload

%autoreload 2

In [2]:
import sys
sys.path.append('../../')

In [3]:
import joblib
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
from transformers import AdamW, BertTokenizer, get_cosine_schedule_with_warmup

from tonks.learner import MultiTaskLearner, MultiInputMultiTaskLearner
from tonks.dataloader import MultiDatasetLoader
from tonks.ensemble import TonksEnsembleDataset, BertResnetEnsembleForMultiTaskClassification

## Load in train and validation datasets

First we load in the csv's we created in Step 1.
Remember to change the path if you stored your data somewhere other than the default.

In [4]:
TRAIN_COLOR_DF = pd.read_csv('data/color_swatches/color_train.csv')

In [5]:
VALID_COLOR_DF = pd.read_csv('data/color_swatches/color_valid.csv')

In [6]:
TRAIN_PATTERN_DF = pd.read_csv('data/pattern_swatches/pattern_train.csv')

In [7]:
VALID_PATTERN_DF = pd.read_csv('data/pattern_swatches/pattern_valid.csv')

You will most likely have to alter this to however big your batches can be on your machine

In [8]:
batch_size = 16

In [9]:
bert_tok = BertTokenizer.from_pretrained(
    'bert-base-uncased',
    do_lower_case=True
)

max_seq_length = 128 

In [10]:
color_train_dataset = TonksEnsembleDataset(
    text_inputs=TRAIN_COLOR_DF['complex_color'],
    img_inputs=TRAIN_COLOR_DF['image_locs'],
    y=TRAIN_COLOR_DF['simple_color_cat'],
    tokenizer=bert_tok,
    max_seq_length=max_seq_length,
    transform='train',
    crop_transform='train'

)
color_valid_dataset = TonksEnsembleDataset(
    text_inputs=VALID_COLOR_DF['complex_color'],
    img_inputs=VALID_COLOR_DF['image_locs'],
    y=VALID_COLOR_DF['simple_color_cat'],
    tokenizer=bert_tok,
    max_seq_length=max_seq_length,
    transform='val',
    crop_transform='val'

)

pattern_train_dataset = TonksEnsembleDataset(
    text_inputs=VALID_PATTERN_DF['fake_text'],
    img_inputs=VALID_PATTERN_DF['image_locs'],
    y=VALID_PATTERN_DF['pattern_type_cat'],
    tokenizer=bert_tok,
    max_seq_length=max_seq_length,
    transform='train',
    crop_transform='train'

)
pattern_valid_dataset = TonksEnsembleDataset(
    text_inputs=VALID_PATTERN_DF['fake_text'],
    img_inputs=VALID_PATTERN_DF['image_locs'],
    y=VALID_PATTERN_DF['pattern_type_cat'],
    tokenizer=bert_tok,
    max_seq_length=max_seq_length,
    transform='val',
    crop_transform='val'
)

We then put the datasets into a dictionary of dataloaders.

Each task is a key.

In [11]:
train_dataloaders_dict = {
    'color': DataLoader(color_train_dataset, batch_size=batch_size, shuffle=True, num_workers=2),
    'pattern': DataLoader(pattern_train_dataset, batch_size=batch_size, shuffle=True, num_workers=2),
}
valid_dataloaders_dict = {
    'color': DataLoader(color_valid_dataset, batch_size=batch_size, shuffle=False, num_workers=2),
    'pattern': DataLoader(pattern_valid_dataset, batch_size=batch_size, shuffle=False, num_workers=2),
}

In [12]:
TrainLoader = MultiDatasetLoader(loader_dict=train_dataloaders_dict)
len(TrainLoader)

23

In [13]:
ValidLoader = MultiDatasetLoader(
    loader_dict=valid_dataloaders_dict,
    shuffle=False
)
len(ValidLoader)

9

Create Model and Learner
===

Since the image model could potentially have multiple Resnets for different subsets of tasks, we need to create an `image_task_dict` that splits up the tasks grouped by the Resnet they use.

In [14]:
image_task_dict = {
    'color_pattern': {
        'color': TRAIN_COLOR_DF['simple_color_cat'].nunique(),
        'pattern': TRAIN_PATTERN_DF['pattern_type_cat'].nunique()
    }  
}

We still need to create the `new_task_dict` for the learner.

In [15]:
new_task_dict = {
    'color': TRAIN_COLOR_DF['simple_color_cat'].nunique(),
    'pattern': TRAIN_PATTERN_DF['pattern_type_cat'].nunique()
}

In [16]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

cuda:0


We first initialize the model by setting up the right shape with the image_task_dict.

In [17]:
model = BertResnetEnsembleForMultiTaskClassification(
    image_task_dict=image_task_dict
)

We then load in the existing models by specifying the folder where the models live and their id's.

In [18]:
resnet_model_id_dict = {
    'color_pattern': 'IMAGE_MODEL1'
}

In [19]:
model.load_core_models(
    folder='models/',
    bert_model_id='TEXT_MODEL1',
    resnet_model_id_dict=resnet_model_id_dict
)

We've set some helper methods that will freeze the core bert and resnets for you if you only want to train the new layers. As with all other aspects of training, this is likely to require some experimentation to determine what works for your problem.

You will likely need to explore different values in this section to find some that work
for your particular model.

In [20]:
model.freeze_bert()
model.freeze_resnets()

loss_function = nn.CrossEntropyLoss()

lr_last = 1e-3
lr_main = 1e-5

lr_list = [
    {'params': model.bert.parameters(), 'lr': lr_main},
    {'params': model.dropout.parameters(), 'lr': lr_main},   
    {'params': model.image_resnets.parameters(), 'lr': lr_main},
    {'params': model.image_dense_layers.parameters(), 'lr': lr_main},
    {'params': model.ensemble_layers.parameters(), 'lr': lr_last},
    {'params': model.classifiers.parameters(), 'lr': lr_last},
]

optimizer = optim.Adam(lr_list)

exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size= 4, gamma= 0.1)

In [21]:
learn = MultiInputMultiTaskLearner(model, TrainLoader, ValidLoader, new_task_dict)

Train Model
===

As your model trains, you can see some output of how the model is performing overall and how it is doing on each individual task.

In [22]:
learn.fit(
    num_epochs=10,
    loss_function=loss_function,
    scheduler=exp_lr_scheduler,
    step_scheduler_on_batch=False,
    optimizer=optimizer,
    device=device,
    best_model=True
)

train_loss,val_loss,color_train_loss,color_val_loss,color_acc,pattern_train_loss,pattern_val_loss,pattern_acc,time
0.269334,0.372842,0.227687,0.264316,0.889908,0.809089,0.846015,0.52,00:03
0.223326,0.397131,0.184951,0.286414,0.87156,0.720665,0.879854,0.52,00:03
0.202404,0.357665,0.172665,0.27039,0.87156,0.58783,0.738185,0.52,00:03
0.23221,0.367504,0.207638,0.275364,0.880734,0.550674,0.769239,0.52,00:03
0.133768,0.376696,0.097655,0.275917,0.87156,0.601798,0.81609,0.52,00:03
0.147174,0.388062,0.121395,0.296142,0.862385,0.481277,0.788833,0.52,00:03
0.171947,0.370653,0.141226,0.272476,0.880734,0.570095,0.798706,0.52,00:03
0.123373,0.375156,0.10867,0.278227,0.880734,0.313924,0.79777,0.52,00:03
0.117283,0.361361,0.098623,0.260629,0.880734,0.359123,0.80055,0.52,00:03
0.144764,0.368404,0.123345,0.265812,0.880734,0.42236,0.815706,0.52,00:03


Epoch 2 best model saved with loss of 0.3576649023748156


Ideally the ensemble would perform better than either the image or text model alone, but our performance is probably suffering due to this being synthetic data.

Checking validation data
===

We provide a method on the learner called `get_val_preds`, which makes predictions on the validation data. You can then use this to analyze your model's performance in more detail.

In [23]:
pred_dict = learn.get_val_preds(device)

In [24]:
pred_dict

{'color': {'y_true': array([1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 0., 0., 1., 1., 0., 0., 1.,
         0., 0., 1., 1., 1., 1., 1., 1., 0., 0., 1., 0., 0., 0., 1., 1., 0.,
         1., 1., 0., 0., 1., 1., 1., 0., 1., 1., 0., 0., 0., 1., 1., 0., 1.,
         1., 0., 1., 1., 1., 1., 1., 1., 0., 1., 0., 1., 0., 1., 0., 1., 1.,
         1., 0., 1., 0., 1., 1., 1., 1., 0., 0., 0., 1., 1., 0., 1., 0., 0.,
         0., 1., 1., 0., 1., 1., 0., 0., 1., 1., 0., 1., 1., 1., 1., 1., 0.,
         1., 1., 0., 1., 1., 1., 1.]),
  'y_pred': array([[2.41741794e-03, 9.97582555e-01],
         [2.05460470e-03, 9.97945368e-01],
         [2.27170484e-03, 9.97728288e-01],
         [9.93301749e-01, 6.69822237e-03],
         [1.75855076e-03, 9.98241425e-01],
         [7.11645419e-03, 9.92883563e-01],
         [8.93120281e-03, 9.91068721e-01],
         [1.27694395e-03, 9.98723090e-01],
         [1.07391253e-01, 8.92608702e-01],
         [8.78677983e-03, 9.91213262e-01],
         [9.64594662e-01, 3.54053192e-02

Save/Export Model
===

The ensemble model can also be saved or exported.

In [25]:
model.save(folder='models/', model_id='ENSEMBLE_MODEL1')

In [26]:
model.export(folder='models/', model_id='ENSEMBLE_MODEL1')