In [22]:
%%wr
import torch
import torch.nn as nn
from episode import Episode
import numpy as np
from config import max_guesses, guess_embedding_size, lstm_hidden_size,feedback_embedding_size, max_feedback
                    

np.random.seed(123)

class Policy(nn.Module):

    def __init__(self):
        super(Policy, self).__init__()
        self.guess_embed    = nn.Embedding(max_guesses+1, guess_embedding_size)
        self.feedback_embed = nn.Embedding(max_guesses+1, feedback_embedding_size)
        self.lstm_cell      = nn.LSTMCell(input_size = guess_embedding_size + feedback_embedding_size,
                                      hidden_size = lstm_hidden_size)
        self.fc             = nn.Linear(lstm_hidden_size, max_guesses)
        
    def forward(self, game_state):
        
        hidden = None
        cell_state = None
        for guess, feedback in game_state:
            guess_tensor      = torch.tensor(guess)
            feedback_tensor   = torch.tensor(feedback)
            guess_embedded    = self.guess_embed(guess_tensor)
            feedback_embedded = self.feedback_embed(feedback_tensor)
            combined_embedded = torch.cat([guess_embedded, feedback_embedded],
                                         axis=-1)
            # input of shape (seq_len, batch, input_size)
            combined_embedded = combined_embedded.reshape(1,-1)            
            if hidden == None:
                hidden, cell_state = self.lstm_cell(combined_embedded)
#                 print(hidden.shape, cell_state.shape)
            else:
                hidden, cell_state = self.lstm_cell(combined_embedded, (hidden, combined_embedded))
        
        logits = self.fc(hidden)
        
        logits = nn.functional.softmax(logits)
        return logits
                                                 

Writing pytorch_policy.py
