###Compression Methods
This notebook implements two compression techniques on the LSTM and ST-GCN models
1. Dynamic Quantisation
2. Knowlegde Distillation

In [None]:
#Install libraries
!pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m15.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1


In [None]:
#Import libraries
import tensorflow as tf
import time
import pandas as pd
import numpy as np
import os
import torch
import torch.nn as nn
import torch_geometric
from torch_geometric.data import Data, DataLoader
from sklearn.metrics import mean_absolute_error
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.preprocessing import LabelEncoder
from torch_geometric.nn import GCNConv
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, Dropout

In [None]:
#Can be uncommented if google colab is being used for running the script
# from google.colab import drive
# drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
#Change the model directory to the path the models were saved in modelling_lstm.ipynb and modelling_stgcn.ipynb
lstm_model_path = '/content/drive/MyDrive/AI_Sustainability/models/lstm_model_lag_14_4layer.h5'
stgcn_model_path = '/content/drive/MyDrive/AI_Sustainability/models/st_gcn_model_epoch_10.pth'
look_back = 14

In [None]:
#Load dataset
train_data = "/content/drive/MyDrive/AI_Sustainability/data/train_data_ensemble_0-agg.csv"
X = pd.read_csv(train_data)
train_dataset = X[['Date', 'min_dist_node', 'stid', 'tcolc_eatm_0', 'ulwrf_tatm_0', 'dlwrf_sfc_0', 'tmp_sfc_0', 'tcdc_eatm_0', 'dswrf_sfc_0',
            'tmax_2m_0', 'tmin_2m_0', 'pwat_eatm_0', 'ulwrf_tatm_0', 'dlwrf_sfc_0', 'tmp_sfc_0',
            'uswrf_sfc_0', 'spfh_2m_0', 'ulwrf_sfc_0', 'tmp_2m_0', 'apcp_sfc_0', 'pres_msl_0', 'Daily_Production']]

test_data = "/content/drive/MyDrive/AI_Sustainability/data/test_data_ensemble_0-agg.csv"
X = pd.read_csv(test_data)
test_dataset = X[['Date', 'min_dist_node','stid', 'tcolc_eatm_0', 'ulwrf_tatm_0', 'dlwrf_sfc_0', 'tmp_sfc_0', 'tcdc_eatm_0', 'dswrf_sfc_0',
            'tmax_2m_0', 'tmin_2m_0', 'pwat_eatm_0', 'ulwrf_tatm_0', 'dlwrf_sfc_0', 'tmp_sfc_0',
            'uswrf_sfc_0', 'spfh_2m_0', 'ulwrf_sfc_0', 'tmp_2m_0', 'apcp_sfc_0', 'pres_msl_0', 'Daily_Production']]


##Data Preprocessing

In [None]:
def preprocess_data(df):
    """
    Data preprocessing that adds columns ['day', 'month', 'year'], extract latitude and longitude from 'min_dist_node' column and interpolates values.
    Station ID is also encoded using one-hot encoding.
    """
    # Convert date to datetime
    df["Date"] = pd.to_datetime(df["Date"])

    # Sort by station and date
    df = df.sort_values(["stid", "Date"]).reset_index(drop=True)

    # Extract weather features
    weather_features = [col for col in df.columns if col.endswith("_0")]

    df["dayofyear"] = df["Date"].dt.dayofyear
    df["month"] = df["Date"].dt.month
    df["weekday"] = df["Date"].dt.weekday

    df["min_dist_node"] = df["min_dist_node"].astype(str)

    def extract_float_tuple(s):
        try:
            return tuple(float(part.replace("np.float32(", "").replace(")", "")) for part in s.strip("()").split(","))
        except:
            return (None, None)

    df[['lat', 'lon']] = df['min_dist_node'].apply(lambda x: pd.Series(extract_float_tuple(str(x))))

    df["lat"] = df["lat"].fillna(method="ffill")
    df["lon"] = df["lon"].fillna(method="ffill")

    le = LabelEncoder()
    df["stid_encoded"] = le.fit_transform(df["stid"])

    return df, le, weather_features

