In [25]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [38]:
import tensorflow as tf
import pandas as pd
import numpy as np
import os
import sys
import copy
sys.path.append('/content/drive/MyDrive/(22-1)캡스톤/recomm/Recommendation/')
sys.path.append('/content/drive/MyDrive/(22-1)캡스톤/recomm/Recommendation/model/MF')
sys.path.append('/content/drive/MyDrive/(22-1)캡스톤/recomm/Recommendation/model/NCF')
import Loader
import MF
from neuralCF import NCF
import warnings
warnings.filterwarnings('ignore')

In [224]:
class Result:
  def __init__(self, batch_size):
    DIR = '/content/drive/MyDrive/(22-1)캡스톤/recomm/data/preprocessed/'
    loader1 = Loader.Loader(DIR, 1, 4) 
    loader2 = Loader.Loader(DIR, 2, 4)
    loader3 = Loader.Loader(DIR, 3, 4)
    loader4 = Loader.Loader(DIR, 4, 4)
    loader5 = Loader.Loader(DIR, 5, 4)
    loader1.load_dataset()
    loader2.load_dataset()
    loader3.load_dataset()
    loader4.load_dataset()
    loader5.load_dataset()
    self.loaders = [loader1, loader2, loader3, loader4, loader5]


    DIR2 = '/content/drive/MyDrive/(22-1)캡스톤/recomm/data/raw_data/'
    problemMeta = pd.read_csv(os.path.join(DIR2, "problemMeta.csv"))
    self.probid2level = {row[1]:row[6] for row in problemMeta.values}
    self.probid = set(problemMeta['problemId'].tolist())

    #self.mf = MF.MF(loader1.users_no, loader1.prob_no, loader1.useridx2level, loader1.probidx2level)
    self.ncf = NCF(loader1.users_no, loader1.prob_no, loader1.useridx2level, loader1.probidx2level)
    #BESTMODEL_DIR = '/content/drive/MyDrive/(22-1)캡스톤/recomm/Recommendation/model/MF/best_model/cluster'
    BESTMODEL_DIR = '/content/drive/MyDrive/(22-1)캡스톤/recomm/Recommendation/model/NCF/best_model/cluster'
    model1 = tf.keras.models.load_model(BESTMODEL_DIR+'1')
    model2 = tf.keras.models.load_model(BESTMODEL_DIR+'2')
    model3 = tf.keras.models.load_model(BESTMODEL_DIR+'3')
    model4 = tf.keras.models.load_model(BESTMODEL_DIR+'4')
    model5 = tf.keras.models.load_model(BESTMODEL_DIR+'5')
    self.models = [model1, model2, model3, model4, model5]

    self.batch_size = batch_size


  def get_output(self, id, problemIds):
    if len(problemIds) > 15:
      cluster, maxlevel = self.get_cluster(problemIds)
      #print(cluster, maxlevel)
      output = self.goto_model(id, problemIds, self.models[cluster-1], cluster, maxlevel)
      return output
    return []

  def get_cluster(self, problemIds):
   maxlevel = max([self.probid2level[prob] for prob in problemIds if prob in self.probid])
   #print([self.probid2level[prob] for prob in problemIds])
   if maxlevel <= 4 and maxlevel >= 1:
     return 1, maxlevel
   elif maxlevel <= 10:
     return 2, maxlevel
   elif maxlevel <= 13:
     return 3, maxlevel
   elif maxlevel <= 16:
     return 4, maxlevel
   else:
     return 5, maxlevel

  def goto_model(self, id, problemIds, model, cluster, maxlevel):
    usridx, useridx2level = self.get_usr_index(id, cluster, maxlevel)
    probidx = self.get_prb_index(problemIds, cluster)
    neg_probidx = self.get_negative_prob(probidx, cluster)
    
    train_usr = np.array([usridx] * len(probidx)).reshape(-1,1)
    train_prb = np.array(probidx).reshape(-1,1)
    train_entry = np.ones_like(train_usr)

    test_usr = np.array([usridx] * len(neg_probidx)).reshape(-1,1)
    test_prb = np.array(neg_probidx).reshape(-1,1)

    weights = model.get_weights()
    for i in tf.range(0, len(train_usr), self.batch_size):
      idxlist = range(i, min(i+self.batch_size, len(train_usr)-1))
      if len(idxlist) == 0:
        break
      model.fit([train_usr[idxlist], train_prb[idxlist]], train_entry[idxlist], verbose=0)
          
    pred = model.predict([test_usr, test_prb])
    pred = np.concatenate(pred).reshape(-1,1)
    
    filtered = self.ncf.level_filtering(test_usr, test_prb, pred, useridx2level, self.loaders[cluster-1].probidx2level, k=30)
    #filtered = self.mf.level_filtering(test_usr, test_prb, pred, self.loaders[cluster-1].useridx2level, self.loaders[cluster-1].probidx2level, k=30)
    model.set_weights(weights)

    output = self.get_id(filtered[1], cluster)
    return output

  def get_usr_index(self, id, cluster, maxlevel):
    try:
      usridx = self.loaders[cluster-1].userid2idx[id]
      return usridx, self.loaders[cluster-1].useridx2level
    except:
      usridx = self.loaders[cluster-1].users_no + 1
      useridx2lv = copy.deepcopy(self.loaders[cluster-1].useridx2level)
      useridx2lv[usridx] = maxlevel 
      return usridx, useridx2lv
  
  def get_prb_index(self, problemIds, cluster):
    prbidx = [self.loaders[cluster-1].probid2idx[prob] for prob in problemIds if prob in self.probid]
    return prbidx
  
  def get_negative_prob(self, problems, cluster):
    return list(set(range(0, self.loaders[cluster-1].prob_no)) - set(problems))

  def get_id(self,problems, cluster):
    prbid = [self.loaders[cluster-1].probidx2id[prob] for prob in problems]
    return prbid

