In [None]:

import xarray as xr
import numpy as np
import h5py
# Read the GRIB2 file


ds = xr.open_dataset('PrecipRate_00.00/20221223/MRMS_PrecipRate_00.00_20221223-045000.grib2', engine='cfgrib')
lats = ds.latitude.values
lons = ds.longitude.values
lons = np.where(lons > 180, lons - 360, lons)
Lat_m, Lon_m = np.meshgrid(lats, lons, indexing='ij')
GMIfiles = sorted(os.listdir('GMI'))
dataset1 = h5py.File('GMI/'+GMIfiles[122])
Lat_g = dataset1['/S2/Latitude'][:]
Lon_g = dataset1['/S2/Longitude'][:]
scantime1 = dataset1['/S2/ScanTime/SecondOfDay'][:]
scantime2 = dataset1['/S2/ScanTime/DayOfYear'][:]


**Finding GMI Orbits crossing CONUS**

In [None]:
dataset1 = h5py.File('GMI/'+GMIfiles[129])
Lat_g = dataset1['/S2/Latitude'][:]
Lon_g = dataset1['/S2/Longitude'][:]
scantime1 = dataset1['/S2/ScanTime/SecondOfDay'][:]
scantime2 = dataset1['/S2/ScanTime/DayOfYear'][:]

def find_gmi_conus_indices(Lat_g, Lon_g, Lat_m, Lon_m):

    # Find MRMS coverage boundaries
    mrms_lat_min = np.min(Lat_m)
    mrms_lat_max = np.max(Lat_m)
    mrms_lon_min = np.min(Lon_m)
    mrms_lon_max = np.max(Lon_m)
    
    # Initialize arrays to track which rows are fully within MRMS coverage
    rows_in_conus = []
    
    # Check each row of GMI data
    for i in range(Lat_g.shape[0]):
        # Check if all points in this row are within MRMS boundaries
        lat_row = Lat_g[i, :]
        lon_row = Lon_g[i, :]
        
        # Check if all points in this row are within MRMS boundaries
        lat_in_bounds = np.all((lat_row >= mrms_lat_min) & (lat_row <= mrms_lat_max))
        lon_in_bounds = np.all((lon_row >= mrms_lon_min) & (lon_row <= mrms_lon_max))
        
        if lat_in_bounds and lon_in_bounds:
            rows_in_conus.append(i)
    
    # Find continuous sequence of rows
    if len(rows_in_conus) > 0:
        start_index = rows_in_conus[0]
        end_index = rows_in_conus[-1]
        
        # Verify continuity
        if not all(rows_in_conus[i+1] - rows_in_conus[i] == 1 
                  for i in range(len(rows_in_conus)-1)):
            print("Warning: Found gaps in CONUS coverage")
            
        return start_index, end_index
    else:
        return None, None

# Usage example:
start_idx, end_idx = find_gmi_conus_indices(Lat_g, Lon_g, Lat_m, Lon_m)

(1712, 1875)

**Creating Spatial-temporal matched filenames**

In [None]:
MyList = []

for i in range(len(GMIfiles)):
    dataset1 = h5py.File('GMI/'+GMIfiles[i])
    Lat_g = dataset1['/S2/Latitude'][:]
    Lon_g = dataset1['/S2/Longitude'][:]
    
    start_idx, end_idx = find_gmi_conus_indices(Lat_g, Lon_g, Lat_m, Lon_m)
    
    if (start_idx is not None):
        bb = end_idx-start_idx
        if bb>289:
            MyList.append([GMIfiles[i],start_idx,end_idx])
            bb = end_idx-start_idx
            
            print(bb)


import datetime

