In [2]:
import os
import sys
import logging
import random
from time import strftime, localtime

from sklearn import metrics
from PIL import Image
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import (CenterCrop, 
                                    Compose, 
                                    Normalize, 
                                    RandomHorizontalFlip,
                                    RandomResizedCrop, 
                                    Resize, 
                                    ToTensor)

from transformers import ViTFeatureExtractor, ViTModel

seed = 777

logger = logging.getLogger()
logger.setLevel(logging.INFO)
logger.addHandler(logging.StreamHandler(sys.stdout))

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

pretrained_vit_name = '/hy-tmp/models/vit-base-patch16-224'
feature_extractor = ViTFeatureExtractor.from_pretrained(pretrained_vit_name)

normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
crop_size = (feature_extractor.size['height'], feature_extractor.size['width'])
# crop_size = feature_extractor.size
train_transforms = Compose(
        [
            RandomResizedCrop(crop_size),
            RandomHorizontalFlip(),
            ToTensor(),
            normalize,
        ]
    )

val_transforms = Compose(
        [
            Resize(crop_size),
            CenterCrop(crop_size),
            ToTensor(),
            normalize,
        ]
    )

img_dir = '/hy-tmp/data/dataset_image'
train_file = '/hy-tmp/data/data-of-multimodal-sarcasm-detection/text/train.txt'
valid_file = '/hy-tmp/data/data-of-multimodal-sarcasm-detection/text/valid2.txt'
test_file = '/hy-tmp/data/data-of-multimodal-sarcasm-detection/text/test2.txt'

vit_dim = 768
polarities_dim = 2

model_name = 'CM_VIT'
check_point_path = '/hy-tmp/models'
log_file = f'/root/logs/{model_name}-{strftime("%y%m%d-%H%M", localtime())}.log'
result_file = f'/root/results/{model_name}_predicts.txt'
model_checkpoint = f'{check_point_path}/best_state/{model_name}'

inputs_cols = ['images', 'labels']

In [3]:
class vit_Dataset(Dataset):
    def __init__(self, img_dir, data_file, transform=None):
        self.transform = transform
        self.img_dir = img_dir
        filenames = os.listdir(img_dir)
        self.all_data = []
        with open(data_file,'r',encoding='utf-8') as fin:
            lines = fin.readlines()
            lines = [x.strip() for x in lines]
            for i in range(len(lines)):
                line = lines[i]
                data = eval(line)
                if 'train' in data_file:
                    img_id,text,label = data
                else:
                    img_id,text,label1,label = data
                
                filename = img_id+'.jpg'
                if filename in filenames:
                    self.all_data.append({'img_id':str(img_id), 'image_file': filename, 'label':int(label)})
                
    def __len__(self):
        return len(self.all_data)
    
    def __getitem__(self, idx):
        img_id = self.all_data[idx]['img_id']
        image_file = self.all_data[idx]['image_file']
        label = self.all_data[idx]['label']
        image = Image.open(os.path.join(self.img_dir, image_file))
        if self.transform:
            image = self.transform(image)
        
        return {'img_id':img_id,
                'image': image,
                'label':label,
        }

def vit_collate_fn(data):
    b_img_ids = []
    b_images = []
    b_labels = []

    for item in data:
        b_img_ids.append(item['img_id'])
        b_images.append(item['image'])
        b_labels.append(item['label'])
        
    return {
            'labels': torch.tensor(b_labels),
            'images':torch.stack(b_images, dim=0),
            'img_ids':b_img_ids,
            }


In [4]:
train_dataset = vit_Dataset(img_dir=img_dir, data_file=train_file, transform=train_transforms)
valid_dataset = vit_Dataset(img_dir=img_dir, data_file=valid_file, transform=val_transforms)
test_dataset = vit_Dataset(img_dir=img_dir, data_file=test_file, transform=val_transforms)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=vit_collate_fn)
valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False, collate_fn=vit_collate_fn)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=vit_collate_fn)

print(train_dataset.__len__(), valid_dataset.__len__(), test_dataset.__len__())

19816 2410 2409


In [5]:
class CM_VIT(torch.nn.Module):
    def __init__(self, pretrained_vit_name):
        super(CM_VIT,self).__init__()
        self.vit = ViTModel.from_pretrained(pretrained_vit_name)
        self.fc = nn.Linear(vit_dim, polarities_dim)
  
    def forward(self, inputs):
        images, labels = inputs
        images_out = self.vit(images, output_hidden_states=False)
        features = images_out.pooler_output
        
        outputs = self.fc(features)
        
        return outputs
    
    def reset_params(self):
        nn.init.xavier_uniform_(self.fc.weight)

