# EfficientNetV2 - Large

In [None]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("elmadafri/the-wildfire-dataset/versions/1")

print("Path to dataset files:", path)

Downloading from https://www.kaggle.com/api/v1/datasets/download/elmadafri/the-wildfire-dataset?dataset_version_number=1...


100%|██████████| 9.94G/9.94G [01:26<00:00, 123MB/s]

Extracting files...





Path to dataset files: /root/.cache/kagglehub/datasets/elmadafri/the-wildfire-dataset/versions/1


In [None]:
!rm /root/.cache/kagglehub/datasets/elmadafri/the-wildfire-dataset/versions/1/the_wildfire_dataset/the_wildfire_dataset/val/fire/Both_smoke_and_fire/desktop.ini

In [None]:
import torch
import pandas as pd
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms

## Load Model

In [None]:
# Use pretrained weights
model = models.efficientnet_v2_s(weights="DEFAULT")

Downloading: "https://download.pytorch.org/models/efficientnet_v2_s-dd5fe13b.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_v2_s-dd5fe13b.pth
100%|██████████| 82.7M/82.7M [00:00<00:00, 153MB/s]


In [None]:
in_features = 1280

# Strip out original classifier
model.classifier = nn.Identity()

binary_head = nn.Sequential(
    nn.Linear(in_features, 512),
    nn.LeakyReLU(),
    nn.Dropout(0.2),
    nn.Linear(512, 1),
    nn.Sigmoid()
)

multi_class_head = nn.Sequential(
    nn.Linear(in_features, 512),
    nn.LeakyReLU(),
    nn.Linear(512, 5)
)

### Change Notes:
Added new layers in between output for better generalizability and to prevent plateuing as seen earlier

In [None]:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset

class CustomFireDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.binary_labels = []
        self.multi_class_labels = []

        # Define mappings for binary and multi-class labels
        binary_label_mapping = {'fire': 1, 'nofire': 0}
        multi_class_mapping = {
            'fire': {'Both_smoke_and_fire': 0, 'Smoke_from_fires': 1},
            'nofire': {'Fire_confounding_elements': 2, 'Forested_areas_without_confounding_elements': 3, 'Smoke_confounding_elements': 4}
        }

        # Traverse the root directory and collect image paths and labels
        for binary_label_name in os.listdir(root_dir):
            binary_label_path = os.path.join(root_dir, binary_label_name)
            if os.path.isdir(binary_label_path):
                # Assign binary label
                binary_label = binary_label_mapping[binary_label_name]

                # Traverse subclasses
                for subclass_name in os.listdir(binary_label_path):
                    subclass_path = os.path.join(binary_label_path, subclass_name)
                    if os.path.isdir(subclass_path):
                        # Assign multi-class label
                        multi_class_label = multi_class_mapping[binary_label_name][subclass_name]

                        # Collect all images in the subclass directory
                        for img_name in os.listdir(subclass_path):
                            img_path = os.path.join(subclass_path, img_name)
                            if os.path.isfile(img_path):
                                self.image_paths.append(img_path)
                                self.binary_labels.append(binary_label)
                                self.multi_class_labels.append(multi_class_label)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")
        binary_label = self.binary_labels[idx]
        multi_class_label = self.multi_class_labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, (torch.tensor(binary_label, dtype=torch.float), torch.tensor(multi_class_label, dtype=torch.long))


In [None]:
train_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Load your datasets (replace "path_to_train_dataset" and "path_to_val_dataset" with your paths)
#train_dataset = datasets.ImageFolder("/root/.cache/kagglehub/datasets/elmadafri/the-wildfire-dataset/versions/3/the_wildfire_dataset_2n_version/train", transform=train_transforms)
#val_dataset = datasets.ImageFolder("/root/.cache/kagglehub/datasets/elmadafri/the-wildfire-dataset/versions/3/the_wildfire_dataset_2n_version/val", transform=val_transforms)
train_dataset = CustomFireDataset(root_dir="/root/.cache/kagglehub/datasets/elmadafri/the-wildfire-dataset/versions/1/the_wildfire_dataset/the_wildfire_dataset/train", transform=train_transforms)
val_dataset = CustomFireDataset(root_dir="/root/.cache/kagglehub/datasets/elmadafri/the-wildfire-dataset/versions/1/the_wildfire_dataset/the_wildfire_dataset/val", transform=val_transforms)


