# Classification BERT + CNN
## Pytorch + Lightning (facilitates training on GPU)

In [1]:
import os
import torch
import wandb
import torch.nn as nn
import numpy as np
import random
import pandas as pd
import lightning as L
from torchvision import transforms, datasets
import torchvision.models as models
from torch.utils.data import DataLoader
from pytorch_lightning.loggers import WandbLogger
from transformers import DistilBertModel
from sklearn.metrics import precision_recall_fscore_support, accuracy_score

ModuleNotFoundError: No module named 'wandb'

#### Set seed for reproducibility 
- not working for initial weights yet

In [None]:
# Reproducibility
seed = 1
L.seed_everything(seed=seed, workers=True)
torch.manual_seed(seed)
random.seed(seed)

Seed set to 1


#### Define folders containing images

In [3]:
data_dir = "data"
train_dir = os.path.join(data_dir, "train_images")            # where the images are stored to be trained on
batch_dir = os.path.join(data_dir, "mini_batch")                 # optional to train on a subset of images
test_dir = os.path.join(data_dir, "test_images")              # validate training against test folder to find accuracy
validate_dir = os.path.join(data_dir, "validate_images")      # validate training against validate folder to find accuracy

#### Load dataset

In [4]:
df_train = pd.read_csv(os.path.join(data_dir, 'multimodal_train.tsv'), sep='\t')
df_train.head(1)

Unnamed: 0,author,clean_title,created_utc,domain,hasImage,id,image_url,linked_submission_id,num_comments,score,subreddit,title,upvote_ratio,2_way_label,3_way_label,6_way_label
0,Alexithymia,my walgreens offbrand mucinex was engraved wit...,1551641000.0,i.imgur.com,True,awxhir,https://external-preview.redd.it/WylDbZrnbvZdB...,,2.0,12,mildlyinteresting,My Walgreens offbrand Mucinex was engraved wit...,0.84,1,0,0


#### Filter image ids and labels
* filter valid images

In [5]:
df_train_labels = df_train[['id','clean_title', '2_way_label', '3_way_label', '6_way_label']]
df_train_labels.set_index('id', inplace=True)
img_ids = [img.split('.')[0] for img in os.listdir(train_dir)]
df_train_labels = df_train_labels.loc[img_ids]
print("length of dataframe: ",len(df_train_labels))
df_train_labels.head(5)

length of dataframe:  9993


Unnamed: 0_level_0,clean_title,2_way_label,3_way_label,6_way_label
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
ddlxep5,displacement,0,2,4
58xyui,man holding a family size box of cheerios,1,0,0
5sg0yr,reaction to the super bow li,0,2,2
cypg4ku,war of the worlds,0,2,4
ez42yp0,goldilocks and the three bears,0,2,4


#### SAMPLE 50/50 for 2_way_label into mini_batch

In [6]:
import shutil
num_samples = min(6, len(df_train_labels))

if os.path.exists(batch_dir):
    print("recreating batch_directory")
    shutil.rmtree(batch_dir)
    os.makedirs(batch_dir)
else:
    os.makedirs(batch_dir)

if len(os.listdir(batch_dir))  >= num_samples:
    print("Images already exist in mini_batch folder")
