In [46]:
from abc import ABC, abstractmethod
from typing import List, Callable, TypeVar, Generic, Optional, Type, Union, Tuple, Dict, Any
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F


TLT = TypeVar("TLT", Tensor, List[Tensor])

ITTT = TypeVar("ITTT", Tensor, Tuple[Tensor])

class TensorCollector(ABC, Generic[TLT]):
    @abstractmethod
    def append(self, t: Tensor) -> None:
        pass

    @abstractmethod
    def result(self) -> TLT:
        pass


class ListCollector(TensorCollector[List[Tensor]]):

    def __init__(self):
        self.data = []

    def result(self) -> List[Tensor]:
        out = self.data
        self.data = []
        return out

    def append(self, t: Tensor) -> None:
        self.data.append(t)


class ReverseListCollector(ListCollector):

    def result(self) -> List[Tensor]:
        self.data.reverse()
        out = self.data
        self.data = []
        return out


class LastElementCollector(TensorCollector[Tensor]):

    def __init__(self):
        self.data: Optional[Tensor] = None

    def result(self) -> Tensor:
        out = self.data
        self.data = None
        return out

    def append(self, t: Tensor) -> None:
        self.data = t


class Cat_function(ABC, Generic[ITTT]):
    @abstractmethod
    def tensor_merge(self, input: ITTT, state: Tensor):
        pass


class Cat_Tuple(Cat_function):
    pass


class Cat_torch(Cat_function):
    pass


CF = TypeVar("CF", bound=Cat_function)


def merge_not_torch_tensor(input: Tensor, state: Tensor):
    return (input, state)


def merge_torch_tensor(input: Tensor, state: Tensor):
    return torch.cat((input, state), dim=1)


def merge_non_torch_tuple(input: Tuple[Tensor], state: Tensor):
    return (*input, state)


class Implicit_cat(Generic[ITTT, CF]):
    @classmethod
    def tensor_merge(cls, input: ITTT, state: Tensor):
        return Implicit_cat.implicit_cat[cls.__args__[1]][cls.__args__[0]](input, state)




    implicit_cat: Dict[Tuple[Cat_function, ITTT], Any] = {
        Cat_Tuple: {Tensor: merge_not_torch_tensor, Tuple[Tensor]: merge_non_torch_tuple},
        Cat_torch: {Tensor: merge_torch_tensor}
    }


class ProgressiveModuleList(nn.Module, Generic[TLT, ITTT]):
    def __init__(self,
                 blocks: List[nn.Module],
                 cat_function: Callable[[ITTT, Tensor], Tensor],
                 collector_class: Type[TensorCollector[TLT]] = ListCollector
                 ):
        super(ProgressiveModuleList, self).__init__()
        self.model_list = nn.ModuleList(blocks)
        self.collector_class = collector_class
        self.cat_function = cat_function


    def forward(self, input: List[ITTT], state: Optional[Tensor]) -> TLT:
        collector: TensorCollector[TLT] = self.collector_class()
        x = input[0]
        i = 0
        while i < (len(input) - 1):
            x = self.model_list[i](x)
            collector.append(x)
            x = torch.cat([x, input[i+1]], dim=1)
            i += 1
        while i < len(self.model_list):
            x = self.model_list[i](x)
            collector.append(x)
            i += 1
        return collector.result()


class ElementwiseModuleList(nn.Module, Generic[TLT]):
    def __init__(self,
                 blocks: List[nn.Module],
                 collector_class: Type[TensorCollector[TLT]] = ListCollector):
        super(ElementwiseModuleList, self).__init__()
        self.model_list = nn.ModuleList(blocks)
        self.collector_class = collector_class

    def forward(self, input: List[Tensor]) -> TLT:
        collector: TensorCollector[TLT] = self.collector_class()
        i = 0
        while i < len(input):
            x = self.model_list[i](input[i])
            collector.append(x)
            i += 1
        return collector.result()

In [47]:
Implicit_cat[Tensor, Cat_Tuple].implicit_cat.keys()

dict_keys([__main__.Cat_Tuple, __main__.Cat_torch])

In [48]:
res = Implicit_cat[Tensor, Cat_Tuple].tensor_merge(input=torch.ones(2,2), state=torch.ones(3,3))

In [49]:
res

(tensor([[1., 1.],
         [1., 1.]]),
 tensor([[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]]))