# Imports

In [None]:
import numpy as np
import os
from os import listdir as LD, makedirs as MDs
from os.path import join as PJ, basename as PBN, dirname as PDN, exists as PE
import pandas as pd
import shutil as sh
from datetime import datetime as DT
from matplotlib import pyplot as plt

In [None]:
from imod import msw
from imod import mf6

In [None]:
import WS_Mdl.utils as U
import WS_Mdl.utils_imod as UIM

In [None]:
import importlib as IL
IL.reload(U)
IL.reload(UIM)

In [None]:
# import imod
# import imod.util.path
# from imod.formats.prj.prj import open_projectfile_data
# from imod.logging.config import LoggerType
# from imod.logging.loglevel import LogLevel
# from imod.mf6.oc import OutputControl
# from imod.mf6.simulation import Modflow6Simulation
# from imod.mf6.ims import Solution

# Options + Basics

In [None]:
MdlN = 'NBr31'

In [None]:
U.set_verbose(False)

In [None]:
d_Pa = U.get_MdlN_Pa(MdlN)
Pa_PRJ = d_Pa['PRJ']
Dir_PRJ = PDN(Pa_PRJ)
d_INI = U.INI_to_d(d_Pa['INI'])

In [None]:
Xmin, Ymin, Xmax, Ymax = [float(i) for i in d_INI['WINDOW'].split(',')]
cellsize = float(d_INI['CELLSIZE'])
N_R, N_C = int( - (Ymin - Ymax) / cellsize ), int( (Xmax - Xmin) / cellsize )

In [None]:
SP_date_1st, SP_date_last = [DT.strftime(DT.strptime(d_INI[f'{i}'], '%Y%m%d'), '%Y-%m-%d') for i in ['SDATE', 'EDATE']]

# Read PRJ

In [None]:
PRJ_content = UIM.read_PRJ_with_OBS(Pa_PRJ)[0]

In [None]:
PRJ_content['(bnd)']

# Load PRJ

In [None]:
PRJ_, PRJ_OBS = UIM.open_PRJ_with_OBS(Pa_PRJ)

In [None]:
Pa_PRJ

In [None]:
PRJ = PRJ_[0]

In [None]:
period_data = PRJ_[1]

In [None]:
# PRJ.keys()

In [None]:
# period_data

# Load MF6 Mdl

In [None]:
times = pd.date_range(SP_date_1st, SP_date_last, freq='D')

In [None]:
PRJ_no_cap.keys()

In [None]:
PRJ_no_cap['bnd']

In [None]:
PRJ_no_cap['chd-1']['head']

In [None]:
# Convert IMOD5 to MODFLOW6 (without CAP package due to mixed grid issues)
PRJ_no_cap = PRJ.copy()
PRJ_CAP = {}
if "cap" in PRJ_no_cap:
    PRJ_CAP['cap'] = PRJ_no_cap["cap"]
    PRJ_CAP['extra'] = PRJ_no_cap['extra']
    del PRJ_no_cap["cap"], PRJ_no_cap['extra']
    print("Removed CAP package due to mixed grid compatibility issues")

Sim_MF6 = imod.mf6.Modflow6Simulation.from_imod5_data(PRJ_no_cap, period_data, times)
print("Simulation created successfully!")

In [None]:
MF6_Mdl = Sim_MF6['imported_model']

In [None]:
MF6_Mdl["oc"] = OutputControl(save_head="last", save_budget="last")
    
# Mimic iMOD5's "Moderate" settings
IMS = Solution(
    modelnames=["imported_model"],
    print_option="summary",
    outer_csvfile=None,
    inner_csvfile=None,
    no_ptc=None,
    outer_dvclose=0.001,
    outer_maximum=150,
    under_relaxation="dbd",
    under_relaxation_theta=0.9,
    under_relaxation_kappa=0.0001,
    under_relaxation_gamma=0.0,
    under_relaxation_momentum=0.0,
    backtracking_number=0,
    backtracking_tolerance=0.0,
    backtracking_reduction_factor=0.0,
    backtracking_residual_limit=0.0,
    inner_maximum=30,
    inner_dvclose=0.001,
    inner_rclose=100.0,
    rclose_option="strict",
    linear_acceleration="bicgstab",
    relaxation_factor=0.97,
    preconditioner_levels=0,
    preconditioner_drop_tolerance=0.0,
    number_orthogonalizations=0,
)
Sim_MF6["ims"] = IMS

MetaSWAP (MSW) needed to be removed, otherwise this function fails.

## Check MF6 params

In [None]:
MF6_Mdl['chd-1']

