In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split, WeightedRandomSampler
from torchvision import transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
from collections import Counter
from sklearn.model_selection import StratifiedShuffleSplit

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(30),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [None]:
class WildfireRiskDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = ['Very_Low', 'Low', 'Moderate', 'High', 'Very_High', 'Water', 'Non-burnable']
        self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}

        self.image_paths = []
        self.labels = []

        for cls in self.classes:
            cls_dir = os.path.join(root_dir, cls)
            for img_name in os.listdir(cls_dir):
                if img_name.endswith(('.png')):
                    self.image_paths.append(os.path.join(cls_dir, img_name))
                    self.labels.append(self.class_to_idx[cls])

        self.print_dataset_stats()

    def print_dataset_stats(self):
        counter = Counter(self.labels)
        print("\nDataset Statistics:")
        print(f"{'Class':<15} {'Count':<10} {'Percentage':<10}")
        for cls_idx, count in counter.items():
            percentage = count / len(self.labels) * 100
            print(f"{self.classes[cls_idx]:<15} {count:<10} {percentage:.2f}%")
        print(f"\nTotal images: {len(self.labels)}")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]

        try:
            image = Image.open(img_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
            return image, label
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            return self.__getitem__((idx + 1) % len(self))

In [None]:
dataset = WildfireRiskDataset(root_dir="drive/MyDrive/FireRisk/train", transform=transform)

class_counts = torch.tensor([len([x for x in dataset.labels if x == i]) for i in range(7)])
class_weights = 1. / class_counts.float()
sample_weights = class_weights[dataset.labels]
sampler = WeightedRandomSampler(sample_weights, len(sample_weights))

In [None]:
def stratified_split(dataset, test_size=0.2, val_size=0.2, random_state=42):
    sss = StratifiedShuffleSplit(n_splits=1, test_size=test_size, random_state=random_state)
    train_idx, test_idx = next(sss.split(dataset.image_paths, dataset.labels))
    
    # Further split train into train and val
    sss_val = StratifiedShuffleSplit(n_splits=1, test_size=val_size/(1-test_size), random_state=random_state)
    train_new_idx, val_idx = next(sss_val.split(
        [dataset.image_paths[i] for i in train_idx],
        [dataset.labels[i] for i in train_idx]
    ))
    
    # Get the actual indices
    final_train_idx = [train_idx[i] for i in train_new_idx]
    final_val_idx = [train_idx[i] for i in val_idx]
    
    return final_train_idx, final_val_idx, test_idx

train_idx, val_idx, test_idx = stratified_split(dataset)
train_dataset = torch.utils.data.Subset(dataset, train_idx)
val_dataset = torch.utils.data.Subset(dataset, val_idx)
test_dataset = torch.utils.data.Subset(dataset, test_idx)

In [None]:
train_labels = [dataset.labels[i] for i in train_idx]
class_counts = torch.tensor([train_labels.count(i) for i in range(7)])
class_weights = 1. / class_counts.float()
sample_weights = class_weights[torch.tensor(train_labels)]

sampler = WeightedRandomSampler(
    weights=sample_weights,
    num_samples=len(train_labels),
    replacement=True)

batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

In [None]:
def print_class_distribution(dataset, subset, subset_name):
    labels = [dataset.labels[i] for i in subset.indices]
    counter = Counter(labels)
    
    print(f"\nClass Distribution for {subset_name} set:")
    print(f"{'Class':<15} {'Count':<10} {'Percentage':<10}")
    total = len(labels)
    
    for cls_idx, cls_name in enumerate(dataset.classes):
        count = counter.get(cls_idx, 0)
        percentage = count / total * 100
        print(f"{cls_name:<15} {count:<10} {percentage:.2f}%")
    
    print(f"Total images: {total}")

print_class_distribution(dataset, train_dataset, "Training")
print_class_distribution(dataset, val_dataset, "Validation")
print_class_distribution(dataset, test_dataset, "Test")

In [None]:
class WildfireRiskClassifier(nn.Module):
    def __init__(self, num_classes=7):
        super(WildfireRiskClassifier, self).__init__()

        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )

        self.classifier = nn.Sequential(
            nn.Linear(512 * 16 * 16, 1024),
            nn.ReLU(),
            nn.Dropout(0.5),

            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.3),

            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = WildfireRiskClassifier(num_classes=7).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-5)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.1)

In [None]:
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    scaler = torch.amp.GradScaler('cuda')

    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)

        with torch.amp.autocast(device_type='cuda'):
            outputs = model(images)
            loss = criterion(outputs, labels)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()


        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    epoch_loss = running_loss / len(loader.dataset)
    epoch_acc = correct / total
    return epoch_loss, epoch_acc

def evaluate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    epoch_loss = running_loss / len(loader.dataset)
    epoch_acc = correct / total
    return epoch_loss, epoch_acc, all_preds, all_labels

## Try running one epoch

In [None]:
try:
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    print("Success! Training worked with simple loader")
except Exception as e:
    print(f"Failing with error: {e}")