In [None]:
!pip install xarray netCDF4 scipy torch numpy matplotlib
import numpy as np
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
from scipy.io import loadmat
import xarray as xr
import joblib
from scipy.stats import pearsonr
import seaborn as sns
import os
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader


if torch.cuda.is_available():
  device = torch.device('cuda:0')
else:
  device = torch.device('cpu')
print('using device:', device)

Collecting netCDF4
  Downloading netcdf4-1.7.3-cp311-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (1.9 kB)
Collecting cftime (from netCDF4)
  Downloading cftime-1.6.5-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (8.7 kB)
Downloading netcdf4-1.7.3-cp311-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (9.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.5/9.5 MB[0m [31m29.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading cftime-1.6.5-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (1.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m16.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: cftime, netCDF4
Successfully installed cftime-1.6.5 netCDF4-1.7.3
using device: cpu


In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

data_folder = '/content/drive/MyDrive/DNN_Wildfire/'
mydir = data_folder
print("Files in folder:")
for i in os.listdir(data_folder):
  print(" ", i)

def get_path(filename):
  return os.path.join(data_folder, filename)

nc4_files = sorted([i for i in os.listdir(data_folder) if i.endswith('.nc4')])
mat_files = sorted([i for i in os.listdir(data_folder) if i.endswith('.mat')])

nc4_dataset = {}
for file in nc4_files:
  path = get_path(file)
  ds = xr.open_dataset(path)
  nc4_dataset[file] = ds
  print(f"Loaded: {file}")

mat_dataset = {}
for file in mat_files:
  path = get_path(file)
  ds = loadmat(path)
  mat_dataset[file] = ds
  print(f"Loaded: {file}")

Mounted at /content/drive
Files in folder:
  MC2_burnedFrac.nc4
  JSBACH-SPITFIRE_burnedFrac.nc4
  wildfire_surrogate4.mat
  wildfire_surrogate6.mat
  wildfire_surrogate9.mat
  LPJ-GUESS-GlobFIRM_burnedFrac.nc4
  JULES-INFERNO_burnedFrac.nc4
  wildfire_surrogate7.mat
  wildfire_surrogate8.mat
  wildfire_surrogate2.mat
  wildfire_surrogate5.mat
  wildfire_surrogate14.mat
  wildfire_surrogate11.mat
  wildfire_surrogate1.mat
  wildfire_surrogate10.mat
  CTEM_burnedFrac.nc4
  CLM_burnedFrac.nc4
  LPJ-GUESS-SPITFIRE_burnedFrac.nc4
  ORCHIDEE-SPITFIRE_burnedFrac.nc4
  wildfire_surrogate3.mat
  wildfire_surrogate13.mat
  wildfire_surrogate12.mat
  LPJ-GUESS-SIMFIRE-BLAZE_burnedFrac.nc4
  11_17_own_data
  11_22_transfer_learning
  11_14_results
  wildfire_surrogate1_DNN_softplus.pt
  11_29_transfer_learning
  global_wildfire_base_model.pt
Loaded: CLM_burnedFrac.nc4
Loaded: CTEM_burnedFrac.nc4
Loaded: JSBACH-SPITFIRE_burnedFrac.nc4
Loaded: JULES-INFERNO_burnedFrac.nc4
Loaded: LPJ-GUESS-GlobFIRM

In [None]:
class ANNWildfire(nn.Module):
  def __init__(self, input_dim):
    super().__init__()
    self.fc1 = nn.Linear(input_dim, 5)
    self.fc2 = nn.Linear(5, 5)
    self.fc3 = nn.Linear(5, 5)
    self.fc4 = nn.Linear(5, 5)
    self.fc5 = nn.Linear(5, 5)
    self.fc6 = nn.Linear(5, 1)
    self.softplus = nn.Softplus()
    #nn.init.uniform_(self.fc1.weight, a=-0.05, b=0.05)
    #if self.fc1.bias is not None:
      #nn.init.zeros_(self.fc1.bias)
    for m in self.modules():
      if isinstance(m, nn.Linear):
        nn.init.uniform_(m.weight, -0.05, 0.05)
        nn.init.zeros_(m.bias)

  def forward(self, x, y=None):
    pred = self.softplus(self.fc1(x))
    pred = self.softplus(self.fc2(pred))
    pred = self.softplus(self.fc3(pred))
    pred = self.softplus(self.fc4(pred))
    pred = self.softplus(self.fc5(pred))
    pred = self.softplus(self.fc6(pred))

    if y is not None:
      loss = nn.functional.mse_loss(pred, y)
      return loss, pred
    return pred




In [None]:
def train_ann(net, train_iter, lr, epochs, device, decay_steps=1000, decay_rate=0.99):
  net = net.to(device)
  optimizer = torch.optim.Adam(net.parameters(), lr=lr)
  loss_list = []
  print_interval = len(train_iter)
  total_iter = epochs * len(train_iter)

  for e in range(epochs):
    net.train()
    for i, train_data in enumerate(train_iter):
      train_data = [ds.to(device) for ds in train_data]
      #x_batch, y_batch = train_data[0], train_data[1]

      loss, pred = net(*train_data)

      loss_list.append(loss.mean().detach().cpu())
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      step = i + e * len(train_iter)
      new_lr = lr * (decay_rate ** (step/decay_steps))
      for param_group in optimizer.param_groups:
        param_group['lr'] = new_lr

      if step % print_interval == 0:
        print('iter {} / {}\tLoss:\t{:.6f}'.format(step, total_iter, loss.mean().detach()))
        print('pred:\t {}\n'.format(pred[0].detach().cpu()))
        print('tgt:\t {}\n'.format(train_data[1][0].cpu()))
        #print('tgt:\t {}\n'.format(y_batch[0].cpu()))
  return loss_list
##
def predict_model(net, X, device):
  net = net.to(device)
  net.eval()
  with torch.no_grad():
    X = X.to(device)
    prediction = net(X).cpu().numpy()
  return prediction



In [None]:
def train_region(region_idx, mydir, device, force_retrain=False):
  matfiles = os.path.join(mydir, f'wildfire_surrogate{region_idx+1}.mat')
  print(f"\n{'='*60}")
  print(f"Processing region {region_idx+1}")
  print('='*60)

  tmp = loadmat(matfiles)
  ELMX = tmp.get('ELMX')
  ELMy = tmp.get('ELMy')

  sc_X = MinMaxScaler(feature_range=(0, 1))
  sc_y = MinMaxScaler(feature_range=(0, 1))

  X = sc_X.fit_transform(ELMX)
  y = sc_y.fit_transform(ELMy.reshape(-1, 1))

  scaler_filename_X = os.path.join(mydir, f"scaler_X{region_idx+1}.mat")
  scaler_filename_y = os.path.join(mydir, f"scaler_y{region_idx+1}.mat")
  joblib.dump(sc_X, scaler_filename_X)
  joblib.dump(sc_y, scaler_filename_y)

  y_1 = np.percentile(y, 33)
  y_2 = np.percentile(y, 66)
  y_3 = np.max(y)
  strata_y = np.full([len(y), 1], 0)
  for j in range(len(y)):
    if y[j] <= y_1:
      strata_y[j] = 1
    elif y[j] <= y_2:
      strata_y[j] = 2
    elif y[j] <= y_3:
      strata_y[j] = 3

  X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, stratify=strata_y, random_state=0)

  train_dataset = TensorDataset(
      torch.tensor(X_train, dtype=torch.float32),
      torch.tensor(y_train, dtype=torch.float32)
  )
  train_iter = DataLoader(train_dataset, batch_size=20, shuffle=True)

  model_path = os.path.join(mydir, f'wildfire_surrogate{region_idx+1}_ANN_softplus.pt')

  net = ANNWildfire(input_dim=X_train.shape[1])

  if os.path.exists(model_path) and not force_retrain:
    print(f"Loading model")
    net.load_state_dict(torch.load(model_path, map_location=device))
  else:
    print(f"Training model")
    loss_list = train_ann(net, train_iter, lr=0.01, epochs=30, device=device, decay_steps=1000, decay_rate=0.99)
    torch.save(net.state_dict(), model_path)

  X_all = sc_X.transform(ELMX)
  X_all_tensor = torch.tensor(X_all, dtype=torch.float32)
  y_pred = predict_model(net, X_all_tensor, device)

  ann_y = sc_y.inverse_transform(y_pred.reshape(-1, 1)).reshape(-1, 360)
  data_y = sc_y.inverse_transform(y.reshape(-1, 1)).reshape(-1, 360)

  return ann_y, data_y



In [None]:
def tune_region(region_idx, mydir, device, obs_name='ensemble'):
  print(f"\n{'='*60}")
  print(f"TUNING region {region_idx+1}")
  print(f"{'='*60}")

  matfiles = os.path.join(mydir, f'wildfire_surrogate{region_idx+1}.mat')
  tmp = loadmat(matfiles)
  OBSX = tmp.get('OBSX')
  OBSy_gfed = tmp.get('OBSy')
  OBSy_cci51 = tmp.get('OBSy_cci51')
  OBSy_ccilt11 = tmp.get('OBSy_ccilt11')
  OBSy_mcd64 = tmp.get('OBSy_mcd64')
  OBSy_atlas = tmp.get('OBSy_atlas')

  OBSy = np.mean(np.hstack((OBSy_gfed, OBSy_cci51, OBSy_ccilt11, OBSy_mcd64, OBSy_atlas)), axis=1).reshape(-1, 1)

  scaler_filename_X = os.path.join(mydir, f"scaler_X{region_idx+1}.mat")
  scaler_filename_y = os.path.join(mydir, f"scaler_y{region_idx+1}.mat")
  sc_X = joblib.load(scaler_filename_X)
  sc_y = joblib.load(scaler_filename_y)

  X = sc_X.transform(OBSX)
  y = sc_y.transform(OBSy.reshape(-1, 1))

  y_1 = np.percentile(y, 33)
  y_2 = np.percentile(y, 66)
  y_3 = np.max(y)
  strata_y = np.full([len(y), 1], 0)
  for j in range(len(y)):
    if y[j] <= y_1:
      strata_y[j] = 1
    elif y[j] <= y_2:
      strata_y[j] = 2
    elif y[j] <= y_3:
      strata_y[j] = 3

  X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, stratify=strata_y, random_state=0)
  train_dataset = TensorDataset(
      torch.tensor(X_train, dtype=torch.float32),
      torch.tensor(y_train, dtype=torch.float32)
  )
  train_iter = DataLoader(train_dataset, batch_size=20, shuffle=True)

  base_model_path = os.path.join(mydir, f'wildfire_surrogate{region_idx+1}_ANN_softplus.pt')
  net = ANNWildfire(input_dim=X_train.shape[1])
  print(f"Loading base model: {base_model_path}")
  net.load_state_dict(torch.load(base_model_path, map_location=device))

  tuned_model_path = ''
  lr = 0.001
  decay_steps = 1000
  decay_rate = 0.9

  if region_idx == 4 or region_idx == 8:
    lr = 0.005
    decay_steps = 3000
    decay_rate = 0.99
    tuned_model_path = os.path.join(mydir, f'wildfire_surrogate{region_idx+1}_ANN_softplus_{obs_name}_tuned4.pt')
  elif region_idx == 7:
    lr = 0.005
    decay_steps = 1000
    decay_rate = 0.99
    tuned_model_path = os.path.join(mydir, f'wildfire_surrogate{region_idx+1}_ANN_softplus_{obs_name}_tuned4.pt')
  else:
    tuned_model_path = os.path.join(mydir, f'wildfire_surrogate{region_idx+1}_ANN_softplus_{obs_name}_tuned2.pt')

  print(f"Tuning model with learning rate = {lr}. Saving to: {tuned_model_path}")
  loss_list = train_ann(net, train_iter, lr=lr, epochs=100, device=device, decay_steps=decay_steps, decay_rate=decay_rate)
  torch.save(net.state_dict(), tuned_model_path)

  #create tune_ensemble.csv files
  X_all_tensor = torch.tensor(X, dtype=torch.float32)
  y_pred = predict_model(net, X_all_tensor, device)

  ann_y = sc_y.inverse_transform(y_pred.reshape(-1, 1)).reshape(-1, 120)
  data_y = sc_y.inverse_transform(y_pred.reshape(-1, 1)).reshape(-1, 120)

  return np.sum(ann_y, 0), np.sum(data_y, 0)


In [None]:
class DNNWildfire(nn.Module):
  def __init__(self, input_dim):
    super().__init__()
    self.fc1 = nn.Linear(input_dim, 5)
    self.fc2 = nn.Linear(5, 5)
    self.fc3 = nn.Linear(5, 5)
    self.fc4 = nn.Linear(5, 5)
    self.fc5 = nn.Linear(5, 5)
    self.fc6 = nn.Linear(5, 1)
    self.softplus = nn.Softplus()
    #nn.init.uniform_(self.fc1.weight, a=-0.05, b=0.05)
    #if self.fc1.bias is not None:
      #nn.init.zeros_(self.fc1.bias)
    for m in self.modules():
      if isinstance(m, nn.Linear):
        nn.init.uniform_(m.weight, -0.05, 0.05)
        nn.init.zeros_(m.bias)

  def forward(self, x, y=None):
    pred = self.softplus(self.fc1(x))
    pred = self.softplus(self.fc2(pred))
    pred = self.softplus(self.fc3(pred))
    pred = self.softplus(self.fc4(pred))
    pred = self.softplus(self.fc5(pred))
    #pred = self.softplus(self.fc6(pred))
    pred = self.fc6(pred)


    if y is not None:
      loss = nn.functional.mse_loss(pred, y)
      return loss, pred
    return pred




In [None]:
def validate_epoch(net, val_iter, device):
  net.eval()
  total_loss = 0.0
  count = 0
  with torch.no_grad():
    for val_data in val_iter:
      val_data = [ds.to(device) for ds in val_data]
      loss, pred = net(*val_data)
      total_loss += loss.mean().item()
      count += 1
  net.train()
  return total_loss / count

In [None]:
def train_dnn(net, train_iter, val_iter, lr, epochs, device, decay_steps=1000, decay_rate=0.99):
  net = net.to(device)
  optimizer = torch.optim.Adam(net.parameters(), lr=lr)
  train_loss_per_epoch = []
  val_loss_per_epoch = []
  print_interval = len(train_iter)
  total_iter = epochs * len(train_iter)

  for e in range(epochs):
    net.train()
    total_train_loss_sum = 0.0
    batch_count = 0
    for i, train_data in enumerate(train_iter):
      train_data = [ds.to(device) for ds in train_data]
      #x_batch, y_batch = train_data[0], train_data[1]

      loss, pred = net(*train_data)

      total_train_loss_sum += loss.mean().item()
      #loss_list.append(loss.mean().detach().cpu())
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      step = i + e * len(train_iter)
      new_lr = lr * (decay_rate ** (step/decay_steps))
      for param_group in optimizer.param_groups:
        param_group['lr'] = new_lr

      if step % print_interval == 0:
        print('iter {} / {}\tLoss:\t{:.6f}'.format(step, total_iter, loss.mean().detach()))
        print('pred:\t {}\n'.format(pred[0].detach().cpu()))
        print('tgt:\t {}\n'.format(train_data[1][0].cpu()))
        #print('tgt:\t {}\n'.format(y_batch[0].cpu()))

    #avg_train_loss = total_train_loss_sum/batch_count
    #train_loss_per_epoch.append(avg_train_loss)

    if batch_count > 0:
        avg_train_loss = total_train_loss_sum / batch_count
    else:
        avg_train_loss = 0.0
        print(f"WARNING: Fold {e+1} had an empty training batch (batch_count=0).")

    train_loss_per_epoch.append(avg_train_loss)

    val_loss = validate_epoch(net, val_iter, device)
    val_loss_per_epoch.append(val_loss)
  return train_loss_per_epoch, val_loss_per_epoch
##
def predict_model(net, X, device):
  net = net.to(device)
  net.eval()
  with torch.no_grad():
    X = X.to(device)
    prediction = net(X).cpu().numpy()
  return prediction



In [None]:
def train_model_from_own_dataset(csv_path, mydir, device, n_splits=10, epochs=100, lr=0.005, batch_size=8, random_state=0, use_log_transform=True):
    print(f"\nLoad data from: {csv_path}")
    tmp = pd.read_csv(csv_path)

    target_name = 'Burned_Area'
    non_feature_cols = ['Date', target_name]
    feature_names = [c for c in tmp.columns if c not in non_feature_cols]

    print(f"Found {len(feature_names)} features: ")
    print(feature_names)

    tmp = tmp.dropna(subset=feature_names + [target_name]).reset_index(drop=True)

    all_dates = pd.to_datetime(tmp['Date'].values)

    X = tmp[feature_names].values.astype(float)
    y = tmp[target_name].values.astype(float).reshape(-1, 1)

    if use_log_transform:
        print("\nApply log1p transform to target")
        y_trans = np.log1p(y)  # log(1 + y)
    else:
        print("\nNo transform applied to target")
        y_trans = y.copy()

    ##
    sc_X = StandardScaler().fit(X)
    sc_y = StandardScaler().fit(y_trans)

    X_scaled = sc_X.fit_transform(X)
    y_scaled = sc_y.fit_transform(y_trans)

    #joblib.dump(sc_X, os.path.join(mydir, "scaler_X_own_dataset.pkl"))
    #joblib.dump(sc_y, os.path.join(mydir, "scaler_y_own_dataset.pkl"))
    #joblib.dump({'transform_method': transform_method}, os.path.join(mydir, "preprocessing.pkl"))

    y_1 = np.percentile(y, 33)
    y_2 = np.percentile(y, 66)
    y_3 = np.max(y_scaled)
    strata_y = np.full([len(y_scaled), 1], 0)
    for j in range(len(y_scaled)):
      if y_scaled[j] <= y_1:
        strata_y[j] = 1
      elif y_scaled[j] <= y_2:
        strata_y[j] = 2
      elif y_scaled[j] <= y_3:
        strata_y[j] = 3

    X_kfold, X_test, y_kfold, y_test, dates_kfold, dates_test, strata_kfold, strata_test = train_test_split(
        X_scaled, y_scaled, all_dates, strata_y, test_size=0.20, stratify=strata_y, random_state=random_state
    )

    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_state)
    fold_metrics = []
    all_preds_val = []

    for fold, (train_idx, val_idx) in enumerate(skf.split(X_kfold, strata_kfold), start=1):
      print(f"\nFold {fold}/{n_splits}: train={len(train_idx)}, val={len(val_idx)}")
      X_train, X_val = X_kfold[train_idx], X_kfold[val_idx]
      y_train, y_val = y_kfold[train_idx], y_kfold[val_idx]
      dates_train, dates_val = dates_kfold[train_idx], dates_kfold[val_idx]

      if len(X_train) < 2:
        print(f"Skipping fold {fold} due to training set being too small")
        continue

      batch_size_train = min(batch_size, max(1, len(X_train)))
      batch_size_val = min(batch_size, len(X_val))
      train_dataset = TensorDataset(
        torch.tensor(X_train, dtype=torch.float32),
        torch.tensor(y_train, dtype=torch.float32)
      )
      val_dataset = TensorDataset(
        torch.tensor(X_val, dtype=torch.float32),
        torch.tensor(y_val, dtype=torch.float32)
      )
      train_loader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True)
      val_loader = DataLoader(val_dataset, batch_size=batch_size_val, shuffle=False)
      input_dim = X.shape[1]

      net = build_tl_model(
        pretrained_path="/path/to/pretrained.pt",
        input_dim=input_dim,
        device=device,
        freeze_until="fc5"
      )

      print(f"DEBUG: Calculated input_dim: {input_dim}")
      optimizer = torch.optim.Adam(net.parameters(), lr=lr)
      train_losses = []
      val_losses = []
      best_val_loss = np.inf
      best_state = None
      for epoch in range(1, epochs+1):
        net.train()
        running_loss = 0.0
        batch_count = 0
        for xb, yb in train_loader:
            xb = xb.to(device)
            yb = yb.to(device)

            optimizer.zero_grad()
            loss, pred = net(xb, yb)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=1.0)
            optimizer.step()

            running_loss += loss.item()
            batch_count += 1

        epoch_train_loss = running_loss / batch_count if batch_count > 0 else np.nan
        train_losses.append(epoch_train_loss)

        # validation
        net.eval()
        val_running = 0.0
        val_count = 0
        with torch.no_grad():
            for xb, yb in val_loader:
                xb = xb.to(device)
                yb = yb.to(device)
                loss_val, pred_val = net(xb, yb)
                val_running += loss_val.item()
                val_count += 1
        epoch_val_loss = val_running / val_count if val_count > 0 else np.nan
        val_losses.append(epoch_val_loss)

        if epoch_val_loss < best_val_loss:
            best_val_loss = epoch_val_loss
            best_state = net.state_dict()

      # restore best model
      if best_state is not None:
        net.load_state_dict(best_state)
      # save fold model
      fold_model_path = os.path.join(mydir, f'own_dataset_kfold_f{fold}_model.pt')
      torch.save(net.state_dict(), fold_model_path)
      print("Saved fold model to: ", fold_model_path)

      net.eval()
      X_val_tensor = torch.tensor(X_val, dtype=torch.float32).to(device)
      with torch.no_grad():
        preds_val = net(X_val_tensor).cpu().numpy().flatten()
      all_preds_val.append(preds_val)

      y_val_transformed = sc_y.inverse_transform(y_val.reshape(-1, 1)).flatten()
      preds_val_transformed = sc_y.inverse_transform(preds_val.reshape(-1, 1)).flatten()
      if use_log_transform:
        y_val_actual = np.expm1(y_val_transformed)
        preds_val_actual = np.expm1(preds_val_transformed)
      else:
        y_val_actual = y_val_transformed
        preds_val_actual = preds_val_transformed

      r2 = r2_score(y_val_actual, preds_val_actual)
      mse = mean_squared_error(y_val_actual, preds_val_actual)
      fold_metrics.append({'fold': fold, 'r2': r2, 'mse': mse, 'train_losses': train_losses, 'val_losses': val_losses})
    print("\nK-Fold cross validation completed")

    # k-fold loss plot
    plt.figure(figsize=(10,6))
    for m in fold_metrics:
      fold = m['fold']
      r2 = m['r2']
      trainloss = m['train_losses']
      valloss = m['val_losses']
      epochs = np.arange(1, len(trainloss)+1)
      plt.plot(epochs, trainloss, label=f"Fold {fold} - Train (R² = {r2:.3f})", alpha=0.6)
      plt.plot(epochs, valloss, label=f"Fold {fold} - Validation (R² = {r2:.3f})", alpha=0.6, linestyle='--')
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("K-Fold Loss Curve")
    plt.legend()
    plt.grid(alpha=0.3)
    plt.tight_layout()
    kfold_path = os.path.join(mydir, 'k_fold_loss_1203.png')
    plt.savefig(kfold_path, dpi=200)
    plt.close()
    print("K-Fold loss plot saved to: ", kfold_path)

    X_test_tensor = torch.tensor(X_test, dtype=torch.float32).to(device)
    all_fold_predictions = []

    print(f"Running Ensemble Prediction on {len(X_test)} Test Samples")
    for fold in range(1, n_splits + 1):
        net_fold = DNNWildfire(input_dim=X.shape[1]).to(device)
        model_path = os.path.join(mydir, f'own_dataset_kfold_f{fold}_model.pt')
        net_fold.load_state_dict(torch.load(model_path, map_location=device))

        net_fold.eval()
        with torch.no_grad():
          preds_fold = net_fold(X_test_tensor).cpu().numpy().flatten()
          all_fold_predictions.append(preds_fold)

    avg_preds_kfold = np.mean(all_fold_predictions, axis=0)
    y_pred_transformed = sc_y.inverse_transform(avg_preds_kfold.reshape(-1, 1)).flatten()
    y_test_transformed = sc_y.inverse_transform(y_test).flatten()

    if use_log_transform:
      y_pred_actual = np.expm1(y_pred_transformed)
      y_test_actual = np.expm1(y_test_transformed)
    else:
      y_pred_actual = y_pred_transformed
      y_test_actual = y_test_transformed

    final_r2 = r2_score(y_test_actual, y_pred_actual)

    print(f"\nFINAL ENSEMBLE PERFORMANCE (on held-out Test Set):")
    print(f"R-squared (R²) Score: {final_r2:.3f}")

    # line plot actual vs pred
    plot_tmp = pd.DataFrame({
        'Date': dates_test,
        'Actual': y_test_actual,
        'Predicted': y_pred_actual
    })
    plot_tmp = plot_tmp.sort_values(by='Date')
    r2 = r2_score(plot_tmp['Actual'], plot_tmp['Predicted'])
    plt.figure(figsize=(12, 6))
    plt.plot(plot_tmp['Date'], plot_tmp['Actual'], label="Actual Burned Area", color='red', linewidth=2)
    plt.plot(plot_tmp['Date'], plot_tmp['Predicted'], label="Predicted Burned Area", color='blue', linestyle='--', linewidth=1.5)
    plt.xlabel("Date", fontsize=14, fontweight='bold')
    plt.ylabel(f"Burned Area (Mha)", fontsize=14, fontweight='bold')
    plt.title(f"Model Time-Series Comparison on Full Dataset (R²: {r2:.3f})", fontsize=16)
    plt.legend()
    plt.grid(True, alpha=0.3)
    line_path = os.path.join(mydir, 'kfold_line_plot_actual_vs_pred_1203.png')
    plt.savefig(line_path, dpi=200)
    plt.close()
    print("Line plot saved to: ", line_path)

    # time-series prediction visualization plot
    N = len(y_test_actual)
    plt.figure(figsize=(12,6))
    plt.plot(range(N), y_test_actual, label='Actual Burned Area', color='blue', linewidth=2)
    plt.plot(range(N), y_pred_actual, label='Predicted Burned Area', color='yellow', linewidth=2)
    plt.title(f"Time-Series Prediction Visualization (R² = {r2:.3f})")
    plt.legend()
    visualization_path = os.path.join(mydir, 'kfold_time-series_prediction_visualization_1203.png')
    plt.savefig(visualization_path, dpi=200)
    plt.close()
    print("Time-series prediction visualization saved to: ", visualization_path)


    return net, {'folds': fold_metrics, 'scalers': (sc_X, sc_y)}


