In [1]:
from torch import nn
from torchvision.transforms import v2
import torchvision
from sklearn.metrics import confusion_matrix, classification_report, roc_auc_score
from tqdm.notebook import tqdm
import numpy as np
from torchsummary import summary
import time

In [2]:
!git clone https://github.com/simon20010923/DDAMFN.git
%cd DDAMFN

Cloning into 'DDAMFN'...
remote: Enumerating objects: 215, done.[K
remote: Counting objects: 100% (145/145), done.[K
remote: Compressing objects: 100% (98/98), done.[K
remote: Total 215 (delta 54), reused 110 (delta 40), pack-reused 70[K
Receiving objects: 100% (215/215), 62.00 MiB | 13.93 MiB/s, done.
Resolving deltas: 100% (79/79), done.
/content/DDAMFN


In [3]:
import torch
import torch.utils.data as data
from networks.DDAM import DDAMNet

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
model = DDAMNet(num_class=7,num_head=2)
checkpoint = torch.load('/content/DDAMFN/DDAMFN++/checkpoints_ver2.0/rafdb_epoch20_acc0.9204_bacc0.8617.pth', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)
model.eval()

In [None]:
summary(model, (3, 112, 112))

In [6]:
data_transforms = v2.Compose([
                              v2.Resize((112, 112)),
                              v2.ToDtype(torch.float32, scale=True),
                              v2.Normalize(mean=[0.485, 0.456, 0.406],
                                           std=[0.229, 0.224, 0.225])
                             ])

In [7]:
sota_class_names = ['neutral', 'happy', 'sad', 'surprise', 'fear', 'disgust', 'anger']
used_class_names = ['anger', 'disgust', 'fear', 'happy', 'neutral', 'sad', 'surprise']
sl_to_ul = {0: 4, 1: 3, 2: 5, 3: 6, 4: 2, 5: 1, 6: 0}       # from sota labels to used labels
ul_to_sl = {4: 0, 3: 1, 5:2, 6: 3, 2: 4, 1: 5, 0: 6}        # from used labels to sota labels

In [8]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, x, y, transform=None):
        self.x = x
        self.y = y
        self.tf = transform

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        if self.tf:
            return self.tf(self.x[idx]), self.y[idx]
        return self.x[idx], self.y[idx]

In [10]:
raf_x = torch.load('/content/drive/MyDrive/Colab Notebooks/project/to_check/raf/test_raf_x.pt')
raf_y = torch.load('/content/drive/MyDrive/Colab Notebooks/project/to_check/raf/test_raf_y.pt')
test_dataset = Dataset(raf_x, raf_y, data_transforms)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size = 16, drop_last=True)

In [26]:
def change_type_predictions(predictions, mapper:dict):
    to_predictions = []

    if torch.is_tensor(predictions):
        to_predictions = to_predictions.numpy().tolist()

    for pred in predictions:
        to_predictions.append(mapper[pred])
    return to_predictions

In [27]:
def test_sota(model, test_dataloader, sl_to_ul, ul_to_sl, used_class_names, device):
    pred_y = []
    true_y = []
    pred_score = np.full((0, 7), -1)
    mean_latency = 0
    model.eval()
    softmax = nn.Softmax(dim=1)
    with torch.no_grad():
        for X, y in tqdm(test_dataloader):
            X = X.to(device)
            start = time.time()
            predictions, _, _ = model(X)
            end = time.time()
            mean_latency += end - start
            pred_y.append(torch.argmax(predictions, dim=-1).cpu().numpy().tolist())
            pred_score = np.vstack((pred_score, softmax(predictions).cpu().numpy()))
            true_y.append(y.numpy().tolist())

    pred_y = np.array(pred_y).ravel().tolist()
    true_y = np.array(true_y).ravel().tolist()

    pred_y = change_type_predictions(pred_y, sl_to_ul)

    changed_to_sota_true_y = change_type_predictions(true_y, ul_to_sl)

    conf_m = confusion_matrix(pred_y, true_y)
    class_rep = classification_report(true_y, pred_y, target_names=used_class_names)
    roc_auc = roc_auc_score(changed_to_sota_true_y, pred_score.tolist(), multi_class='ovo')

    num_elements = test_dataloader.batch_size * len(test_dataloader)

    return conf_m, class_rep, roc_auc, mean_latency / num_elements

In [28]:
conf_m, class_rep, roc_auc, latency = test_sota(model, test_dataloader, sl_to_ul, ul_to_sl, used_class_names, device)

  0%|          | 0/191 [00:00<?, ?it/s]

In [29]:
conf_m

array([[ 143,    9,    2,    3,    1,    1,    5],
       [   4,  114,    0,    3,   10,    8,    3],
       [   3,    2,   49,    3,    1,    4,    8],
       [   5,    9,    2, 1134,   10,   10,    2],
       [   4,   10,    2,   31,  608,   30,   11],
       [   0,   13,    9,    5,   30,  422,    1],
       [   3,    3,   10,    6,    8,    3,  299]])

In [30]:
print(class_rep)

              precision    recall  f1-score   support

       anger       0.87      0.88      0.88       162
     disgust       0.80      0.71      0.75       160
        fear       0.70      0.66      0.68        74
       happy       0.97      0.96      0.96      1185
     neutral       0.87      0.91      0.89       668
         sad       0.88      0.88      0.88       478
    surprise       0.90      0.91      0.90       329

    accuracy                           0.91      3056
   macro avg       0.86      0.85      0.85      3056
weighted avg       0.91      0.91      0.91      3056



In [31]:
roc_auc

0.984091266422086

In [32]:
latency

0.003123847445892414

In [33]:
ck_x = torch.load('/content/drive/MyDrive/Colab Notebooks/project/to_check/ck_tensor112_x.pt')
ck_y = torch.load('/content/drive/MyDrive/Colab Notebooks/project/to_check/ck_tensor112_y.pt')
test_dataset = Dataset(ck_x, ck_y, data_transforms)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size = 16, drop_last=True)

In [35]:
conf_m, class_rep, roc_auc, latency = test_sota(model, test_dataloader, sl_to_ul, ul_to_sl, used_class_names, device)

  0%|          | 0/56 [00:00<?, ?it/s]

In [36]:
conf_m

array([[  4,  11,   0,   0,   0,   0,   0],
       [ 20,  42,   1,   0,   4,   0,   0],
       [  1,   0,  13,   0,   0,   0,   0],
       [  1,   2,   0,  68,  29,   0,   1],
       [  1,   1,   0,   0, 513,   1,   0],
       [ 12,   3,  11,   1,  47,  27,   0],
       [  0,   0,   0,   0,   0,   0,  82]])

In [37]:
print(class_rep)

              precision    recall  f1-score   support

       anger       0.27      0.10      0.15        39
     disgust       0.63      0.71      0.67        59
        fear       0.93      0.52      0.67        25
       happy       0.67      0.99      0.80        69
     neutral       0.99      0.87      0.93       593
         sad       0.27      0.96      0.42        28
    surprise       1.00      0.99      0.99        83

    accuracy                           0.84       896
   macro avg       0.68      0.73      0.66       896
weighted avg       0.89      0.84      0.85       896



In [38]:
roc_auc

0.9655785655736328

In [39]:
latency

0.004232391981141908

In [40]:
micro_x = torch.load('/content/drive/MyDrive/Colab Notebooks/project/to_check/micro_test_x.pt')
micro_y = torch.load('/content/drive/MyDrive/Colab Notebooks/project/to_check/micro_test_y.pt')
test_dataset = Dataset(micro_x, micro_y, data_transforms)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size = 16, drop_last=True)

