# TOY MODEL

In [83]:
import torch as tr
from torch import nn

### CONVOLUTION


# Definición de N_Conv: Una secuencia de Conv1d -> BatchNorm1d -> ReLU
class N_Conv(nn.Module):
    """([Conv] -> [BatchNorm] -> [ReLu]) x N"""

    def __init__(
        self,
        input_channels,
        output_channels,
        num_conv,
        kernel_size=3,
        padding=1,
        stride=1,
        AVG_POOL=False,
    ):
        super().__init__()
        layers = []
        for i in range(num_conv):
            if AVG_POOL:
                layers.append(nn.AvgPool1d(kernel_size=2, stride=2, padding=0))

            if i != 0:
                layers.append(
                    nn.Conv1d(
                        output_channels,
                        output_channels,
                        kernel_size=kernel_size,
                        padding=padding,
                        stride=stride,
                    )
                )
            else:
                layers.append(
                    nn.Conv1d(
                        input_channels,
                        output_channels,
                        kernel_size=kernel_size,
                        padding=padding,
                        stride=stride,
                    )
                )
            layers.append(nn.BatchNorm1d(output_channels))
            layers.append(nn.ReLU(inplace=True))

        self.N_Conv = nn.Sequential(*layers)

    def forward(self, x):
        return self.N_Conv(x)


# Definición de N_Conv: Una secuencia de Conv1d -> BatchNorm1d -> ReLU


class Max_Down(nn.Module):
    """Downscaling with maxpool then N_Conv"""

    def __init__(
        self, in_channels, out_channels, num_conv, kernel_size=3, padding=1, stride=1
    ):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool1d(kernel_size=2, stride=2, padding=0),
            N_Conv(
                in_channels,
                out_channels,
                num_conv,
                kernel_size=kernel_size,
                padding=padding,
                stride=stride,
            ),
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Avg_Down(nn.Module):
    """Downscaling with avgpool then N_Conv"""

    def __init__(
        self, in_channels, out_channels, num_conv, kernel_size=3, padding=1, stride=1
    ):
        super().__init__()
        self.avgpool_conv = nn.Sequential(
            nn.AvgPool1d(kernel_size=2, stride=2, padding=0),
            N_Conv(in_channels, out_channels, num_conv, kernel_size, padding, stride),
        )

    def forward(self, x):
        return self.avgpool_conv(x)


