In [6]:
import os
import json
import pandas as pd
import numpy as np
from tqdm import tqdm

import torch
from torch.utils.data import Dataset

In [7]:
# Define data directories
# base_data_dir = '/srv/scratch/z5370003/projects/data/groundwater/FEFLOW/coastal/variable_density/'
base_data_dir = '/Users/arpitkapoor/Library/CloudStorage/OneDrive-UNSW/Shared/Projects/01_PhD/05_groundwater/data/FEFLOW/variable_density'  # Uncomment for local testing
raw_data_dir = os.path.join(base_data_dir, 'all')
patch_data_dir = os.path.join(base_data_dir, 'filter_all_ts_patch')

print(f"Base data directory: {base_data_dir}")
print(f"Raw data directory: {raw_data_dir}")
print(f"Patch filtered data directory: {patch_data_dir}")

Base data directory: /Users/arpitkapoor/Library/CloudStorage/OneDrive-UNSW/Shared/Projects/01_PhD/05_groundwater/data/FEFLOW/variable_density
Raw data directory: /Users/arpitkapoor/Library/CloudStorage/OneDrive-UNSW/Shared/Projects/01_PhD/05_groundwater/data/FEFLOW/variable_density/all
Patch filtered data directory: /Users/arpitkapoor/Library/CloudStorage/OneDrive-UNSW/Shared/Projects/01_PhD/05_groundwater/data/FEFLOW/variable_density/filter_all_ts_patch


In [8]:
import sys
sys.path.append('../')
from src.data.transform import Normalize
from src.data.patch_dataset import GWPatchDataset

In [9]:
# Read data
df = pd.read_csv(os.path.join(raw_data_dir, '0000.csv'))

# Calculate mean and std of coordinates
coord_mean = df[['X', 'Y', 'Z']].mean().values
coord_std = df[['X', 'Y', 'Z']].std().values

# Print mean and std of coordinates
print(f"Coordinate mean: {coord_mean}")
print(f"Coordinate std: {coord_std}")

# Create coordinate transform
coord_transform = Normalize(mean=coord_mean, std=coord_std)

del df

Coordinate mean: [ 3.57225665e+05  6.45774324e+06 -9.27782248e+00]
Coordinate std: [569.1699999  566.35797379  15.26565618]


In [10]:
# Read data
df = pd.read_csv(os.path.join(raw_data_dir, '0000.csv'))

# Define output columns
obs_cols = ['mass_concentration', 'head', 'pressure']

# Mean and std of output
obs_mean = df[obs_cols].mean().values
obs_std = df[obs_cols].std().values

# Print mean and std of output
print(f"Output mean: {obs_mean}")
print(f"Output std: {obs_std}")

# Define output transform
obs_transform = Normalize(mean=obs_mean, std=obs_std)

del df

Output mean: [1.77942252e+04 3.95881156e-01 9.48469883e+01]
Output std: [1.55859465e+04 2.13080032e-01 1.51226320e+02]


In [11]:
patch_ds = GWPatchDataset(data_path=patch_data_dir, dataset='train', 
                          coord_transform=coord_transform, 
                          obs_transform=obs_transform)
len(patch_ds)

13300

In [12]:
for patch_data in patch_ds:
    print([(k, v.shape) if k != 'patch_id' else (k, v) for k, v in patch_data.items()])
    input_coords = torch.concat([patch_data['core_coords'], patch_data['ghost_coords']], dim=0)
    input_obs = torch.concat([patch_data['core_in'], patch_data['ghost_in']], dim=1)
    output_obs = torch.concat([patch_data['core_out'], patch_data['ghost_out']], dim=1)
    core_len = int(input_coords.shape[0] * 0.8)
    

[('core_in', torch.Size([10, 79, 3])), ('ghost_in', torch.Size([10, 19, 3])), ('core_out', torch.Size([10, 79, 3])), ('ghost_out', torch.Size([10, 19, 3])), ('patch_id', 1), ('core_coords', torch.Size([79, 3])), ('ghost_coords', torch.Size([19, 3]))]
[('core_in', torch.Size([10, 79, 3])), ('ghost_in', torch.Size([10, 19, 3])), ('core_out', torch.Size([10, 79, 3])), ('ghost_out', torch.Size([10, 19, 3])), ('patch_id', 1), ('core_coords', torch.Size([79, 3])), ('ghost_coords', torch.Size([19, 3]))]
[('core_in', torch.Size([10, 79, 3])), ('ghost_in', torch.Size([10, 19, 3])), ('core_out', torch.Size([10, 79, 3])), ('ghost_out', torch.Size([10, 19, 3])), ('patch_id', 1), ('core_coords', torch.Size([79, 3])), ('ghost_coords', torch.Size([19, 3]))]
[('core_in', torch.Size([10, 79, 3])), ('ghost_in', torch.Size([10, 19, 3])), ('core_out', torch.Size([10, 79, 3])), ('ghost_out', torch.Size([10, 19, 3])), ('patch_id', 1), ('core_coords', torch.Size([79, 3])), ('ghost_coords', torch.Size([19, 3]