In [2]:
import xarray as xr
import numpy as np
import glob
import os
import h5py
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import warnings
warnings.filterwarnings("ignore")

In [3]:
C1 = 1.19104*10**(-5)  # in [mW (cm−1)−4 m-2 sr−1]
C2 = 1.43877  # in [K cm]

CHANNEL_NAME = {"channel_1": "VIS 0.6", 
                "channel_2": "VIS 0.8", 
                "channel_3": "NIR 1.6", 
                "channel_4": "IR 3.9", 
                "channel_5": "WV 6.2", 
                "channel_6": "WV 7.3", 
                "channel_7": "IR 8.7", 
                "channel_8": "IR 9.7 - O3", 
                "channel_9": "IR 10.8", 
                "channel_10": "IR 12.0", 
                "channel_11": "IR 13.4 - CO2", }
# in [cm−1]
VC = {'MSG1': {"channel_4": 2567.330, "channel_5": 1598.103, "channel_6": 1362.081, "channel_7": 1149.069, 
                "channel_8": 1034.343, "channel_9": 930.647, "channel_10": 839.660, "channel_11": 752.387
                }, 
      'MSG2': {"channel_4": 2568.832, "channel_5": 1600.548, "channel_6": 1360.330, "channel_7": 1148.620, 
                "channel_8": 1035.289, "channel_9": 931.700, "channel_10": 836.445, "channel_11": 751.792
                }, 
      'MSG3': {"channel_4": 2547.771, "channel_5": 1595.621, "channel_6": 1360.377, "channel_7": 1148.130, 
                "channel_8": 1034.715, "channel_9": 929.842, "channel_10": 838.659, "channel_11": 750.653
                }, 
      'MSG4': {"channel_4": 2555.280, "channel_5": 1596.080, "channel_6": 1361.748, "channel_7": 1147.433, 
                "channel_8": 1034.851, "channel_9": 931.122, "channel_10": 839.113, "channel_11": 748.585
                }, }
# unitless
ALPHA = {'MSG1': {"channel_4": 0.9956, "channel_5": 0.9962, "channel_6": 0.9991, "channel_7": 0.9996, 
                   "channel_8": 0.9999, "channel_9": 0.9983, "channel_10": 0.9988, "channel_11": 0.9981
                   }, 
         'MSG2': {"channel_4": 0.9954, "channel_5": 0.9963, "channel_6": 0.9991, "channel_7": 0.9996, 
                   "channel_8": 0.9999, "channel_9": 0.9983, "channel_10": 0.9988, "channel_11": 0.9981
                   }, 
         'MSG3': {"channel_4": 0.9915, "channel_5": 0.9960, "channel_6": 0.9991, "channel_7": 0.9996, 
                   "channel_8": 0.9999, "channel_9": 0.9983, "channel_10": 0.9988, "channel_11": 0.9982
                   }, 
         'MSG4': {"channel_4": 0.9916, "channel_5": 0.9959, "channel_6": 0.9990, "channel_7": 0.9996, 
                   "channel_8": 0.9998, "channel_9": 0.9983, "channel_10": 0.9988, "channel_11": 0.9981
                   }, }
# in [K]
BETA = {'MSG1': {"channel_4": 3.410, "channel_5": 2.218, "channel_6": 0.478, "channel_7": 0.179, ''
                  "channel_8": 0.060, "channel_9": 0.625, "channel_10": 0.397, "channel_11": 0.578
                  },
        'MSG2': {"channel_4": 3.438, "channel_5": 2.185, "channel_6": 0.470, "channel_7": 0.179, 
                  "channel_8": 0.056, "channel_9": 0.640, "channel_10": 0.408, "channel_11": 0.561
                  },
        'MSG3': {"channel_4": 2.9002, "channel_5": 2.0337, "channel_6": 0.4340, "channel_7": 0.1714, 
                  "channel_8": 0.0527, "channel_9": 0.6084, "channel_10": 0.3882, "channel_11": 0.5390
                  },
        'MSG4': {"channel_4": 2.9438, "channel_5": 2.0780, "channel_6": 0.4929, "channel_7": 0.1731, 
                  "channel_8": 0.0597, "channel_9": 0.6256, "channel_10": 0.4002, "channel_11": 0.5635
                  }, }

# %%
#############
############# look up tables for calculating reflectances
#############
# constants taken from website: 
# https://eumetsatspace.atlassian.net/wiki/spaces/DSDT/pages/1537277953/MSG15+radiances+conversion+to+BT+and+Reflectances
# and from https://www-cdn.eumetsat.int/files/2020-04/pdf_msg_seviri_rad2refl.pdf

