## Import packages

In [1]:
from __future__ import print_function
import torch.utils.data
from scipy import misc
from torch import optim
from torchvision.utils import save_image
import numpy as np
import pickle
import time
import random
import os
import torch
from torch import nn
from torch.nn import functional as F
from tqdm.auto import tqdm, trange
from torchvision import transforms
import pandas as pd
import torchvision.datasets as datasets
import torch.utils.data as data
import copy
from torch.autograd import Variable
from torch.utils.data import Dataset
from skimage import io
from PIL import Image

from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score, confusion_matrix
from transformers import AutoTokenizer, BertTokenizer, BertModel, BertForSequenceClassification
from collections import namedtuple
import torchvision.models as models
from sentence_transformers import SentenceTransformer
from laserembeddings import Laser

from os import listdir

In [2]:
tqdm.pandas()

## Enable gpu device

In [3]:
device = torch.device('cuda:0')

## Set random seed

In [4]:
SEED = 8888

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True

## Load data

In [5]:
# load training label
train_df_fake = pd.read_csv('~/adversarial_learning/weibo/tweets/train_rumor.txt', sep='|', header = None)
train_df_real = pd.read_csv('~/adversarial_learning/weibo/tweets/train_nonrumor.txt', sep='|', header = None)
# load test label
test_df_fake = pd.read_csv('~/adversarial_learning/weibo/tweets/test_rumor.txt', sep='|', header = None)
test_df_real = pd.read_csv('~/adversarial_learning/weibo/tweets/test_nonrumor.txt', sep='|', header = None)

In [6]:
for index, row in test_df_real.iterrows():
    for col in range(1, 10):
        if not pd.isna(row[col]) and 'sinaimg' in row[col]:
            test_df_real.at[index, 0] = test_df_real.at[index, 0] + ',' + row[col]
#             row[0] = row[0] + ',' + row[col]

In [7]:
for index, row in train_df_real.iterrows():
    for col in range(1, 10):
        if not pd.isna(row[col]) and 'sinaimg' in row[col]:
            train_df_real.at[index, 0] = train_df_real.at[index, 0] + ',' + row[col]
#             row[0] = row[0] + ',' + row[col]

In [8]:
# news text
train_fake = train_df_fake[0].tolist()
train_real = train_df_real[0].tolist()
test_fake = test_df_fake[0].tolist()
test_real = test_df_real[0].tolist()

In [9]:
len(test_fake)

2934

In [9]:
# ensure each news has 3 lines
def fix_offset(list_):
    fixed_flag = False

    while not fixed_flag:
        exit_flag=False
        temp = copy.deepcopy(list_)
        for i,v in enumerate(temp):
            if v!=None:
                if 'sinaimg.cn' in v:
                    if list_[i+1] !=None:
                        if list_[i+1].isdigit():
                            list_.insert(i+1,None)
                            exit_flag=True
                            break
        if not exit_flag:
            fixed_flag=True
            
    return list_

In [10]:
train_fake = fix_offset(train_fake) 
train_real = fix_offset(train_real)
test_fake = fix_offset(test_fake)
test_real = fix_offset(test_real)

In [12]:
len(train_fake)

11244

In [11]:
# 3 lines -> 1 news
def break_in_block(list_):
    temp = []
    for i in range(0,len(list_),3):
        temp.append(list_[i:i+3])
    return temp

In [12]:
train_fake = break_in_block(train_fake)
train_real = break_in_block(train_real)
test_fake = break_in_block(test_fake)
test_real = break_in_block(test_real)

In [13]:
len(train_fake),len(train_real),len(test_fake),len(test_real)

(3748, 3783, 1000, 996)

In [13]:
def get_image_and_text_list(blocks_list):
    image_list = []
    text_list = []
    for i in blocks_list:
        if i[-1] !=None:
            image_list.append(i[1])
            text_list.append(i[-1])
#     image_list = [i.split('/')[-1] for i in image_list]
    image_list = [linklist.split(',') for linklist in image_list]
    return image_list, text_list

