In [117]:
import json
import numpy as np
import pandas as pd
import glob2
import torch
import pickle
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy as np
import torch.nn as nn
import torch.optim as optim
import pickle as pkl
from sklearn.metrics import classification_report
from sklearn.metrics import accuracy_score
import os
from transformers import BertTokenizer, VisualBertModel
import re
import PIL
from torchvision.models import vgg16
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import transforms
from pprint import pprint

In [2]:
# load data
with open("./annotation.json","r") as f:
    annotation=json.load(f)

In [37]:
# getting a rough estimate of pillar-file count
unique_pillars=[v for item in annotation["Pillar_Stances"] for k,v in item.items()]
unique_pillars=[item[0][0] for item in unique_pillars ]
unique_pillars=list(set(unique_pillars))
unique_pillars=[v for item in annotation["Pillar_Stances"] for k,v in item.items()]
unique_pillars=[item[0][0] for item in unique_pillars ]
unique_pillars=list(set(unique_pillars))

pillar_file_dict={}
for item in annotation["Pillar_Stances"]:
    for k,items in item.items():
        for itm in items:
            if(len(set(itm[1]))==len(itm[1])):
                continue
            if(itm[0] not in pillar_file_dict):    
                pillar_file_dict[itm[0]]=[k] 
            else:
                pillar_file_dict[itm[0]].append(k) 
               
{pillar:len(_) for pillar,_ in pillar_file_dict.items()}        

{'Economic Defence': 155,
 'Psychological Defence': 323,
 'Social Defence': 43,
 'Military Defence': 602,
 'Digital Defence': 32,
 'Others': 674,
 'Civil Defence': 101}

In [38]:
# 60-20-20 split
train,val,test=[],[],[]
for pillar,items in pillar_file_dict.items():
    temp=np.random.choice(items,int(0.60*len(items)),replace=False)
    train.extend(temp)
    items=[item for item in items if item not in train]
    temp=np.random.choice(items,int(0.5*len(items)),replace=False)
    val.extend(temp)
    items=[item for item in items if item not in val]
    test.extend(items)
len(train),len(val),len(test)        

(1155, 379, 377)

In [41]:
# train split pillar-wise
# file_pillar_dict={file:k for k,files in pillar_file_dict.items() for file in files}
file_pillar_dict={}
for k,files in pillar_file_dict.items():
    for file in files:
        if(file in file_pillar_dict):
            file_pillar_dict[file].append(k)
        else:
            file_pillar_dict[file]=[k]
            
train_pillar_distr={pillar:0 for pillar in unique_pillars}

for item in train:
    for pillar in file_pillar_dict[item]:
        train_pillar_distr[pillar]+=1
        
val_pillar_distr={pillar:0 for pillar in unique_pillars}
for item in val:
    for pillar in file_pillar_dict[item]:
        val_pillar_distr[pillar]+=1
        
test_pillar_distr={pillar:0 for pillar in unique_pillars}
for item in test:
    for pillar in file_pillar_dict[item]:
        test_pillar_distr[pillar]+=1   
        
train_pillar_distr,val_pillar_distr,test_pillar_distr        

({'Economic Defence': 109,
  'Digital Defence': 20,
  'Military Defence': 372,
  'Others': 404,
  'Social Defence': 32,
  'Psychological Defence': 219,
  'Civil Defence': 61},
 {'Economic Defence': 32,
  'Digital Defence': 6,
  'Military Defence': 123,
  'Others': 135,
  'Social Defence': 10,
  'Psychological Defence': 68,
  'Civil Defence': 21},
 {'Economic Defence': 31,
  'Digital Defence': 7,
  'Military Defence': 117,
  'Others': 135,
  'Social Defence': 10,
  'Psychological Defence': 64,
  'Civil Defence': 24})

In [42]:
# train upsample
train_others=[item for item in train if 'Others' in file_pillar_dict[item]]
train_economic=[item for item in train if 'Economic Defence' in file_pillar_dict[item]]
train_social=[item for item in train if 'Social Defence' in file_pillar_dict[item]]
train_digital=[item for item in train if 'Digital Defence' in file_pillar_dict[item]]
train_psychological=[item for item in train if 'Psychological Defence' in file_pillar_dict[item]]
train_civil=[item for item in train if 'Civil Defence' in file_pillar_dict[item]]
train_military=[item for item in train if 'Military Defence' in file_pillar_dict[item]]

