# SAINT Transformer Architecture for wildfire prediction

## Environment Setup

In this section we will install all the dependencies needed to run the code, load the dataset, and import all the necessary libraries required to run the SAINT transformer architecture.

In [29]:
# Install dependencies
%pip install --upgrade torch pytorch_lightning scikit-learn pandas numpy<2 matplotlib

zsh:1: no such file or directory: 2
Note: you may need to restart the kernel to use updated packages.


In [30]:
import json
import torch
import torch.nn as nn
import pytorch_lightning as pl
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder

In [31]:
import pandas as pd

# Load the CSV file into a DataFrame
df = pd.read_csv('dataset.csv')

# Display the first few rows
print(df.head())

  system:index  EVH   EVT    NDVI  PRES_max  SPFH_max    TMP_max  WIND_max  \
0        0_0_0   75  3092  5691.0   97165.0  0.006520  17.850000      4.88   
1        0_1_0   89  3900  6176.0   98114.0  0.006578  21.589990      4.59   
2        0_2_0   96  3097  6563.0   97881.0  0.007021  19.989984      3.85   
3        0_3_0   21  3296  4034.0   98933.0  0.006326  22.369989      4.08   
4        0_4_0   99  3097  8027.0   97243.0  0.007369  16.290002      5.66   

   burned        date  elevation  sm_profile  sm_profile_wetness  sm_rootzone  \
0       0  2023-04-01  348.75693    0.365752            0.797881     0.389533   
1       0  2023-04-01  363.53400    0.356276            0.777190     0.380164   
2       0  2023-04-01  409.15347    0.361830            0.789306     0.381324   
3       0  2023-04-01  263.18298    0.358382            0.781829     0.382392   
4       0  2023-04-01  323.89856    0.367979            0.802734     0.387794   

   sm_rootzone_wetness  sm_surface  sm_surfa

In [32]:
import pandas as pd
import json

# Handle missing values
df.ffill(inplace=True)

# Clean up column names (remove leading/trailing spaces)
df.columns = df.columns.str.strip()

# Encode categorical features
cat_cols = ['EVT']

# Encode numerical features
num_cols = [
    'EVH', 'NDVI', 'PRES_max', 'SPFH_max', 'TMP_max', 'WIND_max',
    'elevation', 'sm_profile', 'sm_profile_wetness', 'sm_rootzone',
    'sm_rootzone_wetness', 'sm_surface', 'sm_surface_wetness'
]

# Check and process 'date' column
if 'date' in df.columns:
    df['date'] = pd.to_datetime(df['date'])
    df['year'] = df['date'].dt.year
    df['month'] = df['date'].dt.month
    df['day'] = df['date'].dt.day
    num_cols += ['year', 'month', 'day']
    df.drop(columns=['date'], inplace=True)

# Ensure 'burned' is integer
df['burned'] = df['burned'].astype(int)

# Extract longitude and latitude from '.geo' JSON string
def extract_coords(geo_str):
    geo = json.loads(geo_str)
    return pd.Series(geo['coordinates'], index=['longitude', 'latitude'])

if '.geo' in df.columns:
    df[['longitude', 'latitude']] = df['.geo'].apply(extract_coords)
    num_cols += ['longitude', 'latitude']
    df.drop(columns=['.geo'], inplace=True)

target = 'burned'

In [5]:
df.head()

