In [1]:
from sklearn.metrics import multilabel_confusion_matrix, classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import json

In [2]:
labels = ["backward", "forward", "right", "left", "down", "up", "go", "stop", "on", "off", "yes", "no", 
          "learn", "follow", "zero", "one", "two", "three", "four", "five", "six", "seven", "eight", 
          "nine", "bed", "bird", "cat", "dog", "happy", "house", "read", "write", "tree", "visual", "wow"]

In [3]:
def results(preds_path, labels, print_cm=False):
    # groundtruth and predicted 
    # labels
    y_true = []
    y_pred = []

    # opening a JSON file
    f = open(preds_path)

    # returns JSON object as 
    # a dictionary
    data = json.load(f)
    # iterating through the json list
    # and adding true and predicted labels
    for t, p in data.items():
        t = t.split('/')[-2]
        if p == 'sheila':
            p = 'read'
        if p == 'marvin':
            p = 'write'
        y_true.append(t)
        y_pred.append(p)

    # closing file
    f.close()
    # generate the classification report
    print(classification_report(y_true,y_pred, digits=4))
    
    if print_cm:
        # generate a confusion matrix in %
        cm = confusion_matrix(y_true, y_pred, labels=labels)
        cmn = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100
        cmn = np.round(cmn,1)

        # plot the confusion matrix in a beautiful manner
        fig = plt.figure(figsize=(16, 16))
        ax= plt.subplot()
        sns.heatmap(cmn, annot=True, ax = ax, fmt=".1f", linewidth=.1, 
                    cmap='YlGn', cbar=False, square=True, linecolor='white')

        # labels, title, and ticks
        ax.set_xlabel('Predicted commands', fontsize=14)
        ax.xaxis.set_label_position('bottom')
        plt.xticks(rotation=90)
        ax.xaxis.set_ticklabels(labels, fontsize=12)
        ax.xaxis.tick_bottom()
        ax.set_ylabel('Actual commands', fontsize=14)
        ax.yaxis.set_ticklabels(labels, fontsize=12)
        plt.yticks(rotation=0)
        plt.title('Confusion Matrix', fontsize=16)
        plt.savefig("confusion_matrix.png")
        plt.show()

In [4]:
def results_2(preds_path, labels, print_cm=False):
    # groundtruth and predicted 
    # labels
    y_true = []
    y_pred = []

    # opening a JSON file
    f = open(preds_path)
    lang = preds_path.split('/')[1].split('_')[1]
    print(lang)
    # returns JSON object as 
    # a dictionary
    data = json.load(f)
    # iterating through the json list
    # and adding true and predicted labels
    for t, p in data.items():
        t = t.split('/')[-2]
        p = p.split('_')[0]
 
        y_true.append(t)
        y_pred.append(p)

    # closing file
    f.close()
    # generate the classification report
    print(classification_report(y_true,y_pred, digits=4))
    
    if print_cm:
        # generate a confusion matrix in %
        cm = confusion_matrix(y_true, y_pred, labels=labels)
        cmn = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100
        cmn = np.round(cmn,1)

        # plot the confusion matrix in a beautiful manner
        fig = plt.figure(figsize=(16, 16))
        ax= plt.subplot()
        sns.heatmap(cmn, annot=True, ax = ax, fmt=".1f", linewidth=.1, 
                    cmap='YlGn', cbar=False, square=True, linecolor='white')

        # labels, title, and ticks
        ax.set_xlabel('Predicted commands', fontsize=14)
        ax.xaxis.set_label_position('bottom')
        plt.xticks(rotation=90)
        ax.xaxis.set_ticklabels(labels, fontsize=12)
        ax.xaxis.tick_bottom()
        ax.set_ylabel('Actual commands', fontsize=14)
        ax.yaxis.set_ticklabels(labels, fontsize=12)
        plt.yticks(rotation=0)
        plt.title('Confusion Matrix', fontsize=16)
        plt.savefig("confusion_matrix.png")
        plt.show()
    

In [5]:
!python3 inference.py --conf checkpoints/mono-35-en/kwmlp_google.yaml \
                      --ckpt checkpoints/mono-35-en/best.pth \
                      --inp test_data_en/ \
                      --out outputs/mono_en/ \
                      --lmap checkpoints/mono-35-en/label_map.json \
                      --device cpu \
                      --batch_size 256 