train_upsample=[]
train_upsample.extend(train_others)
# train_upsample.extend(np.random.choice(train_others,315-236))

train_upsample.extend(train_economic)
train_upsample.extend(np.random.choice(train_economic,404-72))

train_upsample.extend(train_social)
train_upsample.extend(np.random.choice(train_social,404-14)) 

train_upsample.extend(train_digital)
train_upsample.extend(np.random.choice(train_digital,404-16))

train_upsample.extend(train_psychological)
train_upsample.extend(np.random.choice(train_psychological,404-166))

train_upsample.extend(train_civil)
train_upsample.extend(np.random.choice(train_civil,404-55)) 

train_upsample.extend(train_military)
train_upsample.extend(np.random.choice(train_military,404-348)) 

In [29]:
data={}
pillars=[file for item in annotation["Pillar_Stances"] for file,_ in item.items()]
for item in annotation["Text"]:
    for file,text in item.items():
        if(file not in pillars):
            continue
        data[file]=(text,"./TD_Memes/{}".format(file))    

In [19]:
pillar_labels_dict={pillar:i for i,pillar in enumerate(unique_pillars)}
unique_stances=[pillar_stance for item in annotation["Pillar_Stances"] \
                 for _,pillar_stance in item.items() if len(pillar_stance)==1]
               
unique_stances=set([stance for item in unique_stances for stance in item[0][1]])   
unique_stances={stance:i for i,stance in enumerate(unique_stances)}
stance_labels_dict={}
index=0
for pillar,_ in pillar_labels_dict.items():
    stance_labels_dict[pillar]=[]
    for stance in unique_stances:
        stance_labels_dict[pillar].append(index)
        index+=1
print(pillar_labels_dict,stance_labels_dict)  

{'Economic Defence': 0, 'Digital Defence': 1, 'Military Defence': 2, 'Others': 3, 'Social Defence': 4, 'Psychological Defence': 5, 'Civil Defence': 6} {'Economic Defence': [0, 1, 2], 'Digital Defence': [3, 4, 5], 'Military Defence': [6, 7, 8], 'Others': [9, 10, 11], 'Social Defence': [12, 13, 14], 'Psychological Defence': [15, 16, 17], 'Civil Defence': [18, 19, 20]}


In [45]:
file_pillar_stance={file:pillar_stance for item in annotation["Pillar_Stances"] \
                 for file,pillar_stance in item.items()}
file_pillar_stance={file:{item[0]:stance for item in pillar_stance if len(set(item[1]))<len(item[1])\
                          for stance in item[1] if item[1].count(stance)>1}\
                    for file,pillar_stance in file_pillar_stance.items()\
             }

In [144]:
class ClassifierDataset(Dataset):
    def __init__(self,items):
        self.items=items
#         self.transform=transforms.Compose([
#             transforms.Resize((224,224)),
# #         transforms.CenterCrop((224,224)),
#             transforms.ToTensor(),
#             transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
#             ])

    def __len__(self): 
        return len(self.items)

    def __getitem__(self,index):        
        with open("./vgg_embeddings/{}".\
                  format(self.items[index]),"rb") as f:
            encoding=pkl.load(f)
        encoding=encoding.view(-1)    
#         encoding=torch.cat((encoding['text_embeds'],encoding['image_embeds']),1)  
        
#         img=PIL.Image.open(self.items[index]['image_path']).convert('RGB')
#         img=self.transform(img)
#         encoding=encoding.view(-1,4096)
#         encoding=resnet(img.unsqueeze(0)).view(-1,512)
        pillar_labels=torch.zeros(7)  
        stance_labels=torch.zeros(21)
        for k,v in file_pillar_stance[self.items[index]].items():
            pillar_labels[pillar_labels_dict[k]]=1
            if(v=="Supportive"):
                stance_labels[stance_labels_dict[k][0]]=1
            elif(v=="Neutral"):
                stance_labels[stance_labels_dict[k][1]]=1
            elif(v=="Against"):
                stance_labels[stance_labels_dict[k][2]]=1
                 
        return encoding,pillar_labels,stance_labels

