## Import packages

In [42]:
from __future__ import print_function
import torch.utils.data
from scipy import misc
from torch import optim
from torchvision.utils import save_image
import numpy as np
import pickle
import time
import random
import os
import torch
from torch import nn
from torch.nn import functional as F
from tqdm.auto import tqdm, trange
from torchvision import transforms
import pandas as pd
import torchvision.datasets as datasets
import torch.utils.data as data
import copy
from torch.autograd import Variable
from torch.utils.data import Dataset
from skimage import io
from PIL import Image

from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score, confusion_matrix
from transformers import AutoTokenizer, BertTokenizer, BertModel, BertForSequenceClassification
from collections import namedtuple
import torchvision.models as models
from sentence_transformers import SentenceTransformer
from laserembeddings import Laser

from os import listdir

In [2]:
tqdm.pandas()

## Enable gpu device

In [3]:
device = torch.device('cuda:0')

## Set random seed

In [43]:
SEED = 8899
# SEED = 8888

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

## Load data

In [129]:
train_img_dir = os.path.join('../../..', 'MVAE', 'mediaeval2016', 'images_train')
test_img_dir = os.path.join('../../..', 'MVAE', 'mediaeval2016', 'images_test')

# load training label
train_df = pd.read_csv('~/adversarial_learning/MVAE/mediaeval2016/train_posts.txt', sep='\t')
# load test label
test_df = pd.read_csv('~/adversarial_learning/MVAE/mediaeval2016/test_posts.txt', sep='\t')
# test_df = pd.read_csv('~/adversarial_learning/MVAE/mediaeval2016/test_posts_attacks/test_posts_notext.txt', sep='\t')
# test_df = pd.read_csv('~/adversarial_learning/MVAE/mediaeval2016/test_posts_attacks/test_posts_all_fake.txt', sep='\t')
# test_df = pd.read_csv('~/adversarial_learning/MVAE/mediaeval2016/test_posts_charswap_rm_error_line.txt', sep='\t')

In [6]:
train_df.describe()

Unnamed: 0,post_id,user_id
count,15629.0,15629.0
mean,3.516204e+17,461083800.0
std,1.299838e+17,659292000.0
min,2.61573e+17,1117.0
25%,2.631248e+17,68196220.0
50%,2.641468e+17,244215800.0
75%,5.100905e+17,485236300.0
max,5.956985e+17,3229452000.0


In [7]:
train_df.head(5)

Unnamed: 0,post_id,post_text,user_id,image_id(s),username,timestamp,label
0,324597532548276224,Don't need feds to solve the #bostonbombing wh...,886672620,"boston_fake_03,boston_fake_35",SantaCruzShred,Wed Apr 17 18:57:37 +0000 2013,fake
1,325145334739267584,PIC: Comparison of #Boston suspect Sunil Tripa...,21992286,boston_fake_23,Oscar_Wang,Fri Apr 19 07:14:23 +0000 2013,fake
2,325152091423248385,I'm not completely convinced that it's this Su...,16428755,boston_fake_34,jamwil,Fri Apr 19 07:41:14 +0000 2013,fake
3,324554646976868352,Brutal lo que se puede conseguir en colaboraci...,303138574,"boston_fake_03,boston_fake_35",rubenson80,Wed Apr 17 16:07:12 +0000 2013,fake
4,324315545572896768,4chan and the bombing. just throwing it out th...,180460772,boston_fake_15,Slimlenny,Wed Apr 17 00:17:06 +0000 2013,fake


In [7]:
test_df.head(5)

Unnamed: 0,post_id,post_text,user_id,username,image_id(s),timestamp,label
0,651118294447951872,blank,383409700.0,AlexArtAndros,airstrikes_1,Mon Oct 05 19:34:33 +0000 2015,fake
1,651115824065830912,blank,383409700.0,AlexArtAndros,airstrikes_1,Mon Oct 05 19:24:44 +0000 2015,fake
2,651095856662360065,blank,2712310000.0,NataYaraya,airstrikes_1,Mon Oct 05 18:05:23 +0000 2015,fake
3,651086828234104832,blank,36905720.0,Alltecz,airstrikes_1,Mon Oct 05 17:29:31 +0000 2015,fake
4,651034616007106560,blank,1070959000.0,msojormsojor,airstrikes_1,Mon Oct 05 14:02:02 +0000 2015,fake


In [45]:
def return_first_image(row):
    return row['image_id(s)'].split(',')[0].strip()

In [46]:
train_df['first_image_id'] = train_df.progress_apply (lambda row: return_first_image(row),axis=1)

  0%|          | 0/15629 [00:00<?, ?it/s]

In [47]:
images_train_dataset = [i for i in train_df['first_image_id'].tolist()]
images_train_folder = [i.split('.')[0].strip() for i in os.listdir(train_img_dir)]
images_train_not_available = set(images_train_dataset)-set(images_train_folder)
images_train_not_available