def generate_unique_mrms_filenames(scantime_SecondOfDay, scantime_DayOfYear):
   rate_filenames = set()
   flag_filenames = set()
   base_date = datetime.datetime(2022, 1, 1)
   
   for day, sec in zip(scantime_DayOfYear, scantime_SecondOfDay):
       curr_date = base_date + datetime.timedelta(days=int(day-1), seconds=int(sec))
       minutes = curr_date.minute + curr_date.second/60
       rounded_minutes = round(minutes/2) * 2
       rounded_date = curr_date.replace(minute=0, second=0) + datetime.timedelta(minutes=rounded_minutes)
       
       folder = rounded_date.strftime("%Y%m%d")
       timestamp = rounded_date.strftime("%Y%m%d-%H%M%S")
       
       rate_path = f'PrecipRate_00.00/{folder}/MRMS_PrecipRate_00.00_{timestamp}.grib2'
       flag_path = f'PrecipFlag_00.00/{folder}/MRMS_PrecipFlag_00.00_{timestamp}.grib2'
       
       rate_filenames.add(rate_path)
       flag_filenames.add(flag_path)

       Final_list = [sorted(list(rate_filenames)),sorted(list(flag_filenames))]
   
   return Final_list


def get_mapping_indices(scantime_SecondOfDay, scantime_DayOfYear, filenames):
   
   
   base_date = datetime.datetime(2022, 1, 1)
   scan_times = [base_date + datetime.timedelta(days=int(day-1), seconds=int(sec)) 
                for day, sec in zip(scantime_DayOfYear, scantime_SecondOfDay)]
   
   file_times = []
   for filename in filenames:
       timestamp = filename.split('_')[-1].replace('.grib2', '')
       file_time = datetime.datetime.strptime(timestamp, '%Y%m%d-%H%M%S')
       file_times.append(file_time)
   
   mapping_indices = []
   for scan_time in scan_times:
       time_diffs = [abs((file_time - scan_time).total_seconds()) for file_time in file_times]
       closest_idx = time_diffs.index(min(time_diffs))
       mapping_indices.append(closest_idx)
   
   return mapping_indices


Final_List = []
for i in range(len(MyList)): 
    dataset1 = h5py.File('GMI/'+ MyList[i][0])
    scantime1 = dataset1['/S2/ScanTime/SecondOfDay'][:]
    scantime2 = dataset1['/S2/ScanTime/DayOfYear'][:]
    start_f = MyList[i][1]+2
    end_f = MyList[i][2]-2
    scantime_SecondOfDay = scantime1[start_f:end_f+1]
    scantime_DayOfYear = scantime2[start_f:end_f+1]
    List = generate_unique_mrms_filenames(scantime_SecondOfDay, scantime_DayOfYear)
    indices = get_mapping_indices(scantime_SecondOfDay, scantime_DayOfYear, List[0])
    
    missing_files = [file for file in List[0]+List[1] if not os.path.exists(file)]

    if missing_files:
        print(" files don't exist.")
    else:
        Final_List.append(MyList[i] + [indices] + List)



In [None]:
import pickle
import numpy as np
import xarray as xr

f = open("Final_list.pkl", "rb") 
mylist = pickle.load(f)

ds = xr.open_dataset('PrecipRate_00.00/20221223/MRMS_PrecipRate_00.00_20221223-045000.grib2', engine='cfgrib')
lats = ds.latitude.values
lons = ds.longitude.values
lons = np.where(lons > 180, lons - 360, lons)
Lat_m, Lon_m = np.meshgrid(lats, lons, indexing='ij')

**Creating Fake Padding**

In [None]:
import numpy as np

