In [None]:
pip install pytorch-ranger

In [None]:
import math
import random

import gym
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Normal
from torch.utils.tensorboard import SummaryWriter


In [None]:
from collections import defaultdict
import os
import pickle
import random
import requests
import time
import tqdm

from IPython.core.debugger import set_trace
import numpy as np
import pandas as pd
from pytorch_ranger import Ranger
import torch
import torch.nn as nn
import torch.nn.functional as F 
import torch.utils.data as td
from torch.utils.tensorboard import SummaryWriter

In [None]:
from IPython.display import clear_output
import matplotlib.pyplot as plt
%matplotlib inline

**USE CUDA**

In [None]:
use_cuda = torch.cuda.is_available()
# device   = torch.device("cuda:0" if use_cuda else "cpu")
device = 'cpu'

**ULTIL**

In [None]:
from collections import defaultdict
import os
import random
import time
import tqdm

import numpy as np
import pandas as pd
import scipy.sparse as sp
import torch.utils.data as td


class EvalDataset(td.Dataset):
    def __init__(self, positive_data, item_num, positive_mat, negative_samples=99):
        super(EvalDataset, self).__init__()
        self.positive_data = np.array(positive_data)
        self.item_num = item_num
        self.positive_mat = positive_mat
        self.negative_samples = negative_samples
        
        self.reset()
        
    def reset(self):
        print("Resetting dataset")
        data = self.create_valid_data()
        labels = np.zeros(len(self.positive_data) * (1 + self.negative_samples))
        labels[::1+self.negative_samples] = 1
        self.data = np.concatenate([
            np.array(data), 
            np.array(labels)[:, np.newaxis]], 
            axis=1
        )

    def create_valid_data(self):
        valid_data = []
        for user, positive in self.positive_data:
            valid_data.append([user, positive])
            for i in range(self.negative_samples):
                negative = np.random.randint(self.item_num)
                while (user, negative) in self.positive_mat:
                    negative = np.random.randint(self.item_num)
                    
                valid_data.append([user, negative])
        return valid_data
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        user, item, label = self.data[idx]
        output = {
            "user": user,
            "item": item,
            "label": np.float32(label),
        }
        return output


#https://github.com/vitchyr/rlkit/blob/master/rlkit/exploration_strategies/ou_strategy.py
class OUNoise(object):
    def __init__(self, action_dim, mu=0.0, theta=0.15, max_sigma=0.4, min_sigma=0.4, decay_period=100000):
        self.mu           = mu
        self.theta        = theta
        self.sigma        = max_sigma
        self.max_sigma    = max_sigma
        self.min_sigma    = min_sigma
        self.decay_period = decay_period
        self.action_dim   = action_dim
        self.reset()

    def reset(self):
        self.state = np.ones(self.action_dim) * self.mu

    def evolve_state(self):
        x  = self.state
        dx = self.theta * (self.mu - x) + self.sigma * np.random.randn(self.action_dim)
        self.state = x + dx
        return self.state

    def get_action(self, action, t=0):
        ou_state = self.evolve_state()
        self.sigma = self.max_sigma - (self.max_sigma - self.min_sigma) * min(1.0, t / self.decay_period)
        return torch.tensor([action + ou_state]).float()


class Prioritized_Buffer(object):
    def __init__(self, capacity, prob_alpha=0.6):
        self.prob_alpha = prob_alpha
        self.capacity   = capacity
        self.buffer     = []
        self.pos        = 0
        self.priorities = np.zeros((capacity,), dtype=np.float32)
    
    def push(self, user, memory, action, reward, next_user, next_memory, done):
        max_prio = self.priorities.max() if self.buffer else 1.0
        
        if len(self.buffer) < self.capacity:
            self.buffer.append((user, memory, action, reward, next_user, next_memory, done))
        else:
            self.buffer[self.pos] = (user, memory, action, reward, next_user, next_memory, done)
        
        self.priorities[self.pos] = max_prio
        self.pos = (self.pos + 1) % self.capacity
    
    def sample(self, batch_size, beta=0.4):
        if len(self.buffer) == self.capacity:
            prios = self.priorities
        else:
            prios = self.priorities[:self.pos]
        
        probs  = prios ** self.prob_alpha
        probs /= probs.sum()
        
        indices = np.random.choice(len(self.buffer), batch_size, p=probs)
        samples = [self.buffer[idx] for idx in indices]

        total    = len(self.buffer)
        weights  = (total * probs[indices]) ** (-beta)
        weights /= weights.max()
        weights  = np.array(weights, dtype=np.float32)

        batch       = list(zip(*samples))
        user        = np.concatenate(batch[0])
        memory      = np.concatenate(batch[1])
        action      = batch[2]
        reward      = batch[3]
        next_user   = np.concatenate(batch[4])
        next_memory = np.concatenate(batch[5])
        done        = batch[6]

        return user, memory, action, reward, next_user, next_memory, done

    def update_priorities(self, batch_indices, batch_priorities):
        for idx, prio in zip(batch_indices, batch_priorities):
            self.priorities[idx] = prio

    def __len__(self):
        return len(self.buffer)


