In [1]:
import xarray as xr
import numpy as np
import glob
import os
import h5py
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import warnings
warnings.filterwarnings("ignore")

In [2]:
C1 = 1.19104*10**(-5)  # in [mW (cm−1)−4 m-2 sr−1]
C2 = 1.43877  # in [K cm]

CHANNEL_NAME = {"channel_1": "VIS 0.6", 
                "channel_2": "VIS 0.8", 
                "channel_3": "NIR 1.6", 
                "channel_4": "IR 3.9", 
                "channel_5": "WV 6.2", 
                "channel_6": "WV 7.3", 
                "channel_7": "IR 8.7", 
                "channel_8": "IR 9.7 - O3", 
                "channel_9": "IR 10.8", 
                "channel_10": "IR 12.0", 
                "channel_11": "IR 13.4 - CO2", }
# in [cm−1]
VC = {'MSG1': {"channel_4": 2567.330, "channel_5": 1598.103, "channel_6": 1362.081, "channel_7": 1149.069, 
                "channel_8": 1034.343, "channel_9": 930.647, "channel_10": 839.660, "channel_11": 752.387
                }, 
      'MSG2': {"channel_4": 2568.832, "channel_5": 1600.548, "channel_6": 1360.330, "channel_7": 1148.620, 
                "channel_8": 1035.289, "channel_9": 931.700, "channel_10": 836.445, "channel_11": 751.792
                }, 
      'MSG3': {"channel_4": 2547.771, "channel_5": 1595.621, "channel_6": 1360.377, "channel_7": 1148.130, 
                "channel_8": 1034.715, "channel_9": 929.842, "channel_10": 838.659, "channel_11": 750.653
                }, 
      'MSG4': {"channel_4": 2555.280, "channel_5": 1596.080, "channel_6": 1361.748, "channel_7": 1147.433, 
                "channel_8": 1034.851, "channel_9": 931.122, "channel_10": 839.113, "channel_11": 748.585
                }, }
# unitless
ALPHA = {'MSG1': {"channel_4": 0.9956, "channel_5": 0.9962, "channel_6": 0.9991, "channel_7": 0.9996, 
                   "channel_8": 0.9999, "channel_9": 0.9983, "channel_10": 0.9988, "channel_11": 0.9981
                   }, 
         'MSG2': {"channel_4": 0.9954, "channel_5": 0.9963, "channel_6": 0.9991, "channel_7": 0.9996, 
                   "channel_8": 0.9999, "channel_9": 0.9983, "channel_10": 0.9988, "channel_11": 0.9981
                   }, 
         'MSG3': {"channel_4": 0.9915, "channel_5": 0.9960, "channel_6": 0.9991, "channel_7": 0.9996, 
                   "channel_8": 0.9999, "channel_9": 0.9983, "channel_10": 0.9988, "channel_11": 0.9982
                   }, 
         'MSG4': {"channel_4": 0.9916, "channel_5": 0.9959, "channel_6": 0.9990, "channel_7": 0.9996, 
                   "channel_8": 0.9998, "channel_9": 0.9983, "channel_10": 0.9988, "channel_11": 0.9981
                   }, }
# in [K]
BETA = {'MSG1': {"channel_4": 3.410, "channel_5": 2.218, "channel_6": 0.478, "channel_7": 0.179, ''
                  "channel_8": 0.060, "channel_9": 0.625, "channel_10": 0.397, "channel_11": 0.578
                  },
        'MSG2': {"channel_4": 3.438, "channel_5": 2.185, "channel_6": 0.470, "channel_7": 0.179, 
                  "channel_8": 0.056, "channel_9": 0.640, "channel_10": 0.408, "channel_11": 0.561
                  },
        'MSG3': {"channel_4": 2.9002, "channel_5": 2.0337, "channel_6": 0.4340, "channel_7": 0.1714, 
                  "channel_8": 0.0527, "channel_9": 0.6084, "channel_10": 0.3882, "channel_11": 0.5390
                  },
        'MSG4': {"channel_4": 2.9438, "channel_5": 2.0780, "channel_6": 0.4929, "channel_7": 0.1731, 
                  "channel_8": 0.0597, "channel_9": 0.6256, "channel_10": 0.4002, "channel_11": 0.5635
                  }, }

# %%
#############
############# look up tables for calculating reflectances
#############
# constants taken from website: 
# https://eumetsatspace.atlassian.net/wiki/spaces/DSDT/pages/1537277953/MSG15+radiances+conversion+to+BT+and+Reflectances
# and from https://www-cdn.eumetsat.int/files/2020-04/pdf_msg_seviri_rad2refl.pdf

