In [1]:
import numpy as np
import os

import multiprocessing as mp
import librosa
import csv

import torch
import torch.nn as nn
import torch.nn.functional as F
from dataloader import Dataset
from utils import *
from efficientnet_pytorch import EfficientNet
from tqdm import tqdm, tqdm_notebook
from sklearn.metrics import accuracy_score
from DCM.DCM import DCMoptimizer
import pandas as pd 

In [2]:
seed_everything(42)

In [14]:
valid_dir='../audioData/val_process_mel/'
test1_dir = '../audioData/test1_process_mel/'
test2_dir = '../audioData/test2_process_mel/'
test_private_dir = '../audioData/test_private_process_mel/'
valid_list=[item for item in os.listdir(valid_dir) if 'npy' in item]
test1_list = [item for item in os.listdir(test1_dir) if 'npy' in item]
test2_list = [item for item in os.listdir(test2_dir) if 'npy' in item]
test_private_list = [item for item in os.listdir(test_private_dir) if 'npy' in item]
test1_list.sort()
test2_list.sort()
test_private_list.sort()

label2emo={'hap':0,'ang':1,'dis':2,'fea':3,'sad':4,'neu':5,'sur':6}
emo2label = ['hap','ang','dis','fea','sad','neu','sur']
class_dist = np.array([5295, 6571, 6785, 5048, 6843, 9259, 5380])
class_dist = class_dist/np.sum(class_dist)
batch_size = 128

weight_paths = ['./model/b4_smooth_0_mix_0.5_RANGER_COSINEANNEALING/39_best_1.4295.pth',
                './model/b4_smooth_0_mix_0.2_RANGER_COSINEANNEALING/33_best_1.4162.pth',
                './model/b0_smooth_0_mix_0.5_RANGER_COSINEANNEALING/62_best_1.4409.pth',
                './model/b0_smooth_0_mix_0.2_RANGER_COSINEANNEALING/36_best_1.4377.pth'
               ]

In [4]:
valid_dataset = Dataset(file_list=valid_list,root_dir=valid_dir,label_smooth_weight=0,is_train=False)
valid_loader = torch.utils.data.DataLoader(valid_dataset,batch_size=batch_size,num_workers=15,shuffle=False)

In [5]:
test1_dataset = Dataset(file_list=test1_list,root_dir=test1_dir,label_smooth_weight=0,is_train=False)
test1_loader = torch.utils.data.DataLoader(test1_dataset,batch_size=batch_size,num_workers=15,shuffle=False)

In [6]:
test2_dataset = Dataset(file_list=test2_list,root_dir=test2_dir,label_smooth_weight=0,is_train=False)
test2_loader = torch.utils.data.DataLoader(test2_dataset,batch_size=batch_size,num_workers=15,shuffle=False)

In [24]:
test_private_list_by_actor = [[],[],[],[]]
for item in test_private_list:
    actor_num = int(item.split("-")[2])
    test_private_list_by_actor[actor_num-1].append(item)
    
test_private_dataset_1 = Dataset(file_list=test_private_list_by_actor[0],root_dir=test_private_dir,label_smooth_weight=0,is_train=False)
test_private_dataset_2 = Dataset(file_list=test_private_list_by_actor[1],root_dir=test_private_dir,label_smooth_weight=0,is_train=False)
test_private_dataset_3 = Dataset(file_list=test_private_list_by_actor[2],root_dir=test_private_dir,label_smooth_weight=0,is_train=False)
test_private_dataset_4 = Dataset(file_list=test_private_list_by_actor[3],root_dir=test_private_dir,label_smooth_weight=0,is_train=False)
test_private_loader_1 = torch.utils.data.DataLoader(test_private_dataset_1,batch_size=batch_size,num_workers=15,shuffle=False)
test_private_loader_2 = torch.utils.data.DataLoader(test_private_dataset_2,batch_size=batch_size,num_workers=15,shuffle=False)
test_private_loader_3 = torch.utils.data.DataLoader(test_private_dataset_3,batch_size=batch_size,num_workers=15,shuffle=False)
test_private_loader_4 = torch.utils.data.DataLoader(test_private_dataset_4,batch_size=batch_size,num_workers=15,shuffle=False)