{'boston_fake_35',
 'eclipse_video_01',
 'sandy_real_09',
 'sandy_real_10',
 'sandy_real_4',
 'sandy_real_6',
 'sandy_real_90',
 'syrianboy_1',
 'varoufakis_1'}

In [48]:
images_test_dataset = [i.split(',')[0].strip() for i in test_df['image_id(s)'].tolist()]
images_test_folder = [i.split('.')[0].strip() for i in listdir(test_img_dir)]
images_test_not_available = set(images_test_dataset)-set(images_test_folder)
images_test_not_available

{'airstrikes_1',
 'american_soldier_quran_1',
 'ankara_explosions_1',
 'ankara_explosions_2',
 'ankara_explosions_3',
 'attacks_paris_16',
 'attacks_paris_24',
 'attacks_paris_7',
 'boko_haram_1',
 'brussels_car_metro_1',
 'brussels_car_metro_2',
 'brussels_car_metro_3',
 'brussels_explosions_1',
 'brussels_explosions_2',
 'brussels_explosions_3',
 'convoy_explosion_turkey_1',
 'convoy_explosion_turkey_2',
 'convoy_explosion_turkey_3',
 'donald_trump_attacker_1',
 'eagle_kid_1',
 'immigrants_1',
 'immigrants_2',
 'immigrants_3',
 'immigrants_4',
 'immigrants_5',
 'immigrants_6',
 'immigrants_7',
 'immigrants_8',
 'isis_children_1',
 'isis_children_2',
 'nazi_submarine_2',
 'pope_francis_1',
 'snowboard_girl_1',
 'snowboard_girl_2',
 'syrian_children_2'}

In [49]:
images_train_not_available.add('boston_fake_10')

In [50]:
train_df = train_df[~train_df['first_image_id'].isin(images_train_not_available)]

In [130]:
test_df = test_df[~test_df['image_id(s)'].isin(images_test_not_available)]

In [13]:
train_df.shape

(13407, 8)

In [14]:
test_df.shape

(1040, 7)

In [None]:
train_df

In [52]:
# already exclude domain names that don't have image accessible 
train_domain_names_list = [img_id.split('_')[0] for img_id in train_df['first_image_id'].tolist()]
train_domain_names_set = set(train_domain_names_list)
train_domain_names_dict = dict()
domain_label_id = 0

for domain_name in train_domain_names_set:
    train_domain_names_dict[domain_name] = domain_label_id
    domain_label_id += 1

train_domain_names_dict

{'eclipse': 0,
 'columbianChemicals': 1,
 'passport': 2,
 'underwater': 3,
 'nepal': 4,
 'sandy': 5,
 'boston': 6,
 'garissa': 7,
 'samurai': 8,
 'livr': 9,
 'elephant': 10,
 'malaysia': 11,
 'bringback': 12,
 'pigFish': 13,
 'sochi': 14}

In [107]:
test_df

Unnamed: 0,post_id,post_text,user_id,username,image_id(s),timestamp,label
60,665333038944002048,"Tristesse...😢🙏 \n#Bataclan sold out, musiciens...",5.547267e+08,Louise_Officiel,attacks_paris_1,Sat Nov 14 00:58:52 +0000 2015,fake
61,665324167785410560,RT @Proyecto40: #ÚltimaHora Espectacular fotog...,2.827010e+09,MarinoCarril,attacks_paris_1,Sat Nov 14 00:23:37 +0000 2015,fake
62,665333370205765632,RT @Javivi1976: #Bataclan esta noche antes de ...,3.776037e+08,MireiaMaroto,attacks_paris_1,Sat Nov 14 01:00:11 +0000 2015,fake
63,665326188735295489,RT @Pizzigatas: El hombre tiene que establecer...,1.102841e+08,maria_cabrera,attacks_paris_1,Sat Nov 14 00:31:39 +0000 2015,fake
64,665389341141659648,🇫🇷 #Paris https://t.co/zjjRPC7USm,1.653118e+07,RaycMolina,attacks_paris_1,Sat Nov 14 04:42:35 +0000 2015,fake
...,...,...,...,...,...,...,...
2172,700662094107119616,14 children! 14 different fathers! All of them...,6.973105e+07,ErrBodyLuvsCris,woman_14_children_1,Fri Feb 19 12:43:55 +0000 2016,fake
2173,700638809155710977,Meet The Woman Who Has Given Birth To 14 Child...,4.912947e+09,EyoawansBlog,woman_14_children_1,Fri Feb 19 11:11:23 +0000 2016,fake
2174,700618091886047232,"RT @Viasat1Ghana: Woman, 36, gives birth to 14...",3.245410e+08,NanaEssilfie,woman_14_children_1,Fri Feb 19 09:49:04 +0000 2016,fake
2175,700569502874992640,Woman Breaks World Record With 14 Children fro...,4.068154e+08,Abiola_j360ent,woman_14_children_2,Fri Feb 19 06:36:00 +0000 2016,fake