def add_padding_to_gmi(lat_conus, lon_conus, n_pads=2):
    rows, cols = lat_conus.shape
    new_cols = cols + 2 * n_pads
    
    # Initialize padded arrays
    lat_padded = np.zeros((rows, new_cols))
    lon_padded = np.zeros((rows, new_cols))
    
    # Copy original data to middle
    lat_padded[:, n_pads:-n_pads] = lat_conus
    lon_padded[:, n_pads:-n_pads] = lon_conus
    
    # Calculate differences at left edge
    lat_diff_left = lat_conus[:, 1] - lat_conus[:, 0]
    lon_diff_left = lon_conus[:, 1] - lon_conus[:, 0]
    
    # Calculate differences at right edge
    lat_diff_right = lat_conus[:, -1] - lat_conus[:, -2]
    lon_diff_right = lon_conus[:, -1] - lon_conus[:, -2]
    
    # Add left padding grids
    for i in range(n_pads):
        lat_padded[:, n_pads-1-i] = lat_padded[:, n_pads-i] - lat_diff_left
        lon_padded[:, n_pads-1-i] = lon_padded[:, n_pads-i] - lon_diff_left
    
    # Add right padding grids
    for i in range(n_pads):
        lat_padded[:, -n_pads+i] = lat_padded[:, -n_pads-1+i] + lat_diff_right
        lon_padded[:, -n_pads+i] = lon_padded[:, -n_pads-1+i] + lon_diff_right
    
    return lat_padded, lon_padded



**Gathering MRMS points falling inside of GMI grids**

In [2]:
import numpy as np
from scipy.spatial import cKDTree
from collections import defaultdict

def find_reverse_nearest_neighbors(Lat_m, Lon_m, Lat_conus, Lon_conus):

    lat_padded, lon_padded = add_padding_to_gmi(Lat_conus, Lon_conus, n_pads=2)



    # Stack coordinates more efficiently
    mrms_points = np.stack((Lat_m.ravel(), Lon_m.ravel()), axis=1)
    gmi_points = np.stack((lat_padded.ravel(), lon_padded.ravel()), axis=1)
    
    # Build KD-tree with optimized leaf size
    tree = cKDTree(gmi_points)
    
    # Find nearest GMI point for each MRMS point using all available cores
    _, indices = tree.query(mrms_points, workers=-1, distance_upper_bound=np.inf)
    
    # Use defaultdict for faster accumulation
    temp_dict = defaultdict(list)
    
    # Pre-calculate indices for better performance
    flat_mrms_indices = np.arange(len(mrms_points))
    mrms_i = flat_mrms_indices // Lat_m.shape[1]
    mrms_j = flat_mrms_indices % Lat_m.shape[1]
    
    gmi_i = indices // lat_padded.shape[1]
    gmi_j = indices % lat_padded.shape[1]
    
    # Fast population of the dictionary
    for idx in range(len(indices)):
        temp_dict[(gmi_i[idx], gmi_j[idx])].append((mrms_i[idx], mrms_j[idx]))
    
    # Create output array
    reverse_neighbors = np.empty(lat_padded.shape, dtype=object)
    
    # Convert dictionary to array efficiently
    for i in range(lat_padded.shape[0]):
        for j in range(lat_padded.shape[1]):
            reverse_neighbors[i, j] = temp_dict.get((i, j), [])

            final = reverse_neighbors[2:-2,2:-2]
    
    return final

In [None]:
def FlagRate(finall, P_rate, P_flag): 

    Rate = np.zeros(finall.shape, dtype=object)
    Flag = np.zeros(finall.shape, dtype=object)

    for i in range(finall.shape[0]):
        for j in range(finall.shape[1]):
            Rate[i,j] = np.array([P_rate[finall[i,j][k]] for k in range(len(finall[i,j]))])
            Flag[i,j] = np.array([P_flag[finall[i,j][k]] for k in range(len(finall[i,j]))])

    return Rate, Flag

**Deciding on flag & rate of GMI grids**

In [4]:
import numpy as np

