In [1]:
# !pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
# !pip install ftfy regex tqdm
# !pip install git+https://github.com/openai/CLIP.git
# !pip install pandas
# !pip install transformers
# !pip install scikit-Learn

In [2]:
cd ../

/root


In [3]:
# Import all the dependencies
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, transforms
from PIL import Image
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import os
from sklearn.metrics import hamming_loss
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from sklearn.metrics import f1_score
from sklearn.metrics import roc_auc_score
from sklearn.metrics import classification_report
from sklearn.metrics import mean_absolute_error
from sklearn.preprocessing import OneHotEncoder, MultiLabelBinarizer
from pathlib import Path
import clip
import copy

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

cuda


# DataPath

In [4]:
# Load the LOC features 
all_ROI = torch.load("autodl-tmp/Dataset/feature/all_ocr_tensor.pt")
all_face = torch.load("autodl-tmp/Dataset/feature/all_face_tensor.pt")

# Load the ENT features
all_ENT = torch.load("autodl-tmp/Dataset/feature/all_text_tensor.pt")

# Load the Caption features
all_CAP = torch.load("autodl-tmp/Dataset/feature/all_cap_tensor.pt")

# In[ ]:

data_dir = "autodl-tmp/Dataset/images"
# train_path = "autodl-tmp/Dataset/test/fewshot.json"
train_path = "autodl-tmp/Dataset/test/train.json"
dev_path = "autodl-tmp/Dataset/test/val.json"
test_path = "autodl-tmp/Dataset/test/test.json"

In [5]:
test_samples_frame = pd.read_json(test_path, lines=True)
test_samples_frame.drop(test_samples_frame[test_samples_frame.text.str.split().str.len()>=20].index, axis = 0, inplace = True)
test_samples_frame = test_samples_frame.reset_index(drop=True)
test_samples_frame.head()

Unnamed: 0,id,img,labels,text,caption
0,1181,1181.jpg,"[2, 1, 2]",I hate children \n they are loud and messy and...,a brown and white cat with blue eyes being held
1,425,425.jpg,"[1, 1, 1]",One simply does not run load bank cables \n wi...,a man with long hair and a leather jacket
2,12742,12742.jpg,"[0, 0, 0]",IF YOU DON'T FOLLOW THE \n RULES \n YOU'RE GON...,if you dont follow the rules youre gonna have ...
3,11472,11472.jpg,"[1, 1, 1]",YEA LEMME GET THE \n LAST SUPPER FADER \n,a tattoo of the last supper and a black and wh...
4,2741,2741.jpg,"[1, 1, 1]",Mrs. Reimus when someone gets \n the right answer,a portrait of a man in a hat with his hand out


# CLIP

In [6]:
import clip

clip_model, preprocess = clip.load("ViT-B/32", device=device)

# Get the image features for a single image input
def process_image_clip(in_img,datatype):
    
    image = Image.open(in_img).convert("RGB")
        
    image_input = preprocess(image).to(device)
    
    return image_input


# In[ ]:


# Get the text features for a single text input
def process_text_clip(in_text):
    
    text_input = clip.tokenize(in_text).squeeze().to(device)
    return text_input

# Dataset

In [7]:
# import face_recognition
def check(lab):
    if lab == 2:
        return 1
    else:
        return 0
    
with open(r"autodl-tmp/Dataset/feature/num_face.json", "r") as f:
    num_face = eval(f.readlines()[0])