In [7]:
def get_preds(models,ensemble_weights,loader):
    """
    
    ensemble_preds : ndarray N,7
    trues : ndarray N,1
    """
    
    assert len(models) == len(ensemble_weights), "number of models does not match with number of ensemble_weights"
    
    ensemble_weights = ensemble_weights/np.sum(ensemble_weights)
    ensemble_preds = []
    for model_idx,model in enumerate(models):
        model.eval()
        preds = []
        trues = []
        for idx,(x,y) in enumerate(tqdm(loader)):
            with torch.no_grad():
                x,y = x.cuda(),y.cuda()
                pred = model(x)
                pred = F.softmax(pred)
            pred=pred.cpu().detach().numpy()
            y = y.cpu().detach().numpy()
            preds.append(pred)
            trues.append(y)
        preds = np.concatenate(preds)
        trues = np.concatenate(trues)
        ensemble_preds.append(preds*ensemble_weights[model_idx])
    ensemble_preds = np.sum(ensemble_preds,axis=0)
    try:
        trues = np.argmax(trues,1)
    except:
        pass
    return ensemble_preds, trues

In [8]:
models = [None]*len(weight_paths)  
for idx,path in enumerate(weight_paths):
    model_name = path.split('/')[-2].split("_")[0]
    model = EfficientNet.from_pretrained(model_name='efficientnet-'+model_name)
    model._fc=nn.Linear(model._fc.in_features,7)
    model = nn.DataParallel(model)
    model.load_state_dict(torch.load(path))
    model.cuda()

    models[idx] = model

Loaded pretrained weights for efficientnet-b4
Loaded pretrained weights for efficientnet-b4
Loaded pretrained weights for efficientnet-b0
Loaded pretrained weights for efficientnet-b0


In [9]:
def search_weight(models,loader,N_iter = 20):
    """
    return best ensemble_weights and DCM optimizer
    """
    best_score = -np.inf
    best_weights = None
    best_optimizer = None
    for iter in range(N_iter):
        print("================")
        print("iteration :",iter)
        ensemble_weights = [np.random.uniform() for i in range(len(models))]
        ensemble_weights = ensemble_weights/np.sum(ensemble_weights)
        print("ensemble weight :",ensemble_weights)
        
        preds,trues = get_preds(models,ensemble_weights,valid_loader)
        
        search_space = [i/100 for i in range(30)]
        DCM = DCMoptimizer(n_classes = 7, weight = class_dist, predict = preds, true = trues,metric = accuracy_score)
        DCM.search(space = search_space,verbose=False)
        
        if best_score < DCM.best_score:
            best_score = DCM.best_score
            best_weights = ensemble_weights
            best_optimizer = DCM
        
        print("score :",DCM.best_score)
    
    print("best_weights :",best_weights)
    print("best_l :",best_optimizer.best_l)
    print("best_score :",best_score)
    
    return best_weights,best_optimizer

In [10]:
%%time
best_ensemble_weights, best_optimizer = search_weight(models,valid_loader,30)

  0%|          | 0/43 [00:00<?, ?it/s]

iteration : 0
ensemble weight : [0.14102156 0.35796222 0.27560979 0.22540643]


100%|██████████| 43/43 [00:12<00:00,  3.56it/s]
100%|██████████| 43/43 [00:08<00:00,  4.96it/s]
100%|██████████| 43/43 [00:05<00:00,  7.48it/s]
100%|██████████| 43/43 [00:05<00:00,  7.52it/s]
  0%|          | 0/43 [00:00<?, ?it/s]

search completed!
Final l = 0.1 best score is 0.47829313543599256
score : 0.47829313543599256
iteration : 1
ensemble weight : [0.12620081 0.1261813  0.04698284 0.70063506]