def process_mrms_vector(flag_vector, rate_vector):

    # Convert inputs to numpy arrays if they aren't already
    flag_vector = np.array(flag_vector)
    rate_vector = np.array(rate_vector)
    
    # Count occurrences of each type
    rain_flags = [1, 6, 7, 10, 91, 96]  # All rain-type flags
    snow_flags = [3]  # Snow flag
    no_precip_flags = [0]  # No precipitation
    no_coverage_flags = [-3]  # No coverage
    
    rain_count = np.sum(np.isin(flag_vector, rain_flags))
    snow_count = np.sum(np.isin(flag_vector, snow_flags))
    no_precip_count = np.sum(np.isin(flag_vector, no_precip_flags))
    no_coverage_count = np.sum(np.isin(flag_vector, no_coverage_flags))
    
    vector_length = len(flag_vector)
    
    # First check if we only have no_coverage and no_precip
    if rain_count == 0 and snow_count == 0:
        if no_coverage_count > no_precip_count:
            return -3, -3
        else:
            return 0, 0
    
    # Compare rain and snow counts
    if snow_count > rain_count:
        # Snow is majority
        # Calculate rate: sum of snow rates divided by vector length
        snow_mask = np.isin(flag_vector, snow_flags)
        total_snow_rate = np.sum(rate_vector[snow_mask])
        final_rate = total_snow_rate / vector_length
        return 2, final_rate
    else:
        # Rain is majority (or equal)
        # Calculate rate: sum of rain rates divided by vector length
        rain_mask = np.isin(flag_vector, rain_flags)
        total_rain_rate = np.sum(rate_vector[rain_mask])
        final_rate = total_rain_rate / vector_length
        return 1, final_rate

def process_mrms_grid(mapped_flag, mapped_rate):

    a = mapped_flag.shape[0]
    b = mapped_flag.shape[1]
    processed_flags = np.empty((a, b), dtype=int)  # Integer type for flags
    processed_rates = np.empty((a, b), dtype=float)
    
    for i in range(a):
        for j in range(b):
            # Extract the vectors at this position
            flag_vector = mapped_flag[i,j]
            rate_vector = mapped_rate[i,j]
            
            # Process only if we have valid vectors
            if len(flag_vector) > 0:
                flag, rate = process_mrms_vector(flag_vector, rate_vector)
            else:
                flag, rate = -3, -3
                
            processed_flags[i,j] = flag
            processed_rates[i,j] = rate
            
    return processed_flags, processed_rates

**Excecuting functions!**

In [None]:

import xarray as xr
import numpy as np
import h5py


Chunk_index = [200,500]


for i in range(Chunk_index[0],Chunk_index[-1]):
    
    dataset1 = h5py.File('GMI/'+mylist[i][0])
    Lat_g = dataset1['/S2/Latitude'][:]
    Lon_g = dataset1['/S2/Longitude'][:]
    
    
    Lat_conus = Lat_g[mylist[i][1]:mylist[i][2]+1, :]
    Lon_conus = Lon_g[mylist[i][1]:mylist[i][2]+1, :]


    Preflag = [0 for e in range(len(mylist[i][-1]))]
    PreRate = [0 for e in range(len(mylist[i][-2]))]

    Preflag_O = [0 for e in range(len(mylist[i][-1]))]
    PreRate_O = [0 for e in range(len(mylist[i][-2]))]


    for k in range(len(Preflag_O)): 
        ds1 = xr.open_dataset(mylist[i][-2][k], engine='cfgrib')
        PreRate_O[k] = ds1.unknown.values
        ds2 = xr.open_dataset(mylist[i][-1][k], engine='cfgrib')
        Preflag_O[k] = ds2.unknown.values
        
    finall = find_reverse_nearest_neighbors(Lat_m, Lon_m, Lat_conus, Lon_conus)     


    for p in range(len(Preflag)): 

        PreRate[p], Preflag[p] = FlagRate(finall, PreRate_O[p], Preflag_O[p])

    indices = mylist[i][3]
    Final_Rate = np.zeros(finall.shape, dtype=object)
    Final_Flag = np.zeros(finall.shape, dtype=object)

    for r in range(finall.shape[0]):
        Final_Rate[r,:] = PreRate[indices[r]][r,:]
        Final_Flag[r,:] = Preflag[indices[r]][r,:]
        
    processed_flags, processed_rates = process_mrms_grid(Final_Flag, Final_Rate)
    
    start_f = mylist[i][1]+2
    end_f = mylist[i][2]-2
    

    
    GMI_FLAG = np.full(Lon_g.shape, np.nan)
    GMI_RATE = np.full(Lon_g.shape, np.nan)
    
    GMI_FLAG[start_f:end_f+1:] = processed_flags
    GMI_RATE[start_f:end_f+1:] = processed_rates
    
    GMI_FLAG_RAW = np.full(Lon_g.shape, np.nan, dtype=object)
    GMI_RATE_RAW = np.full(Lon_g.shape, np.nan, dtype=object)
    
    GMI_FLAG_RAW[start_f:end_f+1:] = Final_Flag
    GMI_RATE_RAW[start_f:end_f+1:] = Final_Rate
    


    

    np.save('RAW/'+'RFlag'+'.' +mylist[i][0].split('.')[5], GMI_FLAG_RAW)
    np.save('RAW/'+'RRate'+'.' +mylist[i][0].split('.')[5], GMI_RATE_RAW)
    np.save('FILE/'+'Flag'+'.' +mylist[i][0].split('.')[5], GMI_FLAG)
    np.save('FILE/'+'Rate'+'.' +mylist[i][0].split('.')[5], GMI_RATE)
    
    del GMI_FLAG, GMI_RATE, GMI_FLAG_RAW, GMI_RATE_RAW, Preflag_O, PreRate_O, Preflag, PreRate
    
    print(f'#{i} has been done')
     
    