In [14]:
train_fake_image,train_fake_text = get_image_and_text_list(train_fake)
train_real_image,train_real_text = get_image_and_text_list(train_real)
test_fake_image,test_fake_text = get_image_and_text_list(test_fake)
test_real_image,test_real_text = get_image_and_text_list(test_real)

In [None]:
train_real_image[:100]

In [16]:
len(train_fake_text)

3698

In [17]:
len(train_real_text)

3783

In [18]:
len(test_fake_text)

934

In [19]:
len(test_real_text)

996

In [15]:
train_fake_Y = [1]*len(train_fake_image)
train_real_Y = [0]*len(train_real_image)
test_fake_Y = [1]*len(test_fake_image)
test_real_Y = [0]*len(test_real_image)

In [16]:
train_images = train_fake_image+train_real_image
train_text = train_fake_text + train_real_text
trainY = train_fake_Y+train_real_Y

test_images = test_fake_image+test_real_image
test_text = test_fake_text+test_real_text
testY = test_fake_Y+test_real_Y

In [22]:
len(train_images),len(train_text),len(trainY),len(test_images),len(test_text),len(testY)

(7481, 7481, 7481, 1930, 1930, 1930)

In [17]:
train_images = np.array(train_images, dtype=object)
train_text = np.array(train_text)
trainY = np.array(trainY)
test_images = np.array(test_images, dtype=object)
test_text = np.array(test_text)
testY = np.array(testY)

In [18]:
def index_to_delete(list_):
    list_images_dir = listdir('../../../weibo/images')
    gif_list = ['957e1cf2tw1e5foxts295g206o03p4qp.gif','a716fd45jw1ev0cgf8j46g209505zh4i.gif','005vnhZYgw1evupo8ttddg308w06o4qp.gif','7da75521gw1ele2jvi85rg2096056u0x.gif']
    index = []
    image_name_list = []
    for i,vlist in enumerate(list_):
        exclude = True
        for v in vlist:
            vname = v.split('/')[-1]
            if vname in list_images_dir:
                image_name_list.append(vname)
                exclude = False
                break
            if vname in gif_list:
                image_name_list.append(vname)
                exclude = True
                break
        if exclude:
            image_name_list.append("excluded")
            index.append(i)
                
    return index, image_name_list

In [19]:
train_delete_index, train_images =index_to_delete(train_images)
test_delete_index, test_images = index_to_delete(test_images)
len(train_delete_index)+len(test_delete_index)

1558

In [None]:
test_delete_index

In [20]:
train_images = np.delete(train_images,train_delete_index)
train_text = np.delete(train_text,train_delete_index)
trainY = np.delete(trainY,train_delete_index)
test_images = np.delete(test_images,test_delete_index)
test_text = np.delete(test_text,test_delete_index)
testY = np.delete(testY,test_delete_index)

In [23]:
# real news im train
(trainY==0).sum()

2807

In [26]:
# fake news in train
(trainY==1).sum()

3347

In [27]:
(testY==0).sum()

835

In [28]:
(testY==1).sum()

864

In [21]:
shuffle_index= np.arange(len(train_images))
np.random.shuffle(shuffle_index)
train_images = train_images[shuffle_index]
train_text = train_text[shuffle_index]
trainY = trainY[shuffle_index]

In [25]:
shuffle_index

array([1370, 4821, 4817, ..., 2120, 3299, 4483])

In [22]:
len(train_images),len(train_text),len(trainY),len(test_images),len(test_text),len(testY)

(6154, 6154, 6154, 1699, 1699, 1699)

In [103]:
train_text.shape

(6154,)

## Define image transform

In [23]:
# pretrained_size = 128
# pretrained_size = 256
pretrained_size = 224
pretrained_means = [0.485, 0.456, 0.406]
pretrained_stds= [0.229, 0.224, 0.225]

train_transforms = transforms.Compose([
                           transforms.ToTensor()
                       ])