In [None]:
train_df, station_encoder, weather_features = preprocess_data(train_dataset)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df["Date"] = pd.to_datetime(df["Date"])
  df["lat"] = df["lat"].fillna(method="ffill")
  df["lon"] = df["lon"].fillna(method="ffill")


In [None]:
test_df, _, _ = preprocess_data(test_dataset)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df["Date"] = pd.to_datetime(df["Date"])
  df["lat"] = df["lat"].fillna(method="ffill")
  df["lon"] = df["lon"].fillna(method="ffill")


##Temporal Sequence Creation

In [None]:
def create_sequences(data, stations, features, target_col, lookback=7):
    """
    Function to created shifted window sequences for all stations for a given look back window
    """
    X_sequences = []
    y_values = []
    station_ids = []
    dates = []

    # Additional features to include (non-weather)
    additional_features = ['dayofyear', 'month', 'weekday', 'lat', 'lon', 'stid_encoded']
    all_features = features + additional_features

    for station in stations:
        # Get data for this station
        station_data = data[data['stid'] == station].copy()
        if len(station_data) <= lookback:
            continue

        # Create sequences for this station
        station_features = station_data[all_features].values
        station_target = station_data[target_col].values
        station_dates = station_data['Date'].values

        for i in range(len(station_data) - lookback):
            X_sequences.append(station_features[i:i+lookback])
            y_values.append(station_target[i+lookback])
            station_ids.append(station)
            dates.append(station_dates[i+lookback])

    return np.array(X_sequences), np.array(y_values), np.array(station_ids), np.array(dates)

In [None]:
# Extract features and targets
features_cols = weather_features
target_col = 'Daily_Production'
station_list = train_df['stid'].unique()

# Create sequences for training data
X_train_seq, y_train, train_stations, train_dates = create_sequences(
    train_df, station_list, features_cols, target_col, lookback=look_back
)

# Create sequences for test data
X_test_seq, y_test, test_stations, test_dates = create_sequences(
    test_df, station_list, features_cols, target_col, lookback=look_back
)

print(f"Training sequences: {X_train_seq.shape}")
print(f"Training targets: {y_train.shape}")
print(f"Test sequences: {X_test_seq.shape}")
print(f"Test targets: {y_test.shape}")

X_val_seq, y_val = X_train_seq[-1000:], y_train[-1000:]
X_train_seq, y_train = X_train_seq[:-1000], y_train[:-1000]

Training sequences: (177674, 14, 30)
Training targets: (177674,)
Test sequences: (34398, 14, 30)
Test targets: (34398,)


##Feature Scaling

In [None]:
# Normalize features
scaler_X = StandardScaler()
# Reshape to 2D for scaling
n_samples_train, n_timesteps, n_features = X_train_seq.shape
X_train_reshaped = X_train_seq.reshape(n_samples_train * n_timesteps, n_features)
X_train_scaled = scaler_X.fit_transform(X_train_reshaped)
# Reshape back to 3D
X_train_scaled = X_train_scaled.reshape(n_samples_train, n_timesteps, n_features)

# Scale test data using the same scaler
n_samples_test, _, _ = X_test_seq.shape
X_test_reshaped = X_test_seq.reshape(n_samples_test * n_timesteps, n_features)
X_test_scaled = scaler_X.transform(X_test_reshaped)
X_test_scaled = X_test_scaled.reshape(n_samples_test, n_timesteps, n_features)

# Scale target
scaler_y = MinMaxScaler()
y_train_scaled = scaler_y.fit_transform(y_train.reshape(-1, 1)).flatten()
y_test_scaled = scaler_y.transform(y_test.reshape(-1, 1)).flatten()