**Visualization of an example**

In [None]:
import xarray as xr
import numpy as np
import h5py
Flag_file = xr.open_dataset('MRMS_PrecipFlag_00.00_20221223-040400.grib2', engine='cfgrib')
Rate_file = xr.open_dataset('MRMS_PrecipRate_00.00_20221223-040400.grib2', engine='cfgrib')

lats  = Rate_file.latitude.values
lons  = Rate_file.longitude.values
lons = np.where(lons > 180, lons - 360, lons)

Flag_mapped = np.load("Flag.050100.npy")
Rate_mapped = np.load("Rate.050100.npy")

Lat_m, Lon_m = np.meshgrid(lats, lons, indexing='ij')

Flag_MRMS = Flag_file.unknown.values
Rate_MRMS = Rate_file.unknown.values


GMI_dateset = h5py.File('1B.GPM.GMI.TB2021.20221223-S030751-E044025.050100.V07A.HDF5')
Lat_g = GMI_dateset['/S2/Latitude'][:]
Lon_g = GMI_dateset['/S2/Longitude'][:]

del lats, lons, Flag_file, Rate_file, GMI_dateset


Rain_MRMS = Rate_MRMS.copy()
Rain_MRMS[np.where( (Flag_MRMS != 1) &(Flag_MRMS != 6) & (Flag_MRMS != 7) & (Flag_MRMS != 10) & (Flag_MRMS != 91)& (Flag_MRMS != 96) )] = 0
Rain_MRMS[np.where( (Flag_MRMS == 3))] = np.nan

Snow_MRMS = Rate_MRMS.copy()
Snow_MRMS[np.where( (Flag_MRMS != 3))] = 0
Snow_MRMS[np.where( (Flag_MRMS == 1) |(Flag_MRMS == 6) | (Flag_MRMS == 7) | (Flag_MRMS == 10) | (Flag_MRMS == 91)| (Flag_MRMS == 96) )] = np.nan
Snow_MRMS.max()

Rain_Mapped = Rate_mapped.copy()

Rain_Mapped[np.where( (Flag_mapped != 1))] = 0
Rain_Mapped[np.where( (Flag_mapped == 2))] = np.nan

Snow_Mapped = Rate_mapped.copy()

Snow_Mapped[np.where( (Flag_mapped != 2))] = 0
Snow_Mapped[np.where( (Flag_mapped == 1))] = np.nan



