In [None]:
import pyspark
sc = pyspark.SparkContext('local[*]')

In [None]:
%load_ext autoreload
%autoreload 2

from os.path import join

import auxiliary as auxi
from cleanData import CleanData
from splitData import SplitData
from itemBasedSim import ItemBasedSim
from extender import ExtendSim

from pyspark.sql import SQLContext

In [None]:
path_root = "/home/jovyan/work/data"
path_pickle_train = join(path_root, "cache/two_domain/split_data/train")
path_pickle_test = join(path_root, "cache/two_domain/split_data/test")
path_pickle_baseline_sim = join(path_root, "cache/two_domain/item_based_sim/base_sim")

In [None]:
testRDD = sc.pickleFile(path_pickle_test)
item2item_simRDD = sc.pickleFile(path_pickle_baseline_sim)

In [None]:
sqlContext = SQLContext(sc)
itemsim = ItemBasedSim(method='cosine', num_atleast=50)

item2item_simDF = itemsim.build_sim_DF(item2item_simRDD)

item2item_simDF.registerTempTable("sim_table")
BB_item_list = sqlContext.sql(
    "SELECT DISTINCT id1 FROM sim_table WHERE label = 1").map(
    lambda line: line.id1).collect()
BB_item_bd = sc.broadcast(BB_item_list)

item_simRDD = itemsim.get_item_sim(item2item_simRDD)

In [None]:
top_k = 10
extend_sim = ExtendSim(top_k)

classfied_items = extend_sim.find_knn_items(item_simRDD, BB_item_bd).cache()

In [None]:
def extract_siminfo(sc, classfied_items):
    """broadcast knn item information.
    arg:
        classfied_items: iid, (BB_BB, BB_NB), (NB_BB, NB_NN)
    return:
        knn_BB_bd: {BB iid: {NB iid: (sum, mutu, frac_mutu)}}
        knn_NB_bd: {NB iid: {iid: (sum, mutu, frac_mutu)}}
    """
    BB_info = classfied_items.map(
        lambda line: (line[0], line[1])).filter(
        lambda line: line[1] is not None).cache()

    NB_info = classfied_items.map(
        lambda line: (line[0], line[2])).filter(
        lambda line: line[1] is not None).cache()

    BB_items_knn = BB_info.map(
        lambda line: (line[0], dict(
                (l[0], l[1:]) for l in line[1][0] + line[1][1]))
    ).collectAsMap()

    NB_items_knn = NB_info.map(
        lambda line: (line[0], dict(
                (l[0], l[1:]) for l in line[1][0] + line[1][1]))
    ).collectAsMap()

    knn_BB_bd = sc.broadcast(BB_items_knn)
    knn_NB_bd = sc.broadcast(NB_items_knn)
    return BB_info, NB_info, knn_BB_bd, knn_NB_bd

BB_info, NB_info, knn_BB_bd, knn_NB_bd = extract_siminfo(sc, classfied_items)

In [None]:
def combine_BB_withother_in_singledomain(iter_items):
    """combine BB item with other items for each domain.
    return:
        NB_NN iid, [(BB iid, [NB_NN iid*])*]
    """
    for iid, (NB_BB, NB_NN) in iter_items:
        """
        NB_BB: [(BB iid, sim, mutu, frac_mutu)*]
        NB_NN: [(NN iid, sim, mutu, frac_mutu)*]
        """
        for info in NB_BB:
            yield info[0], [(iid, [line[0] for line in NB_NN])]

BB_other_intra = NB_info.mapPartitions(
    combine_BB_withother_in_singledomain).reduceByKey(lambda a, b: a + b).cache()

In [None]:
        def extend_BB_source(sourceRDD):
            """connect BB item in target domain with items in source domain.
            (BB_target, BB_source), connections
            """
            def helper(iter_items):
                for iid, line in iter_items:
                    for v in knn_BB_bd.value[iid].keys():
                        if "T:" in v:
                            yield (v, iid), line
            return sourceRDD.mapPartitions(helper)

        def extend_BB_target(rdd):
            """connect BB item in source domain with item in target domain.
            (BB_target, BB_source), connections
            """
            def helper(iter_items):
                for iid, line in iter_items:
                    for v in knn_BB_bd.value[iid].keys():
                        if "S:" in v:
                            yield (iid, v), line
            return rdd.mapPartitions(helper)
        
BB_other_intra_source = BB_other_intra.filter(lambda l: "S:" in l[0])
BB_other_intra_target = BB_other_intra.filter(lambda l: "T:" in l[0])

extended_BB_source = extend_BB_source(BB_other_intra_source)
extended_BB_target = extend_BB_target(BB_other_intra_target)
joined_extended_BB = extended_BB_source.join(extended_BB_target).cache()

In [None]:
extended_BB_target.take(1)