# Models(5.18)
This tutorial is to demonstrate to use simple policy gradient agents for portfolio management, specifically we deal with different model and how to visualize and train them today

## Step1: Import Packages

In [2]:
%load_ext autoreload
%autoreload 2
import warnings
warnings.filterwarnings("ignore")
import sys
from pathlib import Path
import os
import torch

ROOT = os.path.dirname(os.path.abspath("."))
sys.path.append(ROOT)

import argparse
import os.path as osp
from mmcv import Config
from trademaster.utils import replace_cfg_vals
from trademaster.nets.builder import build_net
from trademaster.environments.builder import build_environment
from trademaster.datasets.builder import build_dataset
from trademaster.agents.builder import build_agent
from trademaster.optimizers.builder import build_optimizer
from trademaster.losses.builder import build_loss
from trademaster.trainers.builder import build_trainer
from trademaster.utils import plot
from trademaster.utils import set_seed
import matplotlib.pyplot as plt
set_seed(2023)

2023-05-25 15:19:54,605	INFO services.py:1470 -- View the Ray dashboard at [1m[32mhttp://127.0.0.1:8265[39m[22m
2023-05-25 15:20:00,230	INFO worker.py:973 -- Calling ray.init() again after it has already been called.
2023-05-25 15:20:00,237	INFO worker.py:973 -- Calling ray.init() again after it has already been called.


## Take a look at Environment

In [3]:
from trademaster.environments.portfolio_management.environment import PortfolioManagementEnvironment
cfg = dict(
    data = dict(
        type = "PortfolioManagementDataset",
        data_path = "data/portfolio_management/sz50",
        train_path = "data/portfolio_management/sz50/train.csv",
        valid_path = "data/portfolio_management/sz50/valid.csv",
        test_path = "data/portfolio_management/sz50/test.csv",
        test_dynamic_path='data/portfolio_management/sz50/test.csv',
        tech_indicator_list = [
            "zopen",
            "zhigh",
            "zlow",
            "zadjcp",
            "zclose",
            "zd_5",
            "zd_10",
            "zd_15",
            "zd_20",
            "zd_25",
            "zd_30"
        ],
        initial_amount = 100000,
        transaction_cost_pct = 0.001
    )
)
cfg=Config(cfg)
dataset = build_dataset(cfg)
cfg2 = dict(dataset=dataset)
env = PortfolioManagementEnvironment(cfg2)


## Play a single step with a random agent

In [4]:
from ray.rllib.agents.pg import PGTrainer 
from ray.tune.registry import register_env
from ray.rllib.models import ModelCatalog
import ray
trainer_cfg = dict(
    rollout_fragment_length = 200,
    # explore = True,
    framework = "torch",
    model = {
        "post_fcnet_hiddens" : [30, 520, 321]
    #     # use_attention = True,
    #     "use_lstm": True,
    #     # Max seq len for training the LSTM, defaults to 20.
    #     "max_seq_len": 200,
    #     # Size of the LSTM cell.
    #     "lstm_cell_size": 1000,
    #     # Whether to feed a_{t-1} to LSTM (one-hot encoded if discrete).
    #     # "lstm_use_prev_action": True,
    #     # Whether to feed r_{t-1} to LSTM.
    #     "lstm_use_prev_reward": False,
    #     # Whether the LSTM is time-major (TxBx..) or batch-major (BxTx..).
    }
    
)
ray.init(ignore_reinit_error=True)
trainer_cfg["env"] = "portfolio_management"
trainer_cfg["env_config"] = dict(dataset=dataset, task="train", device="cpu")
register_env("portfolio_management", lambda config: PortfolioManagementEnvironment(config))
pg_trainer = PGTrainer(trainer_cfg, env="portfolio_management")

2023-05-25 15:20:01,465	INFO worker.py:973 -- Calling ray.init() again after it has already been called.
2023-05-25 15:20:01,541	INFO trainer.py:903 -- Current log_level is WARN. For more information, set 'log_level': 'INFO' / 'DEBUG' or use the -v and -vv flags.


## Print the model
note that this can only be done with pytorch version of the model

In [5]:
pol = pg_trainer.get_policy()
print(pol)
# way1 to print the model paramters
for k, v in pol.get_weights().items():
    print(k, v.shape)
# way2 to print the model
model = pol.model
print(model)
# beside the defined model, 
# we also have flatten and various other predefined components
# which you can look for further details in 
# 'ray.rllib.models.torch.complex_input_net.ComplexInputNetwork'
print(model.flatten)
print(model.one_hot)
print(model.cnns)
# if you want to look into the source code for prototype of the model see here
print(type(model)) # of class 'ray.rllib.models.torch.complex_input_net.ComplexInputNetwork'

PGTorchPolicy
post_fc_stack._hidden_layers.0._model.0.weight (30, 256)
post_fc_stack._hidden_layers.0._model.0.bias (30,)
post_fc_stack._hidden_layers.1._model.0.weight (520, 30)
post_fc_stack._hidden_layers.1._model.0.bias (520,)
post_fc_stack._hidden_layers.2._model.0.weight (321, 520)
post_fc_stack._hidden_layers.2._model.0.bias (321,)
post_fc_stack._value_branch_separate.0._model.0.weight (30, 256)
post_fc_stack._value_branch_separate.0._model.0.bias (30,)
post_fc_stack._value_branch_separate.1._model.0.weight (520, 30)
post_fc_stack._value_branch_separate.1._model.0.bias (520,)
post_fc_stack._value_branch_separate.2._model.0.weight (321, 520)
post_fc_stack._value_branch_separate.2._model.0.bias (321,)
post_fc_stack._value_branch._model.0.weight (1, 321)
post_fc_stack._value_branch._model.0.bias (1,)
logits_layer._model.0.weight (102, 321)
logits_layer._model.0.bias (102,)
value_layer._model.0.weight (1, 321)
value_layer._model.0.bias (1,)
ComplexInputNetwork(
  (post_fc_stack): Fu

## Try and construct and pass a sample_batch into the model

In [6]:
obs = env.reset()
print(obs.shape)
sample_batch = (dict(obs = torch.Tensor([obs, obs])), None, None)
model_output = (model(*sample_batch))
print(model_output, model_output[0].shape)

(11, 50)
(tensor([[-2.5922e-04,  3.7101e-04, -2.4454e-04, -1.4447e-04,  1.9354e-04,
          2.4718e-04, -5.6721e-04,  3.5246e-04,  4.6053e-05,  6.9738e-04,
          3.2152e-04, -1.9862e-04,  1.3373e-04,  7.9480e-05,  4.3697e-04,
          2.3909e-04,  1.6504e-04,  9.2378e-04, -1.7482e-04, -6.6044e-05,
          3.6818e-04,  1.5503e-04,  1.0421e-03, -2.6726e-06,  1.6905e-04,
          5.0499e-04,  1.5690e-04,  3.1352e-04, -2.2751e-05, -4.5401e-04,
          7.2833e-05,  5.3493e-04,  5.8630e-05, -5.0553e-04, -8.5027e-05,
          1.7867e-04, -3.0604e-05,  3.9327e-04, -3.9780e-04, -1.5325e-04,
         -2.6746e-05, -2.7179e-04, -8.5194e-04,  3.7674e-04, -3.5922e-04,
          3.2475e-04,  1.0711e-04,  6.9320e-04,  4.3957e-04,  4.9434e-05,
          6.1676e-04, -1.6000e-04,  4.2720e-05,  5.0777e-04,  5.1159e-04,
         -4.6186e-04, -2.8623e-04, -3.6844e-04,  8.1008e-04,  4.6258e-05,
          1.4441e-04,  1.7621e-04, -4.9282e-04, -4.1605e-04,  7.9114e-04,
          2.8525e-05, -5.577

## Let's get a more complex model and try to train!

In [7]:
trainer_cfg = dict(
    rollout_fragment_length = 200,
    # explore = True,
    framework = "torch",
    model = {
        # "post_fcnet_hiddens" : [256, 256],
        # use_attention = True,
        "use_lstm": True,
        # Max seq len for training the LSTM, defaults to 20.
        "max_seq_len": 1000,
        # Size of the LSTM cell.
        "lstm_cell_size": 64,
        # Whether to feed a_{t-1} to LSTM (one-hot encoded if discrete).
        "lstm_use_prev_action": True,
        # Whether to feed r_{t-1} to LSTM.
        "lstm_use_prev_reward": False,
        # Whether the LSTM is time-major (TxBx..) or batch-major (BxTx..).
    }
    
)
ray.init(ignore_reinit_error=True)
trainer_cfg["env"] = "portfolio_management"
trainer_cfg["env_config"] = dict(dataset=dataset, task="train", device="cpu")
pg_trainer2 = PGTrainer(trainer_cfg, env="portfolio_management")
pol = pg_trainer2.get_policy()
model = pol.model
print(type(model))
print(model)
print(model.flatten)

2023-05-25 15:20:03,557	INFO worker.py:973 -- Calling ray.init() again after it has already been called.


<class 'ray.rllib.models.catalog.ComplexInputNetwork_as_LSTMWrapper'>
ComplexInputNetwork_as_LSTMWrapper(
  (post_fc_stack): FullyConnectedNetwork(
    (_hidden_layers): Sequential()
    (_value_branch_separate): Sequential()
    (_value_branch): SlimFC(
      (_model): Sequential(
        (0): Linear(in_features=256, out_features=1, bias=True)
      )
    )
  )
  (lstm): LSTM(307, 64, batch_first=True)
  (_logits_branch): SlimFC(
    (_model): Sequential(
      (0): Linear(in_features=64, out_features=102, bias=True)
    )
  )
  (_value_branch): SlimFC(
    (_model): Sequential(
      (0): Linear(in_features=64, out_features=1, bias=True)
    )
  )
)
{0: FullyConnectedNetwork(
  (_hidden_layers): Sequential(
    (0): SlimFC(
      (_model): Sequential(
        (0): Linear(in_features=550, out_features=256, bias=True)
        (1): Tanh()
      )
    )
    (1): SlimFC(
      (_model): Sequential(
        (0): Linear(in_features=256, out_features=256, bias=True)
        (1): Tanh()
     

In [83]:
for e in range(30):
    pg_trainer2.train()

tensor([[[ 3.8168e-03,  0.0000e+00,  1.9036e-03,  ...,  1.0189e-01,
          -3.2553e-02,  7.5725e-03],
         [ 1.1450e-02,  2.4450e-03,  1.4594e-02,  ...,  1.1497e-01,
           7.4243e-03,  3.3968e-02],
         [-3.8168e-03, -4.8900e-03, -1.2690e-03,  ..., -2.3895e-02,
          -3.6551e-02, -1.1683e-02],
         ...,
         [ 2.5573e-02,  9.7800e-03,  5.5838e-03,  ..., -2.0595e-01,
           3.5394e-02, -1.6097e-02],
         [ 3.4198e-02,  1.0073e-02,  4.3147e-03,  ..., -2.3140e-01,
           4.0263e-02, -4.9632e-02],
         [ 3.8677e-02,  1.1491e-02,  5.2242e-03,  ..., -2.4836e-01,
           2.8822e-02, -8.6240e-02]],

        [[ 3.8462e-03,  0.0000e+00,  2.5478e-03,  ..., -1.9740e-02,
           8.3694e-03,  1.2547e-02],
         [ 7.6923e-03,  2.4390e-03,  7.6433e-03,  ...,  3.6788e-02,
           3.6652e-02,  2.1352e-02],
         [-3.8462e-03, -2.4390e-03, -1.2739e-03,  ..., -4.8452e-02,
          -1.7316e-03, -1.1006e-02],
         ...,
         [ 2.8462e-02,  6

KeyboardInterrupt: 

## Save the model

In [9]:
from trademaster.utils import get_attr, save_object, load_object

# model level saving & loading
# torch.save(model.state_dict(), "model.pkl")
# model.load_state_dict(torch.load("model.pkl"))

# trainer level saving & loading
obj = pg_trainer2.save_to_object()
save_object(obj, "pg_trainer2.pkl")
obj2 = load_object("pg_trainer2.pkl")
pg_trainer2.restore_from_object(obj2)


2023-05-25 15:20:37,646	INFO trainable.py:588 -- Restored on 127.0.0.1 from checkpoint: /Users/louison/ray_results/PGTrainer_portfolio_management_2023-05-25_15-20-03mka2jyj7/tmpbnnzv1eqrestore_from_object/checkpoint-30
2023-05-25 15:20:37,647	INFO trainable.py:597 -- Current state after restoring: {'_iteration': 30, '_timesteps_total': None, '_time_total': 30.85412883758545, '_episodes_total': 2}


### 2 ways to see weights after training

In [10]:
print((pg_trainer2.get_weights()))
print(next(pg_trainer2.get_policy().model.parameters()))

{'default_policy': {'post_fc_stack._value_branch._model.0.weight': array([[-1.17613155e-04, -1.35803159e-04, -4.87085257e-04,
        -7.34577407e-06, -8.16768501e-04,  2.45048665e-04,
        -8.93007324e-04,  7.75018183e-04,  1.06838932e-04,
        -2.81362154e-04,  9.24324922e-05, -6.81228761e-04,
         3.83171391e-06, -1.67411985e-04,  3.28960537e-04,
         1.07997644e-03,  2.44915660e-04,  3.46449779e-05,
         1.93583808e-04,  7.97267014e-04, -3.90597794e-04,
        -1.84415097e-04, -3.43009451e-04, -1.02459773e-04,
         9.53075767e-04, -7.73487147e-04,  3.37798294e-04,
        -6.73530973e-04, -7.86635850e-04,  7.31635024e-04,
        -1.10107350e-04, -8.38342472e-04,  3.20925727e-04,
         2.58596672e-04,  2.74689344e-04, -5.57639345e-04,
         4.02255449e-04, -1.56585709e-04, -4.38545831e-04,
        -7.02786201e-04,  3.52299568e-04,  1.19831273e-03,
        -2.92071316e-04, -3.97227668e-05, -5.37084299e-04,
        -4.61007439e-04,  4.40112752e-04, -4.487

# How to define your personal model (Different levels of customizability)

## Customize by passing model_config

In [13]:
from ray.rllib.models import ModelCatalog
from ray.rllib.models import MODEL_DEFAULTS
model = ModelCatalog.get_model_v2(env.observation_space, env.action_space, int(env.action_space.shape[0]), model_config=MODEL_DEFAULTS, framework="torch")
print(model)

ComplexInputNetwork(
  (post_fc_stack): FullyConnectedNetwork(
    (_hidden_layers): Sequential()
    (_value_branch_separate): Sequential()
    (_value_branch): SlimFC(
      (_model): Sequential(
        (0): Linear(in_features=256, out_features=1, bias=True)
      )
    )
  )
  (logits_layer): SlimFC(
    (_model): Sequential(
      (0): Linear(in_features=256, out_features=51, bias=True)
    )
  )
  (value_layer): SlimFC(
    (_model): Sequential(
      (0): Linear(in_features=256, out_features=1, bias=True)
    )
  )
)


## Customize by defining your own class

#### First let's get familiar with the API

In [14]:
from ray.rllib.models.torch.complex_input_net import ComplexInputNetwork
ModelCatalog.register_custom_model(model_name="cust_model", model_class=ComplexInputNetwork)
model2 = ModelCatalog.get_model_v2(env.observation_space, env.action_space, int(env.action_space.shape[0]), model_config=MODEL_DEFAULTS, framework="torch", name="cust_model")
print(model2)

ComplexInputNetwork(
  (post_fc_stack): FullyConnectedNetwork(
    (_hidden_layers): Sequential()
    (_value_branch_separate): Sequential()
    (_value_branch): SlimFC(
      (_model): Sequential(
        (0): Linear(in_features=256, out_features=1, bias=True)
      )
    )
  )
  (logits_layer): SlimFC(
    (_model): Sequential(
      (0): Linear(in_features=256, out_features=51, bias=True)
    )
  )
  (value_layer): SlimFC(
    (_model): Sequential(
      (0): Linear(in_features=256, out_features=1, bias=True)
    )
  )
)


#### We can then start to build our own class of Network
The follow code are modified from this class:
ray.rllib.models.torch.complex_input_net.ComplexInputNetwork

In [92]:

from gym.spaces import Box, Discrete, MultiDiscrete
import numpy as np
import tree  # pip install dm_tree

# TODO (sven): add IMPALA-style option.
# from ray.rllib.examples.models.impala_vision_nets import TorchImpalaVisionNet
from ray.rllib.models.torch.misc import (
    normc_initializer as torch_normc_initializer,
    SlimFC,
)
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.modelv2 import ModelV2, restore_original_dimensions
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.utils import get_filter_config
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.spaces.space_utils import flatten_space
from ray.rllib.utils.torch_utils import one_hot

torch, nn = try_import_torch()


class Stok(TorchModelV2, nn.Module):
    """TorchModelV2 concat'ing CNN outputs to flat input(s), followed by FC(s).

    Note: This model should be used for complex (Dict or Tuple) observation
    spaces that have one or more image components.

    The data flow is as follows:

    `obs` (e.g. Tuple[img0, img1, discrete0]) -> `CNN0 + CNN1 + ONE-HOT`
    `CNN0 + CNN1 + ONE-HOT` -> concat all flat outputs -> `out`
    `out` -> (optional) FC-stack -> `out2`
    `out2` -> action (logits) and vaulue heads.
    """

    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        self.original_space = (
            obs_space.original_space
            if hasattr(obs_space, "original_space")
            else obs_space
        )

        nn.Module.__init__(self)
        TorchModelV2.__init__(
            self, self.original_space, action_space, num_outputs, model_config, name
        )

        self.flattened_input_space = flatten_space(self.original_space)

        # Atari type CNNs or IMPALA type CNNs (with residual layers)?
        # self.cnn_type = self.model_config["custom_model_config"].get(
        #     "conv_type", "atari")

        # Build the CNN(s) given obs_space's image components.
        # self.cnns = {}
        # self.one_hot = {}
        # self.flatten_dims = {}
        # self.flatten = {}
        # concat_size = 0
        # for i, component in enumerate(self.flattened_input_space):
        #     size = int(np.product(component.shape))
        #     config = {
        #         "fcnet_hiddens": model_config["fcnet_hiddens"],
        #         "fcnet_activation": model_config.get("fcnet_activation"),
        #         "post_fcnet_hiddens": [],
        #     }
        #     self.flatten[i] = ModelCatalog.get_model_v2(
        #         Box(-1.0, 1.0, (size,), np.float32),
        #         action_space,
        #         num_outputs=None,
        #         model_config=config,
        #         framework="torch",
        #         name="flatten_{}".format(i),
        #     )
        #     self.flatten_dims[i] = size
        #     concat_size += self.flatten[i].num_outputs
        # print(self.flatten)
        # print(self.flatten_dims)

        size = int(np.product(self.flattened_input_space[0].shape))
        # Optional post-concat FC-stack.
        post_fc_stack_config = {
            "fcnet_hiddens": model_config.get("post_fcnet_hiddens", []),
            "fcnet_activation": model_config.get("post_fcnet_activation", "relu"),
        }
        self.post_fc_stack = ModelCatalog.get_model_v2(
            Box(float("-inf"), float("inf"), shape=(size,), dtype=np.float32),
            self.action_space,
            None,
            post_fc_stack_config,
            framework="torch",
            name="post_fc_stack",
        )

        # Actions and value heads.
        self.logits_layer = None
        self.value_layer = None
        self._value_out = None

        if num_outputs:
            # Action-distribution head.
            self.logits_layer = SlimFC(
                in_size=self.post_fc_stack.num_outputs,
                out_size=num_outputs,
                activation_fn=None,
                initializer=torch_normc_initializer(0.01),
            )
            # Create the value branch model.
            self.value_layer = SlimFC(
                in_size=self.post_fc_stack.num_outputs,
                out_size=1,
                activation_fn=None,
                initializer=torch_normc_initializer(0.01),
            )
        else:
            self.num_outputs = size

    @override(ModelV2)
    def forward(self, input_dict, state, seq_lens):
        print(input_dict[SampleBatch.OBS].shape)
        if SampleBatch.OBS in input_dict and "obs_flat" in input_dict:
            orig_obs = input_dict[SampleBatch.OBS]
        else:
            orig_obs = restore_original_dimensions(
                input_dict[SampleBatch.OBS], self.processed_obs_space, tensorlib="torch"
            )
        # Push observations through the different components
        # (CNNs, one-hot + FC, etc..).
        outs = []
        # for i, component in enumerate(tree.flatten(orig_obs)):
        #     nn_out, _ = self.flatten[i](
        #         SampleBatch(
        #             {
        #                 SampleBatch.OBS: torch.reshape(
        #                     component, [-1, self.flatten_dims[i]]
        #                 )
        #             }
        #         )
        #     )
        outs.append(tree.flatten(orig_obs)[0])
        print(outs[0].shape)

        # Concat all outputs and the non-image inputs.
        out = torch.cat(outs, dim=1)
        # Push through (optional) FC-stack (this may be an empty stack).
        out, _ = self.post_fc_stack(SampleBatch({SampleBatch.OBS: out}))
        print(outs[0].shape)

        # No logits/value branches.
        if self.logits_layer is None:
            return out, []

        # Logits- and value branches.
        logits, values = self.logits_layer(out), self.value_layer(out)
        self._value_out = torch.reshape(values, [-1])
        return logits, []

    @override(ModelV2)
    def value_function(self):
        return self._value_out


In [93]:
# model2 = Stok(env.observation_space, env.action_space, int(env.action_space.shape[0]), model_config=MODEL_DEFAULTS, name="fuck")
# print(model2)
ModelCatalog.register_custom_model(model_name="cust_model", model_class=Stok)
trainer_cfg = dict(
    rollout_fragment_length = 200,
    # explore = True,
    framework = "torch",
    model = {
    #     # use_attention = True,
    #     "use_lstm": True,
    #     # Max seq len for training the LSTM, defaults to 20.
    #     "max_seq_len": 200,
    #     # Size of the LSTM cell.
    #     "lstm_cell_size": 1000,
    #     # Whether to feed a_{t-1} to LSTM (one-hot encoded if discrete).
    #     # "lstm_use_prev_action": True,
    #     # Whether to feed r_{t-1} to LSTM.
    #     "lstm_use_prev_reward": False,
    #     # Whether the LSTM is time-major (TxBx..) or batch-major (BxTx..).
        "custom_model": "cust_model"
    }
    
)
ray.init(ignore_reinit_error=True)
trainer_cfg["env"] = "portfolio_management"
trainer_cfg["env_config"] = dict(dataset=dataset, task="train", device="cpu")
register_env("portfolio_management", lambda config: PortfolioManagementEnvironment(config))
pg_trainer = PGTrainer(trainer_cfg, env="portfolio_management")

2023-05-25 21:38:16,392	INFO worker.py:973 -- Calling ray.init() again after it has already been called.


torch.Size([32, 11, 50])
torch.Size([32, 11, 50])
torch.Size([32, 11, 50])
tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        ...,

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
  

In [94]:
for e in range(30):
    pg_trainer.train()

torch.Size([1, 11, 50])
torch.Size([1, 11, 50])
torch.Size([1, 11, 50])
torch.Size([1, 11, 50])
torch.Size([1, 11, 50])
torch.Size([1, 11, 50])
torch.Size([1, 11, 50])
torch.Size([1, 11, 50])
torch.Size([1, 11, 50])
torch.Size([1, 11, 50])
torch.Size([1, 11, 50])
torch.Size([1, 11, 50])
torch.Size([1, 11, 50])
torch.Size([1, 11, 50])
torch.Size([1, 11, 50])
torch.Size([1, 11, 50])
torch.Size([1, 11, 50])
torch.Size([1, 11, 50])
torch.Size([1, 11, 50])
torch.Size([1, 11, 50])
torch.Size([1, 11, 50])
torch.Size([1, 11, 50])
torch.Size([1, 11, 50])
torch.Size([1, 11, 50])
torch.Size([1, 11, 50])
torch.Size([1, 11, 50])
torch.Size([1, 11, 50])
torch.Size([1, 11, 50])
torch.Size([1, 11, 50])
torch.Size([1, 11, 50])
torch.Size([1, 11, 50])
torch.Size([1, 11, 50])
torch.Size([1, 11, 50])
torch.Size([1, 11, 50])
torch.Size([1, 11, 50])
torch.Size([1, 11, 50])
torch.Size([1, 11, 50])
torch.Size([1, 11, 50])
torch.Size([1, 11, 50])
torch.Size([1, 11, 50])
torch.Size([1, 11, 50])
torch.Size([1, 1

KeyboardInterrupt: 