In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [25]:
import numpy as np
import pandas as pd
from tqdm import tqdm_notebook as tqdm
import pickle

cuda = torch.device('cpu')
frame_size = 10
batch_size = 1 # only 1 for testing

In [3]:
import json
movies = pickle.load(open('../data/infos_pca128.pytorch', 'rb'))
infos_web = json.load(open('../data/infos.json')) 

In [26]:
for i in movies.keys():
    movies[i] = movies[i].to(cuda)

In [27]:
class StateRepresentation(nn.Module):
    def __init__(self):
        super(StateRepresentation, self).__init__()
        self.lin = nn.Sequential(
            # 128 - embed size, 1 - rating size
            nn.Linear(frame_size * (128 + 1), 256),
            nn.Tanh(),
        )
        
    def forward(self, info, ratings):
        # raw_size - size of the raw movie info. Constant = 2591
        # embed_size - size of an ebedded movie. Constant = 64
        # raw -> embed via embeddings module defined above
        # input: currently info is batch_size x frame_size x raw_size
        # step 1: tramsform info to batch_size x (frame_size * embed_size)
        info = info.view(batch_size, frame_size * 128)
        # step 2: stack info with ratings. stacked: batch_size x (embed_size + 1)
        stacked = torch.cat([info, ratings], 1)
        # step 3: apply state represemtation module
        state = self.lin(stacked)
        return state

In [28]:
class Actor(nn.Module):
    def __init__(self, num_inputs, num_actions, hidden_size, init_w=3e-3):
        super(Actor, self).__init__()
        
        self.state_rep = StateRepresentation()
        
        self.linear1 = nn.Linear(num_inputs, hidden_size)
        self.linear2 = nn.Linear(hidden_size, hidden_size)
        self.linear3 = nn.Linear(hidden_size, num_actions)
        
        self.linear3.weight.data.uniform_(-init_w, init_w)
        self.linear3.bias.data.uniform_(-init_w, init_w)
        
    def forward(self, info, rewards):
        state = self.state_rep(info, rewards)
        x = F.relu(self.linear1(state))
        x = F.relu(self.linear2(x))
        x = torch.tanh(self.linear3(x))
        return state, x
    
    def get_action(self, info, rewards):
        state, action = self.forward(info, rewards)
        return state, action

In [29]:
class Critic(nn.Module):
    def __init__(self, num_inputs, num_actions, hidden_size, init_w=3e-3):
        super(Critic, self).__init__()
        
        self.linear1 = nn.Linear(num_inputs + num_actions, hidden_size)
        self.linear2 = nn.Linear(hidden_size, hidden_size)
        self.linear3 = nn.Linear(hidden_size, 1)
        
        self.linear3.weight.data.uniform_(-init_w, init_w)
        self.linear3.bias.data.uniform_(-init_w, init_w)
        
    def forward(self, state, action):
        action = torch.squeeze(action)
        x = torch.cat([state, action], 1)
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        x = self.linear3(x)
        return x

In [30]:
value_net  = Critic(256, 128, 320).to(cuda)
policy_net = Actor(256, 128, 192).to(cuda)
value_net.load_state_dict(torch.load("../models/value.pt", map_location='cpu'))
policy_net.load_state_dict(torch.load("../models/policy.pt", map_location='cpu'))
value_net.eval()
policy_net.eval()
print()




In [31]:
watched_ids = [1732, 172, 370, 1639, 1380, 2054, 471, 2502, 1625, 2001]
watched_ratings = torch.tensor([4.0, -3.0, -3.0, 2.0, 3.0, -2.0, 3.0, 1.0, 0.0, -1.0]).to(cuda).unsqueeze(0).float()
watched_infos = [movies[i] for i in watched_ids] 
watched_infos = torch.cat(watched_infos).unsqueeze(0)
enc_state, action = policy_net(watched_infos, watched_ratings)

In [32]:
action

tensor([[-0.9985, -0.9969,  0.9973, -0.9970, -0.9986, -0.9968, -0.9970, -0.9993,
         -0.9966, -0.9953, -0.9994, -0.9990, -0.9992, -0.9925, -0.9973, -0.9991,
         -0.9986, -0.9967, -0.9962, -0.9967, -0.9971, -0.9966,  0.7486, -0.9983,
         -0.9938, -0.9993, -0.9943, -0.9988, -0.9951,  0.9997, -0.9969, -0.9994,
          0.9977, -0.9935, -0.9993,  0.9988, -0.9990, -0.9986,  0.9983,  0.5878,
         -0.9795, -0.9975,  0.0671, -0.9976, -0.9993, -0.9992, -0.9970, -0.9964,
          0.8853, -0.9967, -0.9966, -0.9948, -0.9954,  0.9823, -0.9990,  0.9873,
         -0.9997, -0.9993, -0.9930, -0.9982, -0.3362, -0.9938, -0.9922, -0.9955,
         -0.9971,  0.9997,  0.9988, -0.9993, -0.9989, -0.9710,  0.9949, -0.9990,
         -0.9954,  0.9978, -0.9994, -0.9944, -0.9970, -0.9937, -0.9984,  0.9967,
         -0.9955, -0.9990, -0.9992, -0.9989, -0.9993, -0.9974, -0.9963, -0.9993,
         -0.9911, -0.9990,  0.9983,  0.9981, -0.9991, -0.9993, -0.9993, -0.9991,
         -0.9981, -0.9991, -

In [24]:
print("FUCKFUCKFUCK")

FUCKFUCKFUCK
