In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
from tqdm import tqdm
import numpy as np
from numpy import linalg as LA
import pandas as pd
from sklearn.metrics.pairwise import cosine_similarity

import json
import keras
import tensorflow
from keras import layers, Model
from keras.models import Sequential
from keras.applications import DenseNet201
from keras.callbacks import Callback, EarlyStopping, ModelCheckpoint
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, f1_score, precision_score, recall_score, roc_curve

In [None]:
csv_path = 'drive/My Drive/Team No Name/UrbanSound8K/metadata/UrbanSound8K.csv'
spectrograms_path = "drive/My Drive/Team No Name/numpySpectrograms/"
model_save_path = "drive/My Drive/Team No Name/basemodel-known"
embeddings_path = "drive/My Drive/Team No Name/embeddings/"
embeddings_model_save_path = "drive/My Drive/Team No Name/feature_extraction_test"
test_size = 0.2
val_size = 0.2
batch_size = 16
num_classes = 5

In [None]:
class SpecLoader(keras.utils.Sequence):
  def __init__(self, x_set, y_set, batch_size, spec_dir):
    self.x, self.y = x_set, y_set
    self.batch_size = batch_size
    self.spec_dir = spec_dir

  def __len__(self):
    return int(np.ceil(len(self.x) / self.batch_size))

  def __getitem__(self, idx):
    batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
    batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]

    batchSpecs = []
    for fileName in batch_x:
        spec = np.load(self.spec_dir + fileName + ".npy")
        batchSpecs.append(spec.transpose())
    return np.array(batchSpecs), np.array(batch_y)

In [None]:
data_df = pd.read_csv(csv_path)
data_df_known = data_df.loc[data_df["classID"] < 5]
data_df_unknown = data_df.loc[data_df["classID"] >= 5]

X_trainval, known_X_test, y_trainval, known_y_test = train_test_split(data_df_known['slice_file_name'].tolist(), data_df_known['classID'].tolist(), test_size=test_size, random_state = 42)
X_train, X_val, y_train, y_val = train_test_split(X_trainval, y_trainval, test_size=val_size, random_state = 42)
X_trash, X_unknown, y_trash, y_unknown = train_test_split(data_df_unknown['slice_file_name'].tolist(), data_df_unknown['classID'].tolist(), test_size=test_size, random_state = 42)
y_unknown = [-1] * len(y_unknown)

X_test = known_X_test + X_unknown
y_test = known_y_test + y_unknown
train_loader = SpecLoader(X_train, y_train, batch_size, spectrograms_path)
known_test_loader = SpecLoader(known_X_test, known_y_test, batch_size, spectrograms_path)
test_loader = SpecLoader(X_test, y_test, batch_size, spectrograms_path)
val_loader = SpecLoader(X_val, y_val, batch_size, spectrograms_path)
trainval_loader = SpecLoader(X_trainval, y_trainval, batch_size, spectrograms_path)

# Base Model

In [None]:
densenet = DenseNet201(
            include_top=False,
            weights="imagenet",
            input_tensor=None,
            input_shape=None,
            pooling=None)
