In [None]:
# ----------------------------------------------------------------------
# Numenta Platform for Intelligent Computing (NuPIC)
# Copyright (C) 2019, Numenta, Inc.  Unless you have an agreement
# with Numenta, Inc., for a separate license for this software code, the
# following terms and conditions apply:
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero Public License version 3 as
# published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU Affero Public License for more details.
#
# You should have received a copy of the GNU Affero Public License
# along with this program.  If not, see http://www.gnu.org/licenses.
#
# http://numenta.org/licenses/
# ----------------------------------------------------------------------
import abc

import numpy as np
import torch
import torch.nn as nn

import k_winners_base as F
from duty_cycle_metrics import binary_entropy, max_entropy


def update_boost_strength(m):
    """Function used to update KWinner modules boost strength after each epoch.

    Call using :meth:`torch.nn.Module.apply` after each epoch if required
    For example: ``m.apply(update_boost_strength)``

    :param m: KWinner module
    """
    if isinstance(m, KWinnersBase):
        m.update_boost_strength()


class KWinnersBase(nn.Module, metaclass=abc.ABCMeta):
    def __init__(
        self,
        percent_on,
        k_inference_factor=1.0,
        boost_strength=1.0,
        boost_strength_factor=1.0,
        duty_cycle_period=1000,
    ):
        """Base KWinners class.

        :param percent_on:
          The activity of the top k = percent_on * number of input units will be
          allowed to remain, the rest are set to zero.
        :type percent_on: float

        :param k_inference_factor:
          During inference (training=False) we increase percent_on by this factor.
          percent_on * k_inference_factor must be strictly less than 1.0, ideally much
          lower than 1.0
        :type k_inference_factor: float

        :param boost_strength:
          boost strength (0.0 implies no boosting). Must be >= 0.0
        :type boost_strength: float

        :param boost_strength_factor:
          Boost strength factor to use [0..1]
        :type boost_strength_factor: float

        :param duty_cycle_period:
          The period used to calculate duty cycles
        :type duty_cycle_period: int
        """
        super(KWinnersBase, self).__init__()
        assert boost_strength >= 0.0
        assert 0.0 <= boost_strength_factor <= 1.0
        assert 0.0 < percent_on < 1.0
        assert 0.0 < percent_on * k_inference_factor < 1.0

        self.percent_on = percent_on
        self.percent_on_inference = percent_on * k_inference_factor
        self.k_inference_factor = k_inference_factor
        self.learning_iterations = 0
        self.n = 0
        self.k = 0
        self.k_inference = 0

        # Boosting related parameters
        self.boost_strength = boost_strength
        self.boost_strength_factor = boost_strength_factor
        self.duty_cycle_period = duty_cycle_period

    def extra_repr(self):
        return (
            "n={0}, percent_on={1}, boost_strength={2}, boost_strength_factor={3}, "
            "k_inference_factor={4}, duty_cycle_period={5}".format(
                self.n,
                self.percent_on,
                self.boost_strength,
                self.boost_strength_factor,
                self.k_inference_factor,
                self.duty_cycle_period,
            )
        )

    @abc.abstractmethod
    def update_duty_cycle(self, x):
        r"""Updates our duty cycle estimates with the new value. Duty cycles are
        updated according to the following formula:

        .. math::
            dutyCycle = \frac{dutyCycle \times \left( period - batchSize \right)
                                + newValue}{period}

        :param x:
          Current activity of each unit
        """
        raise NotImplementedError

    def update_boost_strength(self):
        """Update boost strength using given strength factor during
        training.
        """
        if self.training:
            self.boost_strength = self.boost_strength * self.boost_strength_factor

    def entropy(self):
        """Returns the current total entropy of this layer."""
        _, entropy = binary_entropy(self.duty_cycle)
        return entropy

    def max_entropy(self):
        """Returns the maximum total entropy we can expect from this layer."""
        return max_entropy(self.n, int(self.n * self.percent_on))


class KWinners(KWinnersBase):
    def __init__(
        self,
        n,
        percent_on,
        k_inference_factor=1.5,
        stoch_sd=0.0,
        boost_strength=1.0,
        boost_strength_factor=0.9,
        duty_cycle_period=1000,
    ):
        """Applies K-Winner function to the input tensor.

        See :class:`htmresearch.frameworks.pytorch.functions.k_winners`
        :param n:
          Number of units
        :type n: int

        :param percent_on:
          The activity of the top k = percent_on * n will be allowed to remain, the
          rest are set to zero.
        :type percent_on: float

        :param k_inference_factor:
          During inference (training=False) we increase percent_on by this factor.
          percent_on * k_inference_factor must be strictly less than 1.0, ideally much
          lower than 1.0
        :type k_inference_factor: float

        :param boost_strength:
          boost strength (0.0 implies no boosting).
        :type boost_strength: float

        :param boost_strength_factor:
          Boost strength factor to use [0..1]
        :type boost_strength_factor: float

        :param duty_cycle_period:
          The period used to calculate duty cycles
        :type duty_cycle_period: int
        """
        super(KWinners, self).__init__(
            percent_on=percent_on,
            k_inference_factor=k_inference_factor,
            boost_strength=boost_strength,
            boost_strength_factor=boost_strength_factor,
            duty_cycle_period=duty_cycle_period,
        )
        self.n = n
        self.k = int(round(n * percent_on))
        self.k_inference = int(self.k * self.k_inference_factor)
        self.stoch_sd = stoch_sd
        self.register_buffer("duty_cycle", torch.zeros(self.n))

    def forward(self, x):
        if self.stoch_sd:
            k = int(np.random.normal(self.k, self.stoch_sd))
        else:
            k = self.k
        if self.training:

            x = F.KWinners.apply(x, self.duty_cycle, k, self.boost_strength)

            self.update_duty_cycle(x)
        else:
            x = F.KWinners.apply(
                x,
                self.duty_cycle,
                int(k * self.k_inference_factor),
                self.boost_strength,
            )

        return x

    def update_duty_cycle(self, x):

        batch_size = x.shape[0]      
        self.learning_iterations += batch_size
        period = min(self.duty_cycle_period, self.learning_iterations)
        self.duty_cycle.mul_(period - batch_size)
        self.duty_cycle.add_(x.gt(0).sum(dim=0, dtype=torch.float)) # added [0,0,:] to x
        self.duty_cycle.div_(period)


In [None]:
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from torch import nn

from active_dendrite import ActiveDendriteLayer
# from nupic.torch.modules.k_winners import KWinners
from k_winners import KWinners
from sparse_weights import SparseWeights
from util import activity_square, count_parameters, get_grad_printer


def topk_mask(x, k=2):
    """
    Simple functional version of KWinnersMask/KWinners since
    autograd function apparently not currently exportable by JIT
    """
    res = torch.zeros_like(x)
    topk, indices = x.topk(k, sorted=False)
    return res.scatter(-1, indices, 1)


class RSMPredictor(torch.nn.Module):
    def __init__(self, d_in=28 * 28, d_out=10, hidden_size=20):
        """
        """
        super(RSMPredictor, self).__init__()
        self.d_in = d_in
        self.d_out = d_out
        self.hidden_size = hidden_size

        self.layers = nn.Sequential(
            nn.Linear(self.d_in, self.hidden_size),
            nn.LeakyReLU(),
            nn.Linear(self.hidden_size, self.d_out),
            nn.LeakyReLU(),
            nn.Softmax(dim=1),
        )
        self._init_linear_weights()

    def forward(self, x):
        """
        Receive input as hidden memory state from RSM, batch
        x^B is with shape (batch_size, total_cells)
        Output is two tensors of shape (batch_size, d_out) being distribution and
        logits respectively.
        """

        x1 = None
        x2 = x
        for layer in self.layers:
            x1 = x2
            x2 = layer(x1)

        # Return multiple outputs
        # https://discuss.pytorch.org/t/a-model-with-multiple-outputs/10440
        logits = x1.view(-1, self.d_out)

        distribution = x2.view(-1, self.d_out)
        return distribution, logits

    def _init_linear_weights(self):
        for mod in self.modules():
            if isinstance(mod, nn.Linear):
                sd = 0.03
                mod.weight.data.normal_(0.0, sd)
                if mod.bias is not None:
                    mod.bias.data.normal_(0.0, sd)