100%|███████████████████████████████████████████| 43/43 [00:20<00:00,  2.11it/s]
Saved preds to outputs/mono_en/preds.json
[0m

In [6]:
results('outputs/mono_en/preds.json', labels)

              precision    recall  f1-score   support

    backward     0.9532    0.9879    0.9702       165
         bed     0.9849    0.9469    0.9655       207
        bird     0.9424    0.9730    0.9574       185
         cat     0.9585    0.9536    0.9561       194
         dog     0.9548    0.9591    0.9569       220
        down     0.9823    0.9581    0.9701       406
       eight     0.9949    0.9559    0.9750       408
        five     0.9777    0.9843    0.9810       445
      follow     0.9651    0.9651    0.9651       172
     forward     0.9032    0.9032    0.9032       155
        four     0.9521    0.9450    0.9486       400
          go     0.9658    0.9826    0.9741       402
       happy     0.9850    0.9704    0.9777       203
       house     0.9590    0.9791    0.9689       191
       learn     0.9375    0.9317    0.9346       161
        left     0.9808    0.9903    0.9855       412
        nine     0.9877    0.9853    0.9865       408
          no     0.9757    

In [7]:
!python3 inference.py --conf checkpoints/multi-35/kwmlp_multi_35.yaml \
                      --ckpt checkpoints/multi-35/best.pth \
                      --inp test_data_en/ \
                      --out outputs/multi_en/ \
                      --lmap checkpoints/multi-35/label_map.json \
                      --device cpu \
                      --batch_size 256 

100%|███████████████████████████████████████████| 43/43 [00:22<00:00,  1.93it/s]
Saved preds to outputs/multi_en/preds.json
[0m

In [8]:
results('outputs/multi_en/preds.json', labels)

              precision    recall  f1-score   support

    backward     0.9704    0.9939    0.9820       165
         bed     0.9557    0.9372    0.9463       207
        bird     0.9577    0.9784    0.9679       185
         cat     0.9738    0.9588    0.9662       194
         dog     0.9552    0.9682    0.9616       220
        down     0.9724    0.9532    0.9627       406
       eight     0.9878    0.9902    0.9890       408
        five     0.9756    0.9865    0.9810       445
      follow     0.9630    0.9070    0.9341       172
     forward     0.9267    0.8968    0.9115       155
        four     0.9416    0.9675    0.9544       400
          go     0.9798    0.9652    0.9724       402
       happy     0.9851    0.9803    0.9827       203
       house     0.9947    0.9738    0.9841       191
       learn     0.9548    0.9193    0.9367       161
        left     0.9783    0.9854    0.9819       412
        nine     0.9828    0.9828    0.9828       408
          no     0.9709    

In [9]:
!python3 inference.py --conf checkpoints/multi-140/kwmlp_multi_140.yaml \
                      --ckpt checkpoints/multi-140/best.pth \
                      --inp test_data_en/ \
                      --out outputs/multi_en_2/ \
                      --lmap checkpoints/multi-140/label_map.json \
                      --device cpu \
                      --batch_size 256 

100%|███████████████████████████████████████████| 43/43 [00:20<00:00,  2.11it/s]
Saved preds to outputs/multi_en_2/preds.json
[0m

In [10]:
results_2('outputs/multi_en_2/preds.json', labels)

en
              precision    recall  f1-score   support

    backward     0.9760    0.9879    0.9819       165
         bed     0.9604    0.9372    0.9487       207
        bird     0.9519    0.9622    0.9570       185
         cat     0.9388    0.9485    0.9436       194
         dog     0.9815    0.9636    0.9725       220
        down     0.9725    0.9581    0.9653       406
       eight     0.9827    0.9730    0.9778       408
        five     0.9821    0.9888    0.9854       445
      follow     0.9425    0.9535    0.9480       172
     forward     0.9467    0.9161    0.9311       155
        four     0.9578    0.9650    0.9614       400
          go     0.9581    0.9677    0.9629       402
       happy     0.9850    0.9704    0.9777       203
       house     0.9439    0.9686    0.9561       191
       learn     0.9182    0.9068    0.9125       161
        left     0.9736    0.9854    0.9795       412
        nine     0.9828    0.9828    0.9828       408
          no     0.9709 