In [34]:
import torch
import torch.nn as nn

import numpy as np
from config import max_guesses, guess_embedding_size, lstm_hidden_size,feedback_embedding_size, max_feedback
                    


class Policy(nn.Module):

    def __init__(self):
        super(Policy, self).__init__()
        self.guess_embed    = nn.Embedding(max_guesses, guess_embedding_size)
        self.pos_feedback_embed = nn.Embedding(max_guesses+1, feedback_embedding_size)  # pegs in right place
        self.neg_feedback_embed = nn.Embedding(5, feedback_embedding_size)  # pegs in wrong place.
        self.lstm      = nn.LSTM(input_size= guess_embedding_size + 2 * feedback_embedding_size,hidden_size=lstm_hidden_size)
        self.fc             = nn.Linear(lstm_hidden_size, max_guesses)
        
    def forward(self, game_state):
        
        guesses = [elem[0] for elem in game_state]
        feedbacks = [elem[1] for elem in game_state]
        guesses = [self.guess_embed(torch.tensor(elem)) for elem in guesses]
        feedbacks = [[self.pos_feedback_embed(torch.tensor(elem[0])),
                      self.neg_feedback_embed(torch.tensor(elem[1]))] for elem in feedbacks]
        feedbacks = [torch.cat(lst, -1) for lst in feedbacks]
        
        combined = [torch.cat([i,j], -1) for i,j in zip(guesses, feedbacks)]
        combined = torch.stack(combined)
        combined = torch.unsqueeze(combined,1)
        combined,_ = self.lstm(combined)
        
        logits = self.fc(combined)
        
        
        logits = nn.functional.softmax(logits,-1)
        return logits
                                                 

In [36]:
pol = Policy()

game_state = [(10, (1,1)), (20,(2,2))]

logits = pol.forward(game_state)

In [37]:
logits.shape

torch.Size([2, 1, 1296])

In [3]:
ep = Episode(pol, '1223')

In [39]:
import itertools
import random
from collections import Counter

In [40]:
list(itertools.product(range(2), repeat=2))

[(0, 0), (0, 1), (1, 0), (1, 1)]

In [66]:
def score(p, q):
    hits = sum(p_i == q_i for p_i, q_i in zip(p, q))
    misses = sum((Counter(p) & Counter(q)).values()) - hits
    return hits, misses


def generate_best_pattern(code='1222'):
    l = []
    possibles = list(itertools.product('012345', repeat=4))
    while len(possibles) > 1:
        guess = random.choice(possibles)
        if guess == code:
            break
        l.append(guess)
        result = score(code, guess)
        possibles = [p for p in possibles if score(p, guess) == result]
    
    l.append(possibles[0])
    if l[-1]==l[-2]:
        l = l[:-1]
    
    l = [''.join(lst) for lst in l]
    return l

def _number_from_index(index):
        assert(0 <= index < 6**4)
        digits = []
        while index > 0:
            digits.append(str(index % 6))
            index = index // 6
        return "".join(reversed(digits)).zfill(4)

def gen_patterns(num):
    
    nums = np.random.randint(low=0, high=6**4, size=num)
    nums = [_number_from_index(elem) for elem in nums]
    return [generate_best_pattern(elem) for elem in nums]

gen_patterns(20)