In [2]:
import os
import tqdm
import seaborn as sns

import pandas as pd
import numpy as np
from lightfm import LightFM
from lightfm.data import Dataset

# Import LightFM's evaluation metrics
from lightfm.evaluation import precision_at_k

%matplotlib inline
SEED = 42
np.random.seed(SEED)
os.environ['PYTHONHASHSEED'] = str(SEED)



In [3]:
K = 50 #recall number
EPOCHS = 10


In [4]:
main_dir = "./dataset"
user = pd.read_csv(main_dir+"/user.csv")
item = pd.read_csv(main_dir+"/item.csv", dtype={'article_id': int})
train = pd.read_csv(main_dir+'/transaction.csv', usecols=['t_dat','customer_id','item_id'] ,dtype={'article_id': int}, parse_dates=['t_dat'])

In [5]:
train['article_id'] = train['item_id']

In [5]:
train.head(2)

Unnamed: 0,t_dat,customer_id,item_id,article_id
0,2019-09-01,000f7535bdc611ad136a9f04746d6b1431f50a7f60fbbe...,727880001,727880001
1,2019-09-01,000f7535bdc611ad136a9f04746d6b1431f50a7f60fbbe...,767869001,767869001


In [5]:
item.head(2)

Unnamed: 0,article_id,product_code,prod_name,product_type_no,product_type_name,product_group_name,graphical_appearance_no,graphical_appearance_name,colour_group_code,colour_group_name,...,department_name,index_code,index_name,index_group_no,index_group_name,section_no,section_name,garment_group_no,garment_group_name,detail_desc
0,108775015,108775,Strap top,253,Vest top,Garment Upper body,1010016,Solid,9,Black,...,Jersey Basic,A,Ladieswear,1,Ladieswear,16,Womens Everyday Basics,1002,Jersey Basic,Jersey top with narrow shoulder straps.
1,108775044,108775,Strap top,253,Vest top,Garment Upper body,1010016,Solid,10,White,...,Jersey Basic,A,Ladieswear,1,Ladieswear,16,Womens Everyday Basics,1002,Jersey Basic,Jersey top with narrow shoulder straps.


In [6]:
user.head(2)

Unnamed: 0,customer_id,FN,Active,club_member_status,fashion_news_frequency,age,postal_code
0,00000dbacae5abe5e23885899a1fa44253a17956c6d1c3...,,,ACTIVE,NONE,49.0,52043ee2162cf5aa7ee79974281641c6f11a68d276429a...
1,0000423b00ade91418cceaf3b26c6af3dd342b51fd051e...,,,ACTIVE,NONE,25.0,2973abc54daa8a5f8ccfe9362140c63247c5eee03f1d93...


In [6]:
train['t_dat'] = pd.to_datetime(train['t_dat'])
last_date = train['t_dat'].max()
first_date = train['t_dat'].min()
print(first_date,last_date)

2019-09-01 00:00:00 2020-09-22 00:00:00


In [7]:
# calculate week number
train['week'] = (last_date - pd.to_datetime(train['t_dat'])).dt.days // 7

In [8]:
# set item features
def create_item_features(row):
    features = [
        f"product_code:{row['product_code']}",
        f"section_no:{row['section_no']}",
        f"colour_group_code:{row['colour_group_code']}",
        f"perceived_colour_value_id:{row['perceived_colour_value_id']}",
        f"index_code:{row['index_code']}",
        f"product_type_no:{row['product_type_no']}",
        f"department_no:{row['department_no']}",
        f"garment_group_no:{row['garment_group_no']}",
        f"graphical_appearance_no:{row['graphical_appearance_no']}",
    ]
    return features


item['features'] = item.apply(create_item_features, axis=1)


In [9]:
all_features = set()
item['features'].apply(lambda x: all_features.update(x))

0        None
1        None
2        None
3        None
4        None
         ... 
70936    None
70937    None
70938    None
70939    None
70940    None
Name: features, Length: 70941, dtype: object

In [10]:
# set item features
def create_user_features(row):
    features = [
        f"FN:{row['FN']}",
        f"Active:{row['Active']}",
        f"club_member_status:{row['club_member_status']}",
        f"fashion_news_frequency:{row['fashion_news_frequency']}",
        f"age:{row['age']}",
        f"postal_code:{row['postal_code']}",
      
    ]
    return features

user['features'] = user.apply(create_user_features, axis=1)
all_user_features = set()
user['features'].apply(lambda x: all_user_features.update(x))


0         None
1         None
2         None
3         None
4         None
          ... 
594376    None
594377    None
594378    None
594379    None
594380    None
Name: features, Length: 594381, dtype: object

In [11]:
listBin = [-1, 19, 29, 39, 49, 59, 69, 119]
user['age'] = pd.cut(user['age'], listBin)

In [12]:
dataset = Dataset()
dataset.fit(users=user['customer_id'], 
            user_features=(x for x in all_user_features),
            items=item['article_id'],
            item_features=(x for x in all_features))
num_users, num_topics = dataset.interactions_shape()
print(f'Number of users: {num_users}, Number of topics: {num_topics}.')

Number of users: 594381, Number of topics: 70941.


In [13]:
#build item & user Features
item_features_data = ((row['article_id'], row['features']) for index, row in item.iterrows())
item_features = dataset.build_item_features(item_features_data)
user_features_data = ((row['customer_id'], row['features']) for index, row in user.iterrows())
user_features = dataset.build_user_features(user_features_data)

In [14]:
from tqdm.notebook import tqdm
tqdm.pandas()

def generate_predictions(dataset, model, val_set, K, batch_size=100, type='user-itemCF'):
    """
    Optimized function to generate top K predictions for all users in the validation set.

    Args:
        dataset: The dataset object containing mappings.
        model: The trained recommendation model.
        val_set: The validation set containing customer IDs.
        K: The number of top items to predict.
        batch_size: The number of user IDs to process in each batch.

    Returns:
        DataFrame: The validation set with an additional column for predictions.
    """

    uid_map, _, iid_map, _ = dataset.mapping()
    inv_uid_map = {v: k for k, v in uid_map.items()}
    inv_iid_map = {v: k for k, v in iid_map.items()}
    
    # Preparing item data
    all_item_ids_model = list(iid_map.values())
    item_ids = np.array(all_item_ids_model)
    
    predictions = []

    customer_ids = val_set['customer_id'].unique()
    num_batches = len(customer_ids) // batch_size + (1 if len(customer_ids) % batch_size != 0 else 0)
    
    for i in tqdm(range(num_batches), desc="Predicting"):
        start_index = i * batch_size
        end_index = start_index + batch_size
        batch_customer_ids = customer_ids[start_index:end_index]
        
        batch_user_ids_model = np.array([uid_map[cid] for cid in batch_customer_ids])
        
        # Repeating user IDs for each item
        user_ids = np.repeat(batch_user_ids_model, len(all_item_ids_model))
        item_ids_batch = np.tile(item_ids, len(batch_customer_ids))
        if type == 'user-itemCF':
            batch_predictions = model.predict(
                user_ids=user_ids,
                item_ids=item_ids_batch,
                user_features=user_features,
                item_features=item_features,
                num_threads=16
            )
        elif type == 'itemCF':
            batch_predictions = model.predict(
                user_ids=user_ids,
                item_ids=item_ids_batch,
                item_features=item_features,
                num_threads=16
            )
        elif type == 'userCF':
            batch_predictions = model.predict(
                user_ids=user_ids,
                item_ids=item_ids_batch,
                user_features=user_features,
                num_threads=16
            )
        else:
            batch_predictions = None
        
        batch_predictions = batch_predictions.reshape(len(batch_customer_ids), len(all_item_ids_model))
        
        # Extracting top K items for each user in the batch
        for user_predictions in batch_predictions:
            top_k_item_indices = np.argsort(-user_predictions)[:K]
            top_k_item_ids_model = item_ids[top_k_item_indices]
            top_k_item_ids_original = [inv_iid_map[item_id] for item_id in top_k_item_ids_model]
            prediction_str = ' '.join(str(item_id) for item_id in top_k_item_ids_original)
            predictions.append(prediction_str)
    
    val_set[f'prediction_{type}'] = predictions
    return val_set



In [17]:
from joblib import dump, load


In [22]:
recall_weeks = [0,1,2,3,4]
week_nums = 23

In [23]:
#model prediction
for week in recall_weeks:
    label_set = train[train['week']==week]
    label_set_grouped = label_set.groupby('customer_id')['article_id'].agg(list).reset_index()
    label_set = label_set.drop('article_id', axis=1).merge(label_set_grouped, on='customer_id')
    label_set['label'] = label_set['article_id'].apply(lambda x:list(set(x)))
    label_set = label_set.drop(['item_id','t_dat','article_id'],axis=1).drop_duplicates(subset=['customer_id', 'week'])
    #start to train the recall model and predict the candidates.
    #label_set = label_set.head(100)
    userCF_model = load(f'userCF_recall_{week}.joblib')
    itemCF_model = load(f'itemCF_recall_{week}.joblib')
    uiCF_model = load(f'uiCF_recall_{week}.joblib')
    label_set = generate_predictions(dataset, userCF_model, label_set, K,5120,'userCF')
    label_set = generate_predictions(dataset, itemCF_model, label_set, K,5120,'itemCF')
    label_set = generate_predictions(dataset, uiCF_model, label_set, K,5120,'user-itemCF')
    label_set.to_parquet(f'./dataset/recall_CF_week{week}.pt')

    
        
    