X_val_scaled, y_val_scaled = X_train_scaled[-1000:], y_train_scaled[-1000:]
X_train_scaled, y_train_scaled = X_train_scaled[:-1000], y_train_scaled[:-1000]

In [None]:
def model_size(filepath):
  """
  Function to calculate model size in KB
  """
  if not os.path.exists(filepath):
      raise FileNotFoundError(f"No such file: {filepath}")
  size_kb = os.path.getsize(filepath) / 1024
  print(f"File size of '{filepath}': {size_kb:.2f} KB")
  return size_kb

##Post Training Quantisation
Here since two different frameworks were used to create the LSTM and STGCN models (Tensorflow Keras and Pytorch respectively), their corresponding post training quantistaion modules are used for compression

In [None]:
def post_training_quantization_tf(model_path, save_path, model='lstm'):
    """
    Function to perform post-training quantisation
    """
    if model == "lstm":
      # Load model
      model = tf.keras.models.load_model(model_path, custom_objects={'mse': tf.keras.losses.MeanSquaredError})

      # Convert to a quantized TFLite model
      converter = tf.lite.TFLiteConverter.from_keras_model(model)
      converter.optimizations = [tf.lite.Optimize.DEFAULT]  # Default setting with float32 to int8 quantisation
      # converter.target_spec.supported_types = [tf.float16]
      converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
      converter._experimental_lower_tensor_list_ops = False
      tflite_model = converter.convert()
      save_dir = os.path.dirname(save_path)

      # Save quantized model
      with open(save_path, 'wb') as f:
          f.write(tflite_model)

      print(f"Quantized model saved to: {save_path}")
      return save_path
    else:
      class STGCNModel(nn.Module):
        def __init__(self, num_node_features, hidden_dim):
            super(STGCNModel, self).__init__()
            self.gcn1 = GCNConv(num_node_features, hidden_dim)
            self.gcn2 = GCNConv(hidden_dim, hidden_dim)
            self.fc = nn.Linear(hidden_dim, 1)

        def forward(self, x, edge_index, batch):
            x = torch.relu(self.gcn1(x, edge_index))
            x = torch.relu(self.gcn2(x, edge_index))
            # Apply global pooling to get a graph-level representation
            x = torch_geometric.nn.global_mean_pool(x, batch)
            x = self.fc(x)
            return x.squeeze()
      model = STGCNModel(num_node_features=30, hidden_dim=64)
      model.load_state_dict(torch.load(model_path, map_location="cpu"))
      model.eval()
      model.qconfig = torch.quantization.get_default_qconfig("fbgemm")
      model_prepared = torch.quantization.prepare(model)
      torch.save(model_prepared.state_dict(), save_path)
      print(f"Quantized model saved to: {save_path}")
      return model_prepared


LSTM

In [None]:
quant_lstm_model = post_training_quantization_tf(lstm_model_path, '/lstm_quantized_model.tflite', 'lstm')
print(f"LSTM Model size after quantisation: {model_size('/lstm_quantized_model.tflite')} KB")
print(f"LSTM Model size before quantisation: {model_size(lstm_model_path)} KB")



Saved artifact at '/tmp/tmp7zv87yji'. The following endpoints are available:

* Endpoint 'serve'
  args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 14, 30), dtype=tf.float32, name='input_layer')
Output Type:
  TensorSpec(shape=(None, 1), dtype=tf.float32, name=None)
Captures:
  134257162913168: TensorSpec(shape=(), dtype=tf.resource, name=None)
  134257162925648: TensorSpec(shape=(), dtype=tf.resource, name=None)
  134257162912400: TensorSpec(shape=(), dtype=tf.resource, name=None)
  134256709714192: TensorSpec(shape=(), dtype=tf.resource, name=None)
  134256709720144: TensorSpec(shape=(), dtype=tf.resource, name=None)
  134256709719760: TensorSpec(shape=(), dtype=tf.resource, name=None)
  134256709719568: TensorSpec(shape=(), dtype=tf.resource, name=None)
  134256709721104: TensorSpec(shape=(), dtype=tf.resource, name=None)
  134256709718800: TensorSpec(shape=(), dtype=tf.resource, name=None)
  134256709715536: TensorSpec(shape=(), dtype=tf.resource, name=None)
  134256709718608: Ten