test_transforms = train_transforms

## Load pre-trained model

In [22]:
# Clip pretrained model for image encoding

clip_pretrained = SentenceTransformer('clip-ViT-B-32')

# Laser model for text encoding

laser_model = Laser()

## Self-define Dataset class

In [23]:
class WeiboDataset(Dataset):
    """Weibo dataset."""

    def __init__(self, texts, images, labels, root_dir, transform=None):
        """
        Args:
            texts (np.array): train texts
            images (np.array): train images
            labels (np.array): train labels
            root_dir (string): root directory of images
            transform (callable, optional): Optional transform to be applied
                to the image.
        """
        self.texts = texts
        self.images = images
        self.labels = labels
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return self.texts.shape[0]
    
    def read_and_process_image(self, list_of_images):
        X = [] 
        for image in tqdm(list_of_images):
            X.append(cv2.resize(cv2.imread(image, cv2.IMREAD_COLOR), (length,width), interpolation=cv2.INTER_CUBIC))  
        return X
    

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        img_name = os.path.join(self.root_dir, self.images[idx])  # "image_id(s)"
        image = Image.open(img_name)#.convert("RGB")   # convert to RGB is important
        image = torch.Tensor(clip_pretrained.encode(image))
        label = self.labels[idx]   # label
        
        text = self.texts[idx]   # post_text
        text = torch.Tensor(laser_model.embed_sentences(text, lang='zh'))
        
        sample = {'news_image': image, 'label': label, 'news_text': text}

        if self.transform:
            sample['news_image'] = self.transform(news_image)

        return sample

In [24]:
class WeiboTestset(Dataset):
    """Twitter dataset testset."""

    def __init__(self, texts, images, labels, root_dir, transform=None):
        """
        Args:
            texts (np.array): train texts
            images (np.array): train images
            labels (np.array): train labels
            root_dir (string): root directory of images
            transform (callable, optional): Optional transform to be applied
                to the image.
        """
        self.texts = texts
        self.images = images
        self.labels = labels
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return self.texts.shape[0]
    
    def read_and_process_image(self, list_of_images):
        X = [] 
        for image in tqdm(list_of_images):
            X.append(cv2.resize(cv2.imread(image, cv2.IMREAD_COLOR), (length,width), interpolation=cv2.INTER_CUBIC))  
        return X

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.root_dir, self.images[idx])  # "image_id(s)"
        image = Image.open(img_name)#.convert("RGB")   # convert to RGB is important
        image = torch.Tensor(clip_pretrained.encode(image))
        label = self.labels[idx]   # label
        
        text = self.texts[idx]   # post_text
        text = torch.Tensor(laser_model.embed_sentences(text, lang='zh'))
        
        sample = {'news_image': image, 'label': label, 'news_text': text}

        if self.transform:
            sample['news_image'] = self.transform(news_image)

        return sample

## Instantiate the train and trial dataset

In [25]:
root_dir = '../../../weibo/images'
train_data = WeiboDataset(train_text, train_images, trainY, root_dir)

In [26]:
test_data = WeiboTestset(test_text, test_images, testY, root_dir)

## Train valid split

In [88]:
# VALID_RATIO = 0.9

# n_train_examples = int(len(train_data) * VALID_RATIO)
# n_valid_examples = len(train_data) - n_train_examples

# train_data, valid_data = data.random_split(train_data, 
#                                            [n_train_examples, n_valid_examples])

In [28]:
print(f'Number of training examples: {len(train_data)}')
print(f'Number of testing examples: {len(test_data)}')

Number of training examples: 6154
Number of testing examples: 1699


## Create batch iterators

In [27]:
BATCH_SIZE = 64

In [28]:
train_iterator = data.DataLoader(train_data, 
                                 shuffle = True, 
                                 batch_size = BATCH_SIZE)

In [29]:
test_iterator = data.DataLoader(test_data, 
                                batch_size = BATCH_SIZE)

# Gate Module

