In [None]:
import pandas as pd
import numpy as np
import re

train = pd.read_csv('Netflix-1M_train_original.csv')

interaction_counts = train.groupby('item_id:token')['timestamp:float'].count().reset_index()
train = pd.merge(train, interaction_counts, on='item_id:token', how='left')
train.rename(columns={'timestamp:float_x': 'timestamp:float'}, inplace=True)
train.rename(columns={'timestamp:float_y': 'interaction_count'}, inplace=True)

train['title_genre'] = '<' + train['title:token'] + ' (genre: ' + train['genres:token'] + ')>'
train['<title:token>'] = '<' + train['title:token'] + '>'

train = train.sort_values(by=['user_id:token', 'interaction_count'], ascending=[True, False])

np.set_printoptions(linewidth=np.inf)

user_ids = np.unique(train['user_id:token'].values)
user_dict = dict()
rating_count = dict()

for user_id in user_ids:

    df_user = train[train['user_id:token'] == user_id]

    pos_5 = df_user[df_user['rating:float'] == 5]
    pos_4 = df_user[df_user['rating:float'] == 4]
    neg_3 = df_user[df_user['rating:float'] == 3]
    neg_2 = df_user[df_user['rating:float'] == 2]
    neg_1 = df_user[df_user['rating:float'] == 1]

    values_pos_5 = [pos_5['title_genre'].values]
    values_pos_4 = [pos_4['title_genre'].values]
    values_neg_3 = [neg_3['title_genre'].values]
    values_neg_2 = [neg_2['title_genre'].values]
    values_neg_1 = [neg_1['title_genre'].values]

    mean_rating = np.mean(df_user['rating:float'].values)

    values = [values_pos_5, values_pos_4, values_neg_3, values_neg_2, values_neg_1, mean_rating]
    user_dict[user_id] = values


train.head()

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
import torch

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]= "4"

llm_dir = '/home/chwchong/_WWW25/LLM/'
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", cache_dir=llm_dir)
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", cache_dir=llm_dir)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [None]:
prompt1 = """
This dataset is from the Netflix dataset.

User's Positively Rated Movie List:
The following list includes movie titles and genres that the user has rated positively:
{positive_movies_5}

---------------------------------------------

Based on the User's Positively Rated Movie List, analyze the user's preferences and patterns.
You will be provided with Candidate List, which includes movie titles and genres that the user has rated both positively and negatively.
Your task is to strictly select movies and provide only the movie titles from Candidate List that the user is most likely to have rated positively.

Candidate List:
{candidate_example}

Output (Answer):
{only_title_4}

---------------------------------------------

Based on the User's Positively Rated Movie List, analyze the user's preferences and patterns.
You will be provided with Candidate List, which includes movie titles and genres that the user has rated both positively and negatively.
Your task is to strictly select movies and provide only the movie titles from Candidate List that the user is most likely to have rated positively.

Candidate List:
{negative_movies_32}

Output (Answer):
"""


def ask_llama1(question, tokenizer, model, device, stop_token="---------------------------------------------", max_occurrences=2):
    inputs = tokenizer(question, return_tensors="pt").to(device)
    input_ids = inputs['input_ids']
    start_index = question.find("Candidate List:")
    if start_index != -1:
        second_start_index = question.find("Candidate List:", start_index + 1)
    else:
        second_start_index = -1
    if second_start_index != -1:
        candidate_list_2_text = question[second_start_index:]
    else:
        print('"Second occurrence of "Candidate List:" cannot be found.')
    inputs_sub = tokenizer(candidate_list_2_text, return_tensors="pt").to(device)
    input_ids_sub = inputs_sub['input_ids']
    length = input_ids.shape[1] + input_ids_sub.shape[1]
    if length > 5000:
        print(f"Skipping user due to input length: {input_ids.shape[1]}", end=", ")
        return None
    else:
        print(f"input length: {input_ids.shape[1]}", end=", ")
    outputs = model.generate(
        input_ids=input_ids.to(device),
        attention_mask=inputs['attention_mask'].to(device),
        max_length=length,
        pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
        num_beams=1,
        do_sample=False,
        temperature=1,
        top_p=1
    )
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    stop_token_count = 0
    output_lines = []
    for line in response.split('\n'):
        if stop_token in line:
            stop_token_count += 1
            if stop_token_count > max_occurrences:
                break
        output_lines.append(line)
    return '\n'.join(output_lines).strip()