100%|██████████| 43/43 [00:08<00:00,  4.95it/s]
100%|██████████| 43/43 [00:08<00:00,  4.93it/s]
100%|██████████| 43/43 [00:05<00:00,  7.55it/s]
100%|██████████| 43/43 [00:05<00:00,  7.54it/s]
  0%|          | 0/43 [00:00<?, ?it/s]

search completed!
Final l = 0.08 best score is 0.4647495361781076
score : 0.4647495361781076
iteration : 2
ensemble weight : [0.2613905  0.30790022 0.00895102 0.42175826]


100%|██████████| 43/43 [00:08<00:00,  4.91it/s]
100%|██████████| 43/43 [00:08<00:00,  4.94it/s]
100%|██████████| 43/43 [00:05<00:00,  7.62it/s]
100%|██████████| 43/43 [00:05<00:00,  7.46it/s]
  0%|          | 0/43 [00:00<?, ?it/s]

search completed!
Final l = 0.15 best score is 0.4779220779220779
score : 0.4779220779220779
iteration : 3
ensemble weight : [0.59038015 0.15059391 0.12895285 0.13007308]


100%|██████████| 43/43 [00:08<00:00,  5.03it/s]
100%|██████████| 43/43 [00:08<00:00,  4.95it/s]
100%|██████████| 43/43 [00:05<00:00,  7.63it/s]
100%|██████████| 43/43 [00:05<00:00,  7.44it/s]
  0%|          | 0/43 [00:00<?, ?it/s]

search completed!
Final l = 0.01 best score is 0.4792207792207792
score : 0.4792207792207792
iteration : 4
ensemble weight : [0.19601054 0.33807861 0.2782841  0.18762675]


100%|██████████| 43/43 [00:08<00:00,  4.92it/s]
100%|██████████| 43/43 [00:08<00:00,  4.95it/s]
100%|██████████| 43/43 [00:05<00:00,  7.53it/s]
100%|██████████| 43/43 [00:05<00:00,  7.44it/s]
  0%|          | 0/43 [00:00<?, ?it/s]

search completed!
Final l = 0.17 best score is 0.47940630797773653
score : 0.47940630797773653
iteration : 5
ensemble weight : [0.43398339 0.09894211 0.20721635 0.25985814]


100%|██████████| 43/43 [00:08<00:00,  4.95it/s]
100%|██████████| 43/43 [00:08<00:00,  4.93it/s]
100%|██████████| 43/43 [00:05<00:00,  7.50it/s]
100%|██████████| 43/43 [00:05<00:00,  7.46it/s]
  0%|          | 0/43 [00:00<?, ?it/s]

search completed!
Final l = 0.02 best score is 0.48070500927643783
score : 0.48070500927643783
iteration : 6
ensemble weight : [0.23326548 0.40159286 0.10212687 0.26301478]


100%|██████████| 43/43 [00:08<00:00,  4.95it/s]
100%|██████████| 43/43 [00:08<00:00,  4.93it/s]
100%|██████████| 43/43 [00:05<00:00,  7.53it/s]
100%|██████████| 43/43 [00:05<00:00,  7.44it/s]
  0%|          | 0/43 [00:00<?, ?it/s]

search completed!
Final l = 0.2 best score is 0.47940630797773653
score : 0.47940630797773653
iteration : 7
ensemble weight : [0.41809611 0.03278234 0.42877429 0.12034726]


100%|██████████| 43/43 [00:08<00:00,  4.90it/s]
100%|██████████| 43/43 [00:08<00:00,  4.94it/s]
100%|██████████| 43/43 [00:05<00:00,  7.54it/s]
100%|██████████| 43/43 [00:05<00:00,  7.42it/s]
  0%|          | 0/43 [00:00<?, ?it/s]

search completed!
Final l = 0.0 best score is 0.4749536178107607
score : 0.4749536178107607
iteration : 8
ensemble weight : [0.02333299 0.34035041 0.34635711 0.28995949]


100%|██████████| 43/43 [00:08<00:00,  4.98it/s]
100%|██████████| 43/43 [00:08<00:00,  4.92it/s]
100%|██████████| 43/43 [00:05<00:00,  7.65it/s]
100%|██████████| 43/43 [00:05<00:00,  7.43it/s]
  0%|          | 0/43 [00:00<?, ?it/s]

