In [94]:
import xarray as xr
import numpy as np
import glob
import os
import h5py
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import warnings
warnings.filterwarnings("ignore")

1. preparing random crop dataset
2. mean and std of different data chunks
3. NetCDF to HdF5
4. FFT 
5. Region specific and all data map plots

In [97]:
nc_file_loc = '/p/project1/exaww/chatterjee1/dataset/iconeu/'
output_file = "msgobs_108_randcrops_icon.nc"

months = {
    4:'04/',
    5:'05/',
    6:'06/',
    7:'07/',
    8:'08/',
    9:'09/',
}

log_file = nc_file_loc + "processed_files_log.txt"
nan_crop_file = nc_file_loc + "nan_files_log.txt"

In [102]:
sample_counter = 0  # Initialize a counter for unique sample naming

all_crops = []  # List to store all crops
all_lats = []   # List to store all latitude coordinates
all_lons = []   # List to store all longitude coordinates
all_times = []  # List to store all timestamps

first_write = True  # Flag to check if it's the first time writing to the file

# Iterate over each month
for _, key in enumerate(months.keys()):
    month_loc = nc_file_loc + months[key]
    day_folders = sorted(glob.glob(month_loc + '*/'))  # Finds all day subfolders within the month folder

    # Iterate over each day folder within the current month
    for day_folder in day_folders:
        nc_filepattern = "iefrf*.nc"
        nc_files = sorted(glob.glob(day_folder + nc_filepattern)) # '/p/project1/exaww/chatterjee1/dataset/iconeu/09/01/'

        for i, file in enumerate(nc_files):

            # Log the name of the current file
            with open(log_file, 'a') as log:
                log.write(f"{file}\n")

            data = xr.open_dataset(file)
            date = data.time.dt.strftime('%Y-%m-%d').values[0]
            timestamp = data.time.dt.strftime('%H:%M:%S').values[0]
            

            lat = data['lat'][288:377].values  
            lon = data['lon'][472:816].values    
            bt_data = data['SYNMSG_BT_CL_IR10.8'][0,288:377,472:816]

            y_dim, x_dim = bt_data.shape

            # Divide x dimension into 4 segments for spreading the crops
            x_segments = np.linspace(0, x_dim - 78, 5, dtype=int)

            for j in range(4):
                # Ensure crops are more spread in the x direction by choosing a segment
                start_y = np.random.randint(0, y_dim - 78)
                start_x = np.random.randint(x_segments[j], x_segments[j + 1])

                # Crop the data
                crop = bt_data[start_y:start_y + 78, start_x:start_x + 78]

                # Skip this crop if it contains any NaN values
                if np.isnan(crop).any():
                    with open(nan_crop_file, 'a') as log:
                        log.write(f"{file} and crop no {j}\n")
                    continue

                # Store the crop and the corresponding coordinates
                all_crops.append(crop)
                all_lats.append(lat[start_y:start_y + 78])
                all_lons.append(lon[start_x:start_x + 78])
                all_times.append(timestamp)

                # Increment the sample counter
                sample_counter += 1

                all_crops_np = np.array(all_crops)
                all_lats_np = np.array(all_lats)
                all_lons_np = np.array(all_lons)
                all_times_np = np.array(all_times)

# Create a dataset with the combined data
ds = xr.Dataset(
    {
        "model_108": (["sample", "y", "x"], all_crops_np)  # Data variable
    },
    coords={
        "sample": (["sample"], np.arange(len(all_crops_np))),  # Sample numbers
        "lat": (["sample", "y"], all_lats_np),
        "lon": (["sample", "x"], all_lons_np),
        "time": (["sample"], all_times_np)
    }
)

# Write or append to the NetCDF file
if first_write:
    ds.to_netcdf(output_file, mode='w')
    first_write = False


In [103]:
sample_counter

18300

In [36]:
day_folder + nc_filepattern

'/p/project1/exaww/chatterjee1/dataset/iconeu/09/30/iefrf*.nc'

In [39]:
day_folders[0]

'/p/project1/exaww/chatterjee1/dataset/iconeu/09/01/'

In [101]:
nc_files

