In [1]:
import pygrib
import numpy as np
import os

import json

In [2]:
from uuid import uuid4

print(uuid4().hex[-6:])

faaf67


In [3]:
import numpy as np

x = np.array([[0,1,2,3], [0,1,2,3]])
print(x+1)

[[1 2 3 4]
 [1 2 3 4]]


In [2]:
SURFACE_VARIABLES = [
    '10u', 
    '10v', 
    '2d', 
    '2t', 
    'lsm', 
    'msl', 
    'sdor', 
    'skt', 
    'slor', 
    'sp', 
    'tcw']

UPPER_VARIABLES =[
    'q', 
    't', 
    'w', 
    'z', 
    'u', 
    'v'
]

VARIABLES = sorted(SURFACE_VARIABLES + UPPER_VARIABLES)

VARIABLES_DOC = {
    'levels': 'pressure_levels in hPa',
    'q': 'Specific humidity',
    't': 'Temperature',
    'w': 'Vertical velocity',
    'z': 'Geopotential',
    'u': 'U component of wind',
    'v': 'V component of wind',
    '10u': '10 metre U wind component',
    '10v': '10 metre V wind component',
    '2d': '2 metre dewpoint temperature',
    '2t': '2 metre temperature',
    'lsm': 'Land-sea mask',
    'msl': 'Mean sea level pressure',
    'sdor': 'Standard deviation of sub-gridscale orography',
    'skt': 'Skin temperature',
    'slor': 'Slope of sub-gridscale orography',
    'sp': 'Surface pressure',
    'tcw': 'Total column water'
    }

LEVELS = sorted(set([100, 200, 1000, 300, 400, 50, 850, 500, 150, 600, 250, 700, 925]))

VALID_VARIABLES_DICT = {k: LEVELS if k in UPPER_VARIABLES else [0] for k in VARIABLES}
VALID_VARIABLES_DICT['z'] = [0] + LEVELS

In [3]:
len(LEVELS)

13

In [7]:
grib_file = "/home/user/large-disk/aifs_preds/Code3_Interfaces_Python/Ctrl_nopert_20240326.grb"
# "/home/user/large-disk/pangu_preds/test_perturbed.grib"

In [8]:
GRIB_FILE = None

def validate_grib_file(grib_file_path):
    valid_grib = True

    grbs = pygrib.open(grib_file_path)

    invalid_columns = dict()
    invalid_levels = dict()
    time_levels = set()
    columns = set()

    all_time_columns_levels = dict()

    for grb in grbs:

        column_name = grb.shortName
        column_level = grb.level
        column_time = grb.dataTime

        all_columns_levels = all_time_columns_levels.get(column_time, {})

        if column_name not in VARIABLES:
            time_invalid_columns = invalid_columns.get(column_time, [])
            time_invalid_columns.append(column_name)
        
        if column_level not in LEVELS:
            time_invalid_levels = invalid_levels.get(column_time, [])
            time_invalid_levels.append(column_name)

        time_levels.add(column_time)
        columns.add(column_name)

        col = all_columns_levels.get(grb.shortName, list())
        col.append(grb.level)
        all_columns_levels[grb.shortName] = col

        all_time_columns_levels[column_time] = all_columns_levels

    grbs.close()
    
    # sort dict
    sorted_level_3 = {
        k1: {
            k2: sorted(v2)
            for k2, v2 in v1.items()
        }
        for k1, v1 in all_time_columns_levels.items()
    }

    sorted_level_2 = {
        k1: dict(sorted(v1.items()))
        for k1, v1 in sorted_level_3.items()
    }

    all_time_columns_levels = dict(sorted(sorted_level_2.items()))

    if invalid_columns:
        print(f"invalid grib:\nextra columns found:\n{invalid_columns}")
        valid_grib = False

    if invalid_levels:
        print(f"invalid grib:\nextra levels found:\n{invalid_levels}")
        valid_grib = False

    for (k,v) in all_time_columns_levels.items():
        if v != VALID_VARIABLES_DICT:
            print(f"{k} time does not have all the variables and levels")
            valid_grib = False
    
    return valid_grib