IRRAD = {'MSG1': {"channel_1": 65.2296, "channel_2": 73.0127, "channel_3": 62.3715},
         'MSG2': {"channel_1": 65.2065, "channel_2": 73.1869, "channel_3": 61.9923},
         'MSG3': {"channel_1": 65.5148, "channel_2": 73.1807, "channel_3": 62.0208}, 
         'MSG4': {"channel_1": 65.2656, "channel_2": 73.1692, "channel_3": 61.9416}, }


# %%
class ir_channel:
    """
    class that calls channel specific constants from look up tables above
    """
    def __init__(self, satellite, channel):

        self.name = CHANNEL_NAME[channel]
        self.vc = VC[satellite][channel]  # wavenumber in [cm−1]
        self.alpha = ALPHA[satellite][channel]  # unitless
        self.beta = BETA[satellite][channel]  # in [K]

class vis_nir_channel:
    def __init__(self, satellite, channel):
        
        self.name = CHANNEL_NAME[channel]
        self.irrad = IRRAD[satellite][channel]  # irradiance at 1AU in [mW·m-2·(cm-1)-1]

class MSG_satellite:
    def __init__(self, name):
        self.name =  name

    def _get_channel(self, channel_number):
        # return vis/nir or ir channel depending on channel number
        if channel_number <=3:
            return vis_nir_channel(satellite=self.name, channel=f"channel_{channel_number}")
        else:
            return ir_channel(satellite=self.name, channel=f"channel_{channel_number}")

    def rad_2_tb(self, channel_number, radiances):
        # error handling here:
        # TODO: raise exception when given incorrect channel_number, must be >=4

        # get constants for given channel
        channel_consts = self._get_channel(channel_number)

        # converting radiance to brightness temperature [K] with simplified equation
        numerator = C2 * channel_consts.vc
        fraction = C1 * channel_consts.vc**3 / radiances + 1
        denominator = channel_consts.alpha * (np.log(fraction))
        tb = numerator / denominator - channel_consts.beta / channel_consts.alpha  ## [K]
        return tb
    
    def _d(t):
        # Sun-Earth distance in AU at time t
        return None
    
    def _solar_zenith_angle(t, lon, lat):
        # Solar Zenith Angle in Radians at time t and location x
        return None

    def rad_2_refl(self, channel_number, radiances, t, lon, lat):
        # error handling here:
        # TODO: raise exception when given incorrect channel_number, must be <= 3

        # get constants for given channel
        channel_consts = self._get_channel(channel_number)

        numerator = np.pi * radiances * self._d(t)**2
        denominator = channel_consts.irrad * np.cos(self._solar_zenith_angle(t, lon, lat))

# %%
def radiances_2_brightnesstemp_and_reflectances(radiances, channel_number, satellite_name):
    ## radiances in [mW m−2 sr−1 (cm−1)−1)]
    # TODO: add constraint to channel_number (must be >= 4)

    # access correct satellite 
    satellite = MSG_satellite(satellite_name)
    if channel_number <= 3:
        print("not implemented yet for visible and near-infrared")

    elif channel_number >= 4 and channel_number < 12:
        # get brightness temp fro given channel
        return satellite.rad_2_tb(channel_number, radiances)
    
    else:
        print(f"This channel does not exist for satellite {satellite_name}")
        # TODO: raise exception

## model and obs plot together

In [4]:
nc_file_loc = '/p/scratch/exaww/chatterjee1/msg_warmworld/files/'

fig, axes = plt.subplots(1, 2, figsize=(14, 6), subplot_kw={'projection': ccrs.PlateCarree()}, constrained_layout=True)

# First Obs
data_obs = xr.open_dataset(nc_file_loc+'09/HRSEVIRI_20230901T000009Z_20230901T001241Z_epct_c154a552_PC.nc')
satellite_name = data_obs.EPCT_product_name.split('-')[0]
timestamp = data_obs.EPCT_product_name.split('A-')[1].split('.')[0]
long_obs = data_obs['lon']
lat_obs = data_obs['lat']
bt_data_obs = data_obs['channel_9']
bt_data_obs = radiances_2_brightnesstemp_and_reflectances(bt_data_obs, 9, satellite_name)
lon_grid_obs, lat_grid_obs = np.meshgrid(long_obs, lat_obs)

ax = axes[0]
ax.set_title("Observation")
#ax.coastlines()
#ax.add_feature(cfeature.BORDERS, linestyle='-', alpha=0.7)

im1 = ax.pcolormesh(lon_grid_obs, lat_grid_obs, bt_data_obs, cmap='Greys',
                    shading='nearest', transform=ccrs.PlateCarree(), vmin=220, vmax=300)