else:
    print("Generate even samples in mini_batch folder")

    try:
        # Sample 50 IDs where 2_way_label is 0
        false_label = df_train_labels[df_train_labels['2_way_label'] == 0].sample(n=num_samples//2, random_state=seed).index.values
        # Sample 50 IDs where 2_way_label is 1
        true_label = df_train_labels[df_train_labels['2_way_label'] == 1].sample(n=num_samples//2, random_state=seed).index.values
        sampled_ids = np.concatenate([false_label, true_label])
    except ValueError:
        print("Not enough samples to create even distribution")
        sampled_ids = df_train_labels.index.values
        
    random.shuffle(sampled_ids)
    img_paths = [os.path.join(train_dir, file_name+".jpg") for file_name in sampled_ids]
    
    for i,img_path in enumerate(img_paths):
        # img_name = img_path.split('/')[-1]
        img_name = os.path.basename(img_path)
        shutil.copy(img_path, f"{batch_dir}/{img_name}")
        print(f"Generate even 2_way_label batch data: {i+1}/{len(img_paths)}", end="\r")
    print(f"Generate even 2_way_label batch data: {i+1}/{len(img_paths)}", end="\r")

    # Check distribution of labels
    display(df_train_labels.loc[sampled_ids].groupby('2_way_label').size())

recreating batch_directory
Generate even samples in mini_batch folder
Generate even 2_way_label batch data: 100/100

2_way_label
0    50
1    50
dtype: int64

#### Create the BERT + resnet50 class

In [15]:
class BERT_CNN_Classifier(L.LightningModule):
    
    def __init__(self, input_size, num_layers, num_channels, classes):
        super().__init__()

        ######### Learning rate #########
        self.lr = 0.01
        #self.batch_size = 1

        ######### Parameters #########
        self.input_size = input_size
        self.num_layers = num_layers
        self.num_channels = num_channels
        self.classes = classes

        ######### Criterion #########
        self.criterion = nn.CrossEntropyLoss()


        ########## BERT #########
        self.bert_model = DistilBertModel.from_pretrained('distilbert-base-uncased')
        bert_out_size = self.bert_model.config.hidden_size
    

        ######### CNN #########
        #models.resnet50(weights='ResNet50_Weights.DEFAULT')

        # models.efficientnet.efficientnet_b4(weights=models.EfficientNet_B4_Weights.DEFAULT) 
        self.cnn_model = models.resnet50(weights='ResNet50_Weights.DEFAULT')
        cnn_out_size = self.cnn_model.fc.in_features
        self.cnn_model.fc = nn.Identity()

        #cnn_out_size = self.cnn_model.classifier[-1].in_features
        #self.cnn_model.fc.out_features = nn.Identity()   #.classifier = nn.Identity()

        ######### Dropout #########
        self.dropout = nn.Dropout(0.2)

        ######### Fully connected layers #########
        self.fc1 = nn.Linear(cnn_out_size+bert_out_size, classes)


    def configure_optimizers(self):
        #return torch.optim.Adam(self.parameters(), lr=self.lr)
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.parameters()), lr=self.lr)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
        return [optimizer], [scheduler]

    def train_dataloader(self):
        return self.dataloader


    # def test_step(self, batch, batch_idx):
    #     tokens, images, labels = batch
    #     predictions = self(tokens, images)
    #     loss = self.criterion(predictions, labels)
        
    #     predicted = torch.argmax(predictions, dim=1)
    #     precision, recall, f1, _ = precision_recall_fscore_support(labels, predicted, average='weighted',zero_division=1)
    #     accuracy = accuracy_score(labels, predicted)

    #     metrics = {"test_loss": loss, "test_accuracy": accuracy, "test_precision": precision, "test_recall": recall, "test_f1": f1}
    #     self.log_dict(metrics)
    #     return metrics
    

    def training_step(self, batch, batch_idx):
        tokens, images, labels = batch
       
        predictions = self(tokens, images)
        loss = self.criterion(predictions, labels)

        # calculate accuracy
        #total = labels.size(0)
        predicted = torch.argmax(predictions, dim=1)
        # correct = torch.eq(predicted, labels).sum().item()
        # accuracy = correct / total
        # self.log("train-accuracy", accuracy)
        #wandb.log({"train-accuracy": accuracy})

        precision, recall, f1, _ = precision_recall_fscore_support(labels, predicted, average='weighted',zero_division=1)
        accuracy = accuracy_score(labels, predicted)

        # Log metrics
        self.log("train-loss", loss, prog_bar=True)
        self.log("train-accuracy", accuracy, prog_bar=True)
        self.log("train-precision", precision, prog_bar=True)
        self.log("train-recall", recall, prog_bar=True)
        self.log("train-f1", f1, prog_bar=True)

        return loss
    
    def forward(self, tokens, image):

        ####### Forward pass through CNN #####
        cnn_x = self.cnn_model(image)
        cnn_x = cnn_x.view(cnn_x.size(0), -1)
        #####################################
        
        ####### Forward pass through BERT #####
        attention_mask = (tokens != 0).to(torch.long)
        bert_x = self.bert_model(tokens, attention_mask=attention_mask).last_hidden_state[:,-1,:]
        #####################################
        
        # Concatenate the outputs of the two models
        concat_x = torch.cat((cnn_x, bert_x), dim=1)

        # Add dropout layer
        concat_x = self.dropout(concat_x)
        
        # Fully connected layer
        out = self.fc1(concat_x)
        
        return out
    