class RSMNet(torch.nn.Module):
    def __init__(self, n_layers=1, **kwargs):
        super(RSMNet, self).__init__()
        self.n_layers = n_layers
        self.hooks_registered = False

        eps_arr = self._parse_param_array(kwargs["eps"])
        k_winners_arr = self._parse_param_array(kwargs["k"])
        boost_strength_arr = self._parse_param_array(kwargs["boost_strength"])
        duty_cycle_period_arr = self._parse_param_array(
            kwargs.get("duty_cycle_period", 1000)
        )
        m_arr = self._parse_param_array(kwargs["m"])
        n_arr = self._parse_param_array(kwargs["n"])
        last_output_dim = None
        self.total_cells = []
        for i in range(n_layers):
            first_layer = i == 0
            top_layer = i == n_layers - 1
            if not first_layer:
                kwargs["d_in"] = last_output_dim
                # Output is of same dim as input (predictive autoencoder)
                kwargs["d_out"] = kwargs["d_in"]
            if not top_layer:
                kwargs["d_above"] = m_arr[i + 1] * n_arr[i + 1]
            if kwargs.get("lateral_conn", True):
                if top_layer:
                    kwargs["lateral_conn"] = kwargs.get("top_lateral_conn", False)
            kwargs["eps"] = eps_arr[i]
            kwargs["m"] = m_arr[i]
            kwargs["n"] = n_arr[i]
            kwargs["k"] = k_winners_arr[i]
            kwargs["boost_strength"] = boost_strength_arr[i]
            kwargs["duty_cycle_period"] = duty_cycle_period_arr[i]
            self.total_cells.append(kwargs["m"] * kwargs["n"])
            last_output_dim = kwargs["m"] * kwargs["n"]
            self.add_module("RSM_%d" % (i + 1), RSMLayer(**kwargs))

        print("Created RSMNet with %d layer(s)" % n_layers)

    def _parse_param_array(self, param_val):
        param_by_layer = param_val
        if not isinstance(param_by_layer, list):
            param_by_layer = [param_by_layer for x in range(self.n_layers)]
        return param_by_layer

    def _zero_sparse_weights(self):
        for mod in self.children():
            mod._zero_sparse_weights()

    def _zero_kwinner_boost(self):
        # Zero KWinner boost strengths since learning in RSM is pausing
        for layer in self.children():
            for mod in layer.children():
                if isinstance(mod, KWinners) and mod.boost_strength_factor < 1.0:
                    print("Zeroing boost strength for %s" % mod)
                    mod.boost_strength = 0.0

    def _register_hooks(self):
        if not self.hooks_registered:
            for mod in self.children():
                mod._register_hooks()
        self.hooks_registered = True

    def forward(self, x_a_batch, hidden):
        """
        Each layer takes input (image batch from time sequence for first layer,
        batch of hidden states from prior layer otherwise), and generates both:
            - a prediction for the next input it will see
            - a hidden state which is passed to the next layer
        Arguments:
            x_a_batch: (bsz, d_in)
            hidden: Tuple (x_b, phi, psi), each Tensor (n_layers, bsz, total_cells)
                - x_b is (possibly normalized) winners without hysteresis/decayed memory
        Returns:
            output_by_layer: List of tensors
                (n_layers, bsz, dim (total_cells or d_in for first layer))
            new_hidden: Tuple of tensors (n_layers, bsz, total_cells)
        """
        output_by_layer = []

        new_x_b = []
        new_phi = []
        new_psi = []

        x_b, phi, psi = hidden
        layer_input = x_a_batch

        lid = 0
        for (_layer_name, layer), lay_phi, lay_psi in zip(
            self.named_children(), phi, psi
        ):
            last_layer = lid == len(phi) - 1
            lay_above = list(self.children())[lid + 1] if not last_layer else None
            lay_x_b = x_b[lid]
            lay_x_above = x_b[lid + 1] if not last_layer else None

            # Update memory psi with prior step winners and apply decay as per config
            if lay_x_above is not None:
                psi_above = psi[lid + 1]
                lay_x_above = lay_above._decay_memory(psi_above, lay_x_above)
            if lay_x_b is not None:
                lay_x_b = layer._decay_memory(lay_psi, lay_x_b)

            hidden_in = (lay_x_b, lay_x_above, lay_phi, lay_psi)

            pred_output, hidden = layer(layer_input, hidden_in)

            # If layers > 1, higher layers predict lower layer's phi
            # phi has hysteresis (if decay active), x_b is just winners
            layer_input = hidden[1]  # phi

            new_x_b.append(hidden[0])
            new_phi.append(hidden[1])
            new_psi.append(hidden[2])

            output_by_layer.append(pred_output)

            lid += 1

        new_hidden = (new_x_b, new_phi, new_psi)

        return (output_by_layer, new_hidden)

    def _post_train_epoch(self, epoch):
        for mod in self.children():
            mod._post_epoch(epoch)

    def init_hidden(self, batch_size):
        param = next(self.parameters())
        x_b = [
            param.new_zeros((batch_size, tc), dtype=torch.float32, requires_grad=False)
            for tc in self.total_cells
        ]
        phi = [
            param.new_zeros((batch_size, tc), dtype=torch.float32, requires_grad=False)
            for tc in self.total_cells
        ]
        psi = [
            param.new_zeros((batch_size, tc), dtype=torch.float32, requires_grad=False)
            for tc in self.total_cells
        ]
        return (x_b, phi, psi)


