# Tonks Ensemble Model Training Pipeline

As the seventh step of this tutorial, we will train an ensemble model using the two image models and one text model that we already trained.

This notebook was run on an AWS p3.2xlarge

In [1]:
%load_ext autoreload

%autoreload 2

In [2]:
import sys
sys.path.append('../../')

In [3]:
import joblib
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
from transformers import AdamW, BertTokenizer, get_cosine_schedule_with_warmup

from tonks.learner import MultiTaskLearner, MultiInputMultiTaskLearner
from tonks.dataloader import MultiDatasetLoader
from tonks.ensemble import TonksEnsembleDataset, BertResnetEnsembleForMultiTaskClassification

## Load in train and validation datasets

First we load in the csv's we created in Step 1.
Remember to change the path if you stored your data somewhere other than the default.

In [4]:
TRAIN_GENDER_DF = pd.read_csv('/home/ubuntu/fashion_dataset/gender_train.csv')

In [5]:
VALID_GENDER_DF = pd.read_csv('/home/ubuntu/fashion_dataset/gender_valid.csv')

In [6]:
TRAIN_SEASON_DF = pd.read_csv('/home/ubuntu/fashion_dataset/season_train.csv')

In [7]:
VALID_SEASON_DF = pd.read_csv('/home/ubuntu/fashion_dataset/season_valid.csv')

You will most likely have to alter this to however big your batches can be on your machine

In [8]:
batch_size = 128

In [9]:
bert_tok = BertTokenizer.from_pretrained(
    'bert-base-uncased',
    do_lower_case=True
)

max_seq_length = 128 

In [10]:
gender_train_dataset = TonksEnsembleDataset(
    text_inputs=TRAIN_GENDER_DF['productDisplayName'],
    img_inputs=TRAIN_GENDER_DF['image_urls'],
    y=TRAIN_GENDER_DF['gender_cat'],
    tokenizer=bert_tok,
    max_seq_length=max_seq_length,
    transform='train',
    crop_transform='train'

)
gender_valid_dataset = TonksEnsembleDataset(
    text_inputs=VALID_GENDER_DF['productDisplayName'],
    img_inputs=VALID_GENDER_DF['image_urls'],
    y=VALID_GENDER_DF['gender_cat'],
    tokenizer=bert_tok,
    max_seq_length=max_seq_length,
    transform='val',
    crop_transform='val'

)

season_train_dataset = TonksEnsembleDataset(
    text_inputs=TRAIN_SEASON_DF['productDisplayName'],
    img_inputs=TRAIN_SEASON_DF['image_urls'],
    y=TRAIN_SEASON_DF['season_cat'],
    tokenizer=bert_tok,
    max_seq_length=max_seq_length,
    transform='train',
    crop_transform='train'

)
season_valid_dataset = TonksEnsembleDataset(
    text_inputs=VALID_SEASON_DF['productDisplayName'],
    img_inputs=VALID_SEASON_DF['image_urls'],
    y=VALID_SEASON_DF['season_cat'],
    tokenizer=bert_tok,
    max_seq_length=max_seq_length,
    transform='val',
    crop_transform='val'
)

We then put the datasets into a dictionary of dataloaders.

Each task is a key.

In [11]:
train_dataloaders_dict = {
    'gender': DataLoader(gender_train_dataset, batch_size=batch_size, shuffle=True, num_workers=6),
    'season': DataLoader(season_train_dataset, batch_size=batch_size, shuffle=True, num_workers=6),
}
valid_dataloaders_dict = {
    'gender': DataLoader(gender_valid_dataset, batch_size=batch_size, shuffle=False, num_workers=6),
    'season': DataLoader(season_valid_dataset, batch_size=batch_size, shuffle=False, num_workers=6),
}

In [12]:
TrainLoader = MultiDatasetLoader(loader_dict=train_dataloaders_dict)
len(TrainLoader)

366

In [13]:
ValidLoader = MultiDatasetLoader(
    loader_dict=valid_dataloaders_dict,
    shuffle=False
)
len(ValidLoader)

123

Create Model and Learner
===

Since the image model could potentially have multiple Resnets for different subsets of tasks, we need to create an `image_task_dict` that splits up the tasks grouped by the Resnet they use.

This version uses different resnets for gender and season.

In [14]:
image_task_dict = {
    'gender': {
        'gender': TRAIN_GENDER_DF['gender_cat'].nunique()    
    },
    'season': {
        'season': TRAIN_SEASON_DF['season_cat'].nunique()
    }  
}

We still need to create the `new_task_dict` for the learner.

In [15]:
new_task_dict = {
    'gender': TRAIN_GENDER_DF['gender_cat'].nunique(),
    'season': TRAIN_SEASON_DF['season_cat'].nunique()
}

In [16]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

cuda:0


We first initialize the model by setting up the right shape with the image_task_dict.

In [17]:
model = BertResnetEnsembleForMultiTaskClassification(
    image_task_dict=image_task_dict
)

We then load in the existing models by specifying the folder where the models live and their id's.

In [18]:
resnet_model_id_dict = {
    'gender': 'GENDER_IMAGE_MODEL1',
    'season': 'SEASON_IMAGE_MODEL1'
}