Unnamed: 0,system:index,EVH,EVT,NDVI,PRES_max,SPFH_max,TMP_max,WIND_max,burned,elevation,...,sm_profile_wetness,sm_rootzone,sm_rootzone_wetness,sm_surface,sm_surface_wetness,year,month,day,longitude,latitude
0,0_0_0,75,3092,5691.0,97165.0,0.00652,17.85,4.88,0,348.75693,...,0.797881,0.389533,0.849741,0.349155,0.761735,2023,4,1,-118.815671,34.230304
1,0_1_0,89,3900,6176.0,98114.0,0.006578,21.58999,4.59,0,363.534,...,0.77719,0.380164,0.829294,0.351032,0.765764,2023,4,1,-118.546176,34.140472
2,0_2_0,96,3097,6563.0,97881.0,0.007021,19.989984,3.85,0,409.15347,...,0.789306,0.381324,0.831857,0.352863,0.769822,2023,4,1,-118.582109,34.122506
3,0_3_0,21,3296,4034.0,98933.0,0.006326,22.369989,4.08,0,263.18298,...,0.781829,0.382392,0.834207,0.354602,0.773576,2023,4,1,-118.627025,34.176405
4,0_4_0,99,3097,8027.0,97243.0,0.007369,16.290002,5.66,0,323.89856,...,0.802734,0.387794,0.845987,0.356098,0.776841,2023,4,1,-118.707873,34.059624


## 🔧 Defining the SAINT Model Architecture

We build a **Transformer model** that works on tabular data:

- **Categorical features** are turned into embeddings (like word vectors).
- **Numerical features** are projected into the same embedding space via a linear layer.
- All features become **tokens** and go into a Transformer Encoder.

### Inside the Transformer:

- **Self-Attention (column-wise)**: learns relationships between features.
- **Intersample Attention (row-wise)**: each row can attend to other rows in the batch — **unique to SAINT**.

### Output Layer

The transformer output goes through a simple classifier:

- `Linear → ReLU → Dropout → Linear → Sigmoid (via BCEWithLogitsLoss)`

### Class Imbalance Handling

We use `BCEWithLogitsLoss` with a `pos_weight` parameter to account for the fact that wildfire events are **rare** in the dataset.

In [33]:
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.optim as optim

class SAINT(pl.LightningModule):
    def __init__(self, num_cont, cat_dims=[], embed_dim=32, 
                 num_heads=4, num_layers=3, dropout=0.1, pos_weight=None, lr=1e-3):
        super().__init__()
        self.save_hyperparameters(ignore=['pos_weight'])
        self.embeddings = nn.ModuleList([
            nn.Embedding(dim, embed_dim) for dim in cat_dims
        ])
        self.cont_proj = nn.Linear(num_cont, embed_dim)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim, nhead=num_heads, dropout=dropout,
            batch_first=True, dim_feedforward=embed_dim*4
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
        self.classifier = nn.Sequential(
            nn.Linear(embed_dim, 32),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(32, 1)
        )
        if pos_weight is not None:
            self.loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
        else:
            self.loss_fn = nn.BCEWithLogitsLoss()

    def forward(self, x_cat, x_cont):
        tokens = []
        for i, emb in enumerate(self.embeddings):
            tokens.append(emb(x_cat[:, i]))
        tokens.append(self.cont_proj(x_cont))
        tokens = torch.stack(tokens, dim=1)
        attn_output = self.transformer(tokens)
        cls_token = attn_output.mean(dim=1)
        return self.classifier(cls_token).squeeze(-1)

    def training_step(self, batch, batch_idx):
        x_cat, x_cont, y = batch
        logits = self(x_cat, x_cont)
        loss = self.loss_fn(logits, y.float())
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x_cat, x_cont, y = batch
        logits = self(x_cat, x_cont)
        loss = self.loss_fn(logits, y.float())
        self.log('val_loss', loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.hparams.lr)

In [34]:
import pytorch_lightning as pl
import torch
from torch.utils.data import TensorDataset, DataLoader