class RSMLayer(torch.nn.Module):
    ACT_FNS = {"tanh": torch.tanh, "relu": F.relu, "sigmoid": torch.sigmoid}

    def __init__(
        self,
        d_in=28 * 28,
        d_out=28 * 28,
        d_above=None,
        m=200,
        n=6,
        k=25,
        k_winner_cells=1,
        gamma=0.5,
        eps=0.5,
        activation_fn="tanh",
        decode_activation_fn=None,
        embed_dim=0,
        vocab_size=0,
        decode_from_full_memory=False,
        debug_log_names=None,
        boost_strat="rsm_inhibition",
        x_b_norm=False,
        boost_strength=1.0,
        duty_cycle_period=1000,
        mult_integration=False,
        boost_strength_factor=1.0,
        forget_mu=0.0,
        weight_sparsity=None,
        feedback_conn=False,
        input_bias=False,
        decode_bias=True,
        lateral_conn=True,
        col_output_cells=False,
        debug=False,
        visual_debug=False,
        fpartition=None,
        balance_part_winners=False,
        trainable_decay=False,
        trainable_decay_rec=False,
        max_decay=1.0,
        mem_floor=0.0,
        additive_decay=False,
        stoch_decay=False,
        stoch_k_sd=0.0,
        rec_active_dendrites=0,
        **kwargs,
    ):
        """
        This class includes an attempted replication of the Recurrent Sparse Memory
        architecture suggested by by
        [Rawlinson et al 2019](https://arxiv.org/abs/1905.11589).
        Parameters allow experimentation with a wide array of adjustments to this model,
        both minor and major. Classes of models tested include:
        * "Adjusted" model with k-winners and column boosting, 2 cell winners,
            no inhibition
        * "Flattened" model with 1 cell per column, 1000 cols, 25 winners
            and multiplicative integration of FF & recurrent input
        * "Flat Partitioned" model with 120 winners, and cells partitioned into three
            functional types: ff only, recurrent only, and optionally a region that
            integrates both.
        :param d_in: Dimension of input
        :param m: Number of groups/columns
        :param n: Cells per group/column
        :param k: # of groups/columns to win in topk() (sparsity)
        :param k_winner_cells: # of winning cells per column
        :param gamma: Inhibition decay rate (0-1)
        :param eps: Integrated encoding decay rate (0-1)
        """
        super(RSMLayer, self).__init__()
        self.k = int(k)
        self.k_winner_cells = k_winner_cells
        self.m = m
        self.n = n
        self.gamma = gamma
        self.eps = eps
        self.d_in = d_in
        self.d_out = d_out
        self.d_above = d_above
        self.forget_mu = float(forget_mu)

        self.total_cells = m * n
        self.flattened = self.total_cells == self.m

        # Tweaks
        self.activation_fn = activation_fn
        self.decode_activation_fn = decode_activation_fn
        self.decode_from_full_memory = decode_from_full_memory
        self.boost_strat = boost_strat
        self.x_b_norm = x_b_norm
        self.boost_strength = boost_strength
        self.boost_strength_factor = boost_strength_factor
        self.duty_cycle_period = duty_cycle_period
        self.mult_integration = mult_integration
        self.fpartition = fpartition
        if isinstance(self.fpartition, float):
            # Handle simple single-param FF-percentage only
            # If fpartition is list, interpreted as [ff_pct, rec_pct]
            self.fpartition = [self.fpartition, 1.0 - self.fpartition]
        self.balance_part_winners = balance_part_winners
        self.weight_sparsity = weight_sparsity
        self.feedback_conn = feedback_conn
        self.input_bias = input_bias
        self.decode_bias = decode_bias
        self.lateral_conn = lateral_conn
        self.trainable_decay = trainable_decay
        self.trainable_decay_rec = trainable_decay_rec
        self.max_decay = max_decay
        self.additive_decay = additive_decay
        self.stoch_decay = stoch_decay
        self.col_output_cells = col_output_cells
        self.stoch_k_sd = stoch_k_sd
        self.rec_active_dendrites = rec_active_dendrites
        self.mem_floor = mem_floor

        self.debug = debug
        self.visual_debug = visual_debug
        self.debug_log_names = debug_log_names

        self._build_layers_and_kwinners()

        if self.additive_decay:
            decay_init = torch.ones(self.total_cells, dtype=torch.float32).uniform_(
                -3.0, 3.0
            )
        elif self.stoch_decay:
            # Fixed random decay rates, test with trainable_decay = False
            decay_init = torch.ones(self.total_cells, dtype=torch.float32).uniform_(
                -3.0, 3.0
            )
        else:
            decay_init = self.eps * torch.ones(self.total_cells, dtype=torch.float32)
        self.decay = nn.Parameter(decay_init, requires_grad=self.trainable_decay)
        self.register_parameter("decay", self.decay)
        self.learning_iterations = 0
        self.register_buffer("duty_cycle", torch.zeros(self.total_cells))

        print(
            "Created %s with %d trainable params" % (str(self), count_parameters(self))
        )

    def __str__(self):
        fp = ""
        if self.fpartition:
            fp = " partition=(%.2f,%.2f)" % (self.fpartition[0], self.fpartition[1])
        return "<RSMLayer m=%d n=%d k=%d d_in=%d eps=%.2f%s />" % (
            self.m,
            self.n,
            self.k,
            self.d_in,
            self.eps,
            fp,
        )

    def _debug_log(self, tensor_dict, truncate_len=400):
        if self.debug:
            for name, t in tensor_dict.items():
                if not self.debug_log_names or name in self.debug_log_names:
                    _type = type(t)
                    if _type in [int, float, bool]:
                        size = "-"
                    else:
                        size = t.size()
                        _type = t.dtype
                        if t.numel() > truncate_len:
                            t = "..truncated.."
                    print([name, t, size, _type])
        if self.visual_debug:
            for name, t in tensor_dict.items():
                if not self.debug_log_names or name in self.debug_log_names:
                    if isinstance(t, torch.Tensor):
                        t = t.detach().squeeze()
                        if t.dim() == 1:
                            t = t.flatten()
                            size = t.numel()
                            is_cell_level = t.numel() == self.total_cells and self.n > 1
                            if is_cell_level:
                                plt.imshow(
                                    t.view(self.m, self.n).t(),
                                    origin="bottom",
                                    extent=(0, self.m - 1, 0, self.n),
                                )
                            else:
                                plt.imshow(activity_square(t))
                            tmin = t.min()
                            tmax = t.max()
                            tsum = t.sum()
                            plt.title(
                                "%s (%s, rng: %.3f-%.3f, sum: %.3f)"
                                % (name, size, tmin, tmax, tsum)
                            )
                            plt.show()

    def _build_layers_and_kwinners(self):
        self.sparse_mods = []
        if self.fpartition:
            m_ff, m_int, m_rec = self._partition_sizes()
            # Partition memory into fpartition % FF & remainder recurrent
            self.linear_a = nn.Linear(self.d_in, m_ff, bias=self.input_bias)
            self.linear_b = nn.Linear(
                self.total_cells, m_rec, bias=self.input_bias
            )  # Recurrent weights (per cell)
            if m_int:
                # Add two additional layers for integrating ff & rec input
                # NOTE: Testing int layer that gets only input from prior int (no ff)
                self.linear_a_int = nn.Linear(self.d_in, m_int, bias=self.input_bias)
                self.linear_b_int = nn.Linear(
                    self.total_cells, m_int, bias=self.input_bias
                )
        else:
            # Standard architecture, no partition
            self.linear_a = nn.Linear(
                self.d_in, self.m, bias=self.input_bias
            )  # Input weights (shared per group / proximal)
            if self.lateral_conn:
                d1 = d2 = self.total_cells
                if self.col_output_cells:
                    d1 += self.m  # One output per column
                # Recurrent weights (per cell)
                if self.rec_active_dendrites:
                    sparsity = 0.3
                    self.linear_b = ActiveDendriteLayer(
                        d1,
                        n_cells=d2,
                        n_dendrites=self.rec_active_dendrites,
                        sparsity=sparsity,
                    )
                    if sparsity:
                        self.sparse_mods.append(self.linear_b.linear_dend)
                else:
                    self.linear_b = nn.Linear(d1, d2, bias=self.input_bias)

            if self.feedback_conn:
                # Linear layers for both recurrent input from above and below
                self.linear_b_above = nn.Linear(
                    self.d_above, self.total_cells, bias=self.input_bias
                )

        pct_on = self.k / self.m
        if self.fpartition and self.balance_part_winners:
            # Create a kwinners module for each partition each with specified
            # size but same pct on (balanced).
            self.kwinners_ff = self.kwinners_rec = self.kwinners_int = None
            if m_ff:
                self.kwinners_ff = self._build_kwinner_mod(m_ff, pct_on)
            if m_int:
                self.kwinners_int = self._build_kwinner_mod(m_int, pct_on)
            if m_rec:
                self.kwinners_rec = self._build_kwinner_mod(m_rec, pct_on)
        else:
            # We need only a single kwinners to run on full memory
            self.kwinners_col = self._build_kwinner_mod(self.m, pct_on)
        if self.weight_sparsity is not None:
            self.linear_a = SparseWeights(self.linear_a, self.weight_sparsity)
            self.linear_b = SparseWeights(self.linear_b, self.weight_sparsity)
            self.sparse_mods.extend([self.linear_a, self.linear_b])

        # Decode linear
        decode_d_in = self.total_cells if self.decode_from_full_memory else self.m
        self.linear_d = nn.Linear(decode_d_in, self.d_out, bias=self.decode_bias)

        if self.trainable_decay_rec:
            self.linear_decay_rec = nn.Linear(
                self.total_cells, self.total_cells, bias=True
            )

        self._init_linear_weights()

    def _init_linear_weights(self):
        for mod in self.modules():
            if isinstance(mod, nn.Linear):
                sd = 0.03
                mod.weight.data.normal_(0.0, sd)
                if mod.bias is not None:
                    mod.bias.data.normal_(0.0, sd)

    def _zero_sparse_weights(self):
        for mod in self.sparse_mods:
            if isinstance(mod, SparseWeights):
                mod.rezero_weights()

    def _partition_sizes(self):
        pct_ff, pct_rec = self.fpartition
        m_ff = int(round(pct_ff * self.m))
        m_rec = int(round(pct_rec * self.m))
        m_int = self.m - m_ff - m_rec
        return (m_ff, m_int, m_rec)

    def _build_kwinner_mod(self, m, pct_on):
        return KWinners(
            m,
            pct_on,
            boost_strength=self.boost_strength,
            duty_cycle_period=self.duty_cycle_period,
            k_inference_factor=1.0,
            stoch_sd=self.stoch_k_sd,
            boost_strength_factor=self.boost_strength_factor,
        )

    def _post_epoch(self, epoch):
        # Update boost strength of any KWinners modules
        for mod in self.modules():
            if hasattr(mod, "update_boost_strength"):
                mod.update_boost_strength()

    def _register_hooks(self):
        """Utility function to call retain_grad and Pytorch's register_hook
        in a single line
        """
        for label, t in [
            # ('y', self.y),
            # ('sigma', self.sigma),
            ("linear_b grad", self.linear_b.weight)
        ]:
            t.retain_grad()
            t.register_hook(get_grad_printer(label))

    def _decay_memory(self, psi_last, x_b):
        if self.trainable_decay_rec:
            decay_param = self.max_decay * torch.sigmoid(
                self.linear_decay_rec(psi_last)
            )
        elif self.trainable_decay:
            decay_param = self.max_decay * torch.sigmoid(self.decay)
        else:
            decay_param = self.eps
        
        updated = decay_param * psi_last
        if self.mem_floor:
            updated[updated <= self.mem_floor] = 0.0
        memory = torch.max(updated, x_b)
        return memory

    def _do_forgetting(self, phi, psi):
        bsz = phi.size(0)
        if self.training and self.forget_mu > 0:
            keep_idxs = torch.rand(bsz) > self.forget_mu
            mask = torch.zeros_like(phi)
            mask[keep_idxs, :] = 1
            phi = phi * mask
            psi = psi * mask
        return (phi, psi)

    def _group_max(self, activity):
        """
        :param activity: activity vector (bsz x total_cells)
        Returns max cell activity in each group
        """
        return activity.view(-1, self.m, self.n).max(dim=2).values

    def _fc_weighted_ave(self, x_a, x_b, x_b_above=None):
        """
        Compute sigma (weighted sum for each cell j in group i (mxn))
        """
        if self.fpartition:
            m_ff, m_int, m_rec = self._partition_sizes()
            sigma = torch.zeros_like(x_b)
            # Integrate partitioned memory.
            # Pack as 1xm: [ ... m_ff ... ][ ... m_int ... ][ ... m_rec ... ]
            # If m_int non-zero, these cells receive sum of FF & recurrent input
            z_a = self.linear_a(x_a)  # bsz x (m_ff)
            z_log = {"z_a": z_a}
            if m_int:
                z_b = self.linear_b(x_b)  # bsz x m_rec
                z_int_ff = self.linear_a_int(x_a)
                # NOTE: Testing from only int/rec portion of mem (no ff)
                z_int_rec = self.linear_b_int(x_b)
                z_int = (
                    z_int_ff * z_int_rec
                    if self.mult_integration
                    else z_int_ff + z_int_rec
                )
                sigma = torch.cat((z_a, z_int, z_b), 1)  # bsz x m
            else:
                z_b = self.linear_b(x_b)  # bsz x m_rec
                z_log["z_b"] = z_b
                sigma = torch.cat((z_a, z_b), 1)  # bsz x m
        else:
            # Col activation from inputs repeated for each cell
            z_a = self.linear_a(x_a).repeat_interleave(self.n, 1)

            sigma = z_a
            z_log = {"z_a": z_a}

            # Cell activation from recurrent (lateral) input
            if self.lateral_conn:
                z_b_in = x_b
                if self.col_output_cells:
                    z_b_in = torch.cat((z_b_in, self._group_max(x_b)), dim=1)
                z_b = self.linear_b(z_b_in)
                sigma = sigma * z_b if self.mult_integration else sigma + z_b
                z_log["z_b"] = z_b
            # Activation from recurrent (feedback) input
            if self.feedback_conn:
                if x_b_above is not None:
                    # Cell activation from recurrent input from layer above (apical)
                    z_b_above = self.linear_b_above(x_b_above)
                    z_log["z_b_above"] = z_b_above
                    if self.mult_integration:
                        sigma = sigma * z_b_above
                    else:
                        sigma = sigma + z_b_above
        self._debug_log(z_log)

        return sigma  # total_cells

    def _update_duty_cycle(self, winners):
        """
        For tracking layer entropy (across both inhibition/boosting approaches)
        """
        batch_size = winners.shape[0]
        self.learning_iterations += batch_size
        period = min(1000, self.learning_iterations)
        self.duty_cycle.mul_(period - batch_size)

        self.duty_cycle.add_(winners.gt(0).sum(dim=0, dtype=torch.float))
        self.duty_cycle.div_(period)

    def _k_winners(self, sigma, pi):
        bsz = pi.size(0)

        # Group-wise max pooling
        if not self.flattened:
            lambda_ = self._group_max(pi)
        else:
            lambda_ = pi

        # Cell-level mask: Make a bsz x total_cells binary mask of top 1 cell / column
        if self.n == self.k_winner_cells:
            # Usually just in flattened case, no need to choose winners
            m_pi = torch.ones(bsz, self.total_cells, device=sigma.device)
        else:
            mask = topk_mask(pi.view(bsz * self.m, self.n), self.k_winner_cells)
            m_pi = mask.view(bsz, self.total_cells).detach()

        if self.boost_strat == "rsm_inhibition":
            # Standard RSM-style inhibition via phi matrix

            self._debug_log({"lambda_": lambda_})

            # Column-level mask: Make a bsz x total_cells binary mask of top k columns
            mask = topk_mask(lambda_, self.k)
            m_lambda = (
                mask.view(bsz, self.m, 1)
                .repeat(1, 1, self.n)
                .view(bsz, self.total_cells)
                .detach()
            )
            col_winners = m_lambda

            self._debug_log({"m_pi": m_pi, "m_lambda": m_lambda})

            y_pre_act = m_pi * m_lambda * sigma

        elif self.boost_strat == "col_boosting":
            # HTM style boosted k-winner
            if self.balance_part_winners and self.fpartition:
                m_ff, m_int, m_rec = self._partition_sizes()
                winners = []
                if self.kwinners_ff is not None:
                    winners_ff = self.kwinners_ff(lambda_[:, :m_ff])
                    winners.append(winners_ff)
                if self.kwinners_int is not None:
                    winners_int = self.kwinners_int(lambda_[:, m_ff : m_ff + m_int])
                    winners.append(winners_int)
                if self.kwinners_rec is not None:
                    winners_rec = self.kwinners_rec(lambda_[:, -m_rec:])
                    winners.append(winners_rec)
                m_lambda = (torch.cat(winners, 1).abs() > 0).float()
            else:
                
                winning_cols = (
                    self.kwinners_col(lambda_).view(bsz, self.m, 1).abs() > 0
                ).float()
              
                m_lambda = winning_cols.repeat(1, 1, self.n).view(bsz, self.total_cells)

                col_winners = winning_cols

            self._debug_log({"m_lambda": m_lambda})

            y_pre_act = m_pi * m_lambda * sigma

        self._update_duty_cycle(col_winners.squeeze())

        del m_pi
        del m_lambda

        return y_pre_act

    def _inhibited_winners(self, sigma, phi):
        """
        Compute y_lambda
        """
        # Apply inhibition to non-neg shifted sigma
        inh = (1 - phi) if self.boost_strat == "rsm_inhibition" else 1
        pi = inh * (sigma - sigma.min() + 1)
        self._debug_log({"pi": pi})

        pi = pi.detach()  # Prevent gradients from flowing through inhibition/masking

        y_pre_act = self._k_winners(sigma, pi)

        activation = RSMLayer.ACT_FNS[self.activation_fn]
        y = activation(y_pre_act)  # 1 x total_cells

        return y

    def _update_memory_and_inhibition(self, y, phi, psi, x_b=None):
        """
        Decay memory and inhibition tensors
        """

        # Set psi to x_b, which includes decayed prior state (see RSMNet.forward)
        psi = x_b

        # Update phi for next step (decay inhibition cells)
        phi = torch.max(phi * self.gamma, y)

        return (phi, psi)

    def _decode_prediction(self, y):
        if self.decode_from_full_memory:
            decode_input = y
        else:
            decode_input = self._group_max(y)
        output = self.linear_d(decode_input)
        if self.decode_activation_fn:
            activation = RSMLayer.ACT_FNS[self.decode_activation_fn]
            output = activation(output)
        return output

    def forward(self, x_a_batch, hidden):
        """
        :param x_a_batch: Input batch of batch_size items from
        generating process (batch_size, d_in)
        :param hidden:
            x_b: Memory at same layer at t-1 (possibly with decayed/hysteresis memory
                from prior time steps)
            x_b_above: memory state at layer above at t-1
            phi: inhibition state (used only for boost_strat=='rsm_inhibition')
            psi: memory state at t-1, inclusive of hysteresis
        Note that RSMLayer takes a 4-tuple that includes the feedback state
        from the layer above, x_c, while the RSMNet takes only 3-tuple for
        hidden state.
        """
        x_b, x_b_above, phi, psi = hidden
        x_b_in = x_b.clone()

        phi, psi = self._do_forgetting(phi, psi)

        self._debug_log({"x_b": x_b, "x_a_batch": x_a_batch})

        sigma = self._fc_weighted_ave(x_a_batch, x_b, x_b_above=x_b_above)
        self._debug_log({"sigma": sigma})

        y = self._inhibited_winners(sigma, phi)

        phi, psi = self._update_memory_and_inhibition(y, phi, psi, x_b=x_b_in)
        self._debug_log({"phi": phi, "psi": psi})

        output = self._decode_prediction(y)

        self._debug_log({"y": y, "output": output})

        # Update recurrent input / output x_b
        if self.x_b_norm:
            # Normalizing scalar (force sum(x_b) == 1)
            alpha_y = (y.sum(dim=1) + 1e-9).unsqueeze(dim=1)
            x_b = y / alpha_y
        else:
            x_b = y

        hidden = (x_b, phi, psi)
        return (output, hidden)