In [None]:
# Run inference on tflite model
interpreter = tf.lite.Interpreter(model_path='/lstm_quantized_model.tflite')
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

input_data = X_test_scaled.astype(input_details[0]['dtype'])

preds = []
start = time.time()
for i in range(input_data.shape[0]):
    interpreter.set_tensor(input_details[0]['index'], input_data[i:i+1])
    interpreter.invoke()
    output_data = interpreter.get_tensor(output_details[0]['index'])
    preds.append(output_data[0])
end = time.time()
print(f"Time taken for inference: {end - start} seconds")
preds = np.array(preds)

Time taken for inference: 37.71252655982971 seconds


In [None]:
# Inverse transform the predictions and true values
y_pred = scaler_y.inverse_transform(preds.reshape(-1, 1))
y_true = scaler_y.inverse_transform(y_test_scaled.reshape(-1, 1))

# Calculate Mean Absolute Error
mae = mean_absolute_error(y_true, y_pred)
print(f"Mean Absolute Error (MAE) after Quantisation of LSTM Model: {mae}")

Mean Absolute Error (MAE) after Quantisation of LSTM Model: 4456819.899176551


ST-GCN

In [None]:
stgcn_quant_model = post_training_quantization_tf(stgcn_model_path, '/stgcn_quantized_model.pth', 'GCN')
print(f"STGCN Model size after quantisation: {model_size('/stgcn_quantized_model.pth')} KB")
print(f"STGCN Model size before quantisation: {model_size(stgcn_model_path)} KB")

Quantized PyTorch model saved to: /stgcn_quantized_model.pth
File size of '/stgcn_quantized_model.pth': 36.18 KB
STGCN Model size after quantisation: 36.18359375 KB
File size of '/content/drive/MyDrive/ams-2014-solar-energy-prediction/models/st_gcn_model_epoch_1.pth': 27.07 KB
STGCN Model size before quantisation: 27.072265625 KB




In [None]:
class STGCNModel(nn.Module):
    def __init__(self, num_node_features, hidden_dim):
        super(STGCNModel, self).__init__()
        self.gcn1 = GCNConv(num_node_features, hidden_dim)
        self.gcn2 = GCNConv(hidden_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, 1)

    def forward(self, x, edge_index, batch):
        x = torch.relu(self.gcn1(x, edge_index))
        x = torch.relu(self.gcn2(x, edge_index))
        # Apply global pooling to get a graph-level representation
        x = torch_geometric.nn.global_mean_pool(x, batch)
        x = self.fc(x)
        return x.squeeze()

# Re-initialize the model architecture
stgcn_quant_model.load_state_dict(torch.load('/stgcn_quantized_model.pth', map_location="cpu"))
stgcn_quant_model.eval()