# Define custom colormap
colors = ["#A9A9A9", "#FFA500", "#ADD8E6", "#FFC0CB"]  # Gray, Orange, Pale Blue, Pink
cmap = mcolors.ListedColormap(colors)
bounds = [0, 1, 2, 3, 4]
norm = mcolors.BoundaryNorm(bounds, cmap.N)

# Create figure and axis
fig, ax = plt.subplots(figsize=(12, 6), subplot_kw={'projection': ccrs.PlateCarree()})
#ax.set_extent([Lon_ATMS.min(), Lon_ATMS.max(), Lat_ATMS.min(), Lat_ATMS.max()], crs=ccrs.PlateCarree())

# Add map features
ax.add_feature(cfeature.COASTLINE, linewidth=0.5)
ax.add_feature(cfeature.BORDERS, linewidth=0.5, linestyle=':')
    
g1 = ax.gridlines(draw_labels=True)
g1.right_labels = False
g1.bottom_labels = False
g1.xlabel_style = {'size':10, 'weight':'bold'}
g1.ylabel_style = {'size':10, 'weight':'bold'}

# Plot precipitation flags
colors = ['lightgray', 'green', 'orange', 'pink']

custom_cmap_ATMS = plt.cm.colors.LinearSegmentedColormap.from_list('custom', colors)
import numpy.ma as ma
mask_lon = (Lon_g > 179.5) | (Lon_g < -179.5)
mask_lat = (Lat_g > 89.5) | (Lat_g < -89.5)
mask = mask_lon | mask_lat
PrF_plot_masked1 = np.ma.masked_where(mask, Rain_Mapped)
mesh1 = ax.pcolormesh(Lon_g, Lat_g, PrF_plot_masked1,
                        transform=ccrs.PlateCarree(),
                        cmap = custom_cmap_ATMS,
                        vmin =0, 
                        vmax =5)



# Plot precipitation flags
colors2 = [
    "#E0E0E0",  # Very Light Gray (Little to no snowfall)
    "#A9EAFE",  # Light Cyan (Light snowfall)
    "#00BFFF",  # Deep Sky Blue (Moderate snowfall)
    "#0000FF",  # Blue (Heavy snowfall)
    "#9400D3"   # Dark Violet (Extreme snowfall, brighter purple)
]


custom_cmap_ATMS = plt.cm.colors.LinearSegmentedColormap.from_list('custom', colors2)
import numpy.ma as ma
mask_lon = (Lon_g > 179.5) | (Lon_g < -179.5)
mask_lat = (Lat_g > 89.5) | (Lat_g < -89.5)
mask = mask_lon | mask_lat
PrF_plot_masked2 = np.ma.masked_where(mask, Snow_Mapped)

mesh2 = ax.pcolormesh(Lon_g, Lat_g, PrF_plot_masked2,
                        transform=ccrs.PlateCarree(),
                        cmap = custom_cmap_ATMS,
                        vmin =0, 
                        vmax =2)

fig = plt.gcf()


cbar1_ax = fig.add_axes([0.906, 0.54, 0.01, 0.33])  # Top colorbar
cbar2_ax = fig.add_axes([0.906, 0.12, 0.01, 0.33])  # Bottom colorbar

# Add colorbars
cbar1 = fig.colorbar(mesh1, cax=cbar1_ax)
cbar1.set_label('Rain Rate (mm/h)', weight='bold')
cbar1.ax.yaxis.set_major_formatter(FormatStrFormatter('%.1f'))

cbar2 = fig.colorbar(mesh2, cax=cbar2_ax)
cbar2.set_label('Snow Rate (mm/h)', weight='bold')

# Title
ax.set_title("Mapped MRMS on GMI Orbit #50100", fontsize = 16, fontweight='bold')

ax.set_extent([-125, -66.5, 24, 49])

plt.savefig('Mapped', dpi = 500)

plt.show()

import numpy as np
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import numpy as np
from matplotlib.colors import Normalize
from matplotlib.ticker import FormatStrFormatter