if __name__ == "__main__":
    batch_size, d_in = 50, 64

    x = torch.randn(batch_size, d_in)
    y = torch.randn(batch_size, d_in)

    model = RSMLayer(d_in)

    criterion = torch.nn.MSELoss(reduction="sum")
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
    for t in range(500):
        y_pred = model(x)

        loss = criterion(y_pred, y)
        print(t, loss.item())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

In [1]:
import argparse
import os

import ray
from ray import tune
from ray.tune.logger import CSVLogger, JsonLogger


from parse_config import parse_config

import numpy as np
import random
import sys
import time
from functools import partial, reduce

import torch
import torchvision.utils as vutils
from torch.utils.data import DataLoader
from torchvision import transforms

from duty_cycle_metrics import binary_entropy
from ptb import lang_util
from rsm import RSMNet, RSMPredictor
from rsm_samplers import (
    PTBSequenceSampler,
    pred_sequence_collate,
    ptb_pred_sequence_collate,
)
from util import (
    fig2img,
    plot_activity,
    plot_activity_grid,
    plot_confusion_matrix,
    plot_representation_similarity,
    plot_tensors,
    print_aligned_sentences,
    print_epoch_values,
)

torch.autograd.set_detect_anomaly(True)


class RSMExperiment(object):
    """
    Generic class for creating tiny RSM models. This can be used with Ray
    tune or PyExperimentSuite, to run a single trial or repetition of a
    network.
    
    """

    def __init__(self, config=None):
        self.data_dir = config.get("data_dir", "/home/user/nta/datasets") #hardcode please change this one 
        self.path = config.get("path", "/home/user/nta/results") #hardcode please change this one
        self.model_filename = config.get("model_filename", "model.pt")
        self.pred_model_filename = config.get("pred_model_filename", "pred_model.pt")
        self.graph_filename = config.get("graph_filename", "rsm.onnx")
        self.save_onnx_graph_at_checkpoint = config.get(
            "save_onnx_graph_at_checkpoint", False
        )
        self.exp_name = config.get("name", "exp")
        self.batch_log_interval = config.get("batch_log_interval", 0)
        self.eval_interval = config.get("eval_interval", 1)
        self.eval_interval_schedule = config.get("eval_interval_schedule", None)
        self.model_kind = config.get("model_kind", "rsm")
        self.debug = config.get("debug", False)
        self.visual_debug = config.get("visual_debug", False)

        # Instrumentation
        self.instrumentation = config.get("instrumentation", False)
        self.plot_gradients = config.get("plot_gradients", False)
        self.instr_charts = config.get("instr_charts", [])

        self.iterations = config.get("iterations", 1)
        self.dataset_kind = config.get("dataset", "ptb")

        # Training / testing parameters
        self.batch_size = config.get("batch_size", 100)
        self.eval_batch_size = config.get("eval_batch_size", 11)
        self.batches_in_epoch = config.get("batches_in_epoch", 100) #number of game per epoch
        self.batches_in_first_epoch = config.get(
            "batches_in_first_epoch", self.batches_in_epoch
        )
        self.eval_batches_in_epoch = config.get(
            "eval_batches_in_epoch", self.batches_in_epoch
        )
        self.pause_after_upticks = config.get("pause_after_upticks", 0)
        self.pause_after_epochs = config.get("pause_after_epochs", 0)
        self.pause_eval_interval = config.get("pause_eval_interval", 10)
        self.pause_min_epoch = config.get("pause_min_epoch", 0)

        # Data parameters
        self.input_size = config.get("input_size", (1, 100)) #(1, 28, 28)
        self.sequences = config.get("sequences", [[0, 1, 2, 3]])
        self.static_digit = config.get("static_digit", False)
        self.randomize_sequence_cursors = config.get("randomize_sequence_cursors", True)
