In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import sys
import os

In [None]:
sys.path.append('..')
from src.model import UNet
from src.dataset import FireDataset

In [None]:
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
dataset = FireDataset('../data/processed/feature_stack.npy', '../data/processed/labels.npy', tile_size=128)
train_loader = DataLoader(dataset, batch_size=8, shuffle=True)

In [None]:
model = UNet(in_channels=3,out_channels=1).to(device)
criterion = nn.BCELoss() 
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
epochs = 20
print("Starting Training")

In [None]:
for epoch in range(epochs):
    model.train()
    epoch_loss = 0

In [None]:
    for batch_x, batch_y in train_loader:
        batch_x, batch_y = batch_x.to(device), batch_y.to(device)
        
        optimizer.zero_grad()
        outputs = model(batch_x)
        loss = criterion(outputs, batch_y)
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
    
    print(f"Epoch [{epoch+1}/{epochs}], Loss: {epoch_loss/len(train_loader):.4f}")

In [None]:
os.makedirs('../models', exist_ok=True)
torch.save(model.state_dict(), '../models/unet_fire_model.pth')
print("Model saved to models/unet_fire_model.pth")