In [2]:
%load_ext autoreload
%autoreload 2

from src.rl.cql_dqn import *
from src.rl.rec_replay_buffer import RecReplayBuffer
from RECE.data import get_dataset, data_to_sequences, SequentialDataset
from RECE.train import prepare_sasrec_model, train_sasrec_epoch, downvote_seen_items, sasrec_model_scoring, topn_recommendations, model_evaluate
import gc
from tqdm import tqdm_notebook as tqdm
from matplotlib import pyplot as plt
from time import time
from clearml import Task, Logger

In [3]:
device = torch.device("cuda")

In [15]:
sasrec_config = dict(
    manual_seed = 123,
    sampler_seed = 123,
    num_epochs = 100, #3 10 22 100&dropout0.9&hd32&bs1000
    maxlen = 100,
    hidden_units = 64,
    dropout_rate = 0.3,
    num_blocks = 2,
    num_heads = 1,
    batch_size = 128, #DEBUG
    learning_rate = 1e-3,
    fwd_type = 'ce',
    l2_emb = 0,
    patience = 10,
    skip_epochs = 1,
    n_neg_samples=0,
    sampling='no_sampling'
)


config = TrainConfig(
    orthogonal_init = True,
    q_n_hidden_layers = 1,
    qf_lr = 3e-4,
    batch_size=sasrec_config['batch_size'],
    device="cuda",
    bc_steps=100000,
    cql_alpha=100.0,

    env="MovieLens",
    project= "CQL-SASREC",
    group= "CQL-SASREC",
    name= "CQL",
    #cql_negative_samples = 10
)

In [16]:
os.environ["WANDB_API_KEY"] = "d819ea0d92a856b5544d1aa919f503250223447c" # Change to your W&B profile if you need it
os.environ["WANDB_MODE"] = "online"

seed = config.seed
set_seed(seed)
wandb_init(asdict(config))

In [8]:
training_temp, data_description_temp, testset_valid_temp, testset, holdout_valid_temp, _ = get_dataset(splitting='temporal_full',
                                                                                     q=0.8)
data_description_temp

Filtered 115 invalid observations.
Filtered 11 invalid observations.
Filtered 4 invalid observations.


{'users': 'userid',
 'items': 'itemid',
 'order': 'timestamp',
 'n_users': 5400,
 'n_items': 3658}

In [9]:
sasrec_model, sampler, n_batches, optimizers = prepare_sasrec_model(sasrec_config, training_temp, data_description_temp, device)

In [10]:
task = log = None

def pretrain(model, config, data_description, testset_valid, holdout_valid):   
    losses = {}
    metrics = {}
    ndcg = {}
    best_ndcg = 0
    wait = 0

    start_time = time()
    torch.cuda.synchronize()
    torch.cuda.reset_peak_memory_stats()
    start_memory = torch.cuda.memory_allocated()

    checkpt_name = uuid.uuid4().hex
    if not os.path.exists('./checkpt'):
        os.mkdir('./checkpt')
    
    checkpt_path = os.path.join('./checkpt', f'{checkpt_name}.chkpt')

    for epoch in (range(config['num_epochs'])):
        losses[epoch] = train_sasrec_epoch(
            model, n_batches, config['l2_emb'], sampler, optimizers, device
        )
        if epoch % config['skip_epochs'] == 0:
            val_scores = sasrec_model_scoring(model, testset_valid, data_description, device)
            downvote_seen_items(val_scores, testset_valid, data_description)
            val_recs = topn_recommendations(val_scores, topn=10)
            val_metrics = model_evaluate(val_recs, holdout_valid, data_description)
            metrics[epoch] = val_metrics
            ndcg_ = val_metrics['ndcg@10']
            ndcg[epoch] = ndcg_

            print(f'Epoch {epoch}, NDCG@10: {ndcg_}')
            
            if task and (epoch % 5 == 0):
                log.report_scalar("Loss", series='Val', iteration=epoch, value=np.mean(losses[epoch]))
                log.report_scalar("NDCG", series='Val', iteration=epoch, value=ndcg_)

            if ndcg_ > best_ndcg:
                best_ndcg = ndcg_
                torch.save(model.state_dict(), checkpt_path)
                wait = 0
            elif wait < config['patience'] // config['skip_epochs'] + 1:
                wait += 1
            else:
                break
    
    torch.cuda.synchronize()
    training_time_sec = time() - start_time
    full_peak_training_memory_bytes = torch.cuda.max_memory_allocated()
    peak_training_memory_bytes = torch.cuda.max_memory_allocated() - start_memory
    training_epoches = len(losses)
    
    model.load_state_dict(torch.load(checkpt_path))
    os.remove(checkpt_path)

    print()
    print('Peak training memory, mb:', round(full_peak_training_memory_bytes/ 1024. / 1024., 2))
    print('Training epoches:', training_epoches)
    print('Training time, m:', round(training_time_sec/ 60., 2))
    
    if task:
        ind_max = np.argmax(list(ndcg.values())) * config['skip_epochs']
        for metric_name, metric_value in metrics[ind_max].items():
            log.report_single_value(name=f'val_{metric_name}', value=round(metric_value, 4))
        log.report_single_value(name='train_peak_mem_mb', value=round(peak_training_memory_bytes/ 1024. / 1024., 2))
        log.report_single_value(name='full_train_peak_mem_mb', value=round(full_peak_training_memory_bytes/ 1024. / 1024., 2))
        log.report_single_value(name='train_epoches', value=training_epoches)
        log.report_single_value(name='train_time_m', value=round(training_time_sec/ 60., 2))

