# Baseline Assumptions
    - All preprocessing is completed and has been stored in a .csv file
    - There exists no "bad" data such that an associated label is out of the labeled set
    - The vectors placed into the training set are of the same form as the test set
        - There exists no errors due to unseen words

# Imports

In [None]:
# generic
import os

# Data management
import csv
import pandas as pd
from torch.utils.data import Dataset

# Deep learning 
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim

# Math and plots
import numpy as np
import random
import matplotlib.pyplot as plt

## GPU

In [None]:
device = torch.device("cuda")
print(device)

## Flags

In [None]:
NEW_MODEL = True
TRAIN = True
SPLIT_USE = 1

## Data Loader
    - Assumes data is preprossessed such that no transformation must be done on load
    - Does not load element as a tensor
    - Return a descriptor vector and an encoded vector

In [None]:
# filepaths
train_data = './Data/name_of_train.csv'
test_data = './Data/name_of_test.csv'

In [None]:
# data loader fir train and test
class CommentData(Dataset):
    
    def __init__(self, frames):
        self.frames = frames
        self.labels = np.unique(self.frames)
        self.encoding = np.eye(len(self.labels))

    def __len__(self):
        return len(self.frames)

    def __getitem__(self, idx):
        element, label = self.frames[idx]
        enc_label = self.encode(label)
        return element, enc_label
    
    # one-hot encoding on element fetch
    def encode(self, label):
        location = self.labels.index(label)
        encoding = self.encoding[location]
        return encoding

## K-Fold 

In [None]:
# assume loaded Nx1x1
# leverages pandas for fast csv load but operates in numpy
class kFold():
    def __init__(self, filepath, numFolds=5):
        self.data = np.asarray(pd.read_csv(csv_file))
        self.numFolds = numFolds
        self.splits = []
        
    def generateSplits(self):
        np.random.shuffle(self.data)
        
        folds = []
        splitPoint = self.data.shape[0] // (self.numFolds)  #breakpoint index jump
        
        for i in range(self.numFolds - 1):
            folds.append(self.data[i*splitPoint:(i+1)*splitPoint, :, :])
            
        folds.append(self.data[(i+1)*splitPoint:,:,:]) #get extra points in last batch
        
        # create split permutations 80/10/10
        foldDivisor = len(folds) // 2
        for i in range(self.numFolds):
            for k in range(self.numFolds):
                if i == k:
                    validation = fold[i][:foldDivisor] 
                    test = fold[i][foldDivisor:] 
                else:
                    train.append(fold[k])
            
            train = np.hstack(train) # adapt dims
            self.split.append((train, validation, test))

## Model

In [None]:
class WhoReddit(nn.Module):

    def __init__(self):
        super(WhoReddit, self).__init__()
        
        # mini inception net block 1
        self.convA1 = nn.Conv1d(1, 64, 3, padding = 1)
        self.normA1 = nn.BatchNorm1d(64)
        self.reluA1 = nn.ReLU(True)
        self.poolA1 = nn.MaxPool1d(3, 3)
        
        self.convB1 = nn.Conv1d(1, 64, 5, padding = 2)
        self.normB1 = nn.BatchNorm1d(64)
        self.reluB1 = nn.ReLU(True)
        self.poolB1 = nn.MaxPool1d(3, 3)
        
        self.convC1 = nn.Conv1d(1, 64, 7, padding = 3)
        self.normC1 = nn.BatchNorm1d(64)
        self.reluC1 = nn.ReLU(True)
        self.poolC1 = nn.MaxPool1d(3, 3)
        
        self.blend1 = nn.Sequential(
            nn.Conv1d(3*64, 96, 3, padding = 1)
            nn.BatchNorm1d(96)
            nn.Relu(True)
            nn.MaxPool1d(3,3)
            nn.Dropout(0.5)
        )
        
        # mini inception net block 2
        self.convA2 = nn.Conv1d(96, 128, 3, padding = 1)
        self.normA2 = nn.BatchNorm1d(128)
        self.reluA2 = nn.ReLU(True)
        self.poolA2 = nn.MaxPool1d(3, 3)
        
        self.convB2 = nn.Conv1d(96, 128 5, padding = 2)
        self.normB2 = nn.BatchNorm1d(128)
        self.reluB2 = nn.ReLU(True)
        self.poolB2 = nn.MaxPool1d(3, 3)
        
        self.convC2 = nn.Conv1d(96, 128 7, padding = 3)
        self.normC2 = nn.BatchNorm1d(128)
        self.reluC2 = nn.ReLU(True)
        self.poolC2 = nn.MaxPool1d(3, 3)
    
        self.blend2 = nn.Sequential(
            nn.Conv1d(3*128, 196, 3, padding = 1)
            nn.BatchNorm1d(196)
            nn.Relu(True)
            nn.MaxPool1d(3,3)
        )
        
        # mini inception net block 3
        self.convA3 = nn.Conv1d(196, 256, 3, padding = 1)
        self.normA3 = nn.BatchNorm1d(128)
        self.reluA3 = nn.ReLU(True)
        self.poolA3 = nn.MaxPool1d(3, 3)
        
        # core modules
        self.convB3 = nn.Conv1d(196, 256, 5, padding = 2)
        self.normB3 = nn.BatchNorm1d(128)
        self.reluB3 = nn.ReLU(True)
        self.poolB3 = nn.MaxPool1d(3, 3)
        
        self.convC3 = nn.Conv1d(196, 256, 7, padding = 3)
        self.normC3 = nn.BatchNorm1d(128)
        self.reluC3 = nn.ReLU(True)
        self.poolC3 = nn.MaxPool1d(3, 3)
    
        self.merge = nn.Sequential(
            nn.Conv1d(3*196, 256, 3, padding = 1)
            nn.BatchNorm1d(256)
            nn.Relu(True)
            nn.MaxPool1d(3,3)
            nn.AdaptiveAvgPool1d(256)
            nn.Dropout(0.2)
            nn.Linear(256, 20)
        )
        
        self.dropout = Dropout1d(0.2)

    def forward(self, x):
        
        A = self.dropout(self.poolA1(self.reluA1(self.normA1(self.convA1(x)))))
        B = self.dropout(self.poolB1(self.reluB1(self.normB1(self.convB1(x)))))
        C = self.dropout(self.poolC1(self.reluC1(self.normC1(self.convC1(x)))))
        x = torch.cat((A,B,C), dim=1)
        x = self.blend1(x)
        
        A = self.dropout(self.poolA2(self.reluA2(self.normA2(self.convA2(x)))))
        B = self.dropout(self.poolB2(self.reluB2(self.normB2(self.convB2(x)))))
        C = self.dropout(self.poolC2(self.reluC2(self.normC2(self.convC2(x)))))
        x = torch.cat((A,B,C), dim=1)
        x = self.blend2(x)
        
        A = self.dropout(self.poolA3(self.reluA3(self.normA3(self.convA3(x)))))
        B = self.dropout(self.poolB3(self.reluB3(self.normB3(self.convB3(x)))))
        C = self.dropout(self.poolC3(self.reluC3(self.normC3(self.convC3(x)))))
        x = torch.cat((A,B,C), dim=1)
        x = self.merge(x)
        
        return x


