In [1]:
import os
import sys
import logging
import random
import numpy as np
import pickle
from time import strftime, localtime
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import transformers
from transformers import BertTokenizer, BertModel
from sklearn import metrics
import spacy
import logging

seed = 777

logger = logging.getLogger()
logger.setLevel(logging.INFO)
logger.addHandler(logging.StreamHandler(sys.stdout))

transformers.logging.set_verbosity_error()

pretrained_bert_name = '/hy-tmp/models/bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(pretrained_bert_name)
max_seq_len = 100

img_dir = '/hy-tmp/data/dataset_image'
train_file = '/hy-tmp/data/data-of-multimodal-sarcasm-detection/text/train.txt'
valid_file = '/hy-tmp/data/data-of-multimodal-sarcasm-detection/text/valid2.txt'
test_file = '/hy-tmp/data/data-of-multimodal-sarcasm-detection/text/test2.txt'

text_in_imgs_file = "/hy-tmp/data/str_in_images.pkl"

model_name = 'CM_BERT_TEXT_IN_IMG_TEXT'
check_point_path = '/hy-tmp/models'
log_file = f'/root/logs/{model_name}-{strftime("%y%m%d-%H%M", localtime())}.log'
result_file = f'/root/results/{model_name}_predicts.txt'
model_checkpoint = f'{check_point_path}/best_state/{model_name}'

inputs_cols = ['texts_merge', 'labels']
# inputs_cols = ['texts', 'texts_in_img', 'texts_merge', 'labels']
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

bert_dim = 768
polarities_dim = 2

sp_nlp = spacy.load('en_core_web_sm')

filenames = os.listdir(img_dir)

In [2]:
class bert_Dataset(Dataset):
    def __init__(self, data_file, text_in_imgs_file):
        self.all_data = []
        with open(text_in_imgs_file,'rb') as fin:
            text_in_imgs = pickle.load(fin)
        with open(data_file,'r',encoding='utf-8') as fin:
            lines = fin.readlines()
            lines = [x.strip() for x in lines]
            for i in range(len(lines)):
                line = lines[i]
                data = eval(line)
                if 'train' in data_file:
                    img_id,text,label = data
                else:
                    img_id,text,label1,label = data
                
                filename = img_id+'.jpg'
                if filename in filenames:
                    text_in_img = text_in_imgs[img_id]
                    self.all_data.append({'img_id':str(img_id), 'text': text, 'text_in_img':text_in_img, 'label':int(label)})
                
    def __len__(self):
        return len(self.all_data)
    
    def __getitem__(self, idx):
        img_id = self.all_data[idx]['img_id']
        text = self.all_data[idx]['text']
        text_in_img = self.all_data[idx]['text_in_img']
        label = self.all_data[idx]['label']
        
        text_doc,_,text_token = get_doc(text)
        text_in_img_doc,_,text_in_img_token = get_doc(text_in_img)
        if not text_token:
            text_token = ['']
        if not text_in_img_token:
            text_in_img_token = ['']
        
        return {'img_id':img_id,
                'text_token': text_token,
                'text_in_img_token': text_in_img_token,
                'label':label,
        }

def text_to_indices(text, text_pair=None):
    if text_pair is None:
        encoded_dict = tokenizer(
                            text,                      # Sentence to encode.
                            add_special_tokens = True, # Add '[CLS]' and '[SEP]'
                            padding = 'max_length',
                            truncation = True,
                            max_length = max_seq_len,    # Pad & truncate all sentences.
                            return_attention_mask = True,   # Construct attn. masks.
                            return_tensors = 'np',     # Return pytorch tensors.
                            return_length = True,
                            is_split_into_words = True,
                       )

    else:
        encoded_dict = tokenizer(
                        text,                      # Sentence to encode.
                        text_pair,
                        add_special_tokens = True, # Add '[CLS]' and '[SEP]'
                        padding = 'max_length',
                        truncation = 'longest_first',
                        max_length = max_seq_len,    # Pad & truncate all sentences.
                        return_attention_mask = True,   # Construct attn. masks.
                        return_tensors = 'np',     # Return pytorch tensors.
                        return_length = True,
                        is_split_into_words = True,
                   )
    return encoded_dict

