# REINFORCE

The following code contains an implementation of the REINFORCE algorithm, **without Off Policy Correction, LSTM state encoder, and Noise Contrastive Estimation**. Look for these in other notebooks.

Also, I am not google staff, and unlike the paper authors, I cannot have online feedback concerning the recommendations.

**I use actor-critic for reward assigning.** In a real-world scenario that would be done through interactive user feedback, but here I use a neural network (critic) that aims to emulate it.

### **Note on this tutorials:**
**They mostly contain low level implementations explaining what is going on inside the library.**

**Most of the stuff explained here is already available out of the box for your usage.**

If you do not care about the detailed implementation with code, go to the [Library Basics]/algorithms how to/reinforce, there is a 20 liner version

In [1]:
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
from torch.distributions import Categorical
import torch_optimizer as optim

import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from time import gmtime, strftime

from IPython.display import clear_output
import matplotlib.pyplot as plt
%matplotlib inline


# == recnn ==
import sys
sys.path.append("../../")
import recnn

cuda = torch.device('cuda')

# ---
frame_size = 10
batch_size = 10
n_epochs   = 100
plot_every = 30
step       = 0
num_items    = 5000 # n items to recommend. Can be adjusted for your vram 
# --- 
tqdm.pandas()


from jupyterthemes import jtplot
jtplot.style(theme='grade3')

## I will drop low freq items because it doesnt fit into my videocard vram

In [2]:
from typing import List, Dict, Callable

# Plain args. Shouldn't be mutated
class DataFuncKwargs:
    def __init__(self, **kwargs):
        self.kwargs = kwargs
        
    def get(self, name: str):
        if name not in self.kwargs:
            example = """
                # example on how to use kwargs:
                def prepare_dataset(args, args_mut):
                    args.set_kwarg('{}', your_value) # set kwargs for your functions here!
                    pipeline = [recnn.data.truncate_dataset, recnn.data.prepare_dataset]
                    recnn.data.build_data_pipeline(pipeline, args, args_mut)
            """
            raise AttributeError("No kwarg with name {} found!\n{}".format(name, example.format(err_desc)))
        return self.kwargs[name]
    
    def set(self, name: str, value):
        self.kwargs[name] = value

# Used for returning, arguments are mutable
class DataFuncArgsMut:
    def __init__(self, df, base, users: List[int], user_dict: Dict[int, Dict[str, np.ndarray]]):
        self.base = base
        self.users = users
        self.user_dict = user_dict
        self.df = df

In [3]:
def prepare_dataset(args_mut: DataFuncArgsMut, kwargs: DataFuncKwargs):

    """
        Basic prepare dataset function. Automatically makes index linear, in ml20 movie indices look like:
        [1, 34, 123, 2000], recnn makes it look like [0,1,2,3] for you.
    """

    # get args
    frame_size = kwargs.get('frame_size')
    key_to_id = args_mut.base.key_to_id
    df = args_mut.df
    
    # rating range mapped from [0, 5] to [-5, 5]
    df['rating'] = try_progress_apply(df['rating'], lambda i: 2 * (i - 2.5))
    # id's tend to be inconsistent and sparse so they are remapped here
    df['movieId'] = try_progress_apply(df['movieId'], lambda i: key_to_id.get(i))

    users = df[['userId', 'movieId']].groupby(['userId']).size()
    users = users[users > frame_size].sort_values(ascending=False).index

    if pd.get_type() == "modin": df = df._to_pandas() # pandas groupby is sync and doesnt affect performance 
    ratings = df.sort_values(by='timestamp').set_index('userId').drop('timestamp', axis=1).groupby('userId')

    # Groupby user
    user_dict = {}

    def app(x):
        userid = x.index[0]
        user_dict[int(userid)] = {}
        user_dict[int(userid)]['items'] = x['movieId'].values
        user_dict[int(userid)]['ratings'] = x['rating'].values

    try_progress_apply(ratings, app)

    args_mut.user_dict = user_dict
    args_mut.users = users

    return args_mut, kwargs

In [4]:
def truncate_dataset(args_mut: DataFuncArgsMut, kwargs: DataFuncKwargs):
    """
        Truncate #items to reduct_items_to provided in the kwargs
    """

    # here are adjusted n items to keep
    num_items = kwargs.get('reduce_items_to')
    df = args_mut.df
    
    to_remove = df['movieId'].value_counts().sort_values()[:-num_items].index
    to_keep = df['movieId'].value_counts().sort_values()[-num_items:].index
    to_remove_indices = df[df['movieId'].isin(to_remove)].index
    num_removed = len(to_remove)

    df.drop(to_remove_indices, inplace=True)

    for i in list(args_mut.base.movie_embeddings_key_dict.keys()):
        if i not in to_keep:
            del args_mut.base.movie_embeddings_key_dict[i]

    args_mut.base.embeddings, args_mut.base.key_to_id, \
    args_mut.base.id_to_key = recnn.data.make_items_tensor(args_mut.base.movie_embeddings_key_dict)
    args_mut.df = df

    print('action space is reduced to {} - {} = {}'.format(num_items + num_removed, num_removed,
                                                           num_items))

    return args_mut, kwargs

