In [1]:
# Code to mount the drive 
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [2]:
!pip install -U sentence-transformers

Collecting sentence-transformers
  Using cached sentence-transformers-2.1.0.tar.gz (78 kB)
Collecting sentencepiece
  Downloading sentencepiece-0.1.96-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
[K     |████████████████████████████████| 1.2 MB 4.2 MB/s 
Building wheels for collected packages: sentence-transformers
  Building wheel for sentence-transformers (setup.py) ... [?25l[?25hdone
  Created wheel for sentence-transformers: filename=sentence_transformers-2.1.0-py3-none-any.whl size=121000 sha256=0bf4b8107722ec5796b53c1914eecb5fca9eb1e16c24c168f923ed49aeba5e03
  Stored in directory: /root/.cache/pip/wheels/90/f0/bb/ed1add84da70092ea526466eadc2bfb197c4bcb8d4fa5f7bad
Successfully built sentence-transformers
Installing collected packages: sentencepiece, sentence-transformers
Successfully installed sentence-transformers-2.1.0 sentencepiece-0.1.96


In [3]:
cd '/content/drive/MyDrive/685/catr'

/content/drive/.shortcut-targets-by-id/12c1zkm0_oa8VcOfsUa8Tn_YcYdyPpSlP/685/catr


In [4]:
!pip install -r requirements.txt



In [5]:
# from transformers import ViTModel, ViTConfig, ViTFeatureExtractor,BertTokenizer,BertForMaskedLM
from transformers import DeiTFeatureExtractor, DeiTModel #AutoFeatureExtractor, DeiTForImageClassificationWithTeacher, 
from sentence_transformers import SentenceTransformer
import torch
from sklearn.model_selection import KFold
from sklearn.metrics import auc, precision_score, recall_score,roc_auc_score
import xgboost as xgb
from PIL import Image
import argparse
import glob
import json
import pandas as pd
import numpy as np
from models import caption
from datasets import coco, utils
from tqdm import tqdm
from configuration import Config
from xgboost import XGBClassifier
import os

In [6]:
def extract_vision_transformer_feats(image_path,nsel_st=None,nsel_end=None):
  """
  Function to extract the features from vision transformers
  """
  if not nsel_st:
    img_files_list = glob.glob(image_path+"*")
  else:
    img_files_list = glob.glob(image_path+"*")[nsel_st:nsel_end]
  
  # Create image batch array for Vision Transformer
  img_batch = []
  for file in tqdm(img_files_list):
    img = np.asarray(Image.open(file))
    newsize = (240, 240, 3)
    img = np.resize(img,newsize)
    img_batch.append(img)
  print("Creation of image batches to be used for Vision Transformer complete")

  
  # Extract features from ViTModel
  # feature_extractor   = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
  # model               = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
  feature_extractor   =  DeiTFeatureExtractor.from_pretrained('facebook/deit-tiny-distilled-patch16-224')
  model               = DeiTModel.from_pretrained('facebook/deit-tiny-distilled-patch16-224')
  inputs              = feature_extractor(images=img_batch, return_tensors="pt")
  outputs             = model(**inputs)

  # Get image representations and their corresponding hashes i.e. get [CLS] token representation for each image
  img_representations = outputs.last_hidden_state[:,0,:]
  img_hash_li         = []
  for file in img_files_list:
    img_hash_li.append(file.split("/")[-1][:-4])
  
  # Create column names for image dimensions 
  col_img  = ["imdim_"+str(i) for i in list(range(192))]

  # Create a dataframe of image features
  img_data = pd.DataFrame(img_representations.detach().numpy(),columns = col_img)
  img_data['img_hash'] = img_hash_li

  return img_data

def extract_sentence_transformer_feats(reference_file_pth = '/content/drive/MyDrive/685/emogen/Classifier/train/caption/',csv_file = 'train_caption.csv',lab_assign=0):
  """
  Extract features for emotion related texts
  """
  # Get reference text data
  combined_path = reference_file_pth + csv_file
  #file = open(combined_path)
  #data_json = json.load(file)
  data = pd.read_csv(combined_path)#pd.DataFrame.from_dict(data_json['annotations']).reset_index(drop=True)
  try:
    del data['Unnamed: 0']
  except:
    pass
  
  # Use Sentence Transformer to extract features
  model = SentenceTransformer('all-mpnet-base-v2')
  sentence_embeddings = model.encode(data['comment'])
  col_text = ["tdim_"+str(i) for i in list(range(768))]

  # Create text feature dataframe
  text_data = pd.DataFrame(sentence_embeddings,columns=col_text)
  text_data['img_hash'] = data['image_hash']
  text_data['label'] = lab_assign

  return text_data

def xgb_train_kfold(X_trn, y_trn,n_splits=5,max_depth=3,n_estimator=150, rand_st=3815):
  """
  Perform training with XGBoost and evaluate in K-Fold cross-validation settings 
  """
  errors    = []
  precision = []
  recall    = []
  auc       = []
  kf = KFold(n_splits=n_splits, shuffle=True, random_state=3815)

  for train_index, test_index in tqdm(kf.split(X_trn)):
     X_train_n, X_test_n = X_trn[train_index], X_trn[test_index]
     y_train_n, y_test_n = y_trn[train_index], y_trn[test_index]

     model = XGBClassifier(
         max_depth=max_depth, n_estimators=n_estimator, random_state = rand_st
     )
     model.fit(X_train_n, y_train_n)
     y_pred = model.predict(X_test_n)
     accuracy = (sum(y_pred == y_test_n))/len(y_test_n)
     errors.append(1 - accuracy)
     precision.append(precision_score(y_test_n,y_pred))
     recall.append(recall_score(y_test_n,y_pred))
     auc.append(roc_auc_score(y_test_n,y_pred))

  return model, errors, precision, recall, auc

LOAD AND MERGE SARCASM DATA

In [7]:
sarc_im_feats = pd.read_csv("/content/drive/MyDrive/685/sarcasm/evaluation/Vision_trans/sarc_image_feats.csv")
#del sarc_im_feats['Unnamed: 0']
sarc_text_feats = pd.read_csv("/content/drive/MyDrive/685/sarcasm/evaluation/Vision_trans/sarc_text_feats.csv")
sarc_data = sarc_im_feats.merge(sarc_text_feats,how='left',on='img_hash')


In [8]:
sarc_data.head()

Unnamed: 0.1,Unnamed: 0,imdim_0,imdim_1,imdim_2,imdim_3,imdim_4,imdim_5,imdim_6,imdim_7,imdim_8,imdim_9,imdim_10,imdim_11,imdim_12,imdim_13,imdim_14,imdim_15,imdim_16,imdim_17,imdim_18,imdim_19,imdim_20,imdim_21,imdim_22,imdim_23,imdim_24,imdim_25,imdim_26,imdim_27,imdim_28,imdim_29,imdim_30,imdim_31,imdim_32,imdim_33,imdim_34,imdim_35,imdim_36,imdim_37,imdim_38,...,tdim_729,tdim_730,tdim_731,tdim_732,tdim_733,tdim_734,tdim_735,tdim_736,tdim_737,tdim_738,tdim_739,tdim_740,tdim_741,tdim_742,tdim_743,tdim_744,tdim_745,tdim_746,tdim_747,tdim_748,tdim_749,tdim_750,tdim_751,tdim_752,tdim_753,tdim_754,tdim_755,tdim_756,tdim_757,tdim_758,tdim_759,tdim_760,tdim_761,tdim_762,tdim_763,tdim_764,tdim_765,tdim_766,tdim_767,label
0,0,-0.22587,-0.072451,0.150466,-0.310028,0.181258,-0.111621,-0.218812,0.188915,-0.22375,0.293365,0.054259,-0.003534,0.028545,0.135578,-0.102298,0.115907,-0.234384,0.228016,-0.186387,-0.079287,-0.029969,0.191448,-0.033465,-0.026993,0.118669,-0.11538,-0.102989,-0.210903,-0.009456,0.212527,-0.157775,-0.191561,0.009277,-0.233902,-0.349467,-0.059341,0.006363,0.0114,-0.134651,...,0.019627,0.00907,0.00061,-0.037166,-0.010591,-0.080931,0.052879,-0.067864,-0.037016,0.012656,0.002938,-0.025422,-0.062199,0.014316,0.00541,-0.008976,0.024902,-0.040613,-0.046314,-0.063835,-0.031816,0.039533,0.007972,0.044383,0.015157,0.058999,0.007948,1.223041e-34,0.01829,-0.03181,0.040135,0.034153,-0.015313,0.03322,0.045355,-0.087061,-0.011676,-0.02905,0.017642,1
1,1,0.155836,0.183927,-0.113086,-0.211043,-0.050718,0.087473,-0.119086,0.113785,-0.135824,0.022124,-0.026887,0.222732,-0.308995,-0.017612,-0.045683,-0.149274,0.325512,0.137425,-0.154343,0.0828,-0.096126,-0.241386,0.21912,0.047626,-0.1321,-0.158255,0.001754,-0.297031,0.071361,0.346881,-0.365549,0.022101,0.324265,-0.233232,-0.02551,0.122959,-0.292333,-0.05207,-0.027558,...,-0.005163,-0.002975,-0.019531,-0.048703,-0.064064,0.058073,0.042521,-0.02694,-0.044119,0.041224,0.080253,0.028264,-0.035621,0.031429,0.011012,-0.012145,-0.007684,0.030985,0.015418,-0.029362,-0.021399,0.062678,-0.023348,-0.01158,0.014696,-0.098282,-0.052649,1.206995e-34,-0.021449,-0.030386,0.058879,-0.01231,-0.002889,0.016357,0.033415,0.011526,0.029874,0.005155,0.026764,1
2,2,0.144119,0.014812,-0.247498,0.014472,0.338527,0.047474,0.063485,0.213259,0.047971,0.255797,0.059354,0.295847,-0.028295,0.179616,-0.440915,-0.072989,-0.154069,0.093858,0.024484,0.268147,-0.079271,-0.15071,0.162194,-0.305082,0.054961,0.105624,-0.294796,-0.145289,-0.22864,-0.023027,-0.040286,-0.441639,0.283062,0.080178,0.169652,0.037028,0.328282,-0.115819,0.020436,...,-0.031204,0.01888,-0.037276,0.032894,-0.043752,-0.032232,-0.03915,-0.068707,0.01413,0.053535,-0.03242,0.025046,0.002365,-0.003655,0.043359,-0.006344,0.004472,0.003814,-0.006376,0.023385,-0.027005,0.065918,-0.01007,0.077457,0.048885,-0.008758,-0.013489,1.678468e-34,-0.014916,-0.005187,0.053153,0.004884,0.054973,0.004999,0.021592,-0.032027,0.081238,-0.028547,0.003424,1
3,3,-0.471492,0.52079,-0.215634,0.106593,-0.195663,-0.385999,-0.0171,0.061736,-0.044317,-0.158382,-0.238115,0.06172,-0.185766,0.225334,-0.187225,-0.343042,-0.200512,0.073748,-0.276734,-0.044295,-0.121544,0.066809,0.09179,-0.061537,-0.107678,-0.088795,-0.380247,-0.142605,-0.200039,0.135739,0.039265,-0.05932,-0.034385,0.006086,0.155901,-0.238889,-0.161051,-0.271378,0.125379,...,-0.004707,-0.135561,0.058689,0.059448,0.025717,0.034499,-0.021846,-0.044487,-0.025175,-0.060257,-0.015774,-0.000877,-0.019358,0.089503,-0.017624,0.048682,0.044661,-0.004987,-0.006545,0.038779,0.005664,0.041696,0.002619,-0.038359,-0.026387,-0.032622,0.008995,8.694685999999999e-35,-0.037594,-0.009374,0.042049,-0.020785,0.002107,-0.005249,0.036761,-0.064584,0.028553,-0.08572,-0.00635,1
4,4,0.050874,0.070601,-0.182553,-0.184191,0.137736,-0.07895,0.106485,-0.11117,-0.081032,-0.00642,-0.091333,-0.135893,-0.462799,-0.083964,-0.427585,-0.028577,0.088369,-0.050035,0.123073,-0.049001,0.02501,-0.293847,-0.069006,-0.191857,0.029378,-0.09717,-0.035348,-0.320575,-0.056257,0.144316,0.309451,0.003123,0.093818,0.00905,0.106148,-0.173324,0.089836,-0.17861,0.146477,...,0.032927,-0.050616,0.030051,0.012412,0.050817,-0.082734,-0.032172,0.039175,-0.040866,-0.006501,-0.003492,0.023298,-0.009949,-0.008447,0.028277,-0.013939,-0.044873,-0.032059,0.017057,-0.016643,-0.049195,0.071174,-0.011549,-0.041143,0.003758,-0.021582,-0.0416,1.00707e-34,-0.021432,0.021747,-0.036588,0.057104,0.01681,-0.013712,-0.019614,0.032964,-0.008239,-0.04123,-0.017547,1


LOAD AND MERGE NON-SARCASM DATA

In [9]:
non_sarc_im_feats = pd.read_csv("/content/drive/MyDrive/685/sarcasm/evaluation/Vision_trans/non_sarc_image_feats.csv")
#del non_sarc_im_feats['personality']
non_sarc_text_feats = pd.read_csv("/content/drive/MyDrive/685/sarcasm/evaluation/Vision_trans/non_sarc_text_feats.csv")
non_sarc_data = non_sarc_im_feats.merge(non_sarc_text_feats, how='left', on='img_hash')

In [10]:
non_sarc_data.shape

(986, 1539)

In [11]:
non_sarc_data.head()

Unnamed: 0.1,Unnamed: 0,imdim_0,imdim_1,imdim_2,imdim_3,imdim_4,imdim_5,imdim_6,imdim_7,imdim_8,imdim_9,imdim_10,imdim_11,imdim_12,imdim_13,imdim_14,imdim_15,imdim_16,imdim_17,imdim_18,imdim_19,imdim_20,imdim_21,imdim_22,imdim_23,imdim_24,imdim_25,imdim_26,imdim_27,imdim_28,imdim_29,imdim_30,imdim_31,imdim_32,imdim_33,imdim_34,imdim_35,imdim_36,imdim_37,imdim_38,...,tdim_729,tdim_730,tdim_731,tdim_732,tdim_733,tdim_734,tdim_735,tdim_736,tdim_737,tdim_738,tdim_739,tdim_740,tdim_741,tdim_742,tdim_743,tdim_744,tdim_745,tdim_746,tdim_747,tdim_748,tdim_749,tdim_750,tdim_751,tdim_752,tdim_753,tdim_754,tdim_755,tdim_756,tdim_757,tdim_758,tdim_759,tdim_760,tdim_761,tdim_762,tdim_763,tdim_764,tdim_765,tdim_766,tdim_767,label
0,0,-0.154419,0.237147,-0.260658,0.079626,0.174629,-0.037283,0.181275,-0.091481,0.075997,0.222657,0.262198,-0.004729,0.085308,0.160205,-0.334296,-0.66796,-0.058677,-0.00447,-0.38603,0.089395,-0.019256,-0.282389,0.41744,-0.034918,0.079087,-0.243365,-0.051581,0.129912,-0.207757,-0.026039,-0.153776,0.213071,-0.000963,-0.242769,-0.015171,-0.316231,0.026008,-0.164314,0.062776,...,-0.026012,0.07594,-0.05826,-0.01321,0.021054,0.016669,-0.037853,-0.035031,-0.058536,-0.035216,0.007962,0.011898,0.013748,0.026339,-0.000724,0.010676,0.031866,0.020021,-0.028645,0.05024,0.025347,0.058877,-0.044578,0.048524,-0.032181,-0.013451,-0.045294,1.426474e-34,-0.005135,-0.029989,-0.020345,0.069942,-0.012176,-0.007818,-0.037268,0.021345,0.027128,-0.010693,-0.013343,0
1,1,-0.161675,-0.257757,0.129757,-0.131377,0.103156,0.2195,0.019781,-0.121115,0.002601,0.195443,-0.095846,0.208402,0.151647,0.319363,-0.152885,-0.255986,-0.056179,0.0818,-0.269129,0.215181,-0.110072,-0.174208,0.065126,-0.061155,-0.013843,-0.184014,0.206003,-0.064154,0.366466,-0.198762,-0.439766,-0.055843,-0.11324,-0.324072,0.054468,-0.1374,0.183585,-0.259974,0.030896,...,-0.043704,0.040994,-0.030339,-0.050234,0.012174,-0.059127,0.057507,0.030304,0.011525,0.061023,-0.040212,0.033082,-0.000639,-0.099027,0.001826,0.013611,0.053717,0.02413,0.022648,0.014245,0.001712,-0.052697,0.001435,0.000319,-0.062704,-0.030512,-0.005182,1.2041019999999999e-34,-0.00302,0.007886,0.011674,-0.074185,0.014488,-0.00027,0.072506,-0.042129,0.043323,0.019933,0.00334,0
2,2,0.13147,-0.335437,0.321686,0.187901,0.015675,-0.128293,0.221141,0.092105,0.246844,0.173362,0.007657,-0.106794,0.466101,-0.154448,0.049409,0.043912,0.018389,-0.153249,-0.140493,-0.135432,0.170132,-0.131044,0.325657,0.023387,-0.173288,0.059496,0.430015,-0.270507,0.180084,0.061132,-0.077753,-0.007201,0.042139,-0.012061,-0.109858,0.359709,-0.403691,0.016272,0.047788,...,0.042483,-0.001541,0.020764,0.004387,0.038008,0.060733,0.010056,0.070263,-0.005679,0.019557,0.041257,-0.054194,-0.051864,-0.047664,-0.037677,-0.038524,0.016576,0.021669,0.009959,-0.066785,-0.021846,0.004668,-0.05915,0.008056,-0.039005,0.028288,-0.024557,1.288498e-34,-0.034598,-0.04603,0.000561,0.025131,-0.016563,0.015124,0.003021,0.007628,-0.047651,0.007275,-0.035924,0
3,3,-0.105391,0.213757,-0.336909,-0.043461,-0.173466,0.245793,0.022145,-0.223629,-0.032584,-0.092182,-0.123961,0.006555,0.030328,-0.010635,-0.186472,-0.454358,-0.272072,0.173445,-0.213262,0.174757,-0.017561,0.012228,-0.061416,-0.085261,-0.170512,-0.128842,-0.120785,-0.219215,-0.154824,0.087677,-0.10743,-0.030405,0.19494,-0.141551,-0.277199,0.058098,0.14979,0.136208,0.399637,...,-0.02781,0.032527,-0.007415,-0.040548,0.020756,-0.055996,-0.025245,0.021376,0.03752,0.023097,0.029806,0.027258,-0.051254,0.003571,0.004211,-0.022645,0.007237,0.016698,0.012267,-0.061543,-0.036716,0.047153,-0.016412,0.020594,-0.024371,-0.006575,-0.007845,1.552384e-34,-0.035044,0.002738,0.004355,-0.01225,-0.035933,-0.013462,0.029829,0.002006,0.024476,0.018645,0.007668,0
4,4,-0.031014,0.438352,-0.158522,0.049595,-0.111022,-0.186727,-0.082618,0.039569,0.028709,-0.188519,-0.117406,0.100266,-0.306316,0.16741,-0.221251,-0.197687,-0.309434,0.010931,-0.013964,0.080365,-0.033406,0.108233,0.11942,-0.146304,-0.16811,-0.040105,-0.176627,-0.177462,0.024066,0.165547,-0.153441,0.062883,-0.145174,0.191917,0.330765,-0.050291,-0.029352,-0.349351,0.026374,...,0.006671,-0.029744,0.02576,0.011779,-0.018666,-0.006933,0.03958,0.076663,0.043282,0.045108,0.075613,-0.008196,-0.019812,-0.043344,0.029961,-0.037226,0.025132,2.5e-05,-0.008822,0.045939,-0.030032,0.044008,0.071944,-0.056872,0.013245,-0.117237,0.015699,1.615427e-34,0.046269,0.042866,0.048755,-0.023851,-7.2e-05,0.013613,0.032727,-0.050776,-0.054672,0.019995,0.002507,0


MERGE SARCASM AND NON-SARCASM DATA TO CREATE A COMBINED TRAINING DATA

In [12]:
sarcasm_combo_data = pd.concat([sarc_data,non_sarc_data],axis=0)

In [13]:
sarcasm_combo_data.to_csv("/content/drive/MyDrive/685/sarcasm/evaluation/Vision_trans/non_sarc_img_text_binary_class_data.csv",index=False)

READ SARCASM INFO 

In [14]:
sarcasm_mmodal_data = pd.read_csv("/content/drive/MyDrive/685/sarcasm/evaluation/Vision_trans/non_sarc_img_text_binary_class_data.csv")

In [15]:
sarcasm_mmodal_data['label'].value_counts(True)*100

1    50.352467
0    49.647533
Name: label, dtype: float64

PREPARE DATA FOR TRAINING 

In [16]:
from sklearn.model_selection import train_test_split
X_trn, X_tst, y_trn, y_tst = train_test_split(sarcasm_mmodal_data.drop(columns=['label','img_hash'],axis=1), sarcasm_mmodal_data['label'], test_size=0.15)

In [17]:
X_train = X_trn.values
y_train = y_trn.values
X_test  = X_tst.values
y_test  = y_tst.values 

TRAIN XGBOOST MODEL ON MULTIMODAL FEATURES

In [18]:
model, errors, precision, recall, auc = xgb_train_kfold(X_train,y_train,n_estimator=500,max_depth=5,n_splits=10)

10it [13:44, 82.48s/it]


In [19]:
print("The training errors on average is: ", np.round(np.mean(errors),4)*100)
print("The Precision on average is: ", np.round(np.mean(precision),4)*100)
print("The Recall on average is: ", np.round(np.mean(recall),4)*100)
print("The AUC Score on average is: ", np.round(np.mean(auc),4)*100)

The training errors on average is:  32.940000000000005
The Precision on average is:  66.71000000000001
The Recall on average is:  69.66
The AUC Score on average is:  67.08


In [20]:
pred_test = model.predict(X_test)
pred_prob = model.predict_proba(X_test)
accuracy = (sum(pred_test == y_test))*100/len(y_test)
precision = precision_score(y_test,pred_test)
recall    = recall_score(y_test,pred_test)
auc       = roc_auc_score(y_test,pred_test)
accuracy,precision,recall, auc

(64.09395973154362, 0.6184971098265896, 0.722972972972973, 0.6414864864864865)

SAVE AND LOAD MODEL 

In [21]:
import pickle
file_name = "/content/drive/MyDrive/685/sarcasm/evaluation/Vision_trans/xgb_vis_cls_sarcasm_mmodal.pkl"

# save
pickle.dump(model, open(file_name, "wb"))

# load
xgb_model_loaded = pickle.load(open(file_name, "rb"))

In [22]:
# test
# ind = 1
# test = X_test
# xgb_model_loaded.predict(test)