In [None]:
prompt2 = """
This dataset is from the Netflix dataset.
Based on the User's Positively Rated Movie List, analyze the user's preferences and patterns.
You will be provided with Candidate List, which includes movie titles and genres that the user has rated both positively and negatively.
Your task is to strictly select movies and provide only the movie titles from Candidate List that the user is most likely to have rated positively.

User's Positively Rated Movie List:
{positive_movies_5}

Candidate List:
{candidate_example}

Output (Answer):
{only_title_4}

---------------------------------------------

This dataset is from the Netflix dataset.
Based on the User's Positively Rated Movie List, analyze the user's preferences and patterns.
You will be provided with Candidate List, which includes movie titles and genres that the user has rated both positively and negatively.
Your task is to strictly select movies and provide only the movie titles from Candidate List that the user is most likely to have rated positively.

User's Positively Rated Movie List:
{positive_movies_5}

Candidate List:
{negative_movies_32}

Output (Answer):
"""


def ask_llama2(question, tokenizer, model, device, stop_token="---------------------------------------------", max_occurrences=1):
    inputs = tokenizer(question, return_tensors="pt").to(device)
    input_ids = inputs['input_ids']
    start_index = question.find("Candidate List:")
    if start_index != -1:
        second_start_index = question.find("Candidate List:", start_index + 1)
    else:
        second_start_index = -1
    if second_start_index != -1:
        candidate_list_2_text = question[second_start_index:]
    else:
        print('"Second occurrence of "Candidate List:" cannot be found.')
    inputs_sub = tokenizer(candidate_list_2_text, return_tensors="pt").to(device)
    input_ids_sub = inputs_sub['input_ids']
    length = input_ids.shape[1] + input_ids_sub.shape[1]
    if length > 5000:
        print(f"Skipping user due to input length: {input_ids.shape[1]}", end=", ")
        return None
    else:
        print(f"input length: {input_ids.shape[1]}", end=", ")
    outputs = model.generate(
        input_ids=input_ids.to(device),
        attention_mask=inputs['attention_mask'].to(device),
        max_length=length,
        pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
        num_beams=1,
        do_sample=False,
        temperature=1,
        top_p=1
    )
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    stop_token_count = 0
    output_lines = []
    for line in response.split('\n'):
        if stop_token in line:
            stop_token_count += 1
            if stop_token_count > max_occurrences:
                break
        output_lines.append(line)
    return '\n'.join(output_lines).strip()

