# CNN-RNN


In [None]:
import torch
from torch.utils.data import Dataset
import pandas as pd
import pytz
import numpy as np
import os
import sys
import glob
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from datetime import datetime
import matplotlib.pyplot as plt
import re


scripts_dir = os.path.abspath(os.path.join(os.getcwd(), '..', 'scripts'))
sys.path.append(scripts_dir)
from data_generator import normalize_new_data

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
torch.set_default_device(device)

In [None]:
# Define the directory and base filename pattern
file_pattern = "../data/final_data/cleaned_compiled_data_normalized_part*.csv"

# Use glob to get a sorted list of all matching CSV parts
csv_files = sorted(glob.glob(file_pattern))

# Load and concatenate all parts
data = pd.concat((pd.read_csv(f) for f in csv_files), ignore_index=True)

# Done! Now `data` holds the full combined normalized DataFrame
print(f"Loaded {len(csv_files)} files. Final shape: {data.shape}")

In [None]:
chosen_stations = ['S104', 'S107', 'S109', 'S115', 'S116', 'S24', 'S43', 'S50']
pixel_coords = [(4, 11, 'S109'),
 (2, 7, 'S50'),
 (1, 16, 'S107'),
 (2, 13, 'S43'),
 (0, 0, 'S115'),
 (4, 17, 'S24'),
 (0, 6, 'S116'),
 (8, 8, 'S104')]
def tabular_to_image(data: pd.DataFrame, pixel_coords, image_shape=(9, 18)):
    feature_types = ['rainfall', 'air_temperature', 'wind_speed', 'relative_humidity', 'wind_direction']
    H, W = image_shape
    T = data.shape[0]
    image = np.full((T, H, W, len(feature_types)), np.nan, dtype=np.float32)

    feature_to_channel = {feat: i for i, feat in enumerate(feature_types)}

    for y, x, station_id in pixel_coords:
        for feat in feature_types:
            col_name = f"{feat}_{station_id}"
            if col_name in data.columns:
                channel = feature_to_channel[feat]
                image[:, y, x, channel] = data[col_name].values

    return image if T > 1 else image[0]


In [None]:
class LightningDataset_Modified(Dataset):
    def __init__(self, compiled_df, pixel_coords, image_shape=(9, 18), timezone_str="Asia/Singapore", reject_zeros = True):
        self.compiled_df = compiled_df.copy()
        self.pixel_coords = pixel_coords
        self.image_shape = image_shape
        self.timezone = pytz.timezone(timezone_str)
        self.samples = []
        self.reject_zeros = reject_zeros

        self._prepare_dataset()

    def _prepare_dataset(self):
        # Ensure datetime index
        self.compiled_df["Timestamp"] = pd.to_datetime(self.compiled_df["Timestamp"])
        if not isinstance(self.compiled_df.index, pd.DatetimeIndex):
            self.compiled_df.set_index("Timestamp", inplace=True)
        self.compiled_df.index = self.compiled_df.index.tz_localize(None)

        # Drop target for input features
        input_df = self.compiled_df.drop(columns=["Lightning_Risk"])

        # Valid 2-hour timestamps
        min_ts = self.compiled_df.index.min().ceil("2h") + pd.Timedelta(hours=2)
        max_ts = self.compiled_df.index.max().floor("2h")
        valid_ts = self.compiled_df.loc[
            (self.compiled_df.index >= min_ts) &
            (self.compiled_df.index <= max_ts) &
            (self.compiled_df.index.hour % 2 == 0) &
            (self.compiled_df.index.minute == 0)
        ].index

        for timestamp in valid_ts:
            try:
                # Input time windows (past 5)
                input_times = [timestamp - pd.Timedelta(minutes=delta) for delta in [120, 90, 60, 30, 0]]
                input_slices = self.compiled_df.loc[input_times]
                input_images = tabular_to_image(input_slices, self.pixel_coords, self.image_shape)  # (5, H, W, C)

                # Rearrange to (C, T, H, W) if needed
                input_tensor = np.transpose(input_images, (3, 0, 1, 2))  # (C, T, H, W)

                # Output time windows (future 5)
                output_times = [timestamp + pd.Timedelta(minutes=delta) for delta in [0, 30, 60, 90, 120]]
                output_data = self.compiled_df.loc[output_times, "Lightning_Risk"].astype(int).values.flatten()
                if self.reject_zeros and not (output_data == 1).any():
                    continue 
                self.samples.append((input_tensor, output_data))
            except KeyError:
                continue  # Skip if any timestamps are missing

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        x, y = self.samples[idx]
        # x: shape (C, T, H, W) → permute to (T, C, H, W) for model
        x = torch.tensor(x, dtype=torch.float32).permute(1, 0, 2, 3)
        y = torch.tensor(y, dtype=torch.float32)
        return x, y
    
    def get_positive_ratio(self):
        all_labels = np.array([sample[1] for sample in self.samples])  # shape (N, 5)
        total = all_labels.size
        positives = (all_labels == 1).sum()
        return positives / total