#        self.use_mnist_pct = config.get("use_mnist_pct", 1.0)

        self.learning_rate = config.get("learning_rate", 0.0005)
        self.pred_learning_rate = config.get("pred_learning_rate", self.learning_rate)
        self.momentum = config.get("momentum", 0.9)
        self.optimizer_type = config.get("optimizer", "adam")
        self.pred_optimizer_type = config.get("pred_optimizer", self.optimizer_type)

        # Model
        self.heads = config.get("heads", 1)
        self.m_groups = config.get("m_groups", 60)
        self.n_cells_per_group = config.get("n_cells_per_group", 1)
        self.k_winners = config.get("k_winners", 7)
        self.k_winners_pct = config.get("k_winners_pct", None)
        if self.k_winners_pct is not None:
            # Optionally define k-winners proportionally
            self.k_winners = int(self.m_groups * self.k_winners_pct)
        self.gamma = config.get("gamma", 0.5)
        self.eps = config.get("eps", 0.5)
        self.k_winner_cells = config.get("k_winner_cells", 1)
        self.flattened = self.n_cells_per_group == 1
        self.forget_mu = config.get("forget_mu", 0.0)

        # Tweaks
        self.activation_fn = config.get("activation_fn", "tanh")
        self.decode_activation_fn = config.get("decode_activation_fn", None)
        self.pred_l2_reg = config.get("pred_l2_reg", 0)
        self.l2_reg = config.get("l2_reg", 0)
        self.dec_l2_reg = config.get("dec_l2_reg", 0)
        self.decode_from_full_memory = config.get("decode_from_full_memory", False)
        self.boost_strat = config.get("boost_strat", "col_boosting")
        self.x_b_norm = config.get("x_b_norm", False)
        self.mask_shifted_pi = config.get("mask_shifted_pi", False)
        self.boost_strength = config.get("boost_strength", 1.0)
        self.boost_strength_factor = config.get("boost_strength_factor", 1.0)
        self.duty_cycle_period = config.get("duty_cycle_period", 1000)
        self.mult_integration = config.get("mult_integration", False)
        self.noise_buffer = config.get("noise_buffer", False)
        self.col_output_cells = config.get("col_output_cells", False)
        self.fpartition = config.get("fpartition", None)
        self.balance_part_winners = config.get("balance_part_winners", False)
        self.weight_sparsity = config.get("weight_sparsity", None)
        self.embedding_kind = config.get("embedding_kind", "ptb_fasttext_e5")
        self.feedback_conn = config.get("feedback_conn", False)
        self.input_bias = config.get("input_bias", False)
        self.decode_bias = config.get("decode_bias", True)
        self.loss_layers = config.get("loss_layers", "first")
        self.top_lateral_conn = config.get("top_lateral_conn", True)
        self.lateral_conn = config.get("lateral_conn", True)
        self.trainable_decay = config.get("trainable_decay", True)
        self.trainable_decay_rec = config.get("trainable_decay_rec", False)
        self.max_decay = config.get("max_decay", 1.0)
        self.additive_decay = config.get("additive_decay", False)
        self.stoch_decay = config.get("stoch_decay", False)
        self.stoch_k_sd = config.get("stoch_k_sd", False)
        self.rec_active_dendrites = config.get("rec_active_dendrites", 0)
        self.mem_floor = config.get("mem_floor", 0.0)

        # Prediction smoothing
        self.word_cache_decay = config.get("word_cache_decay", 0.0)
        self.word_cache_pct = config.get("word_cache_pct", 0.0)
        self.unif_smoothing = config.get("unif_smoothing", 0.0)
        self.kn5_pct = config.get("kn5_pct", 0.0)

        # Predictor network
        self.predictor_hidden_size = config.get("predictor_hidden_size", 200)
        self.predictor_output_size = config.get("predictor_output_size", 10000)

        self.n_layers = config.get("n_layers", 1)

        # Embeddings for language modeling
        self.embed_dim = config.get("embed_dim", 100)
        self.vocab_size = config.get("vocab_size", 0)

        self.loss_function = config.get("loss_function", "MSELoss")
        self.lr_step_schedule = config.get("lr_step_schedule", None)
        self.learning_rate_gamma = config.get("learning_rate_gamma", 0.0)
        self.learning_rate_min = config.get("learning_rate_min", 0.0)

        # Training state
        self.best_val_loss = None
        self.do_anneal_learning = False
        self.model_learning_paused = False
        self.n_upticks = 0

        self.train_hidden_buffer = []

        # Additional state for vis, etc
        self.activity_by_inputs = {}  # 'digit-digit' -> list of distribution arrays

    def _build_dataloader(self):
        self.val_loader = self.corpus = None
        
        # Download "Penn Treebank" dataset
        from torchnlp.datasets import penn_treebank_dataset

        print("Maybe download PTB...")
        penn_treebank_dataset(self.data_dir + "/PTB", train=True, test=True)
        corpus = lang_util.Corpus(self.data_dir + "/PTB")
        train_sampler = PTBSequenceSampler(
            corpus.train,
            batch_size=self.batch_size,
            max_batches=self.batches_in_epoch,
        )

        import fasttext

        # Generated via notebooks/ptb_embeddings.ipynb
        embedding = {}
        ft_model = fasttext.load_model(
            self.data_dir + "/embeddings/%s.bin" % self.embedding_kind
        )
        for word_id, word in enumerate(corpus.dictionary.idx2word):
            embedding[word_id] = torch.tensor(ft_model[word])

        if self.embedding_kind:
            print(
                "Loaded embedding dict (%s) with %d entries"
                % (self.embedding_kind, len(embedding))
            )

        collate_fn = partial(ptb_pred_sequence_collate, vector_dict=embedding)
        self.train_loader = DataLoader(
            corpus.train, batch_sampler=train_sampler, collate_fn=collate_fn
        )
        val_sampler = PTBSequenceSampler(
            corpus.test,
            batch_size=self.eval_batch_size,
            max_batches=self.eval_batches_in_epoch,
            uniform_offsets=True,
        )
        self.val_loader = DataLoader(
            corpus.test, 
            batch_sampler=val_sampler, 
            collate_fn=collate_fn
        )
        self.corpus = corpus
        print("Built dataloaders...")

    def _get_loss_function(self):
        self.loss = getattr(torch.nn, self.loss_function)(reduction="mean")
        self.predictor_loss = None
        if self.predictor:
            # https://pytorch.org/docs/stable/nn.html#crossentropyloss
            self.predictor_loss = torch.nn.CrossEntropyLoss(reduction="sum")

    def _get_one_optimizer(self, otype, params, lr, l2_reg=0.0):
        if otype == "adam":
            optimizer = torch.optim.Adam(params, lr=lr, weight_decay=l2_reg)
        elif otype == "sgd":
            optimizer = torch.optim.SGD(
                params, lr=lr, momentum=self.momentum, weight_decay=l2_reg
            )
        return optimizer

    def _get_optimizer(self):
        self.pred_optimizer = None
        self.optimizer = self._get_one_optimizer(
            self.optimizer_type, self.model.parameters(), self.learning_rate
        )
        if self.predictor:
            self.pred_optimizer = self._get_one_optimizer(
                self.pred_optimizer_type,
                self.predictor.parameters(),
                self.pred_learning_rate,
                l2_reg=self.pred_l2_reg,
            )

    def model_setup(self, config, restore_path=None):
        seed = config.get("seed", random.randint(0, 10000))
        if torch.cuda.is_available():
            print("setup: Using cuda")
            self.device = torch.device("cuda")
            torch.cuda.manual_seed(seed)
        else:
            print("setup: Using cpu")
            self.device = torch.device("cpu")

        self._build_dataloader()

        if restore_path:
            # Restore to self.model and self.predictor
            self.model_restore(restore_path)
        else:
            # Build model and optimizer
            self.d_in = reduce(lambda x, y: x * y, self.input_size)
            #print('inpust size',self.input_size)
            self.d_out = config.get("output_size", self.d_in)
            self.predictor = None
            predictor_d_in = self.m_groups
            self.model = RSMNet(
                n_layers=self.n_layers,
                d_in=self.d_in,
                d_out=self.d_out,
                m=self.m_groups,
                n=self.n_cells_per_group,
                k=self.k_winners,
                k_winner_cells=self.k_winner_cells,
                gamma=self.gamma,
                eps=self.eps,
                forget_mu=self.forget_mu,
                activation_fn=self.activation_fn,
                decode_activation_fn=self.decode_activation_fn,
                decode_from_full_memory=self.decode_from_full_memory,
                col_output_cells=self.col_output_cells,
                x_b_norm=self.x_b_norm,
                mask_shifted_pi=self.mask_shifted_pi,
                boost_strat=self.boost_strat,
                boost_strength=self.boost_strength,
                boost_strength_factor=self.boost_strength_factor,
                duty_cycle_period=self.duty_cycle_period,
                weight_sparsity=self.weight_sparsity,
                mult_integration=self.mult_integration,
                fpartition=self.fpartition,
                balance_part_winners=self.balance_part_winners,
                feedback_conn=self.feedback_conn,
                lateral_conn=self.lateral_conn,
                top_lateral_conn=self.top_lateral_conn,
                input_bias=self.input_bias,
                decode_bias=self.decode_bias,
                trainable_decay=self.trainable_decay,
                trainable_decay_rec=self.trainable_decay_rec,
                max_decay=self.max_decay,
                additive_decay=self.additive_decay,
                stoch_decay=self.stoch_decay,
                embed_dim=self.embed_dim,
                vocab_size=self.vocab_size,
                stoch_k_sd=self.stoch_k_sd,
                rec_active_dendrites=self.rec_active_dendrites,
                mem_floor=self.mem_floor,
                debug=self.debug,
                visual_debug=self.visual_debug,
            )
            if self.n_layers > 1:
                predictor_d_in = sum([l.total_cells for l in self.model.children()])
            else:
                predictor_d_in = self.m_groups * self.n_cells_per_group

            if self.predictor_hidden_size:
                self.predictor = RSMPredictor(
                    d_in=predictor_d_in,
                    d_out=self.predictor_output_size,
                    hidden_size=self.predictor_hidden_size,
                )

        # Move to device
        self.model.to(self.device)
        if self.predictor:
            self.predictor.to(self.device)

        self._get_loss_function()
        self._get_optimizer()

        if self.word_cache_decay:
            self.word_cache = torch.zeros(
                (self.eval_batch_size, self.vocab_size),
                device=self.device,
                requires_grad=False,
            )

        if self.kn5_pct:
            # This KN5 model likely needs to be generated / downloaded
            self.kn5_distr = torch.load(
                self.data_dir + "/PTB/KN5/kn5_distr_remapped.pt",
                map_location=self.device,
            )

    def _repackage_hidden(self, h):
        """Wraps hidden states in new Tensors, to detach them from their history."""
        if isinstance(h, torch.Tensor):
            return h.detach()
        else:
            return tuple(self._repackage_hidden(v) for v in h)

    def _adjust_learning_rate(self, epoch):
        if self.do_anneal_learning and self.learning_rate > self.learning_rate_min:
            self.learning_rate *= self.learning_rate_gamma
            self.do_anneal_learning = False
            print(
                "Reducing learning rate by gamma %.2f to: %.5f"
                % (self.learning_rate_gamma, self.learning_rate)
            )
            for param_group in self.optimizer.param_groups:
                param_group["lr"] = self.learning_rate

    def _init_hidden(self, batch_size):
        return self.model.init_hidden(batch_size)

    def _cache_inputs(self, input_labels, clear=False):
        """
        Word cache for smoothing, currently only used for eval (on test)
        """
        if self.word_cache_decay:
            if clear:
                # Clear cache
                self.word_cache = self.word_cache * 0.0

            self.word_cache.scatter_(1, input_labels.unsqueeze(1), 1.0)
            # Decay
            self.word_cache = self.word_cache * self.word_cache_decay

    def _get_prediction_and_loss_inputs(self, hidden):
        # hidden is (x_b, phi, psi)
        x_b = hidden[0]
        #print('self.predictor',self.predictor)
        #print('torch.cat(x_b, dim=1).view(-1, self.predictor.d_in).detach() = ',torch.cat(x_b, dim=1).view(-1, self.predictor.d_in).detach())
        #print('x_b at 437',x_b[0].shape)
        if self.predictor:
            # Predict from concat of all layer hidden states
            predictor_input = (
                torch.cat(x_b, dim=1).view(-1,self.predictor.d_in).detach()
            ) #remember to change 31 to batchsize later on
            
        #print('predictor_input at 444',predictor_input.shape)
        return x_b, predictor_input

    def _backward_and_optimize(self, loss):
        if self.debug:
            self.model._register_hooks()
        loss.backward()
        if self.model_kind == "lstm":
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.25)
            for p in self.model.parameters():
                p.data.add_(-self.learning_rate, p.grad.data)
        else:
            self.optimizer.step()

    def _interpolated_loss(
        self, predictor_dist, pred_targets, loader=None, train=False
    ):
        predictor_dist_size = list(predictor_dist.size())
        num_classes = predictor_dist_size[1]
        
        #print('pred_targets ',pred_targets)
        #print('num_classes ',num_classes)
        
        labels_one_hot = torch.nn.functional.one_hot(
            pred_targets, num_classes=num_classes
        )
        labels_one_hot = labels_one_hot.to(self.device).float()

        predictions = torch.zeros_like(predictor_dist)
        predictor_mass_pct = 1.0
        if not train:
            if (
                self.word_cache_decay
            ):  # and predictions.size(0) == self.word_cache.size(0):
                # Word cache enabled
                mass_pct = self.word_cache_pct
                predictor_mass_pct -= mass_pct
                predictions += (
                    mass_pct
                    * self.word_cache
                    / self.word_cache.sum(dim=1, keepdim=True)
                )

            if self.unif_smoothing:
                # Uniform smoothing enabled
                mass_pct = self.unif_smoothing
                predictor_mass_pct -= mass_pct
                predictions += (
                    mass_pct * torch.ones_like(predictor_dist) / self.vocab_size
                )

            if self.kn5_pct:
                # KN5 model interpolation
                mass_pct = self.kn5_pct
                predictor_mass_pct -= mass_pct
                predictions += (
                    mass_pct * self.kn5_distr[loader.batch_sampler.batch_idxs, :]
                )

        predictions += predictor_mass_pct * predictor_dist
        ll = (labels_one_hot * torch.log(predictions)).sum(dim=[0, 1])
        interp_loss = -ll  # sum negative log likelihood

        return interp_loss.item()

    def _do_prediction(
        self, inputs, pred_targets, pcounts, train=False, batch_idx=0, loader=None
    ):
        """
        Do prediction.
        """
        class_predictions = correct_arr = None
        #print('sef.predictor', self.predictor)
        if self.predictor:

            predictor_dist, predictor_logits = self.predictor(inputs.detach())

            # This loss is without inference-time model interpolation
            #print('predictor_logits, pred_targets line 525',predictor_logits.shape,' ',pred_targets.shape)
            
            pred_loss = self.predictor_loss(
                torch.squeeze(predictor_logits), torch.squeeze(pred_targets)
            )  # cross-entropy loss
            
            #print('pred_loss ',pred_loss)
            # This loss is for the interpolated model
            interp_loss = self._interpolated_loss(
                predictor_dist, pred_targets, loader=loader, train=train
            )

            _, class_predictions = torch.max(predictor_dist, 1)
            pcounts["total_samples"] += pred_targets.size(0)
            correct_arr = class_predictions == pred_targets
            pcounts["correct_samples"] += correct_arr.sum().item()
            pred_loss_ = pred_loss.item()
            pcounts["total_pred_loss"] += pred_loss_
            pcounts["total_interp_loss"] += interp_loss
            if train:
                # Predictor backward + optimize
                pred_loss.backward()
                self.pred_optimizer.step()

        if self.batch_log_interval and batch_idx % self.batch_log_interval == 0:
            print("Finished batch %d" % batch_idx)
            if self.predictor:
                batch_acc = correct_arr.float().mean() * 100
                batch_ppl = lang_util.perpl(pred_loss_ / pred_targets.size(0))
                print(
                    "Partial pred acc - "
                    "batch acc: %.3f%%, pred ppl: %.1f" % (batch_acc, batch_ppl)
                )

        return (pcounts, class_predictions, correct_arr)

    def _compute_loss(self, predicted_outputs, targets):
        """
        Compute loss across multiple layers (if applicable).

        First layer loss (l1_loss) is between last image prediction and actual input
            image
        Layers > 1 loss (ls_loss) is between last output (hidden predictions) and
            actual hidden

        Args:
            - predicted_outputs: list of len n_layers of (bsz, d_in or total_cells)
            - targets: 2-tuple
                - list of actual_input (bsz, d_in) by layer
                - list of x_b (bsz, total_cells)) by layer

        Note that batch size will differ if using a smaller first epoch batch size.
        In this case we crop target tensors to match predictions.

        TODO: Decision to be made on whether to compute loss vs max-pooled column
        activations or cells.
        """
        loss = None
        if predicted_outputs is not None:

            # TODO: We can stack these and run loss once only
            if self.loss_layers in ["first", "all_layers"]:
                bottom_targets = targets[0].detach()
                pred_img = predicted_outputs[0]
                #print('pred_img at 590', pred_img.shape)
                #print('bottom_targets at 591', bottom_targets.shape)
                l1_loss = self.loss(pred_img, bottom_targets)
                if loss is None:
                    loss = l1_loss
                else:
                    loss += l1_loss

            if self.n_layers > 1 and self.loss_layers in ["above_first", "all_layers"]:
                memory = self._repackage_hidden(targets[1])
                for l in range(self.n_layers - 1):
                    higher_targets = memory[
                        l
                    ]  # Target memory states up to 2nd to last layer
                    outputs = predicted_outputs[l + 1]  # Predictions from layer above
                    ls_loss = self.loss(outputs, higher_targets)
                    if loss is None:
                        loss = ls_loss
                    else:
                        loss += ls_loss

            if self.l2_reg and self.lateral_conn:
                for l in self.model.children():
                    # Add L2 reg term for recurrent weights
                    loss += self.l2_reg * l.linear_b.weight.norm(2) ** 2
                    if hasattr(l.linear_b, "bias"):
                        loss += self.l2_reg * l.linear_b.bias.norm(2) ** 2
            if self.dec_l2_reg:
                for l in self.model.children():
                    # Add L2 reg term for decode weights
                    loss += self.dec_l2_reg * l.linear_d.weight.norm(2) ** 2
                    if hasattr(l.linear_d, "bias") and l.linear_d.bias is not None:
                        loss += self.dec_l2_reg * l.linear_d.bias.norm(2) ** 2

        return loss

    def eval_epoch(self, epoch, loader=None):
        ret = {}
        print("Evaluating...")
        if not loader:
            loader = self.val_loader

        if self.instrumentation:
            # Capture entropy prior to evaluation
            _, train_entropy = binary_entropy(self.model.RSM_1.duty_cycle)
            ret["train_entropy"] = train_entropy.item()
            self.model.RSM_1.duty_cycle.fill_(0.0)  # Clear duty cycle

        self.model.eval()
        if self.predictor:
            self.predictor.eval()

        if self.weight_sparsity is not None:
            # Rezeroing happens before forward pass, so rezero after last
            # training forward.
            self.model._zero_sparse_weights()

        with torch.no_grad():
            total_loss = 0.0
            pcounts = {
                "total_samples": 0.0,
                "correct_samples": 0.0,
                "total_pred_loss": 0.0,
                "total_interp_loss": 0.0,
            }
            print(self.eval_batch_size)#debug
            hidden = self._init_hidden(self.eval_batch_size)

            read_out_tgt = []
            read_out_pred = []
            metrics = {}

            for _b_idx, (inputs, targets, pred_targets, input_labels) in enumerate(
                loader
            ):

                # Forward
                inputs = inputs.to(self.device)
                targets = targets.to(self.device)
                pred_targets = pred_targets.to(self.device)
                input_labels = input_labels.to(self.device)

                self._cache_inputs(input_labels, clear=_b_idx == 0)
                


                output, hidden = self.model(inputs, hidden)
                
                x_b, pred_input = self._get_prediction_and_loss_inputs(hidden)

                # Loss
                #print('output shape ', output)
                loss = self._compute_loss(output, (targets, x_b))
                if loss is not None:
                    total_loss += loss.item()

                pcounts, class_predictions, correct_arr = self._do_prediction(
                    pred_input, pred_targets, pcounts, batch_idx=_b_idx, loader=loader
                )

                self._read_out_predictions(
                    pred_targets, class_predictions, read_out_tgt, read_out_pred
                )

                hidden = self._repackage_hidden(hidden)

                if self.instrumentation:
                    metrics = self._agg_batch_metrics(
                        metrics,
                        pred_images=output[0].unsqueeze(0),
                        targets=targets.unsqueeze(0),
                        correct_arr=correct_arr.unsqueeze(0),
                        pred_targets=pred_targets,
                        class_predictions=class_predictions,
                    )

            if self.instrumentation:
                # Save some snapshots from last batch of epoch
                # if self.model_kind == "rsm":
                #     metrics['last_hidden_snp'] = x_b
                #     metrics['last_input_snp'] = inputs
                #     metrics['last_output_snp'] = last_output

                # After all eval batches, generate stats & figures
                ret.update(self._generate_instr_charts(metrics))
                ret.update(self._store_instr_hists())
                _, test_entropy = binary_entropy(self.model.RSM_1.duty_cycle)
                ret["test_entropy"] = test_entropy.item()
                self.model.RSM_1.duty_cycle.fill_(0.0)  # Clear duty cycle

            num_batches = _b_idx + 1
            num_samples = pcounts["total_samples"]
            ret["val_loss"] = val_loss = total_loss / num_batches
            if self.predictor:
                test_pred_loss = pcounts["total_pred_loss"] / num_samples
                test_interp_loss = pcounts["total_interp_loss"] / num_samples
                ret["val_interp_ppl"] = lang_util.perpl(test_interp_loss)
                ret["val_pred_ppl"] = lang_util.perpl(test_pred_loss)
                ret["val_pred_acc"] = 100 * pcounts["correct_samples"] / num_samples

            if not self.best_val_loss or val_loss < self.best_val_loss:
                self.best_val_loss = val_loss
            else:
                # Val loss increased
                if self.learning_rate_gamma:
                    self.do_anneal_learning = True  # Reduce LR during post_epoch
                if self.pause_after_upticks and not self.model_learning_paused:
                    if not self.pause_min_epoch or (
                        self.pause_min_epoch and epoch >= self.pause_min_epoch
                    ):
                        self.n_upticks += 1
                        if self.n_upticks >= self.pause_after_upticks:
                            print(
                                ">>> Pausing learning after %d upticks, validation "
                                "loss rose to %.3f, best: %.3f"
                                % (self.n_upticks, val_loss, self.best_val_loss)
                            )
                            self._pause_learning(epoch)

        return ret

    def _pause_learning(self, epoch):
        print(
            "Pausing learning... Setting eval interval to %d" % self.pause_eval_interval
        )
        self.model_learning_paused = True
        self.eval_interval = self.pause_eval_interval
        self.model._zero_kwinner_boost()

    def train_epoch(self, epoch):
        """
        Do one epoch of training and testing.

        Returns:
            A dict that describes progress of this epoch.
            The dict includes the key 'stop'. If set to one, this network
            should be stopped early. Training is not progressing well enough.
        """
        t1 = time.time()
        ret = {}

        self.model.train()  # Needed if using dropout
        
        if self.predictor:
            self.predictor.train()

        # Performance metrics
        total_loss = 0.0
        pcounts = {
            "total_samples": 0.0,
            "correct_samples": 0.0,
            "total_pred_loss": 0.0,
            "total_interp_loss": 0.0,
        }

        bsz = self.batch_size

        hidden = self.train_hidden_buffer[-1] if self.train_hidden_buffer else None
        if hidden is None:
            hidden = self._init_hidden(self.batch_size)
        i = 1
        for batch_idx, (inputs, targets, pred_targets, _input_labels) in enumerate(
            self.train_loader
        ):
           

            # Inputs are of shape (batch, input_size)
            if inputs.size(0) > bsz:
                # Crop to smaller first epoch batch size
                inputs = inputs[:bsz]
                targets = targets[:bsz]
                pred_targets = pred_targets[:bsz]

            hidden = self._repackage_hidden(hidden) #detach from computational graph
            
            self.optimizer.zero_grad() #zero out th
            if self.pred_optimizer:
                self.pred_optimizer.zero_grad()

            inputs = inputs.to(self.device)
            targets = targets.to(self.device)
            pred_targets = pred_targets.to(self.device)
            
            output, hidden = self.model(inputs, hidden)
            

            x_b, pred_input = self._get_prediction_and_loss_inputs(hidden)

            self.train_hidden_buffer.append(hidden)

            loss_targets = (targets, x_b)
            #print('output at line 851',output[0].shape)
            loss = self._compute_loss(output, loss_targets)
            if loss is not None:
                total_loss += loss.item()
                if not self.model_learning_paused:
                    self._backward_and_optimize(loss)        

            # Keep only latest batch states around
            self.train_hidden_buffer = self.train_hidden_buffer[-1:]
            
            
            #print('now do prediction line 857')
            pcounts, class_predictions, correct_arr = self._do_prediction(
                pred_input,
                pred_targets, #pred_targets
                pcounts,
                train=True,
                batch_idx=batch_idx,
                loader=self.train_loader,
            )

            if epoch == 0 and batch_idx >= self.batches_in_first_epoch - 1:
                print(
                    "Breaking after %d batches in epoch %d"
                    % (self.batches_in_first_epoch, epoch)
                )
                break

        ret["stop"] = 0
        self.model._post_train_epoch(epoch)  # Update kwinners duty cycles, etc

        if self.eval_interval and (epoch - 1) % self.eval_interval == 0:

            # Evaluate each x epochs
            ret.update(self.eval_epoch(epoch))

        train_time = time.time() - t1
        self._post_epoch(epoch)

        num_batches = batch_idx + 1
        ret["train_loss"] = total_loss / num_batches
        if self.predictor:
            num_samples = num_batches * self.batch_size
            train_pred_loss = pcounts["total_pred_loss"] / num_samples
            train_interp_loss = pcounts["total_interp_loss"] / num_samples
            ret["train_interp_ppl"] = lang_util.perpl(train_interp_loss)
            ret["train_pred_ppl"] = lang_util.perpl(train_pred_loss)
            ret["train_pred_acc"] = (
                100 * pcounts["correct_samples"] / pcounts["total_samples"]
            )

        ret["epoch_time_train"] = train_time
        ret["epoch_time"] = time.time() - t1
        ret["learning_rate"] = self.learning_rate
        print(epoch, print_epoch_values(ret))
        return ret

    def _post_epoch(self, epoch):
        """
        The set of actions to do after each epoch of training: adjust learning
        rate, rezero sparse weights, and update boost strengths.
        """
        if self.pause_after_epochs and epoch == self.pause_after_epochs:
            self._pause_learning(epoch)
        self._adjust_learning_rate(epoch)
        if self.eval_interval_schedule:
            for step, new_interval in self.eval_interval_schedule:
                if step == epoch:
                    print(">> Changing eval interval to %d" % new_interval)
                    self.eval_interval = new_interval

    def model_save(self, checkpoint_dir):
        """Save the model in this directory.

        :param checkpoint_dir:

        :return: str: The return value is expected to be the checkpoint path that
        can be later passed to `model_restore()`.

        NOTE: Embedding is not saved with model, so results may vary if a different
        embedding binary (even if re-generated with same config) is used with the
        restored model.
        """
        checkpoint_file = os.path.join(checkpoint_dir, self.model_filename)
        if checkpoint_file.endswith(".pt"):
            torch.save(self.model, checkpoint_file)
        else:
            torch.save(self.model.state_dict(), checkpoint_file)
        if self.predictor:
            checkpoint_file = os.path.join(checkpoint_dir, self.pred_model_filename)
            if checkpoint_file.endswith(".pt"):
                torch.save(self.predictor, checkpoint_file)
            else:
                torch.save(self.predictor.state_dict(), checkpoint_file)

        if self.save_onnx_graph_at_checkpoint:
            dummy_input = (torch.rand(1, 1, 28, 28),)
            torch.onnx.export(
                self.model, dummy_input, self.graph_filename, verbose=True
            )

        return checkpoint_file

    def model_restore(self, checkpoint_path):
        """
        :param checkpoint_path: Loads model from this checkpoint path.
        If path is a directory, will append the parameter model_filename
        """
        print("Loading from", checkpoint_path)
        checkpoint_file = os.path.join(checkpoint_path, self.model_filename)
        if checkpoint_file.endswith(".pt"):
            self.model = torch.load(checkpoint_file, map_location=self.device)
        else:
            self.model.load_state_dict(
                torch.load(checkpoint_file, map_location=self.device)
            )
        checkpoint_file = os.path.join(checkpoint_path, self.pred_model_filename)
        if checkpoint_file.endswith(".pt"):
            self.predictor = torch.load(checkpoint_file, map_location=self.device)
        else:
            self.predictor.load_state_dict(
                torch.load(checkpoint_file, map_location=self.device)
            )
        return self.model

    def model_cleanup(self):
        pass

    def _agg_batch_metrics(self, metrics, **kwargs):
        for metric_key, val in kwargs.items():
            if val is not None:
                if metric_key not in metrics:
                    metrics[metric_key] = val
                else:
                    current = metrics[metric_key]
                    metrics[metric_key] = torch.cat((current, val))
        return metrics

    def _generate_instr_charts(self, metrics):
        ret = {}
        if self.model_kind == "rsm" and self.instrumentation:

            if "img_preds" in self.instr_charts:
                ret["img_preds"] = self._image_grid(
                    metrics["pred_images"],
                    compare_with=metrics["targets"],
                    compare_correct=metrics["correct_arr"],
                ).cpu()

            if "img_memory_snapshot" in self.instr_charts and self.n_layers > 1:
                last_inp_layers = [None for x in range(self.n_layers)]
                last_inp_layers[0] = metrics["last_input_snp"]
                fig = plot_tensors(
                    self.model,
                    [
                        ("last_out", metrics["last_output_snp"]),
                        ("inputs", last_inp_layers),
                        ("x_b", metrics["last_hidden_snp"]),
                    ],
                    return_fig=True,
                )
                ret["img_memory_snapshot"] = fig2img(fig)

        return ret

    def _read_out_predictions(
        self,
        pred_targets,
        class_predictions,
        read_out_tgt,
        read_out_pred,
        read_out_len=20,
    ):
        if self.predictor and self.corpus and len(read_out_tgt) < read_out_len:
            read_out_tgt.append(pred_targets[0])
            read_out_pred.append(class_predictions[0])
            

            if len(read_out_tgt) == read_out_len:
                print_aligned_sentences(
                    self.corpus.read_out(read_out_tgt),
                    self.corpus.read_out(read_out_pred),
                    labels=["Targ", "Pred"],
                )
                
    def _reward_out_predictions(
        self,
        pred_targets,
        class_predictions,
        read_out_tgt,
        read_out_pred,
        read_out_len=20,
    ):
        
        if pred_targets[0].item() == class_predictions[0].item():
            return 1
        else:
            return 0


    def _store_instr_hists(self):
        ret = {}
        if self.instrumentation:
            for name, param in self.model.named_parameters():
                if "weight" in name or "decay" in name or "ramp" in name:
                    data = param.data.cpu()
                    if data.size(0):
                        ret["hist_" + name] = data
                        if self.debug:
                            print(
                                "%s: mean: %.3f std: %.3f"
                                % (name, data.mean(), data.std())
                            )
        return ret

    def _store_activity_for_viz(self, x_bs, input_labels, pred_labels):
        """
        Aggregate activity for a supplied batch
        """
        for _x_b, label, target in zip(x_bs, input_labels, pred_labels):
            _label = label.item()
            _label_next = target.item()
            activity = _x_b.detach().view(self.m_groups, -1).squeeze()
            key = "%d-%d" % (_label, _label_next)
            if key not in self.activity_by_inputs:
                self.activity_by_inputs[key] = []
            self.activity_by_inputs[key].append(activity)