class FireDataModule(pl.LightningDataModule):
    def __init__(self, X_train, y_train, X_test, y_test, cat_cols, batch_size=64):
        super().__init__()
        self.X_train = X_train
        self.y_train = y_train
        self.X_test = X_test
        self.y_test = y_test
        self.cat_cols = list(cat_cols)
        self.batch_size = batch_size
        self.device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

    def setup(self, stage=None):
        # Defensive: ensure all column names are strings
        self.cat_cols = [str(col) for col in list(self.cat_cols)]
        self.X_train.columns = [str(col) for col in self.X_train.columns]
        self.X_test.columns = [str(col) for col in self.X_test.columns]
        cat_cols = self.cat_cols
        cont_cols = [col for col in self.X_train.columns if col not in cat_cols]

        print("DEBUG cat_cols:", cat_cols)
        print("DEBUG cont_cols:", cont_cols)
        print("DEBUG X_train columns:", self.X_train.columns)
        print("DEBUG X_train[cat_cols] dtype:", self.X_train[cat_cols].values.dtype)

        X_train_cat = torch.tensor(self.X_train[cat_cols].values, dtype=torch.long, device=self.device)
        X_train_cont = torch.tensor(self.X_train[cont_cols].values, dtype=torch.float32, device=self.device)
        y_train = torch.tensor(self.y_train.values, dtype=torch.float32, device=self.device)
        X_test_cat = torch.tensor(self.X_test[cat_cols].values, dtype=torch.long, device=self.device)
        X_test_cont = torch.tensor(self.X_test[cont_cols].values, dtype=torch.float32, device=self.device)
        y_test = torch.tensor(self.y_test.values, dtype=torch.float32, device=self.device)

        self.train_dataset = TensorDataset(X_train_cat, X_train_cont, y_train)
        self.val_dataset = TensorDataset(X_test_cat, X_test_cont, y_test)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size)

## Model Training with all features

We use the SAINT architecture and Dataloader class to train the model.

In [8]:
import numpy as np
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import train_test_split

cat_cols = list(cat_cols)
num_cols = list(num_cols)

# Encode categorical columns
for col in cat_cols:
    le = LabelEncoder()
    df[col] = le.fit_transform(df[col].astype(str))
    df[col] = df[col].astype(int)  # Ensure integer type for embeddings

# Compute cat_dims after encoding
cat_dims = [df[col].nunique() for col in cat_cols]

# Scale numerical features
scaler = StandardScaler()
df[num_cols] = scaler.fit_transform(df[num_cols])

X = df[cat_cols + num_cols]
y = df['burned']

# Split
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, stratify=y, random_state=42
)

# Ensure categorical columns are first and in the same order as cat_cols
all_cols = cat_cols + [col for col in X_train.columns if col not in cat_cols]
X_train = X_train[all_cols]
X_test = X_test[all_cols]

# Double-check for out-of-range values in categorical columns
for i, col in enumerate(cat_cols):
    max_val = X_train[col].max()
    assert max_val < cat_dims[i], f"Column {col} has value {max_val} >= {cat_dims[i]}"

# Convert ONLY numerical columns to float32 for PyTorch compatibility
for col in num_cols:
    X_train[col] = X_train[col].astype(np.float32)
    X_test[col] = X_test[col].astype(np.float32)
y_train = y_train.astype(np.float32)
y_test = y_test.astype(np.float32)

In [None]:
import torch
import pytorch_lightning as pl

# Initialize model
model = SAINT(
    num_cont=len([col for col in X_train.columns if col not in cat_cols]),
    cat_dims=cat_dims,
    embed_dim=32,
    num_heads=4,
    num_layers=3
)

# Initialize DataModule
dm = FireDataModule(X_train, y_train, X_test, y_test, cat_cols=cat_cols, batch_size=64)

# Train with PyTorch Lightning using MPS (Mac GPU)
trainer = pl.Trainer(
    max_epochs=50,
    accelerator='auto',  # This will use your Mac GPU!
    devices=1,
    callbacks=[
        pl.callbacks.EarlyStopping(monitor='val_loss', patience=5),
        pl.callbacks.ModelCheckpoint(monitor='val_loss', mode='min')
    ]
)

trainer.fit(model, dm)

In [39]:
# Initialize DataModule
import torch
import pytorch_lightning as pl


# Initialize model
model = SAINT(
    num_cont=len([col for col in X_train.columns if col not in cat_cols]),
    cat_dims=cat_dims,
    embed_dim=32,
    num_heads=4,
    num_layers=3
)

# Initialize DataModule
dm = FireDataModule(X_train, y_train, X_test, y_test, cat_cols=cat_cols, batch_size=64)

dm = FireDataModule(X_train, y_train, X_test, y_test, cat_cols=cat_cols, batch_size=64)

In [40]:
# Re-instantiate and load the model
model = SAINT.load_from_checkpoint(
    '/Users/andrguardia/Documents/GitHub/wildfire/lightning_logs/version_14/checkpoints/epoch=20-step=191898.ckpt',
    num_cont=13,
    cat_dims=[39],
    embed_dim=32,
    num_heads=4,
    num_layers=3,
    dropout=0.1,
    lr=0.001
)

# Set model to evaluation mode
model.eval()

SAINT(
  (embeddings): ModuleList(
    (0): Embedding(39, 32)
  )
  (cont_proj): Linear(in_features=13, out_features=32, bias=True)
  (transformer): TransformerEncoder(
    (layers): ModuleList(
      (0-2): 3 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=32, out_features=32, bias=True)
        )
        (linear1): Linear(in_features=32, out_features=128, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=128, out_features=32, bias=True)
        (norm1): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (classifier): Sequential(
    (0): Linear(in_features=32, out_features=32, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.1, inplace=False)
    (3): Linear(in_fea

In [42]:
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix
import numpy as np
import torch

# Make sure your DataModule is set up
dm.setup()
val_loader = dm.val_dataloader()

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model.to(device)
model.eval()

all_preds = []
all_labels = []

with torch.no_grad():
    for x_cat, x_cont, y_true in val_loader:
        x_cat = x_cat.to(device)
        x_cont = x_cont.to(device)
        y_true = y_true.to(device)
        logits = model(x_cat, x_cont)
        probs = torch.sigmoid(logits)
        preds = (probs > 0.3).long()
        all_preds.extend(preds.cpu().detach().tolist())
        all_labels.extend(y_true.cpu().detach().tolist())

all_preds = np.array(all_preds).flatten()
all_labels = np.array(all_labels).flatten()

# Now you can compute metrics as before
acc = accuracy_score(all_labels, all_preds)
f1 = f1_score(all_labels, all_preds)
precision = precision_score(all_labels, all_preds)
recall = recall_score(all_labels, all_preds)
cm = confusion_matrix(all_labels, all_preds)

print(f"Accuracy:  {acc:.4f}")
print(f"F1 Score:  {f1:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall:    {recall:.4f}")
print("Confusion Matrix:\n", cm)

DEBUG cat_cols: ['EVT']
DEBUG cont_cols: ['EVH', 'NDVI', 'PRES_max', 'SPFH_max', 'TMP_max', 'WIND_max', 'elevation', 'sm_profile', 'sm_profile_wetness', 'sm_rootzone', 'sm_rootzone_wetness', 'sm_surface', 'sm_surface_wetness', 'year', 'month', 'day', 'longitude', 'latitude']
DEBUG X_train columns: Index(['EVT', 'EVH', 'NDVI', 'PRES_max', 'SPFH_max', 'TMP_max', 'WIND_max',
       'elevation', 'sm_profile', 'sm_profile_wetness', 'sm_rootzone',
       'sm_rootzone_wetness', 'sm_surface', 'sm_surface_wetness', 'year',
       'month', 'day', 'longitude', 'latitude'],
      dtype='object')
DEBUG X_train[cat_cols] dtype: int64


RuntimeError: linear(): input and weight.T shapes cannot be multiplied (64x18 and 13x32)

### improving f1 score

We will extract the importance scores of each feature here in the model

In [29]:
from sklearn.inspection import permutation_importance
from sklearn.ensemble import RandomForestClassifier

# Train quick RF model to assess feature importance
rf = RandomForestClassifier(n_jobs=-1)
rf.fit(X_train, y_train)
result = permutation_importance(rf, X_test, y_test, n_repeats=5, random_state=42)

# Rank features
importance_df = pd.DataFrame({
    'feature': X_train.columns,
    'importance': result.importances_mean
}).sort_values(by='importance', ascending=False)

print(importance_df)

                feature  importance
2                  NDVI    0.003059
4              SPFH_max    0.001431
9    sm_profile_wetness    0.001384
1                   EVH    0.001271
10          sm_rootzone    0.000810
5               TMP_max    0.000720
8            sm_profile    0.000679
11  sm_rootzone_wetness    0.000573
0                   EVT    0.000561
7             elevation    0.000442
6              WIND_max    0.000272
12           sm_surface    0.000269
13   sm_surface_wetness    0.000249
3              PRES_max    0.000220


#### Feature Importance Analysis & Pruning Strategy

🧠 Observations from Permutation Importance

Rank	Feature	Importance	Notes
- 1	NDVI	0.00306	✅ Strong vegetation signal — keep

- 2	SPFH_max	0.00143	✅ Specific humidity — keep

- 3	sm_profile_wetness	0.00138	✅ Soil moisture — keep

- 4	EVH	0.00127	✅ Likely vegetation height or evapotranspiration — keep

- 5–8	sm_* features	0.0006–0.0008	⚠️ Redundant — keep only top 1–2

- 9	EVT (categorical)	0.00056	✅ Low score but keep — embeddings may help

- 10–13	elevation, WIND_max, sm_surface, sm_surface_wetness, PRES_max	≤ 0.0004	❌ Very low signal — drop candidates


⸻

✅ Recommended Feature Selection

Keep:

selected_features = [
    'NDVI', 'SPFH_max', 'sm_profile_wetness', 'EVH',  # top 4
    'sm_rootzone',                                   # optional (moderate importance)
    'EVT'                                            # categorical
]

Drop (low importance):

drop_features = [
    'PRES_max', 'WIND_max', 'elevation',
    'sm_surface', 'sm_surface_wetness',
    'sm_rootzone_wetness', 'sm_profile',
    # optionally: 'TMP_max'
]


⸻

🧪 Next Steps
	•	Retrain SAINT with the pruned feature set.
	•	Evaluate changes in F1 score, recall, and precision.
	•	Iterate: remove 1–2 additional features if no gain observed.

If F1 or recall improves even slightly, you’ve successfully reduced noise and improved generalization.

Let me know if you want the updated preprocessing code with this pruning applied!

## Model Training with pruned features

In [30]:
import numpy as np
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import train_test_split

# === Feature Selection ===

# Categorical features (keep EVT)
cat_cols = ['EVT']

# Numerical features (keep only the most important)
num_cols = [
    'NDVI',                # top 1
    'SPFH_max',            # top 2
    'sm_profile_wetness',  # top 3
    'EVH',                 # top 4
    'sm_rootzone',         # moderate importance
    'TMP_max',             # top 5
]


# === Encode categorical columns ===
for col in cat_cols:
    le = LabelEncoder()
    df[col] = le.fit_transform(df[col].astype(str))
    df[col] = df[col].astype(int)  # Ensure integer type for embeddings

# Compute cat_dims after encoding
cat_dims = [df[col].nunique() for col in cat_cols]

# === Scale numerical features ===
scaler = StandardScaler()
df[num_cols] = scaler.fit_transform(df[num_cols])

# === Prepare Data ===
X = df[cat_cols + num_cols]
y = df['burned']

# === Train/Test Split ===
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, stratify=y, random_state=42
)

# Ensure categorical columns are first and in the same order as cat_cols
all_cols = cat_cols + [col for col in X_train.columns if col not in cat_cols]
X_train = X_train[all_cols]
X_test = X_test[all_cols]

# Double-check for out-of-range values in categorical columns
for i, col in enumerate(cat_cols):
    max_val = X_train[col].max()
    assert max_val < cat_dims[i], f"Column {col} has value {max_val} >= {cat_dims[i]}"

# Convert ONLY numerical columns to float32 for PyTorch compatibility
for col in num_cols:
    X_train[col] = X_train[col].astype(np.float32)
    X_test[col] = X_test[col].astype(np.float32)
y_train = y_train.astype(np.float32)
y_test = y_test.astype(np.float32)

In [32]:
import torch
import pytorch_lightning as pl

# Initialize model
model = SAINT(
    num_cont=len([col for col in X_train.columns if col not in cat_cols]),
    cat_dims=cat_dims,
    embed_dim=32,
    num_heads=4,
    num_layers=3
)

# Initialize DataModule
dm = FireDataModule(X_train, y_train, X_test, y_test, cat_cols=cat_cols, batch_size=64)

# Train with PyTorch Lightning using MPS (Mac GPU)
trainer = pl.Trainer(
    max_epochs=50,
    accelerator='auto',  # This will use your Mac GPU!
    devices=1,
    callbacks=[
        pl.callbacks.EarlyStopping(monitor='val_loss', patience=5),
        pl.callbacks.ModelCheckpoint(monitor='val_loss', mode='min')
    ]
)

trainer.fit(model, dm)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


DEBUG cat_cols: ['EVT']
DEBUG cont_cols: ['NDVI', 'SPFH_max', 'sm_profile_wetness', 'EVH', 'sm_rootzone', 'TMP_max']
DEBUG X_train columns: Index(['EVT', 'NDVI', 'SPFH_max', 'sm_profile_wetness', 'EVH', 'sm_rootzone',
       'TMP_max'],
      dtype='object')
DEBUG X_train[cat_cols] dtype: int64



  | Name        | Type               | Params | Mode 
-----------------------------------------------------------
0 | embeddings  | ModuleList         | 1.2 K  | train
1 | cont_proj   | Linear             | 224    | train
2 | transformer | TransformerEncoder | 38.1 K | train
3 | classifier  | Sequential         | 1.1 K  | train
4 | loss_fn     | BCEWithLogitsLoss  | 0      | train
-----------------------------------------------------------
40.7 K    Trainable params
0         Non-trainable params
40.7 K    Total params
0.163     Total estimated model params size (MB)
41        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
/usr/local/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [26]:
# Re-instantiate and load the model
model = SAINT.load_from_checkpoint(
    '/Users/andrguardia/Documents/GitHub/wildfire/lightning_logs/version_16/checkpoints/epoch=17-step=164484.ckpt'
)

# Set model to evaluation mode
model.eval()

SAINT(
  (embeddings): ModuleList(
    (0): Embedding(39, 32)
  )
  (cont_proj): Linear(in_features=6, out_features=32, bias=True)
  (transformer): TransformerEncoder(
    (layers): ModuleList(
      (0-2): 3 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=32, out_features=32, bias=True)
        )
        (linear1): Linear(in_features=32, out_features=128, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=128, out_features=32, bias=True)
        (norm1): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (classifier): Sequential(
    (0): Linear(in_features=32, out_features=32, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.1, inplace=False)
    (3): Linear(in_feat

In [27]:
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix
import numpy as np
import torch

# Make sure your DataModule is set up
dm.setup()
val_loader = dm.val_dataloader()

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model.to(device)
model.eval()

all_preds = []
all_labels = []

with torch.no_grad():
    for x_cat, x_cont, y_true in val_loader:
        x_cat = x_cat.to(device)
        x_cont = x_cont.to(device)
        y_true = y_true.to(device)
        logits = model(x_cat, x_cont)
        probs = torch.sigmoid(logits)
        preds = (probs > 0.3).long()
        all_preds.extend(preds.cpu().detach().tolist())
        all_labels.extend(y_true.cpu().detach().tolist())

all_preds = np.array(all_preds).flatten()
all_labels = np.array(all_labels).flatten()

# Now you can compute metrics as before
acc = accuracy_score(all_labels, all_preds)
f1 = f1_score(all_labels, all_preds)
precision = precision_score(all_labels, all_preds)
recall = recall_score(all_labels, all_preds)
cm = confusion_matrix(all_labels, all_preds)

print(f"Accuracy:  {acc:.4f}")
print(f"F1 Score:  {f1:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall:    {recall:.4f}")
print("Confusion Matrix:\n", cm)

DEBUG cat_cols: ['EVT']
DEBUG cont_cols: ['EVH', 'NDVI', 'PRES_max', 'SPFH_max', 'TMP_max', 'WIND_max', 'elevation', 'sm_profile', 'sm_profile_wetness', 'sm_rootzone', 'sm_rootzone_wetness', 'sm_surface', 'sm_surface_wetness', 'year', 'month', 'day', 'longitude', 'latitude']
DEBUG X_train columns: Index(['EVT', 'EVH', 'NDVI', 'PRES_max', 'SPFH_max', 'TMP_max', 'WIND_max',
       'elevation', 'sm_profile', 'sm_profile_wetness', 'sm_rootzone',
       'sm_rootzone_wetness', 'sm_surface', 'sm_surface_wetness', 'year',
       'month', 'day', 'longitude', 'latitude'],
      dtype='object')
DEBUG X_train[cat_cols] dtype: int64


RuntimeError: linear(): input and weight.T shapes cannot be multiplied (64x18 and 6x32)

## Conclusions

Model Evaluation Analysis
1. Accuracy
Value: 0.9974
Interpretation:
The model correctly classified 99.74% of the validation samples. While this appears excellent, accuracy can be misleading in imbalanced datasets (where one class is much more frequent than the other).
2. F1 Score
Value: 0.5698
Interpretation:
The F1 score is the harmonic mean of precision and recall. A value of 0.57 indicates a moderate balance between precision and recall, but is much lower than the accuracy, suggesting issues with class imbalance or difficulty in detecting the minority class.
3. Precision
Value: 0.6571
Interpretation:
Of all the samples the model predicted as positive (class 1), 65.71% were actually positive. This means that about one-third of positive predictions are false positives.
4. Recall
Value: 0.5030
Interpretation:
The model only identified 50.30% of the actual positive cases. This means it is missing nearly half of the true positives (i.e., high false negative rate).
5. Confusion Matrix
CopyInsert
[[145565    132]
 [   250    253]]
Rows: True classes ([0, 1])
Columns: Predicted classes ([0, 1])
Breakdown:
True Negatives (TN): 145,565 (correctly predicted class 0)
False Positives (FP): 132 (predicted 1, actually 0)
False Negatives (FN): 250 (predicted 0, actually 1)
True Positives (TP): 253 (correctly predicted class 1)
6. Class Imbalance
The dataset is highly imbalanced:
Majority class (0): 145,697 samples
Minority class (1): 503 samples
This imbalance explains the high accuracy (dominated by correct negatives) and the much lower F1, precision, and recall.
7. Threshold Setting
You used a sigmoid threshold of 0.3 (not the default 0.5). Lowering the threshold increases recall but may reduce precision. Here, recall is still low, so you may want to try even lower thresholds or other balancing techniques.
8. Recommendations
Metric Focus: Use F1, precision, and recall as your main metrics (not accuracy) due to class imbalance.
Threshold Tuning: Experiment with different thresholds to find a better balance between precision and recall.
Class Imbalance Handling: Try oversampling/undersampling, class weighting, or synthetic data (e.g., SMOTE) to improve minority class detection.
Model Improvements: Consider more complex models, feature engineering, or ensemble methods if appropriate.