In [None]:
arr = ['Aria', 'D-LinkCam', 'D-LinkDayCam', 'D-LinkDoorSensor', 'D-LinkHomeHub', 'D-LinkSensor', 'D-LinkSiren', 'D-LinkSwitch', 'D-LinkWaterSensor', 'EdimaxCam1', 'EdimaxCam2', 'EdimaxPlug1101W', 'EdimaxPlug2101W', 'EdnetCam1', 'EdnetCam2', 'EdnetGateway', 'HomeMaticPlug', 'HueBridge', 'HueSwitch', 'Lightify', 'MAXGateway', 'SmarterCoffee', 'TP-LinkPlugHS100', 'TP-LinkPlugHS110', 'WeMoInsightSwitch', 'WeMoInsightSwitch2', 'WeMoLink', 'WeMoSwitch', 'WeMoSwitch2', 'Withings', 'iKettle2']
dic = {'Aria': 0, 'D-LinkCam': 1, 'D-LinkDayCam': 2, 'D-LinkDoorSensor': 3, 'D-LinkHomeHub': 4, 'D-LinkSensor': 5, 'D-LinkSiren': 6, 'D-LinkSwitch': 7, 'D-LinkWaterSensor': 8, 'EdimaxCam1': 9, 'EdimaxCam2': 10, 'EdimaxPlug1101W': 11, 'EdimaxPlug2101W': 12, 'EdnetCam1': 13, 'EdnetCam2': 14, 'EdnetGateway': 15, 'HomeMaticPlug': 16, 'HueBridge': 17, 'HueSwitch': 18, 'Lightify': 19, 'MAXGateway': 20, 'SmarterCoffee': 21, 'TP-LinkPlugHS100': 22, 'TP-LinkPlugHS110': 23, 'WeMoInsightSwitch': 24, 'WeMoInsightSwitch2': 25, 'WeMoLink': 26, 'WeMoSwitch': 27, 'WeMoSwitch2': 28, 'Withings': 29, 'iKettle2': 30}

from collections import defaultdict
import torch

def evaluate_metrics(predicted_labels, true_labels):
    label_metrics = defaultdict(dict)

    for label in true_labels:
        label_metrics[label]['actual_count'] = true_labels.count(label)
    for label in predicted_labels:
        label_metrics[label]['predicted_count'] = predicted_labels.count(label)

    for label in set(true_labels):
        true_positives = sum((p_label == label) and (t_label == label) for p_label, t_label in zip(predicted_labels, true_labels))
        false_positives = sum((p_label == label) and (t_label != label) for p_label, t_label in zip(predicted_labels, true_labels))
        false_negatives = sum((p_label != label) and (t_label == label) for p_label, t_label in zip(predicted_labels, true_labels))

        precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
        label_metrics[label]['precision'] = round(precision, 4)

        recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
        label_metrics[label]['recall'] = round(recall, 4)

        f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
        label_metrics[label]['f1_score'] = round(f1, 4)

    return label_metrics

predicted_labels = torch.load("best_pre_TrimSENet.pt").tolist()
true_labels = torch.load("best_val_TrimSENet.pt").tolist()
metrics = evaluate_metrics(predicted_labels, true_labels)

for label, metric in metrics.items():
    print(f"Label: {label}")
    print(f"Device: {[k for k, v in dic.items() if v == label]}")
    print(f"Actual Count: {metric['actual_count']}")
    print(f"Predicted Count: {metric['predicted_count']}")
    print(f"Precision: {metric['precision']:.4f}")
    print(f"Recall: {metric['recall']:.4f}")
    print(f"F1 Score: {metric['f1_score']:.4f}")
    print()

# import torch
# from sklearn.metrics import precision_score
# from sklearn.metrics import accuracy_score
# from sklearn.metrics import f1_score
# from sklearn.metrics import recall_score

# arr = [("best_val_TrimSENet.pt","best_pre_TrimSENet.pt"),("best_val_AlexNet.pt","best_pre_AlexNet.pt"),
#        ("best_val_GoogLeNet.pt","best_pre_GoogLeNet.pt"),("best_val_MobileNet.pt","best_pre_MobileNet.pt"),
#        ("best_val_SENet.pt","best_pre_SENet.pt"),("best_val_VggNet.pt","best_pre_VggNet.pt")]

# for i in arr:
#     y_true_add, y_pred_add = i
#     y_true = torch.load(y_true_add).tolist()
#     y_pred = torch.load(y_pred_add).tolist()
#     precision_macro = precision_score(y_true, y_pred, average='macro')
#     recall = recall_score(y_true, y_pred, average='macro')
#     accuracy = accuracy_score(y_true, y_pred)
#     f1 = f1_score(y_true, y_pred, average='macro')
    
#     print(y_true_add.split("_")[2].split(".")[0])
#     print(f"Macro Precision: {precision_macro:.4f}")
#     print(f"recall: {recall:.4f}")
#     print(f"Accuracy: {accuracy:.4f}")
#     print(f"f1: {f1:.4f}")