In [8]:
class MemesDatasetAug(torch.utils.data.Dataset):
    """Uses jsonl data to preprocess and serve
    dictionary of multimodal tensors for model input.
    """

    def __init__(
            self,
            data_path,
            img_dir,
            split_flag=None,
    ):

        self.samples_frame = pd.read_json(
            data_path, lines=True
        )
        
        self.samples_frame.drop(self.samples_frame[self.samples_frame.text.str.split().str.len()>=20].index, axis = 0, inplace = True)
        
        self.samples_frame = self.samples_frame.reset_index(
            drop=True
        )
        self.samples_frame.img = self.samples_frame.apply(
            lambda row: (img_dir + '/' + row.img), axis=1
        )
          
        self.datatype = split_flag

    def __len__(self):
        """This method is called when you do len(instance)
        for an instance of this class.
        """
        return len(self.samples_frame)
    def __getitem__(self, idx):
        """This method is called when you do instance[key]
        for an instance of this class.
        """
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        img_id = self.samples_frame.loc[idx, "id"]
        feature_index = int(img_id) - 1

        img_file_name = self.samples_frame.loc[idx, "img"]

        image_clip_input = process_image_clip(self.samples_frame.loc[idx, "img"],self.datatype)
        # --------------------------------------------------------------------------------------
        #         Pre-extracted features
        if num_face[str(img_id)] == 0:
            image_loc_feature = all_ROI[feature_index].to(device)
        else:
            image_loc_feature = torch.mean(torch.vstack([all_ROI[feature_index],all_face[feature_index]]), axis=0).to(device)
            
        text = self.samples_frame.loc[idx, "text"]
        text_clip_input = process_text_clip(text)
        text_drob_feature = all_ENT[feature_index].to(device)
        cap_clip_input = process_text_clip(self.samples_frame.loc[idx, "caption"])
        
        if "labels" in self.samples_frame.columns:
            #             Uncoment below for binary index creation
            labels = self.samples_frame.loc[idx, "labels"]
            label = torch.tensor(check(labels[0])).to(device)


            sample = {
                "id": img_id,
                "image_clip_input": image_clip_input,
                "image_loc_feature": image_loc_feature,
                "text_clip_input": text_clip_input,
                "text_drob_embedding": text_drob_feature,
                "cap_clip_input": cap_clip_input,
                "label": label,
            }
            
        return sample

In [9]:
dataset_train = MemesDatasetAug(train_path, data_dir, 'train')
dataloader_train = DataLoader(dataset_train, batch_size=128,
                        shuffle=True, num_workers=0)
dataset_val = MemesDatasetAug(dev_path, data_dir, 'val')
dataloader_val = DataLoader(dataset_val, batch_size=128,
                        shuffle=False, num_workers=0)
dataset_test = MemesDatasetAug(test_path, data_dir, 'test')
dataloader_test = DataLoader(dataset_test, batch_size=128,
                        shuffle=False, num_workers=0)

# Model

In [10]:
class MVLP(nn.Module):
    def __init__(self, n_out):
        super(MVLP, self).__init__()

        # 使用序列工具快速构建
        self.linear1 = nn.Sequential(
            nn.Linear(4096, 1024),
            nn.ReLU(),
            nn.Dropout(p=0.3),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(p=0.3),
        )

        self.linear2 = nn.Sequential(
            nn.Linear(768, 512),
            nn.ReLU(),
            nn.Dropout(p=0.1),
        )

        self.linear3 = nn.Sequential(
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Dropout(p=0.2),
        )
        
        self.linear4 = nn.Sequential(
            nn.AvgPool2d((32,1)),
            nn.MaxPool2d((32,1)),
            nn.Linear(3, 3),
            nn.GELU(),
            nn.Linear(3, 3),
            nn.GELU(),
        )
        
        self.clip = copy.deepcopy(clip_model)
        
        self.gen_key_L1 = nn.Linear(512, 256) 
        self.gen_query_L1 = nn.Linear(512, 256)  
        self.gen_key_L2 = nn.Linear(512, 256)  
        self.gen_query_L2 = nn.Linear(512, 256)  
        
        self.soft = nn.Softmax(dim=1)
