In [27]:
#Imports
import pandas as pd
import numpy as np
import random
import jsonlines


In [28]:
#load data

# load data
df_data = pd.read_csv('./ml-100k/u.data', delimiter='\t',header=None, names=['user_id', 'item_id', 'rating', 'timestamp'])
df_user = pd.read_csv('./ml-100k/u.user', delimiter='|', header=None, names=['user_id', 'age', 'gender', 'occupation', 'zip_code'])
df_item = pd.read_csv('./ml-100k/u.item', delimiter='|', encoding='latin-1', header=None, names=['movieid', 'movietitle', 'releasedate', 'videoreleasedate', 'imdburl', 'unknown', 'action', 'adventure', 'animation', "children's", 'comedy', 'crime', 'documentary', 'drama', 'fantasy', 'film-noir', 'horror', 'musical', 'mystery', 'romance', 'sci-fi', 'thriller', 'war', 'western'])


In [29]:
#Constants
PER_USER_DATAPOINTS = 5
MOVIES_IN_DATAPOINTS = 4
TYPE_MOVIE = {'Good', 'Bad',}


In [30]:
#Generate signle datapoint

def gen_datapoint(userId, df_data,df_item):
    #get all the movies that the user has rated
    user_ratings = df_data[df_data['user_id'] == userId]
    random_user_segment = random.choice(list(TYPE_MOVIE))
    datapoint = []
    if random_user_segment == 'Good':
        user_ratings = user_ratings[user_ratings['rating'] > 3.5]
    else:
        user_ratings = user_ratings[user_ratings['rating'] < 2.5]
    # print(user_ratings)
    for i in range(PER_USER_DATAPOINTS):
        #get a random movie that the user has rated
        movie = user_ratings.sample()
        movie_id = movie['item_id'].values[0]
        movie_rating = movie['rating'].values[0]
        movie_title = df_item[df_item['movieid'] == movie_id]['movietitle'].values[0]
        datapoint.append([userId, movie_id, movie_rating, movie_title])
        # print(userId, movie_id, movie_rating, movie_title)
    return (random_user_segment, datapoint)


In [31]:
gen_datapoint(df_data['user_id'].sample().values[0], df_data, df_item)

('Bad',
 [[760, 300, 1, 'Air Force One (1997)'],
  [760, 723, 2, 'Boys on the Side (1995)'],
  [760, 216, 2, 'When Harry Met Sally... (1989)'],
  [760, 183, 2, 'Alien (1979)'],
  [760, 183, 2, 'Alien (1979)']])

In [32]:
def prompt(data):
    rating, datapoints = data
    c_prompt = f'Given a user, here is the list of movies that they have rated as {rating}. \n### User Input:'
    for datapoint in datapoints:
        c_prompt += f'\nMovie: {datapoint[3]} Rating: {datapoint[2]}'
    return c_prompt


In [33]:
prompt(('Good', [[747, 390, 4, 'Fear of a Black Hat (1993)']]))

'Given a user, here is the list of movies that they have rated as Good. \n### User Input:\nMovie: Fear of a Black Hat (1993) Rating: 4'

In [35]:
from tqdm import tqdm


unique_user_ids = df_user['user_id'].unique().tolist()

data_list = []

# Loop through the ratings
for userId in tqdm(unique_user_ids, desc='Generating prompts'):
    # Call the prompt function
    try:
        rating, datapoints = gen_datapoint(userId, df_data,df_item)
        data_list.append({'rating': rating, 'movie1': datapoints[0][3], 'movie2': datapoints[1][3], 'movie3': datapoints[2][3], 'movie4': datapoints[3][3]})
    except:
        print(f'Error with user {userId}')

    


Generating prompts:   4%|▍         | 40/943 [00:00<00:04, 195.49it/s]

Error with user 10
Error with user 33


Generating prompts:  20%|██        | 189/943 [00:00<00:03, 244.37it/s]

Error with user 170


Generating prompts:  42%|████▏     | 400/943 [00:01<00:02, 264.65it/s]

Error with user 355
Error with user 359
Error with user 376
Error with user 384


Generating prompts:  48%|████▊     | 452/943 [00:01<00:02, 244.29it/s]

Error with user 420


Generating prompts:  54%|█████▎    | 505/943 [00:02<00:01, 249.34it/s]

Error with user 469
Error with user 477
Error with user 501


Generating prompts:  73%|███████▎  | 688/943 [00:02<00:00, 255.58it/s]

Error with user 636


Generating prompts:  78%|███████▊  | 740/943 [00:03<00:00, 254.56it/s]

Error with user 694


Generating prompts:  87%|████████▋ | 819/943 [00:03<00:00, 251.24it/s]

Error with user 784
Error with user 810


Generating prompts:  92%|█████████▏| 871/943 [00:03<00:00, 246.70it/s]

Error with user 849
Error with user 888


Generating prompts: 100%|██████████| 943/943 [00:03<00:00, 242.52it/s]

Error with user 928





In [36]:

def create_jsonl_file(datapoints, file_path):
    with jsonlines.open(file_path, mode='w') as writer:
        for datapoint in datapoints:
            writer.write(datapoint)

file_path = 'output.jsonl'
create_jsonl_file(data_list, file_path)
