# Functions for Evaluation and XAI tasks

In [2]:
import numpy as np

In [11]:
def get_file_paths_rec(case_study: str, model: str)-> list:
    
    ''' Recursively extracts the full path of files'''
    
    base_path = f"/glade/campaign/cisl/aiml/ptype/ptype_case_studies/{case_study}/{model}"
    file_paths = []
    
    for root, directories, files in os.walk(base_path):
        for filename in files:
            file_paths.append(os.path.join(root, filename))
            
    return file_paths

In [16]:
def plot_CM(truth: numpy.ndarray, pred: numpy.ndarray, main_title:str, norm_val: str = None, font_size = 9)->None:
    
    '''' Function that plots a single confusion matrix: 
            1) Normalized 
            2) Nonnormalized 
            3) Normalized by Truth''' 
    
    cm = confusion_matrix(truth, pred, normalize=norm_val)
    
    ConfusionMatrixDisplay(cm)

    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, cmap='Blues', cbar= True, xticklabels=class_names, yticklabels=class_names, linewidths=0, linecolor='black',fmt='d',  annot_kws={"size":15})
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    
    
    if norm_val == 'true':
        title = "Normalized by Truth"
    elif norm_val == 'pred':
        title = "Normalized by Prediction"
    elif norm_val == None:
        title = "Unnormalized"
        

    plt.title(f'{main_title} {title}')
    plt.show()

In [13]:
def plot_3CM(truth: numpy.ndarray, pred: numpy.ndarray, main_title:str, class_names:list, save_location= None)->None:
    
    ''' subplots 3 confusion matrices: normalized, unnormalized and normalized by truth'''
    
    fig, axs = plt.subplots(1, 3, figsize=(18, 6))

    # Normalized by Prediction 
    cm_normalized = confusion_matrix(truth, pred, normalize='pred')
    sns.heatmap(cm_normalized, annot=True, cmap='Blues', cbar=True, xticklabels=class_names, yticklabels=class_names, linewidths=1, linecolor='black', fmt='.2f', ax=axs[0], annot_kws={"size": 14})
    axs[0].set_title('Normalized by Prediction', fontweight="bold")
    axs[0].set_xlabel('Predicted Label', fontweight="bold")
    axs[0].set_ylabel('True Label', fontweight="bold")
    
    axs[0].set_xticklabels(class_names, fontsize=12, fontweight='bold')
    axs[0].set_yticklabels(class_names, fontsize=12, fontweight='bold')
    

    # Normalized by Truth
    cm_norm_truth = confusion_matrix(truth, pred, normalize='true')
    
    log_norm = colors.LogNorm(vmin=np.min(cm_norm_truth), vmax=np.max(cm_norm_truth))

    sns.heatmap(cm_norm_truth, annot=True, cmap='Blues', cbar=True, xticklabels=class_names, yticklabels=class_names, linewidths=1, linecolor='black', fmt='.2f', ax=axs[1], annot_kws={"size": 14}, norm=log_norm)
    axs[1].set_title('Normalized by Truth', fontweight="bold")
    axs[1].set_xlabel('Predicted Label', fontweight="bold")
    
    axs[1].set_xticklabels(class_names, fontsize=12, fontweight='bold')
    axs[1].set_yticklabels(class_names, fontsize=12, fontweight='bold')


    # Unnormalized
    cm_unnormalized = confusion_matrix(truth, pred, normalize=None)
    
    log_norm = colors.LogNorm(vmin=np.min(cm_unnormalized), vmax=np.max(cm_unnormalized))

    sns.heatmap(cm_unnormalized, annot=True, cmap='Blues', cbar=True, xticklabels=class_names, yticklabels=class_names, linewidths=1, linecolor='black', fmt='d', ax=axs[2], annot_kws={"size": 14}, norm=log_norm)
    axs[2].set_title('Unnormalized', fontweight="bold")
    axs[2].set_xlabel('Predicted Label', fontweight="bold")
    
    axs[2].set_xticklabels(class_names, fontsize=12, fontweight='bold')
    axs[2].set_yticklabels(class_names, fontsize=12, fontweight='bold')

    fig.suptitle(main_title, fontsize=15, fontweight="bold")
    plt.tight_layout()
    
    if save_location:
        plt.savefig(save_location, dpi=300, bbox_inches="tight")
    
    plt.show()