In [9]:
validate_grib_file(grib_file)

True

In [10]:
def parturbation_of_variable(
        grib_file,
        variable,
        level=0,
        zmul=1,
        zadd=0,
        output_grib_file=None,):
    
    
    if not validate_grib_file(grib_file):
        return False
    
    if zmul == 1 and zadd == 0:
        print(f"no addition term and no multiplication factor, grib file will not be perturbed")
        return False
    
    path, file = os.path.split(grib_file)
    filename, extension = os.path.splitext(file)

    input_grib_file = grib_file
    output_grib_file = os.path.join(path, f"{filename}_{variable}_{level}_perturbed{extension}")
    with open(os.path.join(path, f"{filename}_{variable}_{level}_perturbed.perturbation_config"), "w") as f:
        f.write(json.dumps({
            "grib_file_name": str(grib_file),
            "perturbed_variable": variable,
            "perturbed_level": level,
            "perturbation_add_term": zadd,
            "perturbation_mul_factor": zmul,
        }))

    # Open the GRIB file
    grbs = pygrib.open(input_grib_file)

    with open(output_grib_file, 'wb') as out_file:
        found_variable = False
        for grb in grbs:
            if grb.shortName == variable and grb.level == level:
                grb.expand_grid(False)

                found_variable = True
                print(f"perturbing {grb.shortName} - {grb.level} @ Time: {grb.dataTime} by factor with multiplication of {zmul} and addition of {zadd}.")
                data, latitudes, longitudes = grb.data()

                modified_data = data * zmul + zadd
                
                grb.values = modified_data.reshape(-1)

            out_file.write(grb.tostring())
    
    grbs.close()

    if not found_variable:
        print(f"{variable} column does not exist in the grib file.")
        return False
    else:
        return True


def perturbation_by_factor_list(
        grib_file,
        perturbation_dict = {
            "u": {300: 0.6}, 
            "v": {300: 0.6}}
        ):
    """
    perturbatio_dict contains tuples {variable_name: {variables_level: perturbation_factor}}
    """
    
    if not validate_grib_file(grib_file):
        return False
    
    path, file = os.path.split(grib_file)
    filename, extension = os.path.splitext(file)

    input_grib_file = grib_file
    output_grib_file = os.path.join(path, f"{filename}_perturbed{extension}")
    with open(os.path.join(path, f"{filename}_perturbed.perturbation_config"), "w") as f:
        f.write(json.dumps(perturbation_dict))

    # Open the GRIB file
    grbs = pygrib.open(input_grib_file)

    # this will be used to track if the variables to perturb exist in the grib file
    variables_levels_check = list()
    _ = [variables_levels_check.extend([(k, k0) for k0 in v.keys()]) for (k,v) in perturbation_dict.items()]

    with open(output_grib_file, 'wb') as out_file:
        variables_levels = list()
        for grb in grbs:
            # if grb.shortName in perturbation_dict.keys():

            grbs.expand_grid(False)

            if grb.level in perturbation_dict.get(grb.shortName, {}).keys():
                print(f"perturbing {grb.shortName} - {grb.level} @ Time: {grb.dataTime}")
                # data, latitudes, longitudes = grb.data()
                data = grb.values

                modified_data = data * perturbation_dict.get(grb.shortName).get(grb.level)

                grb.values = modified_data.reshape(-1)
                
                variables_levels.append((grb.shortName, grb.level))

            out_file.write(grb.tostring())
    
    variables_diff = set(variables_levels_check) - set(variables_levels)

    if variables_diff:
        print(f"these variables were not found in the grib file:\n{variables_diff}")

    grbs.close()

    return True    