In [30]:
class GateModule(nn.Module):
    
    '''
    A simple gate implemented with
    - a linear layer
    - a sigmoid layer
    '''
    
    def __init__(self, in_dim):
        super(GateModule, self).__init__()
        
        self.in_dim = in_dim
        self.gate = nn.Linear(in_dim, 1)
        self.prob = nn.Sigmoid()
        
        
    def forward(self, x):
        score = self.gate(x)
        score = self.prob(score)
        
        return score

## Construct VAE class
- gate is applied only to the text input

In [31]:
class VAE(nn.Module):
    '''
    Gate is applied only to the text input
    '''
    def __init__(self, zsize, output_dim=2):
        super(VAE, self).__init__()
        
        self.zsize = zsize
        self.fc1 = nn.Linear(zsize, zsize)   # 4 * 4 is the current size of the image
        self.fc2 = nn.Linear(zsize, zsize)

        ######
        # multi-tasks sub-networks
        self.fc_news = nn.Linear(zsize, output_dim)
        self.fc_shaming = nn.Linear(zsize, output_dim)
        self.fc_stereotype = nn.Linear(zsize, output_dim)
        self.fc_objectification = nn.Linear(zsize, output_dim)
        self.fc_violence = nn.Linear(zsize, output_dim)
        
        
        # encoder layers
        self.enc_txt_fc = nn.Linear(1024, int(0.5 * zsize))
        self.enc_img_fc1 = nn.Linear(512, int(0.5 * zsize))
#         self.enc_img_fc2 = nn.Linear(1024, int(0.5 * zsize))
        
        # decoder layers
        self.dec_txt_fc = nn.Linear(zsize, 1024)
        self.dec_img_fc1 = nn.Linear(zsize, 512)
#         self.dec_img_fc2 = nn.Linear(1024, 2048)

        # batch normalizations
        self.enc_txt_bn = nn.BatchNorm1d(num_features=int(0.5 * zsize))
        self.enc_img_bn1 = nn.BatchNorm1d(num_features=int(0.5 * zsize))
#         self.enc_img_bn2 = nn.BatchNorm1d(num_features=int(0.5 * zsize))
        
        self.dec_txt_bn = nn.BatchNorm1d(num_features=1024)
        self.dec_img_bn1 = nn.BatchNorm1d(num_features=512)
#         self.dec_img_bn2 = nn.BatchNorm1d(num_features=2048)
        
        # dropout
        self.dropout_txt_enc = nn.Dropout(0.2)
        self.dropout_img_enc = nn.Dropout(0.2)
        self.dropout_txt_dec = nn.Dropout(0.2)
        self.dropout_img_dec = nn.Dropout(0.2)
        
        # gated unit
        self.gated_unit = GateModule(1024)  # text embedding dimension
        
        
    def img_encode(self, x_img):
#         _, x_img = self.resnet_pretrained(x_img)
        x_img = F.relu(self.dropout_img_enc(self.enc_img_bn1(self.enc_img_fc1(x_img))))
#         x_img = F.relu(self.enc_img_fc2(x_img))
        
        return x_img   # [bs, 2048]

    def txt_encode(self, x_txt):

        gate = self.gated_unit(x_txt)
        x_txt = x_txt * gate
        x_txt = x_txt.view(x_txt.shape[0], 1024)
        x_txt = F.relu(self.dropout_txt_enc(self.enc_txt_bn(self.enc_txt_fc(x_txt))))
        return x_txt, gate   # [bs, 0.5 * zsize]

    def encode(self, x_img, x_txt):
        
        x_img = self.img_encode(x_img)
        
        x_txt, gate = self.txt_encode(x_txt)
        
        # concate x_img and x_txt
        x = torch.cat((x_txt, x_img), 1)
        
        h1 = self.fc1(x)   # mu
        h2 = self.fc2(x)   # logvar
        return h1, h2, gate
    
    def subtask_news(self, z):
        
        h = self.fc_news(z)
        return h

    def reparameterize(self, mu, logvar):
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return eps.mul(std).add_(mu)
        else:
            return mu

    def decode(self, x):
