In [1]:
import matplotlib.pyplot as plt
import numpy as np
from path import Path
import pickle
from scipy.signal import find_peaks
plt.rcParams["font.family"] = "serif"
plt.rcParams['font.serif'] = ['Times New Roman'] + plt.rcParams['font.serif']
plt.rc('xtick', labelsize=16)    # fontsize of the tick labels
plt.rc('ytick', labelsize=16)
graph_dpi = 100

In [7]:
# First run the test.py to create the results pickle file for analysis
architectures = ["fc", 'cnn', 'vit', 'rvit', 'rcnn']
datasets = ['dvrk', 'dafoes', 'mixed']
features = ['random', 'stiffness', 'structure']
data = ['rgb']
dataset = {"dvrk": "dVRK", "dafoes": "DaFoEs", "mixed": "Mixed"}
save_dir = Path('results')

In [None]:
# Load the data file or files
results = {}
for s in datasets:
    fig = plt.figure(figsize = (1920 / graph_dpi, 1080 / graph_dpi), dpi = graph_dpi)
    fig.suptitle(f"Results for training over {dataset[s]}", fontsize=20, fontweight="bold")
    idx = 1
    for d in data:
        for f in features:
            dafoes_mean_results, dafoes_std_results = [], []
            dvrk_mean_results, dvrk_std_results = [], []
            labels = []
            ax = fig.add_subplot(1, 3, idx)
            for arch in architectures:
                file_s = save_dir/"{}/{}/{}_state_{}.pkl".format(s, d, arch, f)
                labels.append(f"{arch.upper()}")
                with open(file_s, 'rb') as fileobject:
                    results_s = pickle.load(fileobject)       
                try:
                    dafoes_mean_results.append(results_s['dafoes_rmse_mean'].mean())
                    dafoes_std_results.append(results_s['dafoes_rmse_std'].mean())
                    dvrk_mean_results.append(results_s["dvrk_rmse_mean"].mean())
                    dvrk_std_results.append(results_s["dvrk_rmse_std"].mean())
                    results[f"{s}_{arch}_{f}"] = results_s
                except KeyError:
                    continue
            
            x_pos_0 = np.arange(0, 2 * len(labels), 2)
            x_pos_1 = np.arange(1, 2 * len(labels) + 1, 2)

            ax.bar(x_pos_0, dafoes_mean_results, yerr=dafoes_std_results, color="purple", alpha=0.6, ecolor="black", error_kw=dict(lw=5, capsize=10, capthick=2))
            ax.bar(x_pos_1, dvrk_mean_results, yerr=dvrk_std_results, color="orange", alpha=0.6, ecolor='black', error_kw=dict(lw=5, capsize=10, capthick=2))
            ax.legend(["DaFoEs", "dVRK"], fontsize=16)

            if idx == 1:
                ax.set_ylabel("RMSE (N)", fontsize = 16, fontweight="bold")

            
            ax.set_xticks(x_pos_0 + 0.5)
            ax.set_xticklabels(labels)
            ax.set_title(f"Errors for isolated feature {f.capitalize()}", fontsize=16, fontweight="bold")
            

            idx += 1

        fig.tight_layout()
        fig.savefig(f"graphs/features_{dataset[s]}.png", dpi=graph_dpi)
        plt.show()
    

In [None]:
for f in features:
    for arch in architectures:
        print(f"Values for occluded param: {f} and {arch}")
        print(results[f"mixed_{arch}_{f}"]["dvrk_rmse_std"].mean(), results[f"mixed_{arch}_{f}"]["dvrk_rmse_mean"].mean())


In [None]:
# Cell to plot the force estimation for the unseen testing set
my_dpi = 500
labels = ['X', 'Y', 'Z']

# General view of the predictions
# ranges = {"dvrk": [[30*10, int(30*16)], [30*10, int(30*16)], [int(30*84), int(30*90)]],
#           "dafoes": [[30*10, int(30*16)], [30*10, int(30*16)], [30*10, int(30*16)]],
#           "mixed": [[30*10, int(30*16)], [30*10, int(30*16)], [30*84, int(30*90)]]}