from mpl_toolkits.axes_grid1 import make_axes_locatable


# Define custom colormap
colors = ["#A9A9A9", "#FFA500", "#ADD8E6", "#FFC0CB"]  # Gray, Orange, Pale Blue, Pink
cmap = mcolors.ListedColormap(colors)
bounds = [0, 1, 2, 3, 4]
norm = mcolors.BoundaryNorm(bounds, cmap.N)

# Create figure and axis
fig, ax = plt.subplots(figsize=(12, 6), subplot_kw={'projection': ccrs.PlateCarree()})
#ax.set_extent([Lon_ATMS.min(), Lon_ATMS.max(), Lat_ATMS.min(), Lat_ATMS.max()], crs=ccrs.PlateCarree())

# Add map features
ax.add_feature(cfeature.COASTLINE, linewidth=0.5)
ax.add_feature(cfeature.BORDERS, linewidth=0.5, linestyle=':')
    
g1 = ax.gridlines(draw_labels=True)
g1.right_labels = False
g1.bottom_labels = False
g1.xlabel_style = {'size':10, 'weight':'bold'}
g1.ylabel_style = {'size':10, 'weight':'bold'}

# Plot precipitation flags
colors = ['lightgray', 'green', 'orange', 'pink']

custom_cmap_ATMS = plt.cm.colors.LinearSegmentedColormap.from_list('custom', colors)
import numpy.ma as ma
mask_lon = (Lon_g > 179.5) | (Lon_g < -179.5)
mask_lat = (Lat_g > 89.5) | (Lat_g < -89.5)
mask = mask_lon | mask_lat
#PrF_plot_masked1 = np.ma.masked_where(mask, Rain_Mapped)
mesh1 = ax.pcolormesh(Lon_m, Lat_m, Rain_MRMS,
                        transform=ccrs.PlateCarree(),
                        cmap = custom_cmap_ATMS,
                        vmin =0, 
                        vmax =5)



# Plot precipitation flags
colors2 = [
    "#E0E0E0",  # Very Light Gray (Little to no snowfall)
    "#A9EAFE",  # Light Cyan (Light snowfall)
    "#00BFFF",  # Deep Sky Blue (Moderate snowfall)
    "#0000FF",  # Blue (Heavy snowfall)
    "#9400D3"   # Dark Violet (Extreme snowfall, brighter purple)
]


custom_cmap_ATMS = plt.cm.colors.LinearSegmentedColormap.from_list('custom', colors2)
import numpy.ma as ma
mask_lon = (Lon_g > 179.5) | (Lon_g < -179.5)
mask_lat = (Lat_g > 89.5) | (Lat_g < -89.5)
mask = mask_lon | mask_lat
#PrF_plot_masked2 = np.ma.masked_where(mask, Snow_Mapped)

mesh2 = ax.pcolormesh(Lon_m, Lat_m, Snow_MRMS,
                        transform=ccrs.PlateCarree(),
                        cmap = custom_cmap_ATMS,
                        vmin =0, 
                        vmax =2)

fig = plt.gcf()


cbar1_ax = fig.add_axes([0.906, 0.54, 0.01, 0.33])  # Top colorbar
cbar2_ax = fig.add_axes([0.906, 0.12, 0.01, 0.33])  # Bottom colorbar

# Add colorbars
cbar1 = fig.colorbar(mesh1, cax=cbar1_ax)
cbar1.set_label('Rain Rate (mm/h)', weight='bold')
cbar1.ax.yaxis.set_major_formatter(FormatStrFormatter('%.1f'))

cbar2 = fig.colorbar(mesh2, cax=cbar2_ax)
cbar2.set_label('Snow Rate (mm/h)', weight='bold')

# Title
ax.set_title("MRMS Precipitation", fontsize = 16, fontweight='bold')

ax.set_extent([-125, -66.5, 24, 49])

plt.savefig('MRMS', dpi = 500)

plt.show()