model = Sequential()
model.add(densenet)
model.add(layers.GlobalAveragePooling2D())
model.add(layers.Dense(num_classes, activation="softmax"))
model.summary()
model.compile(
      optimizer="Adam",
      loss="sparse_categorical_crossentropy",
      metrics=["accuracy"])

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/densenet/densenet201_weights_tf_dim_ordering_tf_kernels_notop.h5
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
densenet201 (Functional)     (None, None, None, 1920)  18321984  
_________________________________________________________________
global_average_pooling2d (Gl (None, 1920)              0         
_________________________________________________________________
dense (Dense)                (None, 5)                 9605      
Total params: 18,331,589
Trainable params: 18,102,533
Non-trainable params: 229,056
_________________________________________________________________


In [None]:
earlystopping = EarlyStopping(
                    patience=5, 
                    restore_best_weights=True)
checkpoint = ModelCheckpoint(
                    model_save_path, 
                    monitor="val_accuracy", 
                    save_best_only=True)

model.fit(x=train_loader,
          validation_data=val_loader,
          callbacks=[checkpoint, earlystopping],
          epochs=70,
          verbose=1
         )

Epoch 1/70
Epoch 2/70
Epoch 3/70
Epoch 4/70
Epoch 5/70
Epoch 6/70
Epoch 7/70
Epoch 8/70
Epoch 9/70
Epoch 10/70
Epoch 11/70
Epoch 12/70
Epoch 13/70
Epoch 14/70
Epoch 15/70
Epoch 16/70
Epoch 17/70


<keras.callbacks.callbacks.History at 0x7fa9296f6fd0>

In [None]:
model.load_weights(model_save_path)
# test_loss, test_accuracy= model.evaluate(x=known_test_loader)
# print(test_loss, test_accuracy)

# Softmax Threshold

In [None]:
def generate_thresholds(softmaxes, predictions, expected):
    trainval_df = pd.DataFrame([np.amax(softmaxes, 1), predictions, expected])
    trainval_df = trainval_df.transpose()
    trainval_df.columns = ["probability", "predicted", "expected"]
    thresholds = []
    for i in range(num_classes):
        class_df = trainval_df.loc[trainval_df["expected"] == i]
        class_df = class_df.loc[class_df["expected"] == class_df["predicted"]]
        thresholds.append(class_df["probability"].min())
    return thresholds

def check_negative(sample_softmax, thresholds):
    probability = max(sample_softmax)
    idx = sample_softmax.index(probability)
    if probability < thresholds[idx]:
        return -1
    return idx

def evaluate(predicted, expected):
    acc = np.mean(np.array(predicted) == np.array(expected))
    print("Overall accuracy: {}".format(acc))
    acc_dict = {}
    for i in range(len(expected)):
        expected_class = expected[i]
        if expected_class not in acc_dict:
            acc_dict[expected_class] = [0, 0]
        acc_dict[expected_class][1] += 1
        if expected_class == predicted[i]:
            acc_dict[expected_class][0] += 1
    for k,v in acc_dict.items():
        print("Accuracy for class {}: {}".format(k, v[0]/v[1]))

In [None]:
trainval_softmaxes = model.predict(x=trainval_loader)
trainval_predictions = model.predict_classes(x=trainval_loader, batch_size=None)
thresholds = generate_thresholds(trainval_softmaxes, trainval_predictions, y_trainval)

test_softmaxes = model.predict(x=test_loader)
test_class = []
for softmax in test_softmaxes:
    test_class.append(check_negative(softmax.tolist(), thresholds))
    
evaluate(test_class, y_test)

Overall accuracy: 0.5054378935317687
Accuracy for class 2: 0.9310344827586207
Accuracy for class 4: 0.9651162790697675
Accuracy for class 0: 0.9808612440191388
Accuracy for class 3: 0.8947368421052632
Accuracy for class 1: 0.972972972972973
Accuracy for class -1: 0.0545876887340302


In [None]:
# Macro F1-score will give the same importance to each label/class. It will be low for models that only perform well on the common classes while performing poorly on the rare classes.
macro_f1_score = f1_score(test_class, y_test, average='macro')
print("Macro F1 score: {}".format(macro_f1_score))

# Micro-averaging will put more emphasis on the common labels in the data set.
micro_f1_score = f1_score(test_class, y_test, average='micro')
print("Micro F1 score: {}".format(micro_f1_score))

# This alters ‘macro’ to account for label imbalance; it can result in an F-score that is not between precision and recall.
weighted_f1_score = f1_score(test_class, y_test, average='weighted')
print("Weighted F1 score: {}".format(weighted_f1_score))

# Scores for each class are returned.
norm_f1_score = f1_score(test_class, y_test, average=None)
print("F1 score for each class: {}".format(norm_f1_score))

Macro F1 score: 0.5644870096815503
Micro F1 score: 0.5054378935317687
Weighted F1 score: 0.6273909324342298
F1 score for each class: [0.10295728 0.6710311  0.68246445 0.76056338 0.62385321 0.54605263]


### Set threshold using ROC curve

In [None]:
def generate_thresholds_roc(softmaxes, expected):
    thresholds = []
    # create binary class for each class
    for i in range(5):
      binary_expected = [1 if x == i else 0 for x in expected]
      fpr, tpr, threshold = roc_curve(binary_expected, np.amax(softmaxes, 1))
      i = np.arange(len(tpr)) 
      roc = pd.DataFrame({'tf' : pd.Series(tpr-(1-fpr), index=i), 'threshold' : pd.Series(threshold, index=i)})
      roc_t = roc.iloc[(roc.tf-0).abs().argsort()[:1]]
      updated_threshold = list(roc_t['threshold'])[0]
      thresholds.append(updated_threshold)
    return thresholds

def check_negative_roc(sample_softmax, thresholds):
    probability = max(sample_softmax)
    idx = sample_softmax.index(probability)
    count = 0
    for t in thresholds:
        if probability < t:
          count += 1
    if count == len(thresholds):
      return -1
    return idx

In [None]:
roc_thresholds = generate_thresholds_roc(trainval_softmaxes, y_trainval)
print("Optimal threshold set by ROC curve: {}".format(roc_thresholds))

Optimal threshold set by ROC curve: [0.9997814297676086, 0.9998383522033691, 0.9995611310005188, 0.9993612170219421, 0.9996367692947388]


In [None]:
test_softmaxes = model.predict(x=test_loader)
test_class = []
for softmax in test_softmaxes:
    test_class.append(check_negative_roc(softmax.tolist(), roc_thresholds))
    
evaluate(test_class, y_test)

Overall accuracy: 0.7057813394390383
Accuracy for class 2: 0.4630541871921182
Accuracy for class 4: 0.5116279069767442
Accuracy for class 0: 0.6363636363636364
Accuracy for class 3: 0.33771929824561403
Accuracy for class 1: 0.7432432432432432
Accuracy for class -1: 0.9128919860627178


In [None]:
trainval_predictions = model.predict_classes(x=trainval_loader, batch_size=None)
evaluate(trainval_predictions, y_trainval)

Overall accuracy: 0.9774202653118825
Accuracy for class 4: 0.9927536231884058
Accuracy for class 0: 0.9949431099873578
Accuracy for class 3: 0.9637305699481865
Accuracy for class 1: 0.9802816901408451
Accuracy for class 2: 0.9560853199498118


In [None]:
trainval_class = []
for softmax in trainval_softmaxes:
    trainval_class.append(check_negative_roc(softmax.tolist(), roc_thresholds))
    
evaluate(trainval_class, y_trainval)

Overall accuracy: 0.5619531470505221
Accuracy for class 4: 0.5640096618357487
Accuracy for class 0: 0.7243994943109987
Accuracy for class 3: 0.3898963730569948
Accuracy for class 1: 0.7295774647887324
Accuracy for class 2: 0.4905897114178168


In [None]:
macro_f1_score = f1_score(test_class, y_test, average='macro')
print("Macro F1 score: {}".format(macro_f1_score))

micro_f1_score = f1_score(test_class, y_test, average='micro')
print("Micro F1 score: {}".format(micro_f1_score))

weighted_f1_score = f1_score(test_class, y_test, average='weighted')
print("Weighted F1 score: {}".format(weighted_f1_score))

norm_f1_score = f1_score(test_class, y_test, average=None)
print("F1 score for each class: {}".format(norm_f1_score))

Macro F1 score: 0.6427631320302183
Micro F1 score: 0.6971951917572983
Weighted F1 score: 0.7176721974743813
F1 score for each class: [0.75383899 0.68452381 0.73770492 0.58983051 0.50326797 0.58741259]


# Feature Extractor

### Base Model

In [None]:
feature_extractor = Model(inputs = model.inputs, outputs = model.layers[-2].output)
feature_extractor.summary()

Model: "functional_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
densenet201_input (InputLaye [(None, None, None, 3)]   0         
_________________________________________________________________
densenet201 (Functional)     (None, None, None, 1920)  18321984  
_________________________________________________________________
global_average_pooling2d (Gl (None, 1920)              0         
Total params: 18,321,984
Trainable params: 18,092,928
Non-trainable params: 229,056
_________________________________________________________________


### VGGish Model

In [None]:
class EmbeddingsLoader(keras.utils.Sequence):
  def __init__(self, x_set, y_set, batch_size, emb_dir):
    self.x, self.y = x_set, y_set
    self.batch_size = batch_size
    self.emb_dir = emb_dir

  def __len__(self):
    return int(np.ceil(len(self.x) / self.batch_size))

  def __getitem__(self, idx):
    batch_x = self.x[idx * self.batch_size:(idx + 1) *
    self.batch_size]
    batch_y = self.y[idx * self.batch_size:(idx + 1) *
        self.batch_size]

    batchEmbs = []
    for fileName in batch_x:
        emb = np.load(self.emb_dir + fileName + ".npy")[0]
        batchEmbs.append(emb)
    batchEmbs = np.array(batchEmbs)
    return batchEmbs, np.array(batch_y)

In [None]:
vggish_train_loader = EmbeddingsLoader(X_train, y_train, batch_size, embeddings_path)
vggish_test_loader = EmbeddingsLoader(X_test, y_test, batch_size, embeddings_path)
vggish_val_loader = EmbeddingsLoader(X_val, y_val, batch_size, embeddings_path)
vggish_trainval_loader = EmbeddingsLoader(X_trainval, y_trainval, batch_size, embeddings_path)

# Similarity to centroid

## Euclidean Distance

In [None]:
def euclidean_dist(centroid, data_point):
    return LA.norm(centroid - data_point)
    
def generate_centroids(features, expected):
    result_dict = {}
    for i in range(len(features)):
        c = expected[i]
        f = features[i]
        if c not in result_dict:
            result_dict[c] = [np.array(f)]
        else:
            result_dict[c].append(np.array(f))
    for k, v in result_dict.items():
        result_dict[k] = [v, np.mean(np.array(v), axis=0)]
    return result_dict

def generate_thresholds_dist(centroids):
    result_dict = {}
    for k, [points, centroid] in centroids.items():
        current_max = -1
        for point in points:
            current_dist = euclidean_dist(centroid, point)
            if current_max < current_dist:
                current_max = current_dist
        result_dict[k] = current_max
    return result_dict

In [None]:
def check_negative(threshold, centroid, feature):
    dist = euclidean_dist(centroid, feature)
    print(dist)
    print(threshold)
    print()
    if dist > threshold:
        return True
    return False

### Base Model

In [None]:
extracted_features = feature_extractor.predict(x = trainval_loader)

centroids_dict = generate_centroids(extracted_features, y_trainval)
thresholds_dict = generate_thresholds_dist(centroids_dict)
print(thresholds_dict)

{4: 27.000278, 0: 19.336025, 3: 54.661667, 1: 31.284636, 2: 19.541674}


In [None]:
test_features = feature_extractor.predict(x = test_loader)

In [None]:
test_classes = model.predict_classes(x=test_loader, batch_size=None)

In [None]:
updated_predictions = []

for i in range(len(test_features)):
    predicted_class = test_classes[i]
    threshold = thresholds_dict[predicted_class]
    centroid = centroids_dict[predicted_class][1]
    feature = test_features[i]
    for c, thresh in thresholds_dict.items():
        print(euclidean_dist(centroids_dict[c][1], feature))
    print()
    check = check_negative(threshold, centroid, feature)
    if check:
        updated_predictions.append(-1)
    else:
        updated_predictions.append(predicted_class)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
16.630543

11.560316
54.661667

13.521412
10.466235
17.102825
22.627865
21.062817

10.466235
19.336025

8.5056
19.674131
16.821896
22.186884
21.665129

8.5056
27.000278

20.338982
19.79765
18.49885
14.87995
9.970603

9.970603
19.541674

12.173965
13.282677
16.526258
18.785336
19.52238

12.173965
27.000278

14.630336
13.417096
16.11815
17.617628
15.089116

13.417096
19.336025

15.043248
15.390046
14.963867
14.164211
14.761058

14.164211
31.284636

10.70269
21.163765
15.051924
18.396797
15.602848

10.70269
27.000278

17.00115
19.324425
17.004105
10.715025
12.919426

10.715025
31.284636

18.315367
20.185625
18.192015
16.14195
8.785779

8.785779
19.541674

10.5086155
18.5901
12.483031
20.729748
14.76169

12.483031
54.661667

11.452063
16.847862
12.59609
20.705847
17.013113

12.59609
54.661667

18.714346
17.956406
16.004002
13.070864
10.561116

10.561116
19.541674

17.548315
13.048171
12.897671
18.790792
17.949215

12.897671
5

In [None]:
evaluate(updated_predictions, y_test)

Overall accuracy: 0.4785346307956497
Accuracy for class 2: 0.9359605911330049
Accuracy for class 4: 0.9651162790697675
Accuracy for class 0: 0.9856459330143541
Accuracy for class 3: 0.8903508771929824
Accuracy for class 1: 0.9594594594594594
Accuracy for class -1: 0.0


In [None]:
macro_f1_score = f1_score(updated_predictions, y_test, average='macro')
print("Macro F1 score: {}".format(macro_f1_score))

micro_f1_score = f1_score(updated_predictions, y_test, average='micro')
print("Micro F1 score: {}".format(micro_f1_score))

weighted_f1_score = f1_score(updated_predictions, y_test, average='weighted')
print("Weighted F1 score: {}".format(weighted_f1_score))

norm_f1_score = f1_score(updated_predictions, y_test, average=None)
print("F1 score for each class: {}".format(norm_f1_score))

Macro F1 score: 0.5357650047244799
Micro F1 score: 0.4785346307956496
Weighted F1 score: 0.6312004428549921
F1 score for each class: [0.         0.66237942 0.6635514  0.73929961 0.60687593 0.54248366]


### Vggish Model

In [None]:
vggish_trainval_features = None
for data, labels in tqdm(vggish_trainval_loader):
  if vggish_trainval_features is None: vggish_trainval_features = data
  else: vggish_trainval_features = np.concatenate((vggish_trainval_features, data), axis=0)

100%|██████████| 222/222 [17:32<00:00,  4.74s/it]


In [None]:
vggish_centroids_dict = generate_centroids(vggish_trainval_features, y_trainval)
vggish_thresholds_dict = generate_thresholds_dist(vggish_centroids_dict)
print(vggish_thresholds_dict)

{4: 1127.6519556049934, 0: 1239.9636677483186, 3: 1271.5463307512155, 1: 1109.0767031765936, 2: 1116.1240536816063}


In [None]:
test_classes = model.predict_classes(x=test_loader, batch_size=None)

In [None]:
vggish_test_features = None
for data, labels in tqdm(vggish_test_loader):
  if vggish_test_features is None: vggish_test_features = data
  else: vggish_test_features = np.concatenate((vggish_test_features, data), axis=0)

100%|██████████| 110/110 [08:15<00:00,  4.50s/it]


In [None]:
updated_predictions = []

for i in tqdm(range(len(vggish_test_features))):
    predicted_class = test_classes[i]
    threshold = vggish_thresholds_dict[predicted_class]
    centroid = vggish_centroids_dict[predicted_class][1]
    feature = vggish_test_features[i]
    for c, thresh in vggish_thresholds_dict.items():
        print(euclidean_dist(vggish_centroids_dict[c][1], feature))
    print()
    check = check_negative(threshold, centroid, feature)
    if check:
        updated_predictions.append(-1)
    else:
        updated_predictions.append(predicted_class)

 15%|█▌        | 263/1747 [00:00<00:01, 1368.41it/s]

1046.2610174113559
1103.0231006406839
827.6471262243332
1018.3905728556084
773.1608702670204

773.1608702670204
1116.1240536816063

766.1053604550293
864.4318024447975
872.0836157759818
881.7485307939168
916.9221857249753

766.1053604550293
1127.6519556049934

735.9690681079269
821.1705012678856
810.542322288585
831.2585679368956
852.0688247598399

735.9690681079269
1127.6519556049934

944.7722296035208
925.892425477084
814.9600589426144
935.4494188459377
722.3697712346581

722.3697712346581
1116.1240536816063

685.3740035682749
645.4510685655019
660.8251426620662
735.9020744002061
636.5292247982173

685.3740035682749
1127.6519556049934

858.9725272352915
943.4179941828539
911.0344764940944
840.1330468122077
968.8815946032774

858.9725272352915
1127.6519556049934

903.8860978615073
780.0924535896329
1025.234978122601
958.6075441336286
1026.9875303638535

780.0924535896329
1239.9636677483186

1031.9691173028014
980.1233873101374
838.5758521731444
1012.8136827987036
753.1920444208149

75

 27%|██▋       | 476/1747 [00:00<00:01, 1195.80it/s]

673.8099838473923
657.1972629801544
728.8327873172561
751.7257879636863
738.4722834983337

657.1972629801544
1239.9636677483186

670.3795814471918
629.0361195410934
696.8441844206413
683.2518752834546
633.102157912498

633.102157912498
1116.1240536816063

608.6790346966807
461.87646702033896
732.9768329000055
740.2446481825211
751.5712164658523

461.87646702033896
1239.9636677483186

736.3031699650213
810.4799664214424
691.2733098156325
550.6979090664759
744.4143840810281

691.2733098156325
1271.5463307512155

718.6353315437888
593.7169465603581
773.0049978410835
759.6908811730679
759.632715893738

593.7169465603581
1239.9636677483186

682.7374306154168
696.2584889460346
812.9496643329813
767.950419106024
861.3727691875062

682.7374306154168
1127.6519556049934

709.1561670123515
823.0838275021935
681.9298174373164
683.3858753704399
761.4907487430269

709.1561670123515
1127.6519556049934

693.3011319403817
780.6302931638909
795.0033246929133
794.2318219828443
776.1562305185626

693.3011

 39%|███▉      | 689/1747 [00:00<00:00, 1127.84it/s]



527.6124013167663
1127.6519556049934

772.7445913826683
671.640446779769
707.2226813301535
805.4467982409933
639.0263599039412

639.0263599039412
1116.1240536816063

743.1204856335021
639.7326428691113
798.8873018065889
783.6358100704922
765.4455487820146

765.4455487820146
1116.1240536816063

693.1808591638744
556.6151777684554
725.5989477010769
682.2083825276801
751.8718238287959

556.6151777684554
1239.9636677483186

725.2828454655646
880.047539797855
956.7632265952814
935.2454139065784
947.5390192723197

725.2828454655646
1127.6519556049934

1034.0075421538152
1038.28945550688
889.1422709457864
992.5018769667807
846.0743942597297

889.1422709457864
1271.5463307512155

817.4194242280826
658.4112930666964
918.9967250068436
887.2781665276494
906.4742381121429

658.4112930666964
1239.9636677483186

996.0656945013089
1035.4597669351647
1035.0307440820613
916.2133007123948
1042.659797233745

916.2133007123948
1109.0767031765936

846.8911094923916
872.2771989155738
788.104853662284
659.

 52%|█████▏    | 910/1747 [00:00<00:00, 1116.93it/s]

762.0048788526391
803.3690569782189
628.5928472548317
675.2774002821941
737.9326507606272

628.5928472548317
1271.5463307512155

1261.7977860126914
1267.9387170245234
1061.1260265886772
1200.4868805740737
1033.4783648239732

1033.4783648239732
1116.1240536816063

599.4909963618275
655.6233991074664
670.0350418129203
693.9186955719474
703.8283213731971

670.0350418129203
1271.5463307512155

597.1421274867313
627.582817773024
681.4643093556958
672.5870222827275
671.9263173738319

597.1421274867313
1127.6519556049934

913.1897819770556
908.085657332323
806.165698490729
832.5283743467498
854.9763957029285

908.085657332323
1239.9636677483186

723.253899755243
821.0096970318394
808.4731924146259
726.1342664863582
855.6197375745331

723.253899755243
1127.6519556049934

947.7684064703981
949.6961379718883
787.1086798133741
876.7203103288678
754.8109220525514

754.8109220525514
1116.1240536816063

786.8446744278176
662.3559088039594
908.0817701609766
896.6810845177223
945.3364784749364

662.35

 64%|██████▍   | 1121/1747 [00:01<00:00, 1074.42it/s]

911.1571364462883

731.810083421346
1127.6519556049934

861.0348695066514
880.3049909595553
807.7208315459001
767.3539268303899
799.4226609154134

799.4226609154134
1116.1240536816063

1088.2708698296865
1132.6282156262496
1069.8454254981484
1089.025313539158
1052.7772655983501

1069.8454254981484
1271.5463307512155

706.0459610842208
543.8353342577055
709.9783319712234
771.4298639134424
716.1589390050398

716.1589390050398
1116.1240536816063

572.3617299295776
659.9997843089211
591.0668551064493
595.2719828188302
645.4818940939768

591.0668551064493
1271.5463307512155

776.1115143703386
924.397176578652
986.84312513707
930.3172235028023
995.2513133537592

776.1115143703386
1127.6519556049934

685.3472818563462
574.1247819286425
762.5225937386414
757.7983201963712
782.2785140828167

782.2785140828167
1116.1240536816063

683.905150028719
553.3351811937636
741.9775416298991
733.0827930161679
772.7306701898673

553.3351811937636
1239.9636677483186

965.9222611692301
987.4562127901343
909.

 77%|███████▋  | 1353/1747 [00:01<00:00, 1072.88it/s]

793.6248550953513
761.4598865870574
815.0510978860502

600.7163126080902
1239.9636677483186

1075.4377350031266
1095.660625287846
970.2386858481315
1042.7644439861467
980.2642205561317

970.2386858481315
1271.5463307512155

1007.6952733329856
1073.5846730960654
1016.0668230305816
1033.8504310637743
1001.9705834041052

1016.0668230305816
1271.5463307512155

1070.3306763104865
1056.93804039294
967.8775193640993
1017.9589997039947
947.9620363725805

1070.3306763104865
1127.6519556049934

915.9620325412297
893.5867220512966
942.0318926612414
925.4601115316121
914.7377021905895

942.0318926612414
1271.5463307512155

731.0483212179696
744.7449395748424
735.3448285279051
718.3520546181602
753.1847146441149

744.7449395748424
1239.9636677483186

609.9448701103806
643.6254169632732
755.369155369866
671.6307260099192
759.5639192143623

759.5639192143623
1116.1240536816063

716.2144802864827
617.6725540782321
795.0101500008614
779.0352618225087
812.3376725061962

812.3376725061962
1116.1240536816

 90%|████████▉ | 1568/1747 [00:01<00:00, 1041.15it/s]

1022.3911200391766
1271.5463307512155

622.6785258111413
691.6857883651949
673.1410123757249
636.384157173081
707.2459529484729

622.6785258111413
1127.6519556049934

663.5667626084834
709.8413310647827
770.7591330597813
769.2821153294034
774.5873948047115

663.5667626084834
1127.6519556049934

757.7420536939255
659.660812731963
801.4532032813819
833.5592992463389
810.0951752995135

801.4532032813819
1271.5463307512155

1003.4825244821749
1054.2806492935965
1020.3588435900749
1031.1788500803366
1009.4642303921873

1020.3588435900749
1271.5463307512155

640.0302420648136
595.3697039997863
623.869577719058
654.7544715663299
658.1895445894838

623.869577719058
1271.5463307512155

880.9894487592782
877.1486725006301
813.2430273331254
901.776269687396
839.8286882846305

880.9894487592782
1127.6519556049934

678.3517806372854
672.9218980691372
625.7882314530351
600.1207610933477
635.2167028211098

678.3517806372854
1127.6519556049934

1110.2206816508478
1091.4830449492636
1002.2578480834194


100%|██████████| 1747/1747 [00:01<00:00, 1065.84it/s]

877.4749184430711
846.9990975750087
817.2124591385226
874.7647819099183
819.7977570283686

817.2124591385226
1271.5463307512155

601.4474454165747
583.2810253031723
573.4880538693172
588.6751905830031
588.319003182186

573.4880538693172
1271.5463307512155

578.0392127965807
518.5216925348084
484.8632185234625
542.7794286604357
517.9329281018748

484.8632185234625
1271.5463307512155

1048.641709734657
996.5809545537944
1035.3711750916236
1086.4872514730987
1059.095993540191

996.5809545537944
1239.9636677483186

757.4561246738012
829.8554165516488
879.6338020170804
828.8552931057045
910.3367948842266

829.8554165516488
1239.9636677483186

746.5631450842745
676.3645249869097
766.2079959179082
774.0374547890984
809.5916228592665

766.2079959179082
1271.5463307512155

678.7618741296808
717.8317795330214
561.4406835945055
620.1279976981057
637.0656833388266

561.4406835945055
1271.5463307512155

715.2046973186926
636.1603002995415
621.3504211459731
743.0815063493086
573.9704632358165

621.3




In [None]:
evaluate(updated_predictions, y_test)

Overall accuracy: 0.48311390955924444
Accuracy for class 2: 0.9359605911330049
Accuracy for class 4: 0.9651162790697675
Accuracy for class 0: 0.9856459330143541
Accuracy for class 3: 0.8947368421052632
Accuracy for class 1: 0.972972972972973
Accuracy for class -1: 0.006968641114982578


In [None]:
macro_f1_score = f1_score(updated_predictions, y_test, average='macro')
print("Macro F1 score: {}".format(macro_f1_score))

micro_f1_score = f1_score(updated_predictions, y_test, average='micro')
print("Micro F1 score: {}".format(micro_f1_score))

weighted_f1_score = f1_score(updated_predictions, y_test, average='weighted')
print("Weighted F1 score: {}".format(weighted_f1_score))

norm_f1_score = f1_score(updated_predictions, y_test, average=None)
print("F1 score for each class: {}".format(norm_f1_score))

Macro F1 score: 0.5408110073556099
Micro F1 score: 0.48311390955924444
Weighted F1 score: 0.63235153942601
F1 score for each class: [0.01384083 0.66237942 0.6728972  0.74074074 0.60895522 0.54605263]


## Cosine Similarity

In [None]:
def cos_similarity(centroid, data_point):
    centroid = centroid.reshape(1, -1)
    data_point = data_point.reshape(1, -1)
    return LA.norm(cosine_similarity(centroid, data_point))
    
def generate_centroids(features, expected):
    result_dict = {}
    for i in range(len(features)):
        c = expected[i]
        f = features[i]
        print
        if c not in result_dict:
            result_dict[c] = [np.array(f)]
        else:
            result_dict[c].append(np.array(f))
    for k, v in result_dict.items():
        result_dict[k] = [v, np.mean(np.array(v), axis=0)]
    return result_dict

def generate_thresholds_sim(centroids):
    result_dict = {}
    for k, [points, centroid] in centroids.items():
        current_min = 1
        for point in points:
            current_sim = cos_similarity(centroid, point)
            if current_min > current_sim:
                current_min = current_sim
        result_dict[k] = current_min
    return result_dict

In [None]:
def check_negative_sim(threshold, centroid, feature):
    similarity = cos_similarity(centroid, feature)
    print(similarity)
    print(threshold)
    print()
    if similarity < threshold:
        return True
    return False

### Base Model

In [None]:
extracted_features = feature_extractor.predict(x = trainval_loader)

In [None]:
centroids_sim_dict = generate_centroids(extracted_features, y_trainval)
thresholds_sim_dict = generate_thresholds_sim(centroids_sim_dict)
print(thresholds_sim_dict)

{4: 0.6452441, 0: 0.5903251, 3: 0.607296, 1: 0.5934234, 2: 0.59954524}


In [None]:
test_features = feature_extractor.predict(x = test_loader)

In [None]:
test_classes = model.predict_classes(x=test_loader, batch_size=None)

Instructions for updating:
Please use instead:* `np.argmax(model.predict(x), axis=-1)`,   if your model does multi-class classification   (e.g. if it uses a `softmax` last-layer activation).* `(model.predict(x) > 0.5).astype("int32")`,   if your model does binary classification   (e.g. if it uses a `sigmoid` last-layer activation).


In [None]:
updated_predictions = []

for i in range(len(test_features)):
    predicted_class = test_classes[i]
    threshold = thresholds_sim_dict[predicted_class]
    centroid = centroids_sim_dict[predicted_class][1]
    feature = test_features[i]
    for c, thresh in thresholds_sim_dict.items():
        print(cos_similarity(centroids_sim_dict[c][1], feature))
    print()
    check = check_negative_sim(threshold, centroid, feature)
    if check:
        updated_predictions.append(-1)
    else:
        updated_predictions.append(predicted_class)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
0.5296407

0.7806202
0.607296

0.7528894
0.847187
0.5786557
0.3764298
0.3668459

0.847187
0.5903251

0.9076324
0.4917856
0.61866164
0.42958468
0.37093684

0.9076324
0.6452441

0.37411806
0.38554868
0.44314006
0.72389853
0.845824

0.845824
0.59954524

0.7842076
0.7216675
0.5336968
0.5228219
0.35033852

0.7842076
0.6452441

0.6662873
0.7106899
0.52480716
0.5762829
0.59819704

0.7106899
0.5903251

0.6431348
0.59105843
0.5886396
0.79091716
0.6116166

0.79091716
0.5934234

0.8422859
0.354815
0.66268384
0.58199716
0.6411814

0.8422859
0.6452441

0.55645174
0.3964571
0.5157601
0.88007766
0.7301694

0.88007766
0.5934234

0.5038367
0.37297472
0.47213423
0.671047
0.88231516

0.88231516
0.59954524

0.84585524
0.44010764
0.7441691
0.41348484
0.64193094

0.7441691
0.607296

0.81244373
0.5391653
0.7375891
0.40986514
0.514054

0.7375891
0.607296

0.43879053
0.4628412
0.5577142
0.80240864
0.82462656

0.82462656
0.59954524

0.5141313
0.73

In [None]:
evaluate(updated_predictions, y_test)

Overall accuracy: 0.48082427017744706
Accuracy for class 2: 0.9359605911330049
Accuracy for class 4: 0.9651162790697675
Accuracy for class 0: 0.9856459330143541
Accuracy for class 3: 0.8947368421052632
Accuracy for class 1: 0.972972972972973
Accuracy for class -1: 0.0023228803716608595


In [None]:
macro_f1_score = f1_score(updated_predictions, y_test, average='macro')
print("Macro F1 score: {}".format(macro_f1_score))

micro_f1_score = f1_score(updated_predictions, y_test, average='micro')
print("Micro F1 score: {}".format(micro_f1_score))

weighted_f1_score = f1_score(updated_predictions, y_test, average='weighted')
print("Weighted F1 score: {}".format(weighted_f1_score))

norm_f1_score = f1_score(updated_predictions, y_test, average=None)
print("F1 score for each class: {}".format(norm_f1_score))

Macro F1 score: 0.5382164978576821
Micro F1 score: 0.48082427017744706
Weighted F1 score: 0.6327856058642363
F1 score for each class: [0.00463499 0.66237942 0.66976744 0.73929961 0.60895522 0.5442623 ]


### Vggish Model

In [None]:
vggish_trainval_features = None
for data, labels in tqdm(vggish_trainval_loader):
  if vggish_trainval_features is None: vggish_trainval_features = data
  else: vggish_trainval_features = np.concatenate((vggish_trainval_features, data), axis=0)

100%|██████████| 222/222 [40:06<00:00, 10.84s/it]


In [None]:
vggish_centroids_sim_dict = generate_centroids(vggish_trainval_features, y_trainval)
vggish_thresholds_sim_dict = generate_thresholds_sim(vggish_centroids_sim_dict)
print(vggish_thresholds_sim_dict)

{4: 0.779731590950885, 0: 0.7462281809032909, 3: 0.7593273519904803, 1: 0.7803307604456015, 2: 0.7867504675040624}


In [None]:
test_classes = model.predict_classes(x=test_loader, batch_size=None)

In [None]:
vggish_test_features = None
for data, labels in tqdm(vggish_test_loader):
  if vggish_test_features is None: vggish_test_features = data
  else: vggish_test_features = np.concatenate((vggish_test_features, data), axis=0)

100%|██████████| 110/110 [19:50<00:00, 10.82s/it]


In [None]:
updated_predictions = []

for i in tqdm(range(len(vggish_test_features))):
    predicted_class = test_classes[i]
    threshold = vggish_thresholds_sim_dict[predicted_class]
    centroid = vggish_centroids_sim_dict[predicted_class][1]
    feature = vggish_test_features[i]
    for c, thresh in vggish_thresholds_sim_dict.items():
        print(cos_similarity(vggish_centroids_sim_dict[c][1], feature))
    print()
    check = check_negative_sim(threshold, centroid, feature)
    if check:
        updated_predictions.append(-1)
    else:
        updated_predictions.append(predicted_class)

  4%|▍         | 73/1747 [00:00<00:04, 361.43it/s]

0.7851245572579199
0.7602094239258231
0.8703186423304896
0.7973615133555609
0.8884608480103453

0.8884608480103453
0.7867504675040624

0.8814402896706768
0.8471728053213162
0.8438087856662726
0.8407432719466528
0.8265842112476399

0.8814402896706768
0.779731590950885

0.8853949457442012
0.8563919828272295
0.8595814774273012
0.8527813094839409
0.8444105499667454

0.8853949457442012
0.779731590950885

0.8132379185198653
0.821188535285692
0.8624629153885393
0.8174141577883105
0.8935495105856925

0.8935495105856925
0.7867504675040624

0.9031848364325744
0.9147195377831752
0.9105431385183712
0.887583197198728
0.9174035492114221

0.9031848364325744
0.779731590950885

0.8577658651435444
0.8261796894850614
0.8384006062861222
0.8644224594927836
0.8156338290372265

0.8577658651435444
0.779731590950885

0.836439113775278
0.8805374496908045
0.7861208245614002
0.8151585120949778
0.7856096270304576

0.8805374496908045
0.7462281809032909

0.7845284612392125
0.8068480913906129
0.8609255037148906
0.793

  8%|▊         | 142/1747 [00:00<00:04, 349.21it/s]


0.8887795209114556
0.9214512248248984
0.912767203388086
0.905676234206298

0.9286257474020524
0.779731590950885

0.7935604463330718
0.8082839580803756
0.8450503561200342
0.7860121405809435
0.8268872094063668

0.8450503561200342
0.7593273519904803

0.8625226380647557
0.8157797595423669
0.8471378570735298
0.9085477086344915
0.8305661531454953

0.9085477086344915
0.7803307604456015

0.8564358770886044
0.8647823607504797
0.8527236661171355
0.8388466000714551
0.8633451324648949

0.8647823607504797
0.7462281809032909

0.8799472270121722
0.8625647862538556
0.9239802280183232
0.871929892702302
0.8980396300594602

0.9239802280183232
0.7593273519904803

0.9332009663035505
0.8886357589995411
0.8665432102249251
0.8662379353902048
0.8578611988835937

0.9332009663035505
0.779731590950885

0.8903417018403581
0.9226971633781336
0.8693203379040062
0.8824639477889049
0.8563233049610519

0.8903417018403581
0.779731590950885

0.8834022190452743
0.859284240174675
0.8579732974323562
0.8645015889044949
0.83

 12%|█▏        | 205/1747 [00:00<00:04, 327.37it/s]


0.8555776939077668
0.8108865790645965
0.7925437345028923
0.7946580562402041
0.794864335167069

0.8555776939077668
0.779731590950885

0.8410528636212642
0.8583120834772615
0.8282235439957103
0.8274195483267643
0.8232537962996429

0.8583120834772615
0.7462281809032909

0.8659702971925086
0.8658445825836948
0.9187492931798044
0.8872147208709591
0.8876947532817998

0.9187492931798044
0.7593273519904803

0.8443441172876814
0.8724337258311936
0.8913480171945782
0.8431414956863715
0.9064921548630858

0.8913480171945782
0.7593273519904803

0.8769330912502674
0.8784745870250181
0.8917207253117456
0.8855045716096017
0.915087810685782

0.8769330912502674
0.779731590950885

0.9042098881372884
0.9406684910989331
0.8659900598884112
0.8664266896418802
0.8552059175286484

0.9406684910989331
0.7462281809032909

0.8275477604501023
0.8442337376508116
0.9115073840277025
0.8411090188261676
0.9036746144590093

0.9115073840277025
0.7593273519904803

0.8412059320624483
0.8407645958494784
0.8819390151431804
0

 16%|█▌        | 271/1747 [00:00<00:04, 322.89it/s]


0.8346716153608243
0.8773966447113584
0.9156011843189822
0.8656299137783607

0.9156011843189822
0.7803307604456015

0.9115912974386091
0.9255649530849877
0.9299211771241896
0.9132561096870845
0.9249338937261369

0.9255649530849877
0.7462281809032909

0.8676239908533513
0.9092047591334007
0.8728461993622827
0.8579711559400909
0.8728643405405883

0.9092047591334007
0.7462281809032909

0.8920488876356817
0.9054124545996747
0.9300748900583771
0.9100553821609377
0.9196051749447194

0.9300748900583771
0.7593273519904803

0.9125922945310176
0.8962249980900854
0.9038141984402137
0.9071082845216718
0.9160340046291645

0.9160340046291645
0.7867504675040624

0.871411342298565
0.8659250514837845
0.9232790603709304
0.8810978845370349
0.906630679662761

0.9232790603709304
0.7593273519904803

0.833554118212887
0.8079657104568414
0.8327112509070627
0.8671586894185176
0.8095892837211393

0.8671586894185176
0.7803307604456015

0.8953876268292348
0.8990785065929687
0.9082697480508907
0.9262617007190255


 19%|█▉        | 337/1747 [00:01<00:04, 316.93it/s]


0.9208110936214657

0.9469580202396222
0.7593273519904803

0.9079403595790154
0.8803007379870942
0.8822450050857307
0.9082881466760383
0.8800160593754276

0.8803007379870942
0.7462281809032909

0.8823018598962408
0.8613751128795826
0.8842035672006865
0.920911210172081
0.8805322405904089

0.920911210172081
0.7803307604456015

0.9206527265079032
0.8757337448770097
0.8852919011945851
0.8992229914367851
0.8861510069087105

0.8852919011945851
0.7593273519904803

0.8681095597295668
0.8876628132222435
0.8845835181339606
0.8760477401126308
0.9050795010299783

0.9050795010299783
0.7867504675040624

0.853161296103609
0.8627541933819807
0.8612422169758838
0.8307137097153771
0.8738522319953052

0.8738522319953052
0.7867504675040624

0.8132006846015081
0.8456493819568008
0.8565476590385172
0.8105660201453706
0.874047730203686

0.874047730203686
0.7867504675040624

0.8778807461236937
0.9129304830827594
0.9000835695420228
0.8793067253693924
0.9242501325868213

0.9242501325868213
0.7867504675040624



 23%|██▎       | 404/1747 [00:01<00:04, 323.94it/s]



0.8197682487641635
0.7867504675040624

0.8265599230068582
0.7727864120788809
0.8119364009464408
0.8138607573859238
0.7770248353637506

0.8265599230068582
0.779731590950885

0.908715704990583
0.8774311470375065
0.8981848231689953
0.8759960677975094
0.89943895127166

0.89943895127166
0.7867504675040624

0.9420899999411461
0.9231071184976488
0.8707706230993639
0.900040850564839
0.8615948410927187

0.9420899999411461
0.779731590950885

0.8805633403088682
0.8775134480786027
0.9309531254377073
0.8911146912645994
0.8974642023241302

0.9309531254377073
0.7593273519904803

0.7876583463325268
0.7938547481202967
0.817794053438562
0.7555560882741146
0.7988140996298393

0.817794053438562
0.7593273519904803

0.8571869763003983
0.8756984449929689
0.8815142535441546
0.8563744322646748
0.890938013158557

0.8815142535441546
0.7593273519904803

0.7693093838448246
0.7663489215338988
0.8197607721471947
0.7507586788945517
0.8074132468020344

0.8197607721471947
0.7593273519904803

0.9184384443780491
0.9123

 27%|██▋       | 468/1747 [00:01<00:04, 318.05it/s]

0.9185104526379573
0.9168625221018312
0.9310010313279727
0.9279245578275723
0.9328598280496563

0.9328598280496563
0.7867504675040624

0.88768750572875
0.9226750867433664
0.8745740512139251
0.8711012788204109
0.882371948488335

0.882371948488335
0.7867504675040624

0.8456148731257879
0.8375706093884451
0.8988140362945132
0.8767847896336254
0.9196669759890319

0.9196669759890319
0.7867504675040624

0.7939819680086624
0.7824788394593729
0.8014472334170999
0.8430760289751396
0.7772462888075573

0.8430760289751396
0.7803307604456015

0.8966721509686839
0.931896567681137
0.8928201487485109
0.8560398235396225
0.8768111093201176

0.931896567681137
0.7462281809032909

0.8856177435912179
0.8886555117857423
0.8670911042701928
0.8746753721444119
0.8610185813399724

0.8610185813399724
0.7867504675040624

0.7711117438657961
0.7439047491552542
0.8488195484512712
0.7847526762984665
0.8548921545702894

0.8488195484512712
0.7593273519904803

0.8911482041879828
0.9375824246410094
0.8757819634954475
0.87

 30%|███       | 532/1747 [00:01<00:03, 312.05it/s]



0.8569687814852065
0.7593273519904803

0.871987114674774
0.9207132232071749
0.8349841127746356
0.8472081222559031
0.8398425847777711

0.9207132232071749
0.7462281809032909

0.7981012928196691
0.7814972049538754
0.7804376371769027
0.8307728333259394
0.7773326083614815

0.8307728333259394
0.7803307604456015

0.854351902703084
0.845170138582684
0.8748994681016954
0.91463750511551
0.8604352527864507

0.91463750511551
0.7803307604456015

0.8649329845135981
0.9120428968291552
0.8188781166642495
0.8350525702939163
0.8196559123340907

0.9120428968291552
0.7462281809032909

0.8784273601941885
0.9025681927619966
0.8596465770313197
0.8519894798306237
0.855934926755011

0.9025681927619966
0.7462281809032909

0.9207157327620119
0.9538886336142929
0.9190965429536784
0.8970479982757594
0.9216841524924828

0.9216841524924828
0.7867504675040624

0.902078993527013
0.916442484567661
0.9281949672181584
0.8808862742083904
0.9212095241736549

0.9281949672181584
0.7593273519904803

0.8207485724155543
0.846

 34%|███▍      | 600/1747 [00:01<00:03, 320.59it/s]


0.895212916152787
0.8818185728448882
0.8550207769234457
0.8799662333394619

0.895212916152787
0.7462281809032909

0.887096038954867
0.9117560375982734
0.873488322186838
0.8643620577551411
0.8836700740398071

0.8836700740398071
0.7867504675040624

0.918321637464375
0.8818143267487948
0.8883584894785244
0.8890875125331086
0.8805181686413945

0.918321637464375
0.779731590950885

0.8781968734788632
0.8954674564513183
0.8526877797714781
0.8691580496754547
0.8382791183854256

0.8781968734788632
0.779731590950885

0.8200934507334443
0.8595120958827775
0.8008684815601411
0.8079618843809276
0.8060538174621852

0.8595120958827775
0.7462281809032909

0.8008458029950749
0.8045494751792048
0.828927396712183
0.8110772535980627
0.8417082462116294

0.8110772535980627
0.7803307604456015

0.8944265521060011
0.9076757069799899
0.8703990467981432
0.8877983715411222
0.851120845075515

0.9076757069799899
0.7462281809032909

0.8970157949589073
0.8971855640978983
0.8981451031140953
0.8854340561762076
0.88500

 38%|███▊      | 664/1747 [00:02<00:03, 293.50it/s]


0.9163188574563362
0.7593273519904803

0.9019857644756394
0.8980655784392323
0.9388090357063696
0.917240998916757
0.9306102105093639

0.8980655784392323
0.7462281809032909

0.9252207147873575
0.9495437891331657
0.9408518450895575
0.9202349421890254
0.9310266031980613

0.9408518450895575
0.7593273519904803

0.8688533756056573
0.8777994250673931
0.8466625610480463
0.8837374531270787
0.8371816490797533

0.8688533756056573
0.779731590950885

0.8351748869544358
0.8830034006388352
0.864989593326746
0.8140281984492987
0.8839731771713961

0.8830034006388352
0.7462281809032909

0.8550872650409447
0.8207620797614027
0.8846115521144859
0.8779942484163181
0.8452257537905281

0.8846115521144859
0.7593273519904803

0.863292615651712
0.8960831594141196
0.8364776333646551
0.8398419877738523
0.8324930643523959

0.8960831594141196
0.7462281809032909

0.7737487381369352
0.7520470621759288
0.8483828712275981
0.7738422452498497
0.845967800290356

0.8483828712275981
0.7593273519904803

0.904710809661811
0.

 42%|████▏     | 729/1747 [00:02<00:03, 306.24it/s]


0.8836779249020303
0.8775717379197283
0.8814862548507334

0.8836779249020303
0.7593273519904803

0.8225296477531414
0.822391570372555
0.8905364442846687
0.8529576349091277
0.8785537640450201

0.8905364442846687
0.7593273519904803

0.8628439486654349
0.837877998641853
0.9086577424204907
0.8722698180472309
0.9339684549813578

0.9339684549813578
0.7867504675040624

0.9016790107274042
0.8977479922335057
0.9385935099705015
0.9169586189890715
0.9304059232189149

0.8977479922335057
0.7462281809032909

0.8564002383715226
0.814487370934879
0.877601226074806
0.8288968160046839
0.8684675573638988

0.877601226074806
0.7593273519904803

0.8965594565853992
0.8933627196963397
0.8660861455709011
0.8513235304482303
0.8734207730997536

0.8965594565853992
0.779731590950885

0.8397436553532307
0.8231620999138287
0.834779427144406
0.8547303088644385
0.8447329635098566

0.8547303088644385
0.7803307604456015

0.9160170416805204
0.9210836353904756
0.9335454880663036
0.9002818287445276
0.9309063227215075

0.9

 46%|████▌     | 795/1747 [00:02<00:03, 315.96it/s]


0.8650024130845753
0.9044488331613607
0.8656961493906533
0.9131104670557266

0.9131104670557266
0.7867504675040624

0.8380198083586952
0.7913378473492727
0.8518138745307013
0.7846314468443973
0.8319616272427794

0.8518138745307013
0.7593273519904803

0.9093183821016583
0.85802559624292
0.8455619567757415
0.8689992113213685
0.8462891415770288

0.9093183821016583
0.779731590950885

0.8779203271185456
0.9352524957025385
0.8607579202810833
0.8671683359803943
0.8584417403670943

0.9352524957025385
0.7462281809032909

0.9009909000067791
0.8936755185767842
0.9216544774501993
0.9208270937138351
0.9109661258500268

0.9216544774501993
0.7593273519904803

0.8640939723199795
0.8949397548076062
0.8353606820262305
0.8171938498360041
0.835998921757616

0.8949397548076062
0.7462281809032909

0.8907556872384026
0.8809546966341351
0.9197280627931304
0.9149022051615803
0.9270290781280688

0.9270290781280688
0.7867504675040624

0.8420674741063647
0.907816154585706
0.7961315627637902
0.8027843717145147
0.

 49%|████▉     | 858/1747 [00:02<00:02, 306.05it/s]

0.779731590950885

0.916046981965173
0.9497353134102627
0.9166360678244001
0.8969718098352961
0.9153059834998925

0.9497353134102627
0.7462281809032909

0.8631745068072503
0.9231164647815949
0.8128405790914979
0.8274105205845398
0.8140019299495331

0.9231164647815949
0.7462281809032909

0.8229469106996288
0.8057902342747645
0.8928086822804726
0.8490397184812133
0.9017912542216313

0.9017912542216313
0.7867504675040624

0.9310433413535064
0.8690104592232264
0.8524468972387865
0.8496805404998758
0.8489761158300658

0.9310433413535064
0.779731590950885

0.8404926188284098
0.8754412989167364
0.8124615358132493
0.8191384617010786
0.8072965861792328

0.8754412989167364
0.7462281809032909

0.9169042869170523
0.9459594020938346
0.93949200876616
0.8967264840546849
0.9387236841671447

0.9459594020938346
0.7462281809032909

0.8425182987967723
0.850452401145031
0.8720561644409054
0.8248045378836135
0.8869237160658623

0.8869237160658623
0.7867504675040624

0.8985087805237085
0.9004441027357579
0.9

 53%|█████▎    | 925/1747 [00:02<00:02, 317.84it/s]

0.9196060036082212
0.8789248278875835
0.8440965869603496
0.833006251589193
0.843361516047324

0.9196060036082212
0.779731590950885

0.8146653594676654
0.833606351832286
0.8561353717496799
0.7955139780568862
0.8507469584222758

0.8507469584222758
0.7867504675040624

0.890621445278951
0.8126371672163433
0.8088994564723921
0.8355808661085313
0.7913595617151485

0.890621445278951
0.779731590950885

0.8514793868111455
0.9106431402940951
0.7989632715216766
0.819590340958823
0.8004982095414152

0.9106431402940951
0.7462281809032909

0.8067612475965089
0.7834735691968859
0.8739481633243749
0.866433941536527
0.8700705670175872

0.8700705670175872
0.7867504675040624

0.8863669208463257
0.9040841632249088
0.8962085765089777
0.8899515876125208
0.8947346988006001

0.8899515876125208
0.7803307604456015

0.9054271348384703
0.8649232114286889
0.8932950763920302
0.9042414433076686
0.8929280994381612

0.8929280994381612
0.7867504675040624

0.8729783739788344
0.8248181826480879
0.8871104909962774
0.86902

 57%|█████▋    | 989/1747 [00:03<00:02, 300.09it/s]


0.779731590950885

0.8023261129835533
0.8118049162983614
0.814292348446969
0.8722827712632839
0.8117234535568609

0.8722827712632839
0.7803307604456015

0.8525398114420544
0.8123697012741189
0.7687071202396076
0.7741496683282552
0.7712840859160246

0.8525398114420544
0.779731590950885

0.9333551750320549
0.9392835025169022
0.9336655592702646
0.9404610350079303
0.9282777887655911

0.9282777887655911
0.7867504675040624

0.8417948090807295
0.8516281582202738
0.8661910092629028
0.8382395624344654
0.8688603211884061

0.8661910092629028
0.7593273519904803

0.8720255536712788
0.8757862456969053
0.9071311567698146
0.8841200168394613
0.9076846077609412

0.9071311567698146
0.7593273519904803

0.9081385195765175
0.9017661729582447
0.902066558720963
0.8952622276449236
0.8901995817427536

0.9081385195765175
0.779731590950885

0.8715618905330075
0.8777553346034542
0.9182691426664548
0.8782895280178284
0.9165738744072789

0.8777553346034542
0.7462281809032909

0.7631699431515903
0.820246470214425
0.

 60%|██████    | 1056/1747 [00:03<00:02, 314.11it/s]

0.8619070090609052
0.873664801390609
0.8582052300535183
0.8552433136932812

0.873664801390609
0.7593273519904803

0.8703881682288326
0.8249629932526987
0.8327067787775899
0.8445178871490857
0.8072337957772218

0.8703881682288326
0.779731590950885

0.8717402749384989
0.8828782941386527
0.9091508570627818
0.9142381351302962
0.9191754763325992

0.9191754763325992
0.7867504675040624

0.9074212195116952
0.9043328831082263
0.889493022214029
0.9287093646147553
0.8826502143724835

0.8826502143724835
0.7867504675040624

0.860206538984718
0.89816236175764
0.8389627155157405
0.8456628001408815
0.8375683074524451

0.89816236175764
0.7462281809032909

0.93539722004917
0.9466806347995332
0.9140237213054372
0.9225941203051451
0.9105134999381248

0.9466806347995332
0.7462281809032909

0.8749132850113506
0.9133218922605378
0.8397582535522058
0.8399143568355818
0.8555773328931553

0.9133218922605378
0.7462281809032909

0.8447936739932607
0.8211261872172395
0.7575332245129456
0.7657635152492227
0.7632455

 65%|██████▍   | 1129/1747 [00:03<00:01, 335.56it/s]


0.8720446186540145
0.89037898252024
0.8866962504640744
0.8989372025127254

0.8989372025127254
0.7867504675040624

0.7490664932923374
0.7445488203475218
0.8056937619318674
0.7828988074600054
0.8049081683571382

0.8049081683571382
0.7867504675040624

0.9096886488319783
0.9300133398991139
0.8925544872012463
0.8874522092925035
0.8930661127685136

0.9300133398991139
0.7462281809032909

0.893803208108009
0.8484159325274738
0.7997156055044278
0.8108696052364621
0.8033590890816131

0.893803208108009
0.779731590950885

0.8727381926729136
0.8935604871741724
0.8359145735148116
0.8374493057366151
0.8306595734139601

0.8727381926729136
0.779731590950885

0.8999650836643637
0.9111660400692365
0.852343471251829
0.8560856819347924
0.8499183510492425

0.8999650836643637
0.779731590950885

0.9038213913146937
0.8973108180619309
0.8386627216299833
0.8575174088141844
0.8460982397515022

0.9038213913146937
0.779731590950885

0.9029674205194431
0.9359708085450461
0.9292709677979992
0.8989339874555886
0.9281

 69%|██████▉   | 1205/1747 [00:03<00:01, 330.49it/s]


0.9237813270261233
0.779731590950885

0.7976690445924906
0.8351905525598187
0.8154987189640535
0.7812860125002702
0.8071424961325888

0.8071424961325888
0.7867504675040624

0.8975226457911545
0.8936912087745577
0.8979331104082946
0.9175154804530542
0.8993380637858465

0.8975226457911545
0.779731590950885

0.9108289880145177
0.9551925325725098
0.8998804621087302
0.8825975796628284
0.8953602447962915

0.9551925325725098
0.7462281809032909

0.8797781205639141
0.8991767142427065
0.849474122008462
0.8486148992189426
0.843884313968398

0.8797781205639141
0.779731590950885

0.8686042445120061
0.8615285202830619
0.8926604292161737
0.8945562968238836
0.8853754673622561

0.8686042445120061
0.779731590950885

0.9069190970733882
0.8768312293501123
0.8390517202650735
0.8610192859100645
0.8298859915179522

0.9069190970733882
0.779731590950885

0.7855260624025739
0.7850468772626426
0.8437639177017153
0.8244392378513814
0.8457786128444926

0.8244392378513814
0.7803307604456015

0.9386885128789321
0.9

 73%|███████▎  | 1275/1747 [00:03<00:01, 329.31it/s]


0.8737192592264996

0.9500753815686107
0.7462281809032909

0.889959595981018
0.8934085703899579
0.8627467195946894
0.8808802507960577
0.852727922605904

0.8934085703899579
0.7462281809032909

0.8461108201300238
0.8615492244789916
0.82005147864719
0.8054067574937385
0.8276266685733041

0.8461108201300238
0.779731590950885

0.8294914667618801
0.7586479176064879
0.7272533987556575
0.7331717854369699
0.7312180890560347

0.8294914667618801
0.779731590950885

0.9263317082389585
0.961777188380154
0.9184746069321661
0.9168735996657669
0.9225987803803524

0.9225987803803524
0.7867504675040624

0.8498809376971723
0.8542583819368628
0.8836181046916342
0.8517783140663556
0.884954603600894

0.8836181046916342
0.7593273519904803

0.9101643272674489
0.9294526549091819
0.930791337773373
0.91542991083124
0.9298796717505451

0.9298796717505451
0.7867504675040624

0.8017528140677892
0.7895301799207664
0.8327749663140123
0.8627325036713034
0.8187879532867685

0.8627325036713034
0.7803307604456015

0.7780

 77%|███████▋  | 1342/1747 [00:04<00:01, 320.38it/s]


0.9087126142313984
0.8996936707123189
0.9032410066627938

0.93266542319478
0.7462281809032909

0.7597189462951439
0.7754509715485571
0.8236227526153316
0.7722958544167785
0.834395927329322

0.8236227526153316
0.7593273519904803

0.9190324529693752
0.9434673152286215
0.8954799706765167
0.9045619569961161
0.8957507834185653

0.8957507834185653
0.7867504675040624

0.8398616674019301
0.8108849363344142
0.7963679251169704
0.8027580423502174
0.7806642801498549

0.8398616674019301
0.779731590950885

0.8153364276284831
0.8084297301777665
0.8324544365095708
0.8002879380623426
0.8280722449592458

0.8324544365095708
0.7593273519904803

0.87374735880418
0.8977955683251465
0.8906436452835045
0.8537560338339332
0.8993503935649554

0.8993503935649554
0.7867504675040624

0.8044244832448106
0.8079243451501286
0.8348606060502581
0.84305921067558
0.8387306683716345

0.8348606060502581
0.7593273519904803

0.7685716196085473
0.7641148685051691
0.822615874466711
0.7964299514199953
0.8271598325103453

0.822

 81%|████████  | 1412/1747 [00:04<00:01, 332.52it/s]


0.7803307604456015

0.9041590488907554
0.8695739382682429
0.8792298579309658
0.8909293645527447
0.8594535496664428

0.8792298579309658
0.7593273519904803

0.8833026363100991
0.8643944819179721
0.8563959308819958
0.8543830944748027
0.8489211521095181

0.8563959308819958
0.7593273519904803

0.7952934860672599
0.7879302138401674
0.8093702335684232
0.7957392121789333
0.7864893473291058

0.7952934860672599
0.779731590950885

0.895480460560959
0.8701882035171824
0.8541775903612854
0.8745371797227106
0.8421995238327497

0.8701882035171824
0.7462281809032909

0.8960165798810608
0.9320784289984219
0.8918883643593055
0.8729651707907263
0.8903680555615945

0.8918883643593055
0.7593273519904803

0.9236690463302184
0.9433554635410197
0.9050221095744211
0.901618399863129
0.9018098500657816

0.9236690463302184
0.779731590950885

0.8644087083498251
0.9054785691717389
0.8657300575698007
0.8630796601083005
0.8602814230255865

0.8657300575698007
0.7593273519904803

0.8779342306148644
0.8985344849506243


 85%|████████▍ | 1479/1747 [00:04<00:00, 317.16it/s]


0.7867504675040624

0.834632919201019
0.8044045409784609
0.7913515193204415
0.7985203963375089
0.7709545189867917

0.834632919201019
0.779731590950885

0.8787245923352205
0.8722353141100954
0.8874369870426806
0.8756818643137969
0.8775846082906826

0.8874369870426806
0.7593273519904803

0.862303788615972
0.8626597689683428
0.8290465342079083
0.823391872564158
0.825612500885784

0.8626597689683428
0.7462281809032909

0.8966620980003961
0.909978907370323
0.9111002689751653
0.9063442759423752
0.9035428470653242

0.9111002689751653
0.7593273519904803

0.901333908309276
0.9170891531692917
0.8499542293397921
0.878234684417212
0.840077601397832

0.9170891531692917
0.7462281809032909

0.9265234255102259
0.9342225903819957
0.8547116790367606
0.8789818287486326
0.8514008000697388

0.9265234255102259
0.779731590950885

0.925369468804933
0.9472593244848468
0.9036920251418126
0.91817847673423
0.9008378767879825

0.9036920251418126
0.7593273519904803

0.8873959137913682
0.8484112049713067
0.89369625

 88%|████████▊ | 1544/1747 [00:04<00:00, 320.92it/s]

0.8561447125474002
0.8432020538716005
0.8729195039915315

0.8729195039915315
0.7867504675040624

0.9144885193809954
0.9332167588270359
0.9338953295737864
0.9165773579676536
0.9314340580444582

0.9332167588270359
0.7462281809032909

0.9218564408434295
0.9207848795035309
0.9184217884528754
0.93431589464732
0.9147914131102267

0.9207848795035309
0.7462281809032909

0.8871313429622277
0.8957540218428527
0.913064801909996
0.892792193664713
0.9282710293714918

0.9282710293714918
0.7867504675040624

0.9087789738735114
0.9358785524402542
0.9025766577275329
0.8993361197984161
0.8924480251658576

0.9025766577275329
0.7593273519904803

0.8503489327590006
0.8732833643520719
0.8232970902956818
0.8076830635220337
0.8271796505016469

0.8503489327590006
0.779731590950885

0.9075132768535751
0.8996795718267535
0.8815471878325136
0.9106210525701922
0.8655319698763834

0.8815471878325136
0.7593273519904803

0.8515821370963265
0.8632754715730399
0.886208838509392
0.8267441619291774
0.892005532964054

0.88

 92%|█████████▏| 1610/1747 [00:05<00:00, 318.64it/s]


0.8881773864394451
0.9001285044209253

0.9091744382222227
0.7593273519904803

0.918411699655766
0.8959672782804928
0.8715805813403994
0.8603064928730971
0.8680037179595046

0.918411699655766
0.779731590950885

0.8372813055758935
0.8035817926754218
0.819785852169934
0.8036985249385873
0.8207601330844811

0.8372813055758935
0.779731590950885

0.920446813780714
0.9396205441813489
0.8574015960953185
0.8831119050648453
0.8527169699490458

0.920446813780714
0.779731590950885

0.8482812936603799
0.9010633144463912
0.8450168665486812
0.830242943342999
0.8673905527853982

0.8450168665486812
0.7593273519904803

0.8148848969361877
0.8077698429865011
0.8486614548385829
0.8211121381352682
0.8640714983585707

0.8640714983585707
0.7867504675040624

0.7877940947644522
0.7675733430380982
0.8000351049549211
0.7872229364367702
0.7977689301158035

0.8000351049549211
0.7593273519904803

0.8997329548387766
0.9228982248807802
0.9079538449363935
0.8920836748418439
0.919584839836036

0.9079538449363935
0.7593

 96%|█████████▌| 1674/1747 [00:05<00:00, 302.97it/s]


0.8606323158428681
0.8906257867351064

0.902144249327316
0.7593273519904803

0.8984593825736791
0.942548698626039
0.8704045056556696
0.8805743849939971
0.8691488428510055

0.8704045056556696
0.7593273519904803

0.838012868980708
0.8532025373042637
0.8535939191571633
0.8321960954125238
0.8529689738091668

0.8532025373042637
0.7462281809032909

0.9180531310818063
0.9340071793532224
0.8564542709060885
0.8747494245135778
0.8525384930565214

0.9180531310818063
0.779731590950885

0.8241724525382781
0.8285138965812754
0.8592652353962342
0.9028068939834313
0.8620855191108103

0.9028068939834313
0.7803307604456015

0.8990781791008287
0.9415726172766332
0.8913551793445387
0.8879203269771485
0.8906834154216658

0.9415726172766332
0.7462281809032909

0.8538474646100482
0.8529089663101361
0.8848067825902427
0.8716061994271106
0.884616754867485

0.8848067825902427
0.7593273519904803

0.8225264352620195
0.8338663457310482
0.8515822977224987
0.8458620296609768
0.8460311775063858

0.8515822977224987
0

100%|██████████| 1747/1747 [00:05<00:00, 318.92it/s]



0.8997848327588935
0.7867504675040624

0.8691097305323605
0.8364073767719595
0.8152569899365953
0.8110824587471527
0.8119203809824649

0.8691097305323605
0.779731590950885

0.9052635237643016
0.9033257176452566
0.829312423664075
0.843629272343373
0.8309810510956193

0.9052635237643016
0.779731590950885

0.8718205629864049
0.865920385131223
0.8407701690010325
0.8480529512101194
0.8405523850030205

0.8718205629864049
0.779731590950885

0.8623538452788931
0.923881590586969
0.8331870739243509
0.810891780006902
0.8275222672136882

0.923881590586969
0.7462281809032909

0.8926673632089002
0.8863139584591437
0.931474206807546
0.911838588101498
0.922805742971744

0.922805742971744
0.7867504675040624

0.9049958319800152
0.922028739530541
0.8554115138330602
0.8886209994682819
0.8462634224612418

0.9049958319800152
0.779731590950885

0.8797609167289262
0.8635822443597371
0.8800452065179967
0.9290776336490254
0.8649069456247145

0.8797609167289262
0.779731590950885

0.8541709508076116
0.861430010




In [None]:
evaluate(updated_predictions, y_test)

Overall accuracy: 0.48597595878649114
Accuracy for class 2: 0.9359605911330049
Accuracy for class 4: 0.9651162790697675
Accuracy for class 0: 0.9856459330143541
Accuracy for class 3: 0.8947368421052632
Accuracy for class 1: 0.972972972972973
Accuracy for class -1: 0.012775842044134728


In [None]:
macro_f1_score = f1_score(updated_predictions, y_test, average='macro')
print("Macro F1 score: {}".format(macro_f1_score))

micro_f1_score = f1_score(updated_predictions, y_test, average='micro')
print("Micro F1 score: {}".format(micro_f1_score))

weighted_f1_score = f1_score(updated_predictions, y_test, average='weighted')
print("Weighted F1 score: {}".format(weighted_f1_score))

norm_f1_score = f1_score(updated_predictions, y_test, average=None)
print("F1 score for each class: {}".format(norm_f1_score))

Macro F1 score: 0.5432736573388565
Micro F1 score: 0.48597595878649114
Weighted F1 score: 0.6318731682194519
F1 score for each class: [0.02522936 0.66237942 0.66976744 0.74363992 0.60986547 0.54876033]