In [41]:
conf_m, class_rep, roc_auc, latency = test_sota(model, test_dataloader, sl_to_ul, ul_to_sl, used_class_names, device)

  0%|          | 0/91 [00:00<?, ?it/s]

In [42]:
conf_m

array([[134,   9,   3,   0,   0,   1,   2],
       [ 38,  64,   1,   0,   1,   7,   0],
       [ 16,   5,  29,   1,   0,   2,   6],
       [ 14,  20,   5, 389,  29,   7,  26],
       [ 29,   1,   1,   0, 106,  29,   1],
       [ 22,  30,   5,   1,   5, 180,   1],
       [ 10,   5,  34,   3,   3,   2, 179]])

In [43]:
print(class_rep)

              precision    recall  f1-score   support

       anger       0.90      0.51      0.65       263
     disgust       0.58      0.48      0.52       134
        fear       0.49      0.37      0.42        78
       happy       0.79      0.99      0.88       394
     neutral       0.63      0.74      0.68       144
         sad       0.74      0.79      0.76       228
    surprise       0.76      0.83      0.79       215

    accuracy                           0.74      1456
   macro avg       0.70      0.67      0.67      1456
weighted avg       0.75      0.74      0.73      1456



In [44]:
roc_auc

0.9296239904464431

In [45]:
latency

0.004018594930460165