# UHI CNN

In [1]:
import geopandas as gpd
import pandas as pd
from shapely.geometry import Point
import numpy as np
import torch.nn.functional as F
import rasterio
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as T
import rasterio.mask
import matplotlib.pyplot as plt
import xarray as xr
import rioxarray as rxr
import rasterio

import torch.nn as nn

#### SET UP

In [2]:


def get_gdf(path="Training_data_uhi_index_2025-02-18"):
    df = pd.read_csv(path)
    gdf = gpd.GeoDataFrame(
        df,
        geometry=[Point(lon, lat) for lon, lat in zip(df['Longitude'], df['Latitude'])],
        crs="EPSG:4326"
    ).to_crs("EPSG:2263")
    return gdf

# Load your CSV with UHI index values (with Latitude, Longitude, UHI_Index)
train_gdf = get_gdf(path="Training_data_uhi_index_2025-02-18.csv")
submission_gdf = get_gdf(path="_Submission_Processed_Data/Submission_template_UHI2025-v2.csv")
landcover_data = rasterio.open(r"Referentiels/landcover_nyc_2021_6in_Clipped_5.tif")


def extract_satellite_patch(satellite_ds, point, patch_size_meters=200):
    # Extract original patch
    geom = point.buffer(patch_size_meters* 3.28084)
    patch = satellite_ds.rio.clip([geom])
    patch_array = patch.squeeze(dim="band", drop=True).to_array().values
    patch_array = torch.tensor(patch_array)
    print("orgi ",patch_array.shape )

    Global_band_means= satellite_ds.mean().to_array().values
    patch_array = np.where(np.isnan(patch_array), Global_band_means[:, None, None], patch_array)
    # Add batch and channel dimensions (required for F.interpolate)
    upsampled_tensor = F.interpolate(
        torch.tensor(patch_array).unsqueeze(0),  # shape: (1, 1, H, W)
        size=(132, 132),
        mode='bilinear'
    ).squeeze()  # shape: (132, 132)

    return  upsampled_tensor


def extract_LandCover_patch(data, point, patch_size_meters=100):

    geom = point.buffer(patch_size_meters * 3.28084)
    out_image, _ = rasterio.mask.mask(data, [geom], crop=True, nodata=0)

    return out_image[0]


def one_hot_encode_landcover(patch, all_classes=[1,2,3,4,5,6,7,8]):

    one_hot = np.zeros((len(all_classes), patch.shape[0], patch.shape[1]), dtype=np.uint8)

    for i, c in enumerate(all_classes):
        one_hot[i] = (patch == c).astype(np.uint8)

    return torch.tensor(one_hot, dtype=torch.float32)




#### Catboost on CNN

In [3]:
train_gdf = get_gdf(path="Training_data_uhi_index_2025-02-18.csv")
submission_gdf = get_gdf(path="_Submission_Processed_Data/Submission_template_UHI2025-v2.csv")
landcover_data = rasterio.open(r"Referentiels/landcover_nyc_2021_6in_Clipped_5.tif")


In [4]:
from _LandCover_CNN import Adaptive_UHI_CNN, UHIPatchDataset
# get the train val split: manual seed keeps us tracking wich one were training points, which ones were validation
indices = torch.randperm(len(train_gdf), generator=torch.Generator().manual_seed(42))
train_size = int(0.9 * len(train_gdf))
train_indices = indices[:train_size]
val_indices = indices[train_size:]


train_gdf_split = train_gdf.iloc[train_indices]
val_gdf_split = train_gdf.iloc[val_indices]


train_dataset = UHIPatchDataset(
    gdf=train_gdf_split,
    patch_size_meters=500,
    output_size=(200,200)
)

val_dataset = UHIPatchDataset(
    gdf=val_gdf_split,
    patch_size_meters=500,
    output_size=(200,200)
)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

In [5]:
loaded_model = Adaptive_UHI_CNN(num_classes=8)  
loaded_model.load_state_dict(torch.load("_Models/LandCover_AdaptiveUHICNN_500m_220325.pth"))
loaded_model.eval()  # Set the model to evaluation mode
print("Model loaded and set to evaluation mode.")

def get_predictions(model, dataloader):
    model.eval()  
    predictions = []
    with torch.no_grad():
        for batch in dataloader:
            patches, _ = batch  
            preds = model(patches).squeeze()  
            predictions.extend(preds.cpu().numpy())
    return predictions

train_preds = get_predictions(loaded_model, train_loader)
val_preds = get_predictions(loaded_model, val_loader)



Model loaded and set to evaluation mode.


In [22]:
train_gdf_split["LandCover_CNN"] = train_preds
val_gdf_split["LandCover_CNN"] = val_preds


In [20]:
## Here, we are insuring that there is no data leakage, by taking the same samples the CNN learned from and feeding them to the catboost

import json
# from sklearn.preprocessing import StandardScaler, MinMaxScaler, KBinsDiscretizer
from catboost import CatBoostRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error

with open("_Models/Final_Model_Features.json") as f:
    loaded_top_features =  ["LandCover_CNN"] + json.load(f)
    print("number of Top features : ",len(loaded_top_features))

Landsat_Train = pd.read_csv('_Train_Processed_Data/Landsat_Train_V26022024.csv')
Sentinel2_Train = pd.read_csv('_Train_Processed_Data/Sentinel2_Train_v26022025.csv')
satellite_Data = Landsat_Train.merge(Sentinel2_Train, on=["Latitude", "Longitude", "UHI Index"])


Geo_FP_train = pd.read_csv('_Train_Processed_Data/Train_Buildings_v14032025_NegativeMask.csv').select_dtypes("number")
train_FEdata = satellite_Data.merge(Geo_FP_train, on=["Latitude", "Longitude", "UHI Index"])


LandCover_Train = pd.read_csv('_Train_Processed_Data/Polygonized_LandCover_Train_IDW_500Circles_15032025.csv')
LandCover_Train = LandCover_Train.drop(columns=["datetime", "geometry"])
train_FEdata = train_FEdata.merge(LandCover_Train, on=["Latitude", "Longitude", "UHI Index"])

merged_train = train_gdf_split.merge(train_FEdata, on=["Latitude", "Longitude", "UHI Index"], how='inner')
merged_val = val_gdf_split.merge(train_FEdata, on=["Latitude", "Longitude", "UHI Index"], how='inner')
X_train, y_train = merged_train.drop(columns=["Latitude", "Longitude", "UHI Index", "geometry", "datetime"]), merged_train["UHI Index"]
X_val, y_val = merged_val.drop(columns=["Latitude", "Longitude", "UHI Index", "geometry", "datetime"]), merged_val["UHI Index"]

X_train

number of Top features :  251


Unnamed: 0,LandCover_CNN,NDVI_Landsat_mean_0_25,GNDVI_Landsat_mean_0_25,EVI_Landsat_mean_0_25,NDWI_Landsat_mean_0_25,MNDWI_Landsat_mean_0_25,NDBI_Landsat_mean_0_25,BUI_Landsat_mean_0_25,NDBSI_Landsat_mean_0_25,BI_Landsat_mean_0_25,...,vegetation_fragmentation_0_1000,IDW_vegetation_area_0_1000,vegetation_vertical_complexity_0_1000,closest_impervious_distance_0_1000,IDW_full_impervious_area_0_1000,largest_impervious_patch_0_1000,effective_albedo_0_1000,impervious_contiguity_0_1000,impervious_edge_density_0_1000,UHI_risk_index_0_1000
0,-0.002003,0.652412,0.603310,0.443548,-0.603310,-0.995921,0.982771,0.330175,0.995921,0.015381,...,0.000672,1.290229e+05,3.307903,4.566037,2.435746e+08,2.668397e+08,0.095146,0.608690,0.027274,0.447731
1,0.016050,0.201871,0.232142,0.127338,-0.232142,-0.993990,0.990308,0.788648,0.993990,0.079868,...,0.000517,2.770513e+03,3.274995,0.267323,1.333850e+08,2.668397e+08,0.111480,0.690433,0.025385,0.464369
2,-0.000353,0.488310,0.468337,0.307704,-0.468337,-0.995512,0.987214,0.498952,0.995512,0.029043,...,0.000670,4.326922e+03,3.318159,3.281855,2.453028e+08,2.668397e+08,0.103613,0.668536,0.024207,0.463354
3,-0.021167,0.224412,0.233008,0.116109,-0.233008,-0.995089,0.992148,0.767969,0.995089,0.056764,...,0.000296,5.276023e+04,3.308367,8.636637,9.765718e+07,2.668397e+08,0.081889,0.493932,0.020449,0.415259
4,0.009183,0.215856,0.228908,0.121622,-0.228908,-0.995028,0.991699,0.776458,0.995028,0.065269,...,0.000810,1.970201e+03,3.352945,1.030631,2.636623e+08,2.668397e+08,0.123026,0.778121,0.023478,0.487772
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
10101,-0.021690,0.441748,0.445051,0.295785,-0.445051,-0.994918,0.986308,0.544888,0.994918,0.039094,...,0.000476,2.032839e+04,3.306644,10.040297,1.103865e+08,2.668397e+08,0.103160,0.664894,0.024782,0.462358
10102,-0.012212,0.161643,0.190181,0.080648,-0.190181,-0.995316,0.992932,0.831686,0.995316,0.066801,...,0.000300,3.204989e+03,3.243334,2.217983,1.308809e+08,2.668397e+08,0.092419,0.564055,0.018799,0.432478
10103,-0.001134,0.287953,0.308946,0.188137,-0.308946,-0.994061,0.988562,0.700818,0.994061,0.060311,...,0.000212,7.648320e+03,3.560187,14.406976,7.708284e+07,2.668397e+08,0.087176,0.551061,0.025177,0.432394
10104,0.007139,0.110761,0.156293,0.075441,-0.156293,-0.992802,0.989830,0.878963,0.992802,0.112409,...,0.000676,1.171671e+03,3.338494,14.150980,7.026273e+07,2.668397e+08,0.107930,0.681247,0.029108,0.464162
