# Importing Libraries

In [None]:
import os
import torch
import torchtext
import numpy as np
import pandas as pd
from PIL import Image
import seaborn as sns
from nltk import word_tokenize
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchvision.models.efficientnet import efficientnet_b1, EfficientNet_B1_Weights

In [None]:
import sys
sys.path.append("../")
from Utils.nlp import pad_sequence, clean_caption, get_word_to_index

# Model Architecture

In [None]:
class MultimodalModel(torch.nn.Module):
    def __init__(self, n_layers:int, embed_dim:int, hidden_dim:int, neurons:list, embedding:str="twitter.27B", bidirectionality:bool=False, freeze:bool=False, weights=None) -> None:
        super().__init__()
        model = efficientnet_b1(weights=weights)
        model.classifier = torch.nn.Sequential(torch.nn.Dropout(0.2,True))
        self.CNN = model

        glove_embeddings = torchtext.vocab.GloVe(embedding, embed_dim)
        self.LSTM = torch.nn.Sequential(
            torch.nn.Embedding.from_pretrained(glove_embeddings.vectors, freeze=freeze),
            torch.nn.LSTM(embed_dim, hidden_dim, n_layers, batch_first=True, bidirectional=bidirectionality),
        )

        self.linear_layers = torch.nn.ModuleList()
        if bidirectionality == True:
            self.linear_layers.append(torch.nn.Linear(1280+(2*hidden_dim), neurons[0]))
        else:
            self.linear_layers.append(torch.nn.Linear(1280+hidden_dim, neurons[0]))
        self.linear_layers.append(torch.nn.SELU())
        for i in range(1, len(neurons)):
            self.linear_layers.append(torch.nn.Linear(neurons[i-1], neurons[i]))
            self.linear_layers.append(torch.nn.SELU())
        self.linear_layers.append(torch.nn.Dropout(0.3))
        self.linear_layers.append(torch.nn.Linear(neurons[-1], 3))

    def forward(self, text, image):
        image_embeddings = self.CNN(image)
        text_embeddings = self.LSTM(text)
        multimodal = torch.concat([image_embeddings, text_embeddings], dim=1).view(1,1,-1)
        for layer in self.linear_layers:
            multimodal = layer(multimodal)
        return torch.nn.functional.log_softmax(multimodal, dim=1)

# Data Preprocessing

In [None]:
d1 = pd.read_csv("../Data/Images/ImageLabelsSequenced.csv", index_col=False)
d2 = pd.read_csv("../Data/Text/Engineered.csv", index_col=False)

In [None]:
d1['Caption'] = d2['Caption']
d1['Hashtags'] = d2['Hashtags']

In [None]:
d1 = d1[['File Name', 'Caption', 'Hashtags', 'LABEL']]
d1

In [None]:
d1.to_csv("../Data/multimodal.csv", index=False)

In [None]:
d1['Caption'] = d1['Caption'].str.replace('[#@!]', '', regex=True)

# Additional cleanup: Remove any other non-alphanumeric characters
d1['Caption'] = d1['Caption'].str.replace('[^a-zA-Z0-9\s]', '', regex=True)

In [None]:
word_tokenize(d1['Caption'][1])

In [None]:
d1['Caption'] = d1['Caption'].apply(word_tokenize)
# # d1['Caption'] = d1['Caption'].apply(clean_caption)
# d1['Caption']

In [None]:
type(d1['Caption'][1])

In [None]:
d1['Caption'] = d1['Caption'].apply(lambda x: pad_sequence(x, max_seq_length=25))

In [None]:
a = d1['Caption'].apply(len)

In [None]:
a.quantile([0.8,0.9,0.95, 0.99])

In [None]:
d1

In [None]:
d1.to_csv("../Data/multimodal.csv", index=False)

# Dataset Architecture

In [None]:
import ast
data = pd.read_csv("../Data/multimodal.csv", index_col=False)
data['Caption'] = data['Caption'].apply(ast.literal_eval)

In [None]:
class MultimodalDataset(Dataset):
    def __init__(self, tokens, image_dir, labels, images, words_to_index:dict, transform=None):
        self.tokens = tokens
        self.labels = labels
        self.image_dir = image_dir
        self.words_to_index = words_to_index
        self.images = images
        self.transform = transform

    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir, self.images[index])
        image = Image.open(img_path)
        image = image.convert("RGB")
        label = self.labels[index]
        text_indices = torch.LongTensor([self.words_to_index.get(word, 0) for word in self.tokens[index]])
        if self.transform:
            image = self.transform(image)

        return (image, text_indices, label)

# Data Preparation

In [None]:
train, test = train_test_split(data, test_size=0.2, shuffle=True, stratify=data['LABEL'], random_state=1)
train, val = train_test_split(train, test_size=0.125, shuffle=True, stratify=train['LABEL'], random_state=1)

In [None]:
word_to_index = get_word_to_index(".vector_cache/glove.twitter.27B.25d.txt")

In [None]:
EfficientNet_B1_Weights.IMAGENET1K_V2.transforms()

In [None]:
np.array(train['LABEL'])[3]

In [None]:
train_set = MultimodalDataset(np.array(train['Caption']), "../Data/Images", np.array(train['LABEL']), np.array(train['File Name']), word_to_index, EfficientNet_B1_Weights.IMAGENET1K_V2.transforms())
val_set = MultimodalDataset(np.array(val['Caption']), "../Data/Images", np.array(val['LABEL']), np.array(val['File Name']), word_to_index, EfficientNet_B1_Weights.IMAGENET1K_V2.transforms())