IRRAD = {'MSG1': {"channel_1": 65.2296, "channel_2": 73.0127, "channel_3": 62.3715},
         'MSG2': {"channel_1": 65.2065, "channel_2": 73.1869, "channel_3": 61.9923},
         'MSG3': {"channel_1": 65.5148, "channel_2": 73.1807, "channel_3": 62.0208}, 
         'MSG4': {"channel_1": 65.2656, "channel_2": 73.1692, "channel_3": 61.9416}, }


# %%
class ir_channel:
    """
    class that calls channel specific constants from look up tables above
    """
    def __init__(self, satellite, channel):

        self.name = CHANNEL_NAME[channel]
        self.vc = VC[satellite][channel]  # wavenumber in [cm−1]
        self.alpha = ALPHA[satellite][channel]  # unitless
        self.beta = BETA[satellite][channel]  # in [K]

class vis_nir_channel:
    def __init__(self, satellite, channel):
        
        self.name = CHANNEL_NAME[channel]
        self.irrad = IRRAD[satellite][channel]  # irradiance at 1AU in [mW·m-2·(cm-1)-1]

class MSG_satellite:
    def __init__(self, name):
        self.name =  name

    def _get_channel(self, channel_number):
        # return vis/nir or ir channel depending on channel number
        if channel_number <=3:
            return vis_nir_channel(satellite=self.name, channel=f"channel_{channel_number}")
        else:
            return ir_channel(satellite=self.name, channel=f"channel_{channel_number}")

    def rad_2_tb(self, channel_number, radiances):
        # error handling here:
        # TODO: raise exception when given incorrect channel_number, must be >=4

        # get constants for given channel
        channel_consts = self._get_channel(channel_number)

        # converting radiance to brightness temperature [K] with simplified equation
        numerator = C2 * channel_consts.vc
        fraction = C1 * channel_consts.vc**3 / radiances + 1
        denominator = channel_consts.alpha * (np.log(fraction))
        tb = numerator / denominator - channel_consts.beta / channel_consts.alpha  ## [K]
        return tb
    
    def _d(t):
        # Sun-Earth distance in AU at time t
        return None
    
    def _solar_zenith_angle(t, lon, lat):
        # Solar Zenith Angle in Radians at time t and location x
        return None

    def rad_2_refl(self, channel_number, radiances, t, lon, lat):
        # error handling here:
        # TODO: raise exception when given incorrect channel_number, must be <= 3

        # get constants for given channel
        channel_consts = self._get_channel(channel_number)

        numerator = np.pi * radiances * self._d(t)**2
        denominator = channel_consts.irrad * np.cos(self._solar_zenith_angle(t, lon, lat))

# %%
def radiances_2_brightnesstemp_and_reflectances(radiances, channel_number, satellite_name):
    ## radiances in [mW m−2 sr−1 (cm−1)−1)]
    # TODO: add constraint to channel_number (must be >= 4)

    # access correct satellite 
    satellite = MSG_satellite(satellite_name)
    if channel_number <= 3:
        print("not implemented yet for visible and near-infrared")

    elif channel_number >= 4 and channel_number < 12:
        # get brightness temp fro given channel
        return satellite.rad_2_tb(channel_number, radiances)
    
    else:
        print(f"This channel does not exist for satellite {satellite_name}")
        # TODO: raise exception

In [3]:
nc_file_loc = '/p/scratch/exaww/chatterjee1/msg_warmworld/2024/'
output_file = "msgobs_108_randcrops_2024.nc"

months = {
    4:'04/',
    5:'05/',
    6:'06/',
    7:'07/',
    8:'08/',
    9:'09/',
}

In [20]:
nc_file_loc = '/p/scratch/exaww/chatterjee1/msg_warmworld/2024/'
output_file = "msgobs_108_randcrops_2024.nc"
log_file = nc_file_loc + "processed_files_log.txt"
nan_crop_file = nc_file_loc + "nan_files_log.txt"

months = {
    4: '04/',
    5: '05/',
    6: '06/',
    7: '07/',
    8: '08/',
    9: '09/',
}

sample_counter = 0  # Initialize a counter for unique sample naming

# Create an empty NetCDF file to start with
#xr.Dataset().to_netcdf(output_file, mode='w')

all_crops = []  # List to store all crops
all_lats = []   # List to store all latitude coordinates
all_lons = []   # List to store all longitude coordinates
all_times = []  # List to store all timestamps

first_write = True  # Flag to check if it's the first time writing to the file

