In [None]:
import torch
from torch.utils.data import DataLoader
from torch import nn

import pandas as pd
import numpy as np

from src.utils.configs.hyperparams import hyperparams
from src.utils.transforms.video_transforms import *
from src.utils.transforms.audio_transforms import *
from src.utils.configs.ravdess import *
from src.utils.helpers.functions import *
from src.utils.datasets.ravdess import *
from src.multimodal_network.multimodal import MainMultimodal

from src.utils.helpers.loops import *

import seaborn as sn


from torcheval.metrics.functional import multiclass_f1_score

from sklearn.model_selection import train_test_split, KFold
from sklearn.metrics import confusion_matrix

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

In [None]:
non_augment_df = RAVDESSConfigs().make_dataframe("<path_to_ravdess_dir>")

In [None]:
augment_1_df = non_augment_df.copy()
augment_1_df["augment"] = 1

augment_2_df = non_augment_df.copy()
augment_2_df["augment"] = 2

non_augment_df["augment"] = 0

# check an example to see if the strings/naming conventions match
non_augment_df["audio_path"][0], non_augment_df["video_path"][0], augment_1_df["audio_path"][0], augment_1_df["video_path"][0]

In [None]:
# idx = 50


# # RAW AUDIO
# d = feature_extractor(non_augment_df["audio_path"][idx], augment=0, test=True)
# data = d["raw"]
# sr = d["sr"]
# print("Audio transformed with augmentation scheme 0 (no augementation; raw audio):")
# ipd.display(ipd.Audio(data=data, rate=sr))
# print("\n")



# # AUGMENTATION 1 - harmonic
# d = feature_extractor(augment_1_df["audio_path"][idx], augment=1, test=True)
# data = d["raw"]
# sr = d["sr"]
# print("Audio transformed with augmentation scheme 1:")
# ipd.display(ipd.Audio(data=data, rate=sr))
# print("\n")



# # AUGMENTATION 2 - pitch shift
# d = feature_extractor(augment_2_df["audio_path"][idx], augment=2, test=True)
# data = d["raw"]
# sr = d["sr"]
# print("Audio transformed with augmentation scheme 2:")
# ipd.display(ipd.Audio(data=data, rate=sr))
# print("\n")

In [None]:
augment_1_df.head()

In [None]:
augment_2_df.head()

In [None]:
non_augment_df.head()

In [None]:
df = pd.concat([non_augment_df, augment_1_df, augment_2_df])
df.head()

In [None]:
len(non_augment_df), len(augment_1_df), len(augment_2_df), len(df)

In [None]:
# Split into 60% train, 20% val, 20% test set
train_df, test_df = train_test_split(df, test_size=0.40, shuffle=True, random_state=42)
len(train_df), len(test_df)

In [None]:
# Check your examples
idx = 90 # Change index to see different examples
show_example(train_df["video_path"].iloc[idx], train_df["audio_path"].iloc[idx], actual=train_df["label"].iloc[idx], idx2class=RAVDESSConfigs.idx2class)

In [None]:
cv_df, test_df = train_test_split(test_df, test_size=0.50, shuffle=True, random_state=42)

In [None]:
# View their length
len(train_df), len(cv_df), len(test_df)

In [None]:
del df, non_augment_df, augment_1_df, augment_2_df
gc.collect()

In [None]:
trainds = RAVDESSDataset(train_df, video_frame_transform, video_strategy='optimal')
cvds = RAVDESSDataset(cv_df, video_frame_transform, video_strategy='optimal')
testds = RAVDESSDataset(test_df, video_frame_transform, video_strategy='optimal')

In [None]:
trainloader = DataLoader(trainds, batch_size=hyperparams["batch"], shuffle=True)
cvloader = DataLoader(cvds, batch_size=hyperparams["batch"], shuffle=False)
testloader = DataLoader(testds, batch_size=hyperparams["batch"], shuffle=False)

In [None]:
del trainds
del cvds
del testds
del train_df
del cv_df
del test_df
gc.collect()

In [None]:
def view_a_loader(item, i):
    video, audio, label, video_p, audio_p = item
    show_example(video_p[i], audio_p[i], label[i].item(), label[i].item())
    print(f"Video shape: {video.shape} | Audio shape: {audio['mel'].shape}")
    print(f"{video_p[i]}")
    for f in video[i]:
        f = torch.permute(f, (1,2,0))
        plt.figure(figsize=(3, 3))
        plt.imshow(f.numpy())
        plt.show()
#     imgSpec(audio[i].squeeze())

    del item, video, audio, label, video_p, audio_p
    gc.collect()

In [None]:
# item1 = next(iter(trainloader))
# item2 = next(iter(trainloader))
# item3 = next(iter(trainloader))
# item4 = next(iter(trainloader))
# item5 = next(iter(trainloader))

In [None]:
# item2 - 0; raw unfiltered
# item1 - 0; color jittered
# item3 - 0; prespective

# view_a_loader(item2, 0)