In [160]:
pretrain(sasrec_model, sasrec_config, data_description_temp, testset_valid_temp, holdout_valid_temp)

Epoch 0, NDCG@10: 0.020602853670914675
Epoch 1, NDCG@10: 0.022537926101137903
Epoch 2, NDCG@10: 0.03079701581977307
Epoch 3, NDCG@10: 0.05572657924072975
Epoch 4, NDCG@10: 0.0759097182715343
Epoch 5, NDCG@10: 0.0909473590097159
Epoch 6, NDCG@10: 0.10210885287687085
Epoch 7, NDCG@10: 0.10908023379450235
Epoch 8, NDCG@10: 0.1166471985471137
Epoch 9, NDCG@10: 0.12192767949789424
Epoch 10, NDCG@10: 0.12455379862998373
Epoch 11, NDCG@10: 0.1296894701894721
Epoch 12, NDCG@10: 0.13408232137587925
Epoch 13, NDCG@10: 0.13641051296376924
Epoch 14, NDCG@10: 0.13819931679851163
Epoch 15, NDCG@10: 0.13939482776629394
Epoch 16, NDCG@10: 0.14317104566636943
Epoch 17, NDCG@10: 0.14352014098851246
Epoch 18, NDCG@10: 0.1451760569770132


KeyboardInterrupt: 

In [11]:
sasrec_model.fwd_type = 'embedding'

In [12]:

state_dim = data_description_temp['n_items']+2
action_dim = data_description_temp['n_items']+2

replay_buffer = RecReplayBuffer(
    state_dim,
    action_dim,
    config.buffer_size,
    config.device,
    sampler
)

max_action = float(1)

if config.checkpoints_path is not None:
    print(f"Checkpoints path: {config.checkpoints_path}")
    os.makedirs(config.checkpoints_path, exist_ok=True)
    with open(os.path.join(config.checkpoints_path, "config.yaml"), "w") as f:
        pyrallis.dump(config, f)

# Set seeds
seed = config.seed
set_seed(seed)


q_1 = FullyConnectedQFunction(
    64,
    action_dim,
    config.orthogonal_init,
    config.q_n_hidden_layers
).to(config.device)

q_2 = FullyConnectedQFunction(64, action_dim, config.orthogonal_init, config.q_n_hidden_layers).to(
    config.device
)
q_1_optimizer = torch.optim.Adam(list(q_1.parameters()), config.qf_lr)
q_2_optimizer = torch.optim.Adam(list(q_2.parameters()), config.qf_lr)

kwargs = {
    "body": sasrec_model,
    "body_optimizer": optimizers,
    "q_1": q_1,
    "q_2": q_2,
    "q_1_optimizer": q_1_optimizer,
    "q_2_optimizer": q_2_optimizer,
    "discount": config.discount,
    "soft_target_update_rate": config.soft_target_update_rate,
    "device": config.device,
    # CQL
    "target_entropy": 1,
    "alpha_multiplier": config.alpha_multiplier,
    "use_automatic_entropy_tuning": config.use_automatic_entropy_tuning,
    "backup_entropy": config.backup_entropy,
    "policy_lr": config.policy_lr,
    "qf_lr": config.qf_lr,
    "bc_steps": config.bc_steps,
    "target_update_period": config.target_update_period,
    "cql_n_actions": config.cql_n_actions,
    "cql_importance_sample": config.cql_importance_sample,
    "cql_lagrange": config.cql_lagrange,
    "cql_target_action_gap": config.cql_target_action_gap,
    "cql_temp": config.cql_temp,
    "cql_alpha": config.cql_alpha,
    "cql_max_target_backup": config.cql_max_target_backup,
    "cql_clip_diff_min": config.cql_clip_diff_min,
    "cql_clip_diff_max": config.cql_clip_diff_max,
    "cql_negative_samples": 10
}

trainer = DQNCQL(**kwargs)

In [13]:
gc.collect()
torch.cuda.empty_cache()

