In [None]:
import torch
import torch.nn as nn

class ThroughputEstimator(nn.Module):
    """LSTM unit for regression"""

    def __init__(self, lstm_input_size, output_size, hidden_size, num_layers, drop_out=0.1):
        super().__init__()

        self.lstm_input_size = lstm_input_size  # Number of features per time step for LSTM
        self.hidden_size = hidden_size  # Size of hidden state
        self.output_size = output_size  # Number of output features
        self.num_layers = num_layers    # Number of LSTM layers
        self.drop_out = drop_out        # Dropout probability

        # LSTM: Processes input of shape (batch_size, seq_len, lstm_input_size)
        self.lstm = nn.LSTM(
            input_size=self.input_size,
            hidden_size=self.hidden_size,
            num_layers=self.num_layers,
            batch_first=True,           # Input shape: (batch_size, seq_len, input_size)
            dropout=self.drop_out if self.num_layers > 1 else 0.0,  # Dropout only if >1 layer
        )

        # CNN: Processes input of shape (batch_size, cnn_in_channels, 273, 14)
        self.cnn = nn.Sequential(
            nn.Conv2d(
                in_channels=2,
                out_channels=16,
                kernel_size=3,
                stride=1,
                padding=1
            ),  # Output: (batch_size, 16, 273, 14)
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),  # Output: (batch_size, 16, 136, 7)
            nn.Conv2d(
                in_channels=16,
                out_channels=32,
                kernel_size=3,
                stride=1,
                padding=1
            ),  # Output: (batch_size, 32, 136, 7)
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),  # Output: (batch_size, 32, 68, 3)
            nn.Flatten(),  # Output: (batch_size, 32 * 68 * 3)
            nn.Linear(32 * 68 * 3, self.hidden_size),  # Output: (batch_size, hidden_size)
            nn.ReLU(),
            nn.Dropout(self.drop_out)
        )
        
        # Final FC layer: Combines CNN and LSTM outputs
        self.fc = nn.Sequential(
            nn.Linear(self.hidden_size * 2, self.output_size),  # Combines CNN (hidden_size) + LSTM (hidden_size)
            nn.Dropout(self.drop_out) if self.drop_out > 0 else nn.Identity()
        )

    def forward(self, cnn_input, lstm_input):
        """
        Args:
            cnn_input: Tensor of shape (batch_size, cnn_in_channels, 273, 14)
            lstm_input: Tensor of shape (batch_size, seq_len, lstm_input_size)
        Returns:
            Tensor of shape (batch_size,) if output_size=1, else (batch_size, output_size)
        """
        # CNN forward pass
        cnn_out = self.cnn(cnn_input)  # Shape: (batch_size, hidden_size)
        
        # LSTM forward pass
        h0 = torch.zeros(self.num_layers, lstm_input.size(0), self.hidden_size).to(lstm_input.device)
        c0 = torch.zeros(self.num_layers, lstm_input.size(0), self.hidden_size).to(lstm_input.device)
        lstm_out, _ = self.lstm(lstm_input, (h0, c0))  # Shape: (batch_size, seq_len, hidden_size)
        lstm_out = lstm_out[:, -1, :]  # Take last time step: (batch_size, hidden_size)
        
        # Combine CNN and LSTM outputs
        combined = torch.cat((cnn_out, lstm_out), dim=1)  # Shape: (batch_size, hidden_size * 2)
        
        # Final FC layer
        out = self.fc(combined)  # Shape: (batch_size, output_size)
        
        # Squeeze if output_size == 1
        return out.squeeze(-1) if self.output_size == 1 else out