search completed!
Final l = 0.16 best score is 0.47569573283858996
score : 0.47569573283858996
iteration : 9
ensemble weight : [0.19952805 0.06397717 0.44818618 0.2883086 ]


100%|██████████| 43/43 [00:08<00:00,  4.93it/s]
100%|██████████| 43/43 [00:08<00:00,  4.96it/s]
100%|██████████| 43/43 [00:05<00:00,  7.58it/s]
100%|██████████| 43/43 [00:05<00:00,  7.50it/s]
  0%|          | 0/43 [00:00<?, ?it/s]

search completed!
Final l = 0.02 best score is 0.476252319109462
score : 0.476252319109462
iteration : 10
ensemble weight : [0.07818333 0.31723318 0.02203087 0.58255262]


100%|██████████| 43/43 [00:08<00:00,  4.93it/s]
100%|██████████| 43/43 [00:08<00:00,  4.94it/s]
100%|██████████| 43/43 [00:05<00:00,  7.59it/s]
100%|██████████| 43/43 [00:05<00:00,  7.49it/s]
  0%|          | 0/43 [00:00<?, ?it/s]

search completed!
Final l = 0.14 best score is 0.47217068645640076
score : 0.47217068645640076
iteration : 11
ensemble weight : [0.14761436 0.37791873 0.17780754 0.29665937]


100%|██████████| 43/43 [00:08<00:00,  4.93it/s]
100%|██████████| 43/43 [00:08<00:00,  4.87it/s]
100%|██████████| 43/43 [00:05<00:00,  7.56it/s]
100%|██████████| 43/43 [00:05<00:00,  7.43it/s]
  0%|          | 0/43 [00:00<?, ?it/s]

search completed!
Final l = 0.14 best score is 0.48126159554730985
score : 0.48126159554730985
iteration : 12
ensemble weight : [0.22077867 0.07465    0.39154852 0.31302282]


100%|██████████| 43/43 [00:08<00:00,  4.96it/s]
100%|██████████| 43/43 [00:08<00:00,  4.90it/s]
100%|██████████| 43/43 [00:05<00:00,  7.55it/s]
100%|██████████| 43/43 [00:05<00:00,  7.39it/s]
  0%|          | 0/43 [00:00<?, ?it/s]

search completed!
Final l = 0.03 best score is 0.4766233766233766
score : 0.4766233766233766
iteration : 13
ensemble weight : [0.28010459 0.26678609 0.17825941 0.27484991]


100%|██████████| 43/43 [00:08<00:00,  4.92it/s]
100%|██████████| 43/43 [00:08<00:00,  4.92it/s]
100%|██████████| 43/43 [00:05<00:00,  7.59it/s]
100%|██████████| 43/43 [00:05<00:00,  7.21it/s]
  0%|          | 0/43 [00:00<?, ?it/s]

search completed!
Final l = 0.09 best score is 0.4803339517625232
score : 0.4803339517625232
iteration : 14
ensemble weight : [0.13509625 0.29919541 0.06904582 0.49666252]


100%|██████████| 43/43 [00:08<00:00,  4.99it/s]
100%|██████████| 43/43 [00:08<00:00,  4.94it/s]
100%|██████████| 43/43 [00:05<00:00,  7.60it/s]
100%|██████████| 43/43 [00:05<00:00,  7.48it/s]
  0%|          | 0/43 [00:00<?, ?it/s]

search completed!
Final l = 0.07 best score is 0.4766233766233766
score : 0.4766233766233766
iteration : 15
ensemble weight : [0.21060616 0.14703143 0.44905435 0.19330805]


100%|██████████| 43/43 [00:08<00:00,  4.93it/s]
100%|██████████| 43/43 [00:08<00:00,  4.91it/s]
100%|██████████| 43/43 [00:05<00:00,  7.65it/s]
100%|██████████| 43/43 [00:05<00:00,  7.43it/s]
  0%|          | 0/43 [00:00<?, ?it/s]

