In [None]:
#|default_exp models.misc

# Miscellaneous

>This contains a set of experiments.

In [None]:
#|export
from tsai.imports import *
from tsai.models.layers import *
from tsai.models.utils import *

In [None]:
#|export
class InputWrapper(Module):
    def __init__(self, arch, c_in, c_out, seq_len, new_c_in=None, new_seq_len=None, **kwargs):

        new_c_in = ifnone(new_c_in, c_in)
        new_seq_len = ifnone(new_seq_len, seq_len)
        self.new_shape = c_in != new_c_in or seq_len != new_seq_len
        if self.new_shape:
            layers = []
            if c_in != new_c_in: 
                lin = nn.Linear(c_in, new_c_in)
                nn.init.constant_(lin.weight, 0)
                layers += [Transpose(1,2), lin, Transpose(1,2)]
                lin2 = nn.Linear(seq_len, new_seq_len)
                nn.init.constant_(lin2.weight, 0)
                layers += [lin2]
            self.new_shape_fn = nn.Sequential(*layers)
        self.model = build_ts_model(arch, c_in=new_c_in, c_out=c_out, seq_len=new_seq_len, **kwargs)
    def forward(self, x):
        if self.new_shape: x = self.new_shape_fn(x) 
        return self.model(x)

In [None]:
from tsai.models.TST import *

In [None]:
xb = torch.randn(16, 1, 1000)
model = InputWrapper(TST, 1, 4, 1000, 10, 224)
test_eq(model.to(xb.device)(xb).shape, (16,4))

In [None]:
#|export
class ResidualWrapper(Module):
    def __init__(self, model):
        self.model = model

    def forward(self, x):
        return x[..., -1] + self.model(x)

In [None]:
#|eval: false
#|hide
# RecursiveWrapper has not proved to be very useful so far.

In [None]:
#|export
class RecursiveWrapper(Module):
    def __init__(self, model, n_steps, anchored=False):
        self.model, self.n_steps, self.anchored = model, n_steps, anchored
    def forward(self, x):
        preds = []
        for _ in range(self.n_steps): 
            pred = self.model(x)
            preds.append(pred)
            if x.ndim != pred.ndim: pred = pred[:, np.newaxis]
            x = torch.cat((x if self.anchored else x[..., 1:], pred), -1)
        return torch.cat(preds, -1)

In [None]:
xb = torch.randn(16, 1, 20)
model = RecursiveWrapper(TST(1, 1, 20), 5)
test_eq(model.to(xb.device)(xb).shape, (16, 5))

In [None]:
#|eval: false
#|hide
from tsai.export import get_nb_name; nb_name = get_nb_name(locals())
from tsai.imports import create_scripts; create_scripts(nb_name)

<IPython.core.display.Javascript object>

/Users/nacho/notebooks/tsai/nbs/140_models.misc.ipynb saved at 2022-11-09 13:16:35
Correct notebook to script conversion! 😃
Wednesday 09/11/22 13:16:39 CET
