# Practical Work: Fire Detection

In [None]:
import sys
import os
import random

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from torchvision.models import resnet18, ResNet18_Weights
from torchsummary import summary

import matplotlib.pyplot as plt
from PIL import Image

from datasets import WildfirePredictionDataset
from transformations import RandomTransformation
from models import UNet, UNetClassifier

In [None]:
# Set random seed for reproducability and dataset splitting
torch.manual_seed(42)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

## Exploring Data

In [None]:
# Dataset and subsets path
path = "/home/ids/fallemand-24/ROB313/data/wildfire-prediction-dataset"
train_wf_path = "/home/ids/fallemand-24/ROB313/data/wildfire-prediction-dataset/train/wildfire"
train_nwf_path = "/home/ids/fallemand-24/ROB313/data/wildfire-prediction-dataset/train/nowildfire"
valid_wf_path = "/home/ids/fallemand-24/ROB313/data/wildfire-prediction-dataset/valid/wildfire"
valid_nwf_path = "/home/ids/fallemand-24/ROB313/data/wildfire-prediction-dataset/valid/nowildfire"
test_wf_path = "/home/ids/fallemand-24/ROB313/data/wildfire-prediction-dataset/test/wildfire"
test_nwf_path = "/home/ids/fallemand-24/ROB313/data/wildfire-prediction-dataset/test/nowildfire"

In [None]:
%%script false --no-raise-error
# Check data (corrupted images)
def check_data(path):
	files = os.listdir(path)
	for f in files:
		try:
			Image.open(os.path.join(path, f)).convert("RGB")
		except:
			print(f"{path} | {f}")

check_data(train_wf_path)
check_data(train_nwf_path)
check_data(valid_wf_path)
check_data(valid_nwf_path)
check_data(test_wf_path)
check_data(test_nwf_path)

# /home/ids/fallemand-24/ROB313/data/wildfire-prediction-dataset/train/nowildfire/-114.152378,51.027198.jpg
# /home/ids/fallemand-24/ROB313/data/wildfire-prediction-dataset/test/wildfire/-73.15884,46.38819.jpg

In [None]:
# Check class balance
print(f"Train [wf]: {len(os.listdir(train_wf_path))}")
print(f"Train [nwf]: {len(os.listdir(train_nwf_path))}")
print(f"Valid [wf]: {len(os.listdir(valid_wf_path))}")
print(f"Valid [nwf]: {len(os.listdir(valid_nwf_path))}")
print(f"Test [wf]: {len(os.listdir(test_wf_path))}")
print(f"Test [nwf]: {len(os.listdir(test_nwf_path))}")

In [None]:
# Define image transformation
transform = transforms.Compose([
    transforms.ToTensor()
])

# Load Wildfire Prediction Dataset dataset
dataset = WildfirePredictionDataset(split="valid", transform=transform)
print(len(dataset))

In [None]:
# %%script false --no-raise-error
# Check sample sizes
sample_id = torch.randint(len(dataset), size=(1,)).item()
img, label = dataset[sample_id]
print(img.shape)

In [None]:
# %%script false --no-raise-error
# Plot images and labels
cols, rows = 3, 3
figure = plt.figure(figsize=(cols*4, rows*4))
for i in range(1, cols * rows + 1):
    sample_id = torch.randint(len(dataset), size=(1,)).item()
    img, label = dataset[sample_id]
    figure.add_subplot(rows, cols, i)
    plt.title(f"{label} | {dataset.labels_dict[label]}")
    plt.axis("off")
    plt.imshow(img.permute(1, 2, 0))
plt.show()

In [None]:
# %%script false --no-raise-error
# Data augmentation function
random_trans = RandomTransformation((350, 350))

# Plot images and transformed images
rows = 3
cols = 2
fig, axs = plt.subplots(rows, cols, figsize=(cols*4, rows*4))
for i in range(rows):
	sample_id = torch.randint(len(dataset), size=(1,)).item()
	img, label = dataset[sample_id]
	img_trans = random_trans(img)

	axs[i,0].set_title(f"{label} | {dataset.labels_dict[label]}")
	axs[i,0].axis("off")
	axs[i,0].imshow(img.permute(1, 2, 0))
	axs[i,1].axis("off")
	axs[i,1].imshow(img_trans.permute(1, 2, 0))
plt.show()

## Loading Data

In [None]:
# Create training, validation and test data loaders
train_loader, valid_loader_1, valid_loader_2, test_loader = WildfirePredictionDataset.get_dataloaders(transform=transform, batch_size=16, max_valid_2=10)

## SimCLR

### ResNet

#### Image Projection