In [225]:
result = Result(1024)

In [230]:
# id가 없는 경우 기존 아이디랑 겹치지 않게 임의지정해서 실행
print("listolanic: ", result.get_output("listolanic", [2438, 2439, 11021, 9498, 1330, 2739, 10171, 10172, 2753, 10818, 2884, 14681, 10718, 1000, 1001, 1008, 10869, 10871, 2557]))

listolanic:  [10951, 10950, 10952, 10430, 15552, 2741, 8393, 2742, 3052, 1110, 2588, 11022, 10998, 8958, 2562, 11654, 2577, 11720, 2675, 1152, 1546, 10809, 1712, 15596, 4344, 2908, 1157, 1065, 10926, 4673]


In [231]:
print("gunjung2147: ", result.get_output("gunjung2147", [1152, 2562, 1157, 2438, 2439, 11654, 1546, 11021, 11022, 2577, 9498, 2588, 1316, 1065, 1712, 1330, 2739, 2741, 2742, 10809, 10171, 10172, 10430, 15552, 2753, 4673, 10818, 2884, 10950, 10951, 10952, 8393, 11720, 2941, 1110, 14681, 2908, 10718, 1000, 1001, 3052, 15596, 1008, 2675, 2292, 10869, 5622, 10871, 4344, 10998, 2557, 8958]))

gunjung2147:  [2750, 2869, 2798, 2839, 10872, 1978, 10250, 1193, 2775, 10870, 2231, 4153, 2751, 1085, 1929, 10817, 2581, 11653, 7568, 10828, 1181, 1002, 2920, 1018, 10757, 1427, 2475, 1436, 11047, 3009]


In [232]:
print("guren97: ", result.get_output("guren97", [2562, 1546, 10250, 2577, 2580, 2581, 2588, 14888, 1065, 14889, 10797, 4153, 10809, 1085, 2621, 10814, 4673, 10817, 10818, 1094, 10833, 1110, 2661, 2675, 1652, 10869, 10870, 7287, 10871, 10872, 1152, 1157, 1181, 1193, 1712, 2739, 2741, 2742, 2231, 2743, 2750, 2751, 10430, 2753, 15552, 10950, 10951, 10952, 8393, 1747, 2775, 1759, 15596, 10989, 2798, 2292, 10996, 10998, 4344, 8958, 11021, 11022, 8979, 2839, 9498, 15649, 15650, 15651, 1316, 15652, 1330, 14645, 10039, 14656, 2884, 14681, 2908, 10103, 2941, 11650, 11651, 1924, 2438, 2439, 2440, 2441, 11654, 2446, 7568, 5522, 1427, 2455, 1436, 5543, 2475, 13752, 1978, 10171, 10172, 5565, 9663, 3009, 1987, 11718, 11719, 11720, 11721, 6603, 2523, 3036, 10718, 3046, 1000, 1001, 3052, 1008, 1011, 5622, 1018, 2557, 2558]))