# Zoomed areas of interest
ranges = {"dvrk": [[int(30*10.5), int(30*11.5)], [30*12, int(30*13)], [int(30*85.5), int(30*86.5)]],
          "dafoes": [[30*10, int(30*16)], [30*10, int(30*16)], [30*10, int(30*16)]],
          "mixed": [[int(30*10.5), int(30*11.5)], [30*12, int(30*13)], [int(30*85.5), int(30*86.5)]]}

save_fig_root = Path('figures')  
for s in datasets:
    x = (1/30)*np.linspace(0, len(results['{}_rgb_cnn_state_random'.format(s)]['dvrk_gt']), len(results['{}_rgb_cnn_state_random'.format(s)]['dvrk_gt']))
    x_rnn = (1/30)*np.linspace(0, len(results['{}_rgb_cnn_state_random'.format(s)]['dvrk_gt']), len(results['{}_rgb_rcnn_state_random'.format(s)]['dvrk_pred']))
    range_s = ranges[s]
    for f in features:
        for d in data:
            fig, axs = plt.subplots(ncols=1, nrows=3, sharex=False, sharey=False, figsize=(5000/my_dpi, 5000/my_dpi), dpi=my_dpi)
            fig.suptitle("{} force estimation in unknown {} environment".format(s, f), fontsize=12, fontweight='bold')
            for j, ax in enumerate(axs):
                r = range_s[j]
                rnn_range = [r[0], r[1]] if s=="dafoes" else [r[0]-4, r[1]-4]
                ax.plot(x[r[0]:r[1]], results['{}_{}_cnn_state_{}'.format(s, d, f)]['dvrk_gt'][r[0]:r[1], j], 'b', linewidth=6.0)
                ax.plot(x[r[0]:r[1]], results['{}_{}_cnn_state_{}'.format(s, d, f)]['dvrk_pred'][r[0]:r[1], j], 'g-.', linewidth=3.0)
                ax.plot(x[r[0]:r[1]], results['{}_{}_vit_state_{}'.format(s, d, f)]['dvrk_pred'][r[0]:r[1], j], 'y-.', linewidth=3.0)
                ax.plot(x[r[0]:r[1]], results['{}_{}_rcnn_state_{}'.format(s, d, f)]['dvrk_pred'][rnn_range[0]:rnn_range[1], j], 'r-.', linewidth=3.0)
                ax.plot(x[r[0]:r[1]], results['{}_{}_rvit_state_{}'.format(s, d, f)]['dvrk_pred'][rnn_range[0]:rnn_range[1], j], 'k-.', linewidth=3.0)
                ax.plot(x[r[0]:r[1]], results['{}_{}_fc_state_{}'.format(s, d, f)]['dvrk_pred'][r[0]:r[1], j], 'm-.', linewidth=3.0)
                # ax.legend(["GT", "CNN", "CNN-BAM", "RNN", "RNN-BAM", 'FC'], fontsize=14)
                if j==2:
                    ax.legend(["GT", "CNN", "VIT", "R-CNN", "R-VIT", "FC"], fontsize=14, loc=1)
                ax.set_ylabel("Force {} (N)".format(labels[j]), fontsize=14, fontweight='bold')
                if j == 2:
                    ax.set_xlabel("Time (s)", fontsize=14, fontweight='bold')

            fig.align_labels()
            save_fig_path = save_fig_root/"{}".format(s)
            save_fig_path.makedirs_p()
            fig.savefig(save_fig_path/"zoomed_{}.png".format(f), dpi=my_dpi)
            plt.show()

In [None]:
my_dpi = 100
ranges = {"dvrk": [[30*10, 30*13], [30*10, 30*15], [30*84, 30*87]],
          "dafoes": [[30*10, 30*13], [30*10, 30*15], [30*10, 30*15]],
          "mixed": [[30*10, 30*13], [30*10, 30*15], [30*84, 30*87]]}

