In [None]:
import matplotlib.pyplot as plt
import geopandas as gpd
import numpy as np
from data_sources.nbac_fire_data_source import NbacFireDataSource
from boundaries.canada_boundary import CanadaBoundary
from data_sources.canada_boundary_data_source import CanadaBoundaryDataSource
from targets.fire_occurrence_target import FireOccurrenceTarget
from pathlib import Path
from osgeo import gdal
from osgeo import osr

In [None]:
years = range(2001, 2003)
target_epsg = 3978

canada_output_path = Path("../data/canada_boundary/")
canada = CanadaBoundary(CanadaBoundaryDataSource(canada_output_path), target_epsg=target_epsg)
canada.load(["NL", "PE", "NS", "NB", "QC", "ON", "MB", "SK", "AB", "BC"])

raw_data_path = Path("../data/tmp/")
fire_data_source = NbacFireDataSource(raw_data_path)

In [None]:
target = FireOccurrenceTarget(
    fire_data_source=fire_data_source,
    boundary=canada,
    target_pixel_size_in_meters=1000,
    target_epsg_code=target_epsg,
    output_folder_path=Path("../data/tmp/fire_occurrence_target/")
)

In [None]:
def rasterize_fire_polygons(year_fire_polygons, target_epsg, x_min, y_min, x_max, y_max, pixel_size, output_raster_path):
    x_res = int((x_max - x_min) / pixel_size)
    y_res = int((y_max - y_min) / pixel_size)

    target_ds = gdal.GetDriverByName('netCDF').Create(output_raster_path, x_res, y_res, 1, gdal.GDT_Byte)
    target_ds.SetGeoTransform((x_min, pixel_size, 0, y_max, 0, -pixel_size))

    srs = osr.SpatialReference()
    srs.ImportFromEPSG(target_epsg)
    target_ds.SetProjection(srs.ExportToWkt())

    band = target_ds.GetRasterBand(1)
    band.SetNoDataValue(0)
    
    shp_file_path = Path("../data/tmp/") / f"{output_raster_path.stem}_shp"
    year_fire_polygons.to_file(str(shp_file_path.resolve()))
    shp_ds = gdal.OpenEx(str(shp_file_path.resolve()), gdal.OF_VECTOR)
    shp_layer = shp_ds.GetLayer(0)
    
    gdal.RasterizeLayer(
        target_ds,
        [1],
        shp_layer,
        burn_values=[1],
    )

def combine_rasters(raster_paths, combined_raster_path, x_res, y_res, x_min, y_max, pixel_size):
    combined_raster = np.zeros((y_res, x_res), dtype=np.uint8)

    for raster_path in raster_paths:
        ds = gdal.Open(raster_path)
        band = ds.GetRasterBand(1)
        data = band.ReadAsArray()
        combined_raster = np.maximum(combined_raster, data)

    driver = gdal.GetDriverByName('netCDF')
    out_raster = driver.Create(combined_raster_path, x_res, y_res, 1, gdal.GDT_Byte)
    out_raster.SetGeoTransform((x_min, pixel_size, 0, y_max, 0, -pixel_size))
    out_band = out_raster.GetRasterBand(1)
    out_band.Fill(0)
    out_band.WriteArray(combined_raster)
    out_raster = None

x_min, y_min, x_max, y_max = canada.boundary.total_bounds
pixel_size = 250 

x_res = int((x_max - x_min) / pixel_size)
y_res = int((y_max - y_min) / pixel_size)

print(f"x_res: {x_res}, y_res: {y_res}")

raster_paths = []
for year in years:
    year_fire_polygons = fire_data_source.download(year)
    year_fire_polygons = year_fire_polygons.to_crs(epsg=target_epsg)
    output_raster_path = Path(f"../data/tmp/fire_{year}.nc")
    output_raster_path.parent.mkdir(parents=True, exist_ok=True)
    rasterize_fire_polygons(year_fire_polygons, target_epsg, x_min, y_min, x_max, y_max, pixel_size, output_raster_path)
    raster_paths.append(str(output_raster_path.resolve()))

combined_raster_path = "../data/tmp/fire_union_combined.nc"
combine_rasters(raster_paths, combined_raster_path, x_res, y_res, x_min, y_max, pixel_size)

boundary_shapefile = list(Path(canada_output_path).glob("*.shp"))[0]

print(f"boundary_shapefile: {boundary_shapefile}")

canada_mask = gdal.Warp('../data/tmp/fire_union_masked.nc', combined_raster_path,
                        cutlineDSName=str(boundary_shapefile.resolve()),
                        cropToCutline=True)