## Define image transform

In [53]:
# pretrained_size = 128
# pretrained_size = 256
pretrained_size = 224
pretrained_means = [0.485, 0.456, 0.406]
pretrained_stds= [0.229, 0.224, 0.225]

# train_transforms = transforms.Compose([
#                            transforms.Resize(pretrained_size),
#                            transforms.RandomRotation(5),
#                            transforms.RandomHorizontalFlip(0.5),
#                            transforms.RandomCrop(pretrained_size, padding = 10),
#                            transforms.ToTensor(),
#                            transforms.Normalize(mean = pretrained_means, 
#                                                 std = pretrained_stds)
#                        ])

train_transforms = transforms.Compose([
                           transforms.ToTensor()
                       ])

# trial_transforms = transforms.Compose([
#                            transforms.Resize(pretrained_size),
#                            transforms.CenterCrop(pretrained_size),
#                            transforms.ToTensor(),
#                            transforms.Normalize(mean = pretrained_means, 
#                                                 std = pretrained_stds)
#                        ])

test_transforms = train_transforms

## Load pre-trained model

In [54]:
# Clip pretrained model for image encoding

clip_pretrained = SentenceTransformer('clip-ViT-B-32')

# Laser model for text encoding

laser_model = Laser()

## Self-define Dataset class

In [55]:
images = listdir(train_img_dir)
images.extend(listdir(test_img_dir))

jpg = []
png=[]
jpeg=[]
gif = []

for i in images:
    name,ext = i.split('.')[0],i.split('.')[-1]
    eval(ext).append(name)

In [56]:
def get_extension_of_file(file_name):
    if file_name in jpg:
        return '.jpg'
    elif file_name in png:
        return '.png'
    elif file_name in jpeg:
        return '.jpeg'
    else:
        return '.gif'

