# CAT: Multiclass Multilabel Multimodal Classification Models


Please download the dataset from [Datasets](https://www.kaggle.com/competitions/multi-label-classification-competition-2023/data).

## Setup

In [None]:
!git clone https://github.com/XavierSpycy/CAT-ImageTextIntegrator.git
%cd CAT-ImageTextIntegrator

In [None]:
from google.colab import files
files.upload()

In [None]:
!pip install -qq transformers # HuggingFace transformers

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import pandas as pd
from transformers import BertTokenizer, AdamW, get_linear_schedule_with_warmup
from cat.datasets import DatasetProcessor, MultimodalDataset
from cat.trainer import mul_clf_train
from cat.multimodal import WWDBert
from cat.evaluator import model_size, mul_model_f1_score_
from cat.predict import mul_clf_predict

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"We are using {device}.")

In [None]:
!mkdir data && unzip -q multi-label-classification-competition-2023.zip -d data
!mv data/COMP5329S1A2Dataset/* data/
!rm -rf data/COMP5329S1A2Dataset

## Example: The Optimal Multimodal Model

In [None]:
data_processor = DatasetProcessor()
(imgid_raw, caption_raw, label_binary_tensor), (imgid_train, caption_train, label_train_tensor), (imgid_valid, caption_valid, label_valid_tensor) = data_processor.get_train_validate()
imgid_test, caption_test = data_processor.get_test()
num_classes = data_processor.num_classes
img_folder = data_processor.image_folder
max_length = data_processor.max_length

In [None]:
imgid_txt_label_train = []
for i, j, k in zip(imgid_train, caption_train, label_train_tensor):
    imgid_txt_label_train.append((i, j, k))

imgid_txt_label_valid = []
for i, j, k in zip(imgid_valid, caption_valid, label_valid_tensor):
    imgid_txt_label_valid.append((i, j, k))

imgid_txt_label_entire = []
for i, j, k in zip(imgid_raw, caption_raw, label_binary_tensor):
    imgid_txt_label_entire.append((i, j, k))

tokenizer = BertTokenizer.from_pretrained("huawei-noah/TinyBERT_General_4L_312D")
mul_train = MultimodalDataset(imgid_txt_label_train, img_folder, tokenizer, max_length, 'augment', random_swap_=True)
mul_valid = MultimodalDataset(imgid_txt_label_valid, img_folder, tokenizer, max_length, 'normalize')
mul_train_eval = MultimodalDataset(imgid_txt_label_train, img_folder, tokenizer, max_length, 'normalize')

### Training

In [None]:
wwdbert = WWDBert(num_classes).to(device)
train_loader = DataLoader(mul_train, batch_size=16, shuffle=True, num_workers=2)
valid_loader = DataLoader(mul_valid, batch_size=100, shuffle=False, num_workers=2)
epochs = 100
optimizer = AdamW(wwdbert.parameters(), lr=1e-5, correct_bias=False)
criterion = nn.BCEWithLogitsLoss()
total_steps = len(train_loader) * epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)
mul_clf_train(wwdbert, train_loader, valid_loader, optimizer, criterion, scheduler, epochs)

### Performance

In [None]:
train_loader = DataLoader(mul_train_eval, batch_size=100, shuffle=False, num_workers=2)
valid_loader = DataLoader(mul_valid, batch_size=100, shuffle=False, num_workers=2)

model = wwdbert
size = model_size(model)
f1_train = mul_model_f1_score_(model, train_loader, threshold=0.40)
f1_valid = mul_model_f1_score_(model, valid_loader, threshold=0.40)

print(f"Model size: {size:.2f}MB;")
print(f"Model F1 score on the training set: {f1_train:.4f};")
print(f"Model F1 score on the validation set: {f1_valid:.4f}.")

### Predicting

In [None]:
imgid_txt_label_test = [(imgid, txt, 0) for imgid, txt in zip(imgid_test, caption_test)]
tokenizer = BertTokenizer.from_pretrained("huawei-noah/TinyBERT_General_4L_312D")
mul_test = MultimodalDataset(imgid_txt_label_test, img_folder, tokenizer, max_length, 'normalize')
mul_label_test = DataLoader(mul_test, batch_size=100, shuffle=False, num_workers=2)
label_test = mul_clf_predict(model, 'wwdbert', mul_label_test, threshold=0.40, device=device)
label_str = data_processor.decode(label_test)
pred_dict = {'ImageID': imgid_test, 'Labels': label_str}
df = pd.DataFrame(pred_dict)
#df.to_csv("predictions.csv", index=False)