STGCNModel(
  (gcn1): GCNConv(30, 64)
  (gcn2): GCNConv(64, 64)
  (fc): Linear(
    in_features=64, out_features=1, bias=True
    (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
  )
)

In [None]:
preds = []
X_test_tensor = torch.tensor(X_test_scaled, dtype=torch.float32)

preds = []

start = time.time()
for i in range(X_test_scaled.shape[0]):
    sample_sequence = X_test_scaled[i]

    x = torch.tensor(sample_sequence, dtype=torch.float32)

    x = x.view(look_back, -1)
    edge_index = torch.empty((2, 0), dtype=torch.long)
    batch = torch.zeros(look_back, dtype=torch.long)

    # Perform inference
    with torch.no_grad():
        prediction = stgcn_quant_model(x, edge_index, batch)

    preds.append(prediction.item())
end = time.time()
print(f"Time taken for inference: {end - start} seconds")

preds = np.array(preds) # Convert list of predictions to a NumPy array

y_pred = scaler_y.inverse_transform(preds.reshape(-1, 1))
y_true = scaler_y.inverse_transform(y_test_scaled.reshape(-1, 1))

# Calculate Mean Absolute Error
mae = mean_absolute_error(y_true, y_pred)
print(f"Mean Absolute Error (MAE) for STGCN Model: {mae}")

Time taken for inference: 46.592411279678345 seconds
Mean Absolute Error (MAE) for STGCN Model: 4340312.914298176


##Distillation

##LSTM Distillation

##LSTM Model Definiton

In [None]:
# Import libraries needed for distillation
import tensorflow as tf
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import numpy as np
from tensorflow.keras.layers import LSTM, Dense
from tensorflow.keras.models import Sequential
from tensorflow.keras.callbacks import EarlyStopping
from torch_geometric.nn import GCNConv
from sklearn.metrics import mean_absolute_error

"""## LSTM Distillation
Create a smaller student LSTM model and train it to mimic our larger teacher LSTM model
"""

def create_student_lstm_model(input_shape, hidden_units=64):
    """
    Creates a smaller LSTM model with fewer parameters than the teacher
    """
    model = Sequential()
    model.add(LSTM(units=hidden_units, input_shape=input_shape))
    model.add(Dense(1))
    model.compile(optimizer='adam', loss='mse')
    return model

# Define distillation loss function for Tensorflow LSTM
def distillation_loss(alpha=0.5):
    """
    Create a loss function that combines:
    - standard MSE against true values
    - distillation loss (MSE between student & teacher predictions)

    Args:
        alpha: weight for balancing the two losses (0-1)
    """
    mse = tf.keras.losses.MeanSquaredError()
    def loss_fn(y_true, y_pred):
        # Extract the true target and teacher predictions
        # y_true is expected to have shape (batch_size, 2) where:
        # - [:, 0] contains the actual targets
        # - [:, 1] contains the teacher predictions
        true_targets = y_true[:, 0]
        teacher_preds = y_true[:, 1]

        # Standard MSE loss against true targets
        mse_loss = mse(true_targets, y_pred)

        # Distillation loss (MSE between student & teacher predictions)
        distill_loss = mse(teacher_preds, y_pred)

        # Combined loss
        return alpha * mse_loss + (1 - alpha) * distill_loss

    return loss_fn

# Load the teacher LSTM model
custom_objects_for_loading = {'mse': tf.keras.losses.MeanSquaredError()}
lstm_teacher_model = tf.keras.models.load_model(lstm_model_path, custom_objects=custom_objects_for_loading)

# Generate teacher predictions for training data
teacher_train_preds = lstm_teacher_model.predict(X_train_scaled)
teacher_val_preds = lstm_teacher_model.predict(X_val_scaled)

# Prepare the training data for distillation
# Combine true targets with teacher predictions
y_train_combined = np.column_stack([y_train_scaled, teacher_train_preds.flatten()])
y_val_combined = np.column_stack([y_val_scaled, teacher_val_preds.flatten()])

# Create student model
input_shape = (look_back, X_train_scaled.shape[2])
lstm_student_model = create_student_lstm_model(input_shape)

print("Teacher LSTM Model:")
lstm_teacher_model.summary()

print("\nStudent LSTM Model:")
lstm_student_model.summary()

# Train the student model with distillation
early_stopping = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)

# Custom loss function
student_loss = distillation_loss(alpha=0.3)

# Compile with custom loss
lstm_student_model.compile(optimizer='adam', loss=student_loss)

train_start = time.time()
history = lstm_student_model.fit(
    X_train_scaled,
    y_train_combined,
    epochs=5,
    batch_size=128,
    validation_data=(X_val_scaled, y_val_combined),
    callbacks=[early_stopping],
    verbose=1
)
train_end = time.time()
print(f"Training time for student LSTM model: {train_end - train_start:.2f} seconds")