guren97:  [1929, 2869, 9012, 10828, 2609, 10757, 11047, 10773, 1920, 1002, 2920, 1003, 2164, 9020, 1463, 4948, 1260, 1149, 10845, 11399, 11729, 2108, 9095, 11653, 2447, 1874, 11866, 1931, 1012, 1541]


In [233]:
print("gunsong2: ", result.get_output("gunsong2", [2562, 10757, 1546, 11279, 2577, 2579, 2581, 10773, 11286, 2588, 1065, 2606, 10809, 10816, 2557, 3649, 4673, 10817, 10818, 10828, 1110, 11866, 10845, 2667, 2156, 10866, 2675, 2164, 10869, 10870, 7287, 1655, 10871, 10872, 1149, 1152, 2178, 1157, 9372, 1182, 1697, 1193, 10926, 1712, 2739, 2741, 2742, 2231, 2750, 2751, 10430, 2753, 15552, 10950, 10951, 10952, 8393, 1753, 1260, 10989, 2798, 15596, 2805, 10998, 4344, 8958, 11021, 11022, 2839, 9498, 15649, 15650, 15651, 1316, 15652, 11047, 1330, 9012, 2869, 5430, 2884, 1874, 18258, 4949, 14681, 2908, 4963, 2920, 1912, 2941, 1920, 2438, 1927, 2439, 1929, 9095, 1931, 1932, 11654, 2447, 7569, 17298, 7576, 1966, 1463, 1978, 10171, 10172, 9663, 11720, 11725, 11726, 11729, 5086, 10718, 1000, 1001, 1002, 1003, 3052, 1008, 1010, 1012, 5622, 1021]))

gunsong2:  [2206, 1717, 12865, 2580, 11053, 12100, 1987, 16236, 2110, 9251, 2252, 1759, 1806, 1011, 11724, 14888, 1107, 11054, 11051, 14502, 11399, 15686, 1644, 14501, 10844, 2193, 11403, 1504, 9461, 10026]


