In [1]:
import torch
import torch.nn as nn
import numpy as np
from sf_examples.nethack.models.scaled import CharColorEncoderResnet, ScaledNet

In [2]:
from sample_factory.cfg.arguments import load_from_checkpoint
from sample_factory.utils.attr_dict import AttrDict
from sample_factory.algo.utils.env_info import extract_env_info
from sample_factory.algo.utils.make_env import make_env_func_batched
from sample_factory.model.model_utils import get_rnn_size

from sf_examples.nethack.train_nethack import parse_nethack_args, register_nethack_components, make_nethack_actor_critic
from sf_examples.nethack.models.utils import scale_width_critic, downscale_first_layer, downscale_last_layer

/home/bartek/Workspace/ideas/sample-factory/sf_examples/nethack/render_utils/Hack-Regular.ttf


In [3]:
env_name = "challenge"
register_nethack_components()

[36m[2024-05-27 14:43:29,852][209200] register_encoder_factory: <function make_nethack_encoder at 0x7b3fd7138ee0>[0m
[36m[2024-05-27 14:43:29,853][209200] register_actor_critic_factory: <function make_nethack_actor_critic at 0x7b3fd7139090>[0m


In [4]:
cfg = parse_nethack_args(
    [
        f"--env={env_name}",
        "--model=ScaledNet",
        "--use_resnet=True",
        "--h_dim=1738",
        "--rnn_size=1738",
        "--actor_critic_share_weights=False",
        "--critic_increase_factor=2",
    ]
)

In [5]:
env = make_env_func_batched(cfg, env_config=AttrDict(worker_index=0, vector_index=0, env_id=0))
env_info = extract_env_info(env, cfg)

obs_space = env_info.obs_space
action_space = env.action_space
obs, info = env.reset(seed=0)

In [6]:
factor = 2

In [7]:
rnn_states = torch.zeros([1, get_rnn_size(cfg)], dtype=torch.float32)
model = make_nethack_actor_critic(cfg, obs_space, action_space)

[36m[2024-05-27 14:43:33,459][209200] RunningMeanStd input shape: (1,)[0m


In [8]:
scale_width_critic(model, factor=factor)

downscale_first_layer(model.critic_encoder.topline_encoder.msg_fwd, '0', factor)
downscale_first_layer(model.critic_encoder.bottomline_encoder.conv_net, '0', factor)
downscale_first_layer(model.critic_encoder.screen_encoder.conv_net[0], '0', factor)
downscale_first_layer(model.critic_encoder.extract_crop_representation, '0', factor)

downscale_last_layer(model.critic, "critic_linear", factor)

model.critic_encoder.screen_encoder.out_size *= factor

In [10]:
model(obs, rnn_states)

{'values': tensor(0.0167, grad_fn=<SqueezeBackward0>),
 'action_logits': tensor([[-0.0682,  0.0368, -0.0674,  0.1289,  0.0284,  0.0404,  0.0260, -0.0440,
           0.0368, -0.0241,  0.0561, -0.0605,  0.0252, -0.0351,  0.0306,  0.1509,
          -0.1443,  0.0058,  0.0551, -0.0392,  0.0716,  0.0833,  0.1486, -0.0617,
          -0.0385, -0.0565,  0.1088,  0.0341,  0.0141,  0.1072,  0.0090,  0.0732,
          -0.0770,  0.0481,  0.0686,  0.0546,  0.0880, -0.0526,  0.0450, -0.0167,
           0.0202, -0.0883, -0.0978, -0.0865, -0.0264, -0.0358, -0.0701, -0.0983,
           0.0713,  0.0654, -0.0719,  0.0692,  0.0038, -0.0144,  0.0355, -0.1783,
           0.0142,  0.0204, -0.0392,  0.0806,  0.0317,  0.0232, -0.0904,  0.1519,
           0.0628, -0.1408, -0.0519,  0.0024, -0.0246, -0.0061,  0.0137,  0.0216,
           0.0264,  0.0604,  0.1507, -0.0522, -0.0405,  0.0677,  0.0082,  0.0526,
          -0.0601,  0.0098,  0.0258,  0.0578,  0.0570,  0.1510,  0.0259,  0.0474,
           0.0135,  0.0109

# OLD

In [None]:
cfg = parse_nethack_args(
    [
        f"--env={env_name}",
    ]
)

In [None]:
env = make_env_func_batched(cfg, env_config=AttrDict(worker_index=0, vector_index=0, env_id=0))
env_info = extract_env_info(env, cfg)

obs_space = env_info.obs_space

In [289]:
model = ScaledNet(cfg, obs_space=obs_space)

In [290]:
# Function to register hooks
handles = []

def register_hooks(model):
    def hook(module, input, output):
        module.output_shape = output.shape
        # print(f"{module.__class__.__name__} output shape: {output.shape}")

    for name, child in model.named_children():
        if isinstance(child, (nn.Linear, nn.Conv1d, nn.Conv2d)):
            handle = child.register_forward_hook(hook)
            handles.append(handle)
        else:
            register_hooks(child)

model = ScaledNet(cfg, obs_space=obs_space)

register_hooks(model)

obs, info = env.reset()

model(obs)

tensor([[0.0000, 0.0000, 0.1639,  ..., 0.1478, 0.2298, 0.1524]],
       grad_fn=<ReluBackward0>)

In [291]:
model

ScaledNet(
  (encoders): ModuleDict()
  (crop): Crop()
  (extract_crop_representation): Sequential(
    (0): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ELU(alpha=1.0)
    (3): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ELU(alpha=1.0)
    (6): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ELU(alpha=1.0)
    (9): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (10): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ELU(alpha=1.0)
    (12): Conv2d(16, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_ru

In [160]:
def scale_width(module, factor=2):
    for child in list(module.children()):
        
        for name, subchild in child.named_children():
            if isinstance(subchild, (nn.Conv1d, nn.Conv2d)):
                new_in_channels = int(subchild.in_channels * factor)
                new_out_channels = int(subchild.out_channels * factor)

                new_layer = subchild.__class__(
                    new_in_channels, 
                    new_out_channels, 
                    kernel_size=subchild.kernel_size, 
                    bias=subchild.bias is not None, 
                    stride=subchild.stride, 
                    padding=subchild.padding, 
                    dilation=subchild.dilation
                )
                setattr(child, name, new_layer)
                
            elif isinstance(subchild, nn.Linear):           
                new_in_features = int(subchild.in_features * factor)
                new_out_features = int(subchild.out_features * factor)
            
                new_layer = nn.Linear(new_in_features, new_out_features, bias=subchild.bias is not None)
                setattr(child, name, new_layer)
            
            elif isinstance(subchild, nn.BatchNorm2d):
                new_layer = nn.BatchNorm2d(int(subchild.num_features * factor))
                setattr(child, name, new_layer)
            
        scale_width(child, factor=factor)

In [245]:
def downscale_first_layer(module, name, factor=2):
    cur_layer = module[int(name)]
    
    if isinstance(cur_layer, (nn.Conv1d, nn.Conv2d)):
        new_in_channels = int(cur_layer.in_channels // factor)

        new_layer = cur_layer.__class__(
            new_in_channels, 
            cur_layer.out_channels, 
            kernel_size=cur_layer.kernel_size, 
            bias=cur_layer.bias is not None, 
            stride=cur_layer.stride, 
            padding=cur_layer.padding, 
            dilation=cur_layer.dilation
        )
        setattr(module, name, new_layer)
        
    elif isinstance(cur_layer, nn.Linear):           
        new_in_features = int(cur_layer.in_features // factor)
    
        new_layer = nn.Linear(new_in_features, cur_layer.out_features, bias=cur_layer.bias is not None)
        setattr(module, name, new_layer)
        

In [270]:
model = ScaledNet(cfg, obs_space=obs_space)

In [280]:
factor = 2

In [281]:
scale_width(model, factor=factor)

In [282]:
downscale_first_layer(model.topline_encoder.msg_fwd, '0', factor)
downscale_first_layer(model.bottomline_encoder.conv_net, '0', factor)
downscale_first_layer(model.screen_encoder.conv_net[0], '0', factor)
downscale_first_layer(model.extract_crop_representation, '0', factor)

model.screen_encoder.out_size *= factor

In [283]:
model(obs).shape

torch.Size([1, 6952])

In [284]:
model

ScaledNet(
  (encoders): ModuleDict()
  (crop): Crop()
  (extract_crop_representation): Sequential(
    (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ELU(alpha=1.0)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ELU(alpha=1.0)
    (6): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ELU(alpha=1.0)
    (9): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ELU(alpha=1.0)
    (12): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_

In [215]:
B, H, W = obs["tty_chars"].shape
# to process images with CNNs we need channels dim
C = 1

# Take last channel for now
topline = obs["tty_chars"][:, 0].contiguous()
bottom_line = obs["tty_chars"][:, -2:].contiguous()

In [216]:
scale = 2

In [244]:
def downscale_first_layer(module, name, factor=2):
    cur_layer = module[int(name)]
    
    if isinstance(cur_layer, (nn.Conv1d, nn.Conv2d)):
        new_in_channels = int(cur_layer.in_channels // factor)

        new_layer = cur_layer.__class__(
            new_in_channels, 
            cur_layer.out_channels, 
            kernel_size=cur_layer.kernel_size, 
            bias=cur_layer.bias is not None, 
            stride=cur_layer.stride, 
            padding=cur_layer.padding, 
            dilation=cur_layer.dilation
        )
        setattr(module, name, new_layer)
        
    elif isinstance(cur_layer, nn.Linear):           
        new_in_features = int(cur_layer.in_features // factor)
    
        new_layer = nn.Linear(new_in_features, cur_layer.out_features, bias=cur_layer.bias is not None)
        setattr(module, name, new_layer)

In [217]:
model.topline_encoder

TopLineEncoder(
  (msg_fwd): Sequential(
    (0): Linear(in_features=40960, out_features=128, bias=True)
    (1): ELU(alpha=1.0, inplace=True)
    (2): Linear(in_features=128, out_features=128, bias=True)
    (3): ELU(alpha=1.0, inplace=True)
  )
)

In [218]:
name = '0'
cur_layer = model.topline_encoder.msg_fwd[int(name)]
value = nn.Linear(cur_layer.in_features // scale, cur_layer.out_features)
setattr(model.topline_encoder.msg_fwd, name, value)

In [219]:
model.topline_encoder

TopLineEncoder(
  (msg_fwd): Sequential(
    (0): Linear(in_features=20480, out_features=128, bias=True)
    (1): ELU(alpha=1.0, inplace=True)
    (2): Linear(in_features=128, out_features=128, bias=True)
    (3): ELU(alpha=1.0, inplace=True)
  )
)

In [220]:
model.topline_encoder(topline.float(memory_format=torch.contiguous_format).view(B, -1)).shape

torch.Size([1, 128])

In [221]:
model.bottomline_encoder

BottomLinesEncoder(
  (conv_net): Sequential(
    (0): Conv1d(4, 64, kernel_size=(8,), stride=(4,))
    (1): ELU(alpha=1.0, inplace=True)
    (2): Conv1d(64, 128, kernel_size=(4,), stride=(1,))
    (3): ELU(alpha=1.0, inplace=True)
  )
  (fwd_net): Sequential(
    (0): Linear(in_features=4608, out_features=256, bias=True)
    (1): ELU(alpha=1.0)
    (2): Linear(in_features=256, out_features=256, bias=True)
    (3): ELU(alpha=1.0)
  )
)

In [222]:
name = '0'
cur_layer = model.bottomline_encoder.conv_net[int(name)]
value = cur_layer.__class__(
    cur_layer.in_channels // scale, 
    cur_layer.out_channels, 
    kernel_size=cur_layer.kernel_size, 
    bias=cur_layer.bias is not None, 
    stride=cur_layer.stride, 
    padding=cur_layer.padding, 
    dilation=cur_layer.dilation
)
setattr(model.bottomline_encoder.conv_net, name, value)

In [223]:
model.bottomline_encoder

BottomLinesEncoder(
  (conv_net): Sequential(
    (0): Conv1d(2, 64, kernel_size=(8,), stride=(4,))
    (1): ELU(alpha=1.0, inplace=True)
    (2): Conv1d(64, 128, kernel_size=(4,), stride=(1,))
    (3): ELU(alpha=1.0, inplace=True)
  )
  (fwd_net): Sequential(
    (0): Linear(in_features=4608, out_features=256, bias=True)
    (1): ELU(alpha=1.0)
    (2): Linear(in_features=256, out_features=256, bias=True)
    (3): ELU(alpha=1.0)
  )
)

In [224]:
model.bottomline_encoder(bottom_line.float(memory_format=torch.contiguous_format).view(B, -1)).shape

torch.Size([1, 256])

In [225]:
model.screen_encoder

CharColorEncoderResnet(
  (conv_net): Sequential(
    (0): Sequential(
      (0): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
      (2): ResBlock(
        (net): Sequential(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ELU(alpha=1.0, inplace=True)
          (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (5): ELU(alpha=1.0, inplace=True)
        )
      )
      (3): ResBlock(
        (net): Sequential(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ELU(alpha=1.0, inplac

In [226]:
name = '0'
cur_layer = model.screen_encoder.conv_net[0][int(name)]
value = cur_layer.__class__(
    cur_layer.in_channels // scale, 
    cur_layer.out_channels, 
    kernel_size=cur_layer.kernel_size, 
    bias=cur_layer.bias is not None, 
    stride=cur_layer.stride, 
    padding=cur_layer.padding, 
    dilation=cur_layer.dilation
)
setattr(model.screen_encoder.conv_net[0], name, value)

In [232]:
model.screen_encoder.out_size *= scale

In [227]:
# name = '0'
# cur_layer = model.screen_encoder.fc_head[int(name)]
# value = nn.Linear(cur_layer.in_features // scale, cur_layer.out_features)
# setattr(model.screen_encoder.fc_head, name, value)

In [233]:
model.screen_encoder

CharColorEncoderResnet(
  (conv_net): Sequential(
    (0): Sequential(
      (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
      (2): ResBlock(
        (net): Sequential(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ELU(alpha=1.0, inplace=True)
          (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (5): ELU(alpha=1.0, inplace=True)
        )
      )
      (3): ResBlock(
        (net): Sequential(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ELU(alpha=1.0, inplac

In [234]:
tty_chars = (
    obs["tty_chars"][:, 1:-2]
    .contiguous()
    .float(memory_format=torch.contiguous_format)
    .view(B, C, H - 3, W)
)
tty_colors = obs["tty_colors"][:, 1:-2].contiguous().view(B, C, H - 3, W)
tty_cursor = obs["tty_cursor"].contiguous().view(B, -1)

In [235]:
chars, colors = model.screen_encoder._embed(tty_chars, tty_colors)  # 21 x 80
x = model.screen_encoder._stack(chars, colors)
x = model.screen_encoder.conv_net(x)
x = x.view(-1, model.screen_encoder.out_size)

In [236]:
x.shape

torch.Size([1, 4864])

In [237]:
model.screen_encoder(tty_chars, tty_colors).shape

torch.Size([1, 1024])

In [238]:
model.extract_crop_representation

Sequential(
  (0): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ELU(alpha=1.0)
  (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (5): ELU(alpha=1.0)
  (6): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (7): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (8): ELU(alpha=1.0)
  (9): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (10): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (11): ELU(alpha=1.0)
  (12): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (13): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (14): ELU(alpha=1.0)
)

In [239]:
model.extract_crop_representation[0]

Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

In [240]:
name = '0'
cur_layer = model.extract_crop_representation[int(name)]
value = cur_layer.__class__(
    cur_layer.in_channels // scale, 
    cur_layer.out_channels, 
    kernel_size=cur_layer.kernel_size, 
    bias=cur_layer.bias is not None, 
    stride=cur_layer.stride, 
    padding=cur_layer.padding, 
    dilation=cur_layer.dilation
)
setattr(model.extract_crop_representation, name, value)

In [241]:
model.extract_crop_representation

Sequential(
  (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ELU(alpha=1.0)
  (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (5): ELU(alpha=1.0)
  (6): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (7): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (8): ELU(alpha=1.0)
  (9): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (10): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (11): ELU(alpha=1.0)
  (12): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (13): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (14): ELU(alpha=1.0)
)

In [243]:
model(obs).shape

torch.Size([1, 3476])