In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split, WeightedRandomSampler
from torchvision import transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
from collections import Counter

In [None]:
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(30),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [None]:
class WildfireRiskDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = ['Very_Low', 'Low', 'Moderate', 'High', 'Very_High', 'Water', 'Non-burnable']
        self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}

        self.image_paths = []
        self.labels = []

        for cls in self.classes:
            cls_dir = os.path.join(root_dir, cls)
            for img_name in os.listdir(cls_dir):
                if img_name.endswith(('.png')):
                    self.image_paths.append(os.path.join(cls_dir, img_name))
                    self.labels.append(self.class_to_idx[cls])

        self.print_dataset_stats()

    def print_dataset_stats(self):
        counter = Counter(self.labels)
        print("\nDataset Statistics:")
        print(f"{'Class':<15} {'Count':<10} {'Percentage':<10}")
        for cls_idx, count in counter.items():
            percentage = count / len(self.labels) * 100
            print(f"{self.classes[cls_idx]:<15} {count:<10} {percentage:.2f}%")
        print(f"\nTotal images: {len(self.labels)}")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]

        try:
            image = Image.open(img_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
            return image, label
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            return self.__getitem__((idx + 1) % len(self))

In [None]:
dataset = WildfireRiskDataset(root_dir="FireRisk/train", transform=transform)

class_counts = torch.tensor([len([x for x in dataset.labels if x == i]) for i in range(7)])
class_weights = 1. / class_counts.float()
sample_weights = class_weights[dataset.labels]
sampler = WeightedRandomSampler(sample_weights, len(sample_weights))

# Splitting our dataset into training into 60/20/20 training/validation/testing
train_size = int(0.6 * len(dataset))
val_size = int(0.2 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)