In [235]:
print("ksejun95: ", result.get_output("ksejun95",[10250, 2083, 2089, 16430, 4153, 2118, 18512, 2153, 2154, 2156, 2161, 2163, 2164, 8320, 2178, 2193, 14490, 2206, 14496, 14501, 14502, 2217, 2225, 2231, 10430, 2239, 8393, 2250, 10448, 4307, 14556, 2292, 2293, 2294, 4344, 2302, 2309, 2338, 2355, 2357, 10569, 2399, 2407, 14696, 10610, 2437, 2438, 2439, 2440, 2441, 2442, 2443, 2444, 2445, 2446, 2447, 2455, 4504, 2458, 2468, 2475, 2477, 18868, 2490, 2491, 6588, 2493, 2504, 6603, 10699, 2522, 2523, 10718, 2531, 2555, 2556, 2557, 2558, 2559, 12796, 2562, 2563, 2564, 2565, 10757, 2573, 12813, 2576, 2577, 2578, 2579, 2581, 10773, 2583, 2588, 14888, 14889, 14891, 2605, 2606, 10797, 10798, 2609, 10799, 2615, 10808, 10809, 16953, 10814, 10815, 10816, 4673, 10817, 10818, 2628, 10820, 10824, 10825, 2635, 2636, 10828, 10829, 2644, 10844, 10845, 12904, 2667, 2668, 2669, 10866, 2675, 10867, 10868, 10869, 10870, 10871, 10872, 10886, 2729, 17069, 10926, 17070, 2739, 2741, 2742, 2743, 2745, 2747, 4796, 2748, 2749, 2750, 2751, 2752, 2753, 10943, 10950, 10951, 10952, 10953, 2775, 10972, 10973, 10974, 2789, 10987, 10988, 10989, 2798, 17135, 10996, 10998, 17143, 17144, 2812, 11004, 8958, 11005, 2822, 11021, 11022, 8979, 2839, 4889, 2851, 11047, 11048, 2857, 2858, 11050, 11052, 11053, 2863, 2864, 11057, 9012, 2869, 2873, 2875, 9020, 2884, 4948, 4949, 2902, 2908, 2914, 4963, 2920, 2921, 2941, 9093, 4999, 9095, 2953, 2959, 2960, 2961, 17298, 2965, 5014, 2985, 2997, 2999, 3009, 13249, 5063, 3040, 3046, 3047, 1000, 1001, 1002, 1003, 1004, 3049, 3052, 3053, 1008, 1009, 1010, 1012, 9205, 13300, 1015, 1016, 15353, 1018, 1021, 17406, 1024, 1026, 17413, 1032, 1037, 1049, 1057, 3109, 1062, 1065, 1072, 1074, 1075, 1076, 1080, 1085, 17471, 17472, 1091, 1094, 1100, 1110, 1120, 17504, 5218, 9316, 11365, 9322, 17521, 1138, 7287, 1145, 1149, 1152, 1157, 1158, 1159, 11399, 11403, 13458, 1181, 1193, 1194, 1205, 1212, 1213, 15552, 1237, 5338, 1244, 1247, 1254, 1260, 1261, 15596, 11505, 1267, 9461, 1271, 9465, 1297, 5397, 15641, 9498, 1309, 15649, 15650, 15651, 1316, 15652, 15654, 15655, 15656, 15657, 1325, 15661, 15663, 15664, 15665, 1330, 15666, 5430, 1339, 1342, 15685, 1350, 15686, 1356, 1357, 9550, 1371, 1373, 1389, 1406, 11650, 11651, 11652, 13701, 11654, 11655, 11656, 7562, 9613, 7569, 3474, 1427, 1431, 7576, 1436, 5532, 5543, 1449, 5557, 1463, 9655, 9656, 5565, 9663, 1475, 1476, 11718, 11719, 11720, 11721, 11724, 11725, 11726, 11727, 5585, 11729, 1500, 5597, 5598, 5622, 7682, 1541, 1543, 5639, 1546, 1547, 1550, 15897, 1562, 1568, 11816, 1592, 15961, 11866, 1629, 15969, 1644, 1652, 15988, 15989, 1676, 11931, 1697, 1699, 1700, 1706, 1707, 1712, 1717, 9933, 9935, 1748, 1753, 1759, 1764, 1783, 1786, 1789, 16134, 1821, 1834, 1837, 3895, 10039, 16194, 12101, 1863, 1874, 18258, 3933, 16236, 3955, 10101, 1912, 1916, 1918, 1919, 1920, 1924, 1929, 1931, 1932, 1934, 3985, 3986, 1937, 1946, 1964, 1965, 1966, 10157, 10158, 1969, 18352, 10163, 1975, 1977, 1978, 10171, 10172, 1987, 1992, 2004, 2010, 2011, 2012, 2018, 10219]))

ksejun95:  [2252, 2042, 11054, 9251, 12865, 1806, 1011, 1967, 1197, 10942, 11404, 1654, 11049, 2580, 1520, 1504, 12100, 1238, 1167, 12015, 6549, 1655, 10026, 1904, 1927, 11657, 2805, 11286, 1922, 11051]