In [None]:
test = MF6_Mdl['chd_merged'].dataset['head'].values

In [None]:
test.shape

In [None]:
plt.imshow(test[0,7,:,:])

In [None]:
MF6_Mdl['chd_merged'].dataset['head'].isel(time=0, layer=0).plot.imshow()

In [None]:
MF6_Mdl['chd_merged']

In [None]:
MF6_Mdl

In [None]:
(~np.isnan(MF6_Mdl['chd_merged'].dataset['head'].isel(time=0, layer=0))).plot.imshow()

In [None]:
# Diagnose CHD loading issues
chd_data = MF6_Mdl['chd_merged'].dataset['head']
print(f"CHD data shape: {chd_data.shape}")
print(f"CHD data type: {chd_data.dtype}")
print(f"CHD data range: {chd_data.min().values} to {chd_data.max().values}")
print(f"Number of non-NaN values: {(~np.isnan(chd_data)).sum().values}")
print(f"Coordinates: x={chd_data.x.min().values}-{chd_data.x.max().values}, y={chd_data.y.min().values}-{chd_data.y.max().values}")
print(f"Time coordinate: {chd_data.time.values}")
print(f"Layer coordinate: {chd_data.layer.values}")

# Check a sample of the data
print(f"\nSample data (first few non-NaN values):")
sample = chd_data.isel(time=0, layer=0)
non_nan_mask = ~np.isnan(sample)
if non_nan_mask.any():
    print(f"Found {non_nan_mask.sum().values} active CHD cells in layer 0")
else:
    print("No active CHD cells found in layer 0 - checking other layers...")
    for i in range(min(5, chd_data.sizes['layer'])):
        layer_sample = chd_data.isel(time=0, layer=i)
        layer_count = (~np.isnan(layer_sample)).sum().values
        print(f"Layer {i}: {layer_count} active CHD cells")

# Load MSW

## 1st attempt

In [None]:
MF6_DIS = Sim_MF6["imported_model"]["dis"]

In [None]:
try:
    MSW_Mdl = MSW.MetaSwapModel.from_imod5_data(PRJ_CAP, MF6_DIS, times)
    print("🎉 MetaSwap model created successfully!")
except Exception as e:
    print(f"Error creating MetaSwap model: {e}")
    MSW_Mdl = None

## 2nd attempt

### Regriding
Match dimensions of MetaSwap model with MF6 model

Dimensions of MF6_DIS and MetaSWAP files do not align. Former is 100x100, latter is 25x25 m.

In [None]:
# Debug the grid alignment issue
print("=== Target Discretization ===")
print(f"Type: {type(MF6_DIS)}")
print(f"Keys: {list(MF6_DIS.keys()) if hasattr(MF6_DIS, 'keys') else 'No keys method'}")

# Access the data from MF6_DIS
if hasattr(MF6_DIS, 'dataset'):
    print(f"Dataset keys: {list(MF6_DIS.dataset.keys())}")
    target_x = MF6_DIS.dataset.x.values
    target_y = MF6_DIS.dataset.y.values
else:
    # Try accessing as a dataset directly
    target_x = MF6_DIS.x.values
    target_y = MF6_DIS.y.values

print(f"Target X range: {target_x.min()} to {target_x.max()}, length: {len(target_x)}")
print(f"Target Y range: {target_y.min()} to {target_y.max()}, length: {len(target_y)}")

print("\n=== CAP Data ===")
sample_cap = PRJ_CAP['cap']['urban_area']
cap_x = sample_cap.x.values
cap_y = sample_cap.y.values
print(f"CAP X range: {cap_x.min()} to {cap_x.max()}, length: {len(cap_x)}")
print(f"CAP Y range: {cap_y.min()} to {cap_y.max()}, length: {len(cap_y)}")

print(f"\n=== Comparison ===")
print(f"X grids compatible: {len(target_x) == len(cap_x) and np.allclose(target_x, cap_x)}")
print(f"Y grids compatible: {len(target_y) == len(cap_y) and np.allclose(target_y, cap_y)}")

In [None]:
# # Remove 'extra' from cap if it exists (to avoid isel error on metadata)
# if 'extra' in PRJ_CAP['cap'].keys():
#     PRJ_CAP['cap']['extra']

In [None]:
# Get the MODFLOW6 extent (target domain)
MF6_x_min, MF6_x_max = MF6_DIS.dataset.x.min().item(), MF6_DIS.dataset.x.max().item()
MF6_y_min, MF6_y_max = MF6_DIS.dataset.y.min().item(), MF6_DIS.dataset.y.max().item()
print(f"MODFLOW6 domain: X({MF6_x_min:.1f} to {MF6_x_max:.1f}), Y({MF6_y_min:.1f} to {MF6_y_max:.1f})")