In [2]:
import gym
from gym import spaces

import numpy as np

class RsmAttEnv(gym.Env):
    """This is rsm attention environment"""
    
    def __init__(self, config_model = {}, epoch = 1):
        super(RsmAttEnv, self).__init__()
        # There are two configs: config_env for environment, and config_model for model to be reset
        print("Environment initalized...")            
        print("Using torch version", torch.__version__)
        print("Torch device count=%d" % torch.cuda.device_count())
        
        self.epoch = epoch

        self.exp = RSMExperiment(config_model)
        self.exp.model_setup(config_model)
        #self._build_dataloader()
        self.cur_step = 0
        
        #set up model

        self.batch_size = config_model.get("batch_size", self.exp.batch_size)
        self.eval_batch_size = config_model.get("eval_batch_size", self.exp.eval_batch_size)
        self.batches_in_epoch = config_model.get("batches_in_epoch", self.exp.batches_in_epoch)
        self.eval_batches_in_epoch = config_model.get(
            "eval_batches_in_epoch", self.batches_in_epoch
        )
        self.data_dir = config_model.get("data_dir", "data")
        self.embedding_kind = config_model.get("embedding_kind", "ptb_fasttext_e5")
        self.m_groups = config_model.get("m_groups", self.exp.m_groups)
        self.embed_dim = config_model.get("embed_dim", self.exp.embed_dim)
        
        #self.action_space = spaces.Discrete(4**self.heads) #4**heads
        self.action_space = spaces.Discrete(4**1)
        self.x, self.x_b = self._init_state()
        #print('self.total_cells at 227 rsm_att',self.total_cells)
        max_val = sys.maxsize
        
        self.obs_x = spaces.Box(low=-3,
                                high=3, 
                                shape=(self.batch_size, self.embed_dim),
                                dtype=np.float64)
        self.obs_x_b = spaces.Box(low=-3,
                                  high=3,
                                  shape=(self.batch_size, self.m_groups),
                                  dtype=np.float64)
        
        self.observation_space = spaces.Tuple([self.obs_x, self.obs_x_b])
        
        self.hidden = self.exp._init_hidden(self.batch_size)
        self.batch_idx = 0 #this is the current_step check, if the batch_idx equal datasets size then training is done
        self.done = False
        self.get_one = iter(self.exp.train_loader)
        self.reward = 0        
        print('Environment created!')

    def reset(self):
        self.batch_idx = 0 
        self.exp = RSMExperiment(config_model)
        self.exp.model_setup(config_model)
        self.x, self.x_b = self._init_state()
        max_val = sys.maxsize
        
        self.obs_x = spaces.Box(low=-3,
                                high=3, 
                                shape=(self.batch_size, self.embed_dim),
                                dtype=np.float64)
        
        self.obs_x_b = spaces.Box(low=-3,
                                  high=3,
                                  shape=(self.batch_size, self.m_groups),
                                  dtype=np.float64)
        
        self.observation_space = spaces.Tuple([self.obs_x, self.obs_x_b])
        
        self.hidden = self.exp._init_hidden(self.batch_size)
        self.done = False
        self.get_one = iter(self.exp.train_loader)
        self.reward = 0
        print("Environment reset!")   
        return 

    def step(self, action):

        
        self.exp.model.train() #telling model that we are training
        t1 = time.time()
        ret = {}

        self.exp.model.train()  # Needed if using dropout
        
        if self.exp.predictor:
            self.exp.predictor.train()

        # Performance metrics
        total_loss = 0.0
        pcounts = {
            "total_samples": 0.0,
            "correct_samples": 0.0,
            "total_pred_loss": 0.0,
            "total_interp_loss": 0.0,
        }
        
        bsz = self.batch_size

        if self.hidden is None:
            self.hidden = self.exp._init_hidden(self.batch_size)

        if self.batch_idx == len(self.exp.train_loader) - 1:
            self.done = True
        else:
            self.batch_idx+=1
            
        inputs, targets, pred_targets, _input_labels = self.get_one.next() #getdata to train  
        
        if inputs.size(0) > bsz:
                # Crop to smaller first epoch batch size
                inputs = inputs[:bsz]
                targets = targets[:bsz]
                pred_targets = pred_targets[:bsz]

        self.hidden = self.exp._repackage_hidden(self.hidden) #detach from computational graph
        
        pre_inputs = inputs
        
        action = list(f'{action:0{2*1}b}') 
        action = [int(i) for i in action] 
        #convert action from one integer i.e. 15 to list of tuples of binary [(1,0),(1,1)]

        iter_action = iter(action) #[1,0,1,0] => [(1,0),(1,0)]
        
        action_heads = list(zip(iter_action,iter_action))

        for act in range(len(action_heads)):
            
            inputs = action_heads[act][0]*inputs


        self.exp.optimizer.zero_grad() #zero out th
        
        if self.exp.pred_optimizer:
            self.exp.pred_optimizer.zero_grad()

        # Forward
        inputs = inputs.to(self.exp.device)
        targets = targets.to(self.exp.device)
        pred_targets = pred_targets.to(self.exp.device)

        output, self.hidden = self.exp.model(inputs, self.hidden)

        self.x_b, pred_input = self.exp._get_prediction_and_loss_inputs(self.hidden)

        self.exp.train_hidden_buffer.append(self.hidden)

        loss_targets = (targets, self.x_b)
        #print('output at line 851',output[0].shape)
        loss = self.exp._compute_loss(output, loss_targets)
        if loss is not None:
            total_loss += loss.item()
            if not self.exp.model_learning_paused:
                self.exp._backward_and_optimize(loss)        

        # Keep only latest batch states around
        self.exp.train_hidden_buffer = self.exp.train_hidden_buffer[-1:]

        pcounts, class_predictions, correct_arr = self.exp._do_prediction(
            pred_input,
            pred_targets, #pred_targets
            pcounts,
            train=True,
            batch_idx=self.batch_idx,
            loader=self.exp.train_loader,
        )

        ret["stop"] = 0
        
        self.exp.model._post_train_epoch(self.epoch)  # Update kwinners duty cycles, etc
        
        read_out_tgt = []
        read_out_pred = []
        metrics = {}
        
        self.reward += self.exp._reward_out_predictions(
                    pred_targets, class_predictions, read_out_tgt, read_out_pred
                )
        if self.done:
            print('rewards value of epoch {} is {}'.format(self.epoch,self.reward))
            
        #TODOL: reward should be a list of dictionary for each step
        x_b = self.hidden[0][0].reshape(self.batch_size,-1)
        obs = [pre_inputs, x_b]
        return obs, self.reward, self.done #self.info
    
    def _build_dataloader(self):
        
        """Data Loader should be initiated once,
        Leave it here we can adjust it later"""
        
        self.val_loader = self.corpus = None

        # Download "Penn Treebank" dataset
        from torchnlp.datasets import penn_treebank_dataset

        print("Maybe download PTB...")
        penn_treebank_dataset(self.data_dir + "/PTB", train=True, test=True)
        corpus = lang_util.Corpus(self.data_dir + "/PTB")
        train_sampler = PTBSequenceSampler(
            corpus.train,
            batch_size=self.batch_size,
            max_batches=self.batches_in_epoch,
        )

        import fasttext

        # Generated via notebooks/ptb_embeddings.ipynb
        embedding = {}
        ft_model = fasttext.load_model(
            self.data_dir + "/embeddings/%s.bin" % self.embedding_kind
        )
        for word_id, word in enumerate(corpus.dictionary.idx2word):
            embedding[word_id] = torch.tensor(ft_model[word])

        if self.embedding_kind:
            print(
                "Loaded embedding dict (%s) with %d entries"
                % (self.embedding_kind, len(embedding))
            )

        collate_fn = partial(ptb_pred_sequence_collate, vector_dict=embedding)
        self.train_loader = DataLoader(
            corpus.train, batch_sampler=train_sampler, collate_fn=collate_fn
        )
        val_sampler = PTBSequenceSampler(
            corpus.test,
            batch_size=self.eval_batch_size,
            max_batches=self.eval_batches_in_epoch,
            uniform_offsets=True,
        )
        self.val_loader = DataLoader(
            corpus.test, 
            batch_sampler=val_sampler, 
            collate_fn=collate_fn
        )
        self.corpus = corpus
        print("Built dataloaders...")
        
    def _init_state(self):
        x_b = torch.zeros((self.batch_size, 
                           self.m_groups), 
                          dtype=torch.float32, 
                          requires_grad=False) #before d_in it was tc #2 is number of heads
        x = torch.zeros((self.batch_size, 
                        self.embed_dim), 
                          dtype=torch.float32, 
                          requires_grad=False) #before d_in it was tc #2 is number of heads
        return x, x_b

