As the third step of this tutorial, we will train an text model. This notebook can be run in parallel with Step 2 (training the image model). A lof of the cells in this notebook are similar to the previous one.

This notebook was run on an AWS p3.2xlarge.

# Tonks Text Model Training Pipeline

In [1]:
%load_ext autoreload

%autoreload 2

In [2]:
import sys
sys.path.append('../../')

In [3]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
from transformers import AdamW, BertTokenizer, get_cosine_schedule_with_warmup

Note: for text, we use the MultiTaskLearner since we will only have one input, the text.

In [4]:
from tonks import MultiTaskLearner, MultiDatasetLoader
from tonks.text.dataset import TonksTextDataset
from tonks.text.models.multi_task_bert import BertForMultiTaskClassification

For our Bert model, we need a tokenizer. We'll use the one from huggingface's `transformers` library.

In [5]:
bert_tok = BertTokenizer.from_pretrained(
    'bert-base-uncased',
    do_lower_case=True
)

## Load in train and validation datasets

First we load in the csv's we created in Step 1.
Remember to change the path if you stored your data somewhere other than the default.

In [6]:
TRAIN_GENDER_DF = pd.read_csv('/home/ec2-user/fashion_dataset/gender_train.csv')

In [7]:
VALID_GENDER_DF = pd.read_csv('/home/ec2-user/fashion_dataset/gender_valid.csv')

In [8]:
TRAIN_SEASON_DF = pd.read_csv('/home/ec2-user/fashion_dataset/season_train.csv')

In [9]:
VALID_SEASON_DF = pd.read_csv('/home/ec2-user/fashion_dataset/season_valid.csv')

You will most likely have to alter this to however big your batches can be on your machine

In [10]:
batch_size = 64

We use the `TonksTextDataSet` class to create train and valid datasets for each task.

Check out the documentation for infomation about the `tokenizer` and `max_seq_length` arguments.

In [11]:
max_seq_length = 128

In [12]:
gender_train_dataset = TonksTextDataset(
    x=TRAIN_GENDER_DF['productDisplayName'],
    y=TRAIN_GENDER_DF['gender_cat'],
    tokenizer=bert_tok,
    max_seq_length=max_seq_length
)
gender_valid_dataset = TonksTextDataset(
    x=VALID_GENDER_DF['productDisplayName'],
    y=VALID_GENDER_DF['gender_cat'],
    tokenizer=bert_tok,
    max_seq_length=max_seq_length
)

season_train_dataset = TonksTextDataset(
    x=TRAIN_SEASON_DF['productDisplayName'],
    y=TRAIN_SEASON_DF['season_cat'],
    tokenizer=bert_tok,
    max_seq_length=max_seq_length
)
season_valid_dataset = TonksTextDataset(
    x=VALID_SEASON_DF['productDisplayName'],
    y=VALID_SEASON_DF['season_cat'],
    tokenizer=bert_tok,
    max_seq_length=max_seq_length
)

We then put the datasets into a dictionary of dataloaders.

Each task is a key.

In [13]:
train_dataloaders_dict = {
    'gender': DataLoader(gender_train_dataset, batch_size=batch_size, shuffle=True, num_workers=4),
    'season': DataLoader(season_train_dataset, batch_size=batch_size, shuffle=True, num_workers=4),
}
valid_dataloaders_dict = {
    'gender': DataLoader(gender_valid_dataset, batch_size=batch_size, shuffle=False, num_workers=4),
    'season': DataLoader(season_valid_dataset, batch_size=batch_size, shuffle=False, num_workers=4),
}

The dictionary of dataloaders is then put into an instance of the Tonks `MultiDatasetLoader` class.

In [14]:
TrainLoader = MultiDatasetLoader(loader_dict=train_dataloaders_dict)
len(TrainLoader)

730

In [15]:
ValidLoader = MultiDatasetLoader(loader_dict=valid_dataloaders_dict, shuffle=False)
len(ValidLoader)

244

We need to create a dictionary of the tasks and the number of unique values so that we can create our model.

In [16]:
new_task_dict = {
    'gender': TRAIN_GENDER_DF['gender_cat'].nunique(),
    'season': TRAIN_SEASON_DF['season_cat'].nunique(),
}

In [17]:
new_task_dict

{'gender': 5, 'season': 4}

In [18]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

cuda:0


Create Model and Learner
===

These are completely new tasks so we use `new_task_dict`. If we had already trained a model on some tasks, we would use `pretrained_task_dict`.

We are using the trained bert weights from the `transformers` library.

