In [None]:
import kagglehub

# Download selected version
path = kagglehub.dataset_download("elmadafri/the-wildfire-dataset/versions/1")

print("Path to dataset files:", path)

Downloading from https://www.kaggle.com/api/v1/datasets/download/elmadafri/the-wildfire-dataset?dataset_version_number=1...


100%|██████████| 9.94G/9.94G [01:19<00:00, 134MB/s] 

Extracting files...





Path to dataset files: /root/.cache/kagglehub/datasets/elmadafri/the-wildfire-dataset/versions/1


In [None]:
!rm /root/.cache/kagglehub/datasets/elmadafri/the-wildfire-dataset/versions/1/the_wildfire_dataset/the_wildfire_dataset/val/fire/Both_smoke_and_fire/desktop.ini

In [None]:
import torch
import pandas as pd
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
from transformers import MobileViTForImageClassification, MobileViTImageProcessor

In [None]:
# Use pretrained weights
model = MobileViTForImageClassification.from_pretrained("apple/mobilevit-small")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/70.0k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/22.5M [00:00<?, ?B/s]

In [None]:
# Fire vs No Fire
in_features = model.classifier.in_features

# Strip out original classifier
model.classifier = nn.Identity()

# Binary Head
binary_head = nn.Sequential(
    nn.Linear(in_features, 1),
    nn.Sigmoid()
)

# Multi-class Head -> 5 output classes
multi_class_head = nn.Linear(in_features, 5)

In [None]:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset

class CustomFireDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.binary_labels = []
        self.multi_class_labels = []

        # Define mappings for binary and multi-class labels
        binary_label_mapping = {'fire': 1, 'nofire': 0}
        multi_class_mapping = {
            'fire': {'Both_smoke_and_fire': 0, 'Smoke_from_fires': 1},
            'nofire': {'Fire_confounding_elements': 2, 'Forested_areas_without_confounding_elements': 3, 'Smoke_confounding_elements': 4}
        }

        # Traverse the root directory and collect image paths and labels
        for binary_label_name in os.listdir(root_dir):
            binary_label_path = os.path.join(root_dir, binary_label_name)
            if os.path.isdir(binary_label_path):
                # Assign binary label
                binary_label = binary_label_mapping[binary_label_name]

                # Traverse subclasses
                for subclass_name in os.listdir(binary_label_path):
                    subclass_path = os.path.join(binary_label_path, subclass_name)
                    if os.path.isdir(subclass_path):
                        # Assign multi-class label
                        multi_class_label = multi_class_mapping[binary_label_name][subclass_name]

                        # Collect all images in the subclass directory
                        for img_name in os.listdir(subclass_path):
                            img_path = os.path.join(subclass_path, img_name)
                            if os.path.isfile(img_path):
                                self.image_paths.append(img_path)
                                self.binary_labels.append(binary_label)
                                self.multi_class_labels.append(multi_class_label)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")
        binary_label = self.binary_labels[idx]
        multi_class_label = self.multi_class_labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, (torch.tensor(binary_label, dtype=torch.float), torch.tensor(multi_class_label, dtype=torch.long))


In [None]:
train_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Load your datasets (replace "path_to_train_dataset" and "path_to_val_dataset" with your paths)
#train_dataset = datasets.ImageFolder("/root/.cache/kagglehub/datasets/elmadafri/the-wildfire-dataset/versions/3/the_wildfire_dataset_2n_version/train", transform=train_transforms)
#val_dataset = datasets.ImageFolder("/root/.cache/kagglehub/datasets/elmadafri/the-wildfire-dataset/versions/3/the_wildfire_dataset_2n_version/val", transform=val_transforms)
train_dataset = CustomFireDataset(root_dir="/root/.cache/kagglehub/datasets/elmadafri/the-wildfire-dataset/versions/1/the_wildfire_dataset/the_wildfire_dataset/train", transform=train_transforms)
val_dataset = CustomFireDataset(root_dir="/root/.cache/kagglehub/datasets/elmadafri/the-wildfire-dataset/versions/1/the_wildfire_dataset/the_wildfire_dataset/val", transform=val_transforms)


train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=10)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=10)