In [8]:
[print(i) for i in ['GRIBEXSection1Problem', 'GRIBEditionNumber', 'N', 'Ni', 'Nj', 'P1', 'P2', 'PLPresent', 'PVPresent', 'WMO', '__class__', '__delattr__', '__delitem__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattr__', '__getattribute__', '__getitem__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__setitem__', '__setstate__', '__sizeof__', '__str__', '__subclasshook__', '_all_keys', '_get_key', '_read_only_keys', '_reshape_mask', '_ro_keys', '_set_projparams', '_unshape_mask', 'additionalFlagPresent', 'alternativeRowScanning', 'analDate', 'analDate', 'angleSubdivisions', 'average', 'binaryScaleFactor', 'bitMapIndicator', 'bitmapPresent', 'bitsPerValue', 'bitsPerValueAndRepack', 'boustrophedonic', 'centre', 'centreDescription', 'centuryOfReferenceTimeOfData', 'cfName', 'cfNameECMF', 'cfVarName', 'cfVarNameECMF', 'changeDecimalPrecision', 'complexPacking', 'constantFieldHalfByte', 'data', 'dataDate', 'dataFlag', 'dataLength', 'dataRepresentationType', 'dataTime', 'day', 'decimalPrecision', 'decimalScaleFactor', 'deleteLocalDefinition', 'deletePV', 'distinctLatitudes', 'distinctLongitudes', 'earthIsOblate', 'editionNumber', 'endStep', 'eps', 'expand_grid', 'expand_reduced', 'experimentVersionNumber', 'fcstimeunits', 'gaussianGridName', 'generatingProcessIdentifier', 'getNumberOfValues', 'global', 'globalDomain', 'grib2LocalSectionNumber', 'gridDefinition', 'gridDefinitionDescription', 'gridDefinitionTemplateNumber', 'gridDescriptionSectionPresent', 'gridType', 'halfByte', 'has_key', 'hideThis', 'hour', 'iDirectionIncrement', 'iDirectionIncrementInDegrees', 'iScansNegatively', 'iScansPositively', 'ifsParam', 'ijDirectionIncrementGiven', 'indicatorOfParameter', 'indicatorOfTypeOfLevel', 'integerPointValues', 'interpretationOfNumberOfPoints', 'isConstant', 'isGridded', 'isOctahedral', 'isSpectral', 'is_missing', 'jPointsAreConsecutive', 'jScansNegatively', 'jScansPositively', 'julianDay', 'keys', 'kurtosis', 'latLonValues', 'latitudeOfFirstGridPoint', 'latitudeOfFirstGridPointInDegrees', 'latitudeOfLastGridPoint', 'latitudeOfLastGridPointInDegrees', 'latitudes', 'latlons', 'legacyGaussSubarea', 'lengthOfHeaders', 'level', 'localDefinitionNumber', 'localUsePresent', 'longitudeOfFirstGridPoint', 'longitudeOfFirstGridPointInDegrees', 'longitudeOfLastGridPoint', 'longitudeOfLastGridPointInDegrees', 'longitudes', 'marsClass', 'marsParam', 'marsStream', 'marsType', 'maximum', 'md5Headers', 'md5Product', 'md5Section1', 'md5Section2', 'md5Section4', 'messagenumber', 'minimum', 'minute', 'missingValue', 'month', 'name', 'nameECMF', 'neitherPresent', 'numberIncludedInAverage', 'numberMissingFromAveragesOrAccumulations', 'numberOfCodedValues', 'numberOfDataPoints', 'numberOfDataPointsExpected', 'numberOfForecastsInEnsemble', 'numberOfMissing', 'numberOfOctectsForNumberOfPoints', 'numberOfValues', 'numberOfVerticalCoordinateValues', 'offsetSection0', 'offsetValuesBy', 'optimizeScaleFactor', 'orderOfSPD', 'packingError', 'packingType', 'paramId', 'paramIdECMF', 'parameterName', 'parameterUnits', 'perturbationNumber', 'pl', 'pressureUnits', 'productionStatusOfProcessedData', 'projparams', 'pvlLocation', 'radius', 'referenceValue', 'referenceValueError', 'resolutionAndComponentFlags', 'resolutionAndComponentFlags3', 'resolutionAndComponentFlags4', 'resolutionAndComponentFlags6', 'resolutionAndComponentFlags7', 'resolutionAndComponentFlags8', 'scaleValuesBy', 'scanningMode', 'scanningMode4', 'scanningMode5', 'scanningMode6', 'scanningMode7', 'scanningMode8', 'second', 'section0Length', 'section1Length', 'section2Length', 'section4Length', 'section5Length', 'setLocalDefinition', 'setPackingType', 'shortName', 'shortNameECMF', 'skewness', 'sphericalHarmonics', 'standardDeviation', 'startStep', 'stepRange', 'stepType', 'stepTypeForConversion', 'stepUnits', 'subCentre', 'swapScanningAlternativeRows', 'table2Version', 'tableReference', 'timeRangeIndicator', 'tostring', 'totalLength', 'typeOfLevel', 'typeOfLevelECMF', 'unitOfTimeRange', 'units', 'unitsECMF', 'unpackedError', 'uvRelativeToGrid', 'validDate', 'validDate', 'valid_key', 'validityDate', 'validityDateTime', 'validityTime', 'values', 'wrongPadding', 'year', 'yearOfCentury']
]