def get_beta(idx, beta_start=0.4, beta_steps=100000):
    return min(1.0, beta_start + idx * (1.0 - beta_start) / beta_steps)

def preprocess_data(data_dir, train_rating):
    data = pd.read_csv(os.path.join(data_dir, train_rating), 
                       sep='\t', header=None, names=['user', 'item', 'rating'], 
                       usecols=[0, 1, 2], dtype={0: np.int32, 1: np.int32, 2: np.int8})
    data = data[data['rating'] > 3][['user', 'item']]
    user_num = data['user'].max() + 1
    item_num = data['item'].max() + 1

    train_data = data.sample(frac=0.8, random_state=16)
    test_data = data.drop(train_data.index).values.tolist()
    train_data = train_data.values.tolist()

    train_mat = defaultdict(int)
    test_mat = defaultdict(int)
    for user, item in train_data:
        train_mat[user, item] = 1.0
    for user, item in test_data:
        test_mat[user, item] = 1.0
    train_matrix = sp.dok_matrix((user_num, item_num), dtype=np.float32)
    dict.update(train_matrix, train_mat)
    test_matrix = sp.dok_matrix((user_num, item_num), dtype=np.float32)
    dict.update(test_matrix, test_mat)
    
    appropriate_users = np.arange(user_num).reshape(-1, 1)[(train_matrix.sum(1) >= 20)]
    
    return (train_data, train_matrix, test_data, test_matrix, 
            user_num, item_num, appropriate_users)

def to_np(tensor):
    return tensor.detach().cpu().numpy()

def hit_metric(recommended, actual):
    return int(actual in recommended)

def dcg_metric(recommended, actual):
    if actual in recommended:
        index = recommended.index(actual)
        return np.reciprocal(np.log2(index + 2))
    return 0

In [None]:
data_dir = "data"
rating = "ml-1m.train.rating"

params = {
    'batch_size': 512,
    'embedding_dim': 8,
    'hidden_dim': 16,
    'N': 5, # memory size for state_repr
    'ou_noise':False,
    
    'value_lr': 1e-5,
    'value_decay': 1e-4,
    'policy_lr': 1e-5,
    'policy_decay': 1e-6,
    'state_repr_lr': 1e-5,
    'state_repr_decay': 1e-3,
    'log_dir': 'logs/final/',
    'gamma': 0.8,
    'min_value': -10,
    'max_value': 10,
    'soft_tau': 1e-3,
    
    'buffer_size': 1000000
}


In [None]:
# Movielens (1M) data from the https://github.com/hexiangnan/neural_collaborative_filtering
if not os.path.isdir('./data'):
    os.mkdir('./data')
    
file_path = os.path.join(data_dir, rating)
if os.path.exists(file_path):
    print("Skip loading " + file_path)
else:
    with open(file_path, "wb") as tf:
        print("Load " + file_path)
        r = requests.get("https://raw.githubusercontent.com/hexiangnan/neural_collaborative_filtering/master/Data/" + rating)
        tf.write(r.content)
(train_data, train_matrix, test_data, test_matrix, 
 user_num, item_num, appropriate_users) = preprocess_data(data_dir, rating)

**Soft Actor-Critic**

In [None]:
class ValueNetwork(nn.Module):
    def __init__(self, state_repr_dim, action_emb_dim, hidden_dim):
        super().__init__()

        self.layers = nn.Sequential(
            nn.Linear(state_repr_dim + action_emb_dim, hidden_dim), 
            nn.ReLU(), 
            nn.Linear(hidden_dim, 1)
        )

        self.initialize()
        
    def initialize(self):
        for layer in self.layers:
            if isinstance(layer, nn.Linear):
                nn.init.kaiming_uniform_(layer.weight)
        
    def forward(self, state, action):
        x = torch.cat([state, action], 1)
        x = self.layers(x)
        return x