['/p/project1/exaww/chatterjee1/dataset/iconeu/09/30/iefrf00000000.nc',
 '/p/project1/exaww/chatterjee1/dataset/iconeu/09/30/iefrf00010000.nc',
 '/p/project1/exaww/chatterjee1/dataset/iconeu/09/30/iefrf00020000.nc',
 '/p/project1/exaww/chatterjee1/dataset/iconeu/09/30/iefrf00030000.nc',
 '/p/project1/exaww/chatterjee1/dataset/iconeu/09/30/iefrf00040000.nc',
 '/p/project1/exaww/chatterjee1/dataset/iconeu/09/30/iefrf00050000.nc',
 '/p/project1/exaww/chatterjee1/dataset/iconeu/09/30/iefrf00060000.nc',
 '/p/project1/exaww/chatterjee1/dataset/iconeu/09/30/iefrf00070000.nc',
 '/p/project1/exaww/chatterjee1/dataset/iconeu/09/30/iefrf00080000.nc',
 '/p/project1/exaww/chatterjee1/dataset/iconeu/09/30/iefrf00090000.nc',
 '/p/project1/exaww/chatterjee1/dataset/iconeu/09/30/iefrf00100000.nc',
 '/p/project1/exaww/chatterjee1/dataset/iconeu/09/30/iefrf00110000.nc',
 '/p/project1/exaww/chatterjee1/dataset/iconeu/09/30/iefrf00120000.nc',
 '/p/project1/exaww/chatterjee1/dataset/iconeu/09/30/iefrf001300

In [5]:
loc

'/p/project1/exaww/chatterjee1/dataset/iconeu/09/'

In [110]:
data = xr.open_dataset('/p/project1/exaww/chatterjee1/dataset/msgobs_108_randcrops_icon.nc')
data

In [111]:
data.model_108.shape

(18300, 78, 78)

## mean and std of different data chunks

In [112]:
# Function to calculate mean and standard deviation using chunks
def compute_mean_std_chunked(data, chunk_size=1000):
    n_samples, height, width = data.shape
    n_elements = n_samples * height * width

    # Initialize mean and variance sums
    total_sum = 0.0
    total_square_sum = 0.0

    # Process in chunks
    for i in range(0, n_samples, chunk_size):
        chunk = data[i:i+chunk_size]

        # Update sums
        total_sum += np.sum(chunk)
        total_square_sum += np.sum(chunk ** 2)

    # Calculate mean and variance
    mean = total_sum / n_elements
    variance = (total_square_sum / n_elements) - (mean ** 2)
    std = np.sqrt(variance)

    return mean, std

# Simulate smaller sample_data for demonstration purposes
small_sample_data = ds.model_108

# Calculate the mean and standard deviation using chunks
mean_sample_data_chunked, std_sample_data_chunked = compute_mean_std_chunked(small_sample_data)

mean_sample_data_chunked, std_sample_data_chunked


(<xarray.DataArray 'model_108' ()>
 array(266.87361996),
 <xarray.DataArray 'model_108' ()>
 array(20.80882374))

## NetCDF to HdF5

In [113]:
def convert_nc_to_hdf5(nc_file, hdf5_file):
    # Open the NetCDF file using xarray
    ds = xr.open_dataset(nc_file)

    # Create the HDF5 file
    with h5py.File(hdf5_file, 'w') as hdf5_data:
        # Iterate over all variables in the xarray dataset
        for var_name in ds.data_vars:
            var_data = ds[var_name].values
            
            # Create a dataset in the HDF5 file
            hdf5_dataset = hdf5_data.create_dataset(
                var_name, 
                data=var_data, 
                dtype=var_data.dtype, 
                #chunks=chunking,  # Enable chunking if requested
                #compression=compression  # Apply compression if provided
            )

            # Copy variable attributes of the variable to the HDF5 dataset
            for attr_name, attr_value in ds[var_name].attrs.items():
                hdf5_dataset.attrs[attr_name] = attr_value
        
        # Iterate over all coordinates in the xarray dataset
        for coord_name in ds.coords:
            coord_data = ds[coord_name].values
            
            # Handle special case for time coordinate with dtype('O')
            if coord_data.dtype == 'O':
                # Convert to fixed-length strings
                coord_data = coord_data.astype('S')
            
            # Create a dataset in the HDF5 file for the coordinate
            hdf5_coord = hdf5_data.create_dataset(
                coord_name, 
                data=coord_data, 
                dtype=coord_data.dtype, 
                #chunks=chunking,  # Enable chunking if requested
                #compression=compression  # Apply compression if provided
            )
            
            # Copy coordinate attributes to the HDF5 dataset
            for attr_name, attr_value in ds[coord_name].attrs.items():
                hdf5_coord.attrs[attr_name] = attr_value
        
        # Copy global attributes
        for attr_name, attr_value in ds.attrs.items():
            hdf5_data.attrs[attr_name] = attr_value
    # Close the xarray dataset
    ds.close()


nc_file = '/p/project1/exaww/chatterjee1/dataset/msgobs_108_randcrops_icon.nc'
hdf5_file = '/p/project1/exaww/chatterjee1/warmworld_scripts/msgobs_108_randcrops_icon.h5'
convert_nc_to_hdf5(
    nc_file, 
    hdf5_file, 
    #chunking=False, 
    #compression=None
)
print(f"Converted {nc_file} to {hdf5_file}")

Converted /p/project1/exaww/chatterjee1/dataset/msgobs_108_randcrops_icon.nc to /p/project1/exaww/chatterjee1/warmworld_scripts/msgobs_108_randcrops_icon.h5


In [57]:
date = data.time.dt.strftime('%Y-%m-%d').values
timestamp = data.time.dt.strftime('%H:%M:%S').values

In [100]:
date, timestamp

('2023-09-30', '01:00:00')

In [32]:
data.lon.shape, data.lat.shape

((1377,), (657,))

In [14]:
idx_lon_max = np.argmin(np.abs(data.lon.values - 27.5))
idx_lon_min = np.argmin(np.abs(data.lon.values - 6))
idx_lon_min,idx_lon_max

(472, 816)

In [34]:
idx_lat_max = np.argmin(np.abs(data.lat.values - 53))
idx_lat_min = np.argmin(np.abs(data.lat.values - 47.5))
idx_lat_min,idx_lat_max

(288, 376)

In [38]:
data.lat.max(), data.lat.min()

(<xarray.DataArray 'lat' ()>
 array(70.5),
 <xarray.DataArray 'lat' ()>
 array(29.5))

In [48]:
data['SYNMSG_BT_CL_IR10.8'][0,:,:].shape

(657, 1377)

## Region specific map plots

In [89]:
for file in nc_files: 
    data = xr.open_dataset(file)
    date = data.time.dt.strftime('%Y-%m-%d').values
    timestamp = data.time.dt.strftime('%H:%M:%S').values

    long = data['lon'][472:816]
    lat = data['lat'][288:377]

    # Extract the first layer of 'bt'
    bt_data = data['SYNMSG_BT_CL_IR10.8'][0,288:377,472:816]  # shape: (89 X 344 pixels)

    # Create a meshgrid for lat and long
    lon_grid, lat_grid = np.meshgrid(long, lat)

    # Create the plot
    plt.figure(figsize=(3.44, 0.89))

    # Set up the plot with PlateCarree projection
    ax = plt.axes(projection=ccrs.PlateCarree())
    ax.coastlines()
    ax.add_feature(cfeature.BORDERS, linestyle='-', alpha=0.7)

    # Plot the data
    plt.pcolormesh(lon_grid, lat_grid, bt_data, cmap='viridis', shading='nearest', transform=ccrs.PlateCarree(), vmin=220, vmax=300)

    # Add a colorbar
    #plt.colorbar(label='Synth. Sat. brightness temperature cloudy')

    # Set labels and title
    #plt.title('ICON-EU: '+ date[0][5:]+'-'+timestamp[0][0:2]+ '_89 X 344 pixels')
    plt.xlabel('Longitude')
    plt.ylabel('Latitude')

    # Save the plot to a specific location
    save_path = '/p/project1/exaww/chatterjee1/plots/vis_model_data/'
    plt.savefig(save_path + date[0][8:] + '-' + date[0][5:7] + ':' + timestamp[0][0:2] + '.png', dpi=100, bbox_inches='tight',pad_inches=0)  # dpi=300 for high resolution, bbox_inches='tight' to trim whitespace
    plt.close()



# Performing fft of 2D spatial fields

In [82]:
for file in nc_files: 
    data = xr.open_dataset(file)
    date = data.time.dt.strftime('%Y-%m-%d').values
    timestamp = data.time.dt.strftime('%H:%M:%S').values

    long = data['lon'][472:816]
    lat = data['lat'][288:377]

    # Extract the first layer of 'bt'
    bt_data = data['SYNMSG_BT_CL_IR10.8'][0,288:377,472:816]  # shape: (89 X 344 pixels)

    if np.isnan(bt_data).any():
          pass# Fill NaNs with zeros or any other method if necessary
    else:
        # Step 2: Perform 2D FFT
        fft_result = np.fft.fft2(bt_data)

        # Step 3: Shift zero frequency component to the center
        fft_shifted = np.fft.fftshift(fft_result)

        # Step 4: Calculate the magnitude (absolute value) of the FFT
        magnitude_spectrum = np.abs(fft_shifted)

        # Optionally, use a logarithmic scale for better visualization
        log_magnitude_spectrum = np.log1p(magnitude_spectrum)  # log1p for log(1 + x) to avoid log(0) issues

        # Plot the magnitude spectrum
        plt.figure(figsize=(6, 6))
        im = plt.imshow(log_magnitude_spectrum, cmap='viridis')
        plt.colorbar(im, label='Log Magnitude Spectrum',fraction=0.016, pad=0.04)
        plt.title('2D FFT ICON ' + date[0][8:] + '-' + date[0][5:7] + ':' + timestamp[0][0:2])
        plt.xlabel('Frequency X')
        plt.ylabel('Frequency Y')
        save_path = '/p/project1/exaww/chatterjee1/plots/vis_model_data/fft/'
        plt.savefig(save_path + date[0][8:] + '-' + date[0][5:7] + ':' + timestamp[0][0:2] + '_fft.png', dpi=100, bbox_inches='tight',pad_inches=0)  # dpi=300 for high resolution, bbox_inches='tight' to trim whitespace
        plt.close()

# 1D FFT

In [86]:
def radial_profile(data):
    """Calculate the radial profile of a 2D FFT array."""
    # Get the center of the data
    center_y, center_x = np.array(data.shape) // 2
    
    # Create a meshgrid of distances from the center
    y, x = np.indices(data.shape)
    r = np.sqrt((x - center_x)**2 + (y - center_y)**2)
    r = r.astype(np.int)
    
    # Calculate the mean value at each radial distance
    tbin = np.bincount(r.ravel(), data.ravel())
    nr = np.bincount(r.ravel())
    radial_profile = tbin / nr
    return radial_profile

# List to store each radial profile
radial_profiles = []

for file in nc_files: 
    data = xr.open_dataset(file)
    date = data.time.dt.strftime('%Y-%m-%d').values
    timestamp = data.time.dt.strftime('%H:%M:%S').values

    long = data['lon'][472:816]
    lat = data['lat'][288:377]

    # Extract the first layer of 'bt'
    bt_data = data['SYNMSG_BT_CL_IR10.8'][0,288:377,472:816]  # shape: (89 X 344 pixels)
    
    if not np.isnan(bt_data).any():
            # Step 2: Perform 2D FFT
        fft_result = np.fft.fft2(bt_data)

        # Step 3: Shift zero frequency component to the center
        fft_shifted = np.fft.fftshift(fft_result)

        # Step 4: Calculate the magnitude (absolute value) of the FFT
        magnitude_spectrum = np.abs(fft_shifted)

        # Step 5: Calculate the log of the magnitude spectrum
        log_magnitude_spectrum = np.log1p(magnitude_spectrum)  # log(1 + x) to avoid log(0) issues

        # Calculate the radial profile
        radial_profile_1d = radial_profile(log_magnitude_spectrum)

        # Store the radial profile with the timestamp for plotting
        radial_profiles.append((timestamp, radial_profile_1d))

    # Plot all radial profiles in a single plot
    plt.figure(figsize=(10, 6))
    for timestamp, radial_profile_1d in radial_profiles:
        plt.plot(radial_profile_1d, label=f'model {date[0][8:]}-{date[0][5:7]}:{timestamp[0][0:2]}')

    plt.xlabel('Radial Frequency')
    plt.ylabel('Log Magnitude')
    plt.title('ICON: FFT for Each BT Distribution')
    plt.legend()
    save_path = '/p/project1/exaww/chatterjee1/plots/vis_model_data/fft/'
    plt.savefig(save_path + date[0][8:] + '-' + date[0][5:7] + '_fft.png', dpi=100, bbox_inches='tight', pad_inches=0)
    plt.close()

In [91]:
len(radial_profiles)

25

In [92]:
type(radial_profiles)

list

In [93]:
np.save('/p/project1/exaww/chatterjee1/plots/vis_model_data/fft/radprof_icon.npy', radial_profiles)

In [74]:
timestamp[0][0:2]

'00'

## All data map plot

In [None]:
pixel_size = [657, 1377]
fig,ax = plt.subplots(figsize=pixel_size, dpi=1)
fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
a = ax.imshow(data['SYNMSG_BT_CL_IR10.8'][0,:,:], cmap='Spectral_r')
ax.axis(False)
plt.savefig(f"/p/project1/exaww/chatterjee1/plots/vis_model_data/10_8.png",  bbox_inches="tight")

pixel_size = [657, 1377]
fig,ax = plt.subplots(figsize=pixel_size, dpi=1)
fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
a = ax.imshow(data['SYNMSG_BT_CL_WV6.2'][0,:,:], cmap='Spectral_r')
ax.axis(False)
plt.savefig(f"/p/project1/exaww/chatterjee1/plots/vis_model_data/6_2.png",  bbox_inches="tight")

In [None]:
sample_counter = 0  # Initialize a counter for unique sample naming

all_crops = []  # List to store all crops
all_lats = []   # List to store all latitude coordinates
all_lons = []   # List to store all longitude coordinates
all_times = []  # List to store all timestamps

target_lat = 50.9224 #cabauw = 51.9653 #Sirta = 48.717 # zargosa = 41.6474 #madrid = 40.4167 # Bourges = 47.0812 # Vienna = 48.2081 # warsaw = 52.229 # lindenberg = 52.210 #juelich = 50.9224
target_lon = 6.3639 #cabauw = 4.8979 #Sirta = 2.208 # Zargosa = -0.8861 # Madrid = -3.7033 # Bourges = 2.3980 # Vienna = 16.3713 # warsaw = 21.012 # lindenberg = 14.122 #Juelich = 6.3639
crop_size = 78
half_crop = crop_size // 2

first_write = True  # Flag to check if it's the first time writing to the file

# Iterate over each month
for _, key in enumerate(months.keys()):
    month_loc = nc_file_loc + months[key]
    day_folders = sorted(glob.glob(month_loc + '*/'))  # Finds all day subfolders within the month folder

    # Iterate over each day folder within the current month
    for day_folder in day_folders:
        nc_filepattern = "iefrf*.nc"
        nc_files = sorted(glob.glob(day_folder + nc_filepattern)) # '/p/project1/exaww/chatterjee1/dataset/iconeu/09/01/'

        for i, file in enumerate(nc_files):

            # Log the name of the current file
            with open(log_file, 'a') as log:
                log.write(f"{file}\n")

            data = xr.open_dataset(file)
            date = data.time.dt.strftime('%Y-%m-%d').values[0]
            timestamp = data.time.dt.strftime('%H:%M:%S').values[0]
            

            lat = data['lat'].values  # shape: (H,)
            lon = data['lon'].values  # shape: (W,)            
            bt_data = data['SYNMSG_BT_CL_IR10.8'][0, :, :].values  # shape: (H, W)

            # Find closest lat/lon indices
            lat_idx = np.abs(lat - target_lat).argmin()
            lon_idx = np.abs(lon - target_lon).argmin()
            
            # Define crop boundaries (clamp to array bounds)
            start_y = max(0, lat_idx - half_crop)
            end_y = start_y + crop_size
            start_x = max(0, lon_idx - half_crop)
            end_x = start_x + crop_size
            
            # Skip if crop goes beyond data bounds
            if end_y > bt_data.shape[0] or end_x > bt_data.shape[1]:
                continue
            
            # Crop the data
            crop = bt_data[start_y:end_y, start_x:end_x]

            # Skip if NaNs are present
            if np.isnan(crop).any():
                continue

            all_crops.append(crop)
            all_lats.append(lat[start_y:end_y])
            all_lons.append(lon[start_x:end_x])
            #all_times.append(timestamp)
            all_times.append(data.time.values[0])
    
            # Increment the sample counter
            sample_counter += 1
    
all_crops_np = np.array(all_crops)
all_lats_np = np.array(all_lats)
all_lons_np = np.array(all_lons)
all_times_np = np.array(all_times)