GRIBEXSection1Problem
GRIBEditionNumber
N
Ni
Nj
P1
P2
PLPresent
PVPresent
WMO
__class__
__delattr__
__delitem__
__dir__
__doc__
__eq__
__format__
__ge__
__getattr__
__getattribute__
__getitem__
__gt__
__hash__
__init__
__init_subclass__
__le__
__lt__
__ne__
__new__
__reduce__
__reduce_ex__
__repr__
__setattr__
__setitem__
__setstate__
__sizeof__
__str__
__subclasshook__
_all_keys
_get_key
_read_only_keys
_reshape_mask
_ro_keys
_set_projparams
_unshape_mask
additionalFlagPresent
alternativeRowScanning
analDate
analDate
angleSubdivisions
average
binaryScaleFactor
bitMapIndicator
bitmapPresent
bitsPerValue
bitsPerValueAndRepack
boustrophedonic
centre
centreDescription
centuryOfReferenceTimeOfData
cfName
cfNameECMF
cfVarName
cfVarNameECMF
changeDecimalPrecision
complexPacking
constantFieldHalfByte
data
dataDate
dataFlag
dataLength
dataRepresentationType
dataTime
day
decimalPrecision
decimalScaleFactor
deleteLocalDefinition
deletePV
distinctLatitudes
distinctLongitudes
earthIsOblate
edition