search completed!
Final l = 0.01 best score is 0.47680890538033394
score : 0.47680890538033394
iteration : 16
ensemble weight : [0.15901187 0.30717166 0.07976459 0.45405188]


100%|██████████| 43/43 [00:08<00:00,  5.01it/s]
100%|██████████| 43/43 [00:08<00:00,  4.98it/s]
100%|██████████| 43/43 [00:05<00:00,  7.57it/s]
100%|██████████| 43/43 [00:05<00:00,  7.54it/s]
  0%|          | 0/43 [00:00<?, ?it/s]

search completed!
Final l = 0.1 best score is 0.4775510204081633
score : 0.4775510204081633
iteration : 17
ensemble weight : [0.03668112 0.48557759 0.37996729 0.097774  ]


100%|██████████| 43/43 [00:08<00:00,  5.00it/s]
100%|██████████| 43/43 [00:08<00:00,  4.97it/s]
100%|██████████| 43/43 [00:05<00:00,  7.55it/s]
100%|██████████| 43/43 [00:05<00:00,  7.37it/s]
  0%|          | 0/43 [00:00<?, ?it/s]

search completed!
Final l = 0.13 best score is 0.474025974025974
score : 0.474025974025974
iteration : 18
ensemble weight : [0.00244683 0.36132757 0.31320555 0.32302005]


100%|██████████| 43/43 [00:08<00:00,  4.97it/s]
100%|██████████| 43/43 [00:08<00:00,  4.94it/s]
100%|██████████| 43/43 [00:05<00:00,  7.56it/s]
100%|██████████| 43/43 [00:05<00:00,  7.59it/s]
  0%|          | 0/43 [00:00<?, ?it/s]

search completed!
Final l = 0.02 best score is 0.4764378478664193
score : 0.4764378478664193
iteration : 19
ensemble weight : [0.58445078 0.05610932 0.27163701 0.08780289]


100%|██████████| 43/43 [00:08<00:00,  5.01it/s]
100%|██████████| 43/43 [00:08<00:00,  4.99it/s]
100%|██████████| 43/43 [00:05<00:00,  7.67it/s]
100%|██████████| 43/43 [00:05<00:00,  7.49it/s]
  0%|          | 0/43 [00:00<?, ?it/s]

search completed!
Final l = 0.07 best score is 0.47829313543599256
score : 0.47829313543599256
iteration : 20
ensemble weight : [0.45888816 0.33139033 0.1759293  0.03379221]


100%|██████████| 43/43 [00:08<00:00,  4.99it/s]
100%|██████████| 43/43 [00:08<00:00,  4.93it/s]
100%|██████████| 43/43 [00:05<00:00,  7.51it/s]
100%|██████████| 43/43 [00:05<00:00,  7.58it/s]
  0%|          | 0/43 [00:00<?, ?it/s]

search completed!
Final l = 0.02 best score is 0.47977736549165123
score : 0.47977736549165123
iteration : 21
ensemble weight : [0.15523275 0.16232145 0.36419683 0.31824896]


100%|██████████| 43/43 [00:08<00:00,  5.00it/s]
100%|██████████| 43/43 [00:08<00:00,  4.98it/s]
100%|██████████| 43/43 [00:05<00:00,  7.72it/s]
100%|██████████| 43/43 [00:05<00:00,  7.41it/s]
  0%|          | 0/43 [00:00<?, ?it/s]

search completed!
Final l = 0.13 best score is 0.4766233766233766
score : 0.4766233766233766
iteration : 22
ensemble weight : [0.4047011  0.21540031 0.05455278 0.32534581]


100%|██████████| 43/43 [00:08<00:00,  5.04it/s]
100%|██████████| 43/43 [00:08<00:00,  4.97it/s]
100%|██████████| 43/43 [00:05<00:00,  7.57it/s]
100%|██████████| 43/43 [00:05<00:00,  7.49it/s]
  0%|          | 0/43 [00:00<?, ?it/s]