In [None]:
train_loader = DataLoader(train_set, 32)
val_loader = DataLoader(val_set, 32)

# Training

## Training Loop

In [None]:
def TrainLoop(
    model,
    optimizer:torch.optim.Optimizer,
    criterion:torch.nn.Module,
    train_dataloader:torch.utils.data.DataLoader,
    val_dataloader:torch.utils.data.DataLoader,
    scheduler:torch.optim.lr_scheduler.ReduceLROnPlateau,
    num_epochs:int=20,
    early_stopping_rounds:int=5,
    return_best_model:bool=True,
    device:str='cpu'
):
    model.to(device)
    best_val_loss = float('inf')
    epochs_without_improvement = 0

    total_train_loss = []
    total_val_loss = []
    best_model_weights = model.state_dict()

    train_accuracies = []
    val_accuracies = []

    for epoch in tqdm(range(num_epochs)):
        model.train()
        print("\nEpoch {}\n----------".format(epoch))
        train_loss = 0
        for i, (images, texts, labels) in enumerate(train_dataloader):
            images = images.to(device)
            texts = texts.to(device)
            labels = labels.to(device, dtype=torch.long)
            optimizer.zero_grad()
            outputs = model(texts, images)
            loss = criterion(outputs, labels)
            train_loss += loss
            loss.backward()
            optimizer.step()
            print("Loss for batch {} = {}".format(i, loss))

        print("\nTraining Loss for epoch {} = {}\n".format(epoch, train_loss))
        total_train_loss.append(train_loss/len(train_dataloader.dataset))

        model.eval()
        validation_loss = 0
        with torch.inference_mode():
            val_true_labels = []
            train_true_labels = []
            val_pred_labels = []
            train_pred_labels = []
            for (images, texts, labels) in val_dataloader:
                images = images.to(device)
                texts = texts.to(device)
                labels = labels.to(device, dtype=torch.long)
                outputs = model(texts, images)
                loss = criterion(outputs, labels)
                validation_loss += loss

                outputs = torch.argmax(outputs, dim=1)
                val_true_labels.extend(labels.cpu().numpy())
                val_pred_labels.extend(outputs.cpu().numpy())

            for (images, texts, labels) in train_dataloader:
                images = images.to(device)
                texts = texts.to(device)
                labels = labels.to(device, dtype=torch.long)
                outputs = model(texts, images)

                outputs = torch.argmax(outputs, dim=1)
                train_true_labels.extend(labels.cpu().numpy())
                train_pred_labels.extend(outputs.cpu().numpy())

            if validation_loss < best_val_loss:
                best_val_loss = validation_loss
                epochs_without_improvement = 0
                best_model_weights = model.state_dict()
            else:
                epochs_without_improvement += 1

            val_true_labels = np.array(val_true_labels)
            train_true_labels = np.array(train_true_labels)
            val_pred_labels = np.array(val_pred_labels)
            train_pred_labels = np.array(train_pred_labels)

            train_accuracy = accuracy_score(train_true_labels, train_pred_labels)
            val_accuracy = accuracy_score(val_true_labels, val_pred_labels)

            train_accuracies.append(train_accuracy)
            val_accuracies.append(val_accuracy)

            print(f"Current Validation Loss = {validation_loss}")
            print(f"Best Validation Loss = {best_val_loss}")
            print(f"Epochs without Improvement = {epochs_without_improvement}")

            print(f"Train Accuracy: {train_accuracy * 100:.2f}%")
            print(f"Validation Accuracy: {val_accuracy * 100:.2f}%")

        total_val_loss.append(validation_loss/len(val_dataloader.dataset))
        scheduler.step(validation_loss)
        if epochs_without_improvement == early_stopping_rounds:
            break

    if return_best_model == True:
        model.load_state_dict(best_model_weights)
    total_train_loss = [item.cpu().detach().numpy() for item in total_train_loss]
    total_val_loss = [item.cpu().detach().numpy() for item in total_val_loss]

    total_train_loss = np.array(total_train_loss)
    total_val_loss = np.array(total_val_loss)

    train_accuracies = np.array(train_accuracies)
    val_accuracies = np.array(val_accuracies)

    x_train = np.arange(len(total_train_loss))
    x_val = np.arange(len(total_val_loss))
    
    sns.set_style('whitegrid')
    plt.figure(figsize=(14,5))
    
    plt.subplot(1,2,1)
    sns.lineplot(x=x_train, y=total_train_loss, label='Training Loss')
    sns.lineplot(x=x_val, y=total_val_loss, label='Validation Loss')
    plt.title("Loss over {} Epochs".format(len(total_train_loss)))
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.xticks(np.arange(len(total_train_loss)))
    
    plt.subplot(1,2,2)
    sns.lineplot(x=x_train, y=train_accuracies, label='Training Accuracy')
    sns.lineplot(x=x_val, y=val_accuracies, label='Validation Accuracy')
    plt.title("Accuracy over {} Epochs".format(len(total_train_loss)))
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.xticks(np.arange(len(total_train_loss)))

    plt.show()

## Model 1

In [None]:
model_1 = MultimodalModel(4, 25, 256, [512], bidirectionality=True, weights=EfficientNet_B1_Weights.IMAGENET1K_V2)
optimizer = torch.optim.NAdam(model_1.parameters(), lr=0.001)
loss_fun = torch.nn.NLLLoss()
scheduler = ReduceLROnPlateau(optimizer, 'min', 0.4, 8)

In [None]:
sum(p.numel() for p in model_1.parameters())

In [None]:
TrainLoop(model_1, optimizer, loss_fun, train_loader, val_loader, scheduler, 100, 20, True, 'cpu')