In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import mlflow
import optax
import polars as pl
from flax import nnx
from flax_trainer.trainer import Trainer
from sklearn.model_selection import train_test_split

from flax_recsys.encoder import SequentialEncoder
from flax_recsys.evaluator import SequentialEvaluator
from flax_recsys.loader import SequentialLoader
from flax_recsys.loss_fn import cross_entropy_loss
from flax_recsys.model import GRU4Rec

  from .autonotebook import tqdm as notebook_tqdm


### READ

In [3]:
dataset_df = (
    pl.read_csv("/workspace/dataset/amazon-m2/sessions_train.csv")
    .filter(pl.col("locale") == "ES")
    .with_columns(
        pl.col("prev_items")
        .str.split(by="'")
        .list.eval(pl.element().filter(pl.arange(0, pl.len()) % 2 == 1))
    )
    .with_columns(
        pl.concat_list(pl.col("prev_items"), pl.col("next_item")).alias("item_ids")
    )
)
display(dataset_df)

train_df, valid_df = train_test_split(
    dataset_df, test_size=0.1, random_state=0, shuffle=True
)

encoder = SequentialEncoder()
encoder.fit(train_df.get_column("item_ids").to_list())

prev_items,next_item,locale,item_ids
list[str],str,str,list[str]
"[""B08MV5B53K"", ""B08MV4RCQR"", ""B08MV5B53K""]","""B012408XPC""","""ES""","[""B08MV5B53K"", ""B08MV4RCQR"", … ""B012408XPC""]"
"[""B07JGW4QWX"", ""B085VCXHXL""]","""B07JFPYN5P""","""ES""","[""B07JGW4QWX"", ""B085VCXHXL"", ""B07JFPYN5P""]"
"[""B08BFQ52PR"", ""B08LVSTZVF"", ""B08BFQ52PR""]","""B08NJP3KT6""","""ES""","[""B08BFQ52PR"", ""B08LVSTZVF"", … ""B08NJP3KT6""]"
"[""B08PPBF9C6"", ""B08PPBF9C6"", … ""B08PPBF9C6""]","""B08PP6BLLK""","""ES""","[""B08PPBF9C6"", ""B08PPBF9C6"", … ""B08PP6BLLK""]"
"[""B0B6W67XCR"", ""B0B712FY2M"", ""B0B6ZYJ3S2""]","""B09SL4MBM2""","""ES""","[""B0B6W67XCR"", ""B0B712FY2M"", … ""B09SL4MBM2""]"
…,…,…,…
"[""B08LR8CH7S"", ""B00CWKFYES""]","""B00C2U6794""","""ES""","[""B08LR8CH7S"", ""B00CWKFYES"", ""B00C2U6794""]"
"[""B08KH2MTSS"", ""B08KJP91X2""]","""B0BGYMJM5S""","""ES""","[""B08KH2MTSS"", ""B08KJP91X2"", ""B0BGYMJM5S""]"
"[""B09ZV92J5P"", ""B06ZY1MXNG"", ""B07W6NS1VC""]","""B09TDTC96J""","""ES""","[""B09ZV92J5P"", ""B06ZY1MXNG"", … ""B09TDTC96J""]"
"[""B08V6Q3V25"", ""B08V6Q3V25"", ""B09D7MPCDH""]","""B09D7CXVJZ""","""ES""","[""B08V6Q3V25"", ""B08V6Q3V25"", … ""B09D7CXVJZ""]"


<flax_recsys.encoder.sequential_encoder.SequentialEncoder at 0xffff44378f90>

In [4]:
batch_size = 512

loader = SequentialLoader(
    sequences=train_df.get_column("item_ids").to_list(),
    encoder=encoder,
    batch_size=batch_size,
    rngs=nnx.Rngs(0),
)

evaluator = SequentialEvaluator(
    sequences=valid_df.get_column("item_ids").to_list(),
    encoder=encoder,
    batch_size=batch_size,
)

                                                                      

### 学習