In [None]:
class PolicyNetwork(nn.Module):
    def __init__(self, embedding_dim, hidden_dim):
        super().__init__()
    
        self.layers = nn.Sequential(
            nn.Linear(embedding_dim * 3, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, embedding_dim)
        )
        
        self.initialize()
        self.to(device)

    def initialize(self):
        for layer in self.layers:
            if isinstance(layer, nn.Linear):
                nn.init.kaiming_uniform_(layer.weight)
        
    def evaluate(self, action_emb, epsilon=1e-6):
                
        mean, log_std = torch.chunk(action_emb.float(), 2, dim=-1)
        
        std = log_std.exp()
        normal = Normal(mean, std)
        z = normal.sample()

        act = torch.tanh(z)
        
        log_prob = normal.log_prob(z) - torch.log(1 - act.pow(2) + epsilon)
        log_prob = log_prob.sum(-1, keepdim=True)

        return action_emb, log_prob, z, mean, log_std
    

    def forward(self, state):
        return self.layers(state)
    
    def get_action(self, user, memory, state_repr_sac, 
                   action_emb,
                   items=torch.tensor([i for i in range(item_num)]).to(device),
                   return_scores=False
                  ):
        state = state_repr_sac(user, memory).to(device)
        items = items.to(device)
        scores = torch.bmm(state_repr_sac.item_embeddings(items).unsqueeze(0), 
                         action_emb.T.unsqueeze(0)).squeeze(0).to(device)
        if return_scores:
            return scores, torch.gather(items, 0, scores.argmax(0)).to(device)
        else:
            return torch.gather(items, 0, scores.argmax(0)).to(device)

In [None]:
class SoftQNetwork(nn.Module):
    def __init__(self, num_inputs, num_actions, hidden_size, init_w=3e-3):
        super(SoftQNetwork, self).__init__()
        
        self.linear1 = nn.Linear(num_inputs + num_actions, hidden_size)
        self.linear2 = nn.Linear(hidden_size, hidden_size)
        self.linear3 = nn.Linear(hidden_size, 1)
        
        self.linear3.weight.data.uniform_(-init_w, init_w)
        self.linear3.bias.data.uniform_(-init_w, init_w)
        self.to(device)
        
    def forward(self, state, action):
        x = torch.cat([state, action], 1)
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        x = self.linear3(x)
        return x


In [None]:
def soft_q_update(batch_size, 
           gamma=0.99,
           mean_lambda=1e-3,
           std_lambda=1e-3,
           z_lambda=0.0,
           soft_tau=1e-2,
          ):
    user, memory, action, reward, next_user, next_memory, done = replay_buffer.sample(batch_size)

    user        = torch.FloatTensor(user).to(device)
    memory      = torch.FloatTensor(memory).to(device)
    action      = torch.FloatTensor(action).to(device)
    reward      = torch.FloatTensor(reward).to(device)
    next_user   = torch.FloatTensor(next_user).to(device)
    next_memory = torch.FloatTensor(next_memory).to(device)
    done        = torch.FloatTensor(np.float32(done)).unsqueeze(1).to(device)
    
    state       = state_repr_sac(user, memory)
    next_state     = state_repr_sac(next_user, next_memory)

    expected_q_value = soft_q_net(state, action)
    expected_value   = value_net(state, policy_net(state))
    
    action_emb = policy_net(state_repr_sac(user, memory))
    new_action, log_prob, z, mean, log_std = policy_net.evaluate(action_emb)

    target_value = target_value_net(next_state, policy_net(next_state))
    next_q_value = reward + (1 - done) * gamma * target_value
    q_value_loss = soft_q_criterion(expected_q_value, next_q_value.detach())

    expected_new_q_value = soft_q_net(state, new_action)
    next_value = expected_new_q_value - log_prob
    value_loss = value_criterion(expected_value, next_value.detach())

    log_prob_target = expected_new_q_value - expected_value
    policy_loss = (log_prob * (log_prob - log_prob_target).detach()).mean()
    

    mean_loss = mean_lambda * mean.pow(2).mean()
    std_loss  = std_lambda  * log_std.pow(2).mean()
    z_loss    = z_lambda    * z.pow(2).sum(1).mean()

    policy_loss += mean_loss + std_loss + z_loss

    soft_q_optimizer.zero_grad()
    q_value_loss.backward(retain_graph=True )
    soft_q_optimizer.step()

    value_optimizer.zero_grad()
    value_loss.backward(retain_graph=True )
    value_optimizer.step()

    policy_optimizer.zero_grad()
    policy_loss.backward(retain_graph=True )
    policy_optimizer.step()
    
    
    for target_param, param in zip(target_value_net.parameters(), value_net.parameters()):
        target_param.data.copy_(
            target_param.data * (1.0 - soft_tau) + param.data * soft_tau
        )