# Save the student model
lstm_student_model.save('/lstm_student_model.h5')

# Evaluate student model
student_preds = lstm_student_model.predict(X_test_scaled)
y_student_pred = scaler_y.inverse_transform(student_preds.reshape(-1, 1))
student_mae = mean_absolute_error(y_test, y_student_pred)
print(f"Student LSTM Model MAE: {student_mae:.4f}")

# Size comparison
print(f"Teacher LSTM Model size: {model_size(lstm_model_path)} KB")
print(f"Student LSTM Model size: {model_size('/lstm_student_model.h5')} KB")




[1m5490/5490[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m73s[0m 13ms/step
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step
Teacher LSTM Model:


  super().__init__(**kwargs)



Student LSTM Model:


Epoch 1/5
[1m1373/1373[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m36s[0m 25ms/step - loss: 0.0350 - val_loss: 0.0167
Epoch 2/5
[1m1373/1373[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m29s[0m 21ms/step - loss: 0.0174 - val_loss: 0.0126
Epoch 3/5
[1m1373/1373[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m28s[0m 21ms/step - loss: 0.0130 - val_loss: 0.0090
Epoch 4/5
[1m1373/1373[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m31s[0m 23ms/step - loss: 0.0099 - val_loss: 0.0073
Epoch 5/5
[1m1373/1373[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m38s[0m 21ms/step - loss: 0.0080 - val_loss: 0.0058




Training time for student LSTM model: 176.24 seconds
[1m1075/1075[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 3ms/step
Student LSTM Model MAE: 5025533.0000
File size of '/content/drive/MyDrive/ams-2014-solar-energy-prediction/models/lstm_model_lag_14_4layer.h5': 1722.70 KB
Teacher LSTM Model size: 1722.703125 KB
File size of '/lstm_student_model.h5': 309.17 KB
Student LSTM Model size: 309.171875 KB


##Graph Creation

In [None]:
def create_graph_data(X_seq, y_seq):
    data_list = []
    for i in range(len(X_seq)):
        x = torch.tensor(X_seq[i], dtype=torch.float)  # shape (nodes, features)
        num_nodes = x.size(0)
        edge_index = torch.combinations(torch.arange(num_nodes), r=2).T
        edge_index = torch.cat([edge_index, edge_index[[1, 0]]], dim=1)  # undirected
        data = Data(x=x, edge_index=edge_index, y=torch.tensor([y_seq[i]], dtype=torch.float))
        data_list.append(data)
    return data_list
train_graphs = create_graph_data(X_train_scaled, y_train_scaled)

# STGCN Model Distillation

##STGCN Model Definition

In [None]:

class STGCNStudentModel(nn.Module):
    """
    A smaller version of the STGCN model with fewer parameters
    """
    def __init__(self, num_node_features, hidden_dim=32):
        super(STGCNStudentModel, self).__init__()
        self.gcn = GCNConv(num_node_features, hidden_dim)
        self.fc = nn.Linear(hidden_dim, 1)

    def forward(self, x, edge_index, batch):
        x = torch.relu(self.gcn(x, edge_index))
        x = torch_geometric.nn.global_mean_pool(x, batch)
        x = self.fc(x)
        return x.squeeze()

# Load teacher model
class STGCNModel(nn.Module):
    def __init__(self, num_node_features, hidden_dim):
        super(STGCNModel, self).__init__()
        self.gcn1 = GCNConv(num_node_features, hidden_dim)
        self.gcn2 = GCNConv(hidden_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, 1)

    def forward(self, x, edge_index, batch):
        x = torch.relu(self.gcn1(x, edge_index))
        x = torch.relu(self.gcn2(x, edge_index))
        x = torch_geometric.nn.global_mean_pool(x, batch)
        x = self.fc(x)
        return x.squeeze()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize teacher model
stgcn_teacher_model = STGCNModel(num_node_features=X_train_scaled.shape[2], hidden_dim=64).to(device)
stgcn_teacher_model.load_state_dict(torch.load(stgcn_model_path, map_location=device))
stgcn_teacher_model.eval()

# Initialize student model
stgcn_student_model = STGCNStudentModel(num_node_features=X_train_scaled.shape[2], hidden_dim=32).to(device)

# Count parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Teacher STGCN model parameters: {count_parameters(stgcn_teacher_model)}")
print(f"Student STGCN model parameters: {count_parameters(stgcn_student_model)}")

# Create a new train loader with a smaller batch size for distillation
train_loader_distill = DataLoader(train_graphs, batch_size=128, shuffle=True)

# Optimizer
optimizer = torch.optim.Adam(stgcn_student_model.parameters(), lr=0.001)

# Loss functions
mse_loss = nn.MSELoss()
distill_loss = nn.MSELoss()

# Train the student model
stgcn_student_model.train()
alpha = 0.3  # Weight for balancing between true loss and distillation loss

train_start = time.time()
for epoch in range(10):
    total_loss = 0
    for batch in train_loader_distill:
        batch = batch.to(device)

        # Forward pass student model
        student_out = stgcn_student_model(batch.x, batch.edge_index, batch.batch)

        # Get teacher predictions
        with torch.no_grad():
            teacher_out = stgcn_teacher_model(batch.x, batch.edge_index, batch.batch)

        # Calculate losses
        loss_true = mse_loss(student_out, batch.y)  # Loss against true values
        loss_distill = distill_loss(student_out, teacher_out)  # Loss against teacher predictions

        # Combined loss
        loss = alpha * loss_true + (1 - alpha) * loss_distill

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")

train_end = time.time()
print(f"Training time for student STGCN model: {train_end - train_start:.2f} seconds")

# Save student model
torch.save(stgcn_student_model.state_dict(), '/stgcn_student_model.pth')

# Evaluate student model
stgcn_student_model.eval()
preds = []

with torch.no_grad():
    for i in range(X_test_scaled.shape[0]):
        sample_sequence = X_test_scaled[i]
        x = torch.tensor(sample_sequence, dtype=torch.float32).to(device)
        x = x.view(look_back, -1)

        edge_index = torch.empty((2, 0), dtype=torch.long).to(device)
        batch = torch.zeros(look_back, dtype=torch.long).to(device)

        prediction = stgcn_student_model(x, edge_index, batch)
        preds.append(prediction.item())

preds = np.array(preds)
y_student_pred = scaler_y.inverse_transform(preds.reshape(-1, 1))
student_mae = mean_absolute_error(y_test, y_student_pred)
print(f"Student STGCN Model MAE: {student_mae:.4f}")

# Size comparison
print(f"Teacher STGCN Model size: {model_size(stgcn_model_path)} KB")
print(f"Student STGCN Model size: {model_size('/stgcn_student_model.pth')} KB")

Teacher STGCN model parameters: 6209
Student STGCN model parameters: 1025




Epoch 1, Loss: 18.3406
Epoch 2, Loss: 12.8995
Epoch 3, Loss: 12.7256
Epoch 4, Loss: 12.6443
Epoch 5, Loss: 12.5817
Epoch 6, Loss: 12.5418
Epoch 7, Loss: 12.5140
Epoch 8, Loss: 12.4878
Epoch 9, Loss: 12.4610
Epoch 10, Loss: 12.4522
Training time for student STGCN model: 412.71 seconds
Student STGCN Model MAE: 3974047.5527
File size of '/content/drive/MyDrive/ams-2014-solar-energy-prediction/models/st_gcn_model_epoch_1.pth': 27.07 KB
Teacher STGCN Model size: 27.072265625 KB
File size of '/stgcn_student_model.pth': 6.17 KB
Student STGCN Model size: 6.171875 KB