In [None]:
# Get CAP grid info
sample_cap = PRJ_CAP['cap']['urban_area']
cap_dx = abs(sample_cap.x.values[1] - sample_cap.x.values[0])  # CAP resolution (absolute value)
cap_dy = abs(sample_cap.y.values[1] - sample_cap.y.values[0])  # CAP resolution (absolute value)
print(f"CAP grid resolution: dx={cap_dx:.1f}m, dy={cap_dy:.1f}m")

In [None]:
# Create exact coordinates for the refined grid within MODFLOW6 domain
refined_x = np.arange(MF6_x_min + cap_dx/2, MF6_x_max, cap_dx)
refined_y = np.arange(MF6_y_max - cap_dy/2, MF6_y_min, -cap_dy)
print(f"Refined grid: x from {refined_x.min():.1f} to {refined_x.max():.1f} ({len(refined_x)} cells)")
print(f"Refined grid: y from {refined_y.min():.1f} to {refined_y.max():.1f} ({len(refined_y)} cells)")

In [None]:
# Create a refined target discretization
MF6_DIS_refined = MF6_DIS.dataset.interp(
    x=refined_x,
    y=refined_y,
    method='nearest'
)
MF6_DIS_refined['idomain'] = MF6_DIS_refined['idomain'].astype(int)

In [None]:
# Convert back to StructuredDiscretization
MF6_DIS = StructuredDiscretization(
    idomain=MF6_DIS_refined['idomain'],
    top=MF6_DIS_refined['top'],
    bottom=MF6_DIS_refined['bottom']
)
print(f"Refined MF6_DIS shape: {MF6_DIS.dataset.sizes}")

In [None]:
# Now regrid all CAP data to the exact refined coordinates
PRJ_CAP_regridded = {'cap': {}}
for key, data in PRJ_CAP['cap'].items():
    print(f"Processing {key}: dims = {data.dims}")
    
    # Check if data has spatial dimensions (x, y)
    if 'x' in data.dims and 'y' in data.dims:
        if key == 'wetted_area':
            # For area-related fields, we need to preserve the total area
            # The wetted_area represents the actual area, which should remain as cell area
            regridded_data = data.interp(x=refined_x, y=refined_y, method='nearest')
            # Set to full cell area (25m x 25m = 625 m²)
            PRJ_CAP_regridded['cap'][key] = regridded_data * 0 + (cap_dx * cap_dy)
        else:
            # For other spatial data, interpolate normally
            PRJ_CAP_regridded['cap'][key] = data.interp(
                x=refined_x,
                y=refined_y,
                method='nearest'
            )
    else:
        # Keep non-spatial data as is
        PRJ_CAP_regridded['cap'][key] = data
print("CAP data regridding completed.")

In [None]:
# Merge with extra metadata
PRJ_CAP_for_MSW = {**PRJ_CAP_regridded, **{'extra': {**PRJ_CAP['extra']}}}

### Fix mete_grid.inp relative paths
MSW.MetaSwapModel.from_imod5_data is struggling with relative paths, so we'll convert them to full paths. #666 caution, if they're already full paths, this may cause an error.

In [None]:
# Get the original mete_grid.inp file path
Pa_mete_grid = PRJ_CAP['extra']['paths'][2][0]  # 3rd file (index 2) (by design in imod - i.e. the order can't change)
print(f"Original file: {Pa_mete_grid}")

In [None]:
# Load mete_grid, edit and save it 
Dir_mete_grid = PDN(Pa_mete_grid)
Pa_mete_grid_AbsPa = PJ( PDN(Pa_mete_grid), 'temp', 'mete_grid.inp')
if not PE(PDN(Pa_mete_grid_AbsPa)):
    MDs(PDN(Pa_mete_grid_AbsPa))

DF = pd.read_csv(Pa_mete_grid, header=None, names=['N', 'Y', 'P', 'PET'])
DF.P = DF.P.apply(lambda x: os.path.abspath( PJ(Dir_PRJ, x) ))
DF.PET = DF.PET.apply(lambda x: os.path.abspath( PJ(Dir_PRJ, x) ))  # Fixed: was DF.P instead of DF.PET

In [None]:
# Write CSV with proper format to avoid imod parsing issues with newlines
# imod doesn't strip newlines from paths, so we need to format carefully
corrected_lines = []
for index, row in DF.iterrows():
    # Add quotes around paths like the original format
    line = f'{row["N"]},{row["Y"]},"{row["P"]}","{row["PET"]}"'
    corrected_lines.append(line)