In [None]:
class Env():
    def __init__(self, user_item_matrix):
        self.matrix = user_item_matrix
        self.item_count = item_num
        self.memory = np.ones([user_num, params['N']]) * item_num
        # memory is initialized as [item_num] * N for each user
        # it is padding indexes in state_repr and will result in zero embeddings

    def reset(self, user_id):
        self.user_id = user_id
        self.viewed_items = []
        self.related_items = np.argwhere(self.matrix[self.user_id] > 0)[:, 1]
        self.num_rele = len(self.related_items)
        self.nonrelated_items = np.random.choice(
            list(set(range(self.item_count)) - set(self.related_items)), self.num_rele)
        self.available_items = np.zeros(self.num_rele * 2)
        self.available_items[::2] = self.related_items
        self.available_items[1::2] = self.nonrelated_items
        
        return torch.tensor([self.user_id], device=device), torch.tensor(self.memory[[self.user_id], :], device=device)
    
    def step(self, action, action_emb=None, buffer=None):
        initial_user = self.user_id
        initial_memory = self.memory[[initial_user], :]
        
        reward = float(to_np(action)[0] in self.related_items)
        self.viewed_items.append(to_np(action)[0])

        if action.is_cuda:
            Action = action.cpu()
        else:
            Action = action

        if reward:
            if len(Action) == 1:
                self.memory[self.user_id] = list(self.memory[self.user_id][1:]) + [Action]
            else:
                self.memory[self.user_id] = list(self.memory[self.user_id][1:]) + [Action[0]]
                
        if len(self.viewed_items) == len(self.related_items):
            done = 1
        else:
            done = 0
            
        if buffer is not None:
            buffer.push(np.array([initial_user]), np.array(initial_memory), to_np(action_emb)[0], 
                        np.array([reward]), np.array([self.user_id]), self.memory[[self.user_id], :], np.array([reward]))

        return torch.tensor([self.user_id], device=device), torch.tensor(self.memory[[self.user_id], :], device=device), reward, done

In [None]:
class State_Repr_Module(nn.Module):
    def __init__(self, user_num, item_num, embedding_dim, hidden_dim):
        super().__init__()
        self.user_embeddings = nn.Embedding(user_num, embedding_dim)
        self.item_embeddings = nn.Embedding(item_num+1, embedding_dim, padding_idx=int(item_num))
        self.drr_ave = torch.nn.Conv1d(in_channels=params['N'], out_channels=1, kernel_size=1)
        
        self.initialize()
        self.to(device)
            
    def initialize(self):
        nn.init.normal_(self.user_embeddings.weight, std=0.01)
        nn.init.normal_(self.item_embeddings.weight, std=0.01)
        self.item_embeddings.weight.data[-1].zero_()
        nn.init.uniform_(self.drr_ave.weight)
        self.drr_ave.bias.data.zero_()

    def forward(self, user, memory):
        user = user.to(device)
        memory = memory.to(device)
        user_embedding = self.user_embeddings(user.long())
        

        item_embeddings = self.item_embeddings(memory.long())
        drr_ave = self.drr_ave(item_embeddings).squeeze(1)
        
        return torch.cat((user_embedding, user_embedding * drr_ave, drr_ave), 1).to(device)

In [None]:
valid_dataset = EvalDataset(
    np.array(test_data)[np.array(test_data)[:, 0] == 6039], 
    item_num, 
    test_matrix)
valid_loader = td.DataLoader(valid_dataset, batch_size=100, shuffle=False)

full_dataset = EvalDataset(np.array(test_data), item_num, test_matrix)
full_loader = td.DataLoader(full_dataset, batch_size=100, shuffle=False)