for _, key in enumerate(months.keys()):
    loc = nc_file_loc + months[key]
    nc_filepattern = "HRSEVIRI_2024*_PC.nc"
    nc_files = sorted(glob.glob(loc + nc_filepattern))

    for i, file in enumerate(nc_files):

        # Log the name of the current file
        with open(log_file, 'a') as log:
            log.write(f"{file}\n")
        
        data = xr.open_dataset(file)
        satellite_name = data.EPCT_product_name.split('-')[0]
        timestamp = data.EPCT_product_name.split('A-')[1].split('.')[0]

        lat = data.lat[465:611].values  
        lon = data.lon[252:823].values  
        radiances = data["channel_9"][465:611, 252:823].values
        bt_data = radiances_2_brightnesstemp_and_reflectances(radiances, 9, satellite_name)

        y_dim, x_dim = bt_data.shape

        # Divide x dimension into 4 segments for spreading the crops
        x_segments = np.linspace(0, x_dim - 128, 5, dtype=int)

        for j in range(4):
            # Ensure crops are more spread in the x direction by choosing a segment
            start_y = np.random.randint(0, y_dim - 128)
            start_x = np.random.randint(x_segments[j], x_segments[j + 1])

            # Crop the data
            crop = bt_data[start_y:start_y + 128, start_x:start_x + 128]

            # Skip this crop if it contains any NaN values
            if np.isnan(crop).any():
                with open(nan_crop_file, 'a') as log:
                    log.write(f"{file} and crop no {j}\n")
                continue

            # Store the crop and the corresponding coordinates
            all_crops.append(crop)
            all_lats.append(lat[start_y:start_y + 128])
            all_lons.append(lon[start_x:start_x + 128])
            all_times.append(timestamp)

            # Increment the sample counter
            sample_counter += 1

            # Save every 100 samples
            if sample_counter % 100 == 0:
                # Convert lists to numpy arrays
                all_crops_np = np.array(all_crops)
                all_lats_np = np.array(all_lats)
                all_lons_np = np.array(all_lons)
                all_times_np = np.array(all_times)

                # Create a dataset with the combined data
                ds = xr.Dataset(
                    {
                        "sample_data": (["sample", "y", "x"], all_crops_np)  # Data variable
                    },
                    coords={
                        "sample": (["sample"], np.arange(sample_counter - len(all_crops_np) + 1, sample_counter + 1)),  # Sample numbers
                        "lat": (["sample", "y"], all_lats_np),
                        "lon": (["sample", "x"], all_lons_np),
                        "time": (["sample"], all_times_np)
                    }
                )

                # Write or append to the NetCDF file
                if first_write:
                    ds.to_netcdf(output_file, mode='w')
                    first_write = False
                else:
                    ds.to_netcdf(output_file, mode='a')

                # Clear the lists to free up memory
                all_crops.clear()
                all_lats.clear()
                all_lons.clear()
                all_times.clear()

# Save any remaining data after the loop ends
if all_crops:
    all_crops_np = np.array(all_crops)
    all_lats_np = np.array(all_lats)
    all_lons_np = np.array(all_lons)
    all_times_np = np.array(all_times)

    ds = xr.Dataset(
        {
            "sample_data": (["sample", "y", "x"], all_crops_np)  # Data variable
        },
        coords={
            "sample": (["sample"], np.arange(sample_counter - len(all_crops_np) + 1, sample_counter + 1)),  # Sample numbers
            "lat": (["sample", "y"], all_lats_np),
            "lon": (["sample", "x"], all_lons_np),
            "time": (["sample"], all_times_np)
        }
    )

    ds.to_netcdf(output_file, mode='a')

OSError: [Errno -101] NetCDF: HDF error: '/p/scratch/exaww/chatterjee1/msg_netcdf/2023/msgobs_108_randcrops.nc'

In [15]:
i

2

In [8]:
file

'/p/scratch/exaww/chatterjee1/msg_netcdf/2023/04/HRSEVIRI_20230419T180009Z_20230419T181240Z_epct_03c6f782_PC.nc'

In [10]:
crop.shape

(128, 128)

In [11]:
ds

In [12]:
nc_file_loc + output_file

'/p/scratch/exaww/chatterjee1/msg_netcdf/2023/msgobs_108_randcrops.nc'

In [60]:
# 6 – 18°E, 47.5 – 53°N

In [49]:
idx = np.argmin(np.abs(detail.lat.values - 53))
idx

611

In [51]:
detail.lat.values[465], detail.lat.values[611]

(47.5, 52.98872180451127)

In [61]:
idx = np.argmin(np.abs(detail.lon.values - 27.45))
idx

823

In [62]:
detail.lon.values[252] , detail.lon.values[823]

(5.988517745302714, 27.44572025052192)

In [68]:
data