def bert_collate_fn(data):
    b_img_ids = []
    b_text_tokens = []
    b_text_in_img_tokens = []
    b_labels = []

    for item in data:
        b_img_ids.append(item['img_id'])
        b_text_tokens.append(item['text_token'])
        b_text_in_img_tokens.append(item['text_in_img_token'])
        b_labels.append(item['label'])
        
    text_encoded_dict = text_to_indices(b_text_tokens)
    text_in_img_encoded_dict = text_to_indices(b_text_in_img_tokens)
    text_merge_encoded_dict = text_to_indices(b_text_in_img_tokens, b_text_tokens)
        
    return {
            'labels': torch.tensor(b_labels),
            'texts': torch.tensor(text_encoded_dict.input_ids),
            'texts_in_img': torch.tensor(text_in_img_encoded_dict.input_ids),
            'texts_merge': torch.tensor(text_merge_encoded_dict.input_ids),
            'img_ids': b_img_ids,
            }

def get_doc(text, max_len=0):
    token_list = []
    text = text.lower().strip()
    
    document = sp_nlp(text)
    spacy_token = [str(x) for x in document]
    spacy_len = len(spacy_token)
    
    # if max_len > 0:
    #     if spacy_len > max_len:
    #         spacy_token = spacy_token[:max_len]

    s = ''
    for token in spacy_token:
        s = s + ' ' + token
    # document = sp_nlp(s)
    # spacy_token = [str(x) for x in document]
    return document, s.strip(), spacy_token

In [3]:
train_dataset = bert_Dataset(data_file=train_file, text_in_imgs_file=text_in_imgs_file)
valid_dataset = bert_Dataset(data_file=valid_file, text_in_imgs_file=text_in_imgs_file)
test_dataset = bert_Dataset(data_file=test_file, text_in_imgs_file=text_in_imgs_file)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=bert_collate_fn)
valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False, collate_fn=bert_collate_fn)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=bert_collate_fn)

print(train_dataset.__len__(), valid_dataset.__len__(), test_dataset.__len__())

19816 2410 2409


In [4]:
class CM_BERT_TEXT_IN_IMG_TEXT(torch.nn.Module):
    def __init__(self, pretrained_vit_name):
        super(CM_BERT_TEXT_IN_IMG_TEXT,self).__init__()
        self.bert = BertModel.from_pretrained(pretrained_bert_name)
        self.fc = nn.Linear(bert_dim, polarities_dim)
  
    def forward(self, inputs):
        texts_merge, labels = inputs
        texts_out = self.bert(texts_merge, output_hidden_states=False)
        features = texts_out.pooler_output
        
        outputs = self.fc(features)
        
        return outputs
    
    def reset_params(self):
        nn.init.xavier_uniform_(self.fc.weight)

In [5]:
def eval_(model, data_loader, save_path=None):
    n_correct, n_total = 0, 0
    t_targets_all, t_outputs_all = None, None
    model.eval()
    
    with torch.no_grad():
        for i_batch, t_batch in enumerate(data_loader):
            t_inputs = [t_batch[col].to(device)   for col in inputs_cols]
            t_targets = t_batch['labels'].to(device)
            t_img_ids = t_batch['img_ids']
            
            t_outputs = model(t_inputs)

            n_correct += (torch.argmax(t_outputs, -1) == t_targets).sum().item()
            n_total += len(t_outputs)

            if t_targets_all is None:
                t_targets_all = t_targets
                t_outputs_all = t_outputs
                t_img_ids_all = t_img_ids
            else:
                t_targets_all = torch.cat((t_targets_all, t_targets), dim=0)
                t_outputs_all = torch.cat((t_outputs_all, t_outputs), dim=0)
                t_img_ids_all += t_img_ids
    
    if save_path:
        with open(save_path,'w',encoding='utf-8') as fout:
            img_ids_all = t_img_ids_all
            predicts_all = torch.argmax(t_outputs_all, -1).cpu().numpy().tolist()
            labels_all = t_targets_all.cpu().numpy().tolist()
            outputs_all = t_outputs_all.cpu().numpy().tolist()
            assert len(img_ids_all) == len(predicts_all) == len(labels_all) == len(outputs_all)
            
            for i in range(len(img_ids_all)):
                img_id = img_ids_all[i]
                predict = predicts_all[i]
                label = labels_all[i]
                output = outputs_all[i]
                fout.write(f'{str(img_id)} {str(predict)} {str(label)} {str(output)} \n')

    acc = n_correct / n_total
    f1 = metrics.f1_score(t_targets_all.cpu(), torch.argmax(t_outputs_all, -1).cpu())
    precision =  metrics.precision_score(t_targets_all.cpu(),torch.argmax(t_outputs_all, -1).cpu())
    recall = metrics.recall_score(t_targets_all.cpu(),torch.argmax(t_outputs_all, -1).cpu())
    return acc, f1 ,precision, recall