In [3]:
config_model = {
    "data_dir": os.path.expanduser("~/nta/datasets"),
    "path": os.path.expanduser("~/nta/results"),
}
env = RsmAttEnv()

Environment initalized...
Using torch version 1.4.0
Torch device count=0
setup: Using cpu
Maybe download PTB...
Loaded embedding dict (ptb_fasttext_e5) with 10000 entries
Built dataloaders...
Created <RSMLayer m=60 n=1 k=7 d_in=100 eps=0.50 /> with 15760 trainable params
Created RSMNet with 1 layer(s)
Environment created!





In [4]:
env.step(1)

([tensor([[ 0.0397, -0.3821,  0.4202,  ...,  0.3019, -0.0484,  0.1362],
          [-0.1922, -0.1014,  0.3861,  ..., -0.3230,  0.0634,  0.1617],
          [ 0.1360, -0.0353,  0.2040,  ...,  0.0354,  0.2955, -0.0813],
          ...,
          [-0.1963, -0.3011,  0.4417,  ...,  0.0221,  0.2005,  0.0226],
          [-0.6475,  0.2050,  0.3260,  ...,  0.0362, -0.0282,  0.4589],
          [ 0.2003,  0.2507,  0.2764,  ..., -0.2628, -0.1686, -0.0167]]),
  tensor([[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]], grad_fn=<ViewBackward>)],
 0,
 False)