In [5]:
def batch_contstate_discaction(batch, item_embeddings_tensor, frame_size, num_items, *args, **kwargs):
    
    """
    Embed Batch: continuous state discrete action
    """
    
    from recnn.data.utils import get_irsu
    
    items_t, ratings_t, sizes_t, users_t = get_irsu(batch)
    items_emb = item_embeddings_tensor[items_t.long()]
    b_size = ratings_t.size(0)

    items = items_emb[:, :-1, :].view(b_size, -1)
    next_items = items_emb[:, 1:, :].view(b_size, -1)
    ratings = ratings_t[:, :-1]
    next_ratings = ratings_t[:, 1:]

    state = torch.cat([items, ratings], 1)
    next_state = torch.cat([next_items, next_ratings], 1)
    action = items_t[:, -1]
    reward = ratings_t[:, -1]

    done = torch.zeros(b_size)
    done[torch.cumsum(sizes_t - frame_size, dim=0) - 1] = 1
    
    one_hot_action = torch.zeros(action.size(0), num_items)
    one_hot_action.scatter_(1, action.view(-1,1), 1)

    batch = {'state': state, 'action': one_hot_action, 'reward': reward, 'next_state': next_state, 'done': done,
             'meta': {'users': users_t, 'sizes': sizes_t}}
    return batch

def embed_batch(batch, item_embeddings_tensor, *args, **kwargs):
    return batch_contstate_discaction(batch, item_embeddings_tensor, frame_size=frame_size, num_items=num_items)

In [8]:
def build_data_pipeline(chain, kwargs: DataFuncKwargs, args_mut: DataFuncArgsMut):
    """
        :param chain: array of callable
        :param **kwargs: any kwargs you like
    """
    print(chain)
    for call in chain:
        # note: returned kwargs are not utilized to guarantee immutability
        args_mut, _ = call(args_mut, kwargs)
    return kwargs, args_mut

def embed_batch(batch, item_embeddings_tensor, *args, **kwargs):
    return batch_contstate_discaction(batch, item_embeddings_tensor,
                                                 frame_size=frame_size, num_items=num_items)

    
def prepare_dataset(args_mut, kwargs):
    kwargs.set('reduce_items_to', num_items) # set kwargs for your functions here!
    pipeline = [recnn.data.truncate_dataset, recnn.data.prepare_dataset]
    build_data_pipeline(pipeline, kwargs, args_mut)
    

# embeddgings: https://drive.google.com/open?id=1EQ_zXBR3DKpmJR3jBgLvt-xoOvArGMsL
env = recnn.data.env.FrameEnv('../../data/embeddings/ml20_pca128.pkl',
                              '../../data/ml-20m/ratings.csv', frame_size, batch_size,
                              embed_batch=embed_batch, prepare_dataset=prepare_dataset,
                              num_workers=0)

[<function truncate_dataset at 0x7f4675108c10>, <function prepare_dataset at 0x7f4675108940>]
  0%|          | 0/18946308 [00:00<?, ?it/s]action space is reduced to 26744 - 21744 = 5000
executed!
100%|██████████| 18946308/18946308 [00:12<00:00, 1503667.62it/s]
100%|██████████| 18946308/18946308 [00:14<00:00, 1315333.93it/s]
100%|██████████| 138493/138493 [00:07<00:00, 19649.79it/s]
executed!


In [7]:
class DiscreteActor(nn.Module):
    def __init__(self, hidden_size, num_inputs, num_actions):
        super(DiscreteActor, self).__init__()

        self.linear1 = nn.Linear(num_inputs, hidden_size)
        self.linear2 = nn.Linear(hidden_size, num_actions)
        
        self.saved_log_probs = []
        self.rewards = []

    def forward(self, inputs):
        x = inputs
        x = F.relu(self.linear1(x))
        action_scores = self.linear2(x)
        return F.softmax(action_scores)
    
    
    def select_action(self, state):
        probs = self.forward(state)
        m = Categorical(probs)
        action = m.sample()
        self.saved_log_probs.append(m.log_prob(action))
        return action, probs

### Because I do not have a dynamic environment, I also will include a critic. If you have a real non static environment, you can do w/o citic.

