In [None]:
import random

def getUIdx(row):
	uIdx, mIdx, rate, time = row
	return (uIdx, 1)

def getMIdx(row):
	uIdx, mIdx, rate, time = row
	return (mIdx, 1)

def count(x, y):
	return x+y

class splitData(object):
	def __init__(self, uList, mList):
		self.uSet = frozenset(uList)
		self.mSet = frozenset(mList)
		self.uSize = len(self.uSet)
		self.mSize = len(self.mSet)
		#(1 / 2.2) ** 2 = 0.21
		self.uTestIdx = set(random.sample(uList, int(self.uSize // 2.2)))
		self.mTestIdx = set(random.sample(mList, int(self.mSize // 2.2)))
		return

	def split(self, row):
		uIdx, mIdx, rate, time = row
		if uIdx in self.uTestIdx and mIdx in self.mTestIdx:
			return (int(time) % 2 + 1, row)
		else:
			return (0, row)

	#compensate trainData if trainData is not full in terms of uIdx or mIdx
	def update(self, tUList, tMList):
		self.cUIdx = self.uSet - frozenset(tUList)
		self.cMIdx = self.mSet - frozenset(tMList)
		return

	def compensate(self, line):
		key, row = line
		uIdx, mIdx, rate, time = row
		if uIdx in self.cUIdx or mIdx in self.cMIdx:
			return (0, row)
		else:
			return (key, row)

def getRow(data):
	key, row = data
	return row

In [None]:
rateData = spark.read.csv('/user/hz333/data/project/ratings.csv', header = True)

In [None]:
#(uIdx, mIdx, rate, time) => (uIdx, 1)
uIdx = rateData.rdd.map(getUIdx)
#(uIdx, mIdx, rate, time) => (mIdx, 1)
mIdx = rateData.rdd.map(getMIdx)

#(uIdx, 1) => (uIdx, count)
uIdx = uIdx.reduceByKey(count)
#(mIdx, 1) => (mIdx, count)
mIdx = mIdx.reduceByKey(count)

#(uIdx, count) => [uIdx]
uList = uIdx.keys().collect()
#(mIdx, count) => [mIdx]
mList = mIdx.keys().collect()

sp = splitData(uList, mList)
#(uIdx, mIdx, rate, time) => (key, (uIdx, mIdx, rate, time))
data = rateData.rdd.map(sp.split)

#(key, (uIdx, mIdx, rate, time)) => (0, (uIdx, mIdx, rate, time))
trainData = data.filter(lambda line: line[0] == 0)
#(key, (uIdx, mIdx, rate, time)) => (key, (uIdx, mIdx, rate, time))
TVData = data.filter(lambda line: line[0] > 0)

#(0, (uIdx, mIdx, rate, time)) => (uIdx, mIdx, rate, time)
trainData = trainData.map(getRow)

#get [uIdx] and [mIdx] of trainData
tUIdx = trainData.map(getUIdx)
tMIdx = trainData.map(getMIdx)

tUIdx = tUIdx.reduceByKey(count)
tMIdx = tMIdx.reduceByKey(count)

tUList = tUIdx.keys().collect()
tMList = tMIdx.keys().collect()


sp.update(tUList, tMList)
#(key, (uIdx, mIdx, rate, time)) => (newKey, (uIdx, mIdx, rate, time))
TVData = TVData.map(sp.compensate)

#union compoensated trainData
cTrainData = TVData.filter(lambda line: line[0] == 0)
cTrainData = cTrainData.map(getRow)
trainData = trainData.union(cTrainData)

#get testData and validData
validData = TVData.filter(lambda line: line[0] == 2)
testData = TVData.filter(lambda line: line[0] == 1)

validData = validData.map(getRow)
testData = testData.map(getRow)

In [None]:
#check trainData.idx is full
tUIdx = trainData.map(getUIdx)
tMIdx = trainData.map(getMIdx)

tUIdx = tUIdx.reduceByKey(count)
tMIdx = tMIdx.reduceByKey(count)

tUList = tUIdx.keys().collect()
tMList = tMIdx.keys().collect()

In [None]:
len(tUList) == len(uList)

In [None]:
len(tMList) == len(mList)

In [None]:
validData.count()

In [None]:
testData.count()

In [None]:
trainData.count()

In [None]:
rateData.count()

In [None]:
trainCSV = spark.createDataFrame(trainData, samplingRatio = 1)
trainCSV.repartition(1).write.option('header', 'false').csv('/user/hz333/data/project/train.csv')
testCSV = spark.createDataFrame(testData, samplingRatio = 1)
testCSV.repartition(1).write.option('header', 'false').csv('/user/hz333/data/project/test.csv')
validCSV = spark.createDataFrame(validData, samplingRatio = 1)
validCSV.repartition(1).write.option('header', 'false').csv('/user/hz333/data/project/valid.csv')