search completed!
Final l = 0.09 best score is 0.47829313543599256
score : 0.47829313543599256
iteration : 23
ensemble weight : [0.29409993 0.21697532 0.29803608 0.19088867]


100%|██████████| 43/43 [00:08<00:00,  5.00it/s]
100%|██████████| 43/43 [00:08<00:00,  4.93it/s]
100%|██████████| 43/43 [00:05<00:00,  7.66it/s]
100%|██████████| 43/43 [00:05<00:00,  7.50it/s]
  0%|          | 0/43 [00:00<?, ?it/s]

search completed!
Final l = 0.09 best score is 0.4792207792207792
score : 0.4792207792207792
iteration : 24
ensemble weight : [0.48241081 0.39456181 0.02345837 0.09956901]


100%|██████████| 43/43 [00:08<00:00,  4.98it/s]
100%|██████████| 43/43 [00:08<00:00,  4.92it/s]
100%|██████████| 43/43 [00:05<00:00,  7.64it/s]
100%|██████████| 43/43 [00:05<00:00,  7.47it/s]
  0%|          | 0/43 [00:00<?, ?it/s]

search completed!
Final l = 0.0 best score is 0.47717996289424863
score : 0.47717996289424863
iteration : 25
ensemble weight : [0.02108257 0.42690154 0.21086872 0.34114717]


100%|██████████| 43/43 [00:08<00:00,  4.92it/s]
100%|██████████| 43/43 [00:08<00:00,  4.94it/s]
100%|██████████| 43/43 [00:05<00:00,  7.48it/s]
100%|██████████| 43/43 [00:05<00:00,  7.55it/s]
  0%|          | 0/43 [00:00<?, ?it/s]

search completed!
Final l = 0.19 best score is 0.4779220779220779
score : 0.4779220779220779
iteration : 26
ensemble weight : [0.3907221  0.10732435 0.17667651 0.32527703]


100%|██████████| 43/43 [00:08<00:00,  4.93it/s]
100%|██████████| 43/43 [00:08<00:00,  4.90it/s]
100%|██████████| 43/43 [00:05<00:00,  7.55it/s]
100%|██████████| 43/43 [00:05<00:00,  7.63it/s]
  0%|          | 0/43 [00:00<?, ?it/s]

search completed!
Final l = 0.04 best score is 0.48014842300556587
score : 0.48014842300556587
iteration : 27
ensemble weight : [0.30234281 0.10172425 0.38288885 0.21304409]


100%|██████████| 43/43 [00:08<00:00,  4.98it/s]
100%|██████████| 43/43 [00:08<00:00,  4.96it/s]
100%|██████████| 43/43 [00:05<00:00,  7.60it/s]
100%|██████████| 43/43 [00:05<00:00,  7.45it/s]
  0%|          | 0/43 [00:00<?, ?it/s]

search completed!
Final l = 0.11 best score is 0.47717996289424863
score : 0.47717996289424863
iteration : 28
ensemble weight : [0.28670636 0.24921355 0.19533327 0.26874682]


100%|██████████| 43/43 [00:08<00:00,  4.95it/s]
100%|██████████| 43/43 [00:08<00:00,  4.93it/s]
100%|██████████| 43/43 [00:05<00:00,  7.49it/s]
100%|██████████| 43/43 [00:05<00:00,  7.54it/s]
  0%|          | 0/43 [00:00<?, ?it/s]

search completed!
Final l = 0.09 best score is 0.48014842300556587
score : 0.48014842300556587
iteration : 29
ensemble weight : [0.33180202 0.07702684 0.36849966 0.22267148]


100%|██████████| 43/43 [00:08<00:00,  4.98it/s]
100%|██████████| 43/43 [00:08<00:00,  4.92it/s]
100%|██████████| 43/43 [00:05<00:00,  7.41it/s]
100%|██████████| 43/43 [00:05<00:00,  7.56it/s]


search completed!
Final l = 0.06 best score is 0.4777365491651206
score : 0.4777365491651206
best_weights : [0.14761436 0.37791873 0.17780754 0.29665937]
best_l : 0.14
best_score : 0.48126159554730985
CPU times: user 24min 8s, sys: 6min 50s, total: 30min 58s
Wall time: 30min 19s