In [8]:
class ChooseREINFORCE():
    
    def __init__(self, method=None):
        if method is None:
            method = ChooseREINFORCE.reinforce
        self.method = method
    
    @staticmethod
    def basic_reinforce(policy, returns, *args, **kwargs):
        policy_loss = []
        for log_prob, R in zip(policy.saved_log_probs, returns):
            policy_loss.append(-log_prob * R)
        policy_loss = torch.cat(policy_loss).sum()
        return policy_loss
    
    @staticmethod
    def reinforce_with_correction():
        raise NotImplemented

    def __call__(self, policy, optimizer, learn=True):
        R = 0
        
        returns = []
        for r in policy.rewards[::-1]:
            R = r + 0.99 * R
            returns.insert(0, R)
            
        returns = torch.tensor(returns)
        returns = (returns - returns.mean()) / (returns.std() + 0.0001)

        policy_loss = self.method(policy, returns)
        
        if learn:
            optimizer.zero_grad()
            policy_loss.backward()
            optimizer.step()
        
        del policy.rewards[:]
        del policy.saved_log_probs[:]

        return policy_loss

In [9]:
# === reinforce settings ===

params = {
    'reinforce': ChooseREINFORCE(ChooseREINFORCE.basic_reinforce),
    'gamma'      : 0.99,
    'min_value'  : -10,
    'max_value'  : 10,
    'policy_step': 10,
    'soft_tau'   : 0.001,
    
    'policy_lr'  : 1e-5,
    'value_lr'   : 1e-5,
    'actor_weight_init': 54e-2,
    'critic_weight_init': 6e-1,
}

# === end ===

In [10]:
nets = {
    'value_net': recnn.nn.Critic(1290, num_items, 2048, params['critic_weight_init']).to(cuda),
    'target_value_net': recnn.nn.Critic(1290, num_items, 2048, params['actor_weight_init']).to(cuda).eval(),
    
    'policy_net':  DiscreteActor(2048, 1290, num_items).to(cuda),
    'target_policy_net': DiscreteActor(2048, 1290, num_items).to(cuda).eval(),
}


# from good to bad: Ranger Radam Adam RMSprop
optimizer = {
    'value_optimizer': optim.Ranger(nets['value_net'].parameters(),
                                          lr=params['value_lr'], weight_decay=1e-2),

    'policy_optimizer': optim.Ranger(nets['policy_net'].parameters(),
                                           lr=params['policy_lr'], weight_decay=1e-5)
}


loss = {
    'test': {'value': [], 'policy': [], 'step': []},
    'train': {'value': [], 'policy': [], 'step': []}
    }

debug = {}

reinforce.writer = SummaryWriter(log_dir='../../runs/Reinforce{}/'.format(strftime("%H_%M", gmtime())))
plotter = recnn.utils.Plotter(loss, [['value', 'policy']],)
device = cuda

In [11]:
def reinforce_update(batch, params, nets, optimizer,
                     device=torch.device('cpu'),
                     debug=None, writer=recnn.utils.DummyWriter(),
                     learn=False, step=-1):
    
    state, action, reward, next_state, done = recnn.data.get_base_batch(batch)
    
    predicted_action, predicted_probs = nets['policy_net'].select_action(state)
    reward = nets['value_net'](state, predicted_probs).detach()
    nets['policy_net'].rewards.append(reward.mean())
    
    value_loss = recnn.nn.value_update(batch, params, nets, optimizer,
                     writer=writer,
                     device=device,
                     debug=debug, learn=learn, step=step)
    
    
    
    if step % params['policy_step'] == 0 and step > 0:
        
        policy_loss = params['reinforce'](nets['policy_net'], optimizer['policy_optimizer'], learn=learn)
        
        del nets['policy_net'].rewards[:]
        del nets['policy_net'].saved_log_probs[:]
        
        print('step: ', step, '| value:', value_loss.item(), '| policy', policy_loss.item())
    
        recnn.utils.soft_update(nets['value_net'], nets['target_value_net'], soft_tau=params['soft_tau'])
        recnn.utils.soft_update(nets['policy_net'], nets['target_policy_net'], soft_tau=params['soft_tau'])

        losses = {'value': value_loss.item(),
                  'policy': policy_loss.item(),
                  'step': step}

        recnn.utils.write_losses(writer, losses, kind='train' if learn else 'test')

        return losses

In [12]:
step = 0
for epoch in range(n_epochs):
    for batch in tqdm(env.train_dataloader):
        loss = reinforce_update(batch, params, nets, optimizer,
                     writer=writer,
                     device=device,
                     debug=debug, learn=True, step=step)
        if loss:
            plotter.log_losses(loss)
        step += 1
        if step % plot_every == 0:
            # clear_output(True)
            print('step', step)
            #test_loss = run_tests()
            #plotter.log_losses(test_loss, test=True)
            #plotter.plot_loss()
        #if step > 1000:
        #    pass
        #    assert False

NameError: name 'env' is not defined