#         x = x.view(x.shape[0], self.zsize)   # flatten

        # Decoding txt
        dec_x_txt = F.relu(self.dropout_txt_dec(self.dec_txt_bn(self.dec_txt_fc(x))))
        
        # Decoding img
        dec_x_img = F.relu(self.dropout_img_dec(self.dec_img_bn1(self.dec_img_fc1(x))))
#         dec_x_img = F.relu(self.dec_img_fc2(dec_x_img))
        
        return dec_x_img, dec_x_txt

    def forward(self, x_img, x_txt):
        mu, logvar, p = self.encode(x_img, x_txt)
        mu = mu.squeeze()
        logvar = logvar.squeeze()
        z = self.reparameterize(mu, logvar)

        y_news = self.subtask_news(z)
        
        y_pred = dict()
        y_pred["news"] = y_news
        
        dec_x_img, dec_x_txt = self.decode(z.view(-1, self.zsize))
        
        return dec_x_img, dec_x_txt, mu, logvar, y_pred, p

    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)


- gate is applied to both the text input & image input

In [44]:
class VAE(nn.Module):
    
    '''
    Gate is applied to both the text input & image input
    '''
    
    def __init__(self, zsize, output_dim=2):
        super(VAE, self).__init__()
        
        self.zsize = zsize
        self.fc1 = nn.Linear(zsize, zsize)   # 4 * 4 is the current size of the image
        self.fc2 = nn.Linear(zsize, zsize)

        ######
        # multi-tasks sub-networks
        self.fc_news = nn.Linear(zsize, output_dim)
        self.fc_domain = nn.Linear(zsize, 15)
        
        
        # encoder layers
        self.enc_txt_fc = nn.Linear(1024, int(0.5 * zsize))
        self.enc_img_fc1 = nn.Linear(512, int(0.5 * zsize))
#         self.enc_img_fc2 = nn.Linear(1024, int(0.5 * zsize))
        
        # decoder layers
        self.dec_txt_fc = nn.Linear(zsize, 1024)
        self.dec_img_fc1 = nn.Linear(zsize, 512)
#         self.dec_img_fc2 = nn.Linear(1024, 2048)

        # batch normalizations
        self.enc_txt_bn = nn.BatchNorm1d(num_features=int(0.5 * zsize))
        self.enc_img_bn1 = nn.BatchNorm1d(num_features=int(0.5 * zsize))
#         self.enc_img_bn2 = nn.BatchNorm1d(num_features=int(0.5 * zsize))
        
        self.dec_txt_bn = nn.BatchNorm1d(num_features=1024)
        self.dec_img_bn1 = nn.BatchNorm1d(num_features=512)
#         self.dec_img_bn2 = nn.BatchNorm1d(num_features=2048)
        
        # dropout
        self.dropout_txt_enc = nn.Dropout(0.2)
        self.dropout_img_enc = nn.Dropout(0.2)
        self.dropout_txt_dec = nn.Dropout(0.2)
        self.dropout_img_dec = nn.Dropout(0.2)
        
        # gated unit
        self.gate = GateModule(zsize)  # text embedding dimension
        
        
    def img_encode(self, x_img):
#         _, x_img = self.resnet_pretrained(x_img)
        x_img = F.relu(self.dropout_img_enc(self.enc_img_bn1(self.enc_img_fc1(x_img))))