[None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,

In [6]:
import pygrib

grib_file = "/home/user/large-disk/pangu_preds/test_perturbed_phase.grib" #Ctrl_nopert_20240326.grb"

grbs = pygrib.open(grib_file)

for grb in grbs:
    print(f"{grb.shortName} - {grb.level} @ Time: {grb.dataTime}")
    print(grb.values)

    data, latitudes, longitudes = grb.data()
    # print(data)

q - 50 @ Time: 0
[[3.06946822e-06 3.06854357e-06 3.06761892e-06 ... 3.07045048e-06
  3.07012306e-06 3.06979564e-06]
 [3.07066148e-06 3.06951370e-06 3.06836591e-06 ... 3.07117443e-06
  3.07100345e-06 3.07083246e-06]
 [3.06906077e-06 3.06870637e-06 3.06835197e-06 ... 3.07325627e-06
  3.07185777e-06 3.07045927e-06]
 ...
 [3.07645314e-06 3.07681603e-06 3.07717892e-06 ... 3.06873062e-06
  3.07130479e-06 3.07387897e-06]
 [3.03844354e-06 3.03858724e-06 3.03873094e-06 ... 3.04002060e-06
  3.03949491e-06 3.03896923e-06]
 [3.05218055e-06 3.05198500e-06 3.05178946e-06 ... 3.05260346e-06
  3.05246249e-06 3.05232152e-06]]
t - 50 @ Time: 0
[[223.77107239 223.77183533 223.77259827 ... 223.77305603 223.77239482
  223.7717336 ]
 [223.9322052  223.93293762 223.93367004 ... 223.95069885 223.9445343
  223.93836975]
 [224.02302551 224.02900696 224.0349884  ... 224.06489563 224.05093892
  224.03698222]
 ...
 [218.0806427  218.0732371  218.0658315  ... 218.06034851 218.06711324
  218.07387797]
 [217.62654114

In [33]:
grib_file = "/home/user/large-disk/aifs_preds/Code3_Interfaces_Python/Ctrl_nopert_20240326.grb"


perturbation_by_factor_list(
        grib_file,
        perturbation_dict = {
            "u": {300: 0.6}, 
            "v": {300: 0.6}}
        )

perturbing v - 300 @ Time: 1800
perturbing v - 300 @ Time: 0
these variables were not found in the grib file:
{('u1', 300)}


True

In [23]:
grib_file = "/home/user/large-disk/aifs_preds/Code3_Interfaces_Python/Ctrl_nopert_20240326.grb"

def perturbation_phase(
        grib_file,
        ):
    
    if not validate_grib_file(grib_file):
        return False

    path, file = os.path.split(grib_file)
    filename, extension = os.path.splitext(file)

    input_grib_file = grib_file
    output_grib_file = os.path.join(path, f"{filename}_phase_perturbed{extension}")

    # Open the GRIB file
    grbs = pygrib.open(input_grib_file)

    with open(output_grib_file, 'wb') as out_file:
        for grb in grbs:
            # Check the current dataTime
            if grb.dataTime == 0:
                # Change dataTime from 0 to 18
                grb.dataTime = 1800
            elif grb.dataTime == 1800:
                # Change dataTime from 18 to 0
                grb.dataTime = 0
            
            # Write the modified message to the new GRIB file
            out_file.write(grb.tostring())

    grbs.close()

    return True


In [34]:
LAT_MIN_LIM, LAT_MAX_LIM = -90, 90
LON_MIN_LIM, LON_MAX_LIM = -180, 180

def value_minus_3(value):
    return value - 3

def value_plus_3(value):
    return value + 3

def perturb_regionally(
        grib_file,
        variable,
        level,
        perturb_function,
        lat_min=LAT_MIN_LIM, 
        lat_max=LAT_MAX_LIM,
        lon_min=LON_MIN_LIM,
        lon_max=LON_MAX_LIM):
    
    if not validate_grib_file(grib_file):
        return False

    # arg check
    if (lat_min == LAT_MIN_LIM and \
        lat_max == LAT_MAX_LIM and \
        lon_min == LON_MIN_LIM and \
        lon_max == LON_MAX_LIM):

        print(f"will perturb {variable} on all the coordinates.")
    
    if (lat_min < LAT_MIN_LIM and \
        lat_max > LAT_MAX_LIM and \
        lon_min < LON_MIN_LIM and \
        lon_max > LON_MAX_LIM):
        
        print(f"coordinates out of range limit lat",
              f"({LAT_MIN_LIM}, {LAT_MAX_LIM})",
              f"lon ({LON_MIN_LIM}, {LON_MAX_LIM})")
        return False


    path, file = os.path.split(grib_file)
    filename, extension = os.path.splitext(file)

    input_grib_file = grib_file
    output_grib_file = os.path.join(path, f"{filename}_regional_{variable}_{perturb_function}_perturbed{extension}")
    with open(os.path.join(path, f"{filename}_regional_{variable}_{perturb_function}_perturbed.perturbation_config"), "w") as f:
        f.write(json.dumps({
            "lat_min": lat_min, 
            "lat_max": lat_max,
            "lon_min": lon_min,
            "lon_max": lon_max
        }))

    grbs = pygrib.open(input_grib_file)

    with open(output_grib_file, 'wb') as out_file:
        found_variable = False
        for (_i, grb) in enumerate(grbs):
            # Get the data, latitudes, and longitudes
            data, lats, lons = grb.data()

            if grb.shortName == variable and grb.level == level:
                found_variable = True
                unique_lons = np.unique(lons.flatten())
                unique_lats = np.unique(lats.flatten())

                lat_min, lat_max = (np.abs(np.unique(unique_lats) - lat_min)).min() + lat_min, (np.abs(np.unique(unique_lats) - lat_max)).min() + lat_max
                lon_min, lon_max = (np.abs(np.unique(unique_lons) - lon_min)).min() + lon_min, (np.abs(np.unique(unique_lons) - lon_max)).min() + lon_max
                                                
                # Create a mask based on the latitude and longitude range
                mask = (lats >= lat_min) & (lats <= lat_max) & (lons >= lon_min) & (lons <= lon_max)

                # Modify the data where the mask is True
                data[mask] = eval(f"{perturb_function}(data[mask])")

                # Flatten the data to 1D if required by the grid type
                grb.values = data.flatten()

            # Write the modified message to the new GRIB file
            out_file.write(grb.tostring())

    grbs.close()

    if not found_variable:
        print(f"{variable} column does not exist in the grib file.")
        return False
    else:
        return True



In [2]:
int(False)

0

In [35]:
perturb_regionally(
        grib_file,
        variable="u",
        perturb_function="value_plus_3",
        lat_min=LAT_MIN_LIM, 
        lat_max=LAT_MAX_LIM,
        lon_min=LON_MIN_LIM,
        lon_max=LON_MAX_LIM)

will perturb u on all the coordinates.


True

In [None]:
def perturb_specific_location(
        grib_file,
        variable,
        lat,
        lon,
        perturb_function):

    if not validate_grib_file(grib_file):
        return False

    if (lat < LAT_MIN_LIM and \
        lat > LAT_MAX_LIM and \
        lon < LON_MIN_LIM and \
        lon > LON_MAX_LIM):
        
        print(f"coordinates out of range limit lat",
              f"({LAT_MIN_LIM}, {LAT_MAX_LIM})",
              f"lon ({LON_MIN_LIM}, {LON_MAX_LIM})")
        return False
    
    path, file = os.path.split(grib_file)
    filename, extension = os.path.splitext(file)

    input_grib_file = grib_file
    output_grib_file = os.path.join(path, f"{filename}_coord_point__{variable}_{perturb_function}_perturbed{extension}")
    with open(os.path.join(path, f"{filename}_coord_point_{variable}_{perturb_function}_perturbed.perturbation_config"), "w") as f:
        f.write(json.dumps({ 
            "lat": lat,
            "lon": lon
        }))

    grbs = pygrib.open(input_grib_file)

    with open(output_grib_file, 'wb') as out_file:
        found_variable = False
        # Loop through all messages in the GRIB file
        for (_i, grb) in enumerate(grbs):
            # Get the data, latitudes, and longitudes
            data, lats, lons = grb.data()

            if grb.shortName == variable:
                found_variable = True
                unique_lons = np.unique(lons.flatten())
                unique_lats = np.unique(lats.flatten())

                lat = np.abs(np.unique(unique_lats) - lat) + lat
                lon = np.abs(np.unique(unique_lons) - lon) + lon
                                                
                # Create a mask based on the latitude and longitude range
                mask = (lats == lat) & (lons == lon)

                # Modify the data where the mask is True
                data[mask] = eval(f"{perturb_function}(data[mask])")

                # Flatten the data to 1D if required by the grid type
                grb.values = data.flatten()

            # Write the modified message to the new GRIB file
            out_file.write(grb.tostring())

    grbs.close()
    
    if not found_variable:
        print(f"{variable} column does not exist in the grib file.")
        return False
    else:
        return True