In [1]:
import torch
import ImageFeature
import AttributeFeature
import TextFeature
import FinalClassifier
import FuseAllFeature
from LoadData import *
from torch.utils.data import Dataset, DataLoader,random_split
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
class Multimodel(torch.nn.Module):
    def __init__(self):
        super(Multimodel, self).__init__()
        self.image = ImageFeature.ExtractImageFeature()
        self.attribute = AttributeFeature.ExtractAttributeFeature()
        self.text = TextFeature.ExtractTextFeature(TEXT_LENGTH, TEXT_HIDDEN)
        self.fuse = FuseAllFeature.ModalityFusion()
        self.final_classifier = FinalClassifier.ClassificationLayer()
    def forward(self, text_index, image_feature, attribute_index):
        image_result,image_seq = self.image(image_feature)
        attribute_result,attribute_seq = self.attribute(attribute_index)
        text_result,text_seq = self.text(text_index,attribute_result)
        fusion = self.fuse(image_result,image_seq,text_result,text_seq.permute(1,0,2),attribute_result,attribute_seq.permute(1,0,2))
        output = self.final_classifier(fusion)
        return output
        

In [3]:
# loss function
# loss_fn = torch.nn.MSELoss(reduction=
loss_fn=torch.nn.BCELoss()
# learning rate
learning_rate = 0.001
# initilize the model
model = Multimodel().to(device)
# optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate,weight_decay=1e-7)



In [4]:
number_of_epoch=100
for epoch in range(number_of_epoch):

    train_loss=0
    correct_train=0
    model.train()
    for text_index, image_feature, attribute_index, group, id in train_loader:
        group = group.view(-1,1).to(torch.float32).to(device)
        pred = model(text_index.to(device), image_feature.to(device), attribute_index.to(device))
        loss = loss_fn(pred, group)
        train_loss+=loss
        correct_train+=(pred.round()==group).sum().item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # calculate valid loss

    valid_loss=0
    correct_valid=0
    model.eval()
    with torch.no_grad():
        for val_text_index, val_image_feature, val_attribute_index, val_group, val_id in val_loader:
            val_group = val_group.view(-1,1).to(torch.float32).to(device)
            val_pred = model(val_text_index.to(device), val_image_feature.to(device), val_attribute_index.to(device))
            val_loss = loss_fn(val_pred, val_group)
            valid_loss+=val_loss
            correct_valid+=(val_pred.round()==val_group).sum().item()

    print("epoch: %d train_loss=%.5f train_acc=%.3f valid_loss=%.5f valid_acc=%.3f"%(epoch,
                                                                                     train_loss/len(train_loader),
                                                                                  correct_train/len(train_loader)/batch_size,
                                                                                     valid_loss/len(val_loader),
                                                                                     correct_valid/len(val_loader)/batch_size))

epoch: 0 train_loss=0.48695 train_acc=0.759 valid_loss=0.43708 valid_acc=0.796
epoch: 1 train_loss=0.41078 train_acc=0.813 valid_loss=0.40752 valid_acc=0.822
epoch: 2 train_loss=0.37029 train_acc=0.835 valid_loss=0.43805 valid_acc=0.804
epoch: 3 train_loss=0.33806 train_acc=0.852 valid_loss=0.42840 valid_acc=0.805


KeyboardInterrupt: 

In [None]:
import sklearn
import seaborn as sns
def validation_metrics (model, dataset):
    model.eval()
    with torch.no_grad():
        total=0
        correct=0
        confusion_matrix_sum=None
        loss_sum=0
        for text_index, image_feature, attribute_index, group, id in dataset:
            group = group.view(-1,1).to(torch.float32).to(device)
            pred = model(text_index.to(device), image_feature.to(device), attribute_index.to(device))
            loss = loss_fn(pred, group)
            loss_sum+=loss
            correct+=(pred.round()==group).sum().item()
            # calculate confusion matrix
            if confusion_matrix_sum is None:
                confusion_matrix_sum=sklearn.metrics.confusion_matrix(group.to("cpu"),pred.to("cpu"),labels=[0,1])
            else:
                confusion_matrix_sum+=sklearn.metrics.confusion_matrix(group.to("cpu"),pred.to("cpu"),labels=[0,1])
        acc=correct/total
        loss_avg=loss_sum/len(dataset)
    return loss_avg.item(), acc, confusion_matrix_sum

def plot_confusion_matrix(confusion_matrix):
    emotions=['not sarcasm','sarcasm']
    sns.heatmap(confusion_matrix, annot=True, xticklabels=emotions, yticklabels=emotions)
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.show()
loss, acc, confusion_matrix=validation_metrics (model, test_loader)
print("loss:",loss,"accuracy:",acc)
plot_confusion_matrix(confusion_matrix)

In [None]:
import matplotlib as plt
def validation_metrics (model, dataset):
    model.eval()
    with torch.no_grad():
        count=0
        for text_index, image_feature, attribute_index, group, id in dataset:
            if count==5:
                break
            print(f">>>Example 1<<<")
            img=dataset.image_loader(id)
            plt.imshow(img[0].permute(1,2,0))
            plt.show()
            print("Text: ",dataset.text_loader(id))
            print("Labels: ",dataset.label_loader(id))
            print(f"Truth:{' not ' if group[0]==0 else ' '}sarcasm")
            pred = model(text_index.to(device), image_feature.to(device), attribute_index.to(device))
            print(f"Preduct:{' not ' if round(pred[0,0])==0 else ' '}sarcasm")
            count+=1

validation_metrics (model, play_loader)