#         Initialize the global weights
        self.w = nn.Parameter(torch.ones(3))  
        self.fw = nn.Parameter(torch.ones(2)) 

        pre_output_layers = [nn.Linear(1024, 512), nn.ReLU()]
        pre_output_layers.extend([nn.Linear(512, 128), nn.ReLU()])
        self.fc_out = nn.Sequential(*pre_output_layers)
        
        self.out = nn.Linear(128, n_out)  

    def selfatt_a(self, vec1, vec2):
        q1 = F.relu(self.gen_query_L1(vec1))
        k1 = F.relu(self.gen_key_L1(vec1))
        q2 = F.relu(self.gen_query_L1(vec2))
        k2 = F.relu(self.gen_key_L1(vec2))
        score1 = torch.reshape(torch.bmm(q1.view(-1, 1, 256), k2.view(-1, 256, 1)), (-1, 1))
        score2 = torch.reshape(torch.bmm(q2.view(-1, 1, 256), k1.view(-1, 256, 1)), (-1, 1))
        wt_score1_score2_mat = torch.cat((score1, score2), 1)
        wt_i1_i2 = self.soft(wt_score1_score2_mat.float())  # prob
        prob_1 = wt_i1_i2[:, 0]
        prob_2 = wt_i1_i2[:, 1]
        wtd_i1 = vec1 * prob_1[:, None]
        wtd_i2 = vec2 * prob_2[:, None]
        out_rep = torch.cat((wtd_i1, wtd_i2), 1)
        return out_rep 

    def selfatt_b(self, vec1, vec2):
        q1 = F.relu(self.gen_query_L2(vec1))
        k1 = F.relu(self.gen_key_L2(vec1))
        q2 = F.relu(self.gen_query_L2(vec2))
        k2 = F.relu(self.gen_key_L2(vec2))
        score1 = torch.reshape(torch.bmm(q1.view(-1, 1, 256), k2.view(-1, 256, 1)), (-1, 1))
        score2 = torch.reshape(torch.bmm(q2.view(-1, 1, 256), k1.view(-1, 256, 1)), (-1, 1))
        wt_score1_score2_mat = torch.cat((score1, score2), 1)
        wt_i1_i2 = self.soft(wt_score1_score2_mat.float())  # prob
        prob_1 = wt_i1_i2[:, 0]
        prob_2 = wt_i1_i2[:, 1]
        wtd_i1 = vec1 * prob_1[:, None]
        wtd_i2 = vec2 * prob_2[:, None]
        out_rep = torch.cat((wtd_i1, wtd_i2), 1)
        return out_rep
    
#     Adaptive Dynamic Fusion
    def adf(self,img_feat,text_feat,cap_feat):
        
#         normalized global weights
        w1 = torch.exp(self.w[0]) / torch.sum(torch.exp(self.w))
        w2 = torch.exp(self.w[1]) / torch.sum(torch.exp(self.w))
        w3 = torch.exp(self.w[2]) / torch.sum(torch.exp(self.w))
        
        fw1 = torch.exp(self.fw[0]) / torch.sum(torch.exp(self.fw))
        fw2 = torch.exp(self.fw[1]) / torch.sum(torch.exp(self.fw))
        
#         local weight
        out_wt = torch.cat((img_feat.unsqueeze(2),text_feat.unsqueeze(2),cap_feat.unsqueeze(2)),dim=2)
        wt2 = self.linear4(out_wt.float()).reshape(out_wt.shape[0],3)
        wt2 = F.normalize(torch.sigmoid(wt2.float()),p=1,dim=1)
        att_wt = torch.chunk(wt2, 3, dim = 1)
        
#         final weight
        out_img_txt1 = img_feat * att_wt[0] + text_feat * att_wt[1] + cap_feat * att_wt[2]
        out_img_txt2 = img_feat * w1 + text_feat * w2 + cap_feat * w3
    
#         final fusion output
#         out_img_txt = (out_img_txt1 + out_img_txt2)/2
        out_img_txt = fw1*out_img_txt1 + fw2*out_img_txt2
        
        return out_img_txt
    
    def forward(self, in_CI, in_Loc, in_CT, in_Drob, in_Cap):
        
        in_CI = self.clip.encode_image(in_CI).data.float()
        in_CT = self.clip.encode_text(in_CT).data.float()
        in_Cap = self.clip.encode_text(in_Cap).data.float()
        
        Loc_feat = self.linear1(in_Loc)
        Drob_feat = self.linear2(in_Drob)
        Cap_feat = self.linear3(in_Cap)
        
        out_img = self.selfatt_a(Loc_feat, in_CI)
        out_txt = self.selfatt_b(Drob_feat, in_CT)
        
        out_img_txt = self.adf(out_img,out_txt,Cap_feat)
    
        final_out = self.fc_out(out_img_txt)
        out = torch.sigmoid(self.out(final_out))
        return out

In [11]:
output_size = 1 
exp_name = "best"
exp_path = "checkpoint" 
lr = 0.001
criterion = nn.BCELoss()

model = MVLP(output_size)
model.to(device)

# Frozen vision parameter
for name, param in model.named_parameters():
    if "visual" in name:
        param.requires_grad = False
    elif "logit_scale" in name:
        param.requires_grad = False

# Frozen language parameter
# num = 0
# for name, param in model.named_parameters():
#     if "clip.transformer" in name:
#         num += 1
#         if num <= 14:
#             param.requires_grad = False

# Frozen all parameter
# for name, param in model.named_parameters():
#     if "clip" in name:
#         param.requires_grad = False
        
# print(model)
optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 10, gamma=0.5)

# EarlyStopping

In [12]:
import numpy as np
import torch