In [11]:
model_setting = 'MIXUP MODEL 4 ENSEMBLE'

In [12]:
preds,_ = get_preds(models,best_ensemble_weights,test1_loader)
df = pd.read_csv("./qia_test1_baseline.csv")
labels = [emo2label[pred] for pred in best_optimizer.apply(preds)]
df.iloc[:,1] = labels
df.to_csv('./result_csv/'+model_setting+'-test1-w-dcm.csv',index=False)

preds,_ = get_preds(models,[1]*len(models),test1_loader)
df = pd.read_csv("./qia_test1_baseline.csv")
labels = [emo2label[pred] for pred in np.argmax(preds,1)]
df.iloc[:,1] = labels
df.to_csv('./result_csv/'+ model_setting+'-test1-wo-dcm.csv',index=False)

100%|██████████| 22/22 [01:38<00:00,  4.49s/it]
100%|██████████| 22/22 [00:12<00:00,  1.83it/s]
100%|██████████| 22/22 [00:10<00:00,  2.15it/s]
100%|██████████| 22/22 [00:11<00:00,  1.94it/s]
100%|██████████| 22/22 [00:13<00:00,  1.69it/s]
100%|██████████| 22/22 [00:13<00:00,  1.61it/s]
100%|██████████| 22/22 [00:11<00:00,  1.89it/s]
100%|██████████| 22/22 [00:12<00:00,  1.83it/s]


In [13]:
preds,_ = get_preds(models,best_ensemble_weights,test2_loader)
df = pd.read_csv("./qia_test2_baseline.csv")
labels = [emo2label[pred] for pred in best_optimizer.apply(preds)]
df.iloc[:,1] = labels
df.to_csv('./result_csv/'+ model_setting+'-test2-w-dcm.csv',index=False)

preds,_ = get_preds(models,[1]*len(models),test2_loader)
df = pd.read_csv("./qia_test2_baseline.csv")
labels = [emo2label[pred] for pred in np.argmax(preds,1)]
df.iloc[:,1] = labels
df.to_csv('./result_csv/'+ model_setting+'-test2-wo-dcm.csv',index=False)

100%|██████████| 23/23 [00:13<00:00,  1.72it/s]
100%|██████████| 23/23 [00:13<00:00,  1.72it/s]
100%|██████████| 23/23 [00:11<00:00,  1.94it/s]
100%|██████████| 23/23 [00:11<00:00,  1.97it/s]
100%|██████████| 23/23 [00:13<00:00,  1.72it/s]
100%|██████████| 23/23 [00:13<00:00,  1.72it/s]
100%|██████████| 23/23 [00:11<00:00,  1.94it/s]
100%|██████████| 23/23 [00:11<00:00,  1.96it/s]


## Test on our data

In [52]:
from sklearn.metrics import confusion_matrix
np.set_printoptions(precision=4)

In [54]:
preds,true = get_preds(models,best_ensemble_weights,test_private_loader_1)
final = [pred for pred in best_optimizer.apply(preds)]
print(accuracy_score(true,final))
cm = confusion_matrix(true,final)
cm = cm/np.sum(cm,1)
cm

100%|██████████| 1/1 [00:10<00:00, 10.44s/it]
100%|██████████| 1/1 [00:10<00:00, 10.40s/it]
100%|██████████| 1/1 [00:10<00:00, 10.38s/it]
100%|██████████| 1/1 [00:10<00:00, 10.30s/it]

0.2222222222222222





array([[0.    , 0.1111, 0.2222, 0.    , 0.0556, 0.6111, 0.    ],
       [0.    , 0.0556, 0.1111, 0.    , 0.2778, 0.5556, 0.    ],
       [0.1111, 0.    , 0.0556, 0.    , 0.3333, 0.4444, 0.0556],
       [0.    , 0.    , 0.    , 0.    , 0.5   , 0.4444, 0.0556],
       [0.    , 0.    , 0.    , 0.    , 0.6667, 0.2778, 0.0556],
       [0.    , 0.    , 0.0556, 0.    , 0.3889, 0.5556, 0.    ],
       [0.    , 0.2778, 0.0556, 0.0556, 0.1111, 0.2778, 0.2222]])