train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=10)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=10)

In [None]:
criterion_binary = nn.BCELoss()
criterion_multi_class = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(
    list(model.parameters()) + list(binary_head.parameters()) + list(multi_class_head.parameters()),
    lr=0.002
)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
binary_head = binary_head.to(device)
multi_class_head = multi_class_head.to(device)

In [None]:
def combined_loss(binary_output, binary_target, multi_class_output, multi_class_target, alpha=0.5, beta=0.5):
    loss_binary = criterion_binary(binary_output, binary_target)
    loss_multi_class = criterion_multi_class(multi_class_output, multi_class_target)
    return alpha * loss_binary + beta * loss_multi_class

In [None]:
epoch_list = []
loss_list = []
train_binary_accuracy_list = []
train_multi_class_accuracy_list = []
val_binary_accuracy_list = []
val_multi_class_accuracy_list = []

# Training loop
epochs = 50  # You can adjust the number of epochs based on your needs
for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    correct_train_binary = 0
    total_train_binary = 0
    correct_train_multi_class = 0
    total_train_multi_class = 0

    # Training Loop
    for images, (binary_labels, multi_class_labels) in train_loader:
        images = images.to(device)
        binary_labels = binary_labels.to(device).float()
        multi_class_labels = multi_class_labels.to(device).long()

        optimizer.zero_grad()
        features = model(images)
        binary_output = binary_head(features)
        multi_class_output = multi_class_head(features)

        # Compute the combined loss
        loss = combined_loss(binary_output.squeeze(), binary_labels, multi_class_output, multi_class_labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        # Calculate training accuracy for binary classification
        predicted_binary = (binary_output.squeeze() > 0.5).int()  # Convert probabilities to binary predictions
        correct_train_binary += (predicted_binary == binary_labels.int()).sum().item()
        total_train_binary += binary_labels.size(0)

        # Calculate training accuracy for multi-class classification
        _, predicted_multi_class = torch.max(multi_class_output, 1)
        correct_train_multi_class += (predicted_multi_class == multi_class_labels).sum().item()
        total_train_multi_class += multi_class_labels.size(0)

    # Calculate the average loss and training accuracies
    avg_loss = running_loss / len(train_loader)
    train_binary_accuracy = 100 * correct_train_binary / total_train_binary
    train_multi_class_accuracy = 100 * correct_train_multi_class / total_train_multi_class

    # Validation Loop
    model.eval()
    correct_val_binary = 0
    total_val_binary = 0
    correct_val_multi_class = 0
    total_val_multi_class = 0
    with torch.no_grad():
        for images, (binary_labels, multi_class_labels) in val_loader:
            images = images.to(device)
            binary_labels = binary_labels.to(device).float()
            multi_class_labels = multi_class_labels.to(device).long()

            features = model(images)
            binary_output = binary_head(features)
            multi_class_output = multi_class_head(features)

            # Calculate validation accuracy for binary classification
            predicted_binary = (binary_output.squeeze() > 0.5).int()
            correct_val_binary += (predicted_binary == binary_labels.int()).sum().item()
            total_val_binary += binary_labels.size(0)

            # Calculate validation accuracy for multi-class classification
            _, predicted_multi_class = torch.max(multi_class_output, 1)
            correct_val_multi_class += (predicted_multi_class == multi_class_labels).sum().item()
            total_val_multi_class += multi_class_labels.size(0)

    # Calculate validation accuracies
    val_binary_accuracy = 100 * correct_val_binary / total_val_binary
    val_multi_class_accuracy = 100 * correct_val_multi_class / total_val_multi_class

    # Store the metrics
    epoch_list.append(epoch + 1)
    loss_list.append(avg_loss)
    train_binary_accuracy_list.append(train_binary_accuracy)
    train_multi_class_accuracy_list.append(train_multi_class_accuracy)
    val_binary_accuracy_list.append(val_binary_accuracy)
    val_multi_class_accuracy_list.append(val_multi_class_accuracy)

    # Log the metrics
    print(f"Epoch {epoch+1}/{epochs}")
    print(f"  Loss: {avg_loss:.4f}")
    print(f"  Training Binary Accuracy: {train_binary_accuracy:.2f}%")
    print(f"  Training Multi-Class Accuracy: {train_multi_class_accuracy:.2f}%")
    print(f"  Validation Binary Accuracy: {val_binary_accuracy:.2f}%")
    print(f"  Validation Multi-Class Accuracy: {val_multi_class_accuracy:.2f}%")



Epoch 1/50
  Loss: 0.8030
  Training Binary Accuracy: 79.23%
  Training Multi-Class Accuracy: 55.33%
  Validation Binary Accuracy: 79.60%
  Validation Multi-Class Accuracy: 62.69%




Epoch 2/50
  Loss: 0.6600
  Training Binary Accuracy: 86.06%
  Training Multi-Class Accuracy: 62.43%
  Validation Binary Accuracy: 87.81%
  Validation Multi-Class Accuracy: 68.91%




Epoch 3/50
  Loss: 0.5892
  Training Binary Accuracy: 88.18%
  Training Multi-Class Accuracy: 65.61%
  Validation Binary Accuracy: 83.08%
  Validation Multi-Class Accuracy: 60.45%




Epoch 4/50
  Loss: 0.5398
  Training Binary Accuracy: 89.45%
  Training Multi-Class Accuracy: 68.52%
  Validation Binary Accuracy: 84.33%
  Validation Multi-Class Accuracy: 68.16%




Epoch 5/50
  Loss: 0.5125
  Training Binary Accuracy: 90.25%
  Training Multi-Class Accuracy: 70.16%
  Validation Binary Accuracy: 86.07%
  Validation Multi-Class Accuracy: 70.90%




Epoch 6/50
  Loss: 0.4576
  Training Binary Accuracy: 91.41%
  Training Multi-Class Accuracy: 72.66%
  Validation Binary Accuracy: 87.31%
  Validation Multi-Class Accuracy: 74.88%




Epoch 7/50
  Loss: 0.4370
  Training Binary Accuracy: 91.73%
  Training Multi-Class Accuracy: 74.40%
  Validation Binary Accuracy: 87.06%
  Validation Multi-Class Accuracy: 74.38%




Epoch 8/50
  Loss: 0.4234
  Training Binary Accuracy: 93.22%
  Training Multi-Class Accuracy: 75.78%
  Validation Binary Accuracy: 89.30%
  Validation Multi-Class Accuracy: 72.64%




Epoch 9/50
  Loss: 0.3669
  Training Binary Accuracy: 93.85%
  Training Multi-Class Accuracy: 79.23%
  Validation Binary Accuracy: 88.56%
  Validation Multi-Class Accuracy: 73.38%




Epoch 10/50
  Loss: 0.3744
  Training Binary Accuracy: 94.22%
  Training Multi-Class Accuracy: 78.38%
  Validation Binary Accuracy: 87.56%
  Validation Multi-Class Accuracy: 72.14%




Epoch 11/50
  Loss: 0.3351
  Training Binary Accuracy: 94.70%
  Training Multi-Class Accuracy: 80.29%
  Validation Binary Accuracy: 88.81%
  Validation Multi-Class Accuracy: 73.38%




Epoch 12/50
  Loss: 0.3217
  Training Binary Accuracy: 94.86%
  Training Multi-Class Accuracy: 80.71%
  Validation Binary Accuracy: 87.31%
  Validation Multi-Class Accuracy: 74.63%




Epoch 13/50
  Loss: 0.3023
  Training Binary Accuracy: 95.39%
  Training Multi-Class Accuracy: 82.14%
  Validation Binary Accuracy: 89.80%
  Validation Multi-Class Accuracy: 75.12%




Epoch 14/50
  Loss: 0.2554
  Training Binary Accuracy: 96.93%
  Training Multi-Class Accuracy: 84.47%
  Validation Binary Accuracy: 89.05%
  Validation Multi-Class Accuracy: 75.87%




Epoch 15/50
  Loss: 0.3054
  Training Binary Accuracy: 95.28%
  Training Multi-Class Accuracy: 81.61%
  Validation Binary Accuracy: 86.07%
  Validation Multi-Class Accuracy: 74.13%




Epoch 16/50
  Loss: 0.2884
  Training Binary Accuracy: 95.97%
  Training Multi-Class Accuracy: 84.68%
  Validation Binary Accuracy: 87.81%
  Validation Multi-Class Accuracy: 74.13%




Epoch 17/50
  Loss: 0.2457
  Training Binary Accuracy: 96.61%
  Training Multi-Class Accuracy: 85.43%
  Validation Binary Accuracy: 90.30%
  Validation Multi-Class Accuracy: 78.11%




Epoch 18/50
  Loss: 0.2380
  Training Binary Accuracy: 96.82%
  Training Multi-Class Accuracy: 86.75%
  Validation Binary Accuracy: 92.29%
  Validation Multi-Class Accuracy: 76.87%




Epoch 19/50
  Loss: 0.2015
  Training Binary Accuracy: 96.82%
  Training Multi-Class Accuracy: 87.81%
  Validation Binary Accuracy: 90.05%
  Validation Multi-Class Accuracy: 73.88%




Epoch 20/50
  Loss: 0.1774
  Training Binary Accuracy: 98.25%
  Training Multi-Class Accuracy: 88.87%
  Validation Binary Accuracy: 88.06%
  Validation Multi-Class Accuracy: 73.88%




Epoch 21/50
  Loss: 0.1547
  Training Binary Accuracy: 98.25%
  Training Multi-Class Accuracy: 90.94%
  Validation Binary Accuracy: 88.31%
  Validation Multi-Class Accuracy: 75.62%




Epoch 22/50
  Loss: 0.2201
  Training Binary Accuracy: 96.61%
  Training Multi-Class Accuracy: 87.44%
  Validation Binary Accuracy: 88.06%
  Validation Multi-Class Accuracy: 74.13%




Epoch 23/50
  Loss: 0.1920
  Training Binary Accuracy: 96.98%
  Training Multi-Class Accuracy: 88.82%
  Validation Binary Accuracy: 91.04%
  Validation Multi-Class Accuracy: 75.37%




Epoch 24/50
  Loss: 0.1360
  Training Binary Accuracy: 98.15%
  Training Multi-Class Accuracy: 92.00%
  Validation Binary Accuracy: 91.04%
  Validation Multi-Class Accuracy: 78.11%




Epoch 25/50
  Loss: 0.1649
  Training Binary Accuracy: 97.99%
  Training Multi-Class Accuracy: 90.25%
  Validation Binary Accuracy: 86.82%
  Validation Multi-Class Accuracy: 72.89%




Epoch 26/50
  Loss: 0.1757
  Training Binary Accuracy: 96.77%
  Training Multi-Class Accuracy: 89.88%
  Validation Binary Accuracy: 86.32%
  Validation Multi-Class Accuracy: 74.38%




Epoch 27/50
  Loss: 0.1631
  Training Binary Accuracy: 97.83%
  Training Multi-Class Accuracy: 90.78%
  Validation Binary Accuracy: 90.55%
  Validation Multi-Class Accuracy: 71.39%




Epoch 28/50
  Loss: 0.1238
  Training Binary Accuracy: 98.36%
  Training Multi-Class Accuracy: 93.48%
  Validation Binary Accuracy: 89.05%
  Validation Multi-Class Accuracy: 76.12%




Epoch 29/50
  Loss: 0.1470
  Training Binary Accuracy: 97.83%
  Training Multi-Class Accuracy: 92.16%
  Validation Binary Accuracy: 88.56%
  Validation Multi-Class Accuracy: 74.63%




Epoch 30/50
  Loss: 0.1131
  Training Binary Accuracy: 99.05%
  Training Multi-Class Accuracy: 93.38%
  Validation Binary Accuracy: 88.31%
  Validation Multi-Class Accuracy: 76.37%




Epoch 31/50
  Loss: 0.1495
  Training Binary Accuracy: 97.51%
  Training Multi-Class Accuracy: 92.42%
  Validation Binary Accuracy: 90.05%
  Validation Multi-Class Accuracy: 76.12%




Epoch 32/50
  Loss: 0.0872
  Training Binary Accuracy: 99.26%
  Training Multi-Class Accuracy: 94.91%
  Validation Binary Accuracy: 89.55%
  Validation Multi-Class Accuracy: 76.62%




Epoch 33/50
  Loss: 0.0650
  Training Binary Accuracy: 99.42%
  Training Multi-Class Accuracy: 95.81%
  Validation Binary Accuracy: 87.56%
  Validation Multi-Class Accuracy: 72.39%




Epoch 34/50
  Loss: 0.0974
  Training Binary Accuracy: 98.62%
  Training Multi-Class Accuracy: 94.06%
  Validation Binary Accuracy: 88.81%
  Validation Multi-Class Accuracy: 73.38%




Epoch 35/50
  Loss: 0.1296
  Training Binary Accuracy: 98.20%
  Training Multi-Class Accuracy: 92.63%
  Validation Binary Accuracy: 91.04%
  Validation Multi-Class Accuracy: 80.10%




Epoch 36/50
  Loss: 0.1027
  Training Binary Accuracy: 98.62%
  Training Multi-Class Accuracy: 94.75%
  Validation Binary Accuracy: 84.08%
  Validation Multi-Class Accuracy: 71.64%




Epoch 37/50
  Loss: 0.0799
  Training Binary Accuracy: 98.83%
  Training Multi-Class Accuracy: 95.65%
  Validation Binary Accuracy: 88.31%
  Validation Multi-Class Accuracy: 71.39%




Epoch 38/50
  Loss: 0.0787
  Training Binary Accuracy: 99.05%
  Training Multi-Class Accuracy: 95.39%
  Validation Binary Accuracy: 87.81%
  Validation Multi-Class Accuracy: 75.12%




Epoch 39/50
  Loss: 0.0792
  Training Binary Accuracy: 99.36%
  Training Multi-Class Accuracy: 95.18%
  Validation Binary Accuracy: 89.80%
  Validation Multi-Class Accuracy: 75.37%




Epoch 40/50
  Loss: 0.0789
  Training Binary Accuracy: 98.78%
  Training Multi-Class Accuracy: 95.97%
  Validation Binary Accuracy: 87.56%
  Validation Multi-Class Accuracy: 71.39%




Epoch 41/50
  Loss: 0.1057
  Training Binary Accuracy: 98.25%
  Training Multi-Class Accuracy: 94.59%
  Validation Binary Accuracy: 91.04%
  Validation Multi-Class Accuracy: 76.87%




Epoch 42/50
  Loss: 0.0726
  Training Binary Accuracy: 98.68%
  Training Multi-Class Accuracy: 96.18%
  Validation Binary Accuracy: 92.54%
  Validation Multi-Class Accuracy: 79.10%




Epoch 43/50
  Loss: 0.0801
  Training Binary Accuracy: 99.10%
  Training Multi-Class Accuracy: 95.02%
  Validation Binary Accuracy: 88.31%
  Validation Multi-Class Accuracy: 71.39%




Epoch 44/50
  Loss: 0.1047
  Training Binary Accuracy: 98.04%
  Training Multi-Class Accuracy: 95.50%
  Validation Binary Accuracy: 89.55%
  Validation Multi-Class Accuracy: 76.62%




Epoch 45/50
  Loss: 0.0709
  Training Binary Accuracy: 99.05%
  Training Multi-Class Accuracy: 96.18%
  Validation Binary Accuracy: 90.30%
  Validation Multi-Class Accuracy: 77.61%




Epoch 46/50
  Loss: 0.0621
  Training Binary Accuracy: 99.36%
  Training Multi-Class Accuracy: 96.24%
  Validation Binary Accuracy: 91.04%
  Validation Multi-Class Accuracy: 77.86%




Epoch 47/50
  Loss: 0.0657
  Training Binary Accuracy: 98.83%
  Training Multi-Class Accuracy: 96.18%
  Validation Binary Accuracy: 85.57%
  Validation Multi-Class Accuracy: 72.14%




Epoch 48/50
  Loss: 0.1036
  Training Binary Accuracy: 98.04%
  Training Multi-Class Accuracy: 94.91%
  Validation Binary Accuracy: 89.05%
  Validation Multi-Class Accuracy: 76.87%




Epoch 49/50
  Loss: 0.0740
  Training Binary Accuracy: 99.05%
  Training Multi-Class Accuracy: 95.55%
  Validation Binary Accuracy: 90.30%
  Validation Multi-Class Accuracy: 75.37%




Epoch 50/50
  Loss: 0.0419
  Training Binary Accuracy: 99.42%
  Training Multi-Class Accuracy: 97.93%
  Validation Binary Accuracy: 86.82%
  Validation Multi-Class Accuracy: 72.89%


In [None]:
# Save the model's state dictionary
torch.save(model.state_dict(), "aug_efficientnetv2_multi_classifier.pth")
print("Model saved successfully!")

Model saved successfully!


In [None]:
# Save the metrics to a CSV file
data = {
    'Epoch': epoch_list,
    'Loss': loss_list,
    'Train Binary Accuracy': train_binary_accuracy_list,
    'Train Multi-Class Accuracy': train_multi_class_accuracy_list,
    'Validation Binary Accuracy': val_binary_accuracy_list,
    'Validation Multi-Class Accuracy': val_multi_class_accuracy_list
}

# Create a DataFrame from the dictionary
df = pd.DataFrame(data)

# Save the DataFrame to a CSV file
df.to_csv('aug_efficientnetv2_training_results.csv', index=False)
print("Training results saved successfully!")

Training results saved successfully!


## Test the Model

In [None]:
# Define transformations for the test set
test_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Load your test dataset
test_dataset = CustomFireDataset(root_dir="/root/.cache/kagglehub/datasets/elmadafri/the-wildfire-dataset/versions/1/the_wildfire_dataset/the_wildfire_dataset/test", transform=test_transforms)

# Create a DataLoader for the test dataset
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=10)