In [None]:
id = 285136
resnet_auto = resnet18(weights=ResNet18_Weights.DEFAULT)
resnet_auto.fc = nn.Sequential(
	nn.Linear(512, 256),
	nn.ReLU(),
	nn.Linear(256, 128),
)
checkpoint = torch.load(f"train_res/{id}/checkpoint_best.pth.tar",
	weights_only=True, map_location=torch.device("cpu"))
resnet_auto.load_state_dict(checkpoint["state_dict"])
resnet_auto = resnet_auto.eval().to(device)

In [None]:
%%script false --no-raise-error
summary(resnet_auto, input_size=(3, 350, 350))

In [None]:
batch_img, batch_label = next(iter(test_loader))
batch_out = resnet_auto(batch_img)

#### Image Classification

In [None]:
# ResNet SL
# For baseline comparison
id = 285460
resnet_sl = resnet18(weights=ResNet18_Weights.DEFAULT)
resnet_sl.fc = nn.Sequential(
	nn.Flatten(),
	nn.Linear(512, 256),
	nn.ReLU(),
	nn.Linear(256, 64),
	nn.ReLU(),
	nn.Linear(64, 2),
)
checkpoint = torch.load(f"train_res/{id}/checkpoint_best.pth.tar",
	weights_only=True, map_location=torch.device("cpu"))
resnet_sl.load_state_dict(checkpoint["state_dict"])
resnet_sl = resnet_sl.eval().to(device)

In [None]:
# ResNet SSL
id = 285444
resnet_ssl = resnet18(weights=ResNet18_Weights.DEFAULT)
resnet_ssl.fc = nn.Sequential(
	nn.Flatten(),
	nn.Linear(512, 256),
	nn.ReLU(),
	nn.Linear(256, 64),
	nn.ReLU(),
	nn.Linear(64, 2),
)
checkpoint = torch.load(f"train_res/{id}/checkpoint_best.pth.tar",
	weights_only=True, map_location=torch.device("cpu"))
resnet_ssl.load_state_dict(checkpoint["state_dict"])
resnet_ssl = resnet_ssl.eval().to(device)

In [None]:
%%script false --no-raise-error
summary(resnet_ssl, input_size=(3, 350, 350))

### UNet

#### Image Reconstruction

In [None]:
id = 285138
unet_auto = UNet()
checkpoint = torch.load(f"train_res/{id}/checkpoint_best.pth.tar",
	weights_only=True, map_location=torch.device("cpu"))
unet_auto.load_state_dict(checkpoint["state_dict"])
unet_auto = unet_auto.eval().to(device)

In [None]:
%%script false --no-raise-error
summary(unet_auto, input_size=(3, 350, 350))

In [None]:
batch_img, batch_label = next(iter(test_loader))
batch_out = unet_auto(batch_img)

In [None]:
# %%script false --no-raise-error
# Plot images and reconstruction
rows = 4
cols = 3
fig, axs = plt.subplots(rows, cols, figsize=(cols*4, rows*4))
for i in range(rows):
	sample_id = torch.randint(batch_img.shape[0], size=(1,)).item()
	img, label = batch_img[sample_id], batch_label[sample_id].item()
	reconstruction = batch_out["x_hat"][sample_id]
	diffs = torch.mean((reconstruction - img).abs(), dim=0).detach().cpu()

	axs[i,0].imshow(img.permute(1, 2, 0))
	axs[i,0].set_title(f"{label} | {dataset.labels_dict[label]}")
	axs[i,0].axis("off")

	axs[i,1].imshow(transforms.ToPILImage()(reconstruction.detach().cpu()))
	axs[i,2].set_title("Reconstruction")
	axs[i,1].axis("off")

	axs[i,2].imshow(diffs, cmap="viridis")
	axs[i,2].set_title("Difference")
	axs[i,2].axis("off")
    
plt.show()

#### Image Classification

In [None]:
# UNet SL
# For baseline comparison
id = 000
unet_sl = Classifier()
checkpoint = torch.load(f"train_res/{id}/checkpoint_best.pth.tar",
	weights_only=True, map_location=torch.device("cpu"))
unet_sl.load_state_dict(checkpoint["state_dict"])
unet_sl = unet_sl.eval().to(device)

In [None]:
# UNet SSL
id = 000
unet_ssl = Classifier()
checkpoint = torch.load(f"train_res/{id}/checkpoint_best.pth.tar",
	weights_only=True, map_location=torch.device("cpu"))
unet_ssl.load_state_dict(checkpoint["state_dict"])
unet_ssl = unet_ssl.eval().to(device)

In [None]:
%%script false --no-raise-error
summary(unet_ssl, input_size=(3, 350, 350))