In [None]:
# skip: Pre-removed users who have hallucinations because of too many interactions
skip = np.array([544, 1871, 2003, 2790, 3101, 4488, 7380, 8382, 9816, 10310, 14308, 14597, 15796, 18913, 20460, 20774, 20789, 25660, 26421, 27310, 28277, 29129, 29665, 31764, 33599, 34552, 34827, 35269, 35986, 37656, 38600, 39469, 40154, 40960, 41132, 41212, 41308, 43751, 44319, 47080, 47248, 47344, 49895, 50081, 51639, 53406, 57736, 63382, 64311, 64404, 68462, 68472, 69169, 69724, 69824, 69848, 72566, 72611, 73573, 73730, 74317, 74439, 76376, 76497, 76709, 78600, 81093, 81134, 81453, 81519, 82369, 82458, 88383, 90912, 93502, 93834, 94462, 95358, 102020, 102706, 108285, 108780, 109585, 113677, 114008, 114216, 115422, 116371, 116858, 117902, 120450, 122955, 122987, 125219, 128460, 133601, 135477, 136827, 139120, 139147, 139657, 139995, 141430, 146806, 149511, 150998, 154368, 154658, 161259, 162988, 163094, 163434, 163695, 163896, 167612, 170814, 173003, 174705, 181323, 185495, 187355, 188813, 189362, 190729, 191125, 191616, 192796, 193953, 195107, 195761, 195934, 196091, 196422, 198630, 199635, 200001, 201837, 203576, 203902, 204318, 205660, 206826, 207967, 213420, 215739, 216277, 220223, 220872, 225880, 228092, 231052, 231110, 231614, 231915, 232987, 234252, 234817, 235466, 236574, 237752, 238188, 238576, 240953, 242724, 244718, 246687, 247089, 248547, 249189, 250252, 252251, 261362, 262264, 262765, 266946, 267611, 270600, 271463, 273741, 277943, 278096, 278375, 279263, 282737, 283067, 285349, 285597, 287179, 288159, 289045, 290960, 291035, 291292, 291602, 294099, 295855, 304243, 304390, 308843, 309630, 314159, 314374, 315168, 315705, 319513, 323118, 324371, 326359, 328094, 330080, 331005, 331716, 332269, 333766, 334561, 336645, 337001, 337273, 337376, 341215, 341666, 342035, 343857, 351238, 351321, 352345, 353249, 353368, 354138, 355855, 358934, 358976, 359679, 360131, 361450, 362811, 365984, 367049, 367941, 369041, 369335, 369558, 370864, 371474, 371489, 378420, 379563, 381954, 382545, 382747, 384188, 386246, 386619, 388227, 388341, 396153, 397216, 398641, 399170, 399823, 400690, 400880, 401884, 402697, 404798, 410421, 411304, 413420, 413609, 413718, 414542, 414787, 418975, 421869, 423778, 425907, 425935, 426981, 428311, 428322, 428826, 430011, 431588, 433843, 438491, 439124, 441597, 441730, 442361, 442669, 443054, 443619, 443972, 444229, 446469, 446562, 452787, 452860, 454612, 460007, 460660, 461171, 461390, 462934, 466436, 468218, 472700, 473699, 474420, 474613, 475874, 478630, 480305, 483204, 484675, 484774, 484934, 485249, 488188, 493922, 495366, 496173, 498341, 499222, 500353, 504440, 505015, 507498, 509396, 510103, 513830, 513842, 516042, 517647, 518271, 519760, 520241, 521443, 523150, 523890, 525805, 526856, 527849, 534508, 535789, 536316, 537341, 538881, 543370, 543526, 544329, 545592, 548665, 552739, 553915, 556956, 560142, 561494, 562974, 563186, 563575, 563632, 564660, 564989, 567509, 568045, 573283, 576432, 577109, 577317, 577981, 578713, 581065, 582685, 584096, 585599, 585655, 586068, 587917, 591273, 591772, 594525, 596255, 596860, 597193, 597709, 602465, 603874, 604404, 605687, 608576, 610083, 610101, 610174, 615510, 616015, 617275, 618820, 619245, 622278, 628120, 628966, 631075, 631226, 631272, 631381, 633045, 633151, 633277, 635202, 637597, 638231, 638614, 638729, 642596, 643214, 644251, 645523, 645846, 647492, 648128, 648130, 648196, 650247, 655479, 656304, 656824, 657641, 657712, 659867, 660924, 661712, 662441, 663830, 664911, 665996, 671869, 671960, 674221, 676732, 677430, 677522, 677748, 678501, 679255, 679849, 680352, 683159, 685947, 687215, 688599, 692305, 692635, 694235, 695396, 695462, 696052, 696811, 697172, 698813, 699166, 699825, 702258, 704220, 712426, 712624, 720107, 720151, 721554, 724509, 725549, 726652, 728231, 728413, 732183, 732337, 732742, 732938, 733049, 733187, 733809, 734412, 735262, 735351, 738804, 740172, 744010, 744346, 744735, 745720, 746892, 748763, 749593, 749854, 750008, 750674, 754001, 754082, 755616, 755854, 756117, 756746, 757582, 757948, 758360, 760816, 762363, 766065, 769235, 770028, 770622, 770641, 771638, 772116, 776235, 778806, 785657, 786827, 789625, 790635, 792309, 792534, 794439, 794636, 798524, 798797, 800138, 801062, 802045, 803481, 807603, 807892, 809089, 810777, 816707, 818987, 819855, 820707, 821625, 822414, 823824, 825576, 825603, 825779, 827422, 827594, 828201, 833951, 834751, 835854, 837893, 838309, 839232, 840289, 841019, 842153, 844567, 846928, 849929, 850368, 851702, 853287, 854661, 857413, 857469, 859502, 859593, 860839, 861911, 862110, 862888, 864084, 864290, 864549, 866294, 868795, 869130, 873713, 873769, 873795, 874176, 876035, 876890, 877068, 877366, 878976, 880312, 880432, 883036, 884294, 887234, 887915, 890925, 891357, 891561, 892103, 893392, 893466, 893764, 894474, 894681, 895641, 896955, 897125, 900046, 901075, 901211, 901459, 902807, 903002, 903378, 904506, 906101, 906602, 906959, 911197, 911270, 911416, 913516, 916262, 917537, 918983, 919219, 919250, 931232, 939502, 941624, 942554, 943129, 943997, 944149, 947643, 947833, 950251, 950269, 950524, 951682, 954045, 954968, 957287, 957673, 957814, 959019, 959675, 959714, 959957, 960725, 961310, 961541, 962997, 963113, 963997, 966493, 966669, 969386, 970751, 971434, 972267, 973119, 973280, 973711, 974103, 975911, 977127, 979865, 980280, 981049, 981593, 982171, 983123, 984770, 984786, 986472, 989598, 992790, 994599, 995963, 995971, 996379, 997969, 998888, 1000892, 1003239, 1003901, 1008936, 1009667, 1010127, 1011674, 1013570, 1015836, 1017694, 1020195, 1021146, 1024806, 1025705, 1025838, 1025990, 1026814, 1027550, 1028729, 1028818, 1031887, 1032704, 1032784, 1034604, 1035057, 1036783, 1037347, 1038515, 1038786, 1041143, 1043056, 1045220, 1045306, 1046081, 1047641, 1047834, 1049825, 1052738, 1054137, 1056156, 1061602, 1061673, 1062368, 1063109, 1066774, 1068466, 1070065, 1070197, 1070610, 1072069, 1075922, 1075945, 1077184, 1079020, 1079173, 1079264, 1079644, 1082150, 1082857, 1083952, 1084685, 1084814, 1087055, 1094578, 1097495, 1100363, 1100958, 1101586, 1101849, 1102159, 1102362, 1103965, 1105839, 1106390, 1107747, 1108077, 1112565, 1115476, 1117757, 1120126, 1120591, 1121211, 1126733, 1127442, 1129869, 1130281, 1131398, 1132526, 1133422, 1134307, 1134932, 1137389, 1141553, 1142611, 1142985, 1144464, 1144523, 1144639, 1145948, 1146323, 1146580, 1147594, 1147730, 1149285, 1150057, 1152696, 1155304, 1161257, 1161299, 1165087, 1165790, 1169575, 1171904, 1172252, 1173488, 1173837, 1175472, 1176838, 1178494, 1179642, 1180755, 1182139, 1183532, 1185600, 1185696, 1186156, 1187195, 1187282, 1189426, 1190391, 1191044, 1191116, 1191122, 1191216, 1191337, 1192239, 1193479, 1193751, 1194891, 1198766, 1199235, 1201142, 1202231, 1204005, 1205334, 1207408, 1209170, 1210528, 1211183, 1213717, 1214219, 1214278, 1215167, 1215825, 1219748, 1221963, 1224756, 1226073, 1228549, 1228999, 1229977, 1230055, 1230782, 1231594, 1232267, 1232442, 1233302, 1234205, 1234403, 1236665, 1238243, 1244618, 1248503, 1249596, 1249660, 1251170, 1252571, 1254193, 1254305, 1257106, 1257665, 1258532, 1260494, 1262157, 1262477, 1263062, 1263899, 1264254, 1266342, 1267382, 1267764, 1268417, 1271900, 1274580, 1275804, 1276275, 1282942, 1283365, 1283826, 1285169, 1285292, 1288465, 1299403, 1300759, 1305167, 1305838, 1305914, 1306363, 1307282, 1312013, 1312099, 1312946, 1316452, 1317345, 1318210, 1322746, 1325693, 1329133, 1329799, 1329886, 1330760, 1336119, 1336141, 1337104, 1343079, 1343091, 1343848, 1344505, 1346327, 1346977, 1348662, 1348729, 1349619, 1354435, 1354827, 1357619, 1359540, 1369309, 1370466, 1372722, 1373690, 1374765, 1375445, 1377886, 1381159, 1381395, 1385106, 1385267, 1385369, 1390665, 1390692, 1393896, 1394756, 1401192, 1401475, 1401592, 1404139, 1405915, 1408156, 1408623, 1409021, 1409486, 1413286, 1415299, 1415619, 1417275, 1417539, 1418354, 1422960, 1423788, 1427482, 1429580, 1429789, 1430601, 1432781, 1436768, 1437156, 1438768, 1441387, 1441694, 1442622, 1443126, 1443380, 1445631, 1448440, 1454460, 1458700, 1461921, 1464163, 1466349, 1466434, 1470460, 1472622, 1473776, 1474278, 1476078, 1477713, 1480084, 1484886, 1486472, 1486920, 1487394, 1493298, 1494853, 1495711, 1497284, 1498317, 1504833, 1505356, 1506188, 1508863, 1509362, 1510107, 1511920, 1512540, 1519636, 1520724, 1523008, 1523359, 1524188, 1524903, 1525460, 1526758, 1528530, 1529082, 1530338, 1530681, 1533057, 1533559, 1534004, 1538341, 1543179, 1543432, 1543838, 1544338, 1544785, 1545049, 1546515, 1551924, 1553247, 1560840, 1561880, 1563881, 1566002, 1566022, 1566596, 1572100, 1573392, 1574043, 1579110, 1580159, 1585056, 1587665, 1588314, 1589107, 1589989, 1590642, 1590921, 1591526, 1592001, 1593431, 1595068, 1595112, 1595693, 1597868, 1601556, 1601989, 1605674, 1606587, 1607029, 1607317, 1607747, 1608091, 1613132, 1613574, 1614405, 1615216, 1620158, 1621367, 1624150, 1625497, 1628816, 1630760, 1635967, 1636375, 1636381, 1637504, 1640400, 1640511, 1643323, 1647142, 1647370, 1647640, 1648822, 1652307, 1652468, 1653637, 1654509, 1656759, 1657798, 1658467, 1661121, 1663268, 1663459, 1668386, 1670143, 1672270, 1675445, 1676477, 1679190, 1681941, 1681954, 1683315, 1684982, 1685995, 1691434, 1691916, 1695030, 1696181, 1704018, 1705359, 1705636, 1706279, 1709921, 1709936, 1710210, 1712681, 1716002, 1717175, 1718751, 1719571, 1720068, 1722113, 1726339, 1726499, 1730413, 1730614, 1732034, 1737423, 1742543, 1743162, 1746464, 1748373, 1748546, 1748740, 1750115, 1751276, 1751835, 1754660, 1759800, 1760670, 1769654, 1774018, 1774783, 1776547, 1776887, 1777401, 1778764, 1780157, 1783108, 1784858, 1784928, 1791977, 1792324, 1793673, 1794245, 1794535, 1795110, 1797072, 1797262, 1798758, 1799399, 1801452, 1801961, 1802972, 1803892, 1806070, 1807946, 1808030, 1808651, 1810936, 1812060, 1815027, 1815535, 1819146, 1820291, 1821568, 1822524, 1823308, 1823426, 1824145, 1825484, 1827429, 1828042, 1828493, 1828986, 1830165, 1833347, 1834155, 1836648, 1838530, 1839501, 1839612, 1847890, 1848018, 1848735, 1850328, 1853443, 1855673, 1857268, 1860864, 1861258, 1863477, 1865263, 1865783, 1866053, 1866792, 1867308, 1870867, 1871513, 1874396, 1874620, 1882347, 1882837, 1885155, 1885703, 1886300, 1887322, 1888070, 1888966, 1891868, 1892183, 1894134, 1895489, 1896653, 1898454, 1899578, 1900059, 1907386, 1910134, 1915501, 1918443, 1922072, 1922583, 1924464, 1928676, 1929215, 1929564, 1932348, 1932394, 1934894, 1937301, 1938313, 1941042, 1941596, 1943292, 1943761, 1944965, 1945525, 1946537, 1947136, 1947644, 1948627, 1948652, 1948777, 1949092, 1949215, 1955923, 1956058, 1956440, 1957133, 1964181, 1964939, 1964978, 1966615, 1970073, 1970207, 1970442, 1971100, 1971220, 1971863, 1975103, 1977327, 1977604, 1978009, 1978607, 1981351, 1982517, 1984891, 1985466, 1986732, 1987201, 1989121, 1989301, 2003213, 2004856, 2005184, 2008172, 2008539, 2009114, 2009618, 2010259, 2013987, 2014706, 2017869, 2018452, 2018754, 2018908, 2021326, 2022496, 2022868, 2023351, 2024102, 2026855, 2027967, 2030978, 2031068, 2031413, 2032398, 2037043, 2043517, 2047761, 2051865, 2052030, 2053456, 2055030, 2055983, 2056056, 2056315, 2056324, 2057056, 2060560, 2066475, 2070436, 2071210, 2072647, 2080240, 2081231, 2083121, 2083766, 2086957, 2088679, 2088949, 2092413, 2093403, 2093984, 2101124, 2101288, 2103318, 2104077, 2105288, 2108400, 2109649, 2110280, 2112541, 2113383, 2113606, 2115736, 2117885, 2118838, 2121049, 2121917, 2122847, 2123026, 2125060, 2128541, 2130918, 2135410, 2137525, 2137538, 2139217, 2141585, 2141635, 2142814, 2145206, 2145856, 2146200, 2146283, 2147714, 2148823, 2150937, 2151593, 2151938, 2152497, 2152776, 2153351, 2153809, 2165396, 2166500, 2168669, 2170654, 2172616, 2173008, 2174552, 2178000, 2178672, 2179604, 2181770, 2183930, 2184488, 2184882, 2185127, 2185704, 2188674, 2191700, 2195478, 2195600, 2200358, 2201264, 2201937, 2201939, 2202671, 2208368, 2209480, 2215790, 2216145, 2216856, 2217043, 2222044, 2223998, 2226693, 2227880, 2228502, 2229105, 2229377, 2231824, 2233936, 2235910, 2236382, 2241136, 2241466, 2242154, 2242595, 2243423, 2243654, 2245381, 2252432, 2254015, 2254321, 2254513, 2255453, 2255822, 2256471, 2258212, 2260087, 2260281, 2260407, 2261254, 2261989, 2263129, 2264779, 2267067, 2269807, 2270073, 2278455, 2278949, 2278995, 2279048, 2280765, 2282088, 2282605, 2283112, 2283705, 2286414, 2287964, 2288057, 2290963, 2296924, 2298202, 2300139, 2301000, 2301123, 2303565, 2304914, 2306346, 2308382, 2310496, 2311187, 2314321, 2316233, 2316678, 2319800, 2321454, 2322418, 2322442, 2326029, 2330735, 2332346, 2334959, 2335231, 2337017, 2338157, 2341605, 2347743, 2349683, 2354051, 2356346, 2357468, 2361462, 2363063, 2363228, 2365320, 2366219, 2367275, 2367308, 2370951, 2375160, 2375550, 2379179, 2379329, 2379545, 2380939, 2381036, 2381833, 2384128, 2385663, 2386808, 2388694, 2391898, 2392144, 2396995, 2399097, 2403969, 2405969, 2407033, 2409222, 2409280, 2409746, 2410780, 2413936, 2414659, 2415419, 2418437, 2418924, 2420366, 2424626, 2427285, 2429990, 2431172, 2432352, 2432714, 2432882, 2434290, 2436670, 2438956, 2446344, 2453949, 2456188, 2456906, 2457012, 2458442, 2459091, 2460722, 2461695, 2463091, 2464759, 2468067, 2471125, 2471950, 2473252, 2475251, 2480036, 2482518, 2488478, 2490389, 2494418, 2495256, 2496025, 2496353, 2496410, 2497577, 2500594, 2502216, 2502422, 2504163, 2504752, 2506646, 2508234, 2510967, 2514465, 2515522, 2515558, 2520431, 2522355, 2528342, 2528762, 2529941, 2534085, 2535858, 2540042, 2540652, 2545316, 2545709, 2546736, 2547406, 2548847, 2550355, 2554945, 2555902, 2557944, 2559895, 2563065, 2564240, 2569291, 2570164, 2572562, 2574728, 2577872, 2579012, 2580154, 2580594, 2580895, 2581590, 2584438, 2589002, 2592194, 2593856, 2595332, 2595864, 2597544, 2601021, 2606108, 2611670, 2611947, 2615036, 2615291, 2615548, 2616836, 2620107, 2620980, 2621119, 2621443, 2621825, 2624128, 2624550, 2626280, 2627366, 2629271, 2631339, 2631986, 2634788, 2635210, 2635288, 2635326, 2637783, 2637930, 2638364, 2639937, 2642389])
user_ids = np.setdiff1d(np.array(list(user_dict.keys())), skip)
len(user_ids), user_ids