In [None]:
def run_evaluation(net, state_representation, training_env_memory, loader=valid_loader):
    hits = []
    dcgs = []
    test_env = Env(test_matrix)
    test_env.memory = training_env_memory.copy()
    user, memory = test_env.reset(int(to_np(next(iter(valid_loader))['user'])[0]))
    user.to(device)
    memory.to(device)
    for batch in loader:
        action_emb = net(state_repr_sac(user, memory)).to(device)
        scores, action = net.get_action(
            batch['user'], 
            torch.tensor(test_env.memory[to_np(batch['user']).astype(int), :]).to(device), 
            state_representation, 
            action_emb,
            batch['item'].long(), 
            return_scores=True
        )
        user, memory, reward, done = test_env.step(action)

        
        _, ind = scores[:, 0].topk(10)
        predictions = torch.take(batch['item'].to(device), ind).cpu().numpy().tolist()
        actual = batch['item'][0].item()
        hits.append(hit_metric(predictions, actual))
        dcgs.append(dcg_metric(predictions, actual))
        
    return np.mean(hits), np.mean(dcgs)

In [None]:
train_env = Env(train_matrix)


state_repr_sac = State_Repr_Module(user_num, item_num, params['embedding_dim'] , params['hidden_dim']).to(device)
value_net  = ValueNetwork(params['embedding_dim'] * 3, params['embedding_dim'], params['hidden_dim']).to(device)
target_value_net = ValueNetwork(params['embedding_dim'] * 3, params['embedding_dim'], params['hidden_dim']).to(device)

soft_q_net = SoftQNetwork(params['embedding_dim'] * 3, params['embedding_dim'], params['hidden_dim']).to(device)

policy_net = PolicyNetwork(params['embedding_dim'], params['hidden_dim']).to(device)

for target_param, param in zip(target_value_net.parameters(), value_net.parameters()):
    target_param.data.copy_(param.data)
    

value_criterion  = nn.MSELoss()
soft_q_criterion = nn.MSELoss()

value_lr  = 3e-4
soft_q_lr = 3e-4
policy_lr = 3e-4

value_optimizer  = optim.Adam(value_net.parameters(), lr=value_lr)
soft_q_optimizer = optim.Adam(soft_q_net.parameters(), lr=soft_q_lr)
policy_optimizer = optim.Adam(policy_net.parameters(), lr=policy_lr)


replay_buffer_size = params['buffer_size']
replay_buffer = Prioritized_Buffer(replay_buffer_size)
writer = SummaryWriter(log_dir=params['log_dir'])

In [None]:
steps = []
rewards     = []
batch_size  = 128
np.random.seed(16)
hits, dcgs = [], []
hits_all, dcgs_all = [], []
step, best_step = 0, 0
step, best_step, best_step_all = 0, 0, 0
users = np.random.permutation(appropriate_users)
ou_noise = OUNoise(params['embedding_dim'], decay_period=10)

In [None]:
step = 0
users = np.random.permutation(appropriate_users)
for u in tqdm.tqdm(users):
    user, memory = train_env.reset(u)
    user = user.to(device)
    memory = memory.to(device)
    for t in range(int(train_matrix[u].sum())):
        action_emb = policy_net(state_repr_sac(user, memory)) 
        
        items = torch.tensor(
                    [item for item in train_env.available_items 
                    if item not in train_env.viewed_items]
                ).long().to(device)
        action = policy_net.get_action(
                user, 
                torch.tensor(train_env.memory[to_np(user).astype(int), :]).to(device), 
                state_repr_sac, 
                action_emb,
                torch.tensor(
                    [item for item in train_env.available_items 
                    if item not in train_env.viewed_items]
                ).long().to(device)
            )
        user, memory, reward, done = train_env.step(action, action_emb, buffer=replay_buffer)
        user = user.to(device)
        memory = memory.to(device)
        
        if len(replay_buffer) > params['batch_size']:
                soft_q_update(batch_size)
                
        if step % 100 == 0 and step > 0:
                hit, dcg = run_evaluation(policy_net, state_repr_sac, train_env.memory)
                writer.add_scalar('hit', hit, step)
                writer.add_scalar('dcg', dcg, step)
                hits.append(hit)
                dcgs.append(dcg)
                steps.append(step)
                if np.mean(np.array([hit, dcg]) - np.array([hits[best_step], dcgs[best_step]])) > 0:
                    best_step = step // 100
                    torch.save(policy_net.state_dict(), params['log_dir'] + 'policy_net.pth')
                    torch.save(value_net.state_dict(), params['log_dir'] + 'value_net.pth')
                    torch.save(state_repr_sac.state_dict(), params['log_dir'] + 'state_repr.pth')
        if step % 10000 == 0 and step > 0:
            hit, dcg = run_evaluation(policy_net, state_repr_sac, train_env.memory, full_loader)
            writer.add_scalar('hit_all', hit, step)
            writer.add_scalar('dcg_all', dcg, step)
            hits_all.append(hit)
            dcgs_all.append(dcg)
            if np.mean(np.array([hit, dcg]) - np.array([hits_all[best_step_all], dcgs_all[best_step_all]])) > 0:
                best_step_all = step // 10000
                torch.save(policy_net.state_dict(), params['log_dir'] + 'best_policy_net.pth')
                torch.save(value_net.state_dict(), params['log_dir'] + 'best_value_net.pth')
                torch.save(state_repr_sac.state_dict(), params['log_dir'] + 'best_state_repr.pth')
        step += 1