In [5]:
from ray.rllib.agents.sac.sac import SACTrainer, DEFAULT_CONFIG
from ray.tune.registry import register_env
ray.shutdown()
ray.init()
def env_creator(env_config):
    return RsmAttEnv()  # return an env instance

register_env("RsmAtt", env_creator)
config = DEFAULT_CONFIG.copy()
config["num_gpus"] = 1
# config["num_workers"] = int(multiprocessing.cpu_count() / 2)
config["num_workers"] = 1
config["eager"] = False
config["log_level"] = "INFO"
config["monitor"] = True
config["num_cpus_per_worker"] = 1
config["use_pytorch"] = 1
config["framework"] = 'torch'
config["sample_batch_size"] = 300
trainer = SACTrainer(config=config, env="RsmAtt")

2020-08-01 00:44:00,266	INFO resource_spec.py:212 -- Starting Ray with 8.64 GiB memory available for workers and up to 4.32 GiB for objects. You can adjust these settings with ray.init(memory=<bytes>, object_store_memory=<bytes>).
2020-08-01 00:44:00,756	INFO services.py:1165 -- View the Ray dashboard at [1m[32mlocalhost:8265[39m[22m


Environment initalized...
Using torch version 1.4.0
Torch device count=0
setup: Using cpu
Maybe download PTB...





Loaded embedding dict (ptb_fasttext_e5) with 10000 entries
Built dataloaders...
Created <RSMLayer m=60 n=1 k=7 d_in=100 eps=0.50 /> with 15760 trainable params
Created RSMNet with 1 layer(s)
Environment created!


2020-08-01 00:44:07,468	INFO rollout_worker.py:941 -- Built policy map: {'default_policy': <ray.rllib.policy.torch_policy_template.SACTorchPolicy object at 0x7f291bfc3a90>}
2020-08-01 00:44:07,468	INFO rollout_worker.py:942 -- Built preprocessor map: {'default_policy': <ray.rllib.models.preprocessors.TupleFlatteningPreprocessor object at 0x7f2939e95828>}
2020-08-01 00:44:07,469	INFO rollout_worker.py:413 -- Built filter map: {'default_policy': <ray.rllib.utils.filter.NoFilter object at 0x7f293b03ea90>}


In [6]:
trainer.train()

[2m[36m(pid=7634)[0m Environment initalized...
[2m[36m(pid=7634)[0m Using torch version 1.4.0
[2m[36m(pid=7634)[0m Torch device count=0
[2m[36m(pid=7634)[0m setup: Using cpu
[2m[36m(pid=7634)[0m Maybe download PTB...
[2m[36m(pid=7634)[0m 
[2m[36m(pid=7634)[0m Loaded embedding dict (ptb_fasttext_e5) with 10000 entries
[2m[36m(pid=7634)[0m Built dataloaders...
[2m[36m(pid=7634)[0m Created <RSMLayer m=60 n=1 k=7 d_in=100 eps=0.50 /> with 15760 trainable params
[2m[36m(pid=7634)[0m Created RSMNet with 1 layer(s)
[2m[36m(pid=7634)[0m Environment created!
[2m[36m(pid=7634)[0m setup: Using cpu
[2m[36m(pid=7634)[0m Maybe download PTB...
[2m[36m(pid=7634)[0m 2020-08-01 00:44:15,241	INFO rollout_worker.py:526 -- Generating sample batch of size 300
[2m[36m(pid=7634)[0m 


2020-08-01 00:44:20,964	INFO trainer.py:494 -- Worker crashed during call to train(). To attempt to continue training without the failed worker, set `'ignore_worker_failures': True`.


[2m[36m(pid=7634)[0m Loaded embedding dict (ptb_fasttext_e5) with 10000 entries
[2m[36m(pid=7634)[0m Built dataloaders...
[2m[36m(pid=7634)[0m Created <RSMLayer m=60 n=1 k=7 d_in=100 eps=0.50 /> with 15760 trainable params
[2m[36m(pid=7634)[0m Created RSMNet with 1 layer(s)
[2m[36m(pid=7634)[0m Environment reset!
[2m[36m(pid=7634)[0m 2020-08-01 00:44:20,962	INFO sampler.py:466 -- Raw obs from env: {0: {'agent0': None}}
[2m[36m(pid=7634)[0m 2020-08-01 00:44:20,962	INFO sampler.py:467 -- Info return from env: {0: {'agent0': None}}


RayTaskError(ValueError): [36mray::RolloutWorker.par_iter_next()[39m (pid=7634, ip=192.168.2.9)
  File "python/ray/_raylet.pyx", line 446, in ray._raylet.execute_task
  File "python/ray/_raylet.pyx", line 400, in ray._raylet.execute_task.function_executor
  File "/home/incubator/anaconda3/envs/rsm-rl/lib/python3.6/site-packages/ray/util/iter.py", line 1125, in par_iter_next
    return next(self.local_it)
  File "/home/incubator/anaconda3/envs/rsm-rl/lib/python3.6/site-packages/ray/rllib/evaluation/rollout_worker.py", line 263, in gen_rollouts
    yield self.sample()
  File "/home/incubator/anaconda3/envs/rsm-rl/lib/python3.6/site-packages/ray/rllib/evaluation/rollout_worker.py", line 528, in sample
    batches = [self.input_reader.next()]
  File "/home/incubator/anaconda3/envs/rsm-rl/lib/python3.6/site-packages/ray/rllib/evaluation/sampler.py", line 59, in next
    batches = [self.get_data()]
  File "/home/incubator/anaconda3/envs/rsm-rl/lib/python3.6/site-packages/ray/rllib/evaluation/sampler.py", line 164, in get_data
    item = next(self.rollout_provider)
  File "/home/incubator/anaconda3/envs/rsm-rl/lib/python3.6/site-packages/ray/rllib/evaluation/sampler.py", line 489, in _env_runner
    observation_fn=observation_fn)
  File "/home/incubator/anaconda3/envs/rsm-rl/lib/python3.6/site-packages/ray/rllib/evaluation/sampler.py", line 641, in _process_observations
    policy_id).transform(raw_obs)
  File "/home/incubator/anaconda3/envs/rsm-rl/lib/python3.6/site-packages/ray/rllib/models/preprocessors.py", line 199, in transform
    self.check_shape(observation)
  File "/home/incubator/anaconda3/envs/rsm-rl/lib/python3.6/site-packages/ray/rllib/models/preprocessors.py", line 62, in check_shape
    self._obs_space, observation)
ValueError: ('Observation outside expected value range', Tuple(Box(100, 100), Box(100, 60)), None)