In [1]:
import os
import json
import pandas as pd
import numpy as np
import torch
import re
import random
import pickle
import os
from tqdm import tqdm
random.seed(33)

In [2]:
input_dir = 'OriginalData/'
output_dir = './'
states = [ "warm_up", "user_cold_testing","meta_training"]

if not os.path.exists("{}/meta_training/".format(output_dir)):
    os.mkdir("{}/log/".format(output_dir))
    for state in states:
        os.mkdir("{}/{}/".format(output_dir, state))
        if not os.path.exists("{}/{}/{}".format(output_dir, "log", state)):
            os.mkdir("{}/{}/{}".format(output_dir, "log", state))

In [3]:
ui_data = pd.read_csv(input_dir+'ratings.dat', names=['user', 'item', 'rating', 'timestamp'],sep="::", engine='python')
len(ui_data)

1000209

In [4]:
user_data = pd.read_csv(input_dir+'users.dat', names=['user', 'gender', 'age', 'occupation_code', 'zip'],
                        sep="::", engine='python')
item_data = pd.read_csv(input_dir+'movies_extrainfos.dat', names=['item', 'title', 'year', 'rate', 'released', 'genre', 'director', 'writer', 'actors', 'plot', 'poster'],
                        sep="::", engine='python', encoding="utf-8")

In [5]:
user_list = list(set(ui_data.user.tolist()) | set(user_data.user))
item_list = list(set(ui_data.item.tolist()) | set(item_data.item))

In [6]:
user_num = len(user_list)
item_num = len(item_list)
user_num, item_num

(6040, 3881)

### 1. user and item feature

In [7]:
def load_list(fname):
    list_ = []
    with open(fname, encoding="utf-8") as f:
        for line in f.readlines():
            list_.append(line.strip())
    return list_

In [8]:
rate_list = load_list("{}/m_rate.txt".format(input_dir))
genre_list = load_list("{}/m_genre.txt".format(input_dir))
actor_list = load_list("{}/m_actor.txt".format(input_dir))
director_list = load_list("{}/m_director.txt".format(input_dir))
gender_list = load_list("{}/m_gender.txt".format(input_dir))
age_list = load_list("{}/m_age.txt".format(input_dir))
occupation_list = load_list("{}/m_occupation.txt".format(input_dir))
zipcode_list = load_list("{}/m_zipcode.txt".format(input_dir))
len(rate_list), len(genre_list), len(actor_list), len(director_list), len(gender_list), len(age_list), len(occupation_list), len(zipcode_list)

(6, 25, 7978, 2186, 2, 7, 21, 3402)

In [9]:
def item_converting(row, rate_list, genre_list, director_list, actor_list):
    rate_idx = torch.tensor([[rate_list.index(str(row['rate']))]]).long()
    genre_idx = torch.zeros(1, 25).long()
    for genre in str(row['genre']).split(", "):
        idx = genre_list.index(genre)
        genre_idx[0, idx] = 1

    director_idx = torch.zeros(1, 2186).long()
    director_id = []
    for director in str(row['director']).split(", "):
        idx = director_list.index(re.sub(r'\([^()]*\)', '', director))
        director_idx[0, idx] = 1
        director_id.append(idx+1)
    actor_idx = torch.zeros(1, 8030).long()
    actor_id = []
    for actor in str(row['actors']).split(", "):
        idx = actor_list.index(actor)
        actor_idx[0, idx] = 1
        actor_id.append(idx+1)
    return torch.cat((rate_idx, genre_idx), 1), torch.cat((rate_idx, genre_idx, director_idx, actor_idx), 1), director_id, actor_id

def user_converting(row, gender_list, age_list, occupation_list, zipcode_list):
    gender_idx = torch.tensor([[gender_list.index(str(row['gender']))]]).long()
    age_idx = torch.tensor([[age_list.index(str(row['age']))]]).long()
    occupation_idx = torch.tensor([[occupation_list.index(str(row['occupation_code']))]]).long()
    zip_idx = torch.tensor([[zipcode_list.index(str(row['zip'])[:5])]]).long()
    return torch.cat((gender_idx, age_idx, occupation_idx, zip_idx), 1)

