In [1]:
### MODEL TRAINING
import os
import glob
import torch
from torch.utils.data import Dataset, DataLoader
import rasterio
import numpy as np

class FireBurnDataset(Dataset):
    def __init__(self, merged_dir, mask_dir, transform=None):
        self.merged_paths = sorted(glob.glob(os.path.join(merged_dir, 'merged_ID*.tif')))
        self.mask_dir = mask_dir
        self.transform = transform

    def __len__(self):
        return len(self.merged_paths)

    def __getitem__(self, idx):
        merged_path = self.merged_paths[idx]
        id_str = os.path.basename(merged_path).split('_ID')[-1].split('.')[0]
        mask_path = os.path.join(self.mask_dir, f'mask_ID{id_str}.tif')

        with rasterio.open(merged_path) as src:
            image = src.read().astype(np.float32)  # Shape: (bands, H, W)

        with rasterio.open(mask_path) as src:
            mask = src.read(1).astype(np.int64)  # Shape: (H, W)

        if self.transform:
            image, mask = self.transform(image, mask)

        return torch.tensor(image), torch.tensor(mask)


In [2]:
### UNET

import torch.nn as nn
import torch.nn.functional as F

class UNet(nn.Module):
    def __init__(self, in_channels, n_classes):
        super(UNet, self).__init__()

        def conv_block(in_ch, out_ch):
            return nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 3, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_ch, out_ch, 3, padding=1),
                nn.ReLU(inplace=True)
            )

        self.enc1 = conv_block(in_channels, 64)
        self.enc2 = conv_block(64, 128)
        self.enc3 = conv_block(128, 256)
        self.enc4 = conv_block(256, 512)

        self.pool = nn.MaxPool2d(2)

        self.middle = conv_block(512, 1024)

        self.up4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.dec4 = conv_block(1024, 512)
        self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.dec3 = conv_block(512, 256)
        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec2 = conv_block(256, 128)
        self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec1 = conv_block(128, 64)

        self.out_conv = nn.Conv2d(64, n_classes, kernel_size=1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))

        m = self.middle(self.pool(e4))

        d4 = self.up4(m)
        d4 = self.dec4(torch.cat([d4, e4], dim=1))
        d3 = self.up3(d4)
        d3 = self.dec3(torch.cat([d3, e3], dim=1))
        d2 = self.up2(d3)
        d2 = self.dec2(torch.cat([d2, e2], dim=1))
        d1 = self.up1(d2)
        d1 = self.dec1(torch.cat([d1, e1], dim=1))

        return self.out_conv(d1)


In [5]:
### Training Set Up

from torch.utils.data import random_split
import torch.optim as optim

# Parameters
merged_dir = 'data/merged'
mask_dir = 'data/mask'
batch_size = 4
num_epochs = 20
learning_rate = 1e-4
num_classes = 4

# Dataset and split
dataset = FireBurnDataset(merged_dir, mask_dir)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_ds, val_ds = random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=1)

# Get input channels from first sample
in_channels = dataset[0][0].shape[0]

# Model, loss, optimizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet(in_channels, num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
print("Using device:", device)


Using device: cpu


In [6]:
### Training loop
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, masks in train_loader:
        images, masks = images.to(device), masks.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader):.4f}")


KeyboardInterrupt: 

In [None]:
### Evaluate model
import torch.nn.functional as F
from sklearn.metrics import confusion_matrix
import numpy as np

def evaluate_model(model, dataloader, device, num_classes=4):
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for images, masks in dataloader:
            images = images.to(device)
            masks = masks.to(device)

            outputs = model(images)
            preds = torch.argmax(outputs, dim=1)

            all_preds.append(preds.cpu().numpy())
            all_labels.append(masks.cpu().numpy())

    all_preds = np.concatenate(all_preds).flatten()
    all_labels = np.concatenate(all_labels).flatten()

    cm = confusion_matrix(all_labels, all_preds, labels=list(range(num_classes)))
    iou = np.diag(cm) / (cm.sum(1) + cm.sum(0) - np.diag(cm) + 1e-7)
    pixel_acc = np.mean(all_preds == all_labels)

    print(f"Pixel Accuracy: {pixel_acc:.4f}")
    for i, class_iou in enumerate(iou):
        print(f"Class {i} IoU: {class_iou:.4f}")

    return all_preds, all_labels


In [None]:
evaluate_model(model, val_loader, device)


In [None]:
### Visualize example

import matplotlib.pyplot as plt

def visualize_predictions(model, dataloader, device, n=3):
    model.eval()
    count = 0

    with torch.no_grad():
        for images, masks in dataloader:
            images = images.to(device)
            outputs = model(images)
            preds = torch.argmax(outputs, dim=1)

            for i in range(images.size(0)):
                input_img = images[i].cpu().numpy()
                pred_mask = preds[i].cpu().numpy()
                true_mask = masks[i].cpu().numpy()

                # Choose 3 bands for visualization (e.g., RGB-like)
                rgb = input_img[:3] if input_img.shape[0] >= 3 else np.repeat(input_img[0:1], 3, axis=0)
                rgb = np.transpose(rgb, (1, 2, 0))
                rgb = (rgb - rgb.min()) / (rgb.max() - rgb.min() + 1e-6)

                fig, axs = plt.subplots(1, 3, figsize=(12, 4))
                axs[0].imshow(rgb)
                axs[0].set_title("Input Image")
                axs[1].imshow(true_mask, cmap='viridis', vmin=0, vmax=3)
                axs[1].set_title("Ground Truth")
                axs[2].imshow(pred_mask, cmap='viridis', vmin=0, vmax=3)
                axs[2].set_title("Prediction")
                for ax in axs:
                    ax.axis('off')
                plt.tight_layout()
                plt.show()

                count += 1
                if count >= n:
                    return


In [None]:
visualize_predictions(model, val_loader, device, n=3)