# Second Model
data_icon = xr.open_dataset('/p/project1/exaww/chatterjee1/dataset/iconeu/09/01/iefrf00000000.nc')
date = data_icon.time.dt.strftime('%Y-%m-%d').values
timestamp_icon = data_icon.time.dt.strftime('%H:%M:%S').values
long = data_icon['lon']
lat = data_icon['lat']
bt_data_icon = data_icon['SYNMSG_BT_CL_IR10.8'][0,:,:]
lon_grid_icon, lat_grid_icon = np.meshgrid(long, lat)


# Get OBS domain bounds
lat_min = float(lat_obs.min().values)
lat_max = float(lat_obs.max().values)
lon_min = float(long_obs.min().values)
lon_max = float(long_obs.max().values)

# Subset ICON data to OBS domain
bt_data_icon = data_icon['SYNMSG_BT_CL_IR10.8'].isel(time=0).sel(
    lat=slice(lat_min, lat_max),
    lon=slice(lon_min, lon_max)
)

# Update lat/lon after slicing
lat_icon = bt_data_icon['lat']
lon_icon = bt_data_icon['lon']
lon_grid_icon, lat_grid_icon = np.meshgrid(lon_icon, lat_icon)

# Plot ICON data restricted to OBS domain
ax = axes[1]
ax.set_title("Model")
#ax.coastlines()
#ax.add_feature(cfeature.BORDERS, linestyle='-', alpha=0.7)

im2 = ax.pcolormesh(lon_grid_icon, lat_grid_icon, bt_data_icon, cmap='Greys',
                    shading='nearest', transform=ccrs.PlateCarree(), vmin=220, vmax=300)

# Add one shared colorbar horizontally
cbar = fig.colorbar(im2, ax=axes.ravel().tolist(), orientation='horizontal', fraction=0.05, pad=0.07)
cbar.set_label("Brightness Temperature (K)")

# Save the plot
save_path = '/p/project1/exaww/chatterjee1/plots/'
#plt.savefig(save_path + 'both_domains_equal_' + '.png', dpi=100, bbox_inches='tight', pad_inches=0)
plt.close()

In [5]:
long_obs.min(), lat_obs.min(), long_obs.max(), lat_obs.max()

(<xarray.DataArray 'lon' ()>
 array(-3.48121086),
 <xarray.DataArray 'lat' ()>
 array(30.01879699),
 <xarray.DataArray 'lon' ()>
 array(32.48121086),
 <xarray.DataArray 'lat' ()>
 array(54.98120301))

In [6]:
long.min(), lat.min(), long.max(), lat.max()

(<xarray.DataArray 'lon' ()>
 array(-23.5),
 <xarray.DataArray 'lat' ()>
 array(29.5),
 <xarray.DataArray 'lon' ()>
 array(62.5),
 <xarray.DataArray 'lat' ()>
 array(70.5))

## plotting the original domain used for random cropping

In [12]:
nc_file_loc = '/p/scratch/exaww/chatterjee1/msg_warmworld/files/'

fig, axes = plt.subplots(1, 2, figsize=(14, 6), subplot_kw={'projection': ccrs.PlateCarree()}, constrained_layout=True)

# First Obs
data_obs = xr.open_dataset(nc_file_loc+'09/HRSEVIRI_20230901T000009Z_20230901T001241Z_epct_c154a552_PC.nc')
satellite_name = data_obs.EPCT_product_name.split('-')[0]
timestamp = data_obs.EPCT_product_name.split('A-')[1].split('.')[0]
long_obs = data_obs['lon']
lat_obs = data_obs['lat']
bt_data_obs = data_obs['channel_9']
bt_data_obs = radiances_2_brightnesstemp_and_reflectances(bt_data_obs, 9, satellite_name)
bt_data_obs[465:611, 252:823] = 300
lon_grid_obs, lat_grid_obs = np.meshgrid(long_obs, lat_obs)

ax = axes[0]
ax.set_title("Observation")
ax.coastlines()
ax.add_feature(cfeature.BORDERS, linestyle='-', alpha=0.7)

im1 = ax.pcolormesh(lon_grid_obs, lat_grid_obs, bt_data_obs, cmap='Greys',
                    shading='nearest', transform=ccrs.PlateCarree(), vmin=220, vmax=300)

# Second Model
data_icon = xr.open_dataset('/p/project1/exaww/chatterjee1/dataset/iconeu/09/01/iefrf00000000.nc')
date = data_icon.time.dt.strftime('%Y-%m-%d').values
timestamp_icon = data_icon.time.dt.strftime('%H:%M:%S').values
long = data_icon['lon']
lat = data_icon['lat']
lon_grid_icon, lat_grid_icon = np.meshgrid(long, lat)


