# Collate Function and Dataloader Class

In this notebook, we develop the appropriate collate function and data classes for this project. A **collate function**  is how the dataloader will processes each example from the dataset.

Dataset and DataLoader are PyTorch classes that provides utilities for iterating through and sampling from a dataset.

In [1]:
import torch
import pandas as pd
import numpy as np
import pre_processing as pp
from nltk import word_tokenize
from torchtext.data.utils import get_tokenizer
import dataobject as pdata

tokenizer = get_tokenizer('basic_english')



## 0. Download Data

In [6]:
data = pd.read_csv('data/tweets_all.csv')

In [7]:
data

Unnamed: 0,id,topic,source,text,replyCount,vaderMean,vaderStd,vaderCatLabel,vaderCat
0,1377385383168765952,Politics,FoxNews,activists protest renaming chicago school afte...,306,-0.052830,0.445459,medium,1.0
1,1377384607969013765,Violence,FoxNews,border patrol video shows smugglers abandoning...,108,-0.045958,0.495337,medium,1.0
2,1377384339105669122,Media,FoxNews,cause of tiger woods car crash determined but ...,169,-0.034919,0.424833,medium,1.0
3,1377367836046192641,Politics,FoxNews,gop rep urges hhs to halt reported plan to rel...,80,0.043459,0.495874,medium,1.0
4,1377358399759785987,Politics,FoxNews,some democrats trying to stop iowa new hampshi...,96,-0.040135,0.433053,medium,1.0
...,...,...,...,...,...,...,...,...,...
20692,1377415994973376513,Protests,Reuters,u n special envoy tells security council to a...,14,-0.334379,0.346530,low,0.0
20693,1377414604851142662,Business,Reuters,wisconsin high court voids governors mask mand...,12,-0.057192,0.521413,high,2.0
20694,1377412951456411649,Politics,Reuters,analysis biden infrastructure plan bets big o...,38,0.047718,0.422482,medium,1.0
20695,1377411743295541252,Protests,Reuters,analysis deliveroos flop a wake up call for t...,5,-0.171920,0.404383,medium,1.0


## New Example Process
All functions have been placed in dataobject.py 
for ease of replication.

This should be all that is needed in model notebooks:

In [8]:
train, validate, test = pdata.get_datasets(data, 'vaderCat', 'text', 
                                           collate_func='cbow',
                                           batch_size=30,
                                           split=0.2, 
                                           random_seed=42)

downloading GloVe, please wait.
training size:  16560
validation size:  2070
testing size:  2070


## Old Process (functions all moved into dataobject.py)

## 1. Create Custom Dataset Object (PyTorch)

In [None]:
#Updated version imported from dataobject.py

from torch.utils.data import Dataset

class ProjectDataset(Dataset):
    def __init__(self, data, target_col, text_col):
        data_lists = []
        
        for index, row in data.iterrows():
            text = pp.clean_text(row[text_col], lowercase=False)
            text = word_tokenize(text)
            target = row[target_col]
            data_lists.append([target, text])
        
        self.samples = data_lists
        
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]

In [None]:
len(data_object)

In [None]:
data_object[0:2]

## 2. Collate Function

### 2.1 Bag of Words

In [None]:
#Bag of Words
from torchtext.data.utils import get_tokenizer
from collections import Counter
from torchtext.vocab import Vocab

tokenizer = get_tokenizer('basic_english')

def get_vocab(training_data):
    counter = Counter()
    for (label, line) in train_iter:
        counter.update(line)
    vocab = Vocab(counter, min_freq=1000)
    return vocab


def collate_into_bow(batch):  
    labels = []
    bag_vector = torch.zeros((len(batch),len(vocab)))
    for i, (label, line) in enumerate(batch):
        labels.append(label-1)
        for w in line:            
            bag_vector[i, vocab[w]] += 1
    
    bag_vector = (bag_vector/bag_vector.sum(axis=1, keepdim=True))
    return torch.tensor(labels), bag_vector

### 2.2 Continuous Bag of Words using GloVe

In [None]:
from torchtext.vocab import GloVe
glove = GloVe(name='6B') #Takes long to download

def collate_into_cbow(batch):
    cbag_vector = torch.tensor([])
    labels = []
    for i, (label, line) in enumerate(batch):
        labels.append(label-1)
        vecs = glove.get_vecs_by_tokens(line)
        vecs = vecs.sum(axis=0)/vecs.shape[0]
        cbag_vector = torch.cat([cbag_vector, vecs.view(1, -1)])
    
    return torch.tensor(labels), cbag_vector

## 3. DataLoader

### 3.1 Load Articles Data

In [None]:
from torch.utils.data import DataLoader
BATCH_SIZE = 30

train_dataloader = DataLoader(data_object, batch_size=BATCH_SIZE,
                              sampler=data_object.train, 
                              collate_fn=collate_into_bow)
valid_dataloader = DataLoader(data_object, batch_size=BATCH_SIZE,
                              sampler=data_object.valid, 
                              collate_fn=collate_into_bow)
test_dataloader = DataLoader(data_object, batch_size=BATCH_SIZE,
                              sampler=data_object.test, 
                              collate_fn=collate_into_bow)

In [None]:
print("training size: ", len(train_dataloader)*BATCH_SIZE)
print("validation size: ", len(valid_dataloader)*BATCH_SIZE)
print("testing size: ", len(test_dataloader)*BATCH_SIZE)

In [None]:
for i in iter(data_object.train):
    print(i)
    break

In [None]:
len([data_object[i] for i in data_object.train])