class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print            
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func
    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

# Train

In [13]:
def train_model(model, patience, n_epochs):
    epochs = n_epochs

    train_acc_list=[]
    val_acc_list=[]
    train_loss_list=[]
    val_loss_list=[]

    # initialize the experiment path
    Path(exp_path).mkdir(parents=True, exist_ok=True)
    # initialize early_stopping object
    chk_file = os.path.join(exp_path, 'checkpoint_'+exp_name+'.pt')
    early_stopping = EarlyStopping(patience=patience, verbose=True, path=chk_file)

    model.train()
    for i in range(epochs):
        total_acc_train = 0
        total_loss_train = 0
        total_train = 0
        for data in dataloader_train:

#             Clip features...
            img_feat_clip = data['image_clip_input']
            txt_feat_clip = data['text_clip_input']
            cap_feat_clip = data['cap_clip_input']

            img_feat_loc = data['image_loc_feature']
            txt_feat_trans = data['text_drob_embedding']
            
            label = data['label'].to(device)
            
            
            model.zero_grad(),
            output = model(img_feat_clip, img_feat_loc, txt_feat_clip , txt_feat_trans , cap_feat_clip)
            
            loss = criterion(output.squeeze(), label.float())

            loss.backward()
           
            optimizer.step()
            
            with torch.no_grad():
                output = output.squeeze()
                predicted_train = torch.where(output.data > 0.5, torch.ones_like(output.data), torch.zeros_like(output.data))
                total_train += label.size(0)
                total_acc_train += (predicted_train == label).sum().item()
                total_loss_train += loss.item()

        train_acc = 100 * total_acc_train/total_train
        train_loss = total_loss_train/total_train
        model.eval()
        total_acc_val = 0
        total_loss_val = 0
        total_val = 0
        with torch.no_grad():
            for data in dataloader_val:
#                 Clip features...
                img_feat_clip = data['image_clip_input']
                txt_feat_clip = data['text_clip_input']
                cap_feat_clip = data['cap_clip_input']

                img_feat_loc = data['image_loc_feature']
                txt_feat_trans = data['text_drob_embedding']

                label = data['label'].to(device)

                model.zero_grad()

                output = model(img_feat_clip, img_feat_loc, txt_feat_clip , txt_feat_trans , cap_feat_clip)
                
                val_loss = criterion(output.squeeze(), label.float())
                output = output.squeeze()
                predicted_val = torch.where(output.data > 0.5, torch.ones_like(output.data), torch.zeros_like(output.data))
                total_val += label.size(0)
                total_acc_val += (predicted_val == label).sum().item()
                total_loss_val += val_loss.item()
        print("Saving model...")

        torch.save(model.state_dict(), os.path.join(exp_path, "final.pt"))

        val_acc = 100 * total_acc_val/total_val
        val_loss = total_loss_val/total_val

        train_acc_list.append(train_acc)
        val_acc_list.append(val_acc)
        train_loss_list.append(train_loss)
        val_loss_list.append(val_loss)

        early_stopping(val_loss, model)

        if early_stopping.early_stop:
            print("Early stopping")
            break

        print(f'Epoch {i+1}: train_loss: {train_loss:.4f} train_acc: {train_acc:.4f} | val_loss: {val_loss:.4f} val_acc: {val_acc:.4f}')
        # Dynamic output w weight transformation
        for name, p in model.named_parameters():
            if name == 'w':
                print("特征权重: ", name)
                w0 = (torch.exp(p[0]) / torch.sum(torch.exp(p))).item()
                w1 = (torch.exp(p[1]) / torch.sum(torch.exp(p))).item()
                w2 = (torch.exp(p[2]) / torch.sum(torch.exp(p))).item()
#                 w3 = (torch.exp(p[3]) / torch.sum(torch.exp(p))).item()
                print("w0={} w1={} w2={}".format(w0, w1,w2))
            if name == 'fw':
                print("最终特征权重: ", name)
                w0 = (torch.exp(p[0]) / torch.sum(torch.exp(p))).item()
                w1 = (torch.exp(p[1]) / torch.sum(torch.exp(p))).item()
                print("fw0={} fw1={}".format(w0, w1))
                print("")
        model.train()
        scheduler.step()
        torch.cuda.empty_cache()

    # load the last checkpoint with the best model
#     model.load_state_dict(torch.load(chk_file))

    return  model, train_acc_list, val_acc_list, train_loss_list, val_loss_list, i