In [4]:
nc_file_loc = '/p/scratch/exaww/chatterjee1/msg_warmworld/2021/'
output_file = "/p/project/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_108_randcrops_2021.nc"
log_file = nc_file_loc + "processed_files_log.txt"
nan_crop_file = nc_file_loc + "nan_files_log.txt"

months = {
    4: '04/',
    5: '05/',
    6: '06/',
    7: '07/',
    8: '08/',
    9: '09/',
}

sample_counter = 0  # Initialize a counter for unique sample naming

# Create an empty NetCDF file to start with
#xr.Dataset().to_netcdf(output_file, mode='w')

all_crops = []  # List to store all crops
all_lats = []   # List to store all latitude coordinates
all_lons = []   # List to store all longitude coordinates
all_times = []  # List to store all timestamps

first_write = True  # Flag to check if it's the first time writing to the file

for _, key in enumerate(months.keys()):
    loc = nc_file_loc + months[key]
    nc_filepattern = "HRSEVIRI_2021*_PC.nc"
    nc_files = sorted(glob.glob(loc + nc_filepattern))

    for i, file in enumerate(nc_files):

        # Log the name of the current file
        with open(log_file, 'a') as log:
            log.write(f"{file}\n")
        
        data = xr.open_dataset(file)
        satellite_name = data.EPCT_product_name.split('-')[0]
        timestamp = data.EPCT_product_name.split('A-')[1].split('.')[0]

        lat = data.lat[465:611].values  
        lon = data.lon[252:823].values  
        radiances = data["channel_9"][465:611, 252:823].values
        bt_data = radiances_2_brightnesstemp_and_reflectances(radiances, 9, satellite_name)

        y_dim, x_dim = bt_data.shape

        # Divide x dimension into 4 segments for spreading the crops
        x_segments = np.linspace(0, x_dim - 128, 5, dtype=int)

        for j in range(4):
            # Ensure crops are more spread in the x direction by choosing a segment
            start_y = np.random.randint(0, y_dim - 128)
            start_x = np.random.randint(x_segments[j], x_segments[j + 1])

            # Crop the data
            crop = bt_data[start_y:start_y + 128, start_x:start_x + 128]

            # Skip this crop if it contains any NaN values
            if np.isnan(crop).any():
                with open(nan_crop_file, 'a') as log:
                    log.write(f"{file} and crop no {j}\n")
                continue

            # Store the crop and the corresponding coordinates
            all_crops.append(crop)
            all_lats.append(lat[start_y:start_y + 128])
            all_lons.append(lon[start_x:start_x + 128])
            all_times.append(timestamp)

            # Increment the sample counter
            sample_counter += 1

            all_crops_np = np.array(all_crops)
            all_lats_np = np.array(all_lats)
            all_lons_np = np.array(all_lons)
            all_times_np = np.array(all_times)

        # Create a dataset with the combined data
ds = xr.Dataset(
    {
        "sample_data": (["sample", "y", "x"], all_crops_np)  # Data variable
    },
    coords={
        "sample": (["sample"], np.arange(len(all_crops_np))),  # Sample numbers
        "lat": (["sample", "y"], all_lats_np),
        "lon": (["sample", "x"], all_lons_np),
        "time": (["sample"], all_times_np)
    }
)

# Write or append to the NetCDF file
if first_write:
    ds.to_netcdf(output_file, mode='w')
    first_write = False




In [5]:
all_crops_np.shape

(70284, 128, 128)

In [8]:
file = '/p/project/exaww/chatterjee1/dataset/warmworld_datasets/msgobs_108_randcrops_2020.nc'
data = xr.open_dataset(file)
data

In [5]:
data.sample_data.shape

(74268, 128, 128)

In [15]:
data.time[0].shape,data.time[0].values, data.time[1].dtype

((), array('20230401001241', dtype='<U14'), dtype('O'))

In [14]:
data.lat[0].shape,data.lat[0].values,data.lat[0].dtype

