# EdgeNeXt

In [None]:
!pip install torch torchvision timm



In [None]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("elmadafri/the-wildfire-dataset/versions/1")

print("Path to dataset files:", path)

Downloading from https://www.kaggle.com/api/v1/datasets/download/elmadafri/the-wildfire-dataset?dataset_version_number=1...


100%|██████████| 9.94G/9.94G [01:28<00:00, 120MB/s]

Extracting files...





Path to dataset files: /root/.cache/kagglehub/datasets/elmadafri/the-wildfire-dataset/versions/1


In [None]:
!rm /root/.cache/kagglehub/datasets/elmadafri/the-wildfire-dataset/versions/1/the_wildfire_dataset/the_wildfire_dataset/val/fire/Both_smoke_and_fire/desktop.ini

In [None]:
import torch
import pandas as pd
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
from timm import create_model

## Load Model

In [None]:
import torch
from torch import nn
from timm import create_model

# Load the pretrained EdgeNeXt model
model = create_model('edgenext_small', pretrained=True)

# Dynamically compute the number of features from forward_features
dummy_input = torch.randn(1, 3, 224, 224)
with torch.no_grad():
    features = model.forward_features(dummy_input)
flattened_features_size = features.reshape(features.size(0), -1).size(1)

# Add two separate heads for binary and multi-class classification
# Binary classification head: Includes Dropout and LeakyReLU
model.head_binary = nn.Sequential(
    nn.Linear(flattened_features_size, flattened_features_size // 2),
    nn.LeakyReLU(),
    nn.Dropout(p=0.5),
    nn.Linear(flattened_features_size // 2, 1)
)

# Multi-class classification head: Includes LeakyReLU
model.head_multiclass = nn.Sequential(
    nn.Linear(flattened_features_size, flattened_features_size // 2),
    nn.LeakyReLU(),
    nn.Dropout(p=0.4),
    nn.Linear(flattened_features_size // 2, 5)
)

# Define a custom forward method
def forward_with_two_heads(self, x):
    features = self.forward_features(x)  # Extract features from the backbone
    features = features.reshape(features.size(0), -1)  # Flatten features dynamically using reshape
    binary_output = self.head_binary(features)  # Binary classification
    multiclass_output = self.head_multiclass(features)  # Multi-class classification
    return binary_output, multiclass_output

# Attach the custom forward method to the model
model.forward = forward_with_two_heads.__get__(model, type(model))

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model.safetensors:   0%|          | 0.00/22.4M [00:00<?, ?B/s]

In [None]:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset

class CustomFireDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.binary_labels = []
        self.multi_class_labels = []

        # Define mappings for binary and multi-class labels
        binary_label_mapping = {'fire': 1, 'nofire': 0}
        multi_class_mapping = {
            'fire': {'Both_smoke_and_fire': 0, 'Smoke_from_fires': 1},
            'nofire': {'Fire_confounding_elements': 2, 'Forested_areas_without_confounding_elements': 3, 'Smoke_confounding_elements': 4}
        }

        # Traverse the root directory and collect image paths and labels
        for binary_label_name in os.listdir(root_dir):
            binary_label_path = os.path.join(root_dir, binary_label_name)
            if os.path.isdir(binary_label_path):
                # Assign binary label
                binary_label = binary_label_mapping[binary_label_name]

                # Traverse subclasses
                for subclass_name in os.listdir(binary_label_path):
                    subclass_path = os.path.join(binary_label_path, subclass_name)
                    if os.path.isdir(subclass_path):
                        # Assign multi-class label
                        multi_class_label = multi_class_mapping[binary_label_name][subclass_name]

                        # Collect all images in the subclass directory
                        for img_name in os.listdir(subclass_path):
                            img_path = os.path.join(subclass_path, img_name)
                            if os.path.isfile(img_path):
                                self.image_paths.append(img_path)
                                self.binary_labels.append(binary_label)
                                self.multi_class_labels.append(multi_class_label)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")
        binary_label = self.binary_labels[idx]
        multi_class_label = self.multi_class_labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, (torch.tensor(binary_label, dtype=torch.float), torch.tensor(multi_class_label, dtype=torch.long))


In [None]:
train_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Load your datasets (replace "path_to_train_dataset" and "path_to_val_dataset" with your paths)

train_dataset = CustomFireDataset(root_dir="/root/.cache/kagglehub/datasets/elmadafri/the-wildfire-dataset/versions/1/the_wildfire_dataset/the_wildfire_dataset/train", transform=train_transforms)
val_dataset = CustomFireDataset(root_dir="/root/.cache/kagglehub/datasets/elmadafri/the-wildfire-dataset/versions/1/the_wildfire_dataset/the_wildfire_dataset/val", transform=val_transforms)

#train_dataset = CustomFireDataset(root_dir="./the_wildfire_dataset/the_wildfire_dataset/train", transform=train_transforms)
#val_dataset = CustomFireDataset(root_dir="./the_wildfire_dataset/the_wildfire_dataset/val", transform=val_transforms)


train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=10)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=10)

In [None]:
# Define loss functions
criterion_binary = nn.BCEWithLogitsLoss() #More numerically stable
criterion_multiclass = nn.CrossEntropyLoss()

# Optimizer
optimizer = optim.Adam(list(model.parameters()) + list(model.head_binary.parameters()) + list(model.head_multiclass.parameters()), lr=0.001)

  return disable_fn(*args, **kwargs)


In [None]:
# Device configuration
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [None]:
def combined_loss(binary_output, binary_target, multi_class_output, multi_class_target, alpha=0.35, beta=0.65):
    # Compute binary classification loss using BCEWithLogitsLoss
    loss_binary = criterion_binary(binary_output.squeeze(-1), binary_target.float())

    # Compute multi-class classification loss using CrossEntropyLoss
    loss_multiclass = criterion_multiclass(multi_class_output, multi_class_target)

    # Combine losses with weights
    total_loss = alpha * loss_binary + beta * loss_multiclass
    return total_loss


In [None]:
import os

# Define a directory to save model checkpoints
checkpoint_dir = "model_checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

# Training loop
epochs = 50
for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    correct_train_binary = 0
    total_train_binary = 0
    correct_train_multi_class = 0
    total_train_multi_class = 0

    # Training loop
    for images, (binary_labels, multi_class_labels) in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs} - Training"):
        images = images.to(device)
        binary_labels = binary_labels.to(device)
        multi_class_labels = multi_class_labels.to(device)

        optimizer.zero_grad()

        # Forward pass
        binary_output, multi_class_output = model(images)

        # Compute combined loss
        loss = combined_loss(binary_output, binary_labels, multi_class_output, multi_class_labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        # Calculate training accuracy for binary classification
        predicted_binary = (torch.sigmoid(binary_output.squeeze()) > 0.5).int()  # Apply sigmoid for probability
        correct_train_binary += (predicted_binary == binary_labels.int()).sum().item()
        total_train_binary += binary_labels.size(0)

        # Calculate training accuracy for multi-class classification
        _, predicted_multi_class = torch.max(multi_class_output, 1)
        correct_train_multi_class += (predicted_multi_class == multi_class_labels).sum().item()
        total_train_multi_class += multi_class_labels.size(0)

    # Calculate the average loss and training accuracies
    avg_loss = running_loss / len(train_loader)
    train_binary_accuracy = 100 * correct_train_binary / total_train_binary
    train_multi_class_accuracy = 100 * correct_train_multi_class / total_train_multi_class

    # Validation loop
    model.eval()
    correct_val_binary = 0
    total_val_binary = 0
    correct_val_multi_class = 0
    total_val_multi_class = 0
    with torch.no_grad():
        for images, (binary_labels, multi_class_labels) in tqdm(val_loader, desc=f"Epoch {epoch + 1}/{epochs} - Validation"):
            images = images.to(device)
            binary_labels = binary_labels.to(device)
            multi_class_labels = multi_class_labels.to(device)

            binary_output, multi_class_output = model(images)

            # Calculate validation accuracy for binary classification
            predicted_binary = (torch.sigmoid(binary_output.squeeze()) > 0.5).int()  # Apply sigmoid for probability
            correct_val_binary += (predicted_binary == binary_labels.int()).sum().item()
            total_val_binary += binary_labels.size(0)

            # Calculate validation accuracy for multi-class classification
            _, predicted_multi_class = torch.max(multi_class_output, 1)
            correct_val_multi_class += (predicted_multi_class == multi_class_labels).sum().item()
            total_val_multi_class += multi_class_labels.size(0)

    # Calculate validation accuracies
    val_binary_accuracy = 100 * correct_val_binary / total_val_binary
    val_multi_class_accuracy = 100 * correct_val_multi_class / total_val_multi_class

    # Store the metrics
    epoch_list.append(epoch + 1)
    loss_list.append(avg_loss)
    train_binary_accuracy_list.append(train_binary_accuracy)
    train_multi_class_accuracy_list.append(train_multi_class_accuracy)
    val_binary_accuracy_list.append(val_binary_accuracy)
    val_multi_class_accuracy_list.append(val_multi_class_accuracy)

    # Log the metrics
    print(f"Epoch {epoch + 1}/{epochs}")
    print(f"  Loss: {avg_loss:.4f}")
    print(f"  Training Binary Accuracy: {train_binary_accuracy:.2f}%")
    print(f"  Training Multi-Class Accuracy: {train_multi_class_accuracy:.2f}%")
    print(f"  Validation Binary Accuracy: {val_binary_accuracy:.2f}%")
    print(f"  Validation Multi-Class Accuracy: {val_multi_class_accuracy:.2f}%")

    # Save model checkpoint
    checkpoint_path = os.path.join(checkpoint_dir, f"model_epoch_{epoch + 1}.pth")
    torch.save({
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': avg_loss,
        'train_binary_accuracy': train_binary_accuracy,
        'train_multi_class_accuracy': train_multi_class_accuracy,
        'val_binary_accuracy': val_binary_accuracy,
        'val_multi_class_accuracy': val_multi_class_accuracy
    }, checkpoint_path)

    print(f"Model checkpoint saved at {checkpoint_path}")


Epoch 1/50 - Training: 100%|██████████| 30/30 [01:31<00:00,  3.05s/it]
Epoch 1/50 - Validation: 100%|██████████| 7/7 [00:26<00:00,  3.81s/it]


Epoch 1/50
  Loss: 5.7504
  Training Binary Accuracy: 52.31%
  Training Multi-Class Accuracy: 28.46%
  Validation Binary Accuracy: 61.19%
  Validation Multi-Class Accuracy: 31.59%


Epoch 2/50 - Training: 100%|██████████| 30/30 [01:35<00:00,  3.17s/it]
Epoch 2/50 - Validation: 100%|██████████| 7/7 [00:26<00:00,  3.77s/it]


Epoch 2/50
  Loss: 1.2693
  Training Binary Accuracy: 59.78%
  Training Multi-Class Accuracy: 32.70%
  Validation Binary Accuracy: 61.19%
  Validation Multi-Class Accuracy: 27.61%


Epoch 3/50 - Training: 100%|██████████| 30/30 [01:27<00:00,  2.93s/it]
Epoch 3/50 - Validation: 100%|██████████| 7/7 [00:26<00:00,  3.82s/it]


Epoch 3/50
  Loss: 1.1646
  Training Binary Accuracy: 65.82%
  Training Multi-Class Accuracy: 39.48%
  Validation Binary Accuracy: 64.43%
  Validation Multi-Class Accuracy: 48.76%


Epoch 4/50 - Training: 100%|██████████| 30/30 [01:29<00:00,  2.98s/it]
Epoch 4/50 - Validation: 100%|██████████| 7/7 [00:26<00:00,  3.80s/it]


Epoch 4/50
  Loss: 1.0335
  Training Binary Accuracy: 73.34%
  Training Multi-Class Accuracy: 45.89%
  Validation Binary Accuracy: 71.89%
  Validation Multi-Class Accuracy: 45.77%


Epoch 5/50 - Training: 100%|██████████| 30/30 [01:30<00:00,  3.03s/it]
Epoch 5/50 - Validation: 100%|██████████| 7/7 [00:26<00:00,  3.81s/it]


Epoch 5/50
  Loss: 0.9113
  Training Binary Accuracy: 80.60%
  Training Multi-Class Accuracy: 53.05%
  Validation Binary Accuracy: 80.60%
  Validation Multi-Class Accuracy: 61.69%


Epoch 6/50 - Training: 100%|██████████| 30/30 [01:32<00:00,  3.08s/it]
Epoch 6/50 - Validation: 100%|██████████| 7/7 [00:25<00:00,  3.67s/it]


Epoch 6/50
  Loss: 0.7942
  Training Binary Accuracy: 84.21%
  Training Multi-Class Accuracy: 59.57%
  Validation Binary Accuracy: 84.33%
  Validation Multi-Class Accuracy: 56.22%


Epoch 7/50 - Training: 100%|██████████| 30/30 [01:36<00:00,  3.22s/it]
Epoch 7/50 - Validation: 100%|██████████| 7/7 [00:29<00:00,  4.17s/it]


Epoch 7/50
  Loss: 0.6874
  Training Binary Accuracy: 88.61%
  Training Multi-Class Accuracy: 63.65%
  Validation Binary Accuracy: 83.58%
  Validation Multi-Class Accuracy: 65.67%


Epoch 8/50 - Training: 100%|██████████| 30/30 [01:30<00:00,  3.01s/it]
Epoch 8/50 - Validation: 100%|██████████| 7/7 [00:27<00:00,  3.88s/it]


Epoch 8/50
  Loss: 0.6038
  Training Binary Accuracy: 90.30%
  Training Multi-Class Accuracy: 69.10%
  Validation Binary Accuracy: 90.30%
  Validation Multi-Class Accuracy: 70.90%


Epoch 9/50 - Training: 100%|██████████| 30/30 [01:31<00:00,  3.04s/it]
Epoch 9/50 - Validation: 100%|██████████| 7/7 [00:29<00:00,  4.16s/it]


Epoch 9/50
  Loss: 0.5142
  Training Binary Accuracy: 91.89%
  Training Multi-Class Accuracy: 73.24%
  Validation Binary Accuracy: 86.82%
  Validation Multi-Class Accuracy: 69.15%


Epoch 10/50 - Training: 100%|██████████| 30/30 [01:33<00:00,  3.11s/it]
Epoch 10/50 - Validation: 100%|██████████| 7/7 [00:26<00:00,  3.73s/it]


Epoch 10/50
  Loss: 0.4868
  Training Binary Accuracy: 94.01%
  Training Multi-Class Accuracy: 75.89%
  Validation Binary Accuracy: 88.31%
  Validation Multi-Class Accuracy: 72.89%


Epoch 11/50 - Training: 100%|██████████| 30/30 [01:30<00:00,  3.01s/it]
Epoch 11/50 - Validation: 100%|██████████| 7/7 [00:28<00:00,  4.01s/it]


Epoch 11/50
  Loss: 0.4014
  Training Binary Accuracy: 94.49%
  Training Multi-Class Accuracy: 79.70%
  Validation Binary Accuracy: 91.04%
  Validation Multi-Class Accuracy: 73.13%


Epoch 12/50 - Training: 100%|██████████| 30/30 [01:29<00:00,  2.99s/it]
Epoch 12/50 - Validation: 100%|██████████| 7/7 [00:26<00:00,  3.85s/it]


Epoch 12/50
  Loss: 0.4083
  Training Binary Accuracy: 94.17%
  Training Multi-Class Accuracy: 79.44%
  Validation Binary Accuracy: 85.82%
  Validation Multi-Class Accuracy: 70.15%


Epoch 13/50 - Training: 100%|██████████| 30/30 [01:32<00:00,  3.08s/it]
Epoch 13/50 - Validation: 100%|██████████| 7/7 [00:28<00:00,  4.03s/it]


Epoch 13/50
  Loss: 0.3188
  Training Binary Accuracy: 96.18%
  Training Multi-Class Accuracy: 84.37%
  Validation Binary Accuracy: 89.80%
  Validation Multi-Class Accuracy: 73.88%


Epoch 14/50 - Training: 100%|██████████| 30/30 [01:31<00:00,  3.04s/it]
Epoch 14/50 - Validation: 100%|██████████| 7/7 [00:26<00:00,  3.81s/it]


Epoch 14/50
  Loss: 0.2184
  Training Binary Accuracy: 98.25%
  Training Multi-Class Accuracy: 89.30%
  Validation Binary Accuracy: 91.04%
  Validation Multi-Class Accuracy: 71.14%


Epoch 15/50 - Training: 100%|██████████| 30/30 [01:33<00:00,  3.10s/it]
Epoch 15/50 - Validation: 100%|██████████| 7/7 [00:26<00:00,  3.82s/it]


Epoch 15/50
  Loss: 0.1984
  Training Binary Accuracy: 98.36%
  Training Multi-Class Accuracy: 89.72%
  Validation Binary Accuracy: 89.55%
  Validation Multi-Class Accuracy: 72.89%


Epoch 16/50 - Training: 100%|██████████| 30/30 [01:33<00:00,  3.13s/it]
Epoch 16/50 - Validation: 100%|██████████| 7/7 [00:27<00:00,  3.88s/it]


Epoch 16/50
  Loss: 0.1370
  Training Binary Accuracy: 98.83%
  Training Multi-Class Accuracy: 92.90%
  Validation Binary Accuracy: 88.56%
  Validation Multi-Class Accuracy: 73.88%


Epoch 17/50 - Training: 100%|██████████| 30/30 [01:31<00:00,  3.04s/it]
Epoch 17/50 - Validation: 100%|██████████| 7/7 [00:29<00:00,  4.18s/it]


Epoch 17/50
  Loss: 0.1955
  Training Binary Accuracy: 97.77%
  Training Multi-Class Accuracy: 91.26%
  Validation Binary Accuracy: 90.55%
  Validation Multi-Class Accuracy: 72.39%


Epoch 18/50 - Training: 100%|██████████| 30/30 [01:32<00:00,  3.08s/it]
Epoch 18/50 - Validation: 100%|██████████| 7/7 [00:28<00:00,  4.08s/it]


Epoch 18/50
  Loss: 0.1182
  Training Binary Accuracy: 98.73%
  Training Multi-Class Accuracy: 94.17%
  Validation Binary Accuracy: 89.30%
  Validation Multi-Class Accuracy: 76.87%


Epoch 19/50 - Training: 100%|██████████| 30/30 [01:30<00:00,  3.00s/it]
Epoch 19/50 - Validation: 100%|██████████| 7/7 [00:27<00:00,  4.00s/it]


Epoch 19/50
  Loss: 0.0953
  Training Binary Accuracy: 99.42%
  Training Multi-Class Accuracy: 95.18%
  Validation Binary Accuracy: 89.05%
  Validation Multi-Class Accuracy: 74.38%


Epoch 20/50 - Training: 100%|██████████| 30/30 [01:28<00:00,  2.94s/it]
Epoch 20/50 - Validation: 100%|██████████| 7/7 [00:26<00:00,  3.80s/it]


Epoch 20/50
  Loss: 0.1511
  Training Binary Accuracy: 97.51%
  Training Multi-Class Accuracy: 93.64%
  Validation Binary Accuracy: 89.30%
  Validation Multi-Class Accuracy: 76.37%


Epoch 21/50 - Training: 100%|██████████| 30/30 [01:30<00:00,  3.01s/it]
Epoch 21/50 - Validation: 100%|██████████| 7/7 [00:26<00:00,  3.81s/it]


Epoch 21/50
  Loss: 0.1821
  Training Binary Accuracy: 98.57%
  Training Multi-Class Accuracy: 92.21%
  Validation Binary Accuracy: 91.54%
  Validation Multi-Class Accuracy: 76.12%


Epoch 22/50 - Training: 100%|██████████| 30/30 [01:30<00:00,  3.02s/it]
Epoch 22/50 - Validation: 100%|██████████| 7/7 [00:27<00:00,  3.93s/it]


Epoch 22/50
  Loss: 0.1048
  Training Binary Accuracy: 98.41%
  Training Multi-Class Accuracy: 95.44%
  Validation Binary Accuracy: 92.04%
  Validation Multi-Class Accuracy: 78.86%


Epoch 23/50 - Training: 100%|██████████| 30/30 [01:31<00:00,  3.05s/it]
Epoch 23/50 - Validation: 100%|██████████| 7/7 [00:27<00:00,  3.87s/it]


Epoch 23/50
  Loss: 0.0838
  Training Binary Accuracy: 99.52%
  Training Multi-Class Accuracy: 96.66%
  Validation Binary Accuracy: 90.05%
  Validation Multi-Class Accuracy: 77.11%


Epoch 24/50 - Training: 100%|██████████| 30/30 [01:30<00:00,  3.01s/it]
Epoch 24/50 - Validation: 100%|██████████| 7/7 [00:27<00:00,  3.92s/it]


Epoch 24/50
  Loss: 0.0574
  Training Binary Accuracy: 99.31%
  Training Multi-Class Accuracy: 97.40%
  Validation Binary Accuracy: 91.54%
  Validation Multi-Class Accuracy: 78.61%


Epoch 25/50 - Training: 100%|██████████| 30/30 [01:35<00:00,  3.20s/it]
Epoch 25/50 - Validation: 100%|██████████| 7/7 [00:26<00:00,  3.84s/it]


Epoch 25/50
  Loss: 0.1576
  Training Binary Accuracy: 98.41%
  Training Multi-Class Accuracy: 94.70%
  Validation Binary Accuracy: 89.30%
  Validation Multi-Class Accuracy: 75.37%


Epoch 26/50 - Training: 100%|██████████| 30/30 [01:29<00:00,  2.98s/it]
Epoch 26/50 - Validation: 100%|██████████| 7/7 [00:26<00:00,  3.82s/it]


Epoch 26/50
  Loss: 0.1110
  Training Binary Accuracy: 99.10%
  Training Multi-Class Accuracy: 95.34%
  Validation Binary Accuracy: 91.29%
  Validation Multi-Class Accuracy: 77.36%


Epoch 27/50 - Training: 100%|██████████| 30/30 [01:29<00:00,  2.97s/it]
Epoch 27/50 - Validation: 100%|██████████| 7/7 [00:26<00:00,  3.83s/it]


Epoch 27/50
  Loss: 0.1092
  Training Binary Accuracy: 99.31%
  Training Multi-Class Accuracy: 95.92%
  Validation Binary Accuracy: 90.05%
  Validation Multi-Class Accuracy: 76.87%


Epoch 28/50 - Training: 100%|██████████| 30/30 [01:35<00:00,  3.18s/it]
Epoch 28/50 - Validation: 100%|██████████| 7/7 [00:26<00:00,  3.84s/it]


Epoch 28/50
  Loss: 0.0636
  Training Binary Accuracy: 99.05%
  Training Multi-Class Accuracy: 97.09%
  Validation Binary Accuracy: 88.06%
  Validation Multi-Class Accuracy: 73.63%


Epoch 29/50 - Training: 100%|██████████| 30/30 [01:27<00:00,  2.92s/it]
Epoch 29/50 - Validation: 100%|██████████| 7/7 [00:26<00:00,  3.82s/it]


Epoch 29/50
  Loss: 0.0643
  Training Binary Accuracy: 99.10%
  Training Multi-Class Accuracy: 97.14%
  Validation Binary Accuracy: 88.31%
  Validation Multi-Class Accuracy: 76.12%


Epoch 30/50 - Training: 100%|██████████| 30/30 [01:29<00:00,  2.99s/it]
Epoch 30/50 - Validation: 100%|██████████| 7/7 [00:28<00:00,  4.07s/it]


Epoch 30/50
  Loss: 0.0686
  Training Binary Accuracy: 99.42%
  Training Multi-Class Accuracy: 97.62%
  Validation Binary Accuracy: 90.30%
  Validation Multi-Class Accuracy: 79.35%


Epoch 31/50 - Training:   0%|          | 0/30 [00:03<?, ?it/s]


KeyboardInterrupt: 

In [None]:
# Save the model's state dictionary
torch.save(model.state_dict(), "edgenext_aug2.pth")
print("Model saved successfully!")

Model saved successfully!


In [None]:
# Save the metrics to a CSV file
data = {
    'Epoch': epoch_list,
    'Loss': loss_list,
    'Train Binary Accuracy': train_binary_accuracy_list,
    'Train Multi-Class Accuracy': train_multi_class_accuracy_list,
    'Validation Binary Accuracy': val_binary_accuracy_list,
    'Validation Multi-Class Accuracy': val_multi_class_accuracy_list
}

# Create a DataFrame from the dictionary
df = pd.DataFrame(data)

# Save the DataFrame to a CSV file
df.to_csv('edgenext_aug2_training_results.csv', index=False)
print("Training results saved successfully!")

Training results saved successfully!


## Test the Model

In [None]:
#model.load_state_dict(torch.load("efficientnetv2_multi_classifier.pth"))

In [None]:
# Define transformations for the test set
test_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Load your test dataset
test_dataset = CustomFireDataset(root_dir="/root/.cache/kagglehub/datasets/elmadafri/the-wildfire-dataset/versions/1/the_wildfire_dataset/the_wildfire_dataset/test", transform=test_transforms)

# Create a DataLoader for the test dataset
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=10)

In [None]:
from sklearn.metrics import roc_curve, auc, confusion_matrix, classification_report
import pandas as pd
import numpy as np
import torch

# Initialize lists to store results
true_binary_labels = []
predicted_binary_probs = []  # Probabilities for the positive class (fire)

true_multi_class_labels = []
predicted_multi_class_labels = []

# Evaluate the model on the test set
model.eval()
with torch.no_grad():
    for images, (binary_labels, multi_class_labels) in test_loader:
        # Move data to device
        images = images.to(device)
        binary_labels = binary_labels.to(device)
        multi_class_labels = multi_class_labels.to(device)

        # Forward pass through the model
        binary_output, multi_class_output = model(images)

        # Apply sigmoid to binary outputs for probabilities
        binary_probs = torch.sigmoid(binary_output).squeeze().cpu().numpy()  # Convert logits to probabilities

        # Get predicted class labels for multi-class classification
        _, predicted_multi_class = torch.max(multi_class_output, 1)
        predicted_multi_class = predicted_multi_class.cpu().numpy()

        # Store true labels and predictions for binary classification
        true_binary_labels.extend(binary_labels.cpu().numpy())
        predicted_binary_probs.extend(binary_probs)

        # Store true labels and predictions for multi-class classification
        true_multi_class_labels.extend(multi_class_labels.cpu().numpy())
        predicted_multi_class_labels.extend(predicted_multi_class)

# Combine all data into a single DataFrame
results = pd.DataFrame({
    "True Binary Labels": true_binary_labels,
    "Predicted Binary Probabilities": predicted_binary_probs,
    "True Multi-Class Labels": true_multi_class_labels,
    "Predicted Multi-Class Labels": predicted_multi_class_labels
})

# Save results to CSV
results.to_csv("edgenext_aug2_test_results.csv", index=False)
print("Test results saved to 'edgenext_test_results.csv'.")



Test results saved to 'edgenext_test_results.csv'.