In [None]:
def build_tl_model(pretrained_path, input_dim, device, freeze_until='fc5'):
    model = DNNWildfire(input_dim).to(device)

    if pretrained_path is None or not os.path.exists(pretrained_path):
        print(f"[TL] No pretrained model found at {pretrained_path}. ")
        return model

    state = torch.load(pretrained_path, map_location=device)
    if isinstance(state, dict) and 'state_dict' in state and not any(k.startswith('fc') for k in state.keys()):
        state = state['state_dict']

    model_state = model.state_dict()
    keys_copied = []
    for k, v in list(state.items()):
        if k in model_state and model_state[k].shape == v.shape:
            model_state[k] = v.clone()
            keys_copied.append(k)

    model.load_state_dict(model_state)
    missing = [k for k in model_state.keys() if k not in keys_copied]
    print(f"[TL] Loaded pretrained keys: {len(keys_copied)}. Skipped/mismatched keys: {len(missing)}.")

    order = ['fc1','fc2','fc3','fc4','fc5']
    if freeze_until is None:
        for name, param in model.named_parameters():
          param.requires_grad = name.startswith('fc6')
    else:
        freeze_flag = True
        for lname in order:
            layer = getattr(model, lname)
            for p in layer.parameters():
                p.requires_grad = freeze_flag
            if lname == freeze_until:
                freeze_flag = False
        for p in model.fc6.parameters():
            p.requires_grad = True

    return model