In [None]:
criterion_binary = nn.BCELoss()
criterion_multi_class = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(
    list(model.parameters()) + list(binary_head.parameters()) + list(multi_class_head.parameters()),
    lr=0.01
)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
binary_head = binary_head.to(device)
multi_class_head = multi_class_head.to(device)

In [None]:
def combined_loss(binary_output, binary_target, multi_class_output, multi_class_target, alpha=0.5, beta=0.5):
    loss_binary = criterion_binary(binary_output, binary_target)
    loss_multi_class = criterion_multi_class(multi_class_output, multi_class_target)
    return alpha * loss_binary + beta * loss_multi_class

In [None]:
epoch_list = []
loss_list = []
train_binary_accuracy_list = []
train_multi_class_accuracy_list = []
val_binary_accuracy_list = []
val_multi_class_accuracy_list = []

# Training loop
epochs = 30  # You can adjust the number of epochs based on your needs
for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    correct_train_binary = 0
    total_train_binary = 0
    correct_train_multi_class = 0
    total_train_multi_class = 0

    # Training Loop
    for images, (binary_labels, multi_class_labels) in train_loader:
        images = images.to(device)
        binary_labels = binary_labels.to(device).float()
        multi_class_labels = multi_class_labels.to(device).long()

        optimizer.zero_grad()
        features = model(images)  # Extract features using MobileViT
        features = features.logits
        binary_output = binary_head(features)
        multi_class_output = multi_class_head(features)

        # Compute the combined loss
        loss = combined_loss(binary_output.squeeze(), binary_labels, multi_class_output, multi_class_labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        # Calculate training accuracy for binary classification
        predicted_binary = (binary_output.squeeze() > 0.5).int()  # Convert probabilities to binary predictions
        correct_train_binary += (predicted_binary == binary_labels.int()).sum().item()
        total_train_binary += binary_labels.size(0)

        # Calculate training accuracy for multi-class classification
        _, predicted_multi_class = torch.max(multi_class_output, 1)
        correct_train_multi_class += (predicted_multi_class == multi_class_labels).sum().item()
        total_train_multi_class += multi_class_labels.size(0)

    # Calculate the average loss and training accuracies
    avg_loss = running_loss / len(train_loader)
    train_binary_accuracy = 100 * correct_train_binary / total_train_binary
    train_multi_class_accuracy = 100 * correct_train_multi_class / total_train_multi_class

    # Validation Loop
    model.eval()
    correct_val_binary = 0
    total_val_binary = 0
    correct_val_multi_class = 0
    total_val_multi_class = 0
    with torch.no_grad():
        for images, (binary_labels, multi_class_labels) in val_loader:
            images = images.to(device)
            binary_labels = binary_labels.to(device).float()
            multi_class_labels = multi_class_labels.to(device).long()

            features = model(images)  # Extract features using MobileNetV3
            features = features.logits
            binary_output = binary_head(features)
            multi_class_output = multi_class_head(features)

            # Calculate validation accuracy for binary classification
            predicted_binary = (binary_output.squeeze() > 0.5).int()
            correct_val_binary += (predicted_binary == binary_labels.int()).sum().item()
            total_val_binary += binary_labels.size(0)

            # Calculate validation accuracy for multi-class classification
            _, predicted_multi_class = torch.max(multi_class_output, 1)
            correct_val_multi_class += (predicted_multi_class == multi_class_labels).sum().item()
            total_val_multi_class += multi_class_labels.size(0)

    # Calculate validation accuracies
    val_binary_accuracy = 100 * correct_val_binary / total_val_binary
    val_multi_class_accuracy = 100 * correct_val_multi_class / total_val_multi_class

    # Store the metrics
    epoch_list.append(epoch + 1)
    loss_list.append(avg_loss)
    train_binary_accuracy_list.append(train_binary_accuracy)
    train_multi_class_accuracy_list.append(train_multi_class_accuracy)
    val_binary_accuracy_list.append(val_binary_accuracy)
    val_multi_class_accuracy_list.append(val_multi_class_accuracy)

    # Log the metrics
    print(f"Epoch {epoch+1}/{epochs}")
    print(f"  Loss: {avg_loss:.4f}")
    print(f"  Training Binary Accuracy: {train_binary_accuracy:.2f}%")
    print(f"  Training Multi-Class Accuracy: {train_multi_class_accuracy:.2f}%")
    print(f"  Validation Binary Accuracy: {val_binary_accuracy:.2f}%")
    print(f"  Validation Multi-Class Accuracy: {val_multi_class_accuracy:.2f}%")



Epoch 1/30
  Loss: 1.0513
  Training Binary Accuracy: 64.39%
  Training Multi-Class Accuracy: 39.85%
  Validation Binary Accuracy: 59.70%
  Validation Multi-Class Accuracy: 37.56%




Epoch 2/30
  Loss: 0.9995
  Training Binary Accuracy: 65.08%
  Training Multi-Class Accuracy: 42.02%
  Validation Binary Accuracy: 67.66%
  Validation Multi-Class Accuracy: 44.03%




Epoch 3/30
  Loss: 0.9643
  Training Binary Accuracy: 67.94%
  Training Multi-Class Accuracy: 46.00%
  Validation Binary Accuracy: 46.77%
  Validation Multi-Class Accuracy: 31.09%




Epoch 4/30
  Loss: 0.9032
  Training Binary Accuracy: 72.23%
  Training Multi-Class Accuracy: 49.07%
  Validation Binary Accuracy: 74.13%
  Validation Multi-Class Accuracy: 51.00%




Epoch 5/30
  Loss: 0.8530
  Training Binary Accuracy: 74.14%
  Training Multi-Class Accuracy: 50.56%
  Validation Binary Accuracy: 73.63%
  Validation Multi-Class Accuracy: 50.00%




Epoch 6/30
  Loss: 0.8104
  Training Binary Accuracy: 77.11%
  Training Multi-Class Accuracy: 52.31%
  Validation Binary Accuracy: 75.37%
  Validation Multi-Class Accuracy: 52.49%




Epoch 7/30
  Loss: 0.7717
  Training Binary Accuracy: 80.34%
  Training Multi-Class Accuracy: 56.39%
  Validation Binary Accuracy: 76.62%
  Validation Multi-Class Accuracy: 51.99%




Epoch 8/30
  Loss: 0.7491
  Training Binary Accuracy: 80.50%
  Training Multi-Class Accuracy: 56.28%
  Validation Binary Accuracy: 78.61%
  Validation Multi-Class Accuracy: 45.77%




Epoch 9/30
  Loss: 0.7267
  Training Binary Accuracy: 81.19%
  Training Multi-Class Accuracy: 56.76%
  Validation Binary Accuracy: 77.86%
  Validation Multi-Class Accuracy: 57.96%




Epoch 10/30
  Loss: 0.7205
  Training Binary Accuracy: 81.19%
  Training Multi-Class Accuracy: 57.34%
  Validation Binary Accuracy: 75.37%
  Validation Multi-Class Accuracy: 54.23%




Epoch 11/30
  Loss: 0.6980
  Training Binary Accuracy: 82.09%
  Training Multi-Class Accuracy: 59.30%
  Validation Binary Accuracy: 65.42%
  Validation Multi-Class Accuracy: 41.54%




Epoch 12/30
  Loss: 0.7249
  Training Binary Accuracy: 81.66%
  Training Multi-Class Accuracy: 57.98%
  Validation Binary Accuracy: 75.62%
  Validation Multi-Class Accuracy: 51.74%




Epoch 13/30
  Loss: 0.6639
  Training Binary Accuracy: 83.62%
  Training Multi-Class Accuracy: 59.94%
  Validation Binary Accuracy: 73.13%
  Validation Multi-Class Accuracy: 57.21%




Epoch 14/30
  Loss: 0.6703
  Training Binary Accuracy: 83.15%
  Training Multi-Class Accuracy: 59.30%
  Validation Binary Accuracy: 58.21%
  Validation Multi-Class Accuracy: 39.55%




Epoch 15/30
  Loss: 0.6488
  Training Binary Accuracy: 84.47%
  Training Multi-Class Accuracy: 61.00%
  Validation Binary Accuracy: 81.09%
  Validation Multi-Class Accuracy: 60.95%




Epoch 16/30
  Loss: 0.6371
  Training Binary Accuracy: 84.68%
  Training Multi-Class Accuracy: 63.12%
  Validation Binary Accuracy: 82.59%
  Validation Multi-Class Accuracy: 61.69%




Epoch 17/30
  Loss: 0.6408
  Training Binary Accuracy: 85.43%
  Training Multi-Class Accuracy: 62.06%
  Validation Binary Accuracy: 79.60%
  Validation Multi-Class Accuracy: 59.45%




Epoch 18/30
  Loss: 0.6408
  Training Binary Accuracy: 85.80%
  Training Multi-Class Accuracy: 63.33%
  Validation Binary Accuracy: 81.84%
  Validation Multi-Class Accuracy: 61.94%




Epoch 19/30
  Loss: 0.6182
  Training Binary Accuracy: 85.96%
  Training Multi-Class Accuracy: 64.86%
  Validation Binary Accuracy: 82.34%
  Validation Multi-Class Accuracy: 62.44%




Epoch 20/30
  Loss: 0.6159
  Training Binary Accuracy: 86.06%
  Training Multi-Class Accuracy: 63.65%
  Validation Binary Accuracy: 78.61%
  Validation Multi-Class Accuracy: 54.48%




Epoch 21/30
  Loss: 0.5907
  Training Binary Accuracy: 86.70%
  Training Multi-Class Accuracy: 64.60%
  Validation Binary Accuracy: 82.59%
  Validation Multi-Class Accuracy: 66.92%




Epoch 22/30
  Loss: 0.5872
  Training Binary Accuracy: 85.74%
  Training Multi-Class Accuracy: 64.44%
  Validation Binary Accuracy: 81.59%
  Validation Multi-Class Accuracy: 57.46%




Epoch 23/30
  Loss: 0.5732
  Training Binary Accuracy: 86.75%
  Training Multi-Class Accuracy: 66.30%
  Validation Binary Accuracy: 75.37%
  Validation Multi-Class Accuracy: 56.22%




Epoch 24/30
  Loss: 0.5773
  Training Binary Accuracy: 86.70%
  Training Multi-Class Accuracy: 66.98%
  Validation Binary Accuracy: 78.36%
  Validation Multi-Class Accuracy: 58.46%




Epoch 25/30
  Loss: 0.5582
  Training Binary Accuracy: 87.44%
  Training Multi-Class Accuracy: 68.57%
  Validation Binary Accuracy: 83.33%
  Validation Multi-Class Accuracy: 66.67%




Epoch 26/30
  Loss: 0.5491
  Training Binary Accuracy: 87.18%
  Training Multi-Class Accuracy: 67.67%
  Validation Binary Accuracy: 83.83%
  Validation Multi-Class Accuracy: 67.41%




Epoch 27/30
  Loss: 0.5487
  Training Binary Accuracy: 87.65%
  Training Multi-Class Accuracy: 67.04%
  Validation Binary Accuracy: 81.59%
  Validation Multi-Class Accuracy: 62.94%




Epoch 28/30
  Loss: 0.5384
  Training Binary Accuracy: 88.82%
  Training Multi-Class Accuracy: 68.68%
  Validation Binary Accuracy: 82.09%
  Validation Multi-Class Accuracy: 63.93%




Epoch 29/30
  Loss: 0.5107
  Training Binary Accuracy: 88.77%
  Training Multi-Class Accuracy: 69.63%
  Validation Binary Accuracy: 79.60%
  Validation Multi-Class Accuracy: 60.70%




Epoch 30/30
  Loss: 0.5291
  Training Binary Accuracy: 88.71%
  Training Multi-Class Accuracy: 70.32%
  Validation Binary Accuracy: 82.59%
  Validation Multi-Class Accuracy: 64.18%


In [None]:
# Save the model's state dictionary
torch.save(model.state_dict(), "mobileVit_fire_multi_classifier.pth")
print("Model saved successfully!")

Model saved successfully!


In [None]:
# Save the metrics to a CSV file
data = {
    'Epoch': epoch_list,
    'Loss': loss_list,
    'Train Binary Accuracy': train_binary_accuracy_list,
    'Train Multi-Class Accuracy': train_multi_class_accuracy_list,
    'Validation Binary Accuracy': val_binary_accuracy_list,
    'Validation Multi-Class Accuracy': val_multi_class_accuracy_list
}

# Create a DataFrame from the dictionary
df = pd.DataFrame(data)

# Save the DataFrame to a CSV file
df.to_csv('multi_training_results.csv', index=False)
print("Training results saved to 'training_results.csv' successfully!")

Training results saved to 'training_results.csv' successfully!


In [None]:
# Define transformations for the test set
test_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Load your test dataset
test_dataset = CustomFireDataset(root_dir="/root/.cache/kagglehub/datasets/elmadafri/the-wildfire-dataset/versions/1/the_wildfire_dataset/the_wildfire_dataset/test", transform=test_transforms)

# Create a DataLoader for the test dataset
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=10)

