In [None]:
import tensorflow as tf
from tensorflow.keras import layers, Model

In [None]:


class MainLSTM(Model):
    def __init__(self, dropout=0.0):
        super(MainLSTM, self).__init__()
        
        # Masking layer to ignore missing values (-99)
        self.mask = layers.Masking(mask_value=-99.0)

        # LSTM for auxiliary stations
        self.aux_lstm = layers.LSTM(8, return_sequences=True, dropout=dropout)
        
        # Main LSTM for concatenated sequence
        self.station_lstm = layers.LSTM(128, return_sequences=True, dropout=dropout, 
                                        recurrent_dropout=0.0)
        
        # Fully-connected output
        self.station_fc = layers.Dense(1)

    def call(self, inputs):
        # inputs shape: (batch_size, time_steps, features)
        
        # Apply masking
        x = self.mask(inputs)
        
        # Split main vs auxiliary features
        main_station_data = x[:, :, :10]  # First 10 features
        aux_station_data = x[:, :, 10:]   # Remaining features
        
        # Process auxiliary stations in chunks of 8 features
        aux_lstm_outs = []
        for i in range(0, aux_station_data.shape[2], 8):
            aux_chunk = aux_station_data[:, :, i:i+8]
            aux_lstm_out = self.aux_lstm(aux_chunk)
            aux_lstm_outs.append(aux_lstm_out)
        
        # Concatenate main station data + all aux LSTM outputs along features axis
        concat = tf.concat([main_station_data] + aux_lstm_outs, axis=-1)
        
        # Pass through main LSTM
        lstm_out = self.station_lstm(concat)
        
        # Take last time step
        final_feature = lstm_out[:, -1, :]
        
        # Fully-connected layer to predict GHI
        ghi_pred = self.station_fc(final_feature)
        
        return ghi_pred


In [None]:
# Suppose X_train shape = (batch_size, 24, 34)
model = MainLSTM(dropout=0.1)

# Compile
model.compile(optimizer='adam', loss='mse')

# Train
model.fit(X_train, y_train, epochs=50, batch_size=32)
