In [None]:
import logging

import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

from tqdm.auto import tqdm
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score, classification_report

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from torchvision.models import (
    # ======================= ResNet ======================= #
    resnet18, ResNet18_Weights,
    resnet34, ResNet34_Weights,
    resnet50, ResNet50_Weights,
    resnet101, ResNet101_Weights,
    resnet152, ResNet152_Weights,
    # ===================== Mobile Net ===================== #
    mobilenet_v3_small, MobileNet_V3_Small_Weights,
    mobilenet_v3_large, MobileNet_V3_Large_Weights,
    # ==================== EfficientNet ==================== #
    efficientnet_b0, EfficientNet_B0_Weights,
    efficientnet_b1, EfficientNet_B1_Weights,
    efficientnet_v2_s, EfficientNet_V2_S_Weights,
    efficientnet_v2_m, EfficientNet_V2_M_Weights,
    efficientnet_v2_l, EfficientNet_V2_L_Weights,
    # ================= Vision Transformer ================= #
    vit_b_16, ViT_B_16_Weights,
    vit_b_32, ViT_B_32_Weights,
    vit_l_16, ViT_L_16_Weights,
    vit_l_32, ViT_L_32_Weights,
)

from data_loader import get_wildfire_datasets

**Since the model weights of the Vit series require a lot of space, it is recommended to use them separately from other models.**

In [None]:
torch.manual_seed(42)
BATCHSIZE = 64

# logging setup
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")

# Define models for feature extraction
MODELS = {
    # name: [model, param_size]
    # =============================================================== ResNet =============================================================== #
    "resnet18": [resnet18(weights=ResNet18_Weights.DEFAULT), sum(p.numel() for p in resnet18(weights=ResNet18_Weights.DEFAULT).parameters())],
    "resnet34": [resnet34(weights=ResNet34_Weights.DEFAULT), sum(p.numel() for p in resnet34(weights=ResNet34_Weights.DEFAULT).parameters())],
    "resnet50": [resnet50(weights=ResNet50_Weights.DEFAULT), sum(p.numel() for p in resnet50(weights=ResNet50_Weights.DEFAULT).parameters())],
    "resnet101": [resnet101(weights=ResNet101_Weights.DEFAULT), sum(p.numel() for p in resnet101(weights=ResNet101_Weights.DEFAULT).parameters())],
    "resnet152": [resnet152(weights=ResNet152_Weights.DEFAULT), sum(p.numel() for p in resnet152(weights=ResNet152_Weights.DEFAULT).parameters())],
    # ============================================================= Mobile Net ============================================================= #
    "mobilenetv3small": [mobilenet_v3_small(weights=MobileNet_V3_Small_Weights.DEFAULT), sum(p.numel() for p in mobilenet_v3_small(weights=MobileNet_V3_Small_Weights.DEFAULT).parameters())],
    "mobilenetv3large": [mobilenet_v3_large(weights=MobileNet_V3_Large_Weights.DEFAULT), sum(p.numel() for p in mobilenet_v3_large(weights=MobileNet_V3_Large_Weights.DEFAULT).parameters())],
    # ============================================================ EfficientNet ============================================================ #
    "efficientnet_b0": [efficientnet_b0(weights=EfficientNet_B0_Weights.DEFAULT), sum(p.numel() for p in efficientnet_b0(weights=EfficientNet_B0_Weights.DEFAULT).parameters())],
    "efficientnet_b1": [efficientnet_b1(weights=EfficientNet_B1_Weights.DEFAULT), sum(p.numel() for p in efficientnet_b1(weights=EfficientNet_B1_Weights.DEFAULT).parameters())],
    "efficientnet_v2_s": [efficientnet_v2_s(weights=EfficientNet_V2_S_Weights.DEFAULT), sum(p.numel() for p in efficientnet_v2_s(weights=EfficientNet_V2_S_Weights.DEFAULT).parameters())],
    "efficientnet_v2_m": [efficientnet_v2_m(weights=EfficientNet_V2_M_Weights.DEFAULT), sum(p.numel() for p in efficientnet_v2_m(weights=EfficientNet_V2_M_Weights.DEFAULT).parameters())],
    "efficientnet_v2_l": [efficientnet_v2_l(weights=EfficientNet_V2_L_Weights.DEFAULT), sum(p.numel() for p in efficientnet_v2_l(weights=EfficientNet_V2_L_Weights.DEFAULT).parameters())],
    # ========================================================= Vision Transformer ========================================================= #
    "vit_b_16": [vit_b_16(weights=ViT_B_16_Weights.DEFAULT), sum(p.numel() for p in vit_b_16(weights=ViT_B_16_Weights.DEFAULT).parameters())],
    "vit_b_32": [vit_b_32(weights=ViT_B_32_Weights.DEFAULT), sum(p.numel() for p in vit_b_32(weights=ViT_B_32_Weights.DEFAULT).parameters())],
    "vit_l_16": [vit_l_16(weights=ViT_L_16_Weights.DEFAULT), sum(p.numel() for p in vit_l_16(weights=ViT_L_16_Weights.DEFAULT).parameters())],
    "vit_l_32": [vit_l_32(weights=ViT_L_32_Weights.DEFAULT), sum(p.numel() for p in vit_l_32(weights=ViT_L_32_Weights.DEFAULT).parameters())],
}

# Clear cuda mem and cache
torch.cuda.empty_cache()

# Check for GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
# Extract features function
def extract_features(model, dataloader, device):
    model.eval()
    features = []
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Extracting features"):
            images, _ = batch
            images = images.to(device)
            # with torch.autocast("cuda", dtype=torch.float16, enabled=True, cache_enabled=True):
            #     feats = model(images).cpu()  # Move features to CPU immediately
            feats = model(images).cpu()
            features.append(feats)
    return torch.cat(features)

class BinaryClassifier(nn.Module):
    def __init__(self, input_dim=1000, hidden_dim1=512, hidden_dim2=256):
        super(BinaryClassifier, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim1)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim1, hidden_dim2)
        self.fc3 = nn.Linear(hidden_dim2, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        x = self.sigmoid(x)
        return x

def train_model(model, train_loader, criterion, optimizer, num_epochs, num_samples, name):
    model.train()
    print("Training begins!")
    for epoch in range(num_epochs):
        running_loss = 0.0
        for batch_features, batch_labels in train_loader:
            outputs = model(batch_features)
            loss = criterion(outputs, batch_labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * batch_features.size(0)
        epoch_loss = running_loss / num_samples
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")
    torch.save(model.state_dict(), f"binary_classifier_{name}.pth")
    print("Model has been saved.")

# Evaluation function
def evaluate_model(predict_model, model, loader, device, true_labels):
    features = extract_features(model, loader, device)
    test_features = features.to(device)
    predicted_labels = torch.round(predict_model(test_features)).int().cpu()

    accuracy = accuracy_score(true_labels, predicted_labels)
    report = classification_report(true_labels, predicted_labels, target_names=["No Wildfire", "Wildfire"])
    cm = confusion_matrix(true_labels, predicted_labels)

    return accuracy, report, cm

In [None]:
from torch.utils.data import random_split
BATCHSIZE = 64
# Load datasets
train_ds, test_ds, valid_ds = get_wildfire_datasets()

# train
train_loader = DataLoader(train_ds, batch_size=BATCHSIZE, shuffle=True)

# valid --> new train & new valid
val_size = len(valid_ds)

train_size = int(val_size * 0.8)
valid_size = val_size - train_size

new_train_ds, new_valid_ds = random_split(valid_ds, [train_size, valid_size])

new_train_loader = DataLoader(new_train_ds, batch_size=BATCHSIZE, shuffle=False)
new_train_labels = np.array([label for _, label in tqdm(new_train_ds, desc="Extracting new train labels")])
new_train_labels_tensor = torch.tensor(new_train_labels).unsqueeze(-1).float()

new_valid_loader = DataLoader(new_valid_ds, batch_size=BATCHSIZE, shuffle=False)
new_valid_labels = np.array([label for _, label in tqdm(new_valid_ds, desc="Extracting new valid labels")])
new_valid_labels_tensor = torch.tensor(new_valid_labels).unsqueeze(-1).float()

# test
test_loader = DataLoader(test_ds, batch_size=BATCHSIZE, shuffle=False)
test_labels = np.array([label for _, label in tqdm(test_ds, desc="Extracting test labels")])
test_labels_tensor = torch.tensor(test_labels).unsqueeze(-1).float()

In [None]:
# estimate the cuda memory usage by batch size and model sizes and image size
for model_name, model_ls in MODELS.items():
    model = model_ls[0]
    cuda_available_memory = torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_reserved(0)
    estimated_memory = sum(p.numel() for p in model.parameters()) * 4 / 1024 ** 2  # in MB
    logging.info(f"Estimated model memory: {estimated_memory:.2f} MB")
    if estimated_memory > cuda_available_memory:
        raise MemoryError("Not enough GPU memory for this model.")

In [None]:
# Prepare models and extract features
model_results = {}
for model_name, model_ls in tqdm(MODELS.items(), desc="Processing models"):
    model = model_ls[0]
    logging.info(f"Setting up {model_name}...")
    model = model.to(device)

    # Extract features
    logging.info(f"Extracting features with {model_name}...")
    valid_loader = new_valid_loader
    features = extract_features(model, valid_loader, device)

    valid_features = features.to(device)
    valid_labels_tensor = new_train_labels_tensor
    valid_labels_tensor = valid_labels_tensor.to(device)

    mlp_dataset = TensorDataset(valid_features, valid_labels_tensor)
    mlp_batch_size = 32
    mlp_dataloader = DataLoader(mlp_dataset, batch_size=mlp_batch_size, shuffle=True)

    predict_model = BinaryClassifier(input_dim=features.shape[1]).to(device)
    criterion = nn.BCELoss()
    optimizer = optim.Adam(predict_model.parameters(), lr=0.001)

    train_model(predict_model, mlp_dataloader, criterion, optimizer, num_epochs=50, num_samples=len(mlp_dataset), name=model_name)

    model_results[model_name] = {
        "model": model,
        "predict_model": predict_model
    }

    # Clear cuda memory
    model = model.to(torch.device("cpu"))
    torch.cuda.empty_cache()

In [None]:
# Evaluate all models
for model_name, data in model_results.items():
    logging.info(f"Evaluating {model_name}...")

    predict_model = data["predict_model"]
    model = data["model"]
    model = model.to(device)

    accuracy, report, cm = evaluate_model(predict_model, model, test_loader, device, test_labels)

    print(f"{model_name} {accuracy * 100:.2f}%")
    # resnet18(weights=ResNet18_Weights.DEFAULT)
    print(f"{MODELS[model_name][0]}")
    print(predict_model)
    print(f"{model_name} classification report:\n{report}")

    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.title(f"{model_name} Confusion Matrix")
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.tight_layout()
    plt.show()

## Results Records

### ResNet
```txt
resnet18 95.97%
resnet18 classification report:
              precision    recall  f1-score   support

 No Wildfire       0.96      0.95      0.95      2820
    Wildfire       0.96      0.97      0.96      3480

    accuracy                           0.96      6300
   macro avg       0.96      0.96      0.96      6300
weighted avg       0.96      0.96      0.96      6300
```
```txt
resnet34 95.86%
resnet34 classification report:
              precision    recall  f1-score   support

 No Wildfire       0.95      0.96      0.95      2820
    Wildfire       0.97      0.96      0.96      3480

    accuracy                           0.96      6300
   macro avg       0.96      0.96      0.96      6300
weighted avg       0.96      0.96      0.96      6300
```
```txt
resnet50 97.22%
resnet50 classification report:
              precision    recall  f1-score   support

 No Wildfire       0.97      0.97      0.97      2820
    Wildfire       0.98      0.97      0.97      3480

    accuracy                           0.97      6300
   macro avg       0.97      0.97      0.97      6300
weighted avg       0.97      0.97      0.97      6300
```
```txt
resnet101 97.27%
resnet101 classification report:
              precision    recall  f1-score   support

 No Wildfire       0.98      0.96      0.97      2820
    Wildfire       0.97      0.98      0.98      3480

    accuracy                           0.97      6300
   macro avg       0.97      0.97      0.97      6300
weighted avg       0.97      0.97      0.97      6300
```
```txt
resnet152 97.06%
resnet152 classification report:
              precision    recall  f1-score   support

 No Wildfire       0.97      0.96      0.97      2820
    Wildfire       0.97      0.98      0.97      3480

    accuracy                           0.97      6300
   macro avg       0.97      0.97      0.97      6300
weighted avg       0.97      0.97      0.97      6300
```

### MobileNetv3
```txt
mobilenetv3small 97.02%
mobilenetv3small classification report:
              precision    recall  f1-score   support

 No Wildfire       0.97      0.97      0.97      2820
    Wildfire       0.97      0.97      0.97      3480

    accuracy                           0.97      6300
   macro avg       0.97      0.97      0.97      6300
weighted avg       0.97      0.97      0.97      6300
```
```txt
mobilenetv3large 97.30%
mobilenetv3large classification report:
              precision    recall  f1-score   support

 No Wildfire       0.97      0.97      0.97      2820
    Wildfire       0.97      0.98      0.98      3480

    accuracy                           0.97      6300
   macro avg       0.97      0.97      0.97      6300
weighted avg       0.97      0.97      0.97      6300
```
### EfficientNet
```txt
efficientnet_b0 96.73%
efficientnet_b0 classification report:
              precision    recall  f1-score   support

 No Wildfire       0.97      0.96      0.96      2820
    Wildfire       0.97      0.97      0.97      3480

    accuracy                           0.97      6300
   macro avg       0.97      0.97      0.97      6300
weighted avg       0.97      0.97      0.97      6300
```
```txt
efficientnet_b1 97.56%
efficientnet_b1 classification report:
              precision    recall  f1-score   support

 No Wildfire       0.98      0.97      0.97      2820
    Wildfire       0.97      0.98      0.98      3480

    accuracy                           0.98      6300
   macro avg       0.98      0.97      0.98      6300
weighted avg       0.98      0.98      0.98      6300
```
```txt
efficientnet_v2_s 96.43%
efficientnet_v2_s classification report:
              precision    recall  f1-score   support

 No Wildfire       0.96      0.96      0.96      2820
    Wildfire       0.97      0.97      0.97      3480

    accuracy                           0.96      6300
   macro avg       0.96      0.96      0.96      6300
weighted avg       0.96      0.96      0.96      6300
```
```txt
efficientnet_v2_m 95.29%
efficientnet_v2_m classification report:
              precision    recall  f1-score   support

 No Wildfire       0.95      0.95      0.95      2820
    Wildfire       0.96      0.96      0.96      3480

    accuracy                           0.95      6300
   macro avg       0.95      0.95      0.95      6300
weighted avg       0.95      0.95      0.95      6300
```
```txt
efficientnet_v2_l 97.27%
efficientnet_v2_l classification report:
              precision    recall  f1-score   support

 No Wildfire       0.97      0.97      0.97      2820
    Wildfire       0.98      0.97      0.98      3480

    accuracy                           0.97      6300
   macro avg       0.97      0.97      0.97      6300
weighted avg       0.97      0.97      0.97      6300
```
### ViT
```txt
vit_b_16 97.05%
vit_b_16 classification report:
              precision    recall  f1-score   support

 No Wildfire       0.97      0.96      0.97      2820
    Wildfire       0.97      0.98      0.97      3480

    accuracy                           0.97      6300
   macro avg       0.97      0.97      0.97      6300
weighted avg       0.97      0.97      0.97      6300
```
```txt
vit_b_32 96.98%
vit_b_32 classification report:
              precision    recall  f1-score   support

 No Wildfire       0.97      0.96      0.97      2820
    Wildfire       0.97      0.98      0.97      3480

    accuracy                           0.97      6300
   macro avg       0.97      0.97      0.97      6300
weighted avg       0.97      0.97      0.97      6300
```
```txt
vit_l_16 97.75%
vit_l_16 classification report:
              precision    recall  f1-score   support

 No Wildfire       0.98      0.97      0.97      2820
    Wildfire       0.98      0.98      0.98      3480

    accuracy                           0.98      6300
   macro avg       0.98      0.98      0.98      6300
weighted avg       0.98      0.98      0.98      6300
```
```txt
vit_l_32 96.78%
vit_l_32 classification report:
              precision    recall  f1-score   support

 No Wildfire       0.97      0.96      0.96      2820
    Wildfire       0.97      0.97      0.97      3480

    accuracy                           0.97      6300
   macro avg       0.97      0.97      0.97      6300
weighted avg       0.97      0.97      0.97      6300
```
```

### If you want to test the models from the pretrained classifier weights:

In [None]:
for model_name, model_ls in tqdm(MODELS.items(), desc="Testing models"):

    model = model_ls[0]
    logging.info(f"Evaluating {model_name}...")
    model = model.to(device)

    output_dim = 1000
    predict_model = BinaryClassifier(input_dim=output_dim).to(device)
    state_dict = torch.load(f"./models/binary_classifier_{model_name}.pth", weights_only=True, map_location=device)
    predict_model.load_state_dict(state_dict)
    predict_model.eval()

    accuracy, report, cm = evaluate_model(predict_model, model, test_loader, device, test_labels)

    print(f"{model_name} {accuracy * 100:.2f}%")
    print(f"{MODELS[model_name][0]}")
    print(predict_model)
    print(f"{model_name} classification report:\n{report}")

    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.title(f"{model_name} Confusion Matrix")
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.tight_layout()
    plt.show()