In [57]:
class TwitterDataset(Dataset):
    """Twitter dataset."""

    def __init__(self, df, root_dir, domain_labels_dict, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images
            transform (callable, optional): Optional transform to be applied
                to the image.
        """
        self.df = df
        self.root_dir = root_dir
        self.transform = transform
        self.domain_labels_dict = domain_labels_dict

    def __len__(self):
        return len(self.df)
    
    def read_and_process_image(self, list_of_images):
        X = [] 
        for image in tqdm(list_of_images):
            X.append(cv2.resize(cv2.imread(image, cv2.IMREAD_COLOR), (length,width), interpolation=cv2.INTER_CUBIC))  
        return X
    
    def get_news_domain(self, idx):
        img_id = self.df.iloc[idx, -1]
        news_domain = img_id.split('_')[0]
        return news_domain
    
    def get_extension_of_file(self, file_name):
        if file_name in jpg:
            return '.jpg'
        elif file_name in png:
            return '.png'
        elif file_name in jpeg:
            return '.jpeg'
        else:
            return '.gif'

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        img_name = os.path.join(self.root_dir,
                                self.df.iloc[idx, -1]) + self.get_extension_of_file(self.df.iloc[idx, -1])  # "image_id(s)"
        news_image = Image.open(img_name)#.convert("RGB")   # convert to RGB is important
        news_image = torch.Tensor(clip_pretrained.encode(news_image))
        labels = self.df.iloc[idx, -2]   # label
        
        if labels == 'real':
            labels = 0
        else:
            labels = 1
            
        labels = np.array(labels)
        labels = labels.astype('long')
        
        news_text = self.df.iloc[idx, 1]   # post_text
        news_text = torch.Tensor(laser_model.embed_sentences(news_text, lang='en'))
        
        news_domain = self.get_news_domain(idx)
        news_domain_label = self.domain_labels_dict[news_domain]
        
        news_domain_label = np.array(news_domain_label)
        news_domain_label = news_domain_label.astype('long')
        
        sample = {'news_image': news_image, 'labels': labels, 'news_text': news_text, 'news_domain_label': news_domain_label}

        if self.transform:
            sample['news_image'] = self.transform(news_image)

        return sample

In [58]:
class TwitterTestset(Dataset):
    """Twitter dataset testset."""

    def __init__(self, df, root_dir, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.df = df
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)
    
    def read_and_process_image(self, list_of_images):
        X = [] 
        for image in tqdm(list_of_images):
            X.append(cv2.resize(cv2.imread(image, cv2.IMREAD_COLOR), (length,width), interpolation=cv2.INTER_CUBIC))  
        return X
    
    def get_extension_of_file(self, file_name):
        if file_name in jpg:
            return '.jpg'
        elif file_name in png:
            return '.png'
        elif file_name in jpeg:
            return '.jpeg'
        else:
            return '.gif'

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.root_dir,
                                self.df.iloc[idx, 4]) + self.get_extension_of_file(self.df.iloc[idx, 4])
        news_image = Image.open(img_name)#.convert("RGB")   # convert to RGB is important
        news_image = torch.Tensor(clip_pretrained.encode(news_image))
        labels = self.df.iloc[idx, -1]   # label
        
        if labels == 'real':
            labels = 0
        else:
            labels = 1
            
        labels = np.array(labels)
        labels = labels.astype('long')
        
        news_text = self.df.iloc[idx, 1]   # post_text
        news_text = torch.Tensor(laser_model.embed_sentences(news_text, lang='en'))
        
        
        sample = {'news_image': news_image, 'news_text': news_text, 'labels': labels}

        if self.transform:
            sample['news_image'] = self.transform(news_image)

        return sample

## Instantiate the train and trial dataset

In [59]:
root_dir = '~/adversarial_learning/MVAE/mediaeval2016/'
# train_data = MAMIDataset(train_root_dir + 'training.csv', train_root_dir, train_transforms)
train_data = TwitterDataset(train_df, train_img_dir, train_domain_names_dict)

In [131]:
# train_data = MAMIDataset(train_root_dir + 'training.csv', train_root_dir, train_transforms)
test_data = TwitterTestset(test_df, test_img_dir)

## Train valid split

In [28]:
# VALID_RATIO = 0.9

# n_train_examples = int(len(train_data) * VALID_RATIO)
# n_valid_examples = len(train_data) - n_train_examples

# train_data, valid_data = data.random_split(train_data, 
#                                            [n_train_examples, n_valid_examples])

In [29]:
# valid_data = copy.deepcopy(valid_data)
# valid_data.dataset.transform = trial_transforms

In [61]:
print(f'Number of training examples: {len(train_data)}')
# print(f'Number of validation examples: {len(valid_data)}')
print(f'Number of testing examples: {len(test_data)}')

Number of training examples: 13407
Number of testing examples: 1059


## Create batch iterators

In [62]:
BATCH_SIZE = 64

In [63]:
train_iterator = data.DataLoader(train_data, 
                                 shuffle = True, 
                                 batch_size = BATCH_SIZE)

# valid_iterator = data.DataLoader(valid_data, 
#                                  batch_size = BATCH_SIZE)

In [132]:
test_iterator = data.DataLoader(test_data, 
                                batch_size = BATCH_SIZE)

## Construct VAE class

In [65]:
class GateModule(nn.Module):
    
    '''
    A simple gate implemented with
    - a linear layer
    - a sigmoid layer
    '''
    
    def __init__(self, in_dim):
        super(GateModule, self).__init__()
        
        self.in_dim = in_dim
        self.gate = nn.Linear(in_dim, 1)
        self.prob = nn.Sigmoid()
        
        
    def forward(self, x):
        score = self.gate(x)
        score = self.prob(score)
        
        return score

In [66]:
class VAE(nn.Module):
    
    '''
    Gate is applied only to the text input
    '''
    
    def __init__(self, zsize, output_dim=2):
        super(VAE, self).__init__()
        
        self.zsize = zsize
        self.fc1 = nn.Linear(zsize, zsize)   # 4 * 4 is the current size of the image
        self.fc2 = nn.Linear(zsize, zsize)

        ######
        # multi-tasks sub-networks
        self.fc_news = nn.Linear(zsize, output_dim)
        self.fc_domain = nn.Linear(zsize, 15)
        
        
        # encoder layers
        self.enc_txt_fc = nn.Linear(1024, int(0.5 * zsize))
        self.enc_img_fc1 = nn.Linear(512, int(0.5 * zsize))
#         self.enc_img_fc2 = nn.Linear(1024, int(0.5 * zsize))
        
        # decoder layers
        self.dec_txt_fc = nn.Linear(zsize, 1024)
        self.dec_img_fc1 = nn.Linear(zsize, 512)
#         self.dec_img_fc2 = nn.Linear(1024, 2048)

        # batch normalizations
        self.enc_txt_bn = nn.BatchNorm1d(num_features=int(0.5 * zsize))
        self.enc_img_bn1 = nn.BatchNorm1d(num_features=int(0.5 * zsize))
#         self.enc_img_bn2 = nn.BatchNorm1d(num_features=int(0.5 * zsize))
        
        self.dec_txt_bn = nn.BatchNorm1d(num_features=1024)
        self.dec_img_bn1 = nn.BatchNorm1d(num_features=512)
#         self.dec_img_bn2 = nn.BatchNorm1d(num_features=2048)
        
        # dropout
        self.dropout_txt_enc = nn.Dropout(0.2)
        self.dropout_img_enc = nn.Dropout(0.2)
        self.dropout_txt_dec = nn.Dropout(0.2)
        self.dropout_img_dec = nn.Dropout(0.2)
        
        # gated unit
        self.gated_unit = GateModule(1024)  # text embedding dimension
        
        
    def img_encode(self, x_img):
#         _, x_img = self.resnet_pretrained(x_img)
        x_img = F.relu(self.dropout_img_enc(self.enc_img_bn1(self.enc_img_fc1(x_img))))
#         x_img = F.relu(self.enc_img_fc2(x_img))
        
        return x_img   # [bs, 2048]

    def txt_encode(self, x_txt):

        gate = self.gated_unit(x_txt)
        x_txt = x_txt * gate
        
        x_txt = x_txt.view(x_txt.shape[0], 1024)
        x_txt = F.relu(self.dropout_txt_enc(self.enc_txt_bn(self.enc_txt_fc(x_txt))))
        return x_txt, gate   # [bs, 0.5 * zsize]

    def encode(self, x_img, x_txt):
        
        x_img = self.img_encode(x_img)
        
        x_txt, gate = self.txt_encode(x_txt)
        
        # concate x_img and x_txt
        x = torch.cat((x_txt, x_img), 1)
        
        h1 = self.fc1(x)   # mu
        h2 = self.fc2(x)   # logvar
        return h1, h2, gate
    
    def subtask_news(self, z):
        
        h = self.fc_news(z)
        return h
    
    def subtask_domain(self, z):
        
        h = self.fc_domain(z)
        return h

    def reparameterize(self, mu, logvar):
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return eps.mul(std).add_(mu)
        else:
            return mu

    def decode(self, x):
#         x = x.view(x.shape[0], self.zsize)   # flatten

        # Decoding txt
        dec_x_txt = F.relu(self.dropout_txt_dec(self.dec_txt_bn(self.dec_txt_fc(x))))
        
        # Decoding img
        dec_x_img = F.relu(self.dropout_img_dec(self.dec_img_bn1(self.dec_img_fc1(x))))
#         dec_x_img = F.relu(self.dec_img_fc2(dec_x_img))
        
        return dec_x_img, dec_x_txt

    def forward(self, x_img, x_txt):
        mu, logvar, gate = self.encode(x_img, x_txt)
        mu = mu.squeeze()
        logvar = logvar.squeeze()
        z = self.reparameterize(mu, logvar)

        y_news = self.subtask_news(z)
        y_domain = self.subtask_domain(z)
        
        y_pred = dict()
        y_pred["news"] = y_news
        y_pred["domain"] = y_domain
        
        dec_x_img, dec_x_txt = self.decode(z.view(-1, self.zsize))
        
        return dec_x_img, dec_x_txt, mu, logvar, y_pred, gate

    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)


In [119]:
class VAE(nn.Module):
    
    '''
    Gate is applied to both the text input & image input
    '''
    
    def __init__(self, zsize, output_dim=2):
        super(VAE, self).__init__()
        
        self.zsize = zsize
        self.fc1 = nn.Linear(zsize, zsize)   # 4 * 4 is the current size of the image
        self.fc2 = nn.Linear(zsize, zsize)

        ######
        # multi-tasks sub-networks
        self.fc_news = nn.Linear(zsize, output_dim)
        self.fc_domain = nn.Linear(zsize, 15)
        
        
        # encoder layers
        self.enc_txt_fc = nn.Linear(1024, int(0.5 * zsize))
        self.enc_img_fc1 = nn.Linear(512, int(0.5 * zsize))
#         self.enc_img_fc2 = nn.Linear(1024, int(0.5 * zsize))
        
        # decoder layers
        self.dec_txt_fc = nn.Linear(zsize, 1024)
        self.dec_img_fc1 = nn.Linear(zsize, 512)
#         self.dec_img_fc2 = nn.Linear(1024, 2048)

        # batch normalizations
        self.enc_txt_bn = nn.BatchNorm1d(num_features=int(0.5 * zsize))
        self.enc_img_bn1 = nn.BatchNorm1d(num_features=int(0.5 * zsize))
#         self.enc_img_bn2 = nn.BatchNorm1d(num_features=int(0.5 * zsize))
        
        self.dec_txt_bn = nn.BatchNorm1d(num_features=1024)
        self.dec_img_bn1 = nn.BatchNorm1d(num_features=512)
#         self.dec_img_bn2 = nn.BatchNorm1d(num_features=2048)
        
        # dropout
        self.dropout_txt_enc = nn.Dropout(0.2)
        self.dropout_img_enc = nn.Dropout(0.2)
        self.dropout_txt_dec = nn.Dropout(0.2)
        self.dropout_img_dec = nn.Dropout(0.2)
        
        # gated unit
        self.gate = GateModule(zsize)  # text embedding dimension
        
        
    def img_encode(self, x_img):
#         _, x_img = self.resnet_pretrained(x_img)
        x_img = F.relu(self.dropout_img_enc(self.enc_img_bn1(self.enc_img_fc1(x_img))))
#         x_img = F.relu(self.enc_img_fc2(x_img))
        
        return x_img   # [bs, 2048]

    def txt_encode(self, x_txt):

        x_txt = x_txt.view(x_txt.shape[0], 1024)
        x_txt = F.relu(self.dropout_txt_enc(self.enc_txt_bn(self.enc_txt_fc(x_txt))))
        return x_txt  # [bs, 0.5 * zsize]

    def encode(self, x_img, x_txt):
        
        x_img = self.img_encode(x_img)
        
        x_txt = self.txt_encode(x_txt)
        
        # concate x_img and x_txt
        x = torch.cat((x_txt, x_img), 1)
        
        p = self.gate(x)
        
        x = torch.cat((p * x_txt, (1 - p) * x_img), 1)
        
        h1 = self.fc1(x)   # mu
        h2 = self.fc2(x)   # logvar
        return h1, h2, p
    
    def subtask_news(self, z):
        
        h = self.fc_news(z)
        return h
    
    def subtask_domain(self, z):
        
        h = self.fc_domain(z)
        return h

    def reparameterize(self, mu, logvar):
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return eps.mul(std).add_(mu)
        else:
            return mu

    def decode(self, x):
#         x = x.view(x.shape[0], self.zsize)   # flatten

        # Decoding txt
        dec_x_txt = F.relu(self.dropout_txt_dec(self.dec_txt_bn(self.dec_txt_fc(x))))
        
        # Decoding img
        dec_x_img = F.relu(self.dropout_img_dec(self.dec_img_bn1(self.dec_img_fc1(x))))
#         dec_x_img = F.relu(self.dec_img_fc2(dec_x_img))
        
        return dec_x_img, dec_x_txt

    def forward(self, x_img, x_txt):
        mu, logvar, p = self.encode(x_img, x_txt)
        mu = mu.squeeze()
        logvar = logvar.squeeze()
        z = self.reparameterize(mu, logvar)

        y_news = self.subtask_news(z)
        y_domain = self.subtask_domain(z)
        
        y_pred = dict()
        y_pred["news"] = y_news
        y_pred["domain"] = y_domain
        
        dec_x_img, dec_x_txt = self.decode(z.view(-1, self.zsize))
        
        return dec_x_img, dec_x_txt, mu, logvar, y_pred, p

    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)


In [67]:
def normal_init(m, mean, std):
    if isinstance(m, nn.Linear):
        m.weight.data.normal_(mean, std)
        m.bias.data.zero_()

# Training

In [68]:
def loss_function(recon_x_img, recon_x_txt, x_img, x_txt, mu, logvar):
    BCE_img = torch.mean((recon_x_img - x_img)**2)
    BCE_txt = torch.mean((recon_x_txt - x_txt)**2)

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.mean(torch.mean(1 + logvar - mu.pow(2) - logvar.exp(), 1))
    return BCE_img, BCE_txt, KLD * 0.1

In [69]:
f = open("GatedVAE-uni-all-fake.txt", "w")

In [70]:
def main():
    
    name_dict = dict()
    name_dict["news"] = 0
    name_dict["domain"] = 1
    
    #batch_size = 32
    z_size = 512
#     z_size = 1024
    vae = VAE(z_size)
    vae.cuda()
    vae.train()
    vae.weight_init(mean=0, std=0.02)

    lr = 0.0001

    vae_optimizer = optim.Adam(vae.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=1e-5)
    
    criterion = nn.CrossEntropyLoss()
    criterion.to(device)
 
    train_epoch = 30

    
    dataloader = train_iterator
    
    f1_max = 0
    max_acc = 0
    
    for epoch in range(train_epoch):
        vae.train()

        rec_txt_loss = 0
        rec_img_loss = 0
        kl_loss = 0
        subtask_news_loss = 0
        subtask_domain_loss = 0

        epoch_start_time = time.time()

        if (epoch + 1) % 8 == 0:
            vae_optimizer.param_groups[0]['lr'] /= 4
#             print("learning rate change!")
            f.write("learning rate change! The learning rate is %1.4f now\n" % (lr))

#         i = 0
        acc = 0
        num = 0
        for i, data in tqdm(enumerate(dataloader, 0), desc='iterations'):
        #for x in batches:
            vae.train()
            
            #inputs, classes = data
            img_inputs = data['news_image']
            img_inputs = img_inputs.to(device)

            txt_inputs = data["news_text"]
            txt_inputs = txt_inputs.to(device)
            
            classes = data['labels']
            
            domain_classes = data['news_domain_label']
            
            # multi-task labels
            classes_news = classes
            classes_domain = domain_classes
            
            img_inputs, txt_inputs, classes_news = Variable(img_inputs), Variable(txt_inputs), Variable(classes_news)
            classes_domain = Variable(classes_domain)
        
            img_inputs = img_inputs.to(device)
            txt_inputs = txt_inputs.to(device)
            classes_news = classes_news.to(device)
            classes_domain = classes_domain.to(device)
            
            vae.zero_grad()
            rec_img, rec_txt, mu, logvar, y_pred, p = vae(img_inputs, txt_inputs)
#             if i == 0:
#                 print(p[:10])

            loss_re_img, loss_re_txt, loss_kl = loss_function(rec_img, rec_txt, img_inputs, txt_inputs, mu, logvar)
            loss_subtask_news = criterion(y_pred["news"], classes_news)
            loss_subtask_domain = criterion(y_pred["domain"], classes_domain)

            
            (loss_re_img + loss_re_txt + loss_kl + loss_subtask_news + loss_subtask_domain).backward() #\
#              + loss_subtask_shaming + loss_subtask_stereotype + loss_subtask_objectification\
#              + loss_subtask_violence).backward()
            
            vae_optimizer.step()
            rec_img_loss += loss_re_img.item()
            rec_txt_loss += loss_re_txt.item()
            
            kl_loss += loss_kl.item()
            subtask_news_loss += loss_subtask_news.item()
            subtask_domain_loss += loss_subtask_domain.item()
            
            # Calculate batch accuracy
            _, top_pred = y_pred["news"].topk(1, 1)
            y = classes_news.cpu()
            batch_size = y.shape[0]
            top_pred = top_pred.cpu().view(batch_size)
            acc += sum(top_pred == y).item()
            num += batch_size

            #############################################

            epoch_end_time = time.time()
            per_epoch_ptime = epoch_end_time - epoch_start_time

            # report losses and save samples each 60 iterations
            m = len(dataloader)
            i += 1
            if i % m == 0:
                rec_txt_loss /= m
                rec_img_loss /= m
                kl_loss /= m
                subtask_news_loss /= m
                subtask_domain_loss /= m
                
#                 print('\n[%d/%d] - ptime: %.2f, rec img loss: %.9f, rec txt loss: %.9f, KL loss: %.9f, misogynous loss: %.9f, shaming loss: %.9f, stereotype loss: %.9f, objectification loss: %.9f, violence loss: %.9f' % (
#                     (epoch + 1), train_epoch, per_epoch_ptime, rec_img_loss, rec_txt_loss, kl_loss, subtask_misogynous_loss, subtask_shaming_loss, subtask_stereotype_loss, subtask_objectification_loss, subtask_violence_loss))

#                 f.write('\n[%d/%d] - ptime: %.2f, rec img loss: %.9f, rec txt loss: %.9f, KL loss: %.9f, misogynous loss: %.9f, shaming loss: %.9f, stereotype loss: %.9f, objectification loss: %.9f, violence loss: %.9f\n' % (
#                     (epoch + 1), train_epoch, per_epoch_ptime, rec_img_loss, rec_txt_loss, kl_loss, subtask_misogynous_loss, subtask_shaming_loss, subtask_stereotype_loss, subtask_objectification_loss, subtask_violence_loss))
                f.write('\n[%d/%d] - ptime: %.2f, rec img loss: %.9f, rec txt loss: %.9f, KL loss: %.9f, news loss: %.9f, domain loss: %.9f\n' % (
                    (epoch + 1), train_epoch, per_epoch_ptime, rec_img_loss, rec_txt_loss, kl_loss, subtask_news_loss, subtask_domain_loss))
                rec_txt_loss = 0
                rec_img_loss = 0
                kl_loss = 0
                with torch.no_grad():
#                     test_loss, test_acc, test_accuracy, test_f1, test_recall, test_precision = evaluate(vae, valid_iterator, criterion, device, "misogynous")
                    test_loss, test_acc, test_accuracy, test_f1, test_recall, test_precision = evaluate(vae, test_iterator, criterion, device)
                    f.write(f'Test subtask news loss: {test_loss:.3f} | Test Acc @1: {test_acc*100:6.2f}%\n')
                    f.write(f'Test subtask news accuracy: {test_accuracy*100:6.2f}%\n')
                    f.write(f'Test subtask news f1: {test_f1*100:6.2f}%\n')
                    f.write(f'Test subtask news recall: {test_recall*100:6.2f}%\n')
                    f.write(f'Test subtask news precision: {test_precision*100:6.2f}%\n')
                    
                    
                    acc /= num
                    print(f'num_correct: {acc}\n')
                    print(f'total_num: {num}\n')
                    f.write(f'Training accuracy: {acc*100:6.2f}%\n')
                    
                    if test_f1*100 >= f1_max:
                        
                        torch.save(vae.state_dict(), "GatedVAE-uni-all-fake-%d.pkl" % (epoch+1))
                        f.write("Epoch [%d/%d]: test f1 on news improves from %.2f to %.2f, saving training results\n" % (epoch+1, train_epoch, f1_max, test_f1))
                        f1_max = test_f1*100

        f.flush()

    f.write("Training finish!... save training results\n")
    return vae

In [71]:
def calculate_accuracy(y_pred, y):
    with torch.no_grad():
        batch_size = y.shape[0]
        _, top_pred = y_pred.topk(1, 1)
        top_pred = top_pred.t()
        correct = top_pred.eq(y.view(1, -1).expand_as(top_pred))
        correct_1 = correct[:1].reshape(-1).float().sum(0, keepdim = True)
        acc_1 = correct_1 / batch_size
    
    top_pred = top_pred.cpu().view(batch_size)
    y = y.cpu()
    
    accuracy = accuracy_score(y, top_pred)
    #print("accuracy: {}".format(accuracy))

    f1 = f1_score(y, top_pred)
#     print(top_pred)
    #print("f1: {}".format(f1))

    recall = recall_score(y, top_pred)
    #print("recall: {}".format(recall))

    precision = precision_score(y, top_pred)
    #print("precision: {}".format(precision))

    cm = confusion_matrix(y, top_pred)
    #print("cm: {}".format(cm))
    return acc_1, accuracy, f1, recall, precision, cm

In [72]:
def evaluate(model, iterator, criterion, device, subtask_name="news"):
    
    epoch_loss = 0
    
    model.eval()
    
    y_true_all = []
    y_pred_all = []
    
    with torch.no_grad():
        
        for i, data in tqdm(enumerate(iterator, 0), desc='iterations'):

            x_img = data['news_image']
            x_img = x_img.to(device)
            
            x_txt = data['news_text']
            x_txt = x_txt.to(device)
            
            
            y = data['labels']
            batch_size = y.shape[0]
            
            x_img, x_txt = x_img.to(device), x_txt.to(device)

            _, _, _, _, y_pred, gate = model(x_img, x_txt)
                
                
            y_true_all += y.tolist()
            _, top_pred = y_pred["news"].topk(1, 1)
            top_pred = top_pred.t()
            top_pred = top_pred.cpu().view(batch_size)
            y_pred_all += top_pred.tolist()
            
            y = y.to(device)
            loss = criterion(y_pred["news"], y)
            epoch_loss += loss.item()
                
                
    epoch_accuracy = accuracy_score(y_true_all, y_pred_all)
    #print("accuracy: {}".format(accuracy))

    epoch_f1 = f1_score(y_true_all, y_pred_all)
#     print(top_pred)
    #print("f1: {}".format(f1))

    epoch_recall = recall_score(y_true_all, y_pred_all)
    #print("recall: {}".format(recall))

    epoch_precision = precision_score(y_true_all, y_pred_all)
    #print("precision: {}".format(precision))
                
        
    return epoch_loss / len(iterator), epoch_accuracy, epoch_accuracy, epoch_f1, epoch_recall, epoch_precision

In [None]:
vae = main()

f.close()

# Testing

In [74]:
def test(model, iterator, device):
    
    name_dict = dict()
    name_dict["news"] = 0
    
    y_test = dict()
    y_test["news"] = []
    
    y_true = []
    scores = []
    
    model.cuda()
    model.eval()
    
    with torch.no_grad():
        
        for i, data in tqdm(enumerate(iterator, 0), desc='iterations'):

            x_img = data['news_image']
            x_img = x_img.to(device)
            
            x_txt = data['news_text']
            label = data['labels']
            y_true += label.tolist()
            
            x_img, x_txt = x_img.to(device), x_txt.to(device)

            _, _, _, _, y_pred, gate_score = model(x_img, x_txt)
            
            scores.append(gate_score.squeeze(axis=1))
            
            
            for subtask_name, subtask_index in name_dict.items():
                subtask_y = y_pred[subtask_name].cpu()
                for dp in subtask_y:
                    if dp[0] >= dp[1]:
                        y_test[subtask_name].append(0)
                    else:
                        y_test[subtask_name].append(1)
        
    return y_test, y_true, scores

In [133]:
y_test, y_true, scores = test(best_VAE, test_iterator, device)

iterations: 0it [00:00, ?it/s]

In [None]:
from sklearn.metrics import precision_recall_fscore_support

precision_recall_fscore_support(y_true, y_pred, average=None, labels=[0, 1])