In [None]:
if NEW_MODEL:
    net = WhoReddit()
else:
    #todo: load network

print(net)

### Loss and Optimizer

In [None]:
loss = nn.CrossEntropyLoss().type(device)
optimizer = optim.Adam(net.parameters(), lr=1e-5)

## Training

In [None]:
# split init
commentFolds = kFold(train_data) 
commentFolds.generateSplits()
splits = commentFolds.splits

In [None]:
epochs = 50

In [None]:
for idx, (train, validation, test) in enumerate(splits): # split
    # Data loaders!!
    train_loader = torch.utils.data.DataLoader(train, batch_size=2, num_workers=8, shuffle=True)
    val_loader = torch.utils.data.DataLoader(validation, batch_size=1, num_workers=8)
    test_loader = torch.utils.data.DataLoader(validation, batch_size=1, num_workers=8)
    
    # train cycle here
    for epoch in range(epochs):
        
        net.train()
        running_loss = 0.0
        correct = 0.
        total = 0.
        
        for i, (comment, label) in enumerate(train_loader):
            
            # tensor to device
            comment = torch.FloatTensor(comment).to(device)
            label = torch.FloatTensor(label).to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            output = net(comment)
            error = loss(output, label)
            error.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            if i % 50 == 49:    # print every 50 mini-batches
                print('[%d, %5d] loss: %.5f' %
                      (epoch + 1, i + 1, running_loss / 50))
                running_loss = 0.0

            # Get predictions
            preds = F.softmax(output, dim=1)
            preds_cls = preds.argmax(dim=1)

            # Count number of correct predictions
            correct_preds = torch.eq(preds_cls, label)
            correct += torch.sum(correct_preds).cpu().item()
            total += len(correct_preds)

        train_acc = correct / total
        print("Epoch:", epoch+1,"Training Acc:",train_acc)

        net.eval()
        correct = 0.
        total = 0.
        
        for i, (comment, label) in enumerate(val_loader):

            comment = torch.FloatTensor(comment).to(device)
            label = torch.FloatTensor(label).to(device)
            output = net(comment)

            # Get predictions
            preds = F.softmax(output, dim=1)
            preds_cls = preds.argmax(dim=1)

            # Count number of correct predictions
            correct_preds = torch.eq(preds_cls, label)
            correct += torch.sum(correct_preds).cpu().item()
            total += len(correct_preds)

        valid_acc = correct / total
        print("Epoch:", epoch+1,"Validation Acc:",valid_acc)
        
        for i, (comment, label) in enumerate(test_loader):

            comment = torch.FloatTensor(comment).to(device)
            label = torch.FloatTensor(label).to(device)
            output = net(comment)

            # Get predictions
            preds = F.softmax(output, dim=1)
            preds_cls = preds.argmax(dim=1)

            # Count number of correct predictions
            correct_preds = torch.eq(preds_cls, label)
            correct += torch.sum(correct_preds).cpu().item()
            total += len(correct_preds)

        test_acc = correct / total
        print("Epoch:", epoch+1,"Test Acc:",test_acc)
        

    print('Finished Training')

    # terminate cycle
    if idx-1 >= FOLD_USE:
        break