# Write without newlines in path columns
with open(Pa_mete_grid_AbsPa, 'w') as f:
    for i, line in enumerate(corrected_lines):
        if i == len(corrected_lines) - 1:  # Last line - no newline
            f.write(line)
        else:
            f.write(line + '\n')

print(f"Created corrected mete_grid.inp: {Pa_mete_grid_AbsPa}") 

In [None]:
# Replace the mete_grid.inp path in the PRJ_CAP_for_MSW dictionary
PRJ_CAP_for_MSW['extra']['paths'][2][0] = Pa_mete_grid_AbsPa

### Finally load MS Sim

In [None]:
# Create the MetaSwap model
MSW_Mdl = MSW.MetaSwapModel.from_imod5_data(PRJ_CAP_for_MSW, MF6_DIS, times)
print("🎉 MetaSwap model created successfully!")

# Connect MF6 to MetaSWAP

## Clip models

In [None]:
Sim_MF6_AoI = Sim_MF6.clip_box(x_min=Xmin, x_max=Xmax, y_min=Ymin, y_max=Ymax)
MSW_Mdl_AoI = MSW_Mdl.clip_box(x_min=Xmin, x_max=Xmax, y_min=Ymin, y_max=Ymax)

#### Sense check

In [None]:
Xmin, Xmax, Ymin, Ymax

In [None]:
MF6_Mdl_AoI = Sim_MF6_AoI['imported_model']

In [None]:
MF6_Mdl_AoI['dis']['x'].min().values, MF6_Mdl_AoI['dis']['x'].max().values, \
MF6_Mdl_AoI['dis']['y'].min().values, MF6_Mdl_AoI['dis']['y'].max().values

Makes sense as those are cell centers. (dx=dy=100m)

## Load models into memory

In [None]:
for pkg in MF6_Mdl_AoI.values():
    pkg.dataset.load()

for pkg in MSW_Mdl_AoI.values():
    pkg.dataset.load()

## Cleanup

### MF6

In [None]:
mask = MF6_Mdl_AoI.domain

In [None]:
mask.sel({'layer':5}).plot.imshow( #111 Just to check the mask
    cmap='gray')

In [None]:
Sim_MF6_AoI.mask_all_models(mask)
DIS_AoI = MF6_Mdl_AoI["dis"]

### Check if IMS has changed after clipping

In [None]:
Sim_MF6['ims'] == Sim_MF6_AoI['ims']

In [None]:
IMS_AoI = Sim_MF6_AoI['ims']
IMS_AoI_Vars = [attr for attr in dir(IMS_AoI) if not attr.startswith('_') and not callable(getattr(IMS_AoI, attr))]
IMS_Vars = [attr for attr in dir(IMS) if not attr.startswith('_') and not callable(getattr(IMS, attr))]
for var in IMS_AoI_Vars:
    if getattr(IMS_AoI, var) != getattr(IMS, var):
        print(f"{var}: \n{getattr(IMS_AoI, var)}\n{getattr(IMS, var)}")

False shows it's changed, but I don't see anything different.

In [None]:
IMS.dataset.equals(IMS_AoI.dataset)

This confirms the settings are identical. It's just that clipping creates a new object hash, that's why == didn't show True.

### Check if the packages are the same

In [None]:
MF6_Mdl.keys() == MF6_Mdl_AoI.keys()

In [None]:
# Detailed comparison of MF6 model packages
print("=== MF6 Model Package Comparison ===")

# Get the keys from both models
original_keys = set(MF6_Mdl.keys())
aoi_keys = set(MF6_Mdl_AoI.keys())

print(f"Original model packages: {len(original_keys)}")
print(f"AoI model packages: {len(aoi_keys)}")

# Find differences
only_in_original = original_keys - aoi_keys
only_in_aoi = aoi_keys - original_keys
common_keys = original_keys & aoi_keys

print(f"\nPackages only in original model ({len(only_in_original)}):")
for key in sorted(only_in_original):
    print(f"  - {key}")

print(f"\nPackages only in AoI model ({len(only_in_aoi)}):")
for key in sorted(only_in_aoi):
    print(f"  - {key}")

print(f"\nCommon packages ({len(common_keys)}):")
for key in sorted(common_keys):
    print(f"  - {key}")

# Check if it's just an ordering issue
print(f"\nSame packages (different order): {original_keys == aoi_keys}")
print(f"Original keys (ordered): {sorted(original_keys)}")
print(f"AoI keys (ordered): {sorted(aoi_keys)}")

