## 參考
- [BST - keras.io](https://keras.io/examples/structured_data/movielens_recommendations_transformers/)

In [9]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

In [10]:
import tensorflow as tf
from tensorflow import keras 
from tensorflow.keras import layers

In [71]:
# 讀取資料 from .dat

data_dir = 'ml-1m/ml-1m'

users = pd.read_csv(data_dir+'/users.dat', delimiter='::', names=['user_id', 'gender', 'age', 'occupation', 'zip_code'])
movies = pd.read_csv(data_dir+'/movies.dat', delimiter='::', names=['movie_id', 'title', 'genres'])
ratings = pd.read_csv(data_dir+'/ratings.dat', delimiter='::', names=['user_id', 'movie_id', 'rating', 'timestamp'])

  """
  
  import sys


In [72]:
users.head()

Unnamed: 0,user_id,gender,age,occupation,zip_code
0,1,F,1,10,48067
1,2,M,56,16,70072
2,3,M,25,15,55117
3,4,M,45,7,2460
4,5,M,25,20,55455


In [73]:
movies.head()

Unnamed: 0,movie_id,title,genres
0,1,Toy Story (1995),Animation|Children's|Comedy
1,2,Jumanji (1995),Adventure|Children's|Fantasy
2,3,Grumpier Old Men (1995),Comedy|Romance
3,4,Waiting to Exhale (1995),Comedy|Drama
4,5,Father of the Bride Part II (1995),Comedy


In [74]:
ratings.head()

Unnamed: 0,user_id,movie_id,rating,timestamp
0,1,1193,5,978300760
1,1,661,3,978302109
2,1,914,3,978301968
3,1,3408,4,978300275
4,1,2355,5,978824291


In [75]:
ratings.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1000209 entries, 0 to 1000208
Data columns (total 4 columns):
 #   Column     Non-Null Count    Dtype
---  ------     --------------    -----
 0   user_id    1000209 non-null  int64
 1   movie_id   1000209 non-null  int64
 2   rating     1000209 non-null  int64
 3   timestamp  1000209 non-null  int64
dtypes: int64(4)
memory usage: 30.5 MB


In [76]:
# 先透過df frame 排序, 之後用group by就可以得到正確順序(會保持原順序)

ratings.sort_values(by=['user_id', 'timestamp'], inplace=True, ascending=True)

In [77]:
ratings

Unnamed: 0,user_id,movie_id,rating,timestamp
31,1,3186,4,978300019
22,1,1270,5,978300055
27,1,1721,4,978300055
37,1,1022,5,978300055
24,1,2340,3,978300103
...,...,...,...,...
1000019,6040,2917,4,997454429
999988,6040,1921,4,997454464
1000172,6040,1784,3,997454464
1000167,6040,161,3,997454486


In [79]:
# 取得重要參數, 就是總movie 數目, +1即可用於[mask]

num_movies = len(movies['movie_id'].unique())
num_users = len(users['user_id'].unique())

print(f'電影數量: {num_movies}\n使用者數量: {num_users}')

電影數量: 3883
使用者數量: 6040


In [83]:
# 中間有斷, 有點麻煩, 需要重新排序, 或者將數量提升, 沒有的就沒用到而已, okay這樣比較方便。

num_movies = movies['movie_id'].max() + 1
mask_token_id = movies['movie_id'].max() + 1
movies['movie_id'].max()

3952

In [93]:
# 整理成論文需要的資料
# [movie1, movie2, movie3, movie4, movie5...100]
# N = 100
from pprint import pprint

# 透過ratings, 利用groupby user_id, timestamp


groups = ratings.groupby(by=['user_id']).groups

In [95]:
lengths = 0

for key in groups:
    print(f'user_id(${key}): 長度(${len(groups[key])})')
    lengths += len(groups[key])
    print('-'*15)
    pprint(groups[key])

user_id($1): 長度($53)
---------------
Int64Index([31, 22, 27, 37, 24, 36,  3,  7, 47,  0, 21, 44,  9, 51, 43, 41, 48,
            18, 11, 14, 42, 17, 39, 45, 26,  2,  6, 19, 38, 52,  1, 13, 49, 50,
            15, 20, 46,  5,  8, 12, 28, 23, 10, 16, 29, 33, 40,  4, 30, 35, 32,
            34, 25],
           dtype='int64')
user_id($2): 長度($129)
---------------
Int64Index([130,  64,  71, 134,  88, 170, 106, 120, 172,  70,
            ...
             92, 129, 167,  74,  80, 133,  66,  73,  87, 136],
           dtype='int64', length=129)
user_id($3): 長度($51)
---------------
Int64Index([217, 202, 186, 230, 190, 225, 209, 216, 218, 226, 212, 201, 211,
            228, 214, 215, 223, 189, 195, 198, 199, 205, 207, 194, 197, 210,
            227, 222, 229, 204, 206, 224, 184, 188, 231, 203, 182, 185, 193,
            220, 200, 221, 192, 196, 208, 183, 213, 219, 187, 191, 232],
           dtype='int64')
user_id($4): 長度($21)
---------------
Int64Index([234, 244, 233, 240, 249, 238, 241, 242, 248

---------------
Int64Index([30182, 30190, 30230, 30245, 30275, 30225, 30292, 30176, 30294,
            30266,
            ...
            30174, 30260, 30173, 30290, 30213, 30205, 30233, 30281, 30228,
            30244],
           dtype='int64', length=127)
user_id($204): 長度($446)
---------------
Int64Index([30610, 30649, 30679, 30715, 30396, 30709, 30304, 30428, 30510,
            30677,
            ...
            30723, 30514, 30606, 30353, 30385, 30483, 30555, 30685, 30517,
            30564],
           dtype='int64', length=446)
user_id($205): 長度($153)
---------------
Int64Index([30749, 30809, 30845, 30846, 30871, 30746, 30767, 30780, 30873,
            30884,
            ...
            30837, 30861, 30891, 30862, 30762, 30805, 30821, 30836, 30881,
            30801],
           dtype='int64', length=153)
user_id($206): 長度($30)
---------------
Int64Index([30900, 30904, 30917, 30899, 30902, 30907, 30910, 30912, 30914,
            30916, 30920, 30922, 30905, 30927, 30898, 30908, 

Int64Index([62083, 62097, 62127, 62079, 62099, 62107, 62138, 62070, 62089,
            62111,
            ...
            62166, 62122, 62133, 62168, 62125, 62126, 62135, 62145, 62090,
            62071],
           dtype='int64', length=108)
user_id($420): 長度($64)
---------------
Int64Index([62201, 62238, 62208, 62216, 62203, 62223, 62177, 62186, 62185,
            62190, 62219, 62228, 62229, 62195, 62232, 62182, 62183, 62213,
            62224, 62236, 62239, 62198, 62209, 62180, 62227, 62187, 62188,
            62206, 62231, 62233, 62221, 62217, 62220, 62235, 62197, 62200,
            62222, 62192, 62193, 62199, 62211, 62237, 62212, 62218, 62178,
            62184, 62204, 62207, 62214, 62234, 62240, 62205, 62181, 62191,
            62196, 62226, 62179, 62194, 62202, 62210, 62215, 62225, 62189,
            62230],
           dtype='int64')
user_id($421): 長度($20)
---------------
Int64Index([62245, 62248, 62255, 62257, 62259, 62260, 62256, 62251, 62253,
            62243, 62250, 62254, 

user_id($615): 長度($94)
---------------
Int64Index([92272, 92304, 92312, 92253, 92287, 92333, 92315, 92324, 92284,
            92273, 92281, 92292, 92285, 92243, 92262, 92275, 92276, 92286,
            92327, 92279, 92251, 92254, 92289, 92300, 92326, 92267, 92293,
            92332, 92258, 92331, 92335, 92269, 92261, 92319, 92249, 92325,
            92329, 92265, 92252, 92302, 92256, 92257, 92299, 92336, 92320,
            92244, 92296, 92322, 92271, 92314, 92259, 92255, 92282, 92277,
            92283, 92309, 92250, 92245, 92303, 92317, 92248, 92298, 92328,
            92308, 92247, 92263, 92311, 92321, 92260, 92306, 92278, 92323,
            92274, 92246, 92297, 92290, 92310, 92313, 92316, 92266, 92295,
            92307, 92291, 92318, 92280, 92288, 92294, 92305, 92264, 92270,
            92330, 92268, 92334, 92301],
           dtype='int64')
user_id($616): 長度($44)
---------------
Int64Index([92349, 92359, 92367, 92379, 92370, 92339, 92347, 92353, 92374,
            92338, 92341, 9234

           dtype='int64', length=126)
user_id($844): 長度($84)
---------------
Int64Index([130295, 130311, 130333, 130334, 130306, 130272, 130348, 130289,
            130296, 130305, 130273, 130297, 130325, 130307, 130299, 130345,
            130326, 130281, 130294, 130284, 130274, 130319, 130280, 130353,
            130283, 130339, 130288, 130271, 130276, 130300, 130338, 130344,
            130318, 130298, 130291, 130308, 130323, 130270, 130315, 130340,
            130352, 130349, 130310, 130350, 130316, 130330, 130287, 130278,
            130279, 130282, 130347, 130335, 130332, 130301, 130312, 130329,
            130324, 130336, 130286, 130314, 130337, 130341, 130346, 130351,
            130302, 130309, 130317, 130285, 130293, 130303, 130313, 130327,
            130342, 130277, 130275, 130292, 130290, 130321, 130343, 130322,
            130331, 130328, 130304, 130320],
           dtype='int64')
user_id($845): 長度($52)
---------------
Int64Index([130355, 130363, 130371, 130365, 130374, 1

           dtype='int64', length=337)
user_id($1035): 長度($91)
---------------
Int64Index([162476, 162488, 162462, 162470, 162503, 162506, 162500, 162518,
            162481, 162514, 162443, 162498, 162440, 162441, 162475, 162497,
            162451, 162478, 162517, 162466, 162474, 162520, 162501, 162438,
            162479, 162480, 162477, 162527, 162444, 162485, 162490, 162524,
            162528, 162493, 162526, 162492, 162525, 162522, 162504, 162507,
            162439, 162483, 162469, 162459, 162508, 162464, 162489, 162491,
            162465, 162482, 162452, 162453, 162463, 162454, 162455, 162523,
            162442, 162446, 162486, 162509, 162484, 162472, 162513, 162447,
            162449, 162448, 162502, 162458, 162494, 162511, 162473, 162468,
            162519, 162487, 162460, 162516, 162512, 162499, 162510, 162496,
            162471, 162505, 162515, 162445, 162450, 162456, 162461, 162521,
            162457, 162467, 162495],
           dtype='int64')
user_id($1036): 長度($50)

           dtype='int64', length=126)
user_id($1177): 長度($79)
---------------
Int64Index([189801, 189824, 189826, 189835, 189857, 189786, 189789, 189791,
            189815, 189837, 189856, 189808, 189809, 189858, 189860, 189787,
            189802, 189812, 189850, 189803, 189836, 189863, 189864, 189805,
            189834, 189859, 189818, 189817, 189847, 189806, 189855, 189799,
            189814, 189822, 189792, 189804, 189840, 189795, 189861, 189823,
            189844, 189839, 189845, 189851, 189854, 189788, 189798, 189831,
            189842, 189848, 189849, 189790, 189813, 189825, 189853, 189794,
            189821, 189832, 189841, 189819, 189846, 189862, 189797, 189830,
            189828, 189833, 189843, 189811, 189820, 189829, 189852, 189793,
            189796, 189800, 189810, 189827, 189838, 189807, 189816],
           dtype='int64')
user_id($1178): 長度($36)
---------------
Int64Index([189891, 189884, 189890, 189892, 189900, 189882, 189899, 189865,
            189871, 189872,

           dtype='int64', length=132)
user_id($1339): 長度($213)
---------------
Int64Index([220846, 220850, 220855, 220889, 220954, 220891, 220893, 220926,
            220831, 220767,
            ...
            220798, 220851, 220761, 220777, 220820, 220967, 220958, 220847,
            220822, 220755],
           dtype='int64', length=213)
user_id($1340): 長度($805)
---------------
Int64Index([220991, 221220, 221465, 221101, 221591, 221713, 220979, 221308,
            221696, 220983,
            ...
            221393, 221271, 220987, 221244, 221493, 221736, 221242, 221222,
            221142, 221234],
           dtype='int64', length=805)
user_id($1341): 長度($87)
---------------
Int64Index([221821, 221855, 221792, 221797, 221826, 221827, 221845, 221776,
            221781, 221785, 221795, 221818, 221834, 221777, 221815, 221839,
            221848, 221852, 221854, 221857, 221783, 221836, 221853, 221790,
            221816, 221823, 221835, 221837, 221841, 221843, 221850, 221782,
          

           dtype='int64')
user_id($1539): 長度($27)
---------------
Int64Index([253214, 253216, 253219, 253222, 253200, 253204, 253201, 253215,
            253223, 253226, 253221, 253207, 253211, 253213, 253224, 253225,
            253202, 253203, 253205, 253206, 253212, 253209, 253210, 253208,
            253218, 253217, 253220],
           dtype='int64')
user_id($1540): 長度($118)
---------------
Int64Index([253244, 253277, 253282, 253307, 253312, 253232, 253228, 253229,
            253341, 253342,
            ...
            253315, 253329, 253240, 253246, 253275, 253313, 253317, 253324,
            253233, 253267],
           dtype='int64', length=118)
user_id($1541): 長度($26)
---------------
Int64Index([253364, 253346, 253361, 253365, 253353, 253356, 253360, 253368,
            253357, 253350, 253359, 253362, 253348, 253354, 253367, 253351,
            253355, 253349, 253352, 253366, 253347, 253358, 253363, 253369,
            253370, 253345],
           dtype='int64')
user_id($1542): 

           dtype='int64')
user_id($1729): 長度($58)
---------------
Int64Index([289681, 289686, 289691, 289671, 289683, 289697, 289717, 289669,
            289696, 289720, 289668, 289694, 289670, 289692, 289721, 289665,
            289666, 289699, 289706, 289711, 289713, 289716, 289667, 289714,
            289676, 289690, 289684, 289698, 289710, 289682, 289672, 289685,
            289703, 289718, 289680, 289693, 289687, 289719, 289689, 289678,
            289708, 289673, 289700, 289715, 289674, 289695, 289704, 289688,
            289712, 289705, 289722, 289677, 289709, 289679, 289707, 289702,
            289675, 289701],
           dtype='int64')
user_id($1730): 長度($24)
---------------
Int64Index([289733, 289745, 289734, 289726, 289725, 289723, 289724, 289729,
            289730, 289728, 289727, 289735, 289743, 289744, 289746, 289732,
            289738, 289739, 289731, 289736, 289737, 289740, 289742, 289741],
           dtype='int64')
user_id($1731): 長度($68)
---------------
Int64Index([

           dtype='int64')
user_id($1964): 長度($22)
---------------
Int64Index([333048, 333052, 333054, 333055, 333047, 333056, 333035, 333051,
            333036, 333040, 333041, 333045, 333050, 333039, 333049, 333037,
            333042, 333046, 333038, 333044, 333043, 333053],
           dtype='int64')
user_id($1965): 長度($64)
---------------
Int64Index([333063, 333110, 333086, 333080, 333117, 333072, 333075, 333078,
            333093, 333057, 333099, 333120, 333107, 333074, 333077, 333091,
            333101, 333112, 333064, 333089, 333095, 333069, 333098, 333088,
            333102, 333108, 333115, 333070, 333106, 333114, 333073, 333065,
            333094, 333096, 333109, 333092, 333104, 333081, 333111, 333066,
            333085, 333058, 333079, 333087, 333103, 333097, 333067, 333076,
            333084, 333071, 333113, 333060, 333105, 333116, 333118, 333082,
            333083, 333119, 333068, 333090, 333100, 333062, 333059, 333061],
           dtype='int64')
user_id($1966): 長度($

           dtype='int64', length=293)
user_id($2186): 長度($292)
---------------
Int64Index([375017, 375075, 375203, 375223, 375000, 375005, 375090, 375205,
            375237, 374967,
            ...
            375179, 375073, 375173, 375032, 375016, 375112, 375162, 375018,
            375038, 375021],
           dtype='int64', length=292)
user_id($2187): 長度($137)
---------------
Int64Index([375330, 375314, 375376, 375321, 375351, 375299, 375334, 375341,
            375369, 375354,
            ...
            375358, 375264, 375301, 375385, 375392, 375277, 375327, 375342,
            375322, 375316],
           dtype='int64', length=137)
user_id($2188): 長度($405)
---------------
Int64Index([375588, 375403, 375571, 375423, 375724, 375531, 375696, 375430,
            375525, 375589,
            ...
            375702, 375458, 375620, 375786, 375760, 375676, 375574, 375698,
            375650, 375729],
           dtype='int64', length=405)
user_id($2189): 長度($35)
---------------
Int64Index

           dtype='int64')
user_id($2387): 長度($62)
---------------
Int64Index([400751, 400795, 400749, 400760, 400762, 400767, 400791, 400788,
            400740, 400753, 400776, 400785, 400758, 400792, 400743, 400763,
            400778, 400784, 400766, 400790, 400736, 400780, 400789, 400739,
            400750, 400768, 400770, 400772, 400781, 400765, 400782, 400742,
            400757, 400764, 400759, 400769, 400771, 400773, 400738, 400755,
            400786, 400747, 400761, 400783, 400746, 400794, 400744, 400752,
            400754, 400774, 400787, 400741, 400745, 400777, 400793, 400796,
            400775, 400756, 400779, 400737, 400748, 400797],
           dtype='int64')
user_id($2388): 長度($22)
---------------
Int64Index([400809, 400810, 400816, 400806, 400805, 400802, 400815, 400817,
            400799, 400800, 400804, 400818, 400819, 400807, 400808, 400814,
            400798, 400803, 400801, 400811, 400812, 400813],
           dtype='int64')
user_id($2389): 長度($299)
-----------

           dtype='int64')
user_id($2597): 長度($65)
---------------
Int64Index([427630, 427660, 427661, 427643, 427622, 427638, 427650, 427651,
            427665, 427626, 427639, 427603, 427609, 427646, 427654, 427607,
            427618, 427625, 427636, 427658, 427635, 427664, 427605, 427610,
            427613, 427655, 427602, 427621, 427657, 427608, 427629, 427620,
            427633, 427612, 427624, 427645, 427616, 427642, 427634, 427641,
            427632, 427644, 427619, 427640, 427614, 427604, 427623, 427663,
            427652, 427606, 427649, 427611, 427659, 427615, 427653, 427627,
            427662, 427628, 427648, 427637, 427656, 427631, 427617, 427601,
            427647],
           dtype='int64')
user_id($2598): 長度($27)
---------------
Int64Index([427687, 427675, 427672, 427668, 427671, 427678, 427682, 427691,
            427680, 427690, 427673, 427677, 427676, 427684, 427688, 427689,
            427692, 427674, 427667, 427669, 427670, 427683, 427685, 427679,
           

---------------
Int64Index([457151, 457149, 457176, 457144, 457145, 457150, 457186, 457146,
            457166, 457170, 457163, 457184, 457152, 457157, 457210, 457155,
            457214, 457190, 457133, 457154, 457201, 457160, 457177, 457196,
            457204, 457142, 457189, 457192, 457211, 457137, 457139, 457138,
            457159, 457191, 457187, 457134, 457199, 457147, 457143, 457175,
            457193, 457141, 457178, 457195, 457198, 457140, 457156, 457164,
            457213, 457181, 457200, 457202, 457168, 457148, 457179, 457194,
            457135, 457165, 457205, 457203, 457161, 457153, 457174, 457183,
            457185, 457182, 457188, 457173, 457197, 457206, 457167, 457158,
            457169, 457171, 457180, 457209, 457207, 457162, 457136, 457208,
            457212, 457172],
           dtype='int64')
user_id($2818): 長度($293)
---------------
Int64Index([457255, 457345, 457452, 457379, 457351, 457431, 457450, 457485,
            457307, 457270,
            ...
        

---------------
Int64Index([496045, 496047, 496051, 496057, 496066, 496071, 496053, 496059,
            496060, 496068, 496073, 496083, 496100, 496102, 496039, 496055,
            496067, 496098, 496105, 496048, 496088, 496074, 496078, 496096,
            496097, 496084, 496085, 496086, 496092, 496094, 496064, 496040,
            496087, 496089, 496038, 496042, 496052, 496065, 496090, 496091,
            496093, 496099, 496082, 496104, 496062, 496050, 496079, 496103,
            496041, 496061, 496070, 496080, 496049, 496046, 496101, 496056,
            496095, 496043, 496075, 496037, 496044, 496054, 496069, 496063,
            496077, 496081, 496072, 496076, 496058, 496106],
           dtype='int64')
user_id($3048): 長度($70)
---------------
Int64Index([496131, 496145, 496166, 496129, 496142, 496155, 496109, 496112,
            496114, 496125, 496132, 496150, 496168, 496169, 496170, 496144,
            496147, 496160, 496117, 496120, 496128, 496157, 496107, 496119,
            496135, 4

           dtype='int64')
user_id($3249): 長度($65)
---------------
Int64Index([525903, 525921, 525920, 525926, 525870, 525892, 525910, 525923,
            525868, 525883, 525891, 525880, 525922, 525872, 525904, 525924,
            525877, 525925, 525929, 525873, 525900, 525886, 525871, 525902,
            525919, 525867, 525896, 525869, 525866, 525928, 525899, 525906,
            525874, 525912, 525907, 525898, 525901, 525911, 525865, 525888,
            525881, 525893, 525875, 525895, 525908, 525905, 525917, 525894,
            525897, 525927, 525913, 525889, 525885, 525887, 525915, 525878,
            525879, 525914, 525909, 525890, 525884, 525916, 525876, 525882,
            525918],
           dtype='int64')
user_id($3250): 長度($124)
---------------
Int64Index([526033, 525941, 526044, 526002, 526049, 525982, 525976, 525977,
            525930, 525934,
            ...
            526028, 525958, 525964, 525974, 525962, 526018, 526013, 525999,
            525942, 526016],
           dt

           dtype='int64', length=362)
user_id($3455): 長度($23)
---------------
Int64Index([561774, 561777, 561779, 561783, 561781, 561776, 561778, 561784,
            561766, 561773, 561772, 561775, 561782, 561769, 561771, 561763,
            561785, 561764, 561767, 561768, 561780, 561765, 561770],
           dtype='int64')
user_id($3456): 長度($55)
---------------
Int64Index([561809, 561824, 561835, 561798, 561807, 561814, 561833, 561787,
            561790, 561795, 561796, 561805, 561810, 561811, 561819, 561820,
            561823, 561831, 561832, 561806, 561817, 561825, 561822, 561792,
            561794, 561828, 561800, 561801, 561802, 561804, 561815, 561818,
            561821, 561839, 561797, 561827, 561803, 561799, 561786, 561791,
            561808, 561826, 561829, 561834, 561836, 561837, 561838, 561793,
            561812, 561813, 561816, 561840, 561788, 561789, 561830],
           dtype='int64')
user_id($3457): 長度($160)
---------------
Int64Index([561850, 561931, 561951, 561979,

           dtype='int64')
user_id($3681): 長度($668)
---------------
Int64Index([605543, 605555, 605809, 605845, 605895, 606096, 606121, 606122,
            605512, 605588,
            ...
            605786, 605968, 606012, 605751, 605725, 605789, 605531, 606031,
            605811, 605517],
           dtype='int64', length=668)
user_id($3682): 長度($91)
---------------
Int64Index([606203, 606228, 606184, 606196, 606247, 606219, 606221, 606209,
            606215, 606242, 606166, 606173, 606231, 606180, 606176, 606225,
            606207, 606252, 606174, 606178, 606254, 606187, 606191, 606195,
            606235, 606239, 606194, 606217, 606230, 606165, 606233, 606202,
            606241, 606186, 606246, 606244, 606167, 606192, 606171, 606179,
            606229, 606204, 606212, 606189, 606200, 606253, 606250, 606181,
            606205, 606248, 606169, 606182, 606236, 606255, 606177, 606190,
            606214, 606224, 606175, 606234, 606188, 606222, 606211, 606218,
            606172, 60

user_id($3908): 長度($160)
---------------
Int64Index([648661, 648710, 648642, 648654, 648734, 648742, 648760, 648652,
            648656, 648782,
            ...
            648752, 648786, 648735, 648717, 648665, 648695, 648636, 648774,
            648640, 648731],
           dtype='int64', length=160)
user_id($3909): 長度($78)
---------------
Int64Index([648808, 648833, 648843, 648796, 648817, 648795, 648823, 648829,
            648861, 648838, 648847, 648853, 648856, 648806, 648809, 648851,
            648852, 648854, 648866, 648800, 648810, 648798, 648822, 648840,
            648841, 648814, 648839, 648845, 648855, 648865, 648803, 648812,
            648813, 648846, 648848, 648857, 648793, 648794, 648818, 648821,
            648849, 648858, 648816, 648824, 648842, 648859, 648832, 648870,
            648820, 648805, 648862, 648835, 648844, 648863, 648864, 648867,
            648868, 648819, 648831, 648860, 648811, 648828, 648834, 648801,
            648830, 648797, 648799, 648804, 6488

           dtype='int64', length=310)
user_id($4125): 長度($45)
---------------
Int64Index([689648, 689664, 689690, 689681, 689666, 689671, 689668, 689654,
            689657, 689687, 689649, 689650, 689679, 689684, 689646, 689670,
            689672, 689655, 689658, 689659, 689663, 689667, 689680, 689651,
            689665, 689673, 689689, 689686, 689661, 689678, 689682, 689688,
            689652, 689656, 689674, 689677, 689683, 689685, 689647, 689653,
            689660, 689675, 689676, 689662, 689669],
           dtype='int64')
user_id($4126): 長度($358)
---------------
Int64Index([689701, 689903, 689783, 689788, 689846, 689692, 689800, 690022,
            689806, 689823,
            ...
            689897, 689970, 689695, 689881, 689727, 689760, 689904, 689718,
            689883, 689884],
           dtype='int64', length=358)
user_id($4127): 長度($89)
---------------
Int64Index([690095, 690133, 690134, 690079, 690102, 690051, 690057, 690064,
            690109, 690124, 690125, 690130,

           dtype='int64', length=480)
user_id($4319): 長度($51)
---------------
Int64Index([721582, 721593, 721603, 721606, 721617, 721574, 721623, 721575,
            721579, 721584, 721587, 721580, 721581, 721591, 721598, 721599,
            721595, 721590, 721596, 721578, 721605, 721609, 721613, 721607,
            721612, 721619, 721602, 721576, 721577, 721618, 721600, 721594,
            721608, 721615, 721588, 721589, 721614, 721616, 721585, 721592,
            721583, 721597, 721621, 721622, 721611, 721624, 721601, 721586,
            721604, 721620, 721610],
           dtype='int64')
user_id($4320): 長度($81)
---------------
Int64Index([721645, 721674, 721691, 721641, 721678, 721682, 721634, 721640,
            721642, 721665, 721680, 721695, 721626, 721658, 721661, 721664,
            721668, 721694, 721662, 721676, 721687, 721700, 721632, 721650,
            721653, 721660, 721686, 721696, 721630, 721654, 721656, 721659,
            721699, 721703, 721704, 721643, 721646, 721655,

user_id($4572): 長度($145)
---------------
Int64Index([767370, 767293, 767337, 767333, 767308, 767380, 767394, 767371,
            767313, 767319,
            ...
            767415, 767285, 767309, 767381, 767389, 767335, 767416, 767379,
            767321, 767303],
           dtype='int64', length=145)
user_id($4573): 長度($154)
---------------
Int64Index([767531, 767544, 767491, 767506, 767523, 767471, 767441, 767443,
            767568, 767485,
            ...
            767567, 767572, 767435, 767436, 767503, 767518, 767546, 767470,
            767466, 767465],
           dtype='int64', length=154)
user_id($4574): 長度($50)
---------------
Int64Index([767595, 767601, 767617, 767585, 767618, 767627, 767626, 767580,
            767587, 767603, 767588, 767597, 767590, 767607, 767608, 767621,
            767623, 767589, 767600, 767615, 767582, 767593, 767628, 767586,
            767594, 767598, 767613, 767614, 767592, 767606, 767616, 767604,
            767609, 767629, 767583, 767610, 7676

           dtype='int64', length=280)
user_id($4835): 長度($83)
---------------
Int64Index([807857, 807875, 807870, 807827, 807831, 807833, 807840, 807879,
            807893, 807887, 807891, 807894, 807896, 807819, 807898, 807888,
            807889, 807890, 807892, 807895, 807878, 807897, 807845, 807824,
            807858, 807838, 807864, 807884, 807899, 807822, 807856, 807820,
            807837, 807846, 807860, 807867, 807882, 807885, 807862, 807900,
            807876, 807881, 807854, 807863, 807880, 807849, 807839, 807832,
            807877, 807859, 807823, 807861, 807873, 807851, 807872, 807835,
            807874, 807865, 807843, 807841, 807869, 807844, 807847, 807836,
            807855, 807866, 807829, 807868, 807821, 807825, 807826, 807871,
            807834, 807852, 807886, 807828, 807842, 807850, 807830, 807848,
            807853, 807883, 807901],
           dtype='int64')
user_id($4836): 長度($65)
---------------
Int64Index([807907, 807919, 807904, 807910, 807912, 807914,

---------------
Int64Index([840396, 840424, 840434, 840466, 840404, 840460, 840511, 840467,
            840401, 840512,
            ...
            840515, 840431, 840530, 840550, 840561, 840517, 840537, 840486,
            840443, 840463],
           dtype='int64', length=192)
user_id($5050): 長度($85)
---------------
Int64Index([840626, 840575, 840647, 840650, 840581, 840605, 840618, 840629,
            840589, 840599, 840612, 840648, 840583, 840573, 840654, 840637,
            840639, 840632, 840609, 840627, 840651, 840625, 840642, 840587,
            840602, 840606, 840634, 840577, 840604, 840614, 840620, 840631,
            840580, 840610, 840613, 840635, 840571, 840588, 840596, 840607,
            840619, 840630, 840646, 840576, 840623, 840628, 840572, 840579,
            840584, 840598, 840617, 840633, 840582, 840603, 840643, 840591,
            840638, 840585, 840641, 840574, 840601, 840586, 840608, 840611,
            840616, 840644, 840593, 840594, 840645, 840649, 840615, 84059

           dtype='int64')
user_id($5300): 長度($282)
---------------
Int64Index([877068, 877045, 876950, 876965, 877024, 876925, 876926, 876969,
            877015, 877091,
            ...
            877196, 876964, 876971, 877199, 876932, 876933, 876968, 876983,
            877066, 877075],
           dtype='int64', length=282)
user_id($5301): 長度($139)
---------------
Int64Index([877222, 877262, 877224, 877240, 877220, 877221, 877232, 877237,
            877334, 877294,
            ...
            877317, 877202, 877266, 877208, 877287, 877278, 877274, 877212,
            877286, 877282],
           dtype='int64', length=139)
user_id($5302): 長度($238)
---------------
Int64Index([877361, 877558, 877475, 877472, 877391, 877415, 877345, 877348,
            877356, 877469,
            ...
            877539, 877540, 877548, 877394, 877534, 877544, 877354, 877478,
            877427, 877511],
           dtype='int64', length=238)
user_id($5303): 長度($45)
---------------
Int64Index([877615, 87

Int64Index([914415, 914457, 914459, 914389, 914409, 914410, 914412, 914423,
            914388, 914405, 914426, 914452, 914454, 914394, 914397, 914399,
            914428, 914395, 914411, 914432, 914444, 914387, 914396, 914419,
            914443, 914447, 914391, 914453, 914390, 914430, 914455, 914398,
            914404, 914417, 914424, 914445, 914393, 914436, 914460, 914413,
            914400, 914420, 914427, 914434, 914442, 914449, 914418, 914431,
            914438, 914416, 914446, 914456, 914402, 914425, 914441, 914451,
            914458, 914429, 914392, 914401, 914422, 914439, 914408, 914435,
            914437, 914407, 914433, 914450, 914461, 914406, 914414, 914421,
            914448, 914403, 914440],
           dtype='int64')
user_id($5529): 長度($21)
---------------
Int64Index([914462, 914463, 914464, 914465, 914468, 914473, 914476, 914480,
            914481, 914482, 914466, 914467, 914469, 914470, 914471, 914474,
            914475, 914477, 914478, 914479, 914472],
        

           dtype='int64')
user_id($5747): 長度($564)
---------------
Int64Index([950938, 951052, 951080, 951200, 951306, 951141, 951142, 951301,
            951303, 951304,
            ...
            951447, 951297, 951456, 951072, 951057, 951243, 951240, 951448,
            950996, 951441],
           dtype='int64', length=564)
user_id($5748): 長度($54)
---------------
Int64Index([951495, 951519, 951503, 951510, 951529, 951532, 951500, 951523,
            951524, 951488, 951504, 951499, 951494, 951511, 951526, 951533,
            951506, 951483, 951491, 951520, 951490, 951482, 951509, 951525,
            951518, 951534, 951486, 951501, 951530, 951502, 951508, 951498,
            951513, 951514, 951489, 951527, 951528, 951484, 951481, 951507,
            951512, 951497, 951493, 951505, 951515, 951487, 951516, 951485,
            951522, 951517, 951492, 951531, 951521, 951496],
           dtype='int64')
user_id($5749): 長度($407)
---------------
Int64Index([951569, 951667, 951677, 951700, 95

           dtype='int64', length=142)
user_id($5948): 長度($398)
---------------
Int64Index([984145, 984221, 983965, 984180, 984274, 983963, 984115, 984162,
            983921, 983931,
            ...
            983996, 984079, 984043, 984001, 984067, 984182, 984003, 984268,
            984025, 983967],
           dtype='int64', length=398)
user_id($5949): 長度($181)
---------------
Int64Index([984391, 984364, 984388, 984437, 984365, 984297, 984412, 984381,
            984459, 984405,
            ...
            984382, 984409, 984372, 984474, 984368, 984362, 984329, 984385,
            984325, 984460],
           dtype='int64', length=181)
user_id($5950): 長度($435)
---------------
Int64Index([984822, 984909, 984761, 984588, 984717, 984853, 984876, 984534,
            984857, 984700,
            ...
            984847, 984842, 984651, 984598, 984597, 984731, 984682, 984475,
            984660, 984733],
           dtype='int64', length=435)
user_id($5951): 長度($53)
---------------
Int64Index

In [96]:
# 抓出單一個group

list(groups[1])

[31,
 22,
 27,
 37,
 24,
 36,
 3,
 7,
 47,
 0,
 21,
 44,
 9,
 51,
 43,
 41,
 48,
 18,
 11,
 14,
 42,
 17,
 39,
 45,
 26,
 2,
 6,
 19,
 38,
 52,
 1,
 13,
 49,
 50,
 15,
 20,
 46,
 5,
 8,
 12,
 28,
 23,
 10,
 16,
 29,
 33,
 40,
 4,
 30,
 35,
 32,
 34,
 25]

In [97]:
lengths

1000209

In [98]:
# 平均看影片長度

lengths // len(groups)

165

In [99]:
# 大概會有幾筆資料

mean_length = 120
L = 50
num_users = len(users)

(mean_length - L) * num_users * 0.15

63420.0

In [None]:
# 輸入長度L: 50
# 不足長度怎麼處理?
# 1. 補0
# 2. 不使用
# 爆了, 檔案太大!


L = 50

def get_raw_input(groups):
    global L
    rows, temp = [], []
    for key in groups:
        if len(groups[key]) < L:
            continue
            # temp = list(groups) + [0]*(L-len(groups[key]))
        else:
            temp = list(groups)
        for i in range(0, len(temp)-L):
            # 將資料包裝
            rows.append(temp[i:i+L])
            
    return rows


rows = get_raw_input(groups)

In [None]:
for i in range(3):
    pprint(rows[i])

## 模型架構
- Encoder part
- Classifier part

In [84]:
# config, 透過dataclass decorator自動加入__init__

from dataclasses import dataclass


@dataclass
class BERT4RecConfig:
    max_len: int=128
    batch_size: int=128
    lr: float=0.001
    item_size: int=num_movies
    embed_dim: int=16
    num_heads: int=8
    ff_dim: int=16
    num_layers: int=2
    dropout_rate: float=0.2
    
    
config = BERT4RecConfig()
config

BERT4RecConfig(max_len=128, batch_size=128, lr=0.001, item_size=3953, embed_dim=16, num_heads=8, ff_dim=16, num_layers=2, dropout_rate=0.2)

![Trm - 
Encoder](trm_encoder.png)

---

![image.png](bert4rec.png)




In [85]:
def trm_encoder(q, k, v, i):
    """
        Encoder part
    """
    attention_output = layers.MultiHeadAttention(
        num_heads=config.num_heads,
        key_dim=config.embed_dim // config.num_heads,
        name=f'encoder_{i}/multiheadattention'
    )(q, k, v)
    attention_output = layers.Dropout(rate=config.dropout_rate, name=f'encoder_{i}/att_dropout')(attention_output)
    attention_output = layers.LayerNormalization(
        epsilon=1e-1, name=f'encoder_{i}/att_layernormalization'
    )(q + attention_output)                 # short-cut

    # ffn
    ffn = keras.Sequential([
        layers.Dense(config.ff_dim, activation='relu'),
        layers.Dense(config.embed_dim),          
    ], name=f"encoder_{i}/ffn")
    ffn_output = ffn(attention_output)
    ffn_output = layers.Dropout(0.1, name=f"encoder_{i}/ffn_dropout")(
        ffn_output
    )
    sequence_output = layers.LayerNormalization(
        epsilon=1e-6, name=f"encoder_{i}/ffn_layernormalization"
    )(attention_output + ffn_output)            # short-cut
    return sequence_output

In [86]:
# 透過繼承於 keras.Model 去細緻調整fit過程
loss_fn = keras.losses.SparseCategoricalCrossentropy(
    reduction=tf.keras.losses.Reduction.NONE
)
loss_tracker = tf.keras.metrics.Mean(name="loss")


class BERT4RecMLM(keras.Model):
    def train_step(self, inputs):
        if len(inputs) == 2:
            features, labels = inputs
            sample_weight = None
        elif len(inputs) == 3:
            features, labels, sample_weight = inputs
        else:
            print('輸入有誤...(BERT4RecMLM)')
        
        with tf.GradientTape() as tape:
            predictions = self(features, training=True)
            loss = loss_fn(labels, predictions, sample_weight=sample_weight)
        
        # 計算梯度
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        
        # 參數更新
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        
        # 計算 metrics
        loss_tracker.update_state(loss, sample_weight=sample_weight)
        
        return {'loss': loss_tracker.result()}
    
    @property
    def metrics(self):
        return [loss_tracker]

In [87]:
def create_bert4rec():
    # input
    inputs = keras.Input(shape=(config.max_len,), dtype='int64', name='input')
    # Embedding(item + pos)
    item_embedding_layer = layers.Embedding(
        input_dim=config.item_size,
        output_dim=config.embed_dim,
        name='item_embedding'
    )
    item_embedding = item_embedding_layer(inputs)
    pos_embedding = layers.Embedding(
        input_dim=config.max_len,
        output_dim=config.embed_dim,
        name='pos_embedding'
    )(tf.range(start=0, limit=config.max_len, delta=1))
    embedding = item_embedding + pos_embedding
    
    # Encoder
    for i in range(config.num_layers):
        embedding = trm_encoder(embedding, embedding, embedding, i)
    
    # output layer, share embedding(這邊不是layer應該不能直接操作...不符合functional API) ---> 發現可以!!!
    outputs = tf.matmul(embedding, tf.transpose(a=item_embedding_layer.trainable_variables[0]))
    outputs = layers.Activation('softmax')(outputs)
    
    model = BERT4RecMLM(inputs, outputs)
    model.compile(optimizer=keras.optimizers.Adam(config.lr))
    return model

    
bert4rec = create_bert4rec()
bert4rec.summary()

Model: "ber_t4rec_mlm_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input (InputLayer)              [(None, 128)]        0                                            
__________________________________________________________________________________________________
item_embedding (Embedding)      (None, 128, 16)      63248       input[0][0]                      
__________________________________________________________________________________________________
tf.__operators__.add_26 (TFOpLa (None, 128, 16)      0           item_embedding[0][0]             
__________________________________________________________________________________________________
encoder_0/multiheadattention (M (None, 128, 16)      1088        tf.__operators__.add_26[0][0]    
                                                                 tf.__operators__.ad

## embedding 輸入輸出共享測試

In [88]:
# 找出 共享 embedding 方法, 需要轉置
emd = layers.Embedding(10000, 32)
embedding = emd(np.array([1,2, 3]))

emd.weights

[<tf.Variable 'embedding_9/embeddings:0' shape=(10000, 32) dtype=float32, numpy=
 array([[-0.04565385,  0.0228415 ,  0.01408012, ...,  0.02277399,
          0.03881348, -0.02892498],
        [ 0.00531975,  0.01409591,  0.04157652, ..., -0.02709459,
         -0.04011021, -0.02030834],
        [ 0.04395641,  0.03908107,  0.03616558, ..., -0.04538272,
         -0.02270526, -0.04050515],
        ...,
        [ 0.03277861,  0.04361615,  0.03682769, ..., -0.0142963 ,
         -0.00953268,  0.00386209],
        [-0.04249771,  0.02699311, -0.03973543, ..., -0.04423676,
          0.01094314,  0.03270457],
        [ 0.00862191,  0.01632955,  0.03811376, ...,  0.04757805,
          0.00394714,  0.04658416]], dtype=float32)>]

In [89]:
# 轉置! 但這需要該layer 要記住名稱

tf.transpose(a=emd.trainable_variables[0])

<tf.Tensor: shape=(32, 10000), dtype=float32, numpy=
array([[-0.04565385,  0.00531975,  0.04395641, ...,  0.03277861,
        -0.04249771,  0.00862191],
       [ 0.0228415 ,  0.01409591,  0.03908107, ...,  0.04361615,
         0.02699311,  0.01632955],
       [ 0.01408012,  0.04157652,  0.03616558, ...,  0.03682769,
        -0.03973543,  0.03811376],
       ...,
       [ 0.02277399, -0.02709459, -0.04538272, ..., -0.0142963 ,
        -0.04423676,  0.04757805],
       [ 0.03881348, -0.04011021, -0.02270526, ..., -0.00953268,
         0.01094314,  0.00394714],
       [-0.02892498, -0.02030834, -0.04050515, ...,  0.00386209,
         0.03270457,  0.04658416]], dtype=float32)>

In [90]:
# 可以使用!, 但應該要在call裡面使用, 所以需要用繼承keras.Model模式來寫!

tf.matmul(embedding, tf.transpose(a=emd.trainable_variables[0]))

<tf.Tensor: shape=(3, 10000), dtype=float32, numpy=
array([[ 0.00295702,  0.02385878,  0.00536622, ...,  0.00704855,
         0.00206965,  0.00226169],
       [-0.00061949,  0.00536622,  0.02979579, ..., -0.00245284,
        -0.00624401,  0.00031275],
       [ 0.00046163, -0.00302564, -0.0011524 , ..., -0.00433287,
         0.00389158,  0.00151961]], dtype=float32)>

In [91]:
# model 輸入測試確保可以使用

inputs = np.array([[1000]*128], dtype=np.int64)
bert4rec(inputs)

<tf.Tensor: shape=(1, 128, 3953), dtype=float32, numpy=
array([[[0.00020181, 0.00022673, 0.00022804, ..., 0.00030195,
         0.00026357, 0.00021186],
        [0.00021734, 0.00025212, 0.00020987, ..., 0.00028276,
         0.00026578, 0.00024568],
        [0.0002291 , 0.00020995, 0.00027308, ..., 0.00027523,
         0.0002666 , 0.00022405],
        ...,
        [0.00022932, 0.00021925, 0.00025138, ..., 0.0002734 ,
         0.00029029, 0.00022199],
        [0.00022149, 0.0002334 , 0.00020401, ..., 0.00027082,
         0.00031555, 0.00025754],
        [0.00022589, 0.00023743, 0.00023314, ..., 0.00027141,
         0.00026577, 0.00022947]]], dtype=float32)>

In [92]:
# 預測

np.argmax(tf.nn.softmax(bert4rec(inputs)), axis=2)

array([[3573, 3768, 3682,   73, 3719, 2715, 3573,  122,  377, 3573, 3573,
        3573, 3299, 3573, 2356, 3573, 3573, 3573, 2204, 1109,   73, 1304,
        3573, 1109,  449, 3573, 2343, 1304, 3118, 1073, 3573, 2343, 3573,
          73, 1861,  903, 1007,  733, 3385, 3573, 1109,  810, 3573, 1109,
        2343, 1290, 3901, 1304, 1109, 3385, 3573,  377, 3573, 1304, 1382,
        3719, 2285,   73, 3573, 3573, 1085, 3441, 1011, 1992,  794, 2715,
        2385,  794, 3573,  733,  733,   73, 1534, 3573, 3026, 2382, 3131,
        1109,  794, 2715, 3573, 2715, 3118, 2193, 3573, 2718,  377, 3078,
        3118,  377, 3573, 1304, 3056,  903, 2715, 1109, 1304, 1109, 3385,
        3573,   73, 2343, 1109, 3573, 1853, 3719, 3719, 3594, 2343, 3901,
        1109, 2715, 3573, 2382, 2382, 3573,   73,  429, 3573,   73, 3901,
        2969, 1304, 3573, 1109, 1109, 3476,  903]], dtype=int64)

## **接著需要**

- 將[mask] token 放入 vocabulary(可以採用與text 同樣方式, 透過TextVectorization方式 or not)
- 將輸入資料mask: portion = 0.4
- 測試訓練
- 負樣本
- 正式訓練

In [1]:
# [mask] token ---> ok

# 將輸入資料mask: 先將資料轉換成seq data, 但要直接用numpy 還是 用成 檔案閱讀 成dataset呢?
# 以擴充來說, dataset更好, 但目前以實現為原則, 先用numpy ---> 記憶體爆開
# 那採取 寫成txt檔案, user_id,裡面每一row為 [movied_id1, movie_id2,  ... movied_id_length]




NameError: name 'config' is not defined