In [None]:
skipped_users = []

with open('llama_distinguish_answer.txt', 'w') as f_cut, open('llama_distinguish_full.txt', 'w') as f_full:
    for user_id in user_ids:
        np.random.seed(2024)
        mean_rating = user_dict[user_id][5]
        positive_movies_5 = user_dict[user_id][0][0]
        positive_movies_4 = user_dict[user_id][1][0]
        negative_movies_3 = user_dict[user_id][2][0]
        negative_movies_2 = user_dict[user_id][3][0]
        negative_movies_1 = user_dict[user_id][4][0]

        print('mean_rating %.2f' %mean_rating, end=" ")
        if mean_rating >= 3:
            candidate_example = np.array(list(positive_movies_4) + list(negative_movies_2) + list(negative_movies_1))
            candidate_real = negative_movies_3
        else:
            candidate_example = np.array(list(positive_movies_4) + list(negative_movies_1))
            candidate_real = np.array(list(negative_movies_3) + list(negative_movies_2))
        np.random.shuffle(candidate_example)
        np.random.shuffle(candidate_real)

        only_title_4 = candidate_example[np.isin(candidate_example, positive_movies_4)]
        if len(only_title_4) > 0:
            only_title_4 = np.vectorize(lambda item: re.sub(r'\s*\(genre:.*$', '>', item))(only_title_4).astype(object)
        

        if len(candidate_real) > 10:
            chunked_candidate_real = [candidate_real[i:i+10] for i in range(0, len(candidate_real), 10)]
        else:
            chunked_candidate_real = [candidate_real]

        for i, chunk in enumerate(chunked_candidate_real):
            prompt = prompt1.format(
                positive_movies_5=positive_movies_5,
                candidate_example=candidate_example,
                only_title_4=only_title_4,
                negative_movies_32=chunk)
            response = ask_llama1(prompt, tokenizer, model, device)

            if response is None:
                print('user ' + str(user_id) + '[' + str(i) + '] skip')
                skipped_users.append(user_id)
                continue
            if response[-1] != ']':
                print('user ' + str(user_id) + '[' + str(i) + '] skip (hallucination)')
                prompt = prompt2.format(
                positive_movies_5=positive_movies_5,
                candidate_example=candidate_example,
                only_title_4=only_title_4,
                negative_movies_32=chunk)
                response = ask_llama2(prompt, tokenizer, model, device)
                if response is None:
                    print('user ' + str(user_id) + '[' + str(i) + '] skip')
                    skipped_users.append(user_id)
                    continue
                if response[-1] != ']':
                    print('user ' + str(user_id) + '[' + str(i) + '] skip (hallucination)')
                    skipped_users.append(user_id)
                    continue

            print('user ' + str(user_id) + '[' + str(i) + '] complete!')
            user_number_full = f"LLaMA's full recommendation for user {user_id}:"
            f_full.write(user_number_full + '\n')
            f_full.write(response + '\n\n\n\n\n\n\n\n\n')
            user_number_cut = f"LLaMA's cut recommendation for user {user_id}:" 
            f_cut.write(user_number_cut + '\n')
            f_cut.write(response[len(prompt)-1:] + '\n\n\n\n\n\n')


if skipped_users:
    skipped_users = np.unique(np.array(skipped_users))
    print(f"Skipped users due to input length: {skipped_users}")
    print(f"len(skipped_users): {len(skipped_users)}")
#1273m 49.0s