for s in datasets:
    x = (1/30)*np.linspace(0, len(results['{}_rgb_cnn_state_random'.format(s)]['test_gt']), len(results['{}_rgb_cnn_state_random'.format(s)]['test_gt']))
    x_rnn = (1/30)*np.linspace(0, len(results['{}_rgb_cnn_state_random'.format(s)]['test_gt']), len(results['{}_rgb_rcnn_state_random'.format(s)]['test_pred']))
    range_s = ranges[s]
    for f in features:
        for d in data:
            r = range_s[j]
            fig, ax = plt.subplots(ncols=1, nrows=1, sharex=False, sharey=False, figsize=(1920/my_dpi, 1080/my_dpi), dpi=my_dpi)
            fig.suptitle("{} evolution of RMSE in unknown {} environment".format(s, f), fontsize=24, fontweight='bold')
            # ax.plot(x, results['{}_{}_cnn_state_{}'.format(s, d, f)]['test_gt'][r[0]:r[1], j], 'b', linewidth=3.0)
            ax.plot(x[r[0]:r[1]], results['{}_{}_cnn_state_{}'.format(s, d, f)]['test_rmse'][r[0]:r[1]], 'g', linewidth=3.0)
            ax.plot(x[r[0]:r[1]], results['{}_{}_vit_state_{}'.format(s, d, f)]['test_rmse'][r[0]:r[1]], 'y', linewidth=3.0)
            ax.plot(x[r[0]+4:r[1]+4], results['{}_{}_rcnn_state_{}'.format(s, d, f)]['test_rmse'][r[0]:r[1]], 'r', linewidth=3.0)
            ax.plot(x[r[0]+4:r[1]+4], results['{}_{}_rvit_state_{}'.format(s, d, f)]['test_rmse'][r[0]:r[1]], 'k', linewidth=3.0)
            ax.plot(x[r[0]:r[1]], results['{}_{}_fc_state_{}'.format(s, d, f)]['test_rmse'][r[0]:r[1]], 'm', linewidth=3.0)
            # ax.legend(["GT", "CNN", "CNN-BAM", "RNN", "RNN-BAM", 'FC'], fontsize=14)
            ax.legend(["CNN", "VIT", "R-CNN", "R-VIT", "FC"], fontsize=14, loc=1)
            ax.set_ylabel("RMSE (N)", fontsize=16, fontweight='bold')
            ax.set_xlabel("Time (s)", fontsize=16, fontweight='bold')
        
        save_fig_path = save_fig_root/"{}".format(s)
        save_fig_path.makedirs_p()
        fig.align_labels()
        fig.savefig(save_fig_path/'metrics_{}.png'.format(f), dpi=my_dpi)
        plt.show()

In [None]:
fig = plt.figure()
for s in datasets:
    for d in data:
        for f in features:
            peaks_x, _ = find_peaks(results["{}_{}_cnn_state_{}".format(s, d, f)]["dafoes_gt"][:, 0], distance=150)
            peaks_y, _ = find_peaks(results["{}_{}_cnn_state_{}".format(s, d, f)]["dafoes_gt"][:, 1], distance=150)
            peaks_z, _ = find_peaks(results["{}_{}_cnn_state_{}".format(s, d, f)]["dafoes_gt"][:, 2], distance=150)
            for arch in architectures:
                rmse_x = results["{}_{}_{}_state_{}".format(s, d, arch, f)]["dafoes_rmse"][peaks_x].mean()
                rmse_y = results["{}_{}_{}_state_{}".format(s, d, arch, f)]["dafoes_rmse"][peaks_y].mean()
                rmse_z = results["{}_{}_{}_state_{}".format(s, d, arch, f)]["dafoes_rmse"][peaks_z].mean()
                langs = ["X", "Y", "Z"]
                rmse = np.mean([rmse_x, rmse_y, rmse_z])

                print("The value for {}_{}_{}_{}: {} N".format(s, d, arch, f, rmse))
                
                
