# REINFORCE

The following code contains an implementation of the REINFORCE algorithm, **without Off Policy Correction, LSTM state encoder, and Noise Contrastive Estimation**. Look for these in other notebooks.

Also, I am not google staff, and unlike the paper authors, I cannot have online feedback concerning the recommendations.

**I use actor-critic for reward assigning.** In a real-world scenario that would be done through interactive user feedback, but here I use a neural network (critic) that aims to emulate it.

In [1]:
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
from torch.distributions import Categorical

import numpy as np
import pandas as pd
from tqdm.auto import tqdm

from IPython.display import clear_output
import matplotlib.pyplot as plt
%matplotlib inline


# == recnn ==
import sys
sys.path.append("../../")
import recnn

cuda = torch.device('cuda')

# ---
frame_size = 10
batch_size = 10
n_epochs   = 100
plot_every = 30
step       = 0
num_items    = 5000 # n items to recommend. Can be adjusted for your vram 
# --- 
tqdm.pandas()


from jupyterthemes import jtplot
jtplot.style(theme='grade3')

## I will drop low freq items because it doesnt fit into my videocard vram

In [2]:

def prepare_dataset(df, key_to_id, frame_size, env, sort_users=False):
    
    global num_items
    
    value_counts = df['movieId'].value_counts() 
    print('counted!')
    
    num_items = 5000
    to_remove = df['movieId'].value_counts().sort_values()[:-num_items].index
    to_keep = df['movieId'].value_counts().sort_values()[-num_items:].index
    to_remove_indices = df[df['movieId'].isin(to_remove)].index
    num_removed = len(to_remove)
    
    df.drop(to_remove_indices, inplace=True)
    print('dropped!')
    
    print('before', env.embeddings.size(), len(env.movie_embeddings_key_dict))
    for i in list(env.movie_embeddings_key_dict.keys()):
        if i not in to_keep:
            del env.movie_embeddings_key_dict[i]
        
    env.embeddings, env.key_to_id, env.id_to_key = recnn.data.utils.make_items_tensor(env.movie_embeddings_key_dict)
    
    print('after', env.embeddings.size(), len(env.movie_embeddings_key_dict))
    print('embeddings automatically updated')
    print('action space is reduced to {} - {} = {}'.format(num_items + num_removed, num_removed,
                                                           num_items))
    
    return recnn.data.prepare_dataset(df, env.key_to_id, frame_size, sort_users=sort_users)


In [3]:
def batch_contstate_discaction(batch, item_embeddings_tensor, frame_size, num_items, *args, **kwargs):
    
    """
    Embed Batch: continuous state discrete action
    """
    
    from recnn.data.utils import get_irsu
    
    items_t, ratings_t, sizes_t, users_t = get_irsu(batch)
    items_emb = item_embeddings_tensor[items_t.long()]
    b_size = ratings_t.size(0)

    items = items_emb[:, :-1, :].view(b_size, -1)
    next_items = items_emb[:, 1:, :].view(b_size, -1)
    ratings = ratings_t[:, :-1]
    next_ratings = ratings_t[:, 1:]

    state = torch.cat([items, ratings], 1)
    next_state = torch.cat([next_items, next_ratings], 1)
    action = items_t[:, -1]
    reward = ratings_t[:, -1]

    done = torch.zeros(b_size)
    done[torch.cumsum(sizes_t - frame_size, dim=0) - 1] = 1
    
    one_hot_action = torch.zeros(action.size(0), num_items)
    one_hot_action.scatter_(1, action.view(-1,1), 1)

    batch = {'state': state, 'action': one_hot_action, 'reward': reward, 'next_state': next_state, 'done': done,
             'meta': {'users': users_t, 'sizes': sizes_t}}
    return batch

def embed_batch(batch, item_embeddings_tensor, *args, **kwargs):
    return batch_contstate_discaction(batch, item_embeddings_tensor, frame_size=frame_size, num_items=num_items)

In [4]:
# embeddgings: https://drive.google.com/open?id=1EQ_zXBR3DKpmJR3jBgLvt-xoOvArGMsL
env = recnn.data.env.FrameEnv('../../data/embeddings/ml20_pca128.pkl',
                              '../../data/ml-20m/ratings.csv', frame_size, batch_size,
                              embed_batch=embed_batch, prepare_dataset=prepare_dataset,
                              num_workers = 0)

counted!
dropped!
before torch.Size([27278, 128]) 27278
after torch.Size([5000, 128]) 5000
embeddings automatically updated
action space is reduced to 26744 - 21744 = 5000


HBox(children=(IntProgress(value=0, max=18946308), HTML(value='')))




HBox(children=(IntProgress(value=0, max=18946308), HTML(value='')))




HBox(children=(IntProgress(value=0, max=138493), HTML(value='')))