In [None]:
dataset = LightningDataset_Modified(data,pixel_coords)

In [None]:
dataset[0]

In [None]:
len(dataset)

In [None]:
dataset.get_positive_ratio()

In [None]:
g = torch.Generator(device="cuda") if torch.cuda.is_available() else torch.Generator()
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, generator=g)


In [None]:
class LightningRiskCNNRNN(nn.Module):
    def __init__(self, num_channels=5, num_future_steps=5, hidden_size=256):
        super(LightningRiskCNNRNN, self).__init__()

        # --- CNN Feature Extractor ---
        self.conv1 = nn.Conv2d(in_channels=num_channels, out_channels=32, kernel_size=5, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.flatten = nn.Flatten()

        # Dynamically determine CNN output size
        with torch.no_grad():
            dummy_input = torch.zeros(1, num_channels, 9, 18)
            x = F.relu(self.conv1(dummy_input))
            x = F.relu(self.conv2(x))
            x = F.relu(self.conv3(x))
            self.feature_size = self.flatten(x).shape[1]

        # --- RNN ---
        self.rnn = nn.RNN(input_size=self.feature_size, hidden_size=hidden_size, batch_first=True)

        # --- Output Layer ---
        self.fc = nn.Linear(hidden_size, num_future_steps)

        # --- Learnable initial hidden state ---
        self.initial_hidden_state = nn.Parameter(torch.randn(1, 1, hidden_size))

    def forward(self, x):
        # x: (batch_size, seq_len=5, channels=5, height, width)
        batch_size, seq_len, c, h, w = x.shape

        cnn_features = []
        for t in range(seq_len):
            x_t = x[:, t]  # (batch, channels, height, width)
            x_t = torch.nan_to_num(x_t, nan=0.0)

            out = F.relu(self.conv1(x_t))
            out = F.relu(self.conv2(out))
            out = F.relu(self.conv3(out))
            out = self.flatten(out)

            cnn_features.append(out)

        cnn_features = torch.stack(cnn_features, dim=1)  # (batch_size, seq_len, feature_size)

        # Expand learnable hidden state for batch
        h0 = self.initial_hidden_state.expand(1, batch_size, -1).contiguous()

        # RNN
        rnn_out, _ = self.rnn(cnn_features, h0)

        # Final hidden state
        final_hidden = rnn_out[:, -1, :]  # (batch_size, hidden_size)

        # Risk prediction (sigmoid for binary/multi-label classification)
        predictions = torch.sigmoid(self.fc(final_hidden))  # (batch_size, num_future_steps)

        return predictions

In [None]:
model = LightningRiskCNNRNN().to(device)
# Get one batch
for batch_x, batch_y in dataloader:
    batch_x = batch_x.to(device)
    batch_y = batch_y.to(device)


In [None]:
def train(dataloader, model, num_epochs, learning_rate, device=device):
    model.train()
    model.to(device)  # Ensure model is on GPU/CPU

    # Binary Cross Entropy Loss for multi-label binary outputs
    criterion = nn.BCELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # List to track the average loss for every epoch (including epoch 0)
    avg_loss_history = []

    os.makedirs("models", exist_ok=True)  # Ensure the directory exists
    timestamp = datetime.now().strftime('%Y_%m_%d_%H_%M')
    for epoch in range(num_epochs):
        total_loss = 0
        for inputs, targets in dataloader:
            inputs = inputs.to(device)   # (B, T, C, H, W)
            targets = targets.to(device)  # (B, 5) — each value ∈ [0, 1]

            optimizer.zero_grad()

            predictions = model(inputs)  # Output shape: (B, 5)

            loss = criterion(predictions, targets)
            total_loss += loss.item()

            loss.backward()
            optimizer.step()

        # Calculate average loss for the epoch
        avg_loss = total_loss / len(dataloader)

        # Save the average loss every 5 epochs for monitoring
        if (epoch) % 5 == 0:
            avg_loss_history.append(avg_loss)
            print(f"Epoch {epoch+1}/{num_epochs} - Avg Loss: {avg_loss:.4f}")

        if (epoch) % 20 == 0:
            filename = f'model_{timestamp}_{avg_loss:0.6f}_{epoch}.pth'
            torch.save(model.state_dict(), os.path.join("models", filename))
            print(f"Model saved as {filename}")

    # Save the model with timestamp and final loss at the end of training
    final_loss = avg_loss_history[-1] if avg_loss_history else 0
    
    filename = f'model_{timestamp}_{final_loss:0.6f}.pth'

    # Save model state
    
    torch.save(model.state_dict(), os.path.join("models", filename))
    print(f"Model saved as {filename}")

    # Plotting the average loss every epoch (including epoch 0) in two subplots
    fig, axs = plt.subplots(1, 2, figsize=(15, 6))

    # Plot with standard y-axis
    axs[0].plot(range(0,num_epochs,5), avg_loss_history, label='Average Loss (every epoch)', color='blue')
    axs[0].set_xlabel('Epochs')
    axs[0].set_ylabel('Loss')
    axs[0].set_title('Average Training Loss Over Time (Standard Y-Axis)')
    axs[0].grid(True)
    axs[0].legend()

    # Plot with logarithmic y-axis
    axs[1].plot(range(0,num_epochs,5), avg_loss_history, label='Average Loss (every epoch)', color='blue')
    axs[1].set_xlabel('Epochs')
    axs[1].set_ylabel('Loss')
    axs[1].set_title('Average Training Loss Over Time (Logarithmic Y-Axis)')
    axs[1].set_yscale('log')  # Set the y-axis to logarithmic scale
    axs[1].grid(True)
    axs[1].legend()

    # Save the plot as an image
    os.makedirs("plots", exist_ok=True)  # Ensure the directory exists
    plot_filename = f'loss_plot_{timestamp}.png'
    plt.savefig(os.path.join("plots", plot_filename))
    print(f"Plot saved as {plot_filename}")

    # Show the plot
    plt.show()

    return avg_loss_history  # Return the loss history


In [None]:
model = LightningRiskCNNRNN().to(device)
%timeit -r 1 -n 1 train(dataloader = dataloader, model = model, num_epochs = 500, learning_rate = 1e-3)

In [None]:
def evaluate_accuracy(dataloader, model, device=device):
    model.eval()  # Set the model to evaluation mode
    model.to(device)  # Ensure the model is on the correct device (GPU/CPU)

    correct_predictions = 0
    total_predictions = 0

    with torch.no_grad():  # No need to track gradients during evaluation
        for inputs, targets in dataloader:
            inputs = inputs.to(device)  # (B, T, C, H, W)
            targets = targets.to(device)  # (B, 5) — each value ∈ [0, 1]

            predictions = model(inputs)  # Output shape: (B, 5)

            # Convert predictions to binary (0 or 1) based on a threshold of 0.5
            predicted_labels = (predictions > 0.5).float()

            # Calculate number of correct predictions
            correct_predictions += (predicted_labels == targets).sum().item()
            total_predictions += targets.numel()  # Total number of elements in the target tensor

    # Calculate accuracy as the percentage of correct predictions
    accuracy = correct_predictions / total_predictions * 100  # Percentage accuracy
    print(f"Accuracy: {accuracy:.2f}%")
    return accuracy


In [None]:
validation_dataset = LightningDataset_Modified(data,pixel_coords,reject_zeros=False)

In [None]:
len(validation_dataset)

In [None]:
# Assuming you have a test_dataloader prepared
validation_dataloader = DataLoader(validation_dataset, batch_size=64, shuffle=True, generator=g)
accuracy = evaluate_accuracy(validation_dataloader, model, device=device)


In [None]:
def extract_epoch(filename):
    match = re.search(r"_(\d+)\.pth$", filename)
    return int(match.group(1)) if match else -1

def evaluate_all_checkpoints(model_class, checkpoint_dir, train_loader, val_loader, device="cuda"):
    model_paths = sorted(glob.glob(f"./models/{checkpoint_dir}_*.pth"), key=extract_epoch)
    results = []
    best_val_acc = -1
    best_model_state = None

    for path in model_paths:
        epoch = extract_epoch(path)
        if epoch // 1 != epoch or epoch == -1:
            continue
        print(f"\nEvaluating model at epoch {epoch}...")

        # Initialize model and load state dict
        model = model_class()
        model.load_state_dict(torch.load(path, map_location=device))

        train_acc = evaluate_accuracy(train_loader, model, device)
        val_acc = evaluate_accuracy(val_loader, model, device)

        print(f"Epoch {epoch}: Train Acc = {train_acc:.2f}% | Val Acc = {val_acc:.2f}%")
        results.append((epoch, train_acc, val_acc))

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model_state = model.state_dict()

    # Save the best model
    if best_model_state:
        best_path = f"./models/{checkpoint_dir}_best.pth"
        torch.save(best_model_state, best_path)
        print(f"\n✅ Best model saved to {best_path} with Val Acc = {best_val_acc:.2f}%")

    return results


In [None]:
def plot_accuracy_validation(results, save_path):
    # Unpack results
    epochs = [r[0] for r in results]
    train_accuracies = [r[1] for r in results]
    val_accuracies = [r[2] for r in results]

    # Plot
    plt.figure(figsize=(10, 6))
    plt.plot(epochs, train_accuracies, label="Train Accuracy", marker='o', linestyle='-')
    plt.plot(epochs, val_accuracies, label="Validation Accuracy", marker='x', linestyle='--')
    plt.title("Model Accuracy Over Epochs")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy (%)")
    plt.grid(True)
    plt.legend()
    plt.tight_layout()

    # Save to file
    plt.savefig(save_path)
    plt.show()
    print(f"Accuracy plot saved to {save_path}")


In [None]:
model_timestamp = "model_2025_04_18_01_35"

results = evaluate_all_checkpoints(
    model_class=LightningRiskCNNRNN,
    checkpoint_dir=model_timestamp,  # Folder with .pth files
    train_loader=dataloader,
    val_loader=validation_dataloader,
    device="cuda" if torch.cuda.is_available() else "cpu"
)

plot_accuracy_validation(results,f"./plots/{model_timestamp}.png")

In [None]:
# Define the directory and base filename pattern
file_path_test = "../data/test_data/cleaned_compiled_data_normalized.csv"

# Load and concatenate all parts
test_data = pd.read_csv(file_path_test)

In [None]:
test_dataset = LightningDataset_Modified(test_data,pixel_coords,reject_zeros=False)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=True, generator=g)