In [14]:
def regional_CONUS(latN_: int, latS_: int, lonW_: int, lonE_:int, title: str, case_dates: list)-> None:
    
    case_study = test_dataset[(test_dataset['datetime'] >= case_dates[0]) & (test_dataset['datetime'] <= case_dates[1])]
    colors = {0: "springgreen", 1: "skyblue", 2: "red", 3: "black"}
    
   
    main_title = case_dates[0]+ " to " + case_dates[1]+ ", " + title
    
    
    cLat = (latN_ + latS_) / 2
    cLon = (lonW_ + lonE_) / 2
    
     # Create the figure object with a larger canvas
    fig = plt.figure(figsize=(40, 24))
    
    
    projLcc = ccrs.LambertConformal(central_longitude=cLon, central_latitude=cLat)

    projPC = ccrs.PlateCarree()
    #res = '50m' 
    
    # true 
    ax = plt.subplot(1, 2, 1, projection=projLcc)
    ax.set_extent([lonW_, lonE_, latS_, latN_], crs=projPC)
    ax.set_facecolor(cfeature.COLORS['water'])
    ax.add_feature(cfeature.LAND)
    ax.add_feature(cfeature.COASTLINE)
    ax.add_feature(cfeature.BORDERS, linestyle='--')
    ax.add_feature(cfeature.LAKES, alpha=0.5)
    ax.add_feature(cfeature.STATES)
    ax.set_title('mPING ' + case_dates[0]+ " to " + case_dates[1] + ", " + title, fontsize=35,fontweight='bold')
    
    labels_=["Rain", "Snow", "Ice Pellets", "Freezing Rain"]
    
   # ax.legend(colors.values(), labels=labels_, fontsize=24, markerscale=3, loc = 'lower left')
    
    column = 'true_label'
    
    
    for i in range(4):
            ax.scatter(
                case_study["lon"][case_study[column] == i] - 360,
                case_study["lat"][case_study[column] == i],
                c=case_study[column][case_study[column] == i].map(colors),
                s=250,
                transform=ccrs.PlateCarree(),
                alpha=0.2,
            )
            
            
    # prediction 

    projLccEur = ccrs.LambertConformal(central_longitude=cLon, central_latitude=cLat)
    
    proj = ccrs.PlateCarree()

    ax2 = plt.subplot(1, 2, 2, projection=projLccEur)
    ax2.set_extent([lonW_, lonE_, latS_, latN_], crs=proj)
    ax2.set_facecolor(cfeature.COLORS['water'])
    ax2.add_feature(cfeature.LAND)
    ax2.add_feature(cfeature.COASTLINE)
    ax2.add_feature(cfeature.BORDERS, linestyle='--')
    ax2.add_feature(cfeature.LAKES, alpha=0.5)
    ax2.add_feature(cfeature.STATES)
    
    ax2.set_title('Predictions ' + case_dates[0]+ " to " + case_dates[1] + ", " + title, fontsize=35,fontweight='bold')
    
    #ax2.legend(colors.values(), labels=labels_, fontsize=10, loc = 'center')

    pred = 'pred_label'
    
    for i in range(4):
            ax2.scatter(
                case_study["lon"][case_study[pred] == i] - 360,
                case_study["lat"][case_study[pred] == i],
                c=case_study[pred][case_study[pred] == i].map(colors),
                s=250,
                transform=ccrs.PlateCarree(),
                alpha=0.2,
            )
    

    # Create a common legend outside the subplots
    legend_patches = [
        #mpatches.Circle((0, 0), 0.5, facecolor=colors[0], label=labels_[0]),
        mpatches.Patch(facecolor=colors[0], label=labels_[0]),
        mpatches.Patch(facecolor=colors[1], label=labels_[1]),
        mpatches.Patch(facecolor=colors[2],label=labels_[2]),
        mpatches.Patch(facecolor=colors[3], label=labels_[3])
    
    ]
    
    fig.legend(handles=legend_patches, bbox_to_anchor=(0.99, 0.43), loc='upper right', prop={'size': 25})
    
    
    plt.tight_layout()
    #fig.suptitle(main_title, fontsize=15, fontweight="bold")
    #plt.subplots_adjust(top=0.2)
    plt.show()


