In [1]:
from mpl_toolkits.mplot3d import Axes3D
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt  # plotting
import numpy as np  # linear algebra
import os  # accessing directory structure
import pandas as pd  # data processing, CSV file I/O (e.g. pd.read_csv)
import time
import copy
import math
from typing import List
from unicodedata import bidirectional
import numpy as np

from nltk.tokenize import word_tokenize
from tqdm import tqdm
from gensim.models import KeyedVectors
from gensim.test.utils import datapath
import torch
from torch.utils.data import Dataset
from torch import nn
import torch.nn.functional as F
import torch.optim as optimizer
import matplotlib.pyplot as plt
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence


In [None]:
nRowsRead = 2500  # specify 'None' if want to read whole file
# train-balanced-sarcasm.csv has 1010826 rows in reality, but we are only loading/previewing the first 1000 rows
sarc_data = pd.read_csv('train-balanced-sarcasm.csv',
                  delimiter=',', nrows=nRowsRead)
sarc_data .dataframeName = 'train-balanced-sarcasm.csv'
nRow, nCol = sarc_data.shape
print(f'There are {nRow} rows and {nCol} columns')
print(sarc_data[:10])


In [None]:


class SARCDataset(Dataset):
    """Dataset for modified QA task on SQuAD2.0"""

    def __init__(self, data: List[dict]):
        super().__init__()
        self.wv_data = KeyedVectors.load('drive/MyDrive/Assignment 4/glove.kv')
        self.data = data

        # getting unk val
        # for key in self.wv_data:
        arr = np.array([self.wv_data[i] for i in self.wv_data.index2word])
        unk = np.mean(arr, axis=0)
        dataset = []

        for d in data:
            # get questions
            questions = []
            for q in d['qas']:
                ques = q['question']
                ques = word_tokenize(ques)
                q_rep = [0.0]*300
                count = 0
                for w in ques:
                    if w in self.wv_data:
                        q_rep += self.wv_data[w]
                    else:
                        q_rep += unk
                    count += 1
                q_rep = q_rep / count
                questions.append(q_rep)

            # get contexts
            context = []
            for c in d['context']:
                #c = c.lower()
                cont = word_tokenize(c)
                c_rep = [0.0]*300
                count = 0
                for w in cont:
                    if w in self.wv_data:
                        c_rep += self.wv_data[w]
                    else:
                        c_rep += unk
                    count += 1
                c_rep = c_rep / count
                context.append(c_rep)

            # get labels
            labels = []
            for i in range(len(questions)):
                labels.append([0.0]*len(context))
                for j in range(len(context)):
                    if d['qas'][i]['is_impossible'] == True:
                        val = 0
                    elif d['qas'][i]['answer']['sentence_id'] == j:
                        val = 1
                    else:
                        val = 0
                    val = np.float32(val)
                    labels[i][j] = val

            # get sentence ids
            sentence_ids = []
            for q in d['qas']:
                if q['is_impossible']:
                    sentence_ids.append(-1)
                else:
                    sentence_ids.append(q['answer']['sentence_id'])

            dp = dict()
            dp['questions'] = torch.tensor(np.array(questions))
            dp['context'] = torch.tensor(np.array(context))
            dp['labels'] = torch.tensor(np.array(labels))
            dp['sentence_ids'] = torch.tensor(np.array(sentence_ids))

            dataset.append(dp)

        self.dataset = dataset

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.dataset[idx]


def basic_collate_fn(batch):
    """Collate function for basic setting."""
    # questions
    questions = [i['questions'] for i in batch]
    # context
    context = [i['context'] for i in batch]
    # labels
    labels = [i['labels'].reshape(-1) for i in batch]
    labels = torch.cat(labels)
    # sentence ids
    sent_ids = [i['sentence_ids'] for i in batch]
    sent_ids = torch.cat(sent_ids)

    return questions, context, labels, sent_ids