In [None]:
# Check if the difference is related to the masking operation
print("\n=== Impact of Masking Operation ===")

# The masking operation (Sim_MF6_AoI.mask_all_models(mask)) might have removed some packages
# that are entirely outside the domain or have no active cells after masking

# Let's check if any packages were removed due to masking
print("Note: The masking operation might remove packages that have no active cells in the AoI")

# Check for specific package types that are commonly affected by clipping/masking
potentially_affected = ['wel', 'drn', 'riv', 'ghb', 'chd', 'rch', 'evt']
for pkg_type in potentially_affected:
    orig_matches = [k for k in original_keys if pkg_type in k.lower()]
    aoi_matches = [k for k in aoi_keys if pkg_type in k.lower()]
    
    if len(orig_matches) != len(aoi_matches):
        print(f"\n{pkg_type.upper()} packages:")
        print(f"  Original: {orig_matches}")
        print(f"  AoI: {aoi_matches}")
        print(f"  Difference: {len(orig_matches) - len(aoi_matches)} packages removed")

# Check the masking operation that was performed
print(f"\nMask shape: {mask.shape if hasattr(mask, 'shape') else 'No shape attribute'}")
print(f"Mask type: {type(mask)}")

# Count active cells in mask
try:
    if hasattr(mask, 'values'):
        active_cells = (mask.values > 0).sum()
        total_cells = mask.values.size
        print(f"Active cells in mask: {active_cells}/{total_cells} ({100*active_cells/total_cells:.1f}%)")
except:
    print("Could not compute active cell statistics")

In [None]:
MF6_Mdl['chd_merged'].dataset['head'].sel(time='2010-01-01', layer=1).plot.imshow()

** CHD is not plotting properly. We need to check if the input files have been read properly in more regular interval **

### MSW

In [None]:
MF6_Mdl.keys()

In [None]:
for Pkg in MF6_Mdl_AoI.keys():
    print(Pkg)

In [None]:
# Cleanup MetaSWAP
msw_model_clipped["grid"].dataset["rootzone_depth"] = msw_model_clipped["grid"].dataset["rootzone_depth"].fillna(1.0)

import primod

metamod_coupling = primod.MetaModDriverCoupling(mf6_model="imported_model", mf6_recharge_package="msw-rch", mf6_wel_package="msw-sprinkling")
metamod = primod.MetaMod(msw_model_clipped, simulation_clipped, coupling_list=[metamod_coupling])

metamod.write(out_dir, "modflow6_dll", "metaswap_dll", "metaswap_dll_dependency", modflow6_write_kwargs={"binary": False})


# Write Simulation

In [None]:
Sim_MF6_AoI.mask_all_models(mask)

In [None]:
Dir_Sim = d_Pa['Pa_MdlN']

In [None]:
os.makedirs(Dir_Sim, exist_ok=True)

In [None]:
from imod.schemata import ValidationError

with imod.util.print_if_error(ValidationError):
    MF6_Sim.write(Dir_Sim)  # Attention: this will fail!

# Junkyard

#### Test sim without cap

 <!-- # Simple solution: Remove CAP package and test if conversion works
 print("=== Testing Without CAP Package ===")

 # Create a copy of PRJ without the CAP package
 PRJ_no_cap = PRJ.copy()
 if "cap" in PRJ_no_cap:
     del PRJ_no_cap["cap"]
     print("CAP package removed from PRJ")
 else:
     print("CAP package not found in PRJ")

 print(f"Original PRJ keys: {list(PRJ.keys())}")
 print(f"PRJ without CAP keys: {list(PRJ_no_cap.keys())}")

 # Test the conversion without CAP
 try:
     print("\n=== Testing Simulation Conversion Without CAP ===")
     simulation = imod.MF6.Modflow6Simulation.from_imod5_data(PRJ_no_cap, period_data, times)
     print("SUCCESS: Simulation created without CAP package!")
     print(f"Simulation keys: {list(simulation.keys())}")
    
     # Check what packages were created
     if "imported_model" in simulation:
         model = simulation["imported_model"]
         print(f"Model packages: {list(model.keys())}")
        
 except Exception as e:
     print(f"Error without CAP: {e}")
     import traceback
     traceback.print_exc() -->

#### Investigate well error

from imod.MF6.wel import LayeredWell, Well

PRJ['wel-WEL_Br_Wa_T_NBr1']['layer']

Well.from_imod5_data('wel-WEL_Br_Wa_T_NBr1', PRJ, times)

os.remove(Pa_PRJ_temp)  # Delete temp PRJ file as it's not needed anymore.