In [1]:
!pip install torch-geometric

Collecting torch-geometric
  Downloading torch_geometric-2.7.0-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.7/63.7 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.7.0-py3-none-any.whl (1.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m25.9 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25hInstalling collected packages: torch-geometric
Successfully installed torch-geometric-2.7.0


In [2]:
import pandas as pd
import numpy as np
import torch
from torch_geometric.data import Data
from sklearn.preprocessing import StandardScaler
from scipy.spatial import cKDTree

In [16]:
df = pd.read_csv('/kaggle/input/datasets/firecastrl/us-wildfire-dataset/Wildfire_Dataset.csv', parse_dates=['datetime'])

feature_cols = ['pr', 'rmax', 'rmin', 'sph', 'srad', 'tmmn', 'tmmx',
                'vs', 'bi', 'fm100', 'fm1000', 'erc', 'etr', 'pet', 'vpd']
target_col = 'Wildfire'

# Encode target once, vectorized
df['label'] = (df['Wildfire'].str.strip() == 'Yes').astype(np.float32)

# Verify immediately
print("Label counts:", df['label'].value_counts().to_dict())
print("Wildfire %:", df['label'].mean() * 100, "%")

Label counts: {0.0: 9007860, 1.0: 502065}
Wildfire %: 5.2793794 %


In [17]:
# Get unique spatial locations and assign stable node IDs
coords = df[['latitude', 'longitude']].drop_duplicates().reset_index(drop=True)
coords['node_id'] = coords.index
num_nodes = len(coords)

# Merge node_id back — O(n) hash join
df = df.merge(coords, on=['latitude', 'longitude'], how='left')

print(f"Nodes: {num_nodes:,} | Timesteps: {df['datetime'].nunique():,} | Rows: {len(df):,}")

Nodes: 37,098 | Timesteps: 4,122 | Rows: 9,509,925


In [18]:
scaler = StandardScaler()
df[feature_cols] = scaler.fit_transform(df[feature_cols]).astype(np.float32)

In [19]:
# Build spatial graph: connect each node to its k nearest neighbors
K = 6  # neighbors per node
coord_arr = coords[['latitude', 'longitude']].values

tree = cKDTree(coord_arr)
distances, indices = tree.query(coord_arr, k=K + 1)  # +1 to exclude self

src, dst = [], []
for i, neighbors in enumerate(indices):
    for j in neighbors[1:]:  # skip self (index 0)
        src.append(i)
        dst.append(j)

edge_index = torch.tensor([src, dst], dtype=torch.long)
print(f"Edges: {edge_index.shape[1]:,}")

Edges: 222,588


In [20]:
# Same flat index trick but process in date chunks to cap RAM
df['t_id'] = df['datetime'].rank(method='dense').astype(int) - 1
num_times = df['t_id'].nunique()

CHUNK = 100  # process 100 dates at a time, tune up/down based on RAM

snapshots = []

for t_start in range(0, num_times, CHUNK):
    t_end = min(t_start + CHUNK, num_times)
    chunk_size = t_end - t_start

    mask = (df['t_id'] >= t_start) & (df['t_id'] < t_end)
    chunk = df.loc[mask]

    local_t = (chunk['t_id'].values - t_start)
    flat_idx = local_t * num_nodes + chunk['node_id'].values

    X_chunk = np.zeros((chunk_size * num_nodes, len(feature_cols)), dtype=np.float32)
    y_chunk = np.zeros(chunk_size * num_nodes, dtype=np.float32)

    X_chunk[flat_idx] = chunk[feature_cols].values.astype(np.float32)
    y_chunk[flat_idx] = chunk['label'].values

    X_chunk = X_chunk.reshape(chunk_size, num_nodes, len(feature_cols))
    y_chunk = y_chunk.reshape(chunk_size, num_nodes)

    for i in range(chunk_size):
        snapshots.append(Data(
            x=torch.from_numpy(X_chunk[i]),
            y=torch.from_numpy(y_chunk[i]),
            edge_index=edge_index
        ))

    if (t_start // CHUNK) % 5 == 0:
        print(f"  {t_end}/{num_times} timesteps done...")

print(f"Done. {len(snapshots)} snapshots")

  100/4122 timesteps done...
  600/4122 timesteps done...
  1100/4122 timesteps done...
  1600/4122 timesteps done...
  2100/4122 timesteps done...
  2600/4122 timesteps done...
  3100/4122 timesteps done...
  3600/4122 timesteps done...
  4100/4122 timesteps done...
Done. 4122 snapshots


In [21]:
n = len(snapshots)
train_end = int(0.7 * n)
val_end   = int(0.85 * n)

train_snapshots = snapshots[:train_end]
val_snapshots   = snapshots[train_end:val_end]
test_snapshots  = snapshots[val_end:]

print(f"Train: {len(train_snapshots)} | Val: {len(val_snapshots)} | Test: {len(test_snapshots)}")

Train: 2885 | Val: 618 | Test: 619


In [22]:
snap = snapshots[0]
print("x shape     :", snap.x.shape)       # [num_nodes, 15]
print("y shape     :", snap.y.shape)       # [num_nodes]
print("edge_index  :", snap.edge_index.shape)  # [2, num_edges]
print("Wildfire %  :", snap.y.mean().item() * 100, "%")

x shape     : torch.Size([37098, 15])
y shape     : torch.Size([37098])
edge_index  : torch.Size([2, 222588])
Wildfire %  : 0.0 %


In [23]:
# Check what values actually exist in the Wildfire column
print("Raw value counts:")
print(df['Wildfire'].value_counts())

print("\nUnique values:", df['Wildfire'].unique()[:20])

print("\nDtype:", df['Wildfire'].dtype)

# Check what the label column looks like after encoding
print("\nLabel distribution:")
print(df['label'].value_counts())

print("\nSample rows:")
print(df[['Wildfire', 'label']].drop_duplicates().head(20))

Raw value counts:
Wildfire
No     9007860
Yes     502065
Name: count, dtype: int64

Unique values: ['No' 'Yes']

Dtype: object

Label distribution:
label
0.0    9007860
1.0     502065
Name: count, dtype: int64

Sample rows:
    Wildfire  label
0         No    0.0
285      Yes    1.0


In [24]:
# Run this BEFORE re-running Cell 6
print("Label sum in df:", df['label'].sum())          # should be 502065
print("Label in flat array test:")

t0_mask = df['t_id'] == 0
print("  t=0 positives:", df.loc[t0_mask, 'label'].sum())

Label sum in df: 502065.0
Label in flat array test:
  t=0 positives: 0.0


In [25]:
positives_per_t = df.groupby('t_id')['label'].sum()

print("Timesteps with ANY wildfire:", (positives_per_t > 0).sum(), "out of", len(positives_per_t))
print("\nFirst 10 timesteps with fires:")
print(positives_per_t[positives_per_t > 0].head(10))

print("\nDate of first wildfire timestep:")
first_fire_t = positives_per_t[positives_per_t > 0].index[0]
print(df[df['t_id'] == first_fire_t]['datetime'].iloc[0])

Timesteps with ANY wildfire: 3994 out of 4122

First 10 timesteps with fires:
t_id
69    1.0
70    1.0
71    1.0
72    1.0
73    1.0
74    1.0
75    1.0
76    1.0
77    1.0
78    1.0
Name: label, dtype: float32

Date of first wildfire timestep:
2014-03-10 00:00:00


In [26]:
# Check what date snapshot[0] corresponds to
print("Snapshot[0] date:", df[df['t_id'] == 0]['datetime'].iloc[0])
print("Snapshot[69] date:", df[df['t_id'] == 69]['datetime'].iloc[0])

# Overall health check across ALL snapshots
total_positives = sum(s.y.sum().item() for s in snapshots)
snapshots_with_fire = sum(1 for s in snapshots if s.y.sum() > 0)

print(f"\n=== Pipeline Health Check ===")
print(f"Total snapshots       : {len(snapshots)}")
print(f"Snapshots with fires  : {snapshots_with_fire} / {len(snapshots)}")
print(f"Total positive labels : {total_positives:,.0f}  (expected ~502,065)")
print(f"Overall Wildfire %    : {total_positives / (len(snapshots) * num_nodes) * 100:.3f}%")

# Spot check a snapshot that has fires
fire_snap = snapshots[69]
print(f"\nSnapshot[69] Wildfire %: {fire_snap.y.mean().item()*100:.3f}%")
print(f"Snapshot[69] x shape  : {fire_snap.x.shape}")
print(f"Snapshot[69] y shape  : {fire_snap.y.shape}")
print(f"Snapshot[69] edges    : {fire_snap.edge_index.shape}")

Snapshot[0] date: 2013-12-31 00:00:00
Snapshot[69] date: 2014-03-10 00:00:00

=== Pipeline Health Check ===
Total snapshots       : 4122
Snapshots with fires  : 3994 / 4122
Total positive labels : 498,661  (expected ~502,065)
Overall Wildfire %    : 0.326%

Snapshot[69] Wildfire %: 0.003%
Snapshot[69] x shape  : torch.Size([37098, 15])
Snapshot[69] y shape  : torch.Size([37098])
Snapshot[69] edges    : torch.Size([2, 222588])


In [27]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class TGNN(nn.Module):
    def __init__(self, in_channels, hidden_channels, gru_hidden, num_layers=2):
        super().__init__()
        
        # Spatial: stack of GCN layers
        self.gcn_layers = nn.ModuleList()
        self.gcn_layers.append(GCNConv(in_channels, hidden_channels))
        for _ in range(num_layers - 1):
            self.gcn_layers.append(GCNConv(hidden_channels, hidden_channels))
        
        # Temporal: GRU over spatial embeddings
        self.gru = nn.GRU(hidden_channels, gru_hidden, batch_first=True, num_layers=2)
        
        # Classifier head
        self.classifier = nn.Sequential(
            nn.Linear(gru_hidden, gru_hidden // 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(gru_hidden // 2, 1)
        )
    
    def forward(self, snapshot_seq, edge_index):
        """
        snapshot_seq: [T, N, F]  (sequence of T snapshots)
        edge_index:   [2, E]
        returns:      [N] logits for final timestep
        """
        spatial_embeds = []
        
        for t in range(snapshot_seq.size(0)):
            x = snapshot_seq[t]                          # [N, F]
            for gcn in self.gcn_layers:
                x = F.relu(gcn(x, edge_index))           # [N, H]
            spatial_embeds.append(x)
        
        # Stack: [N, T, H]
        x = torch.stack(spatial_embeds, dim=1)
        
        # GRU: [N, T, H] -> [N, T, gru_hidden]
        x, _ = self.gru(x)
        
        # Take last timestep: [N, gru_hidden]
        x = x[:, -1, :]
        
        # Classify: [N]
        return self.classifier(x).squeeze(-1)

In [28]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

model = TGNN(
    in_channels=15,
    hidden_channels=64,
    gru_hidden=64,
    num_layers=2
).to(device)

# Class imbalance weight (~306:1 at node level, use softer weight)
pos_weight = torch.tensor([9007860 / 502065]).to(device)  # ~17.9
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, factor=0.5)

print(f"Model params: {sum(p.numel() for p in model.parameters()):,}")

Using device: cpu
Model params: 57,217


In [30]:
SEQ_LEN = 2        # use 7 days of history to predict next day
EPOCHS  = 2
BATCH_T = 2        # process 4 windows at a time (tune for your GPU RAM)

edge_index_dev = edge_index.to(device)

def make_windows(split_snapshots, seq_len):
    """Yield (input_seq [T,N,F], target [N]) windows"""
    for i in range(len(split_snapshots) - seq_len):
        xs = torch.stack([s.x for s in split_snapshots[i:i+seq_len]])   # [T,N,F]
        y  = split_snapshots[i + seq_len].y                              # [N]
        yield xs, y

best_val_loss = float('inf')

for epoch in range(EPOCHS):
    # --- Train ---
    model.train()
    train_loss, train_steps = 0.0, 0
    
    batch_xs, batch_ys = [], []
    for xs, y in make_windows(train_snapshots, SEQ_LEN):
        batch_xs.append(xs)
        batch_ys.append(y)
        
        if len(batch_xs) == BATCH_T:
            optimizer.zero_grad()
            loss = torch.tensor(0.0, device=device)
            
            for bx, by in zip(batch_xs, batch_ys):
                bx = bx.to(device)
                by = by.to(device)
                logits = model(bx, edge_index_dev)
                loss += criterion(logits, by)
            
            (loss / BATCH_T).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            train_loss += loss.item()
            train_steps += BATCH_T
            batch_xs, batch_ys = [], []
    
    # --- Validate ---
    model.eval()
    val_loss, val_steps = 0.0, 0
    
    with torch.no_grad():
        for xs, y in make_windows(val_snapshots, SEQ_LEN):
            xs, y = xs.to(device), y.to(device)
            logits = model(xs, edge_index_dev)
            val_loss += criterion(logits, y).item()
            val_steps += 1
    
    avg_train = train_loss / max(train_steps, 1)
    avg_val   = val_loss   / max(val_steps, 1)
    scheduler.step(avg_val)
    
    if avg_val < best_val_loss:
        best_val_loss = avg_val
        torch.save(model.state_dict(), 'best_tgnn.pt')
        tag = " ← best"
    else:
        tag = ""
    
    print(f"Epoch {epoch+1:02d}/{EPOCHS} | Train Loss: {avg_train:.4f} | Val Loss: {avg_val:.4f}{tag}")

KeyboardInterrupt: 

In [None]:
from sklearn.metrics import f1_score, roc_auc_score, precision_score, recall_score, classification_report

model.load_state_dict(torch.load('best_tgnn.pt'))
model.eval()

all_preds, all_labels = [], []

with torch.no_grad():
    for xs, y in make_windows(test_snapshots, SEQ_LEN):
        xs = xs.to(device)
        logits = model(xs, edge_index_dev)
        probs  = torch.sigmoid(logits).cpu().numpy()
        all_preds.append(probs)
        all_labels.append(y.numpy())

all_preds  = np.concatenate(all_preds)
all_labels = np.concatenate(all_labels)

# Tune threshold for imbalanced data (default 0.5 is usually too high)
threshold = 0.3
binary_preds = (all_preds >= threshold).astype(int)

print(classification_report(all_labels, binary_preds, target_names=['No Fire', 'Wildfire']))
print(f"ROC-AUC : {roc_auc_score(all_labels, all_preds):.4f}")
print(f"F1      : {f1_score(all_labels, binary_preds):.4f}")
print(f"Recall  : {recall_score(all_labels, binary_preds):.4f}")
print(f"Precision:{precision_score(all_labels, binary_preds):.4f}")