class UpBlock(nn.Module):
    """
    Bloque de upsampling con conexión skip opcional y fusión (concatenación o suma),
    seguido de una secuencia de convoluciones (N_Conv).

    Args:
        in_channels: Canales de entrada para el upsampling.
        out_channels: Canales de salida deseados.
        num_conv: Número de capas en N_Conv.
        up_mode: Método de upsampling: "upsample" o "transpose".
        addition: Modo de fusión en la conexión skip: "cat" (concatenar) o "sum" (sumar).
        skip: Si True se utiliza la conexión skip con la entrada x2.
        kernel_size: Tamaño del kernel para N_Conv (por defecto 3).
        padding: Padding para N_Conv (por defecto 1).
        stride: Stride para N_Conv (por defecto 1).
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        num_conv: int,
        up_mode: str = "upsample",
        skip: bool = True,
        addition: str = "cat",
        kernel_size: int = 3,
        padding: int = 1,
        stride: int = 1,
    ) -> None:
        super().__init__()
        self.skip = skip
        self.addition = addition
        self.up_mode = up_mode

        if up_mode not in ["upsample", "transpose"]:
            raise ValueError(
                'El parámetro "up_mode" debe ser "upsample" o "transpose".'
            )

        # Configuración del upsampling y determinación de canales tras el up.
        if up_mode == "upsample":
            self.up = nn.Upsample(scale_factor=2, mode="linear", align_corners=True)
            up_out_channels = in_channels  # canales permanecen iguales en upsample
        else:  # "transpose"
            self.up = nn.ConvTranspose1d(
                in_channels,
                out_channels,
                kernel_size=3,
                stride=2,
                padding=1,
                output_padding=1,
            )
            up_out_channels = out_channels

        # Definir los canales de entrada para la siguiente convolución.
        if skip:
            if addition == "cat":
                conv_in_channels = up_out_channels + out_channels
            elif addition == "sum":
                conv_in_channels = out_channels
                if up_mode == "upsample":
                    self.adjust = nn.Conv1d(
                        up_out_channels, out_channels, kernel_size=1
                    )
            else:
                raise ValueError('El parámetro "addition" debe ser "cat" o "sum".')
        else:
            conv_in_channels = up_out_channels

        self.conv = N_Conv(
            conv_in_channels, out_channels, num_conv, kernel_size, padding, stride
        )

    def forward(self, x1, x2):
        x1 = self.up(x1)
        if self.skip:
            if x2 is None:
                raise ValueError("Se requiere x2 para la conexión skip.")
            if self.addition == "cat":
                diff = x2.size(2) - x1.size(2)
                x1 = nn.functional.pad(x1, [diff // 2, diff - diff // 2])
                x = torch.cat([x2, x1], dim=1)
            elif self.addition == "sum":
                if self.up_mode == "upsample":
                    x1 = self.adjust(x1)
                x = x2 + x1
        else:
            x = x1
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)


# Definición de un modelo simple similar a U-Net para pruebas
class SimpleUNet(nn.Module):
    def __init__(
        self,
        embedding_dim=4,
        num_conv=2,
        features=[4, 8, 8],
        up_mode="upsample",
        addition="sum",
        skip=True,
    ):
        super().__init__()

        # Capa inicial: de embedding_dim a 4 canales
        self.inc = N_Conv(embedding_dim, features[0], num_conv)
        # Dos bloques de downsampling
        self.down1 = Max_Down(features[0], features[1], num_conv)  # 128 -> 64
        self.down2 = Max_Down(features[1], features[2], num_conv)  # 64 -> 32
        # Dos bloques de upsampling
        # Primer bloque: combina salida de down2 (8 canales) con down1 (8 canales) -> in_channels=16, produce 8
        self.up1 = UpBlock(
            features[2],
            features[1],
            num_conv,
            up_mode=up_mode,
            addition=addition,
            skip=skip,
        )
        # Segundo bloque: combina salida de up1 (8 canales) con inc (4 canales) -> in_channels=12, produce 4
        self.up2 = UpBlock(
            features[1],
            features[0],
            num_conv,
            up_mode=up_mode,
            addition=addition,
            skip=skip,
        )
        # Capa final: reduce a la dimensión original de embedding (por ejemplo, 4)
        self.outc = nn.Conv1d(features[0], embedding_dim, kernel_size=1)

    def forward(self, x):
        # x de tamaño [batch, embedding_dim, L], con L=128 en nuestro ejemplo
        print("Input shape:", x.shape)
        x1 = self.inc(x)
        print("After inc:", x1.shape)
        x2 = self.down1(x1)
        print("After down1:", x2.shape)
        x3 = self.down2(x2)
        print("After down2 (latente):", x3.shape)
        # Decodificación: primero upsample y concatena con la salida de down1
        x_up1 = self.up1(x3, x2)
        print("After up1:", x_up1.shape)
        # Luego upsample y concatena con la salida inicial
        x_up2 = self.up2(x_up1, x1)
        print("After up2:", x_up2.shape)
        out = self.outc(x_up2)
        print("Output shape:", out.shape)
        return out

In [84]:
# Prueba: Creamos un batch de ejemplo y ejecutamos el modelo
batch_size = 4
embedding_dim = 4
num_conv = (2,)
features = ([4, 8, 8],)
up_mode = ("upsample",)
addition = ("sum",)
skip = (True,)
L = 128  # Longitud temporal inicial

# Creamos un tensor de ejemplo con valores aleatorios
x = torch.randn(batch_size, embedding_dim, L)
features = [4, 8]
for f in range(6):
    features.append(8)

    for mode in ["upsample", "transpose"]:
        for add in ["sum", "cat"]:
            for skip in [True, False]:
                print(f"m {mode} a {add} s {skip}")
                model = SimpleUNet(
                    embedding_dim=4,
                    num_conv=2,
                    features=[4, 8, 8],
                    up_mode=mode,
                    addition=add,
                    skip=skip,
                )
                output = model(x)

m upsample a sum s True
Input shape: torch.Size([4, 4, 128])
After inc: torch.Size([4, 4, 128])
After down1: torch.Size([4, 8, 64])
After down2 (latente): torch.Size([4, 8, 32])
After up1: torch.Size([4, 8, 64])
After up2: torch.Size([4, 4, 128])
Output shape: torch.Size([4, 4, 128])
m upsample a sum s False
Input shape: torch.Size([4, 4, 128])
After inc: torch.Size([4, 4, 128])
After down1: torch.Size([4, 8, 64])
After down2 (latente): torch.Size([4, 8, 32])
After up1: torch.Size([4, 8, 64])
After up2: torch.Size([4, 4, 128])
Output shape: torch.Size([4, 4, 128])
m upsample a cat s True
Input shape: torch.Size([4, 4, 128])
After inc: torch.Size([4, 4, 128])
After down1: torch.Size([4, 8, 64])
After down2 (latente): torch.Size([4, 8, 32])
After up1: torch.Size([4, 8, 64])
After up2: torch.Size([4, 4, 128])
Output shape: torch.Size([4, 4, 128])
m upsample a cat s False
Input shape: torch.Size([4, 4, 128])
After inc: torch.Size([4, 4, 128])
After down1: torch.Size([4, 8, 64])
After down2

# SEQ2SEQ

In [None]:
class Seq2Seq(nn.Module):
    def __init__(
        self,
        train_len=0,
        embedding_dim=4,
        device="cpu",
        lr=1e-3,
        scheduler="none",
        output_th=0.5,
        verbose=True,
        **kwargs,
    ):
        """Base instantiation of model"""
        super().__init__()

        self.device = device
        self.verbose = verbose
        self.config = kwargs
        self.output_th = output_th

        self.hyperparameters = {
            "hyp_device": device,
            "hyp_lr": lr,
            "hyp_scheduler": scheduler,
            "hyp_verbose": verbose,
            "hyp_output_th": output_th,
        }
        # Define architecture
        self.build_graph(embedding_dim, **kwargs)
        self.optimizer = tr.optim.Adam(self.parameters(), lr=lr)

    def build_graph(
        self,
        embedding_dim,
        num_conv=2,
        up_mode="transpose",
        skip=False,
        addition="cat",
        **kwargs,
    ):

        self.features = [4, 8, 8, 8]
        self.r_features = self.features[::-1]
        self.encoder_blocks = len(self.features) - 1
        self.L_min = 128 // ((2**self.encoder_blocks))
        volume = [(128 / 2**i) * f for i, f in enumerate(self.features)]

        self.architecture = {
            "arc_embedding_dim": embedding_dim,
            "arc_encoder_blocks": self.encoder_blocks,
            "arc_initial_volume": embedding_dim * 128,
            "arc_latent_volume": volume[-1],
            "arc_features": self.features,
            "arc_num_conv": num_conv,
            "arc_up_mode": up_mode,
            "arc_addition": addition,
            "arc_skip": skip,
        }
        self.inc = N_Conv(embedding_dim, self.features[0], num_conv)

        self.down = nn.ModuleList(
            [
                Max_Down(self.features[i], self.features[i + 1], num_conv)
                for i in range(self.encoder_blocks)
            ]
        )
        self.up = nn.ModuleList(
            [
                UpBlock(
                    in_channels=self.r_features[i],
                    out_channels=self.r_features[i + 1],
                    num_conv=num_conv,
                    up_mode=up_mode,
                    skip=skip,
                    addition=addition,
                )
                for i in range(len(self.r_features) - 1)
            ]
        )
        self.outc = OutConv(self.features[0], embedding_dim)

    def forward(self, batch):
        # x = batch["embedding"].to(self.device)
        x = batch
        print("Input shape:", x.shape)
        x = self.inc(x)
        encoder_outputs = [x]
        print("Encoder 0 shape:", encoder_outputs[0].shape)
        for i, down in enumerate(self.down):
            x = down(x)
            encoder_outputs.append(x)
            print(f"Encoder {i+1} shape:", encoder_outputs[i + 1].shape)

        x_latent = x
        print(f"X latent shape:", x_latent.shape)

        skips = encoder_outputs[:-1][::-1]
        for up, skip in zip(self.up, skips):
            print(f"Up shape: {x.shape};  skip shape: {skip.shape}")
            x = up(x, skip)

        x_rec = self.outc(x)

        return x_rec, x_latent

In [None]:


model = Seq2Seq(
    train_len=0,
    embedding_dim=4,
    device="cpu",
    lr=1e-3,
    scheduler="none",
    output_th=0.5,
    verbose=True,
    kernel=3,
    num_conv=2,
    up_mode="transpose",
    addition="cat",
    skip=False,
)
output = model(x)

Input shape: torch.Size([4, 4, 128])
Encoder 0 shape: torch.Size([4, 4, 128])
Encoder 1 shape: torch.Size([4, 8, 64])
Encoder 2 shape: torch.Size([4, 8, 32])
Encoder 3 shape: torch.Size([4, 8, 16])
X latent shape: torch.Size([4, 8, 16])
Up shape: torch.Size([4, 8, 16]);  skip shape: torch.Size([4, 8, 32])
Up shape: torch.Size([4, 8, 32]);  skip shape: torch.Size([4, 8, 64])
Up shape: torch.Size([4, 8, 64]);  skip shape: torch.Size([4, 4, 128])