In [10]:
movie_fea_hete = {}
movie_fea_homo = {}
m_directors = {}
m_actors = {}
for idx, row in item_data.iterrows():
    m_info = item_converting(row, rate_list, genre_list, director_list, actor_list)
    movie_fea_hete[row['item']] = m_info[0]
    movie_fea_homo[row['item']] = m_info[1]
    m_directors[row['item']] = m_info[2]
    m_actors[row['item']] = m_info[3]

In [11]:
user_fea = {}
for idx, row in user_data.iterrows():
    u_info = user_converting(row, gender_list, age_list, occupation_list, zipcode_list)
    user_fea[row['user']] = u_info

### 2. mp data

In [12]:
states = [ "warm_up", "user_cold_testing","meta_training"]

In [13]:
import collections
def reverse_dict(d):
    re_d = collections.defaultdict(list)
    for k, v_list in d.items():
        for v in v_list:
            re_d[v].append(k)
    return dict(re_d)

In [14]:
a_movies = reverse_dict(m_actors)
d_movies = reverse_dict(m_directors)
len(a_movies), len(d_movies)

(7978, 2186)

In [15]:
def jsonKeys2int(x):
    if isinstance(x, dict):
            return {int(k):v for k,v in x.items()}
    return x

In [129]:
state = 'meta_training'
support_u_movies = json.load(open(output_dir+state+'/support_u_movies.json','r'), object_hook=jsonKeys2int)
query_u_movies= json.load(open(output_dir+state+'/query_u_movies.json','r'), object_hook=jsonKeys2int)
support_u_movies_y = json.load(open(output_dir+state+'/support_u_movies_y.json','r'), object_hook=jsonKeys2int)
query_u_movies_y = json.load(open(output_dir+state+'/query_u_movies_y.json','r'), object_hook=jsonKeys2int)
if support_u_movies.keys() == query_u_movies.keys():
    u_id_list = support_u_movies.keys()
print(len(u_id_list))

train_u_movies = {}
train_u_movies_y = {}
if support_u_movies.keys() == query_u_movies.keys():
    u_id_list = support_u_movies.keys()
print(len(u_id_list))
for idx, u_id in tqdm(enumerate(u_id_list)):
    train_u_movies[int(u_id)] = []
    train_u_movies[int(u_id)] += support_u_movies[u_id]+query_u_movies[u_id]
    train_u_movies_y[int(u_id)] = []
    train_u_movies_y[int(u_id)] = support_u_movies_y[u_id]+query_u_movies_y[u_id]
len(train_u_movies),len(train_u_movies_y)


