# Collate Function and Dataloader Class

In this notebook, we develop the appropriate collate function and data classes for this project. A **collate function**  is how the dataloader will processes each example from the dataset.

Dataset and DataLoader are PyTorch classes that provides utilities for iterating through and sampling from a dataset.

In [6]:
import torch
import pandas as pd
import numpy as np
import pre_processing as pp
from nltk import word_tokenize
from torchtext.data.utils import get_tokenizer

tokenizer = get_tokenizer('basic_english')

## 0. Download Data

In [14]:
articles = pd.read_csv('data/articles.csv')
scores = pd.read_csv('data/articles_w_scores.csv')
data = scores[['id','vaderMean','vaderStd', 'vaderCat']].merge(articles, left_on='id', right_on='articleID')

In [15]:
data

Unnamed: 0,id,vaderMean,vaderStd,vaderCat,articleID,snippet,topic
0,5adf6684068401528a2aa69b,0.345973,0.511952,2.0,5adf6684068401528a2aa69b,"“I understand that they could meet with us, pa...",Culture/Education
1,5adf653f068401528a2aa697,0.044260,0.566761,2.0,5adf653f068401528a2aa697,The agency plans to publish a new regulation T...,Elections
2,5adf4626068401528a2aa628,0.488810,0.538851,2.0,5adf4626068401528a2aa628,What’s it like to eat at the second incarnatio...,International
3,5adf40d2068401528a2aa619,0.042040,0.625098,2.0,5adf40d2068401528a2aa619,President Trump welcomed President Emmanuel Ma...,International
4,5adf3d64068401528a2aa60f,-0.316441,0.623894,2.0,5adf3d64068401528a2aa60f,"Alek Minassian, 25, a resident of Toronto’s Ri...",Gun Crimes
...,...,...,...,...,...,...,...
3355,5abfcca647de81a90121a899,0.317276,0.578538,2.0,5abfcca647de81a90121a899,Can post-Christian spirituality make a bridge ...,Politics
3356,5abfcca747de81a90121a89a,0.188099,0.612754,2.0,5abfcca747de81a90121a89a,"A tale of two Kushner brothers, on either side...",Leisure
3357,5abfd02c47de81a90121a8af,-0.013215,0.631028,2.0,5abfd02c47de81a90121a8af,The president sticks his head in the sand as h...,Social Issues
3358,5abfd3ae47de81a90121a8bf,-0.057539,0.714906,2.0,5abfd3ae47de81a90121a8bf,"In 1968, Spiro Agnew was the man of the white,...",Social Media


## 1. Create Custom Dataset Object (PyTorch)

In [8]:
from torch.utils.data import Dataset

In [27]:
class ProjectDataset(Dataset):
    def __init__(self, data, target_col, text_col):
        data_lists = []
        
        for index, row in data.iterrows():
            text = pp.clean_text(row[text_col], lowercase=False)
            text = word_tokenize(text)
            target = row[target_col]
            data_lists.append([target, text])
        
        self.samples = data_lists
        
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]

In [28]:
data_object = ProjectDataset(data, 'vaderStd', 'snippet')

In [29]:
len(data_object)

3360

In [30]:
data_object[0:2]

[[0.5119524453713237,
  ['I',
   'understand',
   'that',
   'they',
   'could',
   'meet',
   'with',
   'us',
   'patronize',
   'us',
   'and',
   'do',
   'nothing',
   'in',
   'the',
   'end',
   'their',
   'lawyer',
   'says']],
 [0.5667606324178027,
  ['The',
   'agency',
   'plans',
   'to',
   'publish',
   'a',
   'new',
   'regulation',
   'Tuesday',
   'that',
   'would',
   'restrict',
   'the',
   'kinds',
   'of',
   'scientific',
   'studies',
   'the',
   'agency',
   'can',
   'use',
   'when',
   'it',
   'develops',
   'policies']]]

## 2. Collate Function

### 2.1 Bag of Words

In [65]:
#Bag of Words
from torchtext.data.utils import get_tokenizer
from collections import Counter
from torchtext.vocab import Vocab

tokenizer = get_tokenizer('basic_english')

def get_vocab(training_data):
    counter = Counter()
    for (label, line) in train_iter:
        counter.update(line)
    vocab = Vocab(counter, min_freq=1000)
    return vocab


def collate_into_bow(batch):  
    labels = []
    bag_vector = torch.zeros((len(batch),len(vocab)))
    for i, (label, line) in enumerate(batch):
        labels.append(label-1)
        for w in line:            
            bag_vector[i, vocab[w]] += 1
    
    bag_vector = (bag_vector/bag_vector.sum(axis=1, keepdim=True))
    return torch.tensor(labels), bag_vector

### 2.2 Continuous Bag of Words using GloVe

In [66]:
from torchtext.vocab import GloVe
glove = GloVe(name='6B') #Takes long to download

def collate_into_cbow(batch):
    cbag_vector = torch.tensor([])
    labels = []
    for i, (label, line) in enumerate(batch):
        labels.append(label-1)
        vecs = glove.get_vecs_by_tokens(line)
        vecs = vecs.sum(axis=0)/vecs.shape[0]
        cbag_vector = torch.cat([cbag_vector, vecs.view(1, -1)])
    
    return torch.tensor(labels), cbag_vector

.vector_cache/glove.6B.zip: 862MB [07:51, 1.83MB/s]                               
100%|█████████▉| 399999/400000 [00:32<00:00, 12128.48it/s]


## 3. DataLoader

### 3.1 Load Articles Data

In [68]:
from torch.utils.data import DataLoader
train_dataloader = DataLoader(data_object, batch_size=30,
                              shuffle=True, 
                              collate_fn=collate_into_bow)

In [71]:
train_dataloader

<torch.utils.data.dataloader.DataLoader at 0x7fdad80b1cd0>