In [17]:
def train_agent_epoch():
    trainer.q_1.train()
    trainer.q_2.train()
    trainer.body.train()
    losses = []
    N = len(sampler)
    for t in range(N):
        batch = replay_buffer.sample(config.batch_size)
        batch = [b.to(config.device) for b in batch]
        log_dict = trainer.train(batch)
        losses.append(log_dict['loss'])
        if t % 100 == 1:
            print(f"Iter {t} of {N}. Train loss: ", np.mean(losses[-100:]))
    return np.mean(losses)

def agent_model_scoring(data, data_description, device):
    trainer.q_1.eval()
    trainer.q_2.eval()
    trainer.body.eval()
    test_sequences = data_to_sequences(data, data_description)
    # perform scoring on a user-batch level
    scores = []
    for _, seq in test_sequences.items():
        with torch.no_grad():
            body_out = trainer.body.score_with_state(torch.tensor(seq, device=device, dtype=torch.long))[-1]
            body_out = body_out.reshape(-1, body_out.shape[-1])
            predictions = (q_1(body_out) + q_2(body_out)) / 2.0
        scores.append(predictions.detach().cpu().numpy())
    return np.concatenate(scores, axis=0)

def train_agent(config, data_description, testset_valid, holdout_valid):   
    losses = {}
    metrics = {}
    ndcg = {}
    best_ndcg = 0
    wait = 0

    start_time = time()
    torch.cuda.synchronize()
    torch.cuda.reset_peak_memory_stats()
    start_memory = torch.cuda.memory_allocated()

    checkpt_name = uuid.uuid4().hex
    if not os.path.exists('./checkpt'):
        os.mkdir('./checkpt')
    
    checkpt_path = os.path.join('./checkpt', f'{checkpt_name}.chkpt')

    for epoch in (range(config['num_epochs'])):
        losses[epoch] = train_agent_epoch()
        wandb.log({
            "train_loss": losses[epoch]
        }, step=trainer.total_it)
        if epoch % config['skip_epochs'] == 0:
            val_scores = agent_model_scoring(testset_valid, data_description, device)
            downvote_seen_items(val_scores, testset_valid, data_description)
            val_recs = topn_recommendations(val_scores, topn=10)
            val_metrics = model_evaluate(val_recs, holdout_valid, data_description)
            metrics[epoch] = val_metrics
            ndcg_ = val_metrics['ndcg@10']
            ndcg[epoch] = ndcg_

            print(f'Epoch {epoch}, NDCG@10: {ndcg_}')
            wandb.log({
                "valid NDCG@10": ndcg_
            }, step=trainer.total_it)
            
            if task and (epoch % 5 == 0):
                log.report_scalar("Loss", series='Val', iteration=epoch, value=np.mean(losses[epoch]))
                log.report_scalar("NDCG", series='Val', iteration=epoch, value=ndcg_)

            if ndcg_ > best_ndcg:
                best_ndcg = ndcg_
                #torch.save(model.state_dict(), checkpt_path)
                wait = 0
            elif wait < config['patience'] // config['skip_epochs'] + 1:
                wait += 1
            else:
                break
    
    torch.cuda.synchronize()
    training_time_sec = time() - start_time
    full_peak_training_memory_bytes = torch.cuda.max_memory_allocated()
    peak_training_memory_bytes = torch.cuda.max_memory_allocated() - start_memory
    training_epoches = len(losses)
    
    #model.load_state_dict(torch.load(checkpt_path))
    #trainer.load_state_dict()
    #os.remove(checkpt_path)

    print()
    print('Peak training memory, mb:', round(full_peak_training_memory_bytes/ 1024. / 1024., 2))
    print('Training epoches:', training_epoches)
    print('Training time, m:', round(training_time_sec/ 60., 2))
    
    if task:
        ind_max = np.argmax(list(ndcg.values())) * config['skip_epochs']
        for metric_name, metric_value in metrics[ind_max].items():
            log.report_single_value(name=f'val_{metric_name}', value=round(metric_value, 4))
        log.report_single_value(name='train_peak_mem_mb', value=round(peak_training_memory_bytes/ 1024. / 1024., 2))
        log.report_single_value(name='full_train_peak_mem_mb', value=round(full_peak_training_memory_bytes/ 1024. / 1024., 2))
        log.report_single_value(name='train_epoches', value=training_epoches)
        log.report_single_value(name='train_time_m', value=round(training_time_sec/ 60., 2))

In [19]:
train_agent(sasrec_config, data_description_temp, testset_valid_temp, holdout_valid_temp)

100

In [171]:
alpha = 1.0, lr = 3-e4, pretrained 
Epoch 3, NDCG@10: 0.09942672790465602


torch.Size([128, 100, 1])

In [None]:
alpha = 100.0, lr = 3-e4, pretrained 
Epoch 23, NDCG@10: 0.1265350095744121


In [None]:
df