((128,),
 array([47.68796992, 47.72556391, 47.76315789, 47.80075188, 47.83834586,
        47.87593985, 47.91353383, 47.95112782, 47.9887218 , 48.02631579,
        48.06390977, 48.10150376, 48.13909774, 48.17669173, 48.21428571,
        48.2518797 , 48.28947368, 48.32706767, 48.36466165, 48.40225564,
        48.43984962, 48.47744361, 48.51503759, 48.55263158, 48.59022556,
        48.62781955, 48.66541353, 48.70300752, 48.7406015 , 48.77819549,
        48.81578947, 48.85338346, 48.89097744, 48.92857143, 48.96616541,
        49.0037594 , 49.04135338, 49.07894737, 49.11654135, 49.15413534,
        49.19172932, 49.22932331, 49.26691729, 49.30451128, 49.34210526,
        49.37969925, 49.41729323, 49.45488722, 49.4924812 , 49.53007519,
        49.56766917, 49.60526316, 49.64285714, 49.68045113, 49.71804511,
        49.7556391 , 49.79323308, 49.83082707, 49.86842105, 49.90601504,
        49.94360902, 49.98120301, 50.01879699, 50.05639098, 50.09398496,
        50.13157895, 50.16917293, 50.20676

In [13]:
data.lon[0].shape,data.lon[0].values,data.lon[0].dtype

((128,),
 array([ 7.0782881 ,  7.11586639,  7.15344468,  7.19102296,  7.22860125,
         7.26617954,  7.30375783,  7.34133612,  7.37891441,  7.41649269,
         7.45407098,  7.49164927,  7.52922756,  7.56680585,  7.60438413,
         7.64196242,  7.67954071,  7.717119  ,  7.75469729,  7.79227557,
         7.82985386,  7.86743215,  7.90501044,  7.94258873,  7.98016701,
         8.0177453 ,  8.05532359,  8.09290188,  8.13048017,  8.16805846,
         8.20563674,  8.24321503,  8.28079332,  8.31837161,  8.3559499 ,
         8.39352818,  8.43110647,  8.46868476,  8.50626305,  8.54384134,
         8.58141962,  8.61899791,  8.6565762 ,  8.69415449,  8.73173278,
         8.76931106,  8.80688935,  8.84446764,  8.88204593,  8.91962422,
         8.95720251,  8.99478079,  9.03235908,  9.06993737,  9.10751566,
         9.14509395,  9.18267223,  9.22025052,  9.25782881,  9.2954071 ,
         9.33298539,  9.37056367,  9.40814196,  9.44572025,  9.48329854,
         9.52087683,  9.55845511,  9.59603

In [17]:
data.sample_data[0].shape,data.sample_data[0].values,data.sample_data[0].dtype

((128, 128),
 array([[262.9626 , 263.492  , 262.60773, ..., 245.39163, 241.7357 ,
         241.7357 ],
        [262.9626 , 263.492  , 262.60773, ..., 242.89915, 240.3127 ,
         239.34691],
        [264.19278, 264.88785, 263.8431 , ..., 241.50061, 240.55196,
         239.34691],
        ...,
        [251.98094, 252.99074, 254.77533, ..., 246.49591, 245.1687 ,
         244.94505],
        [251.98094, 252.99074, 252.99074, ..., 245.39163, 247.36717,
         247.36717],
        [253.19109, 254.38232, 254.38232, ..., 250.12828, 250.12828,
         247.36717]], dtype=float32),
 dtype('float32'))

In [18]:
data.data_vars

Data variables:
    sample_data  (sample, y, x) float32 ...

In [4]:
def convert_nc_to_hdf5(nc_file, hdf5_file):
    # Open the NetCDF file using xarray
    ds = xr.open_dataset(nc_file)

    # Create the HDF5 file
    with h5py.File(hdf5_file, 'w') as hdf5_data:
        # Iterate over all variables in the xarray dataset
        for var_name in ds.data_vars:
            var_data = ds[var_name].values
            
            # Create a dataset in the HDF5 file
            hdf5_dataset = hdf5_data.create_dataset(
                var_name, 
                data=var_data, 
                dtype=var_data.dtype, 
                #chunks=chunking,  # Enable chunking if requested
                #compression=compression  # Apply compression if provided
            )

            # Copy variable attributes of the variable to the HDF5 dataset
            for attr_name, attr_value in ds[var_name].attrs.items():
                hdf5_dataset.attrs[attr_name] = attr_value
        
        # Iterate over all coordinates in the xarray dataset
        for coord_name in ds.coords:
            coord_data = ds[coord_name].values
            
            # Handle special case for time coordinate with dtype('O')
            if coord_data.dtype == 'O':
                # Convert to fixed-length strings
                coord_data = coord_data.astype('S')
            
            # Create a dataset in the HDF5 file for the coordinate
            hdf5_coord = hdf5_data.create_dataset(
                coord_name, 
                data=coord_data, 
                dtype=coord_data.dtype, 
                #chunks=chunking,  # Enable chunking if requested
                #compression=compression  # Apply compression if provided
            )
            
            # Copy coordinate attributes to the HDF5 dataset
            for attr_name, attr_value in ds[coord_name].attrs.items():
                hdf5_coord.attrs[attr_name] = attr_value
        
        # Copy global attributes
        for attr_name, attr_value in ds.attrs.items():
            hdf5_data.attrs[attr_name] = attr_value
    # Close the xarray dataset
    ds.close()


nc_file = '/p/project1/deepacf/kiste/DC/dataset/msgobs_108_randcrops.nc'
hdf5_file = '/p/project1/deepacf/kiste/DC/dataset/msgobs_108_randcrops.h5'
convert_nc_to_hdf5(
    nc_file, 
    hdf5_file, 
    #chunking=False, 
    #compression=None
)
print(f"Converted {nc_file} to {hdf5_file}")

Converted /p/project1/deepacf/kiste/DC/dataset/msgobs_108_randcrops.nc to /p/project1/deepacf/kiste/DC/dataset/msgobs_108_randcrops.h5


In [6]:
ds = xr.open_dataset(nc_file)
ds

In [7]:
# Function to calculate mean and standard deviation using chunks
def compute_mean_std_chunked(data, chunk_size=1000):
    n_samples, height, width = data.shape
    n_elements = n_samples * height * width

    # Initialize mean and variance sums
    total_sum = 0.0
    total_square_sum = 0.0

    # Process in chunks
    for i in range(0, n_samples, chunk_size):
        chunk = data[i:i+chunk_size]

        # Update sums
        total_sum += np.sum(chunk)
        total_square_sum += np.sum(chunk ** 2)

    # Calculate mean and variance
    mean = total_sum / n_elements
    variance = (total_square_sum / n_elements) - (mean ** 2)
    std = np.sqrt(variance)

    return mean, std

# Simulate smaller sample_data for demonstration purposes
small_sample_data = ds.sample_data

# Calculate the mean and standard deviation using chunks
mean_sample_data_chunked, std_sample_data_chunked = compute_mean_std_chunked(small_sample_data)

mean_sample_data_chunked, std_sample_data_chunked


(<xarray.DataArray 'sample_data' ()>
 array(270.52612414),
 <xarray.DataArray 'sample_data' ()>
 array(17.90132273))

In [8]:
mean_sample_data_chunked, std_sample_data_chunked = compute_mean_std_chunked(small_sample_data)

mean_sample_data_chunked, std_sample_data_chunked

(<xarray.DataArray 'sample_data' ()>
 array(270.52612414),
 <xarray.DataArray 'sample_data' ()>
 array(17.90132273))

# domain with country contours

In [66]:
nc_file_loc = '/p/scratch/exaww/chatterjee1/msg_netcdf/2023/'


months = {
    #4: '04/',
    #5: '05/',
    #6: '06/',
    #7: '07/',
    #8: '08/',
    9: '09/',
}


for _, key in enumerate(months.keys()):
    loc = nc_file_loc + months[key]
    nc_filepattern = "HRSEVIRI_2023*_PC.nc"
    nc_files = sorted(glob.glob(loc + nc_filepattern))


In [15]:
nc_files[0:25]

['/p/scratch/exaww/chatterjee1/msg_netcdf/2023/09/HRSEVIRI_20230901T000009Z_20230901T001241Z_epct_c154a552_PC.nc',
 '/p/scratch/exaww/chatterjee1/msg_netcdf/2023/09/HRSEVIRI_20230901T001510Z_20230901T002742Z_epct_a48198b9_PC.nc',
 '/p/scratch/exaww/chatterjee1/msg_netcdf/2023/09/HRSEVIRI_20230901T003010Z_20230901T004242Z_epct_23ef759b_PC.nc',
 '/p/scratch/exaww/chatterjee1/msg_netcdf/2023/09/HRSEVIRI_20230901T004510Z_20230901T005742Z_epct_1c3c8570_PC.nc',
 '/p/scratch/exaww/chatterjee1/msg_netcdf/2023/09/HRSEVIRI_20230901T010010Z_20230901T011242Z_epct_d83a4484_PC.nc',
 '/p/scratch/exaww/chatterjee1/msg_netcdf/2023/09/HRSEVIRI_20230901T011509Z_20230901T012742Z_epct_34660b62_PC.nc',
 '/p/scratch/exaww/chatterjee1/msg_netcdf/2023/09/HRSEVIRI_20230901T013009Z_20230901T014241Z_epct_ab143892_PC.nc',
 '/p/scratch/exaww/chatterjee1/msg_netcdf/2023/09/HRSEVIRI_20230901T014509Z_20230901T015741Z_epct_99d35475_PC.nc',
 '/p/scratch/exaww/chatterjee1/msg_netcdf/2023/09/HRSEVIRI_20230901T020010Z_2023

In [67]:
for file in nc_files[0:96:4]:
    data = xr.open_dataset(file)
    satellite_name = data.EPCT_product_name.split('-')[0]
    timestamp = data.EPCT_product_name.split('A-')[1].split('.')[0]

    # Your data and grid setup
    long = data['lon'][252:823]
    lat = data['lat'][465:611]
    bt_data = data['channel_9'][465:611, 252:823]  # shape: (146, 571)
    bt_data = radiances_2_brightnesstemp_and_reflectances(bt_data, 9, satellite_name)

    # Create meshgrid for lon/lat
    lon_grid, lat_grid = np.meshgrid(long, lat)

    # Calculate figure size to achieve 146x571 pixels at dpi=100
    fig = plt.figure(figsize=(5.71, 1.46))  # Width=5.71 inches, height=1.46 inches

    # Set up the plot with PlateCarree projection
    ax = plt.axes(projection=ccrs.PlateCarree())
    ax.coastlines()
    ax.add_feature(cfeature.BORDERS, linestyle='-', alpha=0.7)

    # Plot data with 'nearest' shading
    plt.pcolormesh(lon_grid, lat_grid, bt_data, cmap='viridis', shading='nearest', transform=ccrs.PlateCarree(), vmin=220, vmax=300)

    # Optional: Add colorbar, labels
    #plt.colorbar(label='Brightness Temperature')
    #plt.title('OBS:_' + timestamp [0:12] + '_146 x 571 pixels')
    plt.xlabel('Longitude')
    plt.ylabel('Latitude')

    # Save the plot with dpi=100 and no extra whitespace
    save_path = '/p/project1/exaww/chatterjee1/plots/vis_obs_data/'
    plt.savefig(save_path +timestamp[6:8] + '-' + timestamp[4:6] + ':' + timestamp[8:10] + '.png', dpi=100, bbox_inches='tight', pad_inches=0)
    plt.close()

In [64]:
data

In [40]:
nc_files[0:96:4]

['/p/scratch/exaww/chatterjee1/msg_netcdf/2023/09/HRSEVIRI_20230901T000009Z_20230901T001241Z_epct_c154a552_PC.nc',
 '/p/scratch/exaww/chatterjee1/msg_netcdf/2023/09/HRSEVIRI_20230901T010010Z_20230901T011242Z_epct_d83a4484_PC.nc',
 '/p/scratch/exaww/chatterjee1/msg_netcdf/2023/09/HRSEVIRI_20230901T020010Z_20230901T021242Z_epct_190db69c_PC.nc',
 '/p/scratch/exaww/chatterjee1/msg_netcdf/2023/09/HRSEVIRI_20230901T030010Z_20230901T031242Z_epct_d4065cd7_PC.nc',
 '/p/scratch/exaww/chatterjee1/msg_netcdf/2023/09/HRSEVIRI_20230901T040009Z_20230901T041241Z_epct_71861e1d_PC.nc',
 '/p/scratch/exaww/chatterjee1/msg_netcdf/2023/09/HRSEVIRI_20230901T050009Z_20230901T051242Z_epct_4cb15bc0_PC.nc',
 '/p/scratch/exaww/chatterjee1/msg_netcdf/2023/09/HRSEVIRI_20230901T060008Z_20230901T061241Z_epct_f59e8043_PC.nc',
 '/p/scratch/exaww/chatterjee1/msg_netcdf/2023/09/HRSEVIRI_20230901T070009Z_20230901T071241Z_epct_f4fe47bf_PC.nc',
 '/p/scratch/exaww/chatterjee1/msg_netcdf/2023/09/HRSEVIRI_20230901T080009Z_2023

# 2D fft of spatial distributions

In [58]:
for file in nc_files[0:96:4]: 
    data = xr.open_dataset(file)
    satellite_name = data.EPCT_product_name.split('-')[0]
    timestamp = data.EPCT_product_name.split('A-')[1].split('.')[0]

    long = data['lon'][252:823]
    lat = data['lat'][465:611]
    bt_data = data['channel_9'][465:611, 252:823]  # shape: (146, 571)
    bt_data = radiances_2_brightnesstemp_and_reflectances(bt_data, 9, satellite_name)

    if np.isnan(bt_data).any():
          pass# Fill NaNs with zeros or any other method if necessary
    else:
        # Step 2: Perform 2D FFT
        fft_result = np.fft.fft2(bt_data)

        # Step 3: Shift zero frequency component to the center
        fft_shifted = np.fft.fftshift(fft_result)

        # Step 4: Calculate the magnitude (absolute value) of the FFT
        magnitude_spectrum = np.abs(fft_shifted)

        # Optionally, use a logarithmic scale for better visualization
        log_magnitude_spectrum = np.log1p(magnitude_spectrum)  # log1p for log(1 + x) to avoid log(0) issues

        # Plot the magnitude spectrum
        plt.figure(figsize=(6, 6))
        im = plt.imshow(log_magnitude_spectrum, cmap='viridis')
        plt.colorbar(im, label='Log Magnitude Spectrum',fraction=0.016, pad=0.04)
        plt.title('2D FFT OBS ' + timestamp[6:8] + '-' + timestamp[4:6] + ':' + timestamp[8:10] )
        plt.xlabel('Frequency X')
        plt.ylabel('Frequency Y')
        save_path = '/p/project1/exaww/chatterjee1/plots/vis_obs_data/fft/'
        plt.savefig(save_path +timestamp[6:8] + '-' + timestamp[4:6] + ':' + timestamp[8:10] + '_fft.png', dpi=100, bbox_inches='tight', pad_inches=0)  # dpi=300 for high resolution, bbox_inches='tight' to trim whitespace
        plt.close()

In [53]:
timestamp[4:6],timestamp[6:8],timestamp[8:10]

('09', '01', '23')

# 1D FFT

In [62]:
def radial_profile(data):
    """Calculate the radial profile of a 2D FFT array."""
    # Get the center of the data
    center_y, center_x = np.array(data.shape) // 2
    
    # Create a meshgrid of distances from the center
    y, x = np.indices(data.shape)
    r = np.sqrt((x - center_x)**2 + (y - center_y)**2)
    r = r.astype(np.int)
    
    # Calculate the mean value at each radial distance
    tbin = np.bincount(r.ravel(), data.ravel())
    nr = np.bincount(r.ravel())
    radial_profile = tbin / nr
    return radial_profile

# List to store each radial profile
radial_profiles = []

# Loop through each file to process the FFT and calculate the radial profile
for file in nc_files[0:96:4]: 
    data = xr.open_dataset(file)
    satellite_name = data.EPCT_product_name.split('-')[0]
    timestamp = data.EPCT_product_name.split('A-')[1].split('.')[0]

    long = data['lon'][252:823]
    lat = data['lat'][465:611]
    bt_data = data['channel_9'][465:611, 252:823]  # shape: (146, 571)
    bt_data = radiances_2_brightnesstemp_and_reflectances(bt_data, 9, satellite_name)

    if not np.isnan(bt_data).any():
        # Step 2: Perform 2D FFT
        fft_result = np.fft.fft2(bt_data)

        # Step 3: Shift zero frequency component to the center
        fft_shifted = np.fft.fftshift(fft_result)

        # Step 4: Calculate the magnitude (absolute value) of the FFT
        magnitude_spectrum = np.abs(fft_shifted)

        # Step 5: Calculate the log of the magnitude spectrum
        log_magnitude_spectrum = np.log1p(magnitude_spectrum)  # log(1 + x) to avoid log(0) issues

        # Calculate the radial profile
        radial_profile_1d = radial_profile(log_magnitude_spectrum)
        
        # Store the radial profile with the timestamp for plotting
        radial_profiles.append((timestamp, radial_profile_1d))

# Plot all radial profiles in a single plot
plt.figure(figsize=(10, 6))
for timestamp, radial_profile_1d in radial_profiles:
    plt.plot(radial_profile_1d, label=f'OBS {timestamp[6:8]}-{timestamp[4:6]}:{timestamp[8:10]}')

plt.xlabel('Radial Frequency')
plt.ylabel('Log Magnitude')
plt.title('Obs: FFT for Each BT Distribution')
plt.legend()
save_path = '/p/project1/exaww/chatterjee1/plots/vis_obs_data/fft/'
plt.savefig(save_path +timestamp[6:8] + '-' + timestamp[4:6] + '_fft.png', dpi=100, bbox_inches='tight', pad_inches=0)
plt.close()

In [70]:
len(radial_profiles)

24

In [81]:
r_prof = np.load('/p/project1/exaww/chatterjee1/plots/vis_model_data/fft/radprof_icon.npy',allow_pickle=True)
plt.figure(figsize=(10, 6))
for timestamp, radial_profile_1d in r_prof:
    plt.plot(radial_profile_1d)
    
plt.text(180, 3.8, 'ICON', fontsize = 22)
    
for timestamp, radial_profile_1d in radial_profiles:
    plt.plot(radial_profile_1d)    

plt.text(270, 6.2, 'MSG', fontsize = 22)

plt.xlabel('Radial Frequency')
plt.ylabel('Log Magnitude')
plt.title('Obs: FFT for Each BT Distribution')
#plt.legend()
save_path = '/p/project1/exaww/chatterjee1/plots/vis_obs_data/fft/'
plt.savefig(save_path + '_fft_icon_trial.png', dpi=100, bbox_inches='tight', pad_inches=0)
plt.close()