#         x_img = F.relu(self.enc_img_fc2(x_img))
        
        return x_img   # [bs, 2048]

    def txt_encode(self, x_txt):

        x_txt = x_txt.view(x_txt.shape[0], 1024)
        x_txt = F.relu(self.dropout_txt_enc(self.enc_txt_bn(self.enc_txt_fc(x_txt))))
        return x_txt  # [bs, 0.5 * zsize]

    def encode(self, x_img, x_txt):
        
        x_img = self.img_encode(x_img)
        
        x_txt = self.txt_encode(x_txt)
        
        # concate x_img and x_txt
        x = torch.cat((x_txt, x_img), 1)
        
        p = self.gate(x)
        
        x = torch.cat((p * x_txt, (1 - p) * x_img), 1)
        
        h1 = self.fc1(x)   # mu
        h2 = self.fc2(x)   # logvar
        return h1, h2, p
    
    def subtask_news(self, z):
        
        h = self.fc_news(z)
        return h
    
    def subtask_domain(self, z):
        
        h = self.fc_domain(z)
        return h

    def reparameterize(self, mu, logvar):
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return eps.mul(std).add_(mu)
        else:
            return mu

    def decode(self, x):
#         x = x.view(x.shape[0], self.zsize)   # flatten

        # Decoding txt
        dec_x_txt = F.relu(self.dropout_txt_dec(self.dec_txt_bn(self.dec_txt_fc(x))))
        
        # Decoding img
        dec_x_img = F.relu(self.dropout_img_dec(self.dec_img_bn1(self.dec_img_fc1(x))))
#         dec_x_img = F.relu(self.dec_img_fc2(dec_x_img))
        
        return dec_x_img, dec_x_txt

    def forward(self, x_img, x_txt):
        mu, logvar, p = self.encode(x_img, x_txt)
        mu = mu.squeeze()
        logvar = logvar.squeeze()
        z = self.reparameterize(mu, logvar)

        y_news = self.subtask_news(z)
        y_domain = self.subtask_domain(z)
        
        y_pred = dict()
        y_pred["news"] = y_news
        y_pred["domain"] = y_domain
        
        dec_x_img, dec_x_txt = self.decode(z.view(-1, self.zsize))
        
        return dec_x_img, dec_x_txt, mu, logvar, y_pred, p

    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)


In [32]:
def normal_init(m, mean, std):
    if isinstance(m, nn.Linear):
        m.weight.data.normal_(mean, std)
        m.bias.data.zero_()

# Training

In [35]:
def loss_function(recon_x_img, recon_x_txt, x_img, x_txt, mu, logvar):
    BCE_img = torch.mean((recon_x_img - x_img)**2)
    BCE_txt = torch.mean((recon_x_txt - x_txt)**2)

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.mean(torch.mean(1 + logvar - mu.pow(2) - logvar.exp(), 1))
    return BCE_img, BCE_txt, KLD * 0.1

In [41]:
f = open("VAE_weibo_p(1-p)_gated_result.txt", "w")

In [42]:
def main():
    
    name_dict = dict()
    name_dict["news"] = 0
#     name_dict["shaming"] = 1
#     name_dict["stereotype"] = 2
#     name_dict["objectification"] = 3
#     name_dict["violence"] = 4
    
    #batch_size = 32
    z_size = 512
#     z_size = 1024
    vae = VAE(z_size)
    vae.cuda()
    vae.train()
    vae.weight_init(mean=0, std=0.02)

    lr = 0.0001

    vae_optimizer = optim.Adam(vae.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=1e-5)
    
    criterion = nn.CrossEntropyLoss()
    criterion.to(device)
 
    train_epoch = 30

    
    dataloader = train_iterator
    
    f1_max = 0
    max_acc = 0
    
    for epoch in range(train_epoch):
        vae.train()

        rec_txt_loss = 0
        rec_img_loss = 0
        kl_loss = 0
        subtask_news_loss = 0
#         subtask_shaming_loss = 0
#         subtask_stereotype_loss = 0
#         subtask_objectification_loss = 0
#         subtask_violence_loss = 0

        epoch_start_time = time.time()

        if (epoch + 1) % 8 == 0:
            vae_optimizer.param_groups[0]['lr'] /= 4
#             print("learning rate change!")
            f.write("learning rate change! The learning rate is %1.4f now\n" % (lr))