In [None]:
from sklearn.metrics import roc_curve, auc, confusion_matrix, classification_report
import numpy as np
import torch

# Lists to store true labels and predicted probabilities/scores for binary classification
true_binary_labels = []
predicted_binary_probs = []  # Probabilities for the positive class (fire)

# Lists to store true labels and predicted labels for multi-class classification
true_multi_class_labels = []
predicted_multi_class_labels = []

# Evaluate the model on the test set
model.eval()
with torch.no_grad():
    for images, (binary_labels, multi_class_labels) in test_loader:
        images = images.to(device)
        binary_labels = binary_labels.to(device).float()
        multi_class_labels = multi_class_labels.to(device).long()

        # Forward pass
        features = model(images)
        features = features.logits
        binary_output = binary_head(features)
        multi_class_output = multi_class_head(features)

        # Get probabilities for the positive class (fire) in binary classification
        binary_probs = binary_output.squeeze().cpu().numpy()  # Probabilities from Sigmoid

        # Get predicted class labels for multi-class classification
        _, predicted_multi_class = torch.max(multi_class_output, 1)
        predicted_multi_class = predicted_multi_class.cpu().numpy()

        # Store true labels and predictions for binary classification
        true_binary_labels.extend(binary_labels.cpu().numpy())
        predicted_binary_probs.extend(binary_probs)

        # Store true labels and predictions for multi-class classification
        true_multi_class_labels.extend(multi_class_labels.cpu().numpy())
        predicted_multi_class_labels.extend(predicted_multi_class)