def plot_kfold_histories(train_histories, val_histories, fold_results, out_path_loss):
    plt.figure(figsize=(9,6))
    K = len(train_histories)
    for i in range(K):
        tr = train_histories[i]
        va = val_histories[i]
        if tr is None or len(tr) == 0:
            continue
        r2 = fold_results[i].get('r2') if i < len(fold_results) else None
        r2_text = f" (R²={r2:.3f})" if r2 is not None else ""
        epochs = np.arange(1, len(tr)+1)
        plt.plot(epochs, tr, label=f"Fold {i+1} Train{r2_text}", alpha=0.7)
        plt.plot(epochs, va, linestyle='--', label=f"Fold {i+1} Val{r2_text}", alpha=0.6)
    plt.xlabel("Epoch")
    plt.ylabel("Loss (MSE scaled space)")
    plt.title("K-Fold Train / Validation Loss")
    plt.grid(alpha=0.2)
    plt.legend(fontsize=9, loc='best')
    plt.tight_layout()
    plt.savefig(out_path_loss, dpi=200)
    plt.close()
    print(f"Saved K-fold loss plot: {out_path_loss}")


def train_kfold_with_tl(csv_path, mydir, device,
                        pretrained_path=None,
                        freeze_until='fc5',
                        n_splits=5, epochs=60, lr=1e-3, batch_size=8,
                        use_log_transform=True, random_state=0,
                        out_dir=None):
    if out_dir is None:
        out_dir = mydir
    os.makedirs(out_dir, exist_ok=True)

    df = pd.read_csv(csv_path)
    target_name = 'Burned_Area'
    feature_names = [c for c in df.columns if c not in ['Date', target_name]]
    df = df.dropna(subset=feature_names + [target_name]).reset_index(drop=True)

    X_raw = df[feature_names].values.astype(float)
    y_raw = df[target_name].values.astype(float).reshape(-1,1)
    dates = pd.to_datetime(df['Date'].values)

    if use_log_transform:
        y_trans = np.log1p(y_raw)
    else:
        y_trans = y_raw.copy()

    sc_X = StandardScaler().fit(X_raw)
    sc_y = StandardScaler().fit(y_trans)
    X_scaled = sc_X.transform(X_raw)
    y_scaled = sc_y.transform(y_trans)

    joblib.dump(sc_X, os.path.join(out_dir, "scaler_X_mydata.pkl"))
    joblib.dump(sc_y, os.path.join(out_dir, "scaler_y_mydata.pkl"))

    y_1 = np.percentile(y_trans, 33)
    y_2 = np.percentile(y_trans, 66)
    y_3 = np.max(y_trans)
    strata = np.full([len(y_trans), 1], 0)
    for j in range(len(y_trans)):
      if y_trans[j] <= y_1:
        strata[j] = 1
      elif y_trans[j] <= y_2:
        strata[j] = 2
      elif y_trans[j] <= y_3:
        strata[j] = 3

    X_kf_pool, X_test, y_kf_pool, y_test, dates_kf_pool, dates_test, strata_kf_pool, strata_test = \
        train_test_split(X_scaled, y_scaled, dates, strata, test_size=0.20, stratify=strata, random_state=random_state)

    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_state)

    fold_results = []
    fold_train_histories = []
    fold_val_histories = []

    for fold_idx, (train_idx, val_idx) in enumerate(skf.split(X_kf_pool, strata_kf_pool), start=1):
        print(f"\n--- Fold {fold_idx}/{n_splits} ---")
        X_train, X_val = X_kf_pool[train_idx], X_kf_pool[val_idx]
        y_train, y_val = y_kf_pool[train_idx], y_kf_pool[val_idx]
        dates_train, dates_val = dates_kf_pool[train_idx], dates_kf_pool[val_idx]

        net = DNNWildfire(input_dim=X_train.shape[1]).to(device)

        if pretrained_path and os.path.exists(pretrained_path):
          pretrained_dict = torch.load(pretrained_path, map_location=device)
          model_dict = net.state_dict()

          valid_weights = {k: v for k, v in pretrained_dict.items()
                                 if k in model_dict and v.shape == model_dict[k].shape}

          model_dict.update(valid_weights)
          net.load_state_dict(model_dict)

          if freeze_until:
            for name, param in net.named_parameters():
              if 'fc6' in name or 'fc1' in name:
                param.requires_grad = True
              elif freeze_until in name or name < freeze_until:
                param.requires_grad = False


        bs_train = min(batch_size, max(1, len(X_train)))
        bs_val = min(batch_size, max(1, len(X_val)))
        train_loader = DataLoader(TensorDataset(torch.tensor(X_train,dtype=torch.float32),
                                               torch.tensor(y_train,dtype=torch.float32)),
                                  batch_size=bs_train, shuffle=True)
        val_loader = DataLoader(TensorDataset(torch.tensor(X_val,dtype=torch.float32),
                                             torch.tensor(y_val,dtype=torch.float32)),
                                batch_size=bs_val, shuffle=False)

        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=lr)

        best_state = None
        best_val = np.inf
        train_hist = []
        val_hist = []

        for ep in range(1, epochs+1):
            net.train()
            running_loss, batch = 0.0, 0
            for xb, yb in train_loader:
                xb, yb = xb.to(device), yb.to(device)
                optimizer.zero_grad()
                loss, _ = net(xb, yb)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(net.parameters(), 1.0)
                optimizer.step()
                running_loss += loss.item()
                batch += 1
            epoch_train_loss = running_loss / batch if batch > 0 else np.nan
            train_hist.append(epoch_train_loss)

            net.eval()
            val_running, val_batch = 0.0, 0
            with torch.no_grad():
                for xb, yb in val_loader:
                    xb, yb = xb.to(device), yb.to(device)
                    vloss, _ = net(xb, yb)
                    val_running += vloss.item(); val_batch += 1
            epoch_val_loss = val_running / val_batch if val_batch > 0 else np.nan
            val_hist.append(epoch_val_loss)

            if epoch_val_loss < best_val:
                best_val = epoch_val_loss
                best_state = {k:v.cpu() for k,v in net.state_dict().items()}

        if best_state is not None:
            net.load_state_dict(best_state)
        fold_model_path = os.path.join(out_dir, f"kfold_fold{fold_idx}.pt")
        torch.save(net.state_dict(), fold_model_path)
        print(f"Saved fold model: {fold_model_path}")

        net.eval()
        with torch.no_grad():
            preds_val_scaled = net(torch.tensor(X_val, dtype=torch.float32).to(device))
            if isinstance(preds_val_scaled, tuple):
                _, preds_val_scaled = preds_val_scaled
            preds_val_scaled = preds_val_scaled.cpu().numpy().flatten()

        preds_val_back = sc_y.inverse_transform(preds_val_scaled.reshape(-1,1)).flatten()
        y_val_back = sc_y.inverse_transform(y_val.reshape(-1,1)).flatten()
        if use_log_transform:
            preds_val_actual = np.expm1(preds_val_back)
            y_val_actual = np.expm1(y_val_back)
        else:
            preds_val_actual = preds_val_back
            y_val_actual = y_val_back

        r2 = r2_score(y_val_actual, preds_val_actual)
        mse = mean_squared_error(y_val_actual, preds_val_actual)
        print(f"Fold {fold_idx}: val R2={r2:.4f}, mse={mse:.4f}")

        fold_results.append({'fold': fold_idx, 'r2': r2, 'mse': mse})
        fold_train_histories.append(train_hist)
        fold_val_histories.append(val_hist)

    out_loss = os.path.join(out_dir, "transfer_learning_kfold_loss.png")
    plot_kfold_histories(fold_train_histories, fold_val_histories, fold_results, out_loss)

    X_test_tensor = torch.tensor(X_test, dtype=torch.float32).to(device)
    all_fold_preds_test = []
    for fold_idx in range(1, len(fold_train_histories)+1):
        model_path = os.path.join(out_dir, f"kfold_fold{fold_idx}.pt")
        if not os.path.exists(model_path):
            continue
        m = DNNWildfire(input_dim=X_test.shape[1]).to(device)
        st = torch.load(model_path, map_location=device)

        ms = m.state_dict()
        for k,v in list(st.items()):
            if k in ms and ms[k].shape == v.shape:
                ms[k] = v
        m.load_state_dict(ms)
        m.eval()
        with torch.no_grad():
            pr = m(X_test_tensor)
            if isinstance(pr, tuple):
                _, pr = pr
            all_fold_preds_test.append(pr.cpu().numpy().flatten())

    avg_preds = np.mean(np.stack(all_fold_preds_test, axis=1), axis=1)
    preds_back = sc_y.inverse_transform(avg_preds.reshape(-1,1)).flatten()
    y_test_back = sc_y.inverse_transform(y_test.reshape(-1,1)).flatten()
    if use_log_transform:
        preds_actual = np.expm1(preds_back)
        y_test_actual = np.expm1(y_test_back)
    else:
        preds_actual = preds_back
        y_test_actual = y_test_back

    final_r2 = r2_score(y_test_actual, preds_actual)
    final_mse = mean_squared_error(y_test_actual, preds_actual)
    print(f"\nFINAL ENSEMBLE PERFORMANCE (on held-out Test set): R2={final_r2:.4f}, MSE={final_mse:.4f}")

    # line plot actual vs pred
    df_line = pd.DataFrame({'Date': dates_test, 'Actual': y_test_actual, 'Predicted': preds_actual})
    df_line = df_line.sort_values('Date')
    out_line = os.path.join(out_dir, "transfer_learning_kfold_line_plot.png")
    plt.figure(figsize=(12,6))
    plt.plot(df_line['Date'], df_line['Actual'], label='Actual', color='red', linewidth=2)
    plt.plot(df_line['Date'], df_line['Predicted'], label='Predicted', color='blue', linestyle='--')
    plt.title(f"K-Fold Ensemble: Actual vs Pred (R²={final_r2:.3f})")
    plt.xlabel("Date"); plt.ylabel("Burned area")
    plt.legend(); plt.grid(alpha=0.2)
    plt.tight_layout(); plt.savefig(out_line, dpi=200); plt.close()
    print(f"Saved line-plot: {out_line}")

    # time-series visualization
    out_vis = os.path.join(out_dir, "transfer_learning_kfold_time-series_prediction_visualization.png")
    plt.figure(figsize=(12,5))
    plt.plot(range(len(y_test_actual)), y_test_actual, label='Actual', linewidth=2)
    plt.plot(range(len(preds_actual)), preds_actual, label='Predicted', linewidth=2)
    plt.title(f"Time-series Test Predictions (R²={final_r2:.3f})")
    plt.legend(); plt.grid(alpha=0.2)
    plt.tight_layout(); plt.savefig(out_vis, dpi=200); plt.close()
    print(f"Saved time-series visualization: {out_vis}")

    return {
        'fold_results': fold_results,
        'fold_train_histories': fold_train_histories,
        'fold_val_histories': fold_val_histories,
        'scalers': (sc_X, sc_y),
        'final_test': {'r2': final_r2, 'mse': final_mse},
        'out_dir': out_dir
    }


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
csv_path = "/content/drive/MyDrive/DNN_Wildfire/11_17_own_data/DeepLearning_Climate_Biomass_Human_Fire_dataset_2001_2024_montly_dataset.csv"
mydir = "/content/drive/MyDrive/DNN_Wildfire/11_14_results"
outdir = "/content/drive/MyDrive/DNN_Wildfire/11_29_transfer_learning"