In [6]:
def eval_(model, data_loader, save_path=None):
    n_correct, n_total = 0, 0
    t_targets_all, t_outputs_all = None, None
    model.eval()
    
    with torch.no_grad():
        for i_batch, t_batch in enumerate(data_loader):
            t_inputs = [t_batch[col].to(device)   for col in inputs_cols]
            t_targets = t_batch['labels'].to(device)
            t_img_ids = t_batch['img_ids']
            
            t_outputs = model(t_inputs)

            n_correct += (torch.argmax(t_outputs, -1) == t_targets).sum().item()
            n_total += len(t_outputs)

            if t_targets_all is None:
                t_targets_all = t_targets
                t_outputs_all = t_outputs
                t_img_ids_all = t_img_ids
            else:
                t_targets_all = torch.cat((t_targets_all, t_targets), dim=0)
                t_outputs_all = torch.cat((t_outputs_all, t_outputs), dim=0)
                t_img_ids_all += t_img_ids
    
    if save_path:
        with open(save_path,'w',encoding='utf-8') as fout:
            img_ids_all = t_img_ids_all
            predicts_all = torch.argmax(t_outputs_all, -1).cpu().numpy().tolist()
            labels_all = t_targets_all.cpu().numpy().tolist()
            outputs_all = t_outputs_all.cpu().numpy().tolist()
            assert len(img_ids_all) == len(predicts_all) == len(labels_all) == len(outputs_all)
            
            for i in range(len(img_ids_all)):
                img_id = img_ids_all[i]
                predict = predicts_all[i]
                label = labels_all[i]
                output = outputs_all[i]
                fout.write(f'{str(img_id)} {str(predict)} {str(label)} {str(output)} \n')

    acc = n_correct / n_total
    f1 = metrics.f1_score(t_targets_all.cpu(), torch.argmax(t_outputs_all, -1).cpu())
    precision =  metrics.precision_score(t_targets_all.cpu(),torch.argmax(t_outputs_all, -1).cpu())
    recall = metrics.recall_score(t_targets_all.cpu(),torch.argmax(t_outputs_all, -1).cpu())
    return acc, f1 ,precision, recall

def train(model, train_data_loader, val_data_loader, test_data_loader):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam([{'params':model.vit.parameters(),'lr':2e-5},
                            {'params':model.fc.parameters(),'lr':1e-3} ], lr=1e-3, weight_decay=1e-5)
    global_step = 0
    max_val_acc = 0
    max_val_f1 = 0
    max_val_epoch = 0
    
    model.reset_params()
    
    for i_epoch in range(100):
        logger.info('>' * 100)
        logger.info('epoch: {}'.format(i_epoch))
        n_correct, n_total, loss_total = 0, 0, 0

        for i_batch, batch in enumerate(train_data_loader):
            model.train()
            global_step += 1

            inputs = [batch[col].to(device)   for col in inputs_cols]
            outputs = model(inputs)
            targets = batch['labels'].to(device)

            loss = criterion(outputs, targets)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            n_correct += (torch.argmax(outputs, -1) == targets).sum().item()
            n_total += len(outputs)
            loss_total += loss.item() * len(outputs)

            train_acc = n_correct / n_total
            train_loss = loss_total / n_total
            logger.info('loss: {:.4f}, acc: {:.4f}'.format(train_loss, train_acc))

            if global_step % 20 == 0:
                val_acc, val_f1,val_precision,val_recall = eval_(model, val_data_loader)
                logger.info('> max_val_f1: {:.4f}, max_val_acc: {:.4f}'.format(max_val_f1,max_val_acc))
                logger.info('> val_acc: {:.4f}, val_f1: {:.4f}, val_precision: {:.4f}, val_recall: {:.4f}'.format(val_acc,val_f1,val_precision,val_recall))

                if val_acc > max_val_acc:
                    max_val_f1 = val_f1
                    max_val_acc = val_acc
                    max_val_epoch = i_epoch
                    
                    torch.save(model.state_dict(), model_checkpoint)
                    logger.info(f'>> saved: {model_checkpoint}')

        torch.save(model.state_dict(), model_checkpoint)
        if i_epoch - max_val_epoch >= 3:
            logger.info('>> early stop.')
            break

    model.load_state_dict(torch.load(model_checkpoint))
    model = model.to(device)

    test_acc, test_f1,test_precision,test_recall = eval_(model, test_data_loader, save_path=result_file)
    
    logger.info(f"{test_acc} {test_f1} {test_precision} {test_recall}")

    return (test_acc, test_f1,test_precision,test_recall)

In [9]:
def main():
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
    
    # logger.addHandler(logging.FileHandler(log_file))
    
    model = CM_VIT(pretrained_vit_name).to(device)
    
    # train(model, train_loader, valid_loader, test_loader)
    
    model.load_state_dict(torch.load(model_checkpoint))
    model = model.to(device)
    print(eval_(model, test_loader, save_path=result_file))
    
main()

Some weights of the model checkpoint at /hy-tmp/models/vit-base-patch16-224 were not used when initializing ViTModel: ['classifier.weight', 'classifier.bias']
- This IS expected if you are initializing ViTModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTModel were not initialized from the model checkpoint at /hy-tmp/models/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


(0.7231216272312163, 0.65847414234511, 0.6468812877263581, 0.670490093847758)