In [None]:
from sklearn.metrics import roc_curve, auc, confusion_matrix, classification_report
import numpy as np
import torch

# Lists to store true labels and predicted probabilities/scores for binary classification
true_binary_labels = []
predicted_binary_probs = []  # Probabilities for the positive class (fire)

# Lists to store true labels and predicted labels for multi-class classification
true_multi_class_labels = []
predicted_multi_class_labels = []

# Evaluate the model on the test set
model.eval()
with torch.no_grad():
    for images, (binary_labels, multi_class_labels) in test_loader:
        images = images.to(device)
        binary_labels = binary_labels.to(device).float()
        multi_class_labels = multi_class_labels.to(device).long()

        # Forward pass
        features = model(images)
        binary_output = binary_head(features)
        multi_class_output = multi_class_head(features)

        # Get probabilities for the positive class (fire) in binary classification
        binary_probs = binary_output.squeeze().cpu().numpy()  # Probabilities from Sigmoid

        # Get predicted class labels for multi-class classification
        _, predicted_multi_class = torch.max(multi_class_output, 1)
        predicted_multi_class = predicted_multi_class.cpu().numpy()

        # Store true labels and predictions for binary classification
        true_binary_labels.extend(binary_labels.cpu().numpy())
        predicted_binary_probs.extend(binary_probs)

        # Store true labels and predictions for multi-class classification
        true_multi_class_labels.extend(multi_class_labels.cpu().numpy())
        predicted_multi_class_labels.extend(predicted_multi_class)

# Combine all data into a single DataFrame
data = pd.DataFrame({
    "True Binary Labels": true_binary_labels,
    "Predicted Binary Probabilities": predicted_binary_probs,
    "True Multi-Class Labels": true_multi_class_labels,
    "Predicted Multi-Class Labels": predicted_multi_class_labels
})

data.to_csv("aug1_efficientnetv2l_test_results.csv", index=False)