In [19]:
model = BertForMultiTaskClassification.from_pretrained(
    'bert-base-uncased',
    new_task_dict=new_task_dict
)

You will likely need to explore different values in this section to find some that work
for your particular model.

In [20]:
lr = 1e-5
num_total_steps = len(TrainLoader)
num_warmup_steps = int(len(TrainLoader) * 0.1)

optimizer = AdamW(model.parameters(), lr=lr, correct_bias=True)

scheduler = get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_total_steps
)

In [21]:
loss_function_dict = {'gender': 'categorical_cross_entropy', 'season': 'categorical_cross_entropy'}
metric_function_dict = {'gender': 'multi_class_acc', 'season': 'multi_class_acc'}

In [22]:
learn = MultiTaskLearner(model, TrainLoader, ValidLoader, new_task_dict, loss_function_dict, metric_function_dict)

Train Model
===

As your model trains, you can see some output of how the model is performing overall and how it is doing on each individual task.

In [23]:
learn.fit(
    num_epochs=10,
    scheduler=scheduler,
    step_scheduler_on_batch=False,
    optimizer=optimizer,
    device=device,
    best_model=True
)

train_loss,val_loss,gender_train_loss,gender_val_loss,gender_multi_class_accuracy,season_train_loss,season_val_loss,season_multi_class_accuracy,time
1.587308,0.024508,1.616462,0.024402,0.411798,1.548419,0.024651,0.068028,06:22
1.293475,0.018337,1.25692,0.017395,0.524372,1.342238,0.019594,0.481153,06:22
1.069644,0.014979,0.940256,0.011591,0.848475,1.24224,0.019499,0.481153,06:22
0.81146,0.011308,0.514079,0.005714,0.909377,1.208151,0.01877,0.481153,06:22
0.684814,0.01008,0.315183,0.003763,0.923337,1.177884,0.018507,0.481153,06:22
0.622825,0.009299,0.218015,0.002568,0.963188,1.162821,0.018277,0.483706,06:22
0.581187,0.008984,0.155033,0.002133,0.971406,1.149656,0.018124,0.48731,06:22
0.553077,0.008425,0.117599,0.001418,0.985478,1.133983,0.017772,0.497222,06:22
0.518735,0.008281,0.093014,0.001378,0.986829,1.086625,0.01749,0.477249,06:22
0.470713,0.008089,0.077829,0.001109,0.989193,0.994801,0.017399,0.499324,06:22


Epoch 9 best model saved with loss of 0.008088728412985802


Validate Model
===

We provide a method on the learner called `get_val_preds`, which makes predictions on the validation data. You can then use this to analyze your model's performance in more detail.

In [27]:
pred_dict = learn.get_val_preds(device)

In [28]:
pred_dict

{'gender': {'y_true': array([4, 2, 2, ..., 2, 1, 2]),
  'y_pred': array([[0.00402687, 0.00640622, 0.00899468, 0.00694958, 0.9736226 ],
         [0.00183889, 0.00166438, 0.9929559 , 0.00185515, 0.0016856 ],
         [0.21054211, 0.038667  , 0.65535367, 0.07843401, 0.01700323],
         ...,
         [0.00215105, 0.00220033, 0.99170154, 0.00226145, 0.00168571],
         [0.10741836, 0.7006959 , 0.05383366, 0.04526137, 0.09279067],
         [0.00235924, 0.00255873, 0.99062705, 0.00258605, 0.00186897]],
        dtype=float32)},
 'season': {'y_true': array([0, 2, 2, ..., 1, 3, 2]),
  'y_pred': array([[0.09411913, 0.04445902, 0.28530663, 0.57611525],
         [0.32891008, 0.06075584, 0.36052248, 0.24981157],
         [0.09471669, 0.04817087, 0.2728372 , 0.5842753 ],
         ...,
         [0.08327069, 0.04614794, 0.27187034, 0.598711  ],
         [0.05318563, 0.06048654, 0.1256119 , 0.76071596],
         [0.08759815, 0.04440476, 0.25398678, 0.61401033]], dtype=float32)}}

Save/Export Model
===

Once we are happy with our training we can save (or export) our model, using the `save` method (or `export`).

See the docs for the difference between `save` and `export`.

We will need the saved model later to use in the ensemble model

In [29]:
model.save(folder='/home/ec2-user/fashion_dataset/models/', model_id='TEXT_MODEL1')

In [30]:
model.export(folder='/home/ec2-user/fashion_dataset/models/', model_id='TEXT_MODEL1')

Now that we have an image model and a text model, we can move to `Step4_train_ensemble_model`.