# NEURAL NETWORK COLLABORATIVE FILTERING
This notebook demonstrates our approach for NNCF on a small sample dataset.  
Note that some of the functionality is implemented in other classes:
* `twitter_preproc.py` (preprocessing)
* `nnpreprocessor.py` (one hot encoding)
* `NNCFNet.py` (the neural network class)

Our neural network approach for the full dataset is in `nncf-submit.py`. To reproduce our submission attempts use: `spark-submit nncf-submit.py`. (NOTE: As described in our report, this script never actually ran through, due to memory overload exceptions. It is just here for reproduction purposes)

In [1]:
from pyspark.sql import SparkSession
from pyspark import SparkContext, SparkConf
import importlib

# Building Spark Context
# conf = SparkConf().setAll([('spark.executor.memory', '32g'), ('spark.executor.instances','8'),('spark.executor.cores', '12'), ('spark.driver.memory','64g'), ('spark.driver.memoryOverhead', '64g')])
conf = SparkConf()
spark = SparkSession.builder.appName("nncf_train").config(conf=conf).getOrCreate()
sc = spark.sparkContext

## Get training data

In [2]:
import twitter_preproc

base = "///tmp/"
one_k = "traintweet_1000.tsv"
ensemble_train = 'supersecret_ensembletrain5k_bootstrap.tsv'
ensemble_test = 'supersecret_test5k_bootstrap.tsv'
choice = ensemble_train

preproc = twitter_preproc.twitter_preproc(spark, sc, base+choice, MF=True)
traindata = preproc.getDF()

## NNCF specific preprocessing (essentially onehot-encoding)

In [3]:
import nnpreprocessor
importlib.reload(nnpreprocessor)

nnp = nnpreprocessor.NNPreprocessor()
engagement = 'retweet_comment'
tweets, users, target = nnp.nn_preprocess(traindata)

## Train the NN

In [4]:
from NNCFNet import Net
import torch
import torch.nn as nn
import torch.optim as optim
import sys, traceback

# Initalize Hyperparameters
k = 32
n_epochs = 2
batch_size = 256

# Initialize Neural Network
net = Net(users.shape[1], tweets.shape[1], k)
optimizer = optim.SGD(net.parameters(), lr=0.001)
criterion = nn.BCELoss()
output = net(users, tweets)

# Start training
for epoch in range(n_epochs):

    permutation = torch.randperm(users.size()[0])

    for i in range(0,users.size()[0], batch_size):
        optimizer.zero_grad()

        indices = permutation[i:i+batch_size]
        batch_x_user = users[indices]
        batch_x_tweet = tweets[indices]
        batch_y = target[indices]

        outputs = net.forward(batch_x_user, batch_x_tweet)
        loss = criterion(outputs,batch_y)
        loss.backward()
        optimizer.step()

## Create & format output

In [5]:
from pyspark.sql.functions import monotonically_increasing_id
import numpy as np

# get predictions
net.eval()
prediction = net(users, tweets)
p_vec = prediction.detach().numpy().flatten()
scaled = (p_vec - np.min(p_vec))/np.ptp(p_vec)
probabilities = [float(x) for x in scaled]

# get original order
order_df = traindata.withColumn("original_order", monotonically_increasing_id())
order_df = order_df.select("engaging_user_id", "tweet_id", 'original_order')
sorting_tweets = nnp.get_id_indices(order_df, id_column='tweet_id')

# rejoin labels
result = order_df.join(sorting_tweets, 'tweet_id').rdd.map(lambda x: (x['engaging_user_id'], x['tweet_id'], probabilities[x['tweet_id_index']]))

In [6]:
result.toDF(["engaging_user_id", "tweet_id", 'target']).collect()

[Row(engaging_user_id='73941037EA0549A897D939F4363C0869', tweet_id='001D1FCF3DEC0B1A66A4EB72BF637ED3', target=0.374126672744751),
 Row(engaging_user_id='34E69C49BB7EF2CE04AB04BD2E2D9D50', tweet_id='0021028C8FCA722B602D9AFBCED72295', target=0.36835700273513794),
 Row(engaging_user_id='A05961628D0667F1B9D90C47BE232395', tweet_id='0023A079D3C166DCB13C5847054A2A52', target=0.37173765897750854),
 Row(engaging_user_id='B29DBE334E22603FD220500FD67669FD', tweet_id='002A8D35B5F158A65526CA24CC0D8BC3', target=0.22465629875659943),
 Row(engaging_user_id='60D743915E1E67DE3ADB0F601FEBEE49', tweet_id='00427E558D9F432D69E09E56D2B0CF04', target=0.2569754421710968),
 Row(engaging_user_id='0517AB3E9E10926CF69F927397EC9407', tweet_id='0045292E6A49E8C14FB5C9692E552087', target=0.3194500803947449),
 Row(engaging_user_id='A939E5AC38764F9CA53FDE40D5F8A6AD', tweet_id='0056700195233A617AE0E786B494BFA5', target=0.5339193344116211),
 Row(engaging_user_id='16EA418AB0457F4AF7CCDEA710B815C7', tweet_id='00670B713247F