#         i = 0
        acc = 0
        num = 0
        for i, data in tqdm(enumerate(dataloader, 0), desc='iterations'):
        #for x in batches:
            vae.train()
            
            #inputs, classes = data
            img_inputs = data['news_image']
            img_inputs = img_inputs.to(device)

            txt_inputs = data["news_text"]
            txt_inputs = txt_inputs.to(device)
            
            classes = data['label']
            
            # multi-task labels
            classes_news = classes
            
            img_inputs, txt_inputs, classes_news = Variable(img_inputs), Variable(txt_inputs), Variable(classes_news)
        
            img_inputs = img_inputs.to(device)
            txt_inputs = txt_inputs.to(device)
            classes_news = classes_news.to(device)
            
            vae.zero_grad()
#             rec, mu, logvar = vae(x)
            rec_img, rec_txt, mu, logvar, y_pred, p = vae(img_inputs, txt_inputs)

            loss_re_img, loss_re_txt, loss_kl = loss_function(rec_img, rec_txt, img_inputs, txt_inputs, mu, logvar)
            loss_subtask_news = criterion(y_pred["news"], classes_news)
            
            (loss_re_img + loss_re_txt + loss_kl + loss_subtask_news).backward()
            
            vae_optimizer.step()
            rec_img_loss += loss_re_img.item()
            rec_txt_loss += loss_re_txt.item()
            
            kl_loss += loss_kl.item()
            subtask_news_loss += loss_subtask_news.item()
            
            # Calculate batch accuracy
            _, top_pred = y_pred["news"].topk(1, 1)
            y = classes_news.cpu()
            batch_size = y.shape[0]
            top_pred = top_pred.cpu().view(batch_size)
            acc += sum(top_pred == y).item()
            num += batch_size

            #############################################

            epoch_end_time = time.time()
            per_epoch_ptime = epoch_end_time - epoch_start_time

            # report losses and save samples each 60 iterations
            m = len(dataloader)
            i += 1
            if i % m == 0:
                rec_txt_loss /= m
                rec_img_loss /= m
                kl_loss /= m
                subtask_news_loss /= m

                f.write('\n[%d/%d] - ptime: %.2f, rec img loss: %.9f, rec txt loss: %.9f, KL loss: %.9f, news loss: %.9f\n' % (
                    (epoch + 1), train_epoch, per_epoch_ptime, rec_img_loss, rec_txt_loss, kl_loss, subtask_news_loss))
                rec_txt_loss = 0
                rec_img_loss = 0
                kl_loss = 0
                with torch.no_grad():
#                     test_loss, test_acc, test_accuracy, test_f1, test_recall, test_precision = evaluate(vae, valid_iterator, criterion, device, "misogynous")
                    test_loss, test_accuracy, test_f1, test_recall, test_precision = evaluate(vae, test_iterator, criterion, device)
                    f.write(f'Test subtask news loss: {test_loss:.3f} | Test Acc @1: {test_accuracy*100:6.2f}%\n')
                    f.write(f'Test subtask news accuracy: {test_accuracy*100:6.2f}%\n')
                    f.write(f'Test subtask news f1: {test_f1*100:6.2f}%\n')
                    f.write(f'Test subtask news recall: {test_recall*100:6.2f}%\n')
                    f.write(f'Test subtask news precision: {test_precision*100:6.2f}%\n')
                    
                    
                    acc /= num
                    print(f'num_correct: {acc}\n')
                    print(f'total_num: {num}\n')
                    f.write(f'Training accuracy: {acc*100:6.2f}%\n')
                    
                    if test_f1*100 >= f1_max:
                        
                        torch.save(vae.state_dict(), "VAEmodel-weibo-p(1-p)-gated-epoch-%d.pkl" % (epoch+1))
                        f.write("Epoch [%d/%d]: test f1 on news improves, saving training results\n" % (epoch+1, train_epoch))
                        f1_max = test_f1*100

        f.flush()

    f.write("Training finish!... save training results\n")
    return vae