In [19]:
model.load_core_models(
    folder='/home/ubuntu/fashion_dataset/models/',
    bert_model_id='TEXT_MODEL1',
    resnet_model_id_dict=resnet_model_id_dict
)

We've set some helper methods that will freeze the core bert and resnets for you if you only want to train the new layers. As with all other aspects of training, this is likely to require some experimentation to determine what works for your problem.

You will likely need to explore different values in this section to find some that work
for your particular model.

In [20]:
model.freeze_bert()
model.freeze_resnets()

loss_function = nn.CrossEntropyLoss()

lr_last = 1e-3
lr_main = 1e-5

lr_list = [
    {'params': model.bert.parameters(), 'lr': lr_main},
    {'params': model.dropout.parameters(), 'lr': lr_main},   
    {'params': model.image_resnets.parameters(), 'lr': lr_main},
    {'params': model.image_dense_layers.parameters(), 'lr': lr_main},
    {'params': model.ensemble_layers.parameters(), 'lr': lr_last},
    {'params': model.classifiers.parameters(), 'lr': lr_last},
]

optimizer = optim.Adam(lr_list)

exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size= 4, gamma= 0.1)

In [21]:
learn = MultiInputMultiTaskLearner(model, TrainLoader, ValidLoader, new_task_dict)

Train Model
===

As your model trains, you can see some output of how the model is performing overall and how it is doing on each individual task.

In [22]:
learn.fit(
    num_epochs=10,
    loss_function=loss_function,
    scheduler=exp_lr_scheduler,
    step_scheduler_on_batch=False,
    optimizer=optimizer,
    device=device,
    best_model=True
)

train_loss,val_loss,gender_train_loss,gender_val_loss,gender_acc,season_train_loss,season_val_loss,season_acc,time
0.293111,0.322882,0.051661,0.032669,0.989643,0.615227,0.710022,0.715423,06:29
0.266368,0.294048,0.032354,0.032366,0.990431,0.578563,0.643127,0.747259,06:30
0.258399,0.317407,0.028151,0.043071,0.988292,0.56557,0.683366,0.732693,06:29
0.258816,0.305413,0.026256,0.03147,0.990093,0.569071,0.670849,0.736147,06:29
0.249474,0.296359,0.024375,0.03041,0.990544,0.549775,0.65113,0.746508,06:29
0.245043,0.289674,0.020839,0.029983,0.991332,0.544149,0.636099,0.752365,06:28
0.246843,0.289804,0.020975,0.031275,0.990093,0.548169,0.634677,0.749812,06:30
0.2424,0.286017,0.0196,0.029878,0.991107,0.539635,0.627703,0.752816,06:29
0.24197,0.288154,0.020137,0.029609,0.991107,0.537913,0.633048,0.750113,06:29
0.240077,0.287712,0.019503,0.029321,0.991219,0.534341,0.632402,0.752816,06:29


Epoch 7 best model saved with loss of 0.2860171677235031


You will need to check your specific use case to determine whether it is better to train all of your image tasks in one resnet or in multiple resnets. You won't necessarily need one resnet per task, particularly if some of your tasks are related.

Checking validation data
===

We provide a method on the learner called `get_val_preds`, which makes predictions on the validation data. You can then use this to analyze your model's performance in more detail.

In [23]:
pred_dict = learn.get_val_preds(device)

In [24]:
pred_dict

{'gender': {'y_true': array([0., 2., 4., ..., 4., 2., 2.]),
  'y_pred': array([[9.92650092e-01, 3.23335262e-04, 6.87686214e-03, 1.36870003e-04,
          1.29461450e-05],
         [1.35707080e-06, 1.94859410e-07, 9.99997497e-01, 7.13497911e-07,
          2.35319845e-07],
         [7.03864571e-05, 2.32889410e-03, 4.51153865e-05, 6.30916620e-05,
          9.97492552e-01],
         ...,
         [9.33719548e-07, 5.31011210e-05, 5.51037431e-07, 9.38349842e-07,
          9.99944448e-01],
         [1.97838494e-06, 6.37559253e-07, 9.99994874e-01, 1.81650728e-06,
          6.87618694e-07],
         [2.48589913e-06, 6.35381923e-07, 9.99993920e-01, 2.18791865e-06,
          7.60277146e-07]])},
 'season': {'y_true': array([0., 0., 2., ..., 2., 3., 3.]),
  'y_pred': array([[8.31580341e-01, 1.52546994e-03, 1.59035668e-01, 7.85854086e-03],
         [8.96789908e-01, 4.13205649e-04, 1.00355275e-01, 2.44156295e-03],
         [7.72640035e-02, 3.73037672e-03, 8.75569582e-01, 4.34360504e-02],
         ...

Save/Export Model
===

The ensemble model can also be saved or exported.

In [25]:
model.save(folder='/home/ubuntu/fashion_dataset/models/', model_id='ENSEMBLE_MODEL2')

In [26]:
model.export(folder='/home/ubuntu/fashion_dataset/models/', model_id='ENSEMBLE_MODEL2')