In [1]:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torchvision import transforms
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
import glob
import random
import numpy as np
from torchmetrics import Accuracy, Precision, Recall
import wandb
from pytorch_lightning.loggers import WandbLogger
import torchvision.transforms as T
from src.fire_series_dataset import FireSeriesDataset
from src.temporal_model import TemporalModel 


In [2]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '0'

In [3]:
class FireDataModule(pl.LightningDataModule):
    def __init__(self, data_dir, batch_size=16, img_size=224, num_workers=12):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.img_size = img_size
        self.num_workers = num_workers

    def setup(self, stage=None):
        self.train_dataset = FireSeriesDataset(
            os.path.join(self.data_dir, "train"), self.img_size
        )
        self.val_dataset = FireSeriesDataset(
            os.path.join(self.data_dir, "val"), self.img_size
        )

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers)



In [4]:
import torch
import torch.nn.functional as F
import pytorch_lightning as pl
from torchmetrics import Accuracy, Precision, Recall

class FireClassifier(pl.LightningModule):
    def __init__(self, learning_rate=1e-4, unfreeze_epoch=5):
        super(FireClassifier, self).__init__()
        self.save_hyperparameters()

        self.model = TemporalModel()

        # Initialize the accuracy, precision, and recall metrics
        self.train_accuracy = Accuracy(task="binary")
        self.val_accuracy = Accuracy(task="binary")
        self.train_precision = Precision(task="binary")
        self.val_precision = Precision(task="binary")
        self.train_recall = Recall(task="binary")
        self.val_recall = Recall(task="binary")

        # Track the epoch to unfreeze the model
        self.unfreeze_epoch = unfreeze_epoch

    def unfreeze_model(self):
        """Unfreeze efficient last 5 layers."""
        for param in self.model.efficientnet[-5:].parameters(): 
            param.requires_grad = True

            
    def forward(self, x):
        x = self.model(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x).squeeze()

        # Use binary cross-entropy, since the sigmoid is already applied in the model
        loss = F.binary_cross_entropy(y_hat, y.float())
        
        # Metrics computation (no need for sigmoid as it's already applied)
        acc = self.train_accuracy(y_hat, y.int())
        precision = self.train_precision(y_hat, y.int())
        recall = self.train_recall(y_hat, y.int())
        
        # Logging
        self.log("train_loss", loss, on_step=False, on_epoch=True)
        self.log("train_acc", acc, on_step=False, on_epoch=True)
        self.log("train_precision", precision, on_step=False, on_epoch=True)
        self.log("train_recall", recall, on_step=False, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x).squeeze()

        # Use binary cross-entropy for validation as well
        loss = F.binary_cross_entropy(y_hat, y.float())
        
        # Metrics computation (no need for sigmoid as it's already applied)
        acc = self.val_accuracy(y_hat, y.int())
        precision = self.val_precision(y_hat, y.int())
        recall = self.val_recall(y_hat, y.int())
        
        # Logging
        self.log("val_loss", loss, on_step=False, on_epoch=True)
        self.log("val_acc", acc, on_step=False, on_epoch=True)
        self.log("val_precision", precision, on_step=False, on_epoch=True)
        self.log("val_recall", recall, on_step=False, on_epoch=True)
        return loss

    def on_train_epoch_end(self):
        # At the end of each training epoch, display the current metrics
        train_acc = self.train_accuracy.compute().item()
        train_precision = self.train_precision.compute().item()
        train_recall = self.train_recall.compute().item()
        print(f"Epoch {self.current_epoch}:")
        print(f"Train Accuracy: {train_acc:.4f}")
        print(f"Train Precision: {train_precision:.4f}")
        print(f"Train Recall: {train_recall:.4f}")
        
        # Unfreeze the model after the specified number of epochs
        if self.current_epoch + 1 == self.unfreeze_epoch:
            print(f"Unfreeze efficient last 5 layers at epoch {self.current_epoch + 1}")
            self.unfreeze_model()

        # Reset metrics for the next epoch
        self.train_accuracy.reset()
        self.train_precision.reset()
        self.train_recall.reset()

    def on_validation_epoch_end(self):
        # At the end of each validation epoch, display the current metrics
        val_acc = self.val_accuracy.compute().item()
        val_precision = self.val_precision.compute().item()
        val_recall = self.val_recall.compute().item()
        print(f"Validation - Epoch {self.current_epoch}:")
        print(f"Val Accuracy: {val_acc:.4f}")
        print(f"Val Precision: {val_precision:.4f}")
        print(f"Val Recall: {val_recall:.4f}")
        
        # Reset metrics for the next epoch
        self.val_accuracy.reset()
        self.val_precision.reset()
        self.val_recall.reset()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams['learning_rate'], weight_decay=1e-4)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50, eta_min=1e-6)
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'monitor': 'val_loss'
            }
        }


In [5]:
# Ensure wandb is finished from the previous session
wandb.finish()

# Initialize the DataModule
data_dir = "temporal_ds/images"
data_module = FireDataModule(data_dir, batch_size=16, img_size=112, num_workers=12)

# Initialize the model
model = FireClassifier(learning_rate=1e-4, unfreeze_epoch=3)

# Define callbacks
# Save the best model based on the highest recall score in validation
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    monitor="val_recall",  # Monitor validation recall
    mode="max",            # Save the model with the highest recall
    save_top_k=1,          # Save only the best model
    dirpath="model_checkpoints/",  # Directory to save checkpoints
    filename="fire_model-{epoch:02d}-{val_recall:.4f}",  # Filename format for the saved model
    save_weights_only=True  # Save only the model weights (no optimizer state)
)

# Initialize WandbLogger
wandb_logger = pl.loggers.WandbLogger(project='fire_detection_project')

# Initialize the Trainer (total epochs include frozen + unfrozen training)
trainer = pl.Trainer(
    max_epochs=20,  # Total epochs (including frozen and unfrozen stages)
    callbacks=[checkpoint_callback],  # Add the checkpoint callback
    logger=wandb_logger
)

# Train the model
trainer.fit(model, datamodule=data_module)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3050 Ti Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mmateolos[0m. Use [1m`wandb login --relogin`[0m to force relogin


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name            | Type            | Params | Mode 
------------------------------------------------------------
0 | model           | TemporalModel   | 9.1 M  | train
1 | train_accuracy  | BinaryAccuracy  | 0      | train
2 | val_accuracy    | BinaryAccuracy  | 0      | train
3 | train_precision | BinaryPrecision | 0      | train
4 | val_precision   | BinaryPrecision | 0      | train
5 | train_recall    | BinaryRecall    | 0      | train
6 | val_recall      | BinaryRecall    | 0      | train
------------------------------------------------------------
2.6 M     Trainable params
6.5 M     Non-trainable params
9.1 M     Total params
36.564    Total estimated model params size (MB)
480       Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Validation - Epoch 0:
Val Accuracy: 0.5000
Val Precision: 0.5000
Val Recall: 1.0000


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation - Epoch 0:
Val Accuracy: 0.8244
Val Precision: 0.7978
Val Recall: 0.8690
Epoch 0:
Train Accuracy: 0.7803
Train Precision: 0.7321
Train Recall: 0.8842


Validation: |          | 0/? [00:00<?, ?it/s]

Validation - Epoch 1:
Val Accuracy: 0.8601
Val Precision: 0.8712
Val Recall: 0.8452
Epoch 1:
Train Accuracy: 0.8655
Train Precision: 0.8391
Train Recall: 0.9042


Validation: |          | 0/? [00:00<?, ?it/s]

Validation - Epoch 2:
Val Accuracy: 0.8482
Val Precision: 0.8774
Val Recall: 0.8095
Epoch 2:
Train Accuracy: 0.8821
Train Precision: 0.8676
Train Recall: 0.9018
Unfreezing all model layers at epoch 3


Validation: |          | 0/? [00:00<?, ?it/s]

Validation - Epoch 3:
Val Accuracy: 0.9048
Val Precision: 0.9474
Val Recall: 0.8571
Epoch 3:
Train Accuracy: 0.9185
Train Precision: 0.9113
Train Recall: 0.9273


Validation: |          | 0/? [00:00<?, ?it/s]

Validation - Epoch 4:
Val Accuracy: 0.8958
Val Precision: 0.9524
Val Recall: 0.8333
Epoch 4:
Train Accuracy: 0.9606
Train Precision: 0.9617
Train Recall: 0.9594


Validation: |          | 0/? [00:00<?, ?it/s]

Validation - Epoch 5:
Val Accuracy: 0.8839
Val Precision: 0.9510
Val Recall: 0.8095
Epoch 5:
Train Accuracy: 0.9797
Train Precision: 0.9794
Train Recall: 0.9800


Validation: |          | 0/? [00:00<?, ?it/s]

Validation - Epoch 6:
Val Accuracy: 0.9315
Val Precision: 0.9503
Val Recall: 0.9107
Epoch 6:
Train Accuracy: 0.9821
Train Precision: 0.9801
Train Recall: 0.9842


Validation: |          | 0/? [00:00<?, ?it/s]

Validation - Epoch 7:
Val Accuracy: 0.9137
Val Precision: 0.9728
Val Recall: 0.8512
Epoch 7:
Train Accuracy: 0.9870
Train Precision: 0.9867
Train Recall: 0.9873


Validation: |          | 0/? [00:00<?, ?it/s]

Validation - Epoch 8:
Val Accuracy: 0.9226
Val Precision: 0.9733
Val Recall: 0.8690
Epoch 8:
Train Accuracy: 0.9858
Train Precision: 0.9861
Train Recall: 0.9855


Validation: |          | 0/? [00:00<?, ?it/s]

Validation - Epoch 9:
Val Accuracy: 0.9256
Val Precision: 0.9673
Val Recall: 0.8810
Epoch 9:
Train Accuracy: 0.9912
Train Precision: 0.9909
Train Recall: 0.9915


Validation: |          | 0/? [00:00<?, ?it/s]

Validation - Epoch 10:
Val Accuracy: 0.9167
Val Precision: 0.9730
Val Recall: 0.8571
Epoch 10:
Train Accuracy: 0.9924
Train Precision: 0.9927
Train Recall: 0.9921


Validation: |          | 0/? [00:00<?, ?it/s]

Validation - Epoch 11:
Val Accuracy: 0.9375
Val Precision: 0.9742
Val Recall: 0.8988
Epoch 11:
Train Accuracy: 0.9882
Train Precision: 0.9855
Train Recall: 0.9909


Validation: |          | 0/? [00:00<?, ?it/s]

Validation - Epoch 12:
Val Accuracy: 0.9107
Val Precision: 0.9662
Val Recall: 0.8512
Epoch 12:
Train Accuracy: 0.9936
Train Precision: 0.9963
Train Recall: 0.9909


Validation: |          | 0/? [00:00<?, ?it/s]

Validation - Epoch 13:
Val Accuracy: 0.9435
Val Precision: 0.9686
Val Recall: 0.9167
Epoch 13:
Train Accuracy: 0.9948
Train Precision: 0.9940
Train Recall: 0.9958


Validation: |          | 0/? [00:00<?, ?it/s]

Validation - Epoch 14:
Val Accuracy: 0.9315
Val Precision: 0.9801
Val Recall: 0.8810
Epoch 14:
Train Accuracy: 0.9936
Train Precision: 0.9927
Train Recall: 0.9945


Validation: |          | 0/? [00:00<?, ?it/s]

Validation - Epoch 15:
Val Accuracy: 0.9167
Val Precision: 0.9930
Val Recall: 0.8393
Epoch 15:
Train Accuracy: 0.9988
Train Precision: 0.9982
Train Recall: 0.9994


Validation: |          | 0/? [00:00<?, ?it/s]

Validation - Epoch 16:
Val Accuracy: 0.9286
Val Precision: 0.9865
Val Recall: 0.8690
Epoch 16:
Train Accuracy: 0.9964
Train Precision: 0.9952
Train Recall: 0.9976


Validation: |          | 0/? [00:00<?, ?it/s]

Validation - Epoch 17:
Val Accuracy: 0.9226
Val Precision: 0.9863
Val Recall: 0.8571
Epoch 17:
Train Accuracy: 0.9982
Train Precision: 0.9976
Train Recall: 0.9988


Validation: |          | 0/? [00:00<?, ?it/s]

Validation - Epoch 18:
Val Accuracy: 0.9435
Val Precision: 0.9806
Val Recall: 0.9048
Epoch 18:
Train Accuracy: 0.9958
Train Precision: 0.9970
Train Recall: 0.9945


Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=20` reached.


Validation - Epoch 19:
Val Accuracy: 0.9286
Val Precision: 0.9865
Val Recall: 0.8690
Epoch 19:
Train Accuracy: 0.9970
Train Precision: 0.9964
Train Recall: 0.9976


In [6]:
wandb.finish()

VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇████
train_acc,▁▄▄▅▇▇▇█████████████
train_loss,█▆▅▄▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁
train_precision,▁▃▄▅▇▇▇█████████████
train_recall,▁▂▂▄▆▇▇▇▇███████████
trainer/global_step,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇████
val_acc,▁▃▂▆▅▅▇▆▇▇▆█▆█▇▆▇▇█▇
val_loss,█▇▇▃▃▄▁▂▃▁▄▂▆▂▄▆▅▅▁▃
val_precision,▁▄▄▇▇▆▆▇▇▇▇▇▇▇██████
val_recall,▅▄▁▄▃▁█▄▅▆▄▇▄█▅▃▅▄▇▅

0,1
epoch,19.0
train_acc,0.99697
train_loss,0.01004
train_precision,0.99562
train_recall,0.99781
trainer/global_step,4139.0
val_acc,0.92857
val_loss,0.24767
val_precision,0.98519
val_recall,0.86914