0it [00:00, ?it/s][A
2749it [00:00, 271278.30it/s][A

2749
2749


(2749, 2749)

In [104]:
train_u_id_list = list(u_id_list).copy()
len(train_u_id_list)

2749

In [109]:
print(state)

u_m_u_movies = {}
u_m_a_movies = {}
u_m_d_movies = {}

support_m_users = reverse_dict(support_u_movies)
for u in tqdm(u_id_list, leave=False, ncols=0):
    u_m_u_movies[u] = {}
    u_m_a_movies[u] = {}
    u_m_d_movies[u] = {}
    u_m_y = {}
    u_m = {}
    u_m_y_q = {}
    u_m_q = {}
    u_m_y = support_u_movies_y[u]
    u_m_y_q = query_u_movies_y[u]
    
    for m in support_u_movies[u]:
        u_m_a_movies[u][m] = set([m])
        for _a in m_actors[m]:
            cur_ms = a_movies[_a]
            u_m_a_movies[u][m].update(cur_ms)
            
        u_m_d_movies[u][m] = set([m])
        for _d in m_directors[m]:
            cur_ms = d_movies[_d]
            u_m_d_movies[u][m].update(cur_ms)    
    
    for m in support_u_movies[u]:
        u_m_u_movies[u][m] = set([m])
        u_m_u_movies[u][m].update(support_u_movies[u].copy())
        u_m = support_u_movies[u]
        index1_s = u_m.index(m)
        if m in support_m_users:
            for _u in support_m_users[m]:
                cur_ms = support_u_movies[_u]
                index2_s = cur_ms.index(m)
                if support_u_movies_y[_u][index2_s] == u_m_y[index1_s]:
                    u_m_u_movies[u][m].update(cur_ms)
    
    for m in query_u_movies[u]:
        if m in u_m_a_movies[u] or m in u_m_d_movies[u] or m in u_m_u_movies[u]:
            print('error!!!')
            break
        u_m_a_movies[u][m] = set([m])
        for _a in m_actors[m]:
            cur_ms = a_movies[_a]
            u_m_a_movies[u][m].update(cur_ms)
        u_m_d_movies[u][m] = set([m])
        for _d in m_directors[m]:
            cur_ms = d_movies[_d]
            u_m_d_movies[u][m].update(cur_ms)
        u_m_u_movies[u][m] = set([m])
        u_m_u_movies[u][m].update(support_u_movies[u].copy())
        u_m_q = query_u_movies[u]
        index1_q = u_m_q.index(m)
        if m in support_m_users:
            for _u in support_m_users[m]:
                cur_ms = support_u_movies[_u]  # list
                index2_q = cur_ms.index(m)
                if support_u_movies_y[_u][index2_q] == u_m_y_q[index1_q]:
                    u_m_u_movies[u][m].update(cur_ms)
        
print(len(u_m_u_movies), len(u_m_a_movies), len(u_m_d_movies))


  0% 0/2749 [00:00<?, ?it/s][A
  0% 12/2749 [00:00<00:23, 115.61it/s][A

meta_training



  1% 19/2749 [00:00<00:28, 96.03it/s] [A
  1% 32/2749 [00:00<00:26, 103.26it/s][A
  1% 41/2749 [00:00<00:27, 97.82it/s] [A
  2% 50/2749 [00:00<00:28, 93.87it/s][A
  2% 60/2749 [00:00<00:28, 95.60it/s][A
  3% 69/2749 [00:00<00:33, 79.30it/s][A
  3% 77/2749 [00:00<00:36, 73.89it/s][A
  3% 85/2749 [00:01<00:39, 66.81it/s][A
  3% 92/2749 [00:01<00:52, 50.90it/s][A
  4% 102/2749 [00:01<00:44, 59.49it/s][A
  4% 111/2749 [00:01<00:40, 65.86it/s][A
  4% 123/2749 [00:01<00:35, 73.73it/s][A
  5% 136/2749 [00:01<00:31, 83.41it/s][A
  5% 146/2749 [00:02<00:52, 49.50it/s][A
  6% 158/2749 [00:02<00:46, 55.62it/s][A
  6% 166/2749 [00:02<00:49, 52.68it/s][A
  6% 173/2749 [00:02<00:46, 55.90it/s][A
  7% 181/2749 [00:02<00:42, 60.52it/s][A
  7% 192/2749 [00:02<00:37, 68.45it/s][A
  7% 203/2749 [00:02<00:33, 76.57it/s][A
  8% 212/2749 [00:02<00:32, 77.87it/s][A
  8% 221/2749 [00:03<00:34, 73.66it/s][A
  8% 229/2749 [00:03<00:36, 69.23it/s][A
  9% 241/2749 [00:03<00:32, 78.33it/s]

2749


In [110]:
for idx, u_id in  tqdm(enumerate(u_id_list)):
    support_x_app = None
    support_um_app = []
    support_umum_app = []
    support_umam_app = []
    support_umdm_app = []
    for m_id in support_u_movies[u_id]:
        tmp_x_converted = torch.cat((movie_fea_hete[m_id], user_fea[u_id]), 1)
        try:
            support_x_app = torch.cat((support_x_app, tmp_x_converted), 0)
        except:
            support_x_app = tmp_x_converted

        support_um_app.append(torch.cat(list(map(lambda x: movie_fea_hete[x], support_u_movies[u_id])), dim=0))  # each element: (#neighbor, 26=1+25)
        support_umum_app.append(torch.cat(list(map(lambda x: movie_fea_hete[x], u_m_u_movies[u_id][m_id])), dim=0))
        support_umam_app.append(torch.cat(list(map(lambda x: movie_fea_hete[x], u_m_a_movies[u_id][m_id])), dim=0))
        support_umdm_app.append(torch.cat(list(map(lambda x: movie_fea_hete[x], u_m_d_movies[u_id][m_id])), dim=0))
    support_y_app = torch.FloatTensor(support_u_movies_y[u_id])
    
    pickle.dump(support_x_app, open("{}/{}/support_x_{}.pkl".format(output_dir, state, idx), "wb"))
    pickle.dump(support_y_app, open("{}/{}/support_y_{}.pkl".format(output_dir, state, idx), "wb"))
    pickle.dump(support_um_app, open("{}/{}/support_um_{}.pkl".format(output_dir, state, idx), "wb"))
    pickle.dump(support_umum_app, open("{}/{}/support_umum_{}.pkl".format(output_dir, state, idx), "wb"))
    pickle.dump(support_umam_app, open("{}/{}/support_umam_{}.pkl".format(output_dir, state, idx), "wb"))
    pickle.dump(support_umdm_app, open("{}/{}/support_umdm_{}.pkl".format(output_dir, state, idx), "wb"))
    
    query_x_app = None
    query_um_app = []
    query_umum_app = []
    query_umam_app = []
    query_umdm_app = []
    
    for m_id in query_u_movies[u_id]:
        tmp_x_converted = torch.cat((movie_fea_hete[m_id], user_fea[u_id]), 1)
        try:
            query_x_app = torch.cat((query_x_app, tmp_x_converted), 0)
        except:
            query_x_app = tmp_x_converted

        query_um_app.append(torch.cat(list(map(lambda x: movie_fea_hete[x], support_u_movies[u_id])), dim=0))  # each element: (#neighbor, 26=1+25)
        query_umum_app.append(torch.cat(list(map(lambda x: movie_fea_hete[x], u_m_u_movies[u_id][m_id])), dim=0))
        query_umam_app.append(torch.cat(list(map(lambda x: movie_fea_hete[x], u_m_a_movies[u_id][m_id])), dim=0))
        query_umdm_app.append(torch.cat(list(map(lambda x: movie_fea_hete[x], u_m_d_movies[u_id][m_id])), dim=0))
    query_y_app = torch.FloatTensor(query_u_movies_y[u_id])
    
    pickle.dump(query_x_app, open("{}/{}/query_x_{}.pkl".format(output_dir, state, idx), "wb"))
    pickle.dump(query_y_app, open("{}/{}/query_y_{}.pkl".format(output_dir, state, idx), "wb"))
    pickle.dump(query_um_app, open("{}/{}/query_um_{}.pkl".format(output_dir, state, idx), "wb"))
    pickle.dump(query_umum_app,open("{}/{}/query_umum_{}.pkl".format(output_dir, state, idx), "wb"))
    pickle.dump(query_umam_app,open("{}/{}/query_umam_{}.pkl".format(output_dir, state, idx), "wb"))
    pickle.dump(query_umdm_app,open("{}/{}/query_umdm_{}.pkl".format(output_dir, state, idx), "wb"))
print(idx)


0it [00:00, ?it/s][A
1it [00:01,  1.97s/it][A
2it [00:02,  1.41s/it][A
7it [00:02,  1.00it/s][A
9it [00:02,  1.40it/s][A
13it [00:02,  1.95it/s][A
16it [00:02,  2.70it/s][A
19it [00:02,  3.42it/s][A
24it [00:03,  4.67it/s][A
29it [00:03,  6.35it/s][A
33it [00:03,  8.14it/s][A
37it [00:03, 10.63it/s][A
41it [00:03, 12.02it/s][A
46it [00:03, 15.29it/s][A
52it [00:03, 19.55it/s][A
56it [00:04, 21.07it/s][A
60it [00:04, 19.52it/s][A
63it [00:04, 18.53it/s][A
66it [00:06,  4.44it/s][A
70it [00:06,  6.03it/s][A
73it [00:06,  7.39it/s][A
76it [00:06,  9.53it/s][A
79it [00:06, 11.53it/s][A
82it [00:07,  6.14it/s][A
87it [00:08,  8.23it/s][A
90it [00:08,  9.56it/s][A
93it [00:08, 11.93it/s][A
96it [00:08, 10.44it/s][A
101it [00:08, 13.67it/s][A
104it [00:09, 10.70it/s][A
107it [00:09,  8.45it/s][A
110it [00:09, 10.70it/s][A
113it [00:10, 12.61it/s][A
116it [00:10, 14.55it/s][A
119it [00:10, 12.58it/s][A
123it [00:10, 10.96it/s][A
125it [00:11,  8.77it/s][A


1096it [03:11, 10.39it/s][A
1099it [03:11, 11.52it/s][A
1101it [03:11, 12.14it/s][A
1104it [03:11, 12.30it/s][A
1106it [03:11, 10.36it/s][A
1108it [03:12,  9.55it/s][A
1110it [03:12,  8.45it/s][A
1112it [03:12,  7.38it/s][A
1113it [03:12,  7.78it/s][A
1114it [03:12,  8.28it/s][A
1115it [03:13,  4.37it/s][A
1117it [03:13,  5.63it/s][A
1119it [03:13,  6.50it/s][A
1121it [03:14,  5.86it/s][A
1122it [03:14,  6.29it/s][A
1123it [03:15,  1.67it/s][A
1124it [03:16,  2.10it/s][A
1125it [03:16,  2.49it/s][A
1126it [03:16,  2.49it/s][A
1128it [03:17,  2.78it/s][A
1129it [03:17,  3.16it/s][A
1130it [03:17,  2.90it/s][A
1131it [03:18,  3.44it/s][A
1133it [03:18,  4.26it/s][A
1134it [03:18,  3.82it/s][A
1136it [03:18,  4.95it/s][A
1139it [03:18,  6.39it/s][A
1141it [03:19,  7.52it/s][A
1143it [03:19,  6.56it/s][A
1145it [03:19,  5.98it/s][A
1146it [03:20,  5.72it/s][A
1148it [03:20,  4.96it/s][A
1149it [03:21,  3.74it/s][A
1150it [03:21,  3.20it/s][A
1151it [03:21,

1965it [06:01,  4.34it/s][A
1966it [06:01,  5.16it/s][A
1968it [06:01,  5.43it/s][A
1969it [06:02,  4.84it/s][A
1971it [06:02,  6.01it/s][A
1972it [06:02,  6.41it/s][A
1973it [06:02,  5.11it/s][A
1975it [06:02,  6.01it/s][A
1976it [06:03,  2.24it/s][A
1978it [06:04,  2.88it/s][A
1979it [06:04,  3.33it/s][A
1980it [06:04,  3.87it/s][A
1981it [06:05,  2.88it/s][A
1982it [06:05,  2.86it/s][A
1983it [06:05,  3.63it/s][A
1985it [06:05,  4.42it/s][A
1987it [06:05,  5.63it/s][A
1988it [06:06,  5.97it/s][A
1989it [06:06,  5.03it/s][A
1990it [06:06,  4.32it/s][A
1991it [06:06,  4.18it/s][A
1992it [06:07,  4.86it/s][A
1993it [06:07,  5.31it/s][A
1994it [06:07,  5.87it/s][A
1996it [06:07,  7.26it/s][A
1997it [06:07,  6.30it/s][A
1999it [06:07,  6.82it/s][A
2000it [06:09,  1.35it/s][A
2001it [06:10,  1.62it/s][A
2002it [06:10,  1.90it/s][A
2003it [06:11,  1.81it/s][A
2004it [06:11,  2.38it/s][A
2005it [06:11,  2.83it/s][A
2007it [06:11,  3.21it/s][A
2009it [06:12,

2748


In [161]:
# state = 'warm_up'
# state = 'user_cold_testing'

support_u_movies = json.load(open(output_dir+state+'/support_u_movies.json','r'), object_hook=jsonKeys2int)
query_u_movies= json.load(open(output_dir+state+'/query_u_movies.json','r'), object_hook=jsonKeys2int)
support_u_movies_y = json.load(open(output_dir+state+'/support_u_movies_y.json','r'), object_hook=jsonKeys2int)
query_u_movies_y = json.load(open(output_dir+state+'/query_u_movies_y.json','r'), object_hook=jsonKeys2int)
if support_u_movies.keys() == query_u_movies.keys():
    u_id_list = support_u_movies.keys()
print(len(u_id_list))

cur_train_u_movies =  train_u_movies.copy()
cur_train_u_movies_y = train_u_movies_y.copy()

if support_u_movies.keys() == query_u_movies.keys():
    u_id_list = support_u_movies.keys()
print(len(u_id_list))
for idx, u_id in tqdm(enumerate(u_id_list)):
    if u_id not in cur_train_u_movies:
        cur_train_u_movies[u_id] = []
        cur_train_u_movies_y[u_id] = []
    cur_train_u_movies[u_id] += support_u_movies[u_id]
    cur_train_u_movies_y[u_id] += support_u_movies_y[u_id]

print(len(cur_train_u_movies),  len(train_u_movies))
print(len(cur_train_u_movies_y),  len(train_u_movies_y))
print(len(set(train_u_id_list) & set(u_id_list)))

(len(u_id_list) +  len(train_u_movies) - len(set(train_u_id_list) & set(u_id_list))) == len(set(cur_train_u_movies))


0it [00:00, ?it/s][A
731it [00:00, 419832.43it/s][A

731
731
3480 2749
3480 2749
0


True

In [163]:
u_m_u_movies = {}
u_m_a_movies = {}
u_m_d_movies = {}
cur_train_m_users = reverse_dict(cur_train_u_movies)

for u in tqdm(u_id_list, leave=False, ncols=0):
    u_m_u_movies[u] = {}
    u_m_a_movies[u] = {}
    u_m_d_movies[u] = {}
    u_m_y = {}
    u_m = {}
    u_m_y_q = {}
    u_m_q = {}
    u_m_y = support_u_movies_y[u]
    u_m_y_q = query_u_movies_y[u]
    for m in support_u_movies[u]:
        u_m_u_movies[u][m] = set([m])
        u_m_u_movies[u][m].update(cur_train_u_movies[u].copy())
        u_m = support_u_movies[u]
        index1_s = u_m.index(m)
        if m in cur_train_m_users:
            for _u in cur_train_m_users[m]:
                cur_ms = cur_train_u_movies[_u]
                index2_s = cur_ms.index(m)
                if cur_train_u_movies_y[_u][index2_s] == u_m_y[index1_s]:
                    u_m_u_movies[u][m].update(cur_ms)
        u_m_a_movies[u][m] = set([m])
        for _a in m_actors[m]:
            cur_ms = a_movies[_a]
            u_m_a_movies[u][m].update(cur_ms)
            
        u_m_d_movies[u][m] = set([m])
        for _d in m_directors[m]:
            cur_ms = d_movies[_d]
            u_m_d_movies[u][m].update(cur_ms)
    
    for m in query_u_movies[u]:
        if m in u_m_a_movies[u] or m in u_m_d_movies[u] or m in u_m_u_movies[u]:
            print('error!!!')
            break
            
        u_m_u_movies[u][m] = set([m])
        u_m_u_movies[u][m].update(cur_train_u_movies[u].copy())
        u_m_q = query_u_movies[u]
        index1_q = u_m_q.index(m)
        if m in cur_train_m_users:  # for meta_training, only support set can be seen!!!
            for _u in cur_train_m_users[m]:  #  only include user in training set !!!!
                cur_ms = cur_train_u_movies[_u]  # list
                index2_q = cur_ms.index(m)
                if cur_train_u_movies_y[_u][index2_q] == u_m_y_q[index1_q]:
                    u_m_u_movies[u][m].update(cur_ms)
        u_m_a_movies[u][m] = set([m])
        for _a in m_actors[m]:
            cur_ms = a_movies[_a]
            u_m_a_movies[u][m].update(cur_ms)
            
        u_m_d_movies[u][m] = set([m])
        for _d in m_directors[m]:
            cur_ms = d_movies[_d]
            u_m_d_movies[u][m].update(cur_ms)
print(len(u_m_u_movies), len(u_m_a_movies), len(u_m_d_movies))


  0% 0/731 [00:00<?, ?it/s][A
  1% 10/731 [00:00<00:08, 82.51it/s][A

user_and_item_cold_testing



  2% 17/731 [00:00<00:09, 76.12it/s][A
  4% 27/731 [00:00<00:08, 81.88it/s][A
  5% 36/731 [00:00<00:08, 83.36it/s][A
  6% 45/731 [00:00<00:08, 83.15it/s][A
  8% 55/731 [00:00<00:07, 85.88it/s][A
  9% 64/731 [00:00<00:07, 83.84it/s][A
 10% 75/731 [00:00<00:07, 89.59it/s][A
 11% 84/731 [00:00<00:07, 87.35it/s][A
 13% 93/731 [00:01<00:07, 85.81it/s][A
 14% 102/731 [00:01<00:07, 83.11it/s][A
 15% 111/731 [00:01<00:07, 81.82it/s][A
 16% 120/731 [00:01<00:07, 76.51it/s][A
 18% 130/731 [00:01<00:07, 79.91it/s][A
 19% 139/731 [00:01<00:07, 79.04it/s][A
 20% 148/731 [00:01<00:07, 80.88it/s][A
 21% 157/731 [00:01<00:07, 78.58it/s][A
 23% 167/731 [00:02<00:06, 83.12it/s][A
 24% 176/731 [00:02<00:06, 80.50it/s][A
 25% 185/731 [00:02<00:07, 77.48it/s][A
 26% 193/731 [00:02<00:07, 74.77it/s][A
 28% 203/731 [00:02<00:06, 79.55it/s][A
 29% 212/731 [00:02<00:06, 80.46it/s][A
 31% 223/731 [00:02<00:06, 84.04it/s][A
 32% 232/731 [00:03<00:09, 54.43it/s][A
 33% 240/731 [00:03<00:0

731


In [164]:
for idx, u_id in  tqdm(enumerate(u_id_list)):
    support_x_app = None
    support_um_app = []
    support_umum_app = []
    support_umam_app = []
    support_umdm_app = []  
    for m_id in support_u_movies[u_id]:
        tmp_x_converted = torch.cat((movie_fea_hete[m_id], user_fea[u_id]), 1)
        try:
            support_x_app = torch.cat((support_x_app, tmp_x_converted), 0)
        except:
            support_x_app = tmp_x_converted

        support_um_app.append(torch.cat(list(map(lambda x: movie_fea_hete[x], cur_train_u_movies[u_id])), dim=0))  # each element: (#neighbor, 26=1+25)
        support_umum_app.append(torch.cat(list(map(lambda x: movie_fea_hete[x], u_m_u_movies[u_id][m_id])), dim=0))
        support_umam_app.append(torch.cat(list(map(lambda x: movie_fea_hete[x], u_m_a_movies[u_id][m_id])), dim=0))
        support_umdm_app.append(torch.cat(list(map(lambda x: movie_fea_hete[x], u_m_d_movies[u_id][m_id])), dim=0))
    support_y_app = torch.FloatTensor(support_u_movies_y[u_id])
    
    pickle.dump(support_x_app, open("{}/{}/support_x_{}.pkl".format(output_dir, state, idx), "wb"))
    pickle.dump(support_y_app, open("{}/{}/support_y_{}.pkl".format(output_dir, state, idx), "wb"))
    pickle.dump(support_um_app, open("{}/{}/support_um_{}.pkl".format(output_dir, state, idx), "wb"))
    pickle.dump(support_umum_app, open("{}/{}/support_umum_{}.pkl".format(output_dir, state, idx), "wb"))
    pickle.dump(support_umam_app, open("{}/{}/support_umam_{}.pkl".format(output_dir, state, idx), "wb"))
    pickle.dump(support_umdm_app, open("{}/{}/support_umdm_{}.pkl".format(output_dir, state, idx), "wb"))
    
    query_x_app = None
    query_um_app = []
    query_umum_app = []
    query_umam_app = []
    query_umdm_app = []
    for m_id in query_u_movies[u_id]:
        tmp_x_converted = torch.cat((movie_fea_hete[m_id], user_fea[u_id]), 1)
        try:
            query_x_app = torch.cat((query_x_app, tmp_x_converted), 0)
        except:
            query_x_app = tmp_x_converted

        query_um_app.append(torch.cat(list(map(lambda x: movie_fea_hete[x], cur_train_u_movies[u_id]+[m_id])), dim=0))  # each element: (#neighbor, 26=1+25)
        query_umum_app.append(torch.cat(list(map(lambda x: movie_fea_hete[x], u_m_u_movies[u_id][m_id])), dim=0))
        query_umam_app.append(torch.cat(list(map(lambda x: movie_fea_hete[x], u_m_a_movies[u_id][m_id])), dim=0))
        query_umdm_app.append(torch.cat(list(map(lambda x: movie_fea_hete[x], u_m_d_movies[u_id][m_id])), dim=0))
    query_y_app = torch.FloatTensor(query_u_movies_y[u_id])
    
    pickle.dump(query_x_app, open("{}/{}/query_x_{}.pkl".format(output_dir, state, idx), "wb"))
    pickle.dump(query_y_app, open("{}/{}/query_y_{}.pkl".format(output_dir, state, idx), "wb"))
    pickle.dump(query_um_app, open("{}/{}/query_um_{}.pkl".format(output_dir, state, idx), "wb"))
    pickle.dump(query_umum_app,open("{}/{}/query_umum_{}.pkl".format(output_dir, state, idx), "wb"))
    pickle.dump(query_umam_app,open("{}/{}/query_umam_{}.pkl".format(output_dir, state, idx), "wb"))
    pickle.dump(query_umdm_app,open("{}/{}/query_umdm_{}.pkl".format(output_dir, state, idx), "wb"))
print(idx)


0it [00:00, ?it/s][A
6it [00:00, 52.23it/s][A
10it [00:00, 45.09it/s][A
14it [00:00, 42.47it/s][A
19it [00:00, 43.42it/s][A
25it [00:00, 46.49it/s][A
29it [00:00, 42.58it/s][A
33it [00:00, 38.38it/s][A
40it [00:00, 42.50it/s][A
46it [00:01, 46.20it/s][A
53it [00:01, 48.93it/s][A
59it [00:01, 51.23it/s][A
65it [00:01, 49.45it/s][A
71it [00:01, 52.04it/s][A
78it [00:01, 54.85it/s][A
84it [00:02, 29.94it/s][A
89it [00:02, 18.57it/s][A
93it [00:03,  9.35it/s][A
96it [00:04,  5.36it/s][A
98it [00:06,  2.71it/s][A
100it [00:06,  3.40it/s][A
102it [00:09,  1.37it/s][A
103it [01:01, 16.01s/it][A
109it [01:01, 11.22s/it][A
111it [01:01,  7.88s/it][A
112it [01:02,  5.81s/it][A
113it [01:04,  4.52s/it][A
114it [01:05,  3.34s/it][A
115it [01:06,  2.67s/it][A
116it [01:06,  1.99s/it][A
117it [01:06,  1.46s/it][A
118it [01:10,  2.11s/it][A
119it [01:10,  1.53s/it][A
120it [01:10,  1.14s/it][A
121it [01:11,  1.10it/s][A
122it [01:11,  1.43it/s][A
123it [01:12,  1.

730