In [5]:
class DiscretePolicy(nn.Module):
    def __init__(self, hidden_size, num_inputs, num_actions):
        super(DiscretePolicy, self).__init__()

        self.linear1 = nn.Linear(num_inputs, hidden_size)
        self.linear2 = nn.Linear(hidden_size, num_actions)
        
        self.saved_log_probs = []
        self.rewards = []

    def forward(self, inputs):
        x = inputs
        x = F.relu(self.linear1(x))
        action_scores = self.linear2(x)
        return F.softmax(action_scores)
    
    
    def select_action(self, state):
        probs = self.forward(state)
        m = Categorical(probs)
        action = m.sample()
        self.saved_log_probs.append(m.log_prob(action))
        return action, probs

### Because I do not have a dynamic environment, I also will include a critic. If you have a real non static environment, you can do w/o citic.

In [6]:
# === reinforce settings ===

params = {
    'gamma'      : 0.99,
    'min_value'  : -10,
    'max_value'  : 10,
    'policy_step': 10,
    'soft_tau'   : 0.001,
    
    'policy_lr'  : 1e-5,
    'value_lr'   : 1e-5,
    'actor_weight_init': 54e-2,
    'critic_weight_init': 6e-1,
}

# === end ===

In [7]:
nets = {
    'value_net': recnn.nn.Critic(1290, num_items, 2048, params['critic_weight_init']).to(cuda),
    'target_value_net': recnn.nn.Critic(1290, num_items, 2048, params['actor_weight_init']).to(cuda).eval(),
    
    'policy_net':  DiscretePolicy(2048, 1290, num_items).to(cuda),
    'target_policy_net': DiscretePolicy(2048, 1290, num_items).to(cuda).eval(),
}


# from good to bad: Ranger Radam Adam RMSprop
optimizer = {
    'value_optimizer': recnn.optim.Ranger(nets['value_net'].parameters(),
                                          lr=params['value_lr'], weight_decay=1e-2),

    'policy_optimizer': recnn.optim.Ranger(nets['policy_net'].parameters(),
                                           lr=params['policy_lr'], weight_decay=1e-5)
}


loss = {
    'test': {'value': [], 'policy': [], 'step': []},
    'train': {'value': [], 'policy': [], 'step': []}
    }

debug = {}

writer = SummaryWriter(log_dir='../../runs')
plotter = recnn.utils.Plotter(loss, [['value', 'policy']],)
device = cuda

In [8]:
def td_update(batch, learn=True):
    
    state, action, reward, next_state, done = recnn.data.get_base_batch(batch)
    
    # Value Learning
    
    with torch.no_grad():
        next_action = nets['target_policy_net'](next_state)
        target_value   = nets['target_value_net'](next_state, next_action.detach())
        expected_value = reward + (1.0 - done) * 0.99 * target_value
        expected_value = torch.clamp(expected_value, -10, 10)

    value = nets['value_net'](state, action)
    value_loss = torch.pow(value - expected_value.detach(), 2).mean()
    
    if learn:
        optimizer['value_optimizer'].zero_grad()
        value_loss.backward()
        optimizer['value_optimizer'].step()
        
    return value_loss

In [11]:
class REINFORCE():
    
    @staticmethod
    def reinforce(policy, returns, *args, **kwargs):
        policy_loss = []
        for log_prob, R in zip(policy.saved_log_probs, returns):
            policy_loss.append(-log_prob * R)
        policy_loss = torch.cat(policy_loss).sum()
        return policy_loss
    
    @staticmethod
    def reinforce_with_correction():
        raise NotImplemented

    @staticmethod
    def call(policy, optimizer, algorithm=None):
        
        if algorithm is None:
            algorithm  = REINFORCE.reinforce
            
        R = torch.tensor([0]).to(cuda)

        returns = []
        for r in policy.rewards[::-1]:
            R = r + 0.99 * R
            returns.insert(0, R)
            
        returns = torch.tensor(returns)
        returns = (returns - returns.mean()) / (returns.std() + 0.0001)

        policy_loss = algorithm(policy, returns)

        optimizer.zero_grad()
        policy_loss.backward()
        optimizer.step()
        
        del policy.rewards[:]
        del policy.saved_log_probs[:]

        return policy_loss

In [14]:
def reinforce_update(batch, params, nets, optimizer,
                     writer=recnn.utils.DummyWriter(),
                     device=torch.device('cpu'),
                     debug=dict(), learn=False, step=-1):
    
    state, action, reward, next_state, done = recnn.data.get_base_batch(batch)
    
    predicted_action, predicted_probs = nets['policy_net'].select_action(state)
    reward = nets['value_net'](state, predicted_probs).detach()
    nets['policy_net'].rewards.append(reward.mean())
    
    value_loss = recnn.nn.value_update(batch, params, nets, optimizer,
                     writer=writer,
                     device=device,
                     debug=debug, learn=True, step=step)
    
    
    
    if step % params['policy_step'] == 0 and step > 0:
        
        policy_loss = REINFORCE.call(nets['policy_net'], optimizer['policy_optimizer'])
        del nets['policy_net'].rewards[:]
        del nets['policy_net'].saved_log_probs[:]
        print('step: ', step, '| value:', value_loss.item(), '| policy', policy_loss.item())
    
        recnn.utils.soft_update(nets['value_net'], nets['target_value_net'], soft_tau=params['soft_tau'])
        recnn.utils.soft_update(nets['policy_net'], nets['target_policy_net'], soft_tau=params['soft_tau'])

        losses = {'value': value_loss.item(),
                  'policy': policy_loss.item(),
                  'step': step}

        recnn.utils.write_losses(writer, losses, kind='train' if learn else 'test')

        return losses