def train(model, train_data_loader, val_data_loader, test_data_loader):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam([{'params':model.bert.parameters(),'lr':2e-5},
                            {'params':model.fc.parameters(),'lr':1e-3} ], lr=1e-3, weight_decay=1e-5)
    global_step = 0
    max_val_acc = 0
    max_val_f1 = 0
    max_val_epoch = 0
    
    model.reset_params()
    
    for i_epoch in range(100):
        logger.info('>' * 100)
        logger.info('epoch: {}'.format(i_epoch))
        n_correct, n_total, loss_total = 0, 0, 0

        for i_batch, batch in enumerate(train_data_loader):
            model.train()
            global_step += 1

            inputs = [batch[col].to(device)   for col in inputs_cols]
            outputs = model(inputs)
            targets = batch['labels'].to(device)

            loss = criterion(outputs, targets)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            n_correct += (torch.argmax(outputs, -1) == targets).sum().item()
            n_total += len(outputs)
            loss_total += loss.item() * len(outputs)

            train_acc = n_correct / n_total
            train_loss = loss_total / n_total
            logger.info('loss: {:.4f}, acc: {:.4f}'.format(train_loss, train_acc))

            if global_step % 20 == 0:
                val_acc, val_f1,val_precision,val_recall = eval_(model, val_data_loader)
                logger.info('> max_val_f1: {:.4f}, max_val_acc: {:.4f}'.format(max_val_f1,max_val_acc))
                logger.info('> val_acc: {:.4f}, val_f1: {:.4f}, val_precision: {:.4f}, val_recall: {:.4f}'.format(val_acc,val_f1,val_precision,val_recall))

                if val_acc > max_val_acc:
                    max_val_f1 = val_f1
                    max_val_acc = val_acc
                    max_val_epoch = i_epoch
                    
                    torch.save(model.state_dict(), model_checkpoint)
                    logger.info(f'>> saved: {model_checkpoint}')

        torch.save(model.state_dict(), model_checkpoint)
        if i_epoch - max_val_epoch >= 3:
            logger.info('>> early stop.')
            break

    model.load_state_dict(torch.load(model_checkpoint))
    model = model.to(device)

    test_acc, test_f1,test_precision,test_recall = eval_(model, test_data_loader, save_path=result_file)
    
    logger.info(f"{test_acc} {test_f1} {test_precision} {test_recall}")

    return (test_acc, test_f1,test_precision,test_recall)

In [6]:
def main():
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
    
    # logger.addHandler(logging.FileHandler(log_file))
    
    model = CM_BERT_TEXT_IN_IMG_TEXT(pretrained_bert_name).to(device)
    
    # train(model, train_loader, valid_loader, test_loader)
    
    model.load_state_dict(torch.load(model_checkpoint))
    model = model.to(device)
    print(eval_(model, test_loader, save_path=result_file))
    
main()

(0.8530510585305106, 0.8171487603305785, 0.8096212896622313, 0.8248175182481752)