In [14]:
n_epochs = 20
# early stopping patience; how long to wait after last time validation loss improved.
patience = 20
# model, train_acc_list, val_acc_list, train_loss_list, val_loss_list, epowsqec_num = train_model(model, patience, n_epochs)

# Test

In [None]:
model_test = MVLP(output_size).to(device)
# model_test.load_state_dict(torch.load('checkpoint/checkpoint_best.pt'))
model_test.load_state_dict(torch.load('autodl-tmp/path_to_saved_files/checkpoint/final_8546.pt'))
# model_test.load_state_dict(torch.load('checkpoint/final.pt'))
def test_model(model):
    model.eval()
    total_acc_test = 0
    total_loss_test = 0
    total_test = 0
    outputs = []
    test_labels=[]
    with torch.no_grad():
        for data in dataloader_test:
            img_feat_clip = data['image_clip_input']
            txt_feat_clip = data['text_clip_input']

            img_feat_loc = data['image_loc_feature']
            txt_feat_trans = data['text_drob_embedding']
            cap_feat_clip = data['cap_clip_input']
            label = data['label'].to(device)

            out = model(img_feat_clip, img_feat_loc, txt_feat_clip, txt_feat_trans , cap_feat_clip)

            outputs += list(out.cpu().data.numpy())
            loss = criterion(out.squeeze(), label.float())
            out = out.squeeze()
            predicted_test = torch.where(out.data > 0.5, torch.ones_like(out.data), torch.zeros_like(out.data))
            
            total_test += label.size(0)
            total_acc_test += (predicted_test == label).sum().item()
            total_loss_test += loss.item()
    print(total_acc_test)
    
    acc_test = total_acc_test/total_test
    loss_test = total_loss_test/total_test
    print(f'acc: {acc_test:.4f} loss: {loss_test:.4f}')
    # Dynamic output w weight transformation
    for name, p in model.named_parameters():
        if name == 'w':
            print("特征权重: ", name)
            w0 = (torch.exp(p[0]) / torch.sum(torch.exp(p))).item()
            w1 = (torch.exp(p[1]) / torch.sum(torch.exp(p))).item()
            w2 = (torch.exp(p[2]) / torch.sum(torch.exp(p))).item()
            print("w0={} w1={} w2={}".format(w0, w1,w2))
            print("")
        if name == 'fw':
            print("最终特征权重: ", name)
            w0 = (torch.exp(p[0]) / torch.sum(torch.exp(p))).item()
            w1 = (torch.exp(p[1]) / torch.sum(torch.exp(p))).item()
            print("fw0={} fw1={}".format(w0, w1))
            print("")
    return outputs
outputs = test_model(model_test)

In [None]:
np_out = np.array(outputs)
y_pred = np.zeros(np_out.shape)
y_pred[np_out > 0.5] = 1
y_pred = np.array(y_pred)

test_labels = []
for index, row in test_samples_frame.iterrows():
    lab = row['labels'][0]
    if lab == 2:
        test_labels.append(1)
    else:
        test_labels.append(0)        
def calculate_mmae(expected, predicted, classes):
    NUM_CLASSES = len(classes)
    count_dict = {}
    dist_dict = {}
    for i in range(NUM_CLASSES):
        count_dict[i] = 0
        dist_dict[i] = 0.0
    for i in range(len(expected)):
        dist_dict[expected[i]] += abs(expected[i] - predicted[i])
        count_dict[expected[i]] += 1
    overall = 0.0
    for claz in range(NUM_CLASSES):
        if count_dict[claz] != 0:
            class_dist = 1.0 * dist_dict[claz] / count_dict[claz]
            overall += class_dist
    overall /= NUM_CLASSES
    #     return overall[0]
    return overall

In [None]:
rec = np.round(recall_score(test_labels, y_pred, average="macro"), 4)
prec = np.round(precision_score(test_labels, y_pred, average="macro"), 4)
f1 = np.round(f1_score(test_labels, y_pred, average="macro"), 4)
acc = np.round(accuracy_score(test_labels, y_pred), 4)
mmae = np.round(calculate_mmae(test_labels, y_pred, [0, 1]), 4)
mae = np.round(mean_absolute_error(test_labels, y_pred), 4)
print(classification_report(test_labels, y_pred))

# In[ ]:

print("Acc, F1,  Rec, Prec, MAE, MMAE")
print(acc, f1, rec, prec, mae, mmae)