In [None]:
step = 0
for epoch in range(n_epochs):
    for batch in tqdm(env.train_dataloader):
        loss = reinforce_update(batch, params, nets, optimizer,
                     writer=writer,
                     device=device,
                     debug=debug, learn=True, step=step)
        if loss:
            plotter.log_losses(loss)
        step += 1
        if step % plot_every == 0:
            # clear_output(True)
            print('step', step)
            #test_loss = run_tests()
            #plotter.log_losses(test_loss, test=True)
            #plotter.plot_loss()
        #if step > 1000:
        #    pass
        #    assert False

HBox(children=(IntProgress(value=0, max=13155), HTML(value='')))

  from ipykernel import kernelapp as app


step:  10 | value: 50.69260025024414 | policy -54447.9140625
step:  20 | value: 39.12702941894531 | policy -1322.011474609375
step 30
step:  30 | value: 32.70941162109375 | policy 15132.70703125
step:  40 | value: 24.22550392150879 | policy 22501.21484375
step:  50 | value: 23.58995246887207 | policy 13874.126953125
step 60
step:  60 | value: 18.77687644958496 | policy -18705.01953125
step:  70 | value: 18.473474502563477 | policy -6204.1689453125
step:  80 | value: 18.67662811279297 | policy -11902.884765625
step 90
step:  90 | value: 17.839340209960938 | policy -18216.49609375
step:  100 | value: 17.994937896728516 | policy 12965.2548828125
step:  110 | value: 18.77390480041504 | policy 897.503662109375
step 120
step:  120 | value: 17.465091705322266 | policy 18707.359375
step:  130 | value: 17.891725540161133 | policy 12005.7431640625
step:  140 | value: 17.914501190185547 | policy -7836.19140625
step 150
step:  150 | value: 17.158405303955078 | policy 16218.330078125
step:  160 | v

step:  1240 | value: 7.234259605407715 | policy -8767.65234375
step:  1250 | value: 7.516167163848877 | policy -1342.419189453125
step 1260
step:  1260 | value: 7.649448871612549 | policy 2115.92431640625
step:  1270 | value: 7.189054489135742 | policy 11169.875
step:  1280 | value: 6.84197998046875 | policy -20524.76171875
step 1290
step:  1290 | value: 7.2946906089782715 | policy -8359.2216796875
step:  1300 | value: 6.865091800689697 | policy -1599.4044189453125
step:  1310 | value: 6.977664470672607 | policy -3350.458984375
step 1320
step:  1320 | value: 6.87705135345459 | policy -18109.0390625
step:  1330 | value: 6.453563213348389 | policy 141.71835327148438
step:  1340 | value: 6.6365509033203125 | policy -34606.1953125
step 1350
step:  1350 | value: 6.6394524574279785 | policy 8188.4697265625
step:  1360 | value: 6.757900238037109 | policy 15928.171875
step:  1370 | value: 6.947423458099365 | policy -7407.06640625
step 1380
step:  1380 | value: 6.174886226654053 | policy -1499.

step:  2470 | value: 3.309230089187622 | policy 23788.26953125
step:  2480 | value: 3.086838722229004 | policy -13578.12890625
step 2490
step:  2490 | value: 2.963146448135376 | policy 10986.34765625
step:  2500 | value: 3.2121074199676514 | policy 39080.453125
step:  2510 | value: 3.084400177001953 | policy 11100.427734375
step 2520
step:  2520 | value: 3.132249355316162 | policy 5319.16259765625
step:  2530 | value: 3.0717625617980957 | policy 1523.568359375
step:  2540 | value: 3.2605490684509277 | policy -16160.1259765625
step 2550
step:  2550 | value: 2.91607403755188 | policy -16951.5546875
step:  2560 | value: 3.175579786300659 | policy -19928.798828125
step:  2570 | value: 3.106559991836548 | policy -30896.44921875
step 2580
step:  2580 | value: 3.1524198055267334 | policy 1927.8369140625
step:  2590 | value: 2.9477086067199707 | policy 17286.58203125
step:  2600 | value: 2.92819881439209 | policy 6001.7998046875
step 2610
step:  2610 | value: 2.9511878490448 | policy 6979.4633

In [None]:
nets