In [16]:
from modules.image_text_dataset import Multi_Modal_Dataset_Tensors
from modules.collate_fn import collate_X_Y_Z


img_to_tensor = transforms.Compose([
    transforms.Resize(232, interpolation=transforms.InterpolationMode.BILINEAR),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


# Create dataset
classes = 2
dataset = Multi_Modal_Dataset_Tensors(
    img_dir=batch_dir,
    classes=classes,
    df_labels=df_train_labels,
    transform=img_to_tensor,
)

multimodal_model = BERT_CNN_Classifier(
    input_size=dataset.max_seq_len,
    num_layers=1,
    num_channels=3,
    classes=classes
)

dataloader = DataLoader(
    dataset,
    batch_size=1,
    shuffle=True,
    collate_fn=collate_X_Y_Z,
    num_workers=3,
    persistent_workers=True,
    pin_memory=True
)

#### Test forward pass through model

In [17]:
X, Y, Z = next(iter(dataloader))
multimodal_model(X, Y)

tensor([[0.1133, 0.0422]], grad_fn=<AddmmBackward0>)

#### Setup WandB

In [18]:
import os
import config # file containing secrets

os.environ['WANDB_NOTEBOOK_NAME'] = './BERT_CNN.ipynb'
os.environ['WANDB_API_KEY'] = config.WANDB_API_KEY #"7ca1603d429e49e6afc9245a7bde51b76cddb314"

try:
    print("already logged on wandb...")
    wandb.init(project='dat550_Bert_EffNetB4', reinit=True, name=f'run_{"multi_modal"}_{"adam"}')
    wandb_logger = WandbLogger()
except wandb.errors.UsageError:
    print("logging on wandb...")
    wandb.login()
    wandb.init(project='dat550_Bert_EffNetB4', reinit=True, name=f'run_{"multi_modal"}_{"adam"}')
    wandb_logger = WandbLogger()

already logged on wandb...


[34m[1mwandb[0m: Currently logged in as: [33mhakurem[0m. Use [1m`wandb login --relogin`[0m to force relogin


#### Train model

``batch_size``: number of samples to process at once

``epochs``: number of times the model will see the whole dataset

``accelerator``: "auto" --> let Lightning detect GPU to use

``devices``: "auto" --> let Lightning decide how many GPUs are available and use them

``log_every_n_steps``: logging rate in the output


In [19]:
# accelerator: "auto" --> let Lightning detect GPU to use
# devices: "auto" --> let Lightning decide how many GPUs are available and use them
trainer = L.Trainer(
    logger=wandb_logger, 
    max_epochs=10, 
    accelerator="auto", 
    devices="auto",
    log_every_n_steps=len(dataset)//dataloader.batch_size,
    profiler="simple",
    deterministic=True,
    accumulate_grad_batches=len(dataset)//dataloader.batch_size
)

if torch.cuda.is_available():
   multimodal_model.cuda()

# set model to train mode
multimodal_model.train()

trainer.fit(multimodal_model,train_dataloaders=dataloader)
wandb.finish()

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Users/eriktruong/Documents/TrainModel_Gorina/venv/lib/python3.12/site-packages/pytorch_lightning/loggers/wandb.py:391: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.

  | Name       | Type             | Params
------------------------------------------------
0 | criterion  | CrossEntropyLoss | 0     
1 | bert_model | DistilBertModel  | 66.4 M
2 | cnn_model  | ResNet           | 23.5 M
3 | dropout    | Dropout          | 0     
4 | fc1        | Linear           | 5.6 K 
------------------------------------------------
89.9 M    Trainable params
0         Non-trainable params
89.9 M    Total params
359.506   Total estimated model params size (MB)


Epoch 4:  83%|████████▎ | 5/6 [00:15<00:03,  0.32it/s, v_num=e7sa, train-loss=0.000206, train-accuracy=1.000, train-precision=1.000, train-recall=1.000, train-f1=1.000]

/Users/eriktruong/Documents/TrainModel_Gorina/venv/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


In [None]:
# Save the weights of the model
if not os.path.exists('trained_models'):
    os.makedirs('trained_models')

custom_args =  {
    "input_size": multimodal_model.input_size,
    "num_layers": multimodal_model.num_layers,
    "num_channels": multimodal_model.num_channels,
    "classes": multimodal_model.classes
}
data_to_save = multimodal_model.state_dict()
data_to_save["__custom_arguments__"] = custom_args
torch.save(data_to_save, 'trained_models/BERT_CNN.pth')

#### Load Model

In [20]:
state_dict = torch.load('trained_models/BERT_CNN.pth')
loaded_model = BERT_CNN_Classifier(**state_dict["__custom_arguments__"])
state_dict.pop('__custom_arguments__', None)
loaded_model.load_state_dict(state_dict)
loaded_model.eval()

BERT_CNN_Classifier(
  (criterion): CrossEntropyLoss()
  (bert_model): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1

#### Find learning rate and batch_size

In [None]:
# tuner = Tuner(trainer)
# multimodal_model.dataloader = dataloader
# tuner.scale_batch_size(multimodal_model, mode="binsearch")
# updated_batch_size = multimodal_model.batch_size-1

#### Calculate Accuracy on test dataset

In [21]:
correct = 0
predictions = {
    0: 0,
    1: 0
}

df_test = pd.read_csv('data/multimodal_test_public.tsv', sep='\t')
df_test_labels = df_test[['id','clean_title', '2_way_label', '3_way_label', '6_way_label']]
df_test_labels.set_index('id', inplace=True)
test_img_ids = [img.split('.')[0] for img in os.listdir(test_dir)]
df_test_labels = df_test_labels.loc[test_img_ids]


test_dataset = Multi_Modal_Dataset_Tensors(
    img_dir=test_dir,
    classes=2,
    df_labels=df_test_labels,
    transform=img_to_tensor,
)

for i, (text_data, img_data, label_data) in enumerate(test_dataset):
    text_data = text_data.unsqueeze(0)
    img_data = img_data.unsqueeze(0)
    label_data = label_data

    out = multimodal_model(text_data, img_data)
    predicted = torch.argmax(out).item()
    predictions[predicted] += 1

    if predicted == label_data:
        correct += 1
    print(f'progress: {i+1}/{len(test_dataset)} | Test Accuracy: {np.round(correct/(i+1),3)}', end="\r")

accuracy = correct / len(test_dataset)
print(f'Accuracy: {accuracy}')
print()
print("predictions: ")
print(predictions)

progress: 18/9994 | Test Accuracy: 0.571

KeyboardInterrupt: 

#### Calculate Accuracy on validation dataset

In [22]:
correct = 0
predictions = {
    0: 0,
    1: 0
}

df_test = pd.read_csv('data/multimodal_test_public.tsv', sep='\t')
df_test_labels = df_test[['id','clean_title', '2_way_label', '3_way_label', '6_way_label']]
df_test_labels.set_index('id', inplace=True)
test_img_ids = [img.split('.')[0] for img in os.listdir(test_dir)]
df_test_labels = df_test_labels.loc[test_img_ids]



test_dataset = Multi_Modal_Dataset_Tensors(
    img_dir=test_dir,
    classes=2,
    df_labels=df_test_labels,
    transform=img_to_tensor,
)

for i, (text_data, img_data, label_data) in enumerate(test_dataset):
    text_data = text_data.unsqueeze(0)
    img_data = img_data.unsqueeze(0)
    label_data = label_data

    out = multimodal_model(text_data, img_data)
    predicted = torch.argmax(out).item()
    predictions[predicted] += 1

    if predicted == label_data:
        correct += 1
    print(f'progress: {i+1}/{len(test_dataset)} | Test Accuracy: {np.round(correct/(i+1),3)}', end="\r")

accuracy = correct / len(test_dataset)
print(f'Accuracy: {accuracy}')
print()
print("predictions: ")
print(predictions)

progress: 8/9994 | Test Accuracy: 0.625

KeyboardInterrupt: 