In [132]:
pip install pytorch-ranger

[0mNote: you may need to restart the kernel to use updated packages.


In [133]:
import math
import random

import gym
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Normal
from torch.utils.tensorboard import SummaryWriter


In [134]:
from collections import defaultdict
import os
import pickle
import random
import requests
import time
import tqdm

from IPython.core.debugger import set_trace
import numpy as np
import pandas as pd
from pytorch_ranger import Ranger
import torch
import torch.nn as nn
import torch.nn.functional as F 
import torch.utils.data as td
from torch.utils.tensorboard import SummaryWriter

In [135]:
from IPython.display import clear_output
import matplotlib.pyplot as plt
%matplotlib inline

**USE CUDA**

In [136]:
use_cuda = torch.cuda.is_available()
device   = torch.device("cuda:0" if use_cuda else "cpu")


**ULTIL**

In [137]:
from collections import defaultdict
import os
import random
import time
import tqdm

import numpy as np
import pandas as pd
import scipy.sparse as sp
import torch.utils.data as td


class EvalDataset(td.Dataset):
    def __init__(self, positive_data, item_num, positive_mat, negative_samples=99):
        super(EvalDataset, self).__init__()
        self.positive_data = np.array(positive_data)
        self.item_num = item_num
        self.positive_mat = positive_mat
        self.negative_samples = negative_samples
        
        self.reset()
        
    def reset(self):
        print("Resetting dataset")
        data = self.create_valid_data()
        labels = np.zeros(len(self.positive_data) * (1 + self.negative_samples))
        labels[::1+self.negative_samples] = 1
        self.data = np.concatenate([
            np.array(data), 
            np.array(labels)[:, np.newaxis]], 
            axis=1
        )

    def create_valid_data(self):
        valid_data = []
        for user, positive in self.positive_data:
            valid_data.append([user, positive])
            for i in range(self.negative_samples):
                negative = np.random.randint(self.item_num)
                while (user, negative) in self.positive_mat:
                    negative = np.random.randint(self.item_num)
                    
                valid_data.append([user, negative])
        return valid_data
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        user, item, label = self.data[idx]
        output = {
            "user": user,
            "item": item,
            "label": np.float32(label),
        }
        return output


#https://github.com/vitchyr/rlkit/blob/master/rlkit/exploration_strategies/ou_strategy.py
class OUNoise(object):
    def __init__(self, action_dim, mu=0.0, theta=0.15, max_sigma=0.4, min_sigma=0.4, decay_period=100000):
        self.mu           = mu
        self.theta        = theta
        self.sigma        = max_sigma
        self.max_sigma    = max_sigma
        self.min_sigma    = min_sigma
        self.decay_period = decay_period
        self.action_dim   = action_dim
        self.reset()

    def reset(self):
        self.state = np.ones(self.action_dim) * self.mu

    def evolve_state(self):
        x  = self.state
        dx = self.theta * (self.mu - x) + self.sigma * np.random.randn(self.action_dim)
        self.state = x + dx
        return self.state

    def get_action(self, action, t=0):
        ou_state = self.evolve_state()
        self.sigma = self.max_sigma - (self.max_sigma - self.min_sigma) * min(1.0, t / self.decay_period)
        return torch.tensor([action + ou_state]).float()


class Prioritized_Buffer(object):
    def __init__(self, capacity, prob_alpha=0.6):
        self.prob_alpha = prob_alpha
        self.capacity   = capacity
        self.buffer     = []
        self.pos        = 0
        self.priorities = np.zeros((capacity,), dtype=np.float32)
    
    def push(self, user, memory, action, reward, next_user, next_memory, done):
        max_prio = self.priorities.max() if self.buffer else 1.0
        
        if len(self.buffer) < self.capacity:
            self.buffer.append((user, memory, action, reward, next_user, next_memory, done))
        else:
            self.buffer[self.pos] = (user, memory, action, reward, next_user, next_memory, done)
        
        self.priorities[self.pos] = max_prio
        self.pos = (self.pos + 1) % self.capacity
    
    def sample(self, batch_size, beta=0.4):
        if len(self.buffer) == self.capacity:
            prios = self.priorities
        else:
            prios = self.priorities[:self.pos]
        
        probs  = prios ** self.prob_alpha
        probs /= probs.sum()
        
        indices = np.random.choice(len(self.buffer), batch_size, p=probs)
        samples = [self.buffer[idx] for idx in indices]

        total    = len(self.buffer)
        weights  = (total * probs[indices]) ** (-beta)
        weights /= weights.max()
        weights  = np.array(weights, dtype=np.float32)

        batch       = list(zip(*samples))
        user        = np.concatenate(batch[0])
        memory      = np.concatenate(batch[1])
        action      = batch[2]
        reward      = batch[3]
        next_user   = np.concatenate(batch[4])
        next_memory = np.concatenate(batch[5])
        done        = batch[6]

        return user, memory, action, reward, next_user, next_memory, done

    def update_priorities(self, batch_indices, batch_priorities):
        for idx, prio in zip(batch_indices, batch_priorities):
            self.priorities[idx] = prio

    def __len__(self):
        return len(self.buffer)


def get_beta(idx, beta_start=0.4, beta_steps=100000):
    return min(1.0, beta_start + idx * (1.0 - beta_start) / beta_steps)

def preprocess_data(data_dir, train_rating):
    data = pd.read_csv(os.path.join(data_dir, train_rating), 
                       sep='\t', header=None, names=['user', 'item', 'rating'], 
                       usecols=[0, 1, 2], dtype={0: np.int32, 1: np.int32, 2: np.int8})
    data = data[data['rating'] > 3][['user', 'item']]
    user_num = data['user'].max() + 1
    item_num = data['item'].max() + 1

    train_data = data.sample(frac=0.8, random_state=16)
    test_data = data.drop(train_data.index).values.tolist()
    train_data = train_data.values.tolist()

    train_mat = defaultdict(int)
    test_mat = defaultdict(int)
    for user, item in train_data:
        train_mat[user, item] = 1.0
    for user, item in test_data:
        test_mat[user, item] = 1.0
    train_matrix = sp.dok_matrix((user_num, item_num), dtype=np.float32)
    dict.update(train_matrix, train_mat)
    test_matrix = sp.dok_matrix((user_num, item_num), dtype=np.float32)
    dict.update(test_matrix, test_mat)
    
    appropriate_users = np.arange(user_num).reshape(-1, 1)[(train_matrix.sum(1) >= 20)]
    
    return (train_data, train_matrix, test_data, test_matrix, 
            user_num, item_num, appropriate_users)

def to_np(tensor):
    return tensor.detach().cpu().numpy()

def hit_metric(recommended, actual):
    return int(actual in recommended)

def dcg_metric(recommended, actual):
    if actual in recommended:
        index = recommended.index(actual)
        return np.reciprocal(np.log2(index + 2))
    return 0

In [138]:
data_dir = "data"
rating = "ml-1m.train.rating"

params = {
    'batch_size': 512,
    'embedding_dim': 8,
    'hidden_dim': 16,
    'N': 5, # memory size for state_repr
    'ou_noise':False,
    
    'value_lr': 1e-5,
    'value_decay': 1e-4,
    'policy_lr': 1e-5,
    'policy_decay': 1e-6,
    'state_repr_lr': 1e-5,
    'state_repr_decay': 1e-3,
    'log_dir': 'logs/final/',
    'gamma': 0.8,
    'min_value': -10,
    'max_value': 10,
    'soft_tau': 1e-3,
    
    'buffer_size': 1000000
}


In [139]:
# Movielens (1M) data from the https://github.com/hexiangnan/neural_collaborative_filtering
if not os.path.isdir('./data'):
    os.mkdir('./data')
    
file_path = os.path.join(data_dir, rating)
if os.path.exists(file_path):
    print("Skip loading " + file_path)
else:
    with open(file_path, "wb") as tf:
        print("Load " + file_path)
        r = requests.get("https://raw.githubusercontent.com/hexiangnan/neural_collaborative_filtering/master/Data/" + rating)
        tf.write(r.content)
(train_data, train_matrix, test_data, test_matrix, 
 user_num, item_num, appropriate_users) = preprocess_data(data_dir, rating)

Skip loading data/ml-1m.train.rating


In [140]:
class Env():
    def __init__(self, user_item_matrix):
        self.matrix = user_item_matrix
        self.item_count = item_num
        self.memory = np.ones([user_num, params['N']]) * item_num
        # memory is initialized as [item_num] * N for each user
        # it is padding indexes in state_repr and will result in zero embeddings

    def reset(self, user_id):
        self.user_id = user_id
        self.viewed_items = []
        self.related_items = np.argwhere(self.matrix[self.user_id] > 0)[:, 1]
        self.num_rele = len(self.related_items)
        self.nonrelated_items = np.random.choice(
            list(set(range(self.item_count)) - set(self.related_items)), self.num_rele)
        self.available_items = np.zeros(self.num_rele * 2)
        self.available_items[::2] = self.related_items
        self.available_items[1::2] = self.nonrelated_items
        
        return torch.tensor([self.user_id], device=device), torch.tensor(self.memory[[self.user_id], :], device=device)
    
    def step(self, action, action_emb=None, buffer=None):
        initial_user = self.user_id
        initial_memory = self.memory[[initial_user], :]
        
        reward = float(to_np(action)[0] in self.related_items)
        self.viewed_items.append(to_np(action)[0])

        if action.is_cuda:
            Action = action.cpu()
        else:
            Action = action
        
        if reward:
            if len(Action) == 1:
                self.memory[self.user_id] = list(self.memory[self.user_id][1:]) + [Action]
            else:
                self.memory[self.user_id] = list(self.memory[self.user_id][1:]) + [Action[0]]
                
        if len(self.viewed_items) == len(self.related_items):
            done = 1
        else:
            done = 0
            
        if buffer is not None:
            buffer.push(np.array([initial_user]), np.array(initial_memory), to_np(action_emb)[0], 
                        np.array([reward]), np.array([self.user_id]), self.memory[[self.user_id], :], np.array([reward]))

        return torch.tensor([self.user_id], device=device), torch.tensor(self.memory[[self.user_id], :], device=device), reward, done

In [141]:
class State_Repr_Module(nn.Module):
    def __init__(self, user_num, item_num, embedding_dim, hidden_dim):
        super().__init__()
        self.user_embeddings = nn.Embedding(user_num, embedding_dim)
        self.item_embeddings = nn.Embedding(item_num+1, embedding_dim, padding_idx=int(item_num))
        self.drr_ave = torch.nn.Conv1d(in_channels=params['N'], out_channels=1, kernel_size=1)
        
        self.initialize()
        self.to(device)
            
    def initialize(self):
        nn.init.normal_(self.user_embeddings.weight, std=0.01)
        nn.init.normal_(self.item_embeddings.weight, std=0.01)
        self.item_embeddings.weight.data[-1].zero_()
        nn.init.uniform_(self.drr_ave.weight)
        self.drr_ave.bias.data.zero_()

    def forward(self, user, memory):
        user = user.to(device)
        memory = memory.to(device)
        user_embedding = self.user_embeddings(user.long())
        

        item_embeddings = self.item_embeddings(memory.long())
        drr_ave = self.drr_ave(item_embeddings).squeeze(1)
        
        return torch.cat((user_embedding, user_embedding * drr_ave, drr_ave), 1).to(device)


In [142]:
class ValueNetwork(nn.Module):
    def __init__(self, state_repr_dim, action_emb_dim, hidden_dim):
        super().__init__()

        self.layers = nn.Sequential(
            nn.Linear(state_repr_dim + action_emb_dim, hidden_dim), 
            nn.ReLU(), 
            nn.Linear(hidden_dim, 1)
        )

        self.initialize()
        
    def initialize(self):
        for layer in self.layers:
            if isinstance(layer, nn.Linear):
                nn.init.kaiming_uniform_(layer.weight)
        
    def forward(self, state, action):
        x = torch.cat([state, action], 1)
        x = self.layers(x)
        return x


In [143]:
class PolicyNetwork(nn.Module):
    def __init__(self, embedding_dim, hidden_dim):
        super().__init__()
    
        self.layers = nn.Sequential(
            nn.Linear(embedding_dim * 3, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, embedding_dim)
        )
        
        self.initialize()
        self.to(device)

    def initialize(self):
        for layer in self.layers:
            if isinstance(layer, nn.Linear):
                nn.init.kaiming_uniform_(layer.weight)
        
    def evaluate(self, action_emb, epsilon=1e-6):
                
        mean, log_std = torch.chunk(action_emb.float(), 2, dim=-1)
        
        std = log_std.exp()
        normal = Normal(mean, std)
        z = normal.sample()

        act = torch.tanh(z)
        
        log_prob = normal.log_prob(z) - torch.log(1 - act.pow(2) + epsilon)
        log_prob = log_prob.sum(-1, keepdim=True)

        return action_emb, log_prob, z, mean, log_std
    

    def forward(self, state):
        return self.layers(state)
    
    def get_action(self, user, memory, state_repr, 
                   action_emb,
                   items=torch.tensor([i for i in range(item_num)]).to(device),
                   return_scores=False
                  ):
        state = state_repr(user, memory).to(device)
        items = items.to(device)
        scores = torch.bmm(state_repr.item_embeddings(items).unsqueeze(0), 
                         action_emb.T.unsqueeze(0)).squeeze(0).to(device)
        if return_scores:
            return scores, torch.gather(items, 0, scores.argmax(0)).to(device)
        else:
            return torch.gather(items, 0, scores.argmax(0)).to(device)

In [144]:
valid_dataset = EvalDataset(
    np.array(test_data)[np.array(test_data)[:, 0] == 6039], 
    item_num, 
    test_matrix)
valid_loader = td.DataLoader(valid_dataset, batch_size=100, shuffle=False)

full_dataset = EvalDataset(np.array(test_data), item_num, test_matrix)
full_loader = td.DataLoader(full_dataset, batch_size=100, shuffle=False)

Resetting dataset
Resetting dataset


In [145]:
def run_evaluation(net, state_representation, training_env_memory, loader=valid_loader):
    hits = []
    dcgs = []
    test_env = Env(test_matrix)
    test_env.memory = training_env_memory.copy()
    user, memory = test_env.reset(int(to_np(next(iter(valid_loader))['user'])[0]))
    user.to(device)
    memory.to(device)
    for batch in loader:
        action_emb = net(state_repr(user, memory)).to(device)
        scores, action = net.get_action(
            batch['user'], 
            torch.tensor(test_env.memory[to_np(batch['user']).astype(int), :]).to(device), 
            state_representation, 
            action_emb,
            batch['item'].long(), 
            return_scores=True
        )
        user, memory, reward, done = test_env.step(action)

        
        _, ind = scores[:, 0].topk(10)
        predictions = torch.take(batch['item'].to(device), ind).cpu().numpy().tolist()
        actual = batch['item'][0].item()
        hits.append(hit_metric(predictions, actual))
        dcgs.append(dcg_metric(predictions, actual))
        
    return np.mean(hits), np.mean(dcgs)

In [146]:
train_env = Env(train_matrix)


state_repr = State_Repr_Module(user_num, item_num, params['embedding_dim'] , params['hidden_dim']).to(device)
value_net  = ValueNetwork(params['embedding_dim'] * 3, params['embedding_dim'], params['hidden_dim']).to(device)
target_value_net = ValueNetwork(params['embedding_dim'] * 3, params['embedding_dim'], params['hidden_dim']).to(device)


policy_net = PolicyNetwork(params['embedding_dim'], params['hidden_dim']).to(device)

for target_param, param in zip(target_value_net.parameters(), value_net.parameters()):
    target_param.data.copy_(param.data)
    

value_criterion  = nn.MSELoss()
soft_q_criterion = nn.MSELoss()

value_lr  = 3e-4
soft_q_lr = 3e-4
policy_lr = 3e-4

# value_optimizer  = optim.Adam(value_net.parameters(), lr=value_lr)
# policy_optimizer = optim.Adam(policy_net.parameters(), lr=policy_lr)


replay_buffer_size = params['buffer_size']
replay_buffer = Prioritized_Buffer(replay_buffer_size)
writer = SummaryWriter(log_dir=params['log_dir'])

In [147]:
optimizer_value  = Ranger(value_net.parameters(),  lr=params['value_lr'], 
                          weight_decay=params['value_decay'])
optimizer_policy = Ranger(policy_net.parameters(), lr=params['policy_lr'], 
                          weight_decay=params['policy_decay'])
state_repr_optimizer = Ranger(state_repr.parameters(), lr=params['state_repr_lr'], 
                              weight_decay=params['state_repr_decay'])

In [148]:
def ppo_update(train_env, policy_net, value_net, state_repr, optimizer_policy, optimizer_value, replay_buffer, clip_param=0.1, ppo_epoch=4, mini_batch_size=32, value_coef=0.5, entropy_coef=0.01):
    # Sample data from the replay buffer
    user, memory, action, reward, next_user, next_memory, done = replay_buffer.sample(mini_batch_size)

    # Convert data to tensors and move to the device
    user = torch.tensor(user, dtype=torch.long).to(device)
    memory = torch.tensor(memory, dtype=torch.float32).to(device)
    action = torch.tensor(action, dtype=torch.float32).to(device)
    reward = torch.tensor(reward, dtype=torch.float32).to(device)
    next_user = torch.tensor(next_user, dtype=torch.long).to(device)
    next_memory = torch.tensor(next_memory, dtype=torch.float32).to(device)
    done = torch.tensor(done, dtype=torch.float32).to(device)

    # Compute the current state and value predictions for the sampled data
    state = state_repr(user, memory)
    value = value_net(state, policy_net(state))

    # Compute the next state and value predictions for the sampled data
    with torch.no_grad():
        next_state = state_repr(next_user, next_memory)
        next_value = value_net(next_state, policy_net(next_state))
        mask = 1 - done

        # Compute the target value
        target_value = reward + mask * next_value

    # Compute the advantage
    advantage = target_value - value

    # Create mini-batches for training
    batch_size = user.shape[0]
    num_samples = user.shape[0]
    indices = torch.randperm(num_samples)
    mini_batch_count = num_samples // mini_batch_size

    for batch in range(mini_batch_count):
        # Compute the actions and log probabilities for the sampled data
        batch_indices = indices[batch * mini_batch_size : (batch + 1) * mini_batch_size]
        sampled_state = state[batch_indices]
        sampled_memory = memory[batch_indices]
        sampled_action = action[batch_indices]
        sampled_old_log_prob, _, _, _, _ = policy_net.evaluate(sampled_action)

        new_log_prob, _, _, _, _ = policy_net.evaluate(sampled_action)
        ratio = (new_log_prob - sampled_old_log_prob).exp()

        # Compute the clipped surrogate loss
        policy_loss1 = ratio * advantage[indices]
        policy_loss2 = torch.clamp(ratio, 1 - clip_param, 1 + clip_param) * advantage[indices]
        policy_loss = -torch.min(policy_loss1, policy_loss2).mean()

        # Compute the value loss
        sampled_target_value = target_value[indices]
        sampled_value = value[indices]
        clipped_value = sampled_value + torch.clamp(sampled_target_value - sampled_value, -clip_param, clip_param)
        value_loss1 = (sampled_value - sampled_target_value).pow(2)
        value_loss2 = (clipped_value - sampled_target_value).pow(2)
        value_loss = 0.5 * torch.max(value_loss1, value_loss2).mean()

        # Compute the entropy loss
        entropy_loss = -(new_log_prob.exp() * new_log_prob).mean()

        # Compute the total loss
        loss = policy_loss + value_coef * value_loss - entropy_coef * entropy_loss
        
        
        state_repr_optimizer.zero_grad()
        
        # Update the policy network
        optimizer_policy.zero_grad()
        loss.backward(retain_graph=True)
        optimizer_policy.step()

        # Update the value network
        optimizer_value.zero_grad()
        value_loss.backward(retain_graph=True)
        optimizer_value.step()

        state_repr_optimizer.step()


In [None]:
train_env = Env(train_matrix)
hits, dcgs = [], []
hits_all, dcgs_all = [], []
step, best_step, best_step_all = 0, 0, 0
users = np.random.permutation(appropriate_users)
ou_noise = OUNoise(params['embedding_dim'], decay_period=10)

for u in tqdm.tqdm(users):
    user, memory = train_env.reset(u)
    if params['ou_noise']:
        ou_noise.reset()
    for t in range(int(train_matrix[u].sum())):
        state = state_repr(user, memory)
        action_emb_old = policy_net(state)
        if params['ou_noise']:
            action_emb_old = ou_noise.get_action(action_emb_old.detach().cpu().numpy()[0], t)
        action_old = policy_net.get_action(
            user,
            torch.tensor(train_env.memory[to_np(user).astype(int), :]),
            state_repr,
            action_emb_old,
            torch.tensor(
                [item for item in train_env.available_items
                if item not in train_env.viewed_items]
            ).long()
        )
        user, memory, reward, done = train_env.step(
            action_old,
            action_emb_old,
            buffer=replay_buffer
        )

        if len(replay_buffer) > params['batch_size']:
            ppo_update(train_env, policy_net, value_net, state_repr, policy_optimizer, value_optimizer, replay_buffer)

        if step % 100 == 0 and step > 0:
            hit, dcg = run_evaluation(policy_net, state_repr, train_env.memory)
            writer.add_scalar('hit', hit, step)
            writer.add_scalar('dcg', dcg, step)
            hits.append(hit)
            dcgs.append(dcg)
#             print(f'Hit rate: {hit}, dcg: {dcg}')
            if np.mean(np.array([hit, dcg]) - np.array([hits[best_step], dcgs[best_step]])) > 0:
                best_step = step // 100
                torch.save(policy_net.state_dict(), params['log_dir'] + 'policy_net.pth')
                torch.save(value_net.state_dict(), params['log_dir'] + 'value_net.pth')
                torch.save(state_repr.state_dict(), params['log_dir'] + 'state_repr.pth')
        if step % 10000 == 0 and step > 0:
            hit, dcg = run_evaluation(policy_net, state_repr, train_env.memory, full_loader)
            print(f'Hit {hit}, DCG {dcg}')
            writer.add_scalar('hit_all', hit, step)
            writer.add_scalar('dcg_all', dcg, step)
            hits_all.append(hit)
            dcgs_all.append(dcg)
            if np.mean(np.array([hit, dcg]) - np.array([hits_all[best_step_all], dcgs_all[best_step_all]])) > 0:
                best_step_all = step // 10000
                torch.save(policy_net.state_dict(), params['log_dir'] + 'best_policy_net.pth')
                torch.save(value_net.state_dict(), params['log_dir'] + 'best_value_net.pth')
                torch.save(state_repr.state_dict(), params['log_dir'] + 'best_state_repr.pth')
        step += 1


  2%|▏         | 114/4699 [04:23<57:27:22, 45.11s/it]

Hit 0.04381630386300627, DCG 0.017724674975541557


  5%|▍         | 225/4699 [06:20<2:25:18,  1.95s/it] 

Hit 0.019258183853560838, DCG 0.007700193490892519


  7%|▋         | 324/4699 [11:26<1:30:09,  1.24s/it] 

Hit 0.011334516927436353, DCG 0.004658194546915398


  9%|▉         | 438/4699 [16:09<1:40:49,  1.42s/it] 

Hit 0.008500887695577265, DCG 0.0035258384099512343


 12%|█▏        | 551/4699 [23:38<52:27:58, 45.53s/it]

Hit 0.006611801541004539, DCG 0.0028333169801887394


 14%|█▍        | 652/4699 [25:53<1:02:33,  1.08it/s] 

Hit 0.005571055002142713, DCG 0.002409691092017178


 16%|█▋        | 773/4699 [30:51<58:09,  1.12it/s]   

Hit 0.004958851155753404, DCG 0.0021033038548802204


 19%|█▉        | 894/4699 [37:01<45:38,  1.39it/s]   

Hit 0.004425359232471292, DCG 0.0018777882125966604


 21%|██▏       | 1001/4699 [41:46<54:35,  1.13it/s]  

Hit 0.003909358847657446, DCG 0.0016537260794509392


 24%|██▎       | 1107/4699 [46:38<1:12:36,  1.21s/it] 

Hit 0.003655731539867589, DCG 0.0015154499484454317


 25%|██▌       | 1193/4699 [55:28<45:21:54, 46.58s/it]

Hit 0.0032971550012681366, DCG 0.0013594984497872543


 27%|██▋       | 1282/4699 [58:14<1:18:25,  1.38s/it] 

Hit 0.0028773580780297532, DCG 0.0011926457052620916


 29%|██▉       | 1381/4699 [1:03:26<1:11:44,  1.30s/it] 

Hit 0.0025887476933033646, DCG 0.001098686101630114


 31%|███       | 1467/4699 [1:09:09<59:45,  1.11s/it]   

Hit 0.0024400696163231037, DCG 0.0010357510293893075


 33%|███▎      | 1571/4699 [1:16:54<40:30:14, 46.62s/it]

Hit 0.0024400696163231037, DCG 0.0010291329393013538


 36%|███▌      | 1696/4699 [1:19:33<38:44,  1.29it/s]   

Hit 0.0023263746162793747, DCG 0.0009813367773389063


 38%|███▊      | 1792/4699 [1:24:56<44:04,  1.10it/s]   

Hit 0.002265154231640444, DCG 0.0009492803091702799


 41%|████      | 1911/4699 [1:30:13<1:21:21,  1.75s/it] 

Hit 0.002221425385469779, DCG 0.0009201580974813043


 43%|████▎     | 2010/4699 [1:35:49<3:24:55,  4.57s/it] 

Hit 0.00210773038542605, DCG 0.000880694578888594


 45%|████▌     | 2121/4699 [1:41:30<2:17:25,  3.20s/it] 

Hit 0.0020640015392553853, DCG 0.0008609345359482524


 47%|████▋     | 2227/4699 [1:50:56<38:57:31, 56.74s/it]

Hit 0.002142713462362582, DCG 0.0008738356589754199


 50%|█████     | 2358/4699 [1:54:09<20:49,  1.87it/s]   

Hit 0.002160205000830848, DCG 0.0008698665600078085


 52%|█████▏    | 2459/4699 [2:00:22<56:07,  1.50s/it]   

Hit 0.002247662693172178, DCG 0.0008870459323876657


 54%|█████▍    | 2560/4699 [2:06:25<1:54:44,  3.22s/it] 

Hit 0.0023701034624500398, DCG 0.0009179746489783227


 57%|█████▋    | 2679/4699 [2:12:01<46:48,  1.39s/it]   

Hit 0.0026412223087081623, DCG 0.0009813594386858283


 59%|█████▉    | 2790/4699 [2:17:53<53:34,  1.68s/it]   

Hit 0.0030697650011806787, DCG 0.0011004351580624688


 62%|██████▏   | 2899/4699 [2:23:45<51:41,  1.72s/it]   

Hit 0.0038481384630185147, DCG 0.0013296797135272174


 64%|██████▍   | 3015/4699 [2:29:43<35:26,  1.26s/it]   

Hit 0.004434105001705425, DCG 0.0015089836129924528


 66%|██████▌   | 3111/4699 [2:36:28<27:22,  1.03s/it]   

Hit 0.005256207309713926, DCG 0.0017466282875363947


 68%|██████▊   | 3206/4699 [2:42:24<1:06:08,  2.66s/it] 

Hit 0.006113292694658959, DCG 0.002001679685851996


 70%|███████   | 3293/4699 [2:48:53<41:31,  1.77s/it]   

Hit 0.01610096116003883, DCG 0.005095897644907555


 73%|███████▎  | 3409/4699 [2:55:01<36:32,  1.70s/it]   

Hit 0.019126997315048844, DCG 0.0060385471200033805


 75%|███████▌  | 3530/4699 [3:01:22<16:56,  1.15it/s]   

Hit 0.02312381385504762, DCG 0.007315381305519839


 77%|███████▋  | 3629/4699 [3:07:39<37:00,  2.07s/it]   

Hit 0.02690198616419307, DCG 0.008520789553227895


 80%|███████▉  | 3740/4699 [3:14:38<24:30,  1.53s/it]   

Hit 0.03072388731950919, DCG 0.009758634953284766


 82%|████████▏ | 3846/4699 [3:21:16<46:21,  3.26s/it]   

Hit 0.03446707655171811, DCG 0.01100225402278662


 84%|████████▍ | 3957/4699 [3:27:39<20:10,  1.63s/it]   

Hit 0.038009113091541966, DCG 0.012187708590194723


 87%|████████▋ | 4068/4699 [3:34:10<42:43,  4.06s/it]  

Hit 0.04016057232313868, DCG 0.012946955785816845


 87%|████████▋ | 4109/4699 [3:38:22<09:52,  1.00s/it]  

In [None]:
import matplotlib.pyplot as plt
plt.plot(steps, dcgs, label='nDCG')
plt.plot(steps, hits, label='Hit rates')


plt.xlabel('Step')
plt.title('Hit Rates and nDCG in Each Step')

plt.legend()
plt.show()

In [None]:
np.max(hits_all)

In [None]:
np.max(dcgs_all)

In [None]:
torch.save(policy_net.state_dict(), params['log_dir'] + 'policy_net_final.pth')
torch.save(value_net.state_dict(), params['log_dir'] + 'value_net_final.pth')
torch.save(state_repr_sac.state_dict(), params['log_dir'] + 'state_repr_final.pth')

In [None]:
# we need memory for validation, so it's better to save it and not wait next time 
with open('logs/memory.pickle', 'wb') as f:
    pickle.dump(train_env.memory, f)
    
with open('logs/memory.pickle', 'rb') as f:
    memory = pickle.load(f)

In [None]:
no_ou_state_repr = State_Repr_Module(user_num, item_num, params['embedding_dim'], params['hidden_dim'])
no_ou_policy_net = PolicyNetwork(params['embedding_dim'], params['hidden_dim'])
no_ou_state_repr.load_state_dict(torch.load('logs/final/' + 'best_state_repr.pth'))
no_ou_policy_net.load_state_dict(torch.load('logs/final/' + 'best_policy_net.pth'))
    
hit, dcg = run_evaluation(no_ou_policy_net, no_ou_state_repr, memory, full_loader)
print('hit rate: ', hit, 'dcg: ', dcg)