In [3]:
import pandas as pd
from sklearn.model_selection import train_test_split

In [17]:
RAW_U_DATA_PATH = '../data/raw/ml-100k/u.data'
INTERIM_U_TRAIN_PATH = '../data/interim/u.train'
INTERIM_U_TEST_PATH = '../data/interim/u.test'

In [5]:
columns_name=['user_id','item_id','rating','timestamp']
df = pd.read_csv(RAW_U_DATA_PATH,sep="\t",names=columns_name)
df = df[df['rating']>=3]

Perform a 80/20 train-test split on the interactions in the dataset

In [10]:
train, test = train_test_split(df.values, test_size=0.2, random_state=16)
train_df = pd.DataFrame(train, columns=df.columns)
test_df = pd.DataFrame(test, columns=df.columns)

In [11]:
print("Train Size  : ", len(train_df))
print("Test Size : ", len (test_df))

Train Size  :  66016
Test Size :  16504


Since we performed the train/test randomly on the interactions, not all users and items may be present in the training set. We will relabel all of users and items to ensure the highest label is the number of users and items, respectively.

In [12]:
le_user = pp.LabelEncoder()
le_item = pp.LabelEncoder()
train_df['user_id_idx'] = le_user.fit_transform(train_df['user_id'].values)
train_df['item_id_idx'] = le_item.fit_transform(train_df['item_id'].values)

In [13]:
train_user_ids = train_df['user_id'].unique()
train_item_ids = train_df['item_id'].unique()

print(len(train_user_ids), len(train_item_ids))

test_df = test_df[
  (test_df['user_id'].isin(train_user_ids)) & \
  (test_df['item_id'].isin(train_item_ids))
]
print(len(test_df))

943 1546
16472


In [14]:
test_df['user_id_idx'] = le_user.transform(test_df['user_id'].values)
test_df['item_id_idx'] = le_item.transform(test_df['item_id'].values)

In [15]:
n_users = train_df['user_id_idx'].nunique()
n_items = train_df['item_id_idx'].nunique()
print("Number of Unique Users : ", n_users)
print("Number of unique Items : ", n_items)

Number of Unique Users :  943
Number of unique Items :  1546


In [21]:
train_df.to_csv(INTERIM_U_TRAIN_PATH,sep="\t",index=False)
test_df.to_csv(INTERIM_U_TEST_PATH,sep="\t", index=False)