In [57]:
preds,true = get_preds(models,best_ensemble_weights,test_private_loader_2)
final = [pred for pred in best_optimizer.apply(preds)]
print(accuracy_score(true,final))
cm = confusion_matrix(true,final)
cm = cm/np.sum(cm,1)
cm



100%|██████████| 1/1 [00:11<00:00, 11.41s/it]

  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:10<00:00, 10.39s/it]

  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:10<00:00, 10.12s/it]

  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:10<00:00, 10.34s/it]

0.23622047244094488





array([[0.1667, 0.    , 0.0556, 0.0556, 0.4444, 0.2778, 0.    ],
       [0.0556, 0.    , 0.    , 0.    , 0.3889, 0.5556, 0.    ],
       [0.    , 0.    , 0.0556, 0.    , 0.8333, 0.0556, 0.0526],
       [0.    , 0.    , 0.    , 0.    , 0.7778, 0.2222, 0.    ],
       [0.    , 0.    , 0.    , 0.0556, 0.8333, 0.1111, 0.    ],
       [0.2222, 0.    , 0.0556, 0.    , 0.2778, 0.4444, 0.    ],
       [0.    , 0.0556, 0.2222, 0.2222, 0.2222, 0.1667, 0.1579]])

In [58]:
preds,true = get_preds(models,best_ensemble_weights,test_private_loader_3)
final = [pred for pred in best_optimizer.apply(preds)]
print(accuracy_score(true,final))
cm = confusion_matrix(true,final)
cm = cm/np.sum(cm,1)
cm



100%|██████████| 1/1 [00:10<00:00, 10.19s/it]

  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:10<00:00, 10.41s/it]

  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:10<00:00, 10.33s/it]

  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:10<00:00, 10.44s/it]

0.29365079365079366





array([[0.2222, 0.6111, 0.    , 0.    , 0.0556, 0.1111, 0.    ],
       [0.    , 0.5556, 0.    , 0.0556, 0.0556, 0.2778, 0.0556],
       [0.0556, 0.5556, 0.    , 0.    , 0.0556, 0.3333, 0.    ],
       [0.0556, 0.2778, 0.    , 0.1667, 0.2222, 0.2778, 0.    ],
       [0.0556, 0.1111, 0.2222, 0.0556, 0.3889, 0.1667, 0.    ],
       [0.1111, 0.2778, 0.    , 0.0556, 0.0556, 0.5   , 0.    ],
       [0.1111, 0.1111, 0.1111, 0.1111, 0.1111, 0.2222, 0.2222]])

In [59]:
preds,true = get_preds(models,best_ensemble_weights,test_private_loader_4)
final = [pred for pred in best_optimizer.apply(preds)]
print(accuracy_score(true,final))
cm = confusion_matrix(true,final)
cm = cm/np.sum(cm,1)
cm



100%|██████████| 1/1 [00:10<00:00, 10.35s/it]

  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:10<00:00, 10.20s/it]

  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:10<00:00, 10.27s/it]

  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:10<00:00, 10.35s/it]

0.29365079365079366





array([[0.    , 0.0556, 0.1667, 0.    , 0.2778, 0.4444, 0.0556],
       [0.    , 0.1667, 0.2222, 0.    , 0.1667, 0.3889, 0.0556],
       [0.    , 0.    , 0.1111, 0.    , 0.3333, 0.5   , 0.0556],
       [0.    , 0.    , 0.0556, 0.1667, 0.3889, 0.2222, 0.1667],
       [0.    , 0.    , 0.1111, 0.1111, 0.6111, 0.1667, 0.    ],
       [0.    , 0.    , 0.1111, 0.    , 0.2778, 0.5556, 0.0556],
       [0.0556, 0.1111, 0.1111, 0.0556, 0.0556, 0.1667, 0.4444]])