# Get OBS domain bounds
lat_min = float(lat_obs.min().values)
lat_max = float(lat_obs.max().values)
lon_min = float(long_obs.min().values)
lon_max = float(long_obs.max().values)

# Subset ICON data to OBS domain
bt_data_icon = data_icon['SYNMSG_BT_CL_IR10.8']
bt_data_icon[0,288:377,472:816] = 300
bt_data_icon = bt_data_icon.isel(time=0).sel(
    lat=slice(lat_min, lat_max),
    lon=slice(lon_min, lon_max)
)

# Update lat/lon after slicing
lat_icon = bt_data_icon['lat']
lon_icon = bt_data_icon['lon']
lon_grid_icon, lat_grid_icon = np.meshgrid(lon_icon, lat_icon)

# Plot ICON data restricted to OBS domain
ax = axes[1]
ax.set_title("Model")
ax.coastlines()
ax.add_feature(cfeature.BORDERS, linestyle='-', alpha=0.7)

im2 = ax.pcolormesh(lon_grid_icon, lat_grid_icon, bt_data_icon, cmap='Greys',
                    shading='nearest', transform=ccrs.PlateCarree(), vmin=220, vmax=300)

# Add one shared colorbar horizontally
cbar = fig.colorbar(im2, ax=axes.ravel().tolist(), orientation='horizontal', fraction=0.05, pad=0.07)
cbar.set_label("Brightness Temperature (K)")

# Save the plot
save_path = '/p/project1/exaww/chatterjee1/plots/'
plt.savefig(save_path + 'original_exp_domain' + '.png', dpi=100, bbox_inches='tight', pad_inches=0)
plt.close()

In [14]:
lat_min, lat_max

(30.018796992481203, 54.9812030075188)

In [15]:
lon_min, lon_max

(-3.4812108559498958, 32.4812108559499)

In [13]:
data_obs

In [14]:
data_icon

 ### Now going to prepare the dataset from all over the domain

In [None]:
nc_file_loc = '/p/scratch/exaww/chatterjee1/msg_warmworld/files/'
output_file = "/p/project1/exaww/chatterjee1/dataset/msgobs_108_randcrops_alldom.nc"

log_file = nc_file_loc + "processed_files_log_alldom.txt"
nan_crop_file = nc_file_loc + "nan_files_log_alldom.txt"

months = {
    4: '04/',
    5: '05/',
    6: '06/',
    7: '07/',
    8: '08/',
    9: '09/',
}

sample_counter = 0  # Initialize a counter for unique sample naming

all_crops = []  # List to store all crops
all_lats = []   # List to store all latitude coordinates
all_lons = []   # List to store all longitude coordinates
all_times = []  # List to store all timestamps

first_write = True  # Flag to check if it's the first time writing to the file

for _, key in enumerate(months.keys()):
    loc = nc_file_loc + months[key]
    nc_filepattern = "HRSEVIRI_2023*_PC.nc"
    nc_files = sorted(glob.glob(loc + nc_filepattern))

    for i, file in enumerate(nc_files):
        with open(log_file, 'a') as log:
            log.write(f"{file}\n")

        data = xr.open_dataset(file)
        satellite_name = data.EPCT_product_name.split('-')[0]
        timestamp = data.EPCT_product_name.split('A-')[1].split('.')[0]

        lat = data.lat.values
        lon = data.lon.values
        radiances = data["channel_9"].values
        bt_data = radiances_2_brightnesstemp_and_reflectances(radiances, 9, satellite_name)

        y_dim, x_dim = bt_data.shape

        for j in range(8):
            start_y = np.random.randint(0, y_dim - 128)
            start_x = np.random.randint(0, x_dim - 128)

            crop = bt_data[start_y:start_y + 128, start_x:start_x + 128]

            if not np.all(np.isfinite(crop)):
                with open(nan_crop_file, 'a') as log:
                    log.write(f"{file} and crop no {j}\n")
                continue

            all_crops.append(crop)
            all_lats.append(lat[start_y:start_y + 128])
            all_lons.append(lon[start_x:start_x + 128])
            all_times.append(timestamp)
            sample_counter += 1

# After processing all files:
all_crops_np = np.array(all_crops)
all_lats_np = np.array(all_lats)
all_lons_np = np.array(all_lons)
all_times_np = np.array(all_times)

ds = xr.Dataset(
    {
        "sample_data": (["sample", "y", "x"], all_crops_np)
    },
    coords={
        "sample": (["sample"], np.arange(len(all_crops_np))),
        "lat": (["sample", "y"], all_lats_np),
        "lon": (["sample", "x"], all_lons_np),
        "time": (["sample"], all_times_np)
    }
)

if first_write:
    ds.to_netcdf(output_file, mode='w')
    first_write = False