In [None]:
model = MainMultimodal(num_classes=len(RAVDESSConfigs.class2idx), fine_tune_limit=3).to(device=device)
model

In [None]:
next(model.parameters()).is_cuda

In [None]:
optim = torch.optim.AdamW(params=model.parameters(), lr=hyperparams["lr"], betas=hyperparams["adam_betas"], weight_decay=1e-2)
loss_fn = nn.CrossEntropyLoss()

In [None]:
from tqdm.autonotebook import tqdm
import time
import datetime

In [None]:
epochs = []
train_loss_history = []
eval_loss_history = []

train_accuracy_history = []
eval_accuracy_history = []

In [None]:
best_params = {}
best_train_loss, best_eval_loss = 10000, 10000

In [None]:
torch.manual_seed(42)

save_memory = True

if save_memory:
    print("\tSave memory mode is on. Set `save_memory=False` to see video-audio examples")

start = time.time()
for epoch in range(hyperparams["epochs"]):
    print(f"========================== Starting Epoch: # {epoch} ==========================")

    inference_start = time.time()

    train_loss, train_acc = train_step(model, trainloader, optim, loss_fn, multiclass_f1_score, save_memory=save_memory, idx2class=RAVDESSConfigs.idx2class)
    eval_loss, eval_acc = eval_step(model, cvloader, loss_fn, multiclass_f1_score, save_memory=save_memory, idx2class=RAVDESSConfigs.idx2class)

    inference_total = time.time() - inference_start


    print(f"Epoch: #{epoch} | Total Train Loss: {train_loss} | Total Eval. Loss: {eval_loss} | Train Acc: {train_acc * 100}% | Eval Acc: {eval_acc * 100}% in {inference_total} seconds")


    epochs.append(epoch+1)
    train_loss_history.append(train_loss)
    eval_loss_history.append(eval_loss)
    train_accuracy_history.append(train_acc.detach().cpu()*100)
    eval_accuracy_history.append(eval_acc.detach().cpu()*100)

    if train_loss < best_train_loss and eval_loss < best_eval_loss:
        best_train_loss, best_eval_loss = train_loss, eval_loss
        torch.save(model.state_dict(), "./best-multimodal.pt")
        best_w = model.state_dict()

    del train_loss, eval_loss, train_acc, eval_acc
    torch.cuda.empty_cache()
    gc.collect()



end = time.time()
total = end - start
convert = str(datetime.timedelta(seconds=total))
print(f"Total Training Time: {total}s => {convert}")

In [None]:
gc.collect()

In [None]:
torch.save(model.state_dict(), "./multimodal-final.pt")

In [None]:
# epoch = hyperparams["epochs"]
epoch = len(epochs)

plt.plot(epochs, train_loss_history, color='dodgerblue', label='Train Loss')
plt.plot(epochs, eval_loss_history, color='orange', label='Eval. Loss')


plt.xlabel("Epochs")
plt.ylabel("Loss Value")
plt.title(f"Train and Eval. Loss along {epoch} epochs (RAVDESS)")

plt.legend()

plt.savefig("./Loss curves.png")

plt.show()

In [None]:
plt.plot(epochs, train_accuracy_history, color='dodgerblue', label='Train Accuracy')
plt.plot(epochs, eval_accuracy_history, color='orange', label='Eval. Accuracy')

plt.xlabel("Epochs")
plt.ylabel("F1 Score Value")
plt.title(f"Train and Eval. Accuracy along {epoch} epochs (RAVDESS)")

plt.legend()


plt.savefig("./F1-Score curves.png")

plt.show()

In [None]:
# load best weights model
model.load_state_dict(torch.load('./best-multimodal.pt'))


test_loss, test_acc, y_true, y_preds = eval_step(model, testloader, loss_fn, multiclass_f1_score, save_memory=False, confusion_matrix=True)
test_acc = test_acc.detach().cpu()

print(f"Test loss: {test_loss}\tTest Accuracy: {test_acc*100}")

In [None]:
classes = [v for k,v in RAVDESSConfigs.idx2class.items()]

cf_matrix = confusion_matrix(y_true, y_preds)

df_cm = pd.DataFrame(cf_matrix / np.sum(cf_matrix, axis=1)[:, None], index = [i for i in classes], columns = [i for i in classes])

plt.figure(figsize = (12,7))

sn.heatmap(df_cm, annot=True)

plt.savefig('./confusion_matrix_savee.png')

plt.show()

In [None]:
# Save stats
with open("./recorded.txt", "w") as f:
    f.write("R2plus1D & CNN-SE attempt\n")
    for i, line in enumerate(epochs):
        f.write(f"Epoch: {line}: | Train Loss: {train_loss_history[i]} | Train Accuracy: {train_accuracy_history[i]} | Eval Loss: {eval_loss_history[i]} | Eval Accuracy: {eval_accuracy_history[i]}")
        f.write("\n")

    f.write("\n==================================================\n")
    f.write(f"On best weights => Test loss: {test_loss}\tTest Accuracy: {test_acc*100}")
    f.write("\n==================================================\n\n\n")