In [1]:
from game_core import Wordle

from typing import Tuple
import torch
from torch import Tensor
from torch.nn import Linear, Softmax, Dropout, LeakyReLU, Module, Tanh, Flatten, CrossEntropyLoss

In [2]:
class Guesser(Module):
    def __init__(self, n) -> None:
        super().__init__()
        
        self.n = n
        self.in_feature = Linear(26 * (n + 1), 2048)
        self.leakyRelu = LeakyReLU()
        self.tanh = Tanh()
        self.dropout = Dropout()
        self.out_feature = Linear(2048, 26 * n)
        
        self.loss = CrossEntropyLoss()
    
    def forward(self, x:Tensor, label:Tensor=None) -> Tuple[Tensor, Tensor]:
        x = self.encode_input(x)
        x = self.in_feature(x)
        x = self.leakyRelu(x)
        x = self.tanh(x)
        x = self.dropout(x)
        x = self.out_feature(x)
        x = self.decode_output(x)
        
        loss = None
        if label is not None:
            loss = self.loss(x, label)
        return x, loss
    
    def encode_input(self, x:Tensor) -> Tensor:
        batch_size = x.shape[0]
        x = x.reshape(batch_size, ((self.n + 1) * 26))
        return x

    def decode_output(self, x:Tensor) -> Tensor:
        batch_size = x.shape[0]
        x = x.reshape((batch_size, self.n, 26))
        return x

    def save(self, save_path):
        '''
        save_path: path to the file
        e.g. './model/saved_model/save.bin'
        '''
        model_state_dict = self.state_dict()
        torch.save(model_state_dict, save_path)
        
    def load(self, save_path):
        '''
        save_path: path to the file
        e.g. './model/saved_model/save.bin'
        '''
        self.load_state_dict(torch.load(save_path))

In [3]:
wordle = Wordle(target="scone", training=True)
guesser = Guesser(n=5)

In [4]:
x = torch.stack((wordle.query(),))
print(x.shape)

torch.Size([1, 1, 6, 26])