In [145]:
class Pillar_Stance_Classifier(nn.Module):
    def __init__(self):
        super(Pillar_Stance_Classifier, self).__init__()
#         self.pretrained_model=VisualBertModel.from_pretrained('uclanlp/visualbert-nlvr2-coco-pre')
#         for param in self.pretrained_model.parameters():
#             param.requires_grad = True
#         self.pillar_cls_layer = torch.nn.Sequential(nn.Linear(25088, 4096),nn.ReLU(),nn.Linear(4096, 512),\
#                                                    nn.ReLU(),nn.Linear(512,7))#
#         self.stance_cls_layer = torch.nn.Sequential(nn.Linear(25088, 4096),nn.ReLU(),nn.Linear(4096, 512),\
#                                                    nn.ReLU(),nn.Linear(512,21))
        
        self.pillar_cls_layer = torch.nn.Sequential(nn.Linear(8*512, 64),nn.ReLU(),nn.Linear(64,7))#
        self.stance_cls_layer = torch.nn.Sequential(nn.Linear(8*512, 64),nn.ReLU(),nn.Linear(64,21))

    def forward(self,inp):
#         output=self.pretrained_model(inp)
        # for visualBERT
#         output=output.last_hidden_state
#         print(inp.shape)
        output_pillar=self.pillar_cls_layer(inp)
        output_stance=self.stance_cls_layer(inp)
        return output_pillar,output_stance

In [146]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model=Pillar_Stance_Classifier().to(device)
model.train()
train_set=ClassifierDataset(train_upsample)
val_set=ClassifierDataset(val)
train_loader=DataLoader(train_set, batch_size = 64)
val_loader = DataLoader(val_set, batch_size = 64)   
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr = 1e-3)

In [147]:
val_losses=[]
for _e in range(100):
    train_loss=0
    for t, (embedding, pillar_labels,stance_labels) in enumerate(train_loader):
        embedding=embedding.to(device)
#         input_ids=embedding['input_ids'].to(device)
#         token_type_ids=embedding['token_type_ids'].to(device)
#         attention_mask=embedding['attention_mask'].to(device)
#         visual_embeds=embedding['visual_embeds'].to(device)
        
        pillar_labels =pillar_labels.to(device) 
        stance_labels =stance_labels.to(device) 

        optimizer.zero_grad()
        pillar_logits,stance_logits=model(embedding)
        pillar_loss = criterion(pillar_logits.squeeze(1), pillar_labels.float())
        stance_loss = criterion(stance_logits.squeeze(1), stance_labels.float())
        loss=pillar_loss+stance_loss
        train_loss+=loss.data.item()
        loss.backward()
        optimizer.step()
    train_loss= np.mean(train_loss)
    val_loss=0
    for t, (embedding,pillar_labels,stance_labels) in enumerate(val_loader):
        embedding=embedding.to(device)
        pillar_labels =pillar_labels.to(device) 
        stance_labels =stance_labels.to(device)  

        optimizer.zero_grad()
        pillar_logits,stance_logits=model(embedding)
        pillar_loss = criterion(pillar_logits.squeeze(1), pillar_labels.float())
        stance_loss = criterion(stance_logits.squeeze(1), stance_labels.float())
        val_loss+=(pillar_loss+stance_loss).data.item()
    val_loss= np.mean(val_loss)   
    if(len(val_losses)>0 and val_loss<min(val_losses)):
          torch.save(model.state_dict(), 'model_vgg.pt')  
    val_losses.append(val_loss)      
    print('training loss:{} validation loss:{}'.format(train_loss,val_loss))

training loss:46.44727572798729 validation loss:4.816993623971939
training loss:35.94687816500664 validation loss:4.2128506898880005
training loss:27.825812816619873 validation loss:4.093301296234131
training loss:23.760528579354286 validation loss:3.922414720058441
training loss:20.794313341379166 validation loss:3.8788132071495056
training loss:18.481240831315517 validation loss:3.896069347858429
training loss:16.747145354747772 validation loss:3.8111177384853363
training loss:14.803797036409378 validation loss:3.817850649356842
training loss:13.234844997525215 validation loss:3.8421100676059723
training loss:12.025866065174341 validation loss:3.760339319705963
training loss:10.84834874048829 validation loss:3.804943561553955
training loss:9.825224686414003 validation loss:3.806150257587433
training loss:8.939231241121888 validation loss:3.8802632689476013
training loss:8.114565830677748 validation loss:3.976857304573059
training loss:7.493706060573459 validation loss:3.9789298772811