In [36]:
def calculate_accuracy(y_pred, y):
    with torch.no_grad():
        batch_size = y.shape[0]
        _, top_pred = y_pred.topk(1, 1)
        top_pred = top_pred.t()
        correct = top_pred.eq(y.view(1, -1).expand_as(top_pred))
        correct_1 = correct[:1].reshape(-1).float().sum(0, keepdim = True)
        acc_1 = correct_1 / batch_size
    
    top_pred = top_pred.cpu().view(batch_size)
    y = y.cpu()
    
    accuracy = accuracy_score(y, top_pred)
    #print("accuracy: {}".format(accuracy))

    f1 = f1_score(y, top_pred)
#     print(top_pred)
    #print("f1: {}".format(f1))

    recall = recall_score(y, top_pred)
    #print("recall: {}".format(recall))

    precision = precision_score(y, top_pred)
    #print("precision: {}".format(precision))

    cm = confusion_matrix(y, top_pred)
    #print("cm: {}".format(cm))
    return acc_1, accuracy, f1, recall, precision, cm

In [33]:
def evaluate(model, iterator, criterion, device, subtask_name="news"):
    
    epoch_loss = 0
    
    model.eval()
    
    y_true_all = []
    y_pred_all = []
    
    with torch.no_grad():
        
        for i, data in tqdm(enumerate(iterator, 0), desc='iterations'):

            x_img = data['news_image']
            x_img = x_img.to(device)
            
            x_txt = data['news_text']
            x_txt = x_txt.to(device)
            
            
            y = data['label']
            batch_size = y.shape[0]
            
            x_img, x_txt = x_img.to(device), x_txt.to(device)

            _, _, _, _, y_pred, gate = model(x_img, x_txt)
                
                
            y_true_all += y.tolist()
            _, top_pred = y_pred["news"].topk(1, 1)
            top_pred = top_pred.t()
            top_pred = top_pred.cpu().view(batch_size)
            y_pred_all += top_pred.tolist()
            
            y = y.to(device)
            loss = criterion(y_pred["news"], y)
            epoch_loss += loss.item()
                
                
    epoch_accuracy = accuracy_score(y_true_all, y_pred_all)
    #print("accuracy: {}".format(accuracy))

    epoch_f1 = f1_score(y_true_all, y_pred_all)
#     print(top_pred)
    #print("f1: {}".format(f1))

    epoch_recall = recall_score(y_true_all, y_pred_all)
    #print("recall: {}".format(recall))

    epoch_precision = precision_score(y_true_all, y_pred_all)
    #print("precision: {}".format(precision))
                
        
    return epoch_loss / len(iterator), epoch_accuracy, epoch_f1, epoch_recall, epoch_precision

In [None]:
vae = main()

f.close()

# Testing

In [34]:
def test(model, iterator, device):
    
    name_dict = dict()
    name_dict["news"] = 0
    
    y_test = dict()
    y_test["news"] = []
    
    model.cuda()
    model.eval()
    scores = []
    
    with torch.no_grad():
        
        for i, data in tqdm(enumerate(iterator, 0), desc='iterations'):

            x_img = data['news_image']
            x_img = x_img.to(device)
            
            x_txt = data['news_text']
            
            x_img, x_txt = x_img.to(device), x_txt.to(device)

            _, _, _, _, y_pred, gated_score = model(x_img, x_txt)
            scores.append(gated_score.squeeze(axis=1))
            
            
            for subtask_name, subtask_index in name_dict.items():
                subtask_y = y_pred[subtask_name].cpu()
                for dp in subtask_y:
                    if dp[0] >= dp[1]:
                        y_test[subtask_name].append(0)
                    else:
                        y_test[subtask_name].append(1)
        
    return y_test, scores

In [46]:
y_test, scores = test(best_VAE, test_iterator, device)

iterations: 0it [00:00, ?it/s]

In [None]:
from sklearn.metrics import precision_recall_fscore_support

precision_recall_fscore_support(y_true, y_pred, average=None, labels=[0, 1])