In [5]:
model = GRU4Rec(
    item_num=encoder.item_num + 1,
    embed_dim=30,
    gru_layer_dims=[30],
    ff_layer_dims=[30],
    output_layer_dim=encoder.item_num + 1,
    rngs=nnx.Rngs(0),
    max_batch_size=batch_size,
)

In [6]:
mlflow.set_tracking_uri(uri="http://localhost:8080")
mlflow.set_experiment("GRU4Rec")

with mlflow.start_run() as run:
    trainer = Trainer(
        model=model,
        optimizer=optax.adamw(learning_rate=0.001, weight_decay=0.001),
        train_loader=loader,
        loss_fn=cross_entropy_loss,
        valid_evaluator=evaluator,
        early_stopping_patience=10,
        epoch_num=32,
        active_run=run,
    )
    trainer = trainer.fit()

100%|██████████| 113/113 [00:04<00:00, 23.84it/s]


[VALID 000]: loss=10.621649742126465, metrics={'hit_10': 0.0002447680744808167, 'cross_entropy': 10.621649742126465}


[TRAIN 001]: 100%|██████████| 599/599 [00:35<00:00, 17.01it/s, batch_loss=9.884598]  
100%|██████████| 113/113 [00:04<00:00, 25.46it/s]


[VALID 001]: loss=9.517483711242676, metrics={'hit_10': 0.050758782774209976, 'cross_entropy': 9.517483711242676}


[TRAIN 002]: 100%|██████████| 588/588 [00:33<00:00, 17.50it/s, batch_loss=8.924423] 
100%|██████████| 113/113 [00:04<00:00, 24.80it/s]


[VALID 002]: loss=8.884390830993652, metrics={'hit_10': 0.15555012226104736, 'cross_entropy': 8.884390830993652}


[TRAIN 003]: 100%|██████████| 590/590 [00:32<00:00, 18.12it/s, batch_loss=8.225981] 
100%|██████████| 113/113 [00:04<00:00, 25.59it/s]


[VALID 003]: loss=8.407713890075684, metrics={'hit_10': 0.24663443863391876, 'cross_entropy': 8.407713890075684}


[TRAIN 004]: 100%|██████████| 591/591 [00:32<00:00, 18.21it/s, batch_loss=7.440843] 
100%|██████████| 113/113 [00:04<00:00, 25.45it/s]


[VALID 004]: loss=8.096561431884766, metrics={'hit_10': 0.31467998027801514, 'cross_entropy': 8.096561431884766}


[TRAIN 005]: 100%|██████████| 591/591 [00:33<00:00, 17.90it/s, batch_loss=6.855351] 
100%|██████████| 113/113 [00:04<00:00, 24.85it/s]


[VALID 005]: loss=7.9163641929626465, metrics={'hit_10': 0.36323583126068115, 'cross_entropy': 7.9163641929626465}


[TRAIN 006]: 100%|██████████| 603/603 [00:33<00:00, 18.01it/s, batch_loss=6.28156]   
100%|██████████| 113/113 [00:04<00:00, 26.13it/s]


[VALID 006]: loss=7.780933380126953, metrics={'hit_10': 0.3998592495918274, 'cross_entropy': 7.780933380126953}


[TRAIN 007]: 100%|██████████| 592/592 [00:32<00:00, 18.14it/s, batch_loss=5.903362]  
100%|██████████| 113/113 [00:04<00:00, 24.24it/s]


[VALID 007]: loss=7.708841323852539, metrics={'hit_10': 0.4261106252670288, 'cross_entropy': 7.708841323852539}


[TRAIN 008]: 100%|██████████| 592/592 [00:33<00:00, 17.91it/s, batch_loss=5.716528]  
100%|██████████| 113/113 [00:04<00:00, 25.13it/s]


[VALID 008]: loss=7.684292316436768, metrics={'hit_10': 0.4444376528263092, 'cross_entropy': 7.684292316436768}


[TRAIN 009]: 100%|██████████| 596/596 [00:34<00:00, 17.42it/s, batch_loss=5.0491924] 
100%|██████████| 113/113 [00:04<00:00, 24.75it/s]