In [148]:
# Evaluation on test set
model=Pillar_Stance_Classifier().to(device)
test_set=ClassifierDataset(test)
test_loader=DataLoader(test_set, batch_size = 1,shuffle=False)
model.load_state_dict(torch.load("model_vgg.pt"))
model.eval()
preds=[]
pillar_outputs=[]
stance_outputs=[]
pillar_labels_gt=[]
stance_labels_gt=[]
stance_output_proba=[]
filenames=[]
with torch.no_grad():
    for t, (embedding,pillar_labels,stance_labels) in enumerate(test_loader):
        embedding=embedding.to(device)
        pillar_labels =pillar_labels.to(device) 
        stance_labels =stance_labels.to(device)  
        pillar_logits,stance_logits=model(embedding)
        pillar_output,stance_output=torch.sigmoid(pillar_logits),torch.sigmoid(stance_logits)
        temp=[0 if item<0.5 else 1 for item in pillar_output.view(-1).tolist()]
        pillar_outputs.append(temp) 
        stance_outputs.append([0 if item<0.5 else 1 for item in stance_output.view(-1).tolist()])
        stance_output_proba.append(stance_output.view(-1).tolist())
        pillar_labels_gt.append(pillar_labels.view(-1).tolist())
        stance_labels_gt.append(stance_labels.view(-1).tolist())

In [149]:
gt=[[item[n*3:(n+1)*3] for n in range(7)] for item in stance_labels_gt]
stance_outputs=[[item[n*3:(n+1)*3] for n in range(7)] for item in stance_output_proba]

pred_filtered=[]
gt_filtered=[]
for i,item in enumerate(pillar_outputs):
    temp=np.array(item)
    indices=[it for item in np.where(temp==1) for it in item]
    for indx in indices:
        arg=np.argmax(stance_outputs[i][indx])
        
        if(arg==0):
            pred_filtered.append(1)
        elif(arg==1):
            pred_filtered.append(2)
        elif(arg==2):
            pred_filtered.append(3)
        if(gt[i][indx]==[0,0,0]):
            gt_filtered.append(0)
        elif(gt[i][indx]==[1,0,0]):
            gt_filtered.append(1)
        elif(gt[i][indx]==[0,1,0]):
            gt_filtered.append(2)
        elif(gt[i][indx]==[0,0,1]):
            gt_filtered.append(3)    

pprint(classification_report(gt_filtered,pred_filtered))

('              precision    recall  f1-score   support\n'
 '\n'
 '           0       0.00      0.00      0.00        91\n'
 '           1       1.00      0.18      0.31        11\n'
 '           2       0.41      0.76      0.54        25\n'
 '           3       0.24      0.93      0.39        28\n'
 '\n'
 '    accuracy                           0.30       155\n'
 '   macro avg       0.41      0.47      0.31       155\n'
 'weighted avg       0.18      0.30      0.18       155\n')


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [150]:
pprint(classification_report(pillar_labels_gt,pillar_outputs))

('              precision    recall  f1-score   support\n'
 '\n'
 '           0       0.50      0.13      0.21        31\n'
 '           1       0.00      0.00      0.00         7\n'
 '           2       0.51      0.19      0.28       117\n'
 '           3       0.00      0.00      0.00       135\n'
 '           4       1.00      0.30      0.46        10\n'
 '           5       0.35      0.55      0.42        64\n'
 '           6       0.00      0.00      0.00        24\n'
 '\n'
 '   micro avg       0.41      0.16      0.24       388\n'
 '   macro avg       0.34      0.17      0.20       388\n'
 'weighted avg       0.28      0.16      0.18       388\n'
 ' samples avg       0.15      0.15      0.15       388\n')


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [151]:
accuracy_score(pillar_labels_gt,pillar_outputs)

0.14058355437665782

In [152]:
accuracy_score(gt_filtered,pred_filtered)

0.3032258064516129