Predicting:   0%|          | 0/11 [00:00<?, ?it/s]

Predicting:   0%|          | 0/11 [00:00<?, ?it/s]

Predicting:   0%|          | 0/11 [00:00<?, ?it/s]

Predicting:   0%|          | 0/12 [00:00<?, ?it/s]

Predicting:   0%|          | 0/12 [00:00<?, ?it/s]

Predicting:   0%|          | 0/12 [00:00<?, ?it/s]

Predicting:   0%|          | 0/12 [00:00<?, ?it/s]

Predicting:   0%|          | 0/12 [00:00<?, ?it/s]

Predicting:   0%|          | 0/12 [00:00<?, ?it/s]

Predicting:   0%|          | 0/13 [00:00<?, ?it/s]

Predicting:   0%|          | 0/13 [00:00<?, ?it/s]

Predicting:   0%|          | 0/13 [00:00<?, ?it/s]

Predicting:   0%|          | 0/12 [00:00<?, ?it/s]

Predicting:   0%|          | 0/12 [00:00<?, ?it/s]

Predicting:   0%|          | 0/12 [00:00<?, ?it/s]

In [24]:
label_set

Unnamed: 0,customer_id,week,label,prediction_userCF,prediction_itemCF,prediction_user-itemCF
0,002156b708c7c6dd8afe31a743131d13b1e5dcbf2ce8c4...,4,"[896152002, 897146002]",523489001 872233003 783346018 573085045 568601...,715624008 715624012 685814048 815264002 715624...,621914009 405636022 621914001 749088001 541308...
2,00ad9e5d82fc8ad18e1fac84f515ab735bd516df32b8ca...,4,"[572998009, 920752001, 902161006]",892850001 844874008 781758033 865440001 816166...,843753004 829977003 689389047 829977002 828930...,730454048 756356003 732409001 845790003 711957...
5,00d4f8759e569b9e63a6fecb0c2ed5802174c2c2ad64ce...,4,"[915526001, 557994014]",856121001 865440001 892794001 880089001 900176...,715624008 876410001 751429003 592959008 815669...,697201002 727947006 690621001 825744002 590489...
7,0119826e13f3ef7fb3fb84c778a883710cc859de4b1886...,4,"[808659001, 832307003, 750424017]",875451002 850249003 816841006 861731002 843362...,833975002 868042001 870958004 501620040 501620...,659460002 659460003 659460001 513696004 887904...
10,011fc4c3387f8c6eba0e7062aa47750b65d4dc2d5d6148...,4,"[762600009, 817086002, 535455003]",849468001 845285001 814817001 831412002 891669...,868042001 838182002 689389047 658880001 699620...,824352001 840856001 513696004 824352003 824352...
...,...,...,...,...,...,...
206536,fb04e98ab39aff596c8dee1f20c5863a12f68d2064ae72...,4,"[873217001, 873217004, 868034001, 791033010, 7...",875728001 781758057 897693003 915019001 781758...,685814001 685814048 715624012 685814003 505882...,920529004 920529001 788464001 744306007 744306...
206541,fd344f39be798bc456fd3c041b6cef4933ab0e5875189a...,4,"[799365028, 886737001, 906305002, 891898001, 7...",787285001 871346001 624251002 874465001 804551...,756415005 699923113 654854001 572127013 917720...,830874001 830874002 612136002 664074021 664074...
206546,fe5648cc03e5337ce28d4ba24cebdf57247c093a937ee7...,4,"[715624001, 448509001, 910601001, 826646001, 9...",877643001 817361007 846581002 840947003 859399...,640331001 820097001 685814048 626263019 464297...,830874002 664074037 830874001 513696004 683356...
206554,fec9fcd8d529ecd32485518de4ea12f196b9f7126a5c1e...,4,"[866714016, 918516001, 896152002, 866714017, 8...",399256005 826492007 859105002 858052005 399256...,828499001 828930003 655710008 751257007 815669...,830874001 830874002 664074043 664074021 664074...