In [236]:
print("임의지정: ", result.get_output("임의지정",[10250, 2083, 2089, 16430, 4153, 2118, 18512, 2153, 2154, 2156, 2161, 2163, 2164, 8320, 2178, 2193, 14490, 2206, 14496, 14501, 14502, 2217, 2225, 2231, 10430, 2239, 8393, 2250, 10448, 4307, 14556, 2292, 2293, 2294, 4344, 2302, 2309, 2338, 2355, 2357, 10569, 2399, 2407, 14696, 10610, 2437, 2438, 2439, 2440, 2441, 2442, 2443, 2444, 2445, 2446, 2447, 2455, 4504, 2458, 2468, 2475, 2477, 18868, 2490, 2491, 6588, 2493, 2504, 6603, 10699, 2522, 2523, 10718, 2531, 2555, 2556, 2557, 2558, 2559, 12796, 2562, 2563, 2564, 2565, 10757, 2573, 12813, 2576, 2577, 2578, 2579, 2581, 10773, 2583, 2588, 14888, 14889, 14891, 2605, 2606, 10797, 10798, 2609, 10799, 2615, 10808, 10809, 16953, 10814, 10815, 10816, 4673, 10817, 10818, 2628, 10820, 10824, 10825, 2635, 2636, 10828, 10829, 2644, 10844, 10845, 12904, 2667, 2668, 2669, 10866, 2675, 10867, 10868, 10869, 10870, 10871, 10872, 10886, 2729, 17069, 10926, 17070, 2739, 2741, 2742, 2743, 2745, 2747, 4796, 2748, 2749, 2750, 2751, 2752, 2753, 10943, 10950, 10951, 10952, 10953, 2775, 10972, 10973, 10974, 2789, 10987, 10988, 10989, 2798, 17135, 10996, 10998, 17143, 17144, 2812, 11004, 8958, 11005, 2822, 11021, 11022, 8979, 2839, 4889, 2851, 11047, 11048, 2857, 2858, 11050, 11052, 11053, 2863, 2864, 11057, 9012, 2869, 2873, 2875, 9020, 2884, 4948, 4949, 2902, 2908, 2914, 4963, 2920, 2921, 2941, 9093, 4999, 9095, 2953, 2959, 2960, 2961, 17298, 2965, 5014, 2985, 2997, 2999, 3009, 13249, 5063, 3040, 3046, 3047, 1000, 1001, 1002, 1003, 1004, 3049, 3052, 3053, 1008, 1009, 1010, 1012, 9205, 13300, 1015, 1016, 15353, 1018, 1021, 17406, 1024, 1026, 17413, 1032, 1037, 1049, 1057, 3109, 1062, 1065, 1072, 1074, 1075, 1076, 1080, 1085, 17471, 17472, 1091, 1094, 1100, 1110, 1120, 17504, 5218, 9316, 11365, 9322, 17521, 1138, 7287, 1145, 1149, 1152, 1157, 1158, 1159, 11399, 11403, 13458, 1181, 1193, 1194, 1205, 1212, 1213, 15552, 1237, 5338, 1244, 1247, 1254, 1260, 1261, 15596, 11505, 1267, 9461, 1271, 9465, 1297, 5397, 15641, 9498, 1309, 15649, 15650, 15651, 1316, 15652, 15654, 15655, 15656, 15657, 1325, 15661, 15663, 15664, 15665, 1330, 15666, 5430, 1339, 1342, 15685, 1350, 15686, 1356, 1357, 9550, 1371, 1373, 1389, 1406, 11650, 11651, 11652, 13701, 11654, 11655, 11656, 7562, 9613, 7569, 3474, 1427, 1431, 7576, 1436, 5532, 5543, 1449, 5557, 1463, 9655, 9656, 5565, 9663, 1475, 1476, 11718, 11719, 11720, 11721, 11724, 11725, 11726, 11727, 5585, 11729, 1500, 5597, 5598, 5622, 7682, 1541, 1543, 5639, 1546, 1547, 1550, 15897, 1562, 1568, 11816, 1592, 15961, 11866, 1629, 15969, 1644, 1652, 15988, 15989, 1676, 11931, 1697, 1699, 1700, 1706, 1707, 1712, 1717, 9933, 9935, 1748, 1753, 1759, 1764, 1783, 1786, 1789, 16134, 1821, 1834, 1837, 3895, 10039, 16194, 12101, 1863, 1874, 18258, 3933, 16236, 3955, 10101, 1912, 1916, 1918, 1919, 1920, 1924, 1929, 1931, 1932, 1934, 3985, 3986, 1937, 1946, 1964, 1965, 1966, 10157, 10158, 1969, 18352, 10163, 1975, 1977, 1978, 10171, 10172, 1987, 1992, 2004, 2010, 2011, 2012, 2018, 10219]))