In [17]:
def regional_CONUS_(latN_: int, latS_: int, lonW_: int, lonE_:int, title: str, pred, case_dates: list)-> None:
    
    ## mping vs RAP 

    
    case_study = test_dataset[(test_dataset['datetime'] >= case_dates[0]) & (test_dataset['datetime'] <= case_dates[1])]
    colors = {0: "springgreen", 1: "skyblue", 2: "red", 3: "black"}
    
   
    main_title = case_dates[0]+ " to " + case_dates[1]+ ", " + title
    
    
    cLat = (latN_ + latS_) / 2
    cLon = (lonW_ + lonE_) / 2
    
    # Create the figure object with a larger canvas
    fig = plt.figure(figsize=(40, 24))
    
    projLcc = ccrs.LambertConformal(central_longitude=cLon, central_latitude=cLat)

    proj = ccrs.PlateCarree()
    
    # mPING data 

    ax = plt.subplot(1, 2, 1, projection=projLcc)
    ax.set_extent([lonW_, lonE_, latS_, latN_], crs=proj)
    ax.set_facecolor(cfeature.COLORS['water'])
    ax.add_feature(cfeature.LAND)
    ax.add_feature(cfeature.COASTLINE)
    ax.add_feature(cfeature.BORDERS, linestyle='--')
    ax.add_feature(cfeature.LAKES, alpha=0.5)
    ax.add_feature(cfeature.STATES)
    ax.set_title('mPING ' + case_dates[0]+ " to " + case_dates[1] + ", " + title, fontsize=35,fontweight='bold')
    
    labels_=["Rain", "Snow", "Ice Pellets", "Freezing Rain"]
    
   # ax.legend(colors.values(), labels=labels_, fontsize=24, markerscale=3, loc = 'lower left')
    
    column = 'true_label'
    
    
    for i in range(4):
            ax.scatter(
                case_study["lon"][case_study[column] == i] - 360,
                case_study["lat"][case_study[column] == i],
                c=case_study[column][case_study[column] == i].map(colors),
                s=250,
                transform=ccrs.PlateCarree(),
                alpha=0.2,
            )
            
            
    # RAP 
    ax2 = plt.subplot(1, 2, 2, projection=projLcc)
    ax2.set_extent([lonW_, lonE_, latS_, latN_], crs=proj)
    ax2.set_facecolor(cfeature.COLORS['water'])
    ax2.add_feature(cfeature.LAND)
    ax2.add_feature(cfeature.COASTLINE)
    ax2.add_feature(cfeature.BORDERS, linestyle='--')
    ax2.add_feature(cfeature.LAKES, alpha=0.5)
    ax2.add_feature(cfeature.STATES)
    
    ax2.set_title('Predictions ' + case_dates[0]+ " to " + case_dates[1] + ", " + title, fontsize=35,fontweight='bold')
    
    #ax2.legend(colors.values(), labels=labels_, fontsize=10, loc = 'center')

 #   pred = 'pred_label'
    
    for i in range(4):
            ax2.scatter(
                sourcedata["lon"][pred == i] - 360,
                sourcedata["lat"][pred == i],
                c=pred[[pred] == i].map(colors), #!!!
                s=250,
                transform=ccrs.PlateCarree(),
                alpha=0.2,
            )
    

    # Create a common legend outside the subplots
    legend_patches = [
        #mpatches.Circle((0, 0), 0.5, facecolor=colors[0], label=labels_[0]),
        mpatches.Patch(facecolor=colors[0], label=labels_[0]),
        mpatches.Patch(facecolor=colors[1], label=labels_[1]),
        mpatches.Patch(facecolor=colors[2],label=labels_[2]),
        mpatches.Patch(facecolor=colors[3], label=labels_[3])
    
    ]
    
    fig.legend(handles=legend_patches, bbox_to_anchor=(0.99, 0.43), loc='upper right', prop={'size': 25})
    
    
    plt.tight_layout()
    #fig.suptitle(main_title, fontsize=15, fontweight="bold")
    #plt.subplots_adjust(top=0.2)
    plt.show()