In [None]:
import matplotlib.pyplot as plt
plt.plot(steps, dcgs, label='nDCG')
plt.plot(steps, hits, label='Hit rates')


plt.xlabel('Step')
plt.title('Hit Rates and nDCG in Each Step')

plt.legend()
plt.show()

In [None]:
np.max(hits_all)

In [None]:
np.max(dcgs_all)

In [None]:
torch.save(policy_net.state_dict(), params['log_dir'] + 'policy_net_final.pth')
torch.save(value_net.state_dict(), params['log_dir'] + 'value_net_final.pth')
torch.save(state_repr_sac.state_dict(), params['log_dir'] + 'state_repr_final.pth')

In [None]:
# we need memory for validation, so it's better to save it and not wait next time 
with open('logs/memory.pickle', 'wb') as f:
    pickle.dump(train_env.memory, f)
    
with open('logs/memory.pickle', 'rb') as f:
    memory = pickle.load(f)

In [None]:
no_ou_state_repr = State_Repr_Module(user_num, item_num, params['embedding_dim'], params['hidden_dim'])
no_ou_policy_net = PolicyNetwork(params['embedding_dim'], params['hidden_dim'])
no_ou_state_repr.load_state_dict(torch.load('logs/final/' + 'state_repr_final.pth'))
no_ou_policy_net.load_state_dict(torch.load('logs/final/' + 'policy_net_final.pth'))
    
hit, dcg = run_evaluation(no_ou_policy_net, no_ou_state_repr, memory, full_loader)
print('hit rate: ', hit, 'dcg: ', dcg)

In [None]:
no_ou_state_repr = State_Repr_Module(user_num, item_num, params['embedding_dim'], params['hidden_dim'])
no_ou_policy_net = PolicyNetwork(params['embedding_dim'], params['hidden_dim'])
no_ou_state_repr.load_state_dict(torch.load('logs/final/' + 'best_state_repr.pth'))
no_ou_policy_net.load_state_dict(torch.load('logs/final/' + 'best_policy_net.pth'))
    
hit, dcg = run_evaluation(no_ou_policy_net, no_ou_state_repr, memory, full_loader)
print('hit rate: ', hit, 'dcg: ', dcg)

In [None]:
random_user = np.random.randint(user_num)
print(random_user)

In [None]:
movies = pd.read_csv('/kaggle/input/movielens/movies.dat', sep='::', header=None, engine='python', names=['id', 'name', 'genre'],  encoding='ISO-8859-1')
# in the code numeration starts with 0
movies[movies['id'].isin(np.argwhere(test_matrix[random_user] > 0)[:, 1] + 1)]

In [None]:
predictions = []
model = no_ou_policy_net

for model, state_representation in zip([no_ou_policy_net], [no_ou_state_repr]):
    example_env = Env(test_matrix)
    user, memory = example_env.reset(random_user)

    user, memory, reward, _ = example_env.step(torch.tensor([3706]))
    user, memory, reward, _ = example_env.step(torch.tensor([1584]))

    preds = []
    for _ in range(3):
        action_emb = model(state_representation(user, memory))
        action = model.get_action(
            user, 
            torch.tensor(example_env.memory[to_np(user).astype(int), :]), 
            state_representation, 
            action_emb,
            torch.tensor(
                [item for item in example_env.available_items 
                if item not in example_env.viewed_items]
            ).long()
        )
        user, memory, reward, _ = example_env.step(action)
        preds.append(action)

    predictions.append(preds)



In [None]:
print(predictions[0])