In [None]:
def evaluate_and_plot_checkpoints(model_class, checkpoint_dir, train_loader, val_loader, test_loader, device="cuda"):
    def extract_epoch(filename):
        match = re.search(r"_(\d+)\.pth$", filename)
        return int(match.group(1)) if match else -1

    model_paths = sorted(glob.glob(f"./models/{checkpoint_dir}_*.pth"), key=extract_epoch)
    results = []
    best_val_acc = -1
    best_model_state = None
    best_test_acc = -1

    for path in model_paths:
        epoch = extract_epoch(path)
        if epoch // 1 != epoch or epoch == -1:
            continue
        print(f"\n🔍 Evaluating model at epoch {epoch}...")

        # Initialize model and load weights
        model = model_class()
        model.load_state_dict(torch.load(path, map_location=device))

        # Evaluate on all splits
        train_acc = evaluate_accuracy(train_loader, model, device)
        val_acc = evaluate_accuracy(val_loader, model, device)
        test_acc = evaluate_accuracy(test_loader, model, device)

        print(f"Epoch {epoch}: Train Acc = {train_acc:.2f}% | Val Acc = {val_acc:.2f}% | Test Acc = {test_acc:.2f}%")
        results.append((epoch, train_acc, val_acc, test_acc))

        # Track best validation model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_test_acc = test_acc
            best_model_state = model.state_dict()

    # Save best model weights
    if best_model_state:
        best_model_path = f"./models/{checkpoint_dir}_best.pth"
        torch.save(best_model_state, best_model_path)
        print(f"\n✅ Best model saved to {best_model_path} with Val Acc = {best_val_acc:.2f}%")

        # Save best test accuracy
        test_acc_path = f"./models/{checkpoint_dir}_best_test.txt"
        with open(test_acc_path, "w") as f:
            f.write(f"Best Validation Accuracy: {best_val_acc:.2f}%\n")
            f.write(f"Test Accuracy at Best Val: {best_test_acc:.2f}%\n")
        print(f"✅ Test accuracy saved to {test_acc_path}")

    # --- Plotting ---
    if results:
        epochs = [r[0] for r in results]
        train_accuracies = [r[1] for r in results]
        val_accuracies = [r[2] for r in results]
        test_accuracies = [r[3] for r in results]

        plt.figure(figsize=(10, 6))
        plt.plot(epochs, train_accuracies, label="Train Accuracy", marker='o', linestyle='-')
        plt.plot(epochs, val_accuracies, label="Validation Accuracy", marker='x', linestyle='--')
        plt.plot(epochs, test_accuracies, label="Test Accuracy", marker='s', linestyle=':')
        plt.title("Model Accuracy Over Epochs")
        plt.xlabel("Epoch")
        plt.ylabel("Accuracy (%)")
        plt.grid(True)
        plt.legend()
        plt.tight_layout()

        # Save the plot
        plot_path = f"./models/{checkpoint_dir}_accuracy_plot.png"
        plt.savefig(plot_path)
        plt.show()
        print(f"📊 Accuracy plot saved to {plot_path}")

    return results


In [None]:
model_timestamp = "model_2025_04_18_01_35"

results = evaluate_and_plot_checkpoints(
    model_class=LightningRiskCNNRNN,
    checkpoint_dir=model_timestamp,  # Folder with .pth files
    train_loader=dataloader,
    val_loader=validation_dataloader,test_loader = test_dataloader,
    device="cuda" if torch.cuda.is_available() else "cpu"
)