pretrained_model_path = os.path.join(mydir, "wildfire_surrogate6_ANN_softplus.pt")

results = train_kfold_with_tl(
    csv_path=csv_path,
    mydir=mydir,
    device=device,
    pretrained_path=pretrained_model_path,
    freeze_until='fc3',
    n_splits=10,
    epochs=100,
    lr=1e-3,
    batch_size=8,
    out_dir=os.path.join(outdir, "kfold_tl_results")
)



--- Fold 1/10 ---
Saved fold model: /content/drive/MyDrive/DNN_Wildfire/11_29_transfer_learning/kfold_tl_results/kfold_fold1.pt
Fold 1: val R2=0.5317, mse=4.1938

--- Fold 2/10 ---
Saved fold model: /content/drive/MyDrive/DNN_Wildfire/11_29_transfer_learning/kfold_tl_results/kfold_fold2.pt
Fold 2: val R2=0.7274, mse=1.3607

--- Fold 3/10 ---
Saved fold model: /content/drive/MyDrive/DNN_Wildfire/11_29_transfer_learning/kfold_tl_results/kfold_fold3.pt
Fold 3: val R2=0.2757, mse=20.7231

--- Fold 4/10 ---
Saved fold model: /content/drive/MyDrive/DNN_Wildfire/11_29_transfer_learning/kfold_tl_results/kfold_fold4.pt
Fold 4: val R2=0.5194, mse=5.5238

--- Fold 5/10 ---
Saved fold model: /content/drive/MyDrive/DNN_Wildfire/11_29_transfer_learning/kfold_tl_results/kfold_fold5.pt
Fold 5: val R2=0.3752, mse=7.2138

--- Fold 6/10 ---
Saved fold model: /content/drive/MyDrive/DNN_Wildfire/11_29_transfer_learning/kfold_tl_results/kfold_fold6.pt
Fold 6: val R2=0.1504, mse=34.7134

--- Fold 7/10 ---
S