# Combine all data into a single DataFrame
data = pd.DataFrame({
    "True Binary Labels": true_binary_labels,
    "Predicted Binary Probabilities": predicted_binary_probs,
    "True Multi-Class Labels": true_multi_class_labels,
    "Predicted Multi-Class Labels": predicted_multi_class_labels
})

data.to_csv("mtl_mobileVit_test_results.csv", index=False)



# **METRICS**

In [None]:
import pandas as pd
import numpy as np
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score, roc_auc_score, classification_report

# Load the data
data = pd.read_csv("mtl_mobileVit_test_results.csv")

# Binary Classification Metrics
true_binary = np.array(data['True Binary Labels'])
pred_binary_probs = np.array(data['Predicted Binary Probabilities'])
pred_binary = (pred_binary_probs >= 0.5).astype(int)

f1 = f1_score(true_binary, pred_binary)
precision = precision_score(true_binary, pred_binary)
recall = recall_score(true_binary, pred_binary)
accuracy = accuracy_score(true_binary, pred_binary)
roc_auc = roc_auc_score(true_binary, pred_binary_probs)

print("Binary Classification Metrics:")
print(f"F1 Score: {f1:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"Accuracy: {accuracy:.4f}")
print(f"ROC-AUC: {roc_auc:.4f}")

# Multi-class Classification Metrics
true_multi = np.array(data['True Multi-Class Labels'])
pred_multi = np.array(data['Predicted Multi-Class Labels'])

print("\nMulti-class Classification Report:")
print(classification_report(true_multi, pred_multi))

Binary Classification Metrics:
F1 Score: 0.8100
Precision: 0.8025
Recall: 0.8176
Accuracy: 0.8512
ROC-AUC: 0.9257

Multi-class Classification Report:
              precision    recall  f1-score   support

           0       0.68      0.51      0.58        59
           1       0.61      0.76      0.68       100
           2       0.55      0.62      0.58        52
           3       0.82      0.78      0.80       128
           4       0.64      0.55      0.59        71

    accuracy                           0.68       410
   macro avg       0.66      0.64      0.65       410
weighted avg       0.68      0.68      0.67       410

