In [1]:
import os
import re
import json
import numpy as np
from sklearn import metrics
from sklearn.preprocessing import normalize

img_dir = '/hy-tmp/data/dataset_image'
train_file = '/hy-tmp/data/data-of-multimodal-sarcasm-detection/text/train.txt'
valid_file = '/hy-tmp/data/data-of-multimodal-sarcasm-detection/text/valid2.txt'
test_file = '/hy-tmp/data/data-of-multimodal-sarcasm-detection/text/test2.txt'

image_files = os.listdir(img_dir)


CM_BERT_predicts = '/root/results/CM_BERT_predicts.txt'
CM_BERT_TEXT_IN_IMG_TEXT_predicts = '/root/results/CM_BERT_TEXT_IN_IMG_TEXT_predicts.txt'
CM_VIT_predicts = '/root/results/CM_VIT_predicts.txt'
CM_VIT2_predicts = '/root/results/CM_VIT2_predicts.txt'
VIT_64_predicts = '/root/results/VIT_64.txt'
CM_GCN_predicts = '/root/results/CM_GCN_predicts.txt'
CM_ATTENTION_predicts = '/root/results/CM_ATTENTION_predicts.txt'
CM_ATTENTION2_predicts = '/root/results/CM_ATTENTION2_predicts.txt'
CM_ATTENTION3_predicts = '/root/results/CM_ATTENTION3_predicts.txt'
CM_ATTENTION4_predicts = '/root/results/CM_ATTENTION4_predicts.txt'
CM_ATTENTION5_predicts = '/root/results/CM_ATTENTION5_predicts.txt'


CM_BERT_val_predicts = '/root/results/CM_BERT_val_predicts.txt'
CM_BERT_TEXT_IN_IMG_TEXT_val_predicts = '/root/results/CM_BERT_TEXT_IN_IMG_TEXT_val_predicts.txt'
CM_VIT_val_predicts = '/root/results/CM_VIT_val_predicts.txt'
CM_VIT2_val_predicts = '/root/results/CM_VIT2_val_predicts.txt'
CM_GCN_val_predicts = '/root/results/CM_GCN_val_predicts.txt'
CM_ATTENTION_val_predicts = '/root/results/CM_ATTENTION_val_predicts.txt'
CM_ATTENTION2_val_predicts = '/root/results/CM_ATTENTION2_val_predicts.txt'
CM_ATTENTION3_val_predicts = '/root/results/CM_ATTENTION3_val_predicts.txt'
CM_ATTENTION4_val_predicts = '/root/results/CM_ATTENTION4_val_predicts.txt'
CM_ATTENTION5_val_predicts = '/root/results/CM_ATTENTION5_val_predicts.txt'


In [2]:
def load_all_data(data_file):
    all_data = {}
    with open(data_file,'r',encoding='utf-8') as fin:
        lines = fin.readlines()
        lines = [x.strip() for x in lines]
        for i in range(len(lines)):
            line = lines[i]
            data = eval(line)
            if 'train' in test_file:
                img_id,text,label = data
            else:
                img_id,text,label1,label = data

            image_file = img_id+'.jpg'
            if image_file in image_files:
                all_data[img_id] = {'image_file': image_file, 'label':int(label)}
    return all_data

def load_predicts(predicts_file, all_data):
    logits = {}
    with open(predicts_file,'r',encoding='utf-8') as fin:
        lines = fin.readlines()
        lines = [x.strip() for x in lines]
        for i in range(len(lines)):
            line = lines[i]
            data = line.split()
            img_id, predict, label, logit1, logit2 = data
            
            if img_id in all_data:
                logit1 = float(re.findall('-?\d+(?:\.\d+)?', logit1)[0])
                logit2 = float(re.findall('-?\d+(?:\.\d+)?', logit2)[0])
                logits[img_id] = [logit1, logit2]
    
    return logits

def evaluate_acc_f1(logits_list, all_data, method='mean', weights=None, norm=None):
    num_model = len(logits_list)
    labels = []
    logits_tmp = [[] for i in range(num_model)]
    
    for img_id in all_data:
        label = all_data[img_id]['label']
        labels.append(label)
        for i in range(num_model):
            model_logits = logits_list[i]
            logit = model_logits[img_id]
            logits_tmp[i].append(logit)
    for i in range(num_model):
        logits_tmp[i] = np.array(logits_tmp[i])
        assert len(labels) == logits_tmp[i].shape[0]
    
    if norm:
        for i in range(num_model):
            logits_tmp[i] = normalize(logits_tmp[i], norm=norm, axis=0)
    if weights:
        assert len(weights) == num_model
        for i in range(num_model):
            logits_tmp[i] = logits_tmp[i] * weights[i]
    
    stacked_logits = np.stack(logits_tmp, axis=0)
    if method=='sum':
        logits = np.sum(stacked_logits, axis=0)
    elif method=='max':
        logits = np.max(stacked_logits, axis=0)
    elif method=='mean':
        logits = np.mean(stacked_logits, axis=0)
    else:
        print('fusion method not find, use sum methon')
        logits = np.sum(stacked_logits, axis=0)
    
    predicts = np.argmax(logits, axis=1)
    acc = metrics.accuracy_score(labels, predicts)
    f1 = metrics.f1_score(labels, predicts)
    precision =  metrics.precision_score(labels, predicts)
    recall = metrics.recall_score(labels, predicts)
    
    return acc, f1, precision, recall

In [3]:
all_data = load_all_data(test_file)
val_all_data = load_all_data(valid_file)

print('len of all test data:', len(all_data))

CM_BERT_logits = load_predicts(CM_BERT_predicts, all_data)
CM_BERT_TEXT_IN_IMG_TEXT_logits = load_predicts(CM_BERT_TEXT_IN_IMG_TEXT_predicts, all_data)
CM_VIT_logits = load_predicts(CM_VIT_predicts, all_data)
CM_VIT2_logits = load_predicts(CM_VIT2_predicts, all_data)
CM_GCN_logits = load_predicts(CM_GCN_predicts, all_data)
CM_ATTENTION_logits = load_predicts(CM_ATTENTION_predicts, all_data)
CM_ATTENTION2_logits = load_predicts(CM_ATTENTION2_predicts, all_data)
CM_ATTENTION3_logits = load_predicts(CM_ATTENTION3_predicts, all_data)
CM_ATTENTION4_logits = load_predicts(CM_ATTENTION4_predicts, all_data)
CM_ATTENTION5_logits = load_predicts(CM_ATTENTION5_predicts, all_data)

CM_BERT_val_logits = load_predicts(CM_BERT_val_predicts, val_all_data)
CM_VIT_val_logits = load_predicts(CM_VIT_val_predicts, val_all_data)
CM_VIT2_val_logits = load_predicts(CM_VIT2_val_predicts, val_all_data)
CM_BERT_TEXT_IN_IMG_TEXT_val_logits = load_predicts(CM_BERT_TEXT_IN_IMG_TEXT_val_predicts, val_all_data)
CM_GCN_val_logits = load_predicts(CM_GCN_val_predicts, val_all_data)
CM_ATTENTION_val_logits = load_predicts(CM_ATTENTION_val_predicts, val_all_data)
CM_ATTENTION2_val_logits = load_predicts(CM_ATTENTION2_val_predicts, val_all_data)
CM_ATTENTION3_val_logits = load_predicts(CM_ATTENTION3_val_predicts, val_all_data)
CM_ATTENTION4_val_logits = load_predicts(CM_ATTENTION4_val_predicts, val_all_data)
CM_ATTENTION5_val_logits = load_predicts(CM_ATTENTION5_val_predicts, val_all_data)

len of all test data: 2409


In [15]:
weights = [1.8, 1.0, 2.5, 3.0]
weights = [1.0, 1.0, 1.5, 1.6]

val_fusion_models = [
    # CM_BERT_val_logits,
    CM_VIT2_val_logits,
    CM_BERT_TEXT_IN_IMG_TEXT_val_logits,
    CM_GCN_val_logits,
    CM_ATTENTION4_val_logits,
                ]
fusion_models = [
    # CM_BERT_logits,
    CM_VIT2_logits,
    CM_BERT_TEXT_IN_IMG_TEXT_logits,
    CM_GCN_logits,
    CM_ATTENTION4_logits,
                ]

result = evaluate_acc_f1(val_fusion_models, val_all_data, method='mean', weights=None, norm=None)
print(result)
result = evaluate_acc_f1(val_fusion_models, val_all_data, method='mean', weights=weights, norm='l2')
print(result)

result = evaluate_acc_f1(fusion_models, all_data, method='mean', weights=None, norm=None)
print(result)
result = evaluate_acc_f1(fusion_models, all_data, method='mean', weights=weights, norm='l2')
print(result)

(0.8721991701244813, 0.8402489626556017, 0.8359133126934984, 0.8446298227320125)
(0.8780082987551867, 0.8473520249221183, 0.843846949327818, 0.8508863399374348)
(0.8725612287256123, 0.8415074858027878, 0.8333333333333334, 0.8498435870698644)
(0.8779576587795765, 0.8479834539813856, 0.841025641025641, 0.8550573514077163)


In [6]:
fusion_models = [
    # CM_BERT_logits,
    CM_VIT_logits,
    CM_BERT_TEXT_IN_IMG_TEXT_logits,
    CM_GCN_logits,
    CM_ATTENTION5_logits,
                ]

result = evaluate_acc_f1(fusion_models, all_data, method='mean', weights=None, norm=None)
print(result)
result = evaluate_acc_f1(fusion_models, all_data, method='mean', weights=None, norm='l1')
print(result)

(0.871315898713159, 0.8383733055265902, 0.8383733055265902, 0.8383733055265902)
(0.8700705687007056, 0.8372334893395734, 0.8350622406639004, 0.8394160583941606)