[VALID 009]: loss=7.677452564239502, metrics={'hit_10': 0.4584506154060364, 'cross_entropy': 7.677452564239502}


[TRAIN 010]: 100%|██████████| 605/605 [00:33<00:00, 18.05it/s, batch_loss=5.187661]   
100%|██████████| 113/113 [00:04<00:00, 24.49it/s]


[VALID 010]: loss=7.679049968719482, metrics={'hit_10': 0.4694957733154297, 'cross_entropy': 7.679049968719482}


[TRAIN 011]: 100%|██████████| 603/603 [00:37<00:00, 16.21it/s, batch_loss=4.755828]    
100%|██████████| 113/113 [00:04<00:00, 23.67it/s]


[VALID 011]: loss=7.7042236328125, metrics={'hit_10': 0.47803205251693726, 'cross_entropy': 7.7042236328125}


[TRAIN 012]: 100%|██████████| 600/600 [00:34<00:00, 17.23it/s, batch_loss=4.5527043]  
100%|██████████| 113/113 [00:04<00:00, 25.09it/s]


[VALID 012]: loss=7.729147911071777, metrics={'hit_10': 0.48595643043518066, 'cross_entropy': 7.729147911071777}


[TRAIN 013]: 100%|██████████| 593/593 [00:34<00:00, 17.13it/s, batch_loss=4.4096193] 
100%|██████████| 113/113 [00:04<00:00, 24.59it/s]


[VALID 013]: loss=7.756795406341553, metrics={'hit_10': 0.49020928144454956, 'cross_entropy': 7.756795406341553}


[TRAIN 014]: 100%|██████████| 592/592 [00:34<00:00, 17.16it/s, batch_loss=4.4204063] 
100%|██████████| 113/113 [00:04<00:00, 24.23it/s]


[VALID 014]: loss=7.781134605407715, metrics={'hit_10': 0.4946456849575043, 'cross_entropy': 7.781134605407715}


[TRAIN 015]: 100%|██████████| 590/590 [00:34<00:00, 16.96it/s, batch_loss=3.9354303]
100%|██████████| 113/113 [00:04<00:00, 25.03it/s]


[VALID 015]: loss=7.809829235076904, metrics={'hit_10': 0.49957165122032166, 'cross_entropy': 7.809829235076904}


[TRAIN 016]: 100%|██████████| 593/593 [00:34<00:00, 17.01it/s, batch_loss=3.8762136] 
100%|██████████| 113/113 [00:04<00:00, 24.48it/s]


[VALID 016]: loss=7.852991104125977, metrics={'hit_10': 0.5014685988426208, 'cross_entropy': 7.852991104125977}


[TRAIN 017]: 100%|██████████| 595/595 [00:35<00:00, 16.79it/s, batch_loss=3.6959248] 
100%|██████████| 113/113 [00:04<00:00, 23.71it/s]


[VALID 017]: loss=7.873661041259766, metrics={'hit_10': 0.5046811699867249, 'cross_entropy': 7.873661041259766}


[TRAIN 018]: 100%|██████████| 599/599 [00:35<00:00, 16.65it/s, batch_loss=3.6972241]  
100%|██████████| 113/113 [00:04<00:00, 25.41it/s]


[VALID 018]: loss=7.906472682952881, metrics={'hit_10': 0.5063333511352539, 'cross_entropy': 7.906472682952881}


[TRAIN 019]: 100%|██████████| 588/588 [00:35<00:00, 16.59it/s, batch_loss=3.409321] 
100%|██████████| 113/113 [00:04<00:00, 24.33it/s]


[VALID 019]: loss=7.932363033294678, metrics={'hit_10': 0.5087810754776001, 'cross_entropy': 7.932363033294678}
🏃 View run unique-koi-969 at: http://localhost:8080/#/experiments/756717060270598937/runs/e6a5d809a8ff4f3ebba70b75d481ca49
🧪 View experiment at: http://localhost:8080/#/experiments/756717060270598937