임의지정:  [1126, 2316, 2820, 1014, 11376, 10999, 2188, 6086, 2261, 4013, 3653, 11375, 1168, 3176, 11003, 1761, 1948, 1725, 11438, 2568, 1708, 11658, 10090, 2150, 13548, 11400, 1067, 1102, 2213, 3665]


# test

In [91]:
import ast

In [92]:
DIR = '/content/drive/MyDrive/(22-1)캡스톤/recomm/data/preprocessed/'
user = pd.read_csv(os.path.join(DIR, "total_user_info.csv"))

In [93]:
sampling = user.sample(3000)
sampling

Unnamed: 0,handle,problemIds,max_level,cluster
56954,xogn13,"[2562, 14852, 10757, 17413, 1546, 10250, 1037,...",13,3.0
40530,poll2565,"[1152, 2753, 2884, 10171, 11654, 1000, 1546, 2...",5,2.0
33852,louisdebroglie,"[18436, 2056, 10256, 2092, 2098, 14389, 14390,...",21,5.0
2236,alex4242,"[1920, 1026, 2438, 2439, 2822, 2441, 11399, 11...",11,3.0
6271,cbqnk9,"[2562, 2438, 2439, 11021, 11022, 2577, 9498, 2...",5,2.0
...,...,...,...,...
18527,hachi557,"[2562, 1541, 2565, 10757, 1546, 10250, 1037, 2...",13,3.0
16454,ghbae1798,"[2562, 3078, 1546, 10250, 1037, 2573, 2577, 25...",15,4.0
20498,hollom,"[1922, 11266, 2824, 11659, 1932, 11404, 11660,...",17,5.0
11814,dltjgus0709,"[12865, 2504, 1929, 14888, 1292, 9996, 10828, ...",11,3.0


In [94]:
handles = sampling['handle'].tolist()
problemIds = sampling['problemIds'].apply(lambda x: ast.literal_eval(x)).tolist()
type(problemIds[0])

list

In [229]:
for i in range(len(handles)):
  print("[", i, "] ", handles[i], ": ", result.get_output(handles[i], problemIds[i]))

[ 0 ]  xogn13 :  [9012, 11729, 2164, 15650, 1931, 7576, 1541, 1920, 10773, 2447, 2579, 15649, 15651, 2108, 1874, 1932, 1012, 2920, 1904, 10816, 15652, 2606, 9663, 9461, 11866, 1966, 10817, 1654, 14888, 11050]
[ 1 ]  poll2565 :  []
[ 2 ]  louisdebroglie :  [9251, 1967, 2357, 1011, 7569, 5430, 2164, 1904, 1005, 11286, 9252, 1065, 1874, 5052, 1966, 1427, 9020, 2565, 4673, 2941, 1181, 1157, 3190, 2869, 1725, 1110, 2839, 4948, 8958, 4344]
[ 3 ]  alex4242 :  [11057, 2583, 2589, 1916, 17070, 6588, 6064, 9465, 14891, 2636, 1068, 2493, 2573, 2668, 11000, 2631, 16234, 11403, 14502, 14719, 14503, 12865, 20055, 14002, 13549, 3055, 11559, 17406, 2512, 17471]
[ 4 ]  cbqnk9 :  [3059, 11508, 6996, 1439, 2828, 1652, 14916, 9237, 15904, 7785, 2750, 14487, 19939, 2563, 1543, 18406, 19947, 1834, 11586, 8979, 10163, 2217, 2331, 13458, 2846, 19532, 5576, 15008, 11728, 2669]
[ 5 ]  xcrypt0r :  [5347, 2312, 2156, 11441, 1652, 1149, 2667, 3943, 11051, 1753, 1927, 1940, 8979, 2606, 1012, 1016, 2178, 6502, 2688,

KeyboardInterrupt: ignored

In [240]:
set(result.get_output(handles[0], problemIds[0])) & set(result.get_output("임의지정", problemIds[0]))

{1932, 7576, 9663}