In [1]:
import sys
sys.path.append("src")

In [2]:
import torch

from char_rnn.model import CharRNN
from char_rnn.infer import InferenceModule

In [3]:
with open("data/anna.txt", "r") as f:
    text = f.read()

tokens = tuple(set(text))

int2char = {i: ch for i, ch in enumerate(tokens)}
char2int = {ch: i for i, ch in int2char.items()}

In [4]:
model = CharRNN(tokens)
model

CharRNN(
  (lstm): LSTM(83, 256, num_layers=2, batch_first=True, dropout=0.5)
  (dropout): Dropout(p=0.5, inplace=False)
  (fc): Linear(in_features=256, out_features=83, bias=True)
)

In [5]:
model.state_dict()

OrderedDict([('lstm.weight_ih_l0',
              tensor([[-0.0187,  0.0510, -0.0024,  ...,  0.0447, -0.0341,  0.0451],
                      [ 0.0327,  0.0398,  0.0088,  ...,  0.0068, -0.0435, -0.0376],
                      [ 0.0374,  0.0513, -0.0488,  ...,  0.0312, -0.0026, -0.0581],
                      ...,
                      [-0.0445,  0.0274,  0.0330,  ..., -0.0020, -0.0230,  0.0068],
                      [ 0.0208, -0.0154, -0.0294,  ...,  0.0179,  0.0614, -0.0599],
                      [ 0.0373,  0.0123, -0.0008,  ...,  0.0290,  0.0026,  0.0302]])),
             ('lstm.weight_hh_l0',
              tensor([[-0.0312,  0.0475,  0.0510,  ...,  0.0010,  0.0533, -0.0043],
                      [-0.0532,  0.0014,  0.0348,  ..., -0.0518,  0.0106,  0.0123],
                      [-0.0525,  0.0607,  0.0073,  ..., -0.0388,  0.0341,  0.0406],
                      ...,
                      [-0.0415, -0.0282,  0.0414,  ..., -0.0131,  0.0567,  0.0470],
                      [ 0.0201, -

In [6]:
_state_dict = torch.load(
    "checkpoints/rnn-epoch=02-val_loss=1.75.ckpt",
    map_location=torch.device("cpu"),
    weights_only=False
)

state_dict = {k.replace("model.", ""):v for k,v in _state_dict["state_dict"].items()}
model.load_state_dict(state_dict)

<All keys matched successfully>

In [7]:
infer = InferenceModule(
    tokens, 
    char2int, 
    int2char, 
    model_path="checkpoints/rnn-epoch=02-val_loss=1.75.ckpt"
)

In [11]:
infer.generate(1000, prime="The superman", top_k=2)

"The supermanv?;FY`'@RY?h@YF?RYe;``RpYF?RY`'VpYF'Y?RTY?;VpYF'YF?RYeT'eRVFY;VpYF'YF?RYe;TRY;VpYF'nF?RYeT'eRT`Y;VpY`?RY`;ppYF'Y?RTY?RTY`'VR:Y;`n?RYU;`YF'YF?RYg'@RVFY'-YF?RY`;@R:Y;VpY`?h`Y;FY?RTY`;Vp:Y;`YF?RYg;TRY;VpY`;hpYF'Y?RTY`;VpYF'Y?RTY`'VFYF'YF?RYe;TRY;VpYF?RYg'@eT;Fh'V:Y;`YF?RY`'@RF?RTY;VpY`?RY`'mVpYF?;FY?RY?RY?;pY4RRVnF?RYg'@eTR``hVfYF'YF?RneT'VgR``Y'-YF?RY`;@hVf:Y;`YF?;FY?RYU;`YF?RYeT'VRpYF'Y?RTY?;VpYF?RY`;hpYF?;FY?RY?;pnVRF?RTY?RTY?RTY`;hpYF'Y?R;TpYF?RYeT'ehVRYF'YF?RYeT'VgRY;Vpn?RTYF'VpY?h@YF'Y?h@YF'Y?h@Y;VpY`'Y?RTY?;VpYU;`YF'YF'Y?h`Y`;VpY;VpY`?h!!YF'YF'@RY?RTY?RTY?R;p:Y;VpYF?RYg'VVRTYF?RY`;hpY;FYF?RYg'VVRTYF'Y?RTY?;pY4RRpYF'Y?RTY?;Vp:Y;VpY`?RYU;`Y`'Y?RTY?RT:Y;VpY`?h!pYF'Y?h@Y;Y`;VpYF'Y?h@xnn8MF,`YV'FY4mFYF?RY`'@RY?RTY`;VpY;VpY`'Y@'TRY;Y`;hpYF'YF'YF?RYe;TRYF?RY`'@RY?;``RpYF'Y?RTY`'Vp:Y;VpYF?RY`'@RF?hVfY;VpYF?;FYF?RY`'@RYF?RYg'VpRT`R;TY;FYF?RY`;@RYF'Y?RTY?RTY`'VpRTY;VpYF?RYe;TFhVf:Y;VpYF'YF?RYeT'VRpYF'YF?RYg;TRY'-Y?RT:Y;VpnF?RYe;`FRTY;VpYF?RYeTh`Fh'VY;VpnF'YF'Y?RTY?R;TpYF?;FYF?RY