In [None]:
from PIL import Image
import os
import csv
import torch
import torch.nn as nn
from torchvision import models, transforms
from torch.utils.data import DataLoader

In [None]:
def test_model_and_save_results(model, test_dir, output_csv):
    model.eval()  # Set the model to evaluation mode

    results = []  # To store results for the CSV

    # Disable gradient computation during inference
    with torch.no_grad():
        for file_name in os.listdir(test_dir):
            file_path = os.path.join(test_dir, file_name)

            # Check if the file is an image
            if not file_name.lower().endswith(('.jpg', '.jpeg', '.png')):
                continue

            # Load and preprocess the image
            image = Image.open(file_path).convert("RGB")
            processed_image = test_transform(image).unsqueeze(0)

            # Perform inference
            # The model outputs a value between -1 and 1
            outputs = model(processed_image)
            # We need a probability score, this is why we apply sigmoid to outputs
            # We assume that if score >0.5 then the image contains a fire
            # IMPORTANT: the model outputs a logi between -1 and 1 
            # in this case the model outputs how likeley a picture is to 
            # NOT have a fire ( ~ -1: fire , ~ 1: NO fire )
            preds = torch.sigmoid(outputs) < 0.5  
            predicted_class = int(preds.item())  # Convert to 0 or 1

            # Append result
            results.append({'id': file_name, 'class': predicted_class})
            
    with open(output_csv, mode='w', newline='') as file:
        writer = csv.DictWriter(file, fieldnames=['id', 'class'])
        writer.writeheader()
        writer.writerows(results)

    print(f"Results saved to {output_csv}")

In [None]:
# Updated for batch processing during inference
from torch.utils.data import DataLoader, Dataset

class TestDataset(Dataset):
    def __init__(self, test_dir, transform):
        self.test_dir = test_dir
        self.transform = transform
        self.images = [f for f in os.listdir(test_dir) if f.endswith(('.jpg', '.png', '.jpeg'))]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = self.images[idx]
        img_path = os.path.join(self.test_dir, img_name)
        image = Image.open(img_path).convert("RGB")
        return img_name, self.transform(image)

# Create DataLoader for test set
test_dataset = TestDataset(test_dir="dl2425_challenge_dataset/test", transform=test_transform)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

# Batch inference
model.eval()
results = []
with torch.no_grad():
    for batch in test_loader:
        file_names, images = batch
        outputs = model(images)
        preds = torch.sigmoid(outputs) < 0.5
        results.extend(zip(file_names, preds.cpu().numpy().astype(int)))


# Define the model structure
def get_model_resnet18():
    model = models.resnet18(pretrained=True)  # Load pretrained ResNet18
    num_ftrs = model.fc.in_features
    # Modify the last layer for binary classification
    model.fc = nn.Linear(num_ftrs, 1)
    return model


# Import the model trained on kaggle
model = get_model_resnet18()
model.load_state_dict(torch.load("resnet18-15ep.pth", map_location=torch.device('cpu')))

# Define the transformations to pre-process test data,
# Make sure it's the same transformation as for validation and training
test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Path to test dir
test_dir = "dl2425_challenge_dataset/test"
# Save results to CSV
import csv
output_csv = "test_v2.csv"
with open(output_csv, mode='w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(["id", "class"])
    writer.writerows(results)

print(f"Batch inference completed. Results saved to {output_csv}")

In [None]:
test_model_and_save_results(model, test_dir, output_csv)

In [None]:
# Use this to empty a csv file
with open(output_csv, 'w'):
    pass  # This will truncate the file to 0 length, removing all its content