In [1]:
import torch 
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader,random_split,TensorDataset
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
import timm 
import tqdm as notebook_tqdm
import torchaudio
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_auc_score, precision_score, recall_score, f1_score

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import numpy as np
import pandas as pd 
import matplotlib.pyplot as plt
import sys
import json
from sklearn.metrics.pairwise import cosine_similarity

# Đọc dữ liệu

In [3]:
path_to_data='../../dataset'

In [4]:
with open(path_to_data+'/train.json', 'r') as f:
    train_data = json.load(f)
with open(path_to_data+'/test.json', 'r') as f:
    test_data = json.load(f)

In [5]:
train_data=pd.DataFrame(train_data)
test_data=pd.DataFrame(test_data)

# Lấy mẫu VGGISH TURKEY

In [6]:
model = torch.hub.load('harritaylor/torchvggish', 'vggish')
model.eval()

Using cache found in C:\Users\vietl/.cache\torch\hub\harritaylor_torchvggish_master


VGGish(
  (features): Sequential(
    (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (11): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): ReLU(inplace=True)
    (13): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (14): ReLU(inplace=True)
    (15): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False

In [7]:
import urllib
filename='turkey.wav'

In [8]:
turkey_features=model.forward(filename)
turkey_features = turkey_features.int()

In [9]:
turkey_mean_vector = np.round(turkey_features.float().mean(axis=0), 2)

In [10]:
turkey_mean_vector

tensor([165.8000,  30.6000, 174.0000,  99.6000, 220.7000,  55.4000, 120.0000,
         72.3000, 205.4000, 199.8000,  71.5000,  81.9000, 176.0000, 207.6000,
          7.9000,  46.8000, 148.1000, 142.1000, 213.5000, 194.5000,  39.6000,
        210.9000, 121.0000,  50.2000,  96.5000, 123.2000, 172.7000, 215.1000,
        185.9000,   6.9000,  33.4000,  10.2000,  79.9000,  86.7000, 127.3000,
        192.5000,  20.4000, 104.2000, 151.7000, 132.5000,   2.5000,  63.1000,
          1.6000, 132.6000, 206.4000, 137.7000, 103.0000, 188.1000,  38.0000,
        255.0000, 196.7000, 185.1000,  25.3000, 146.9000,  43.7000, 186.6000,
        161.6000,  82.3000,  50.0000, 192.7000,  71.1000, 115.4000,  47.5000,
        180.6000, 179.9000, 139.1000,  24.6000,  94.0000, 233.0000, 171.5000,
        156.0000,  25.9000, 254.0000, 198.9000,  55.3000,  30.5000, 159.4000,
          0.0000,  13.4000, 222.1000,  61.6000, 232.2000, 241.9000, 235.6000,
        245.9000, 187.5000, 152.2000, 143.2000,   1.1000,   0.00

# Tách các array vector thành các mẫu mới

In [11]:
def expand_audio_embeddings(data):
    expanded_rows = []
    for idx, row in data.iterrows():
        embeddings = row['audio_embedding']
        for emb in embeddings:
            new_row = row.copy()
            new_row['audio_embedding'] = emb
            expanded_rows.append(new_row)
    expanded_data = pd.DataFrame(expanded_rows)
    expanded_data.reset_index(drop=True, inplace=True)
    return expanded_data

In [12]:
expanded_train_val_data = expand_audio_embeddings(train_data)
expanded_test_data = expand_audio_embeddings(test_data)

In [13]:
expanded_train_val_data['is_turkey'].value_counts()

is_turkey
0    6954
1    4841
Name: count, dtype: int64

In [14]:
expanded_train_val_data=expanded_train_val_data[['audio_embedding','is_turkey','vid_id']]

In [15]:
def get_similarity_score(embedding):
    emb = np.array(embedding).reshape(1, -1)
    turkey_vec = turkey_mean_vector.numpy().reshape(1, -1)
    similarity = cosine_similarity(emb, turkey_vec)
    return similarity[0][0]


In [16]:
expanded_train_val_data['similarity_to_turkey'] = expanded_train_val_data['audio_embedding'].apply(get_similarity_score)
expanded_test_data['similarity_to_turkey'] = expanded_test_data['audio_embedding'].apply(get_similarity_score)

In [17]:
expanded_train_val_data['fixing_label'] = expanded_train_val_data['similarity_to_turkey'].apply(lambda x: 1 if x > 0.85 else 0)

In [18]:
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)
pd.set_option('display.width', None)
pd.set_option('display.max_colwidth', None)
expanded_train_val_data.iloc[300:400]

Unnamed: 0,audio_embedding,is_turkey,vid_id,similarity_to_turkey,fixing_label
300,"[152, 11, 180, 92, 229, 95, 97, 105, 153, 181, 69, 110, 151, 94, 81, 148, 160, 116, 162, 159, 30, 234, 54, 130, 66, 148, 167, 161, 171, 95, 181, 23, 65, 215, 113, 181, 149, 94, 178, 87, 75, 230, 20, 110, 105, 152, 72, 106, 46, 186, 147, 127, 26, 56, 3, 15, 209, 43, 114, 255, 123, 151, 121, 255, 217, 117, 105, 157, 255, 107, 163, 144, 136, 255, 108, 144, 95, 192, 136, 215, 0, 208, 255, 156, 125, 177, 18, 149, 208, 127, 229, 162, 194, 168, 39, 254, 54, 31, 255, 165, ...]",0,lehPGCdtNmY,0.842992,0
301,"[157, 15, 179, 122, 213, 69, 78, 59, 168, 168, 62, 26, 130, 128, 102, 85, 85, 105, 175, 144, 34, 182, 120, 114, 10, 139, 251, 116, 116, 19, 73, 121, 237, 147, 85, 154, 189, 86, 95, 34, 130, 135, 0, 205, 156, 134, 136, 111, 24, 166, 156, 92, 2, 144, 54, 51, 208, 1, 193, 255, 129, 98, 87, 114, 255, 105, 104, 207, 255, 123, 210, 61, 224, 228, 62, 122, 168, 158, 119, 255, 23, 255, 185, 136, 157, 174, 45, 180, 166, 0, 53, 205, 210, 214, 98, 206, 32, 60, 156, 96, ...]",0,lehPGCdtNmY,0.846113,0
302,"[153, 4, 183, 96, 212, 116, 67, 87, 151, 217, 71, 96, 128, 97, 134, 112, 128, 138, 199, 134, 53, 209, 140, 101, 0, 170, 159, 184, 173, 128, 141, 87, 92, 170, 63, 152, 163, 45, 243, 53, 103, 226, 102, 176, 126, 173, 149, 185, 20, 185, 213, 147, 51, 61, 131, 4, 204, 6, 13, 255, 144, 113, 94, 255, 231, 118, 150, 174, 242, 67, 237, 105, 175, 255, 72, 116, 11, 255, 133, 255, 61, 153, 220, 118, 143, 74, 101, 101, 162, 47, 103, 255, 156, 34, 0, 250, 14, 227, 168, 94, ...]",0,lehPGCdtNmY,0.830122,0
303,"[149, 5, 179, 51, 210, 103, 100, 101, 155, 208, 89, 125, 103, 165, 147, 152, 115, 149, 173, 154, 71, 197, 30, 167, 96, 172, 222, 191, 190, 119, 188, 0, 156, 88, 28, 158, 136, 16, 255, 0, 120, 208, 0, 200, 140, 255, 179, 152, 138, 170, 255, 97, 0, 62, 64, 0, 106, 19, 89, 255, 106, 122, 32, 149, 125, 87, 227, 185, 255, 0, 184, 209, 40, 213, 73, 168, 88, 255, 225, 72, 0, 97, 205, 118, 147, 129, 52, 145, 140, 142, 216, 192, 179, 167, 0, 255, 0, 129, 244, 193, ...]",0,lehPGCdtNmY,0.769359,0
304,"[158, 12, 170, 127, 190, 59, 53, 66, 181, 168, 86, 43, 159, 200, 61, 32, 96, 79, 187, 82, 89, 153, 143, 165, 39, 143, 146, 89, 117, 99, 78, 150, 174, 121, 130, 108, 180, 99, 162, 95, 0, 129, 35, 111, 184, 144, 107, 90, 0, 203, 137, 79, 57, 96, 143, 40, 150, 0, 176, 243, 136, 35, 97, 225, 255, 70, 226, 124, 229, 144, 187, 86, 187, 193, 15, 123, 32, 164, 123, 255, 83, 164, 176, 88, 255, 84, 130, 197, 164, 0, 110, 217, 126, 62, 31, 229, 46, 9, 181, 226, ...]",0,lehPGCdtNmY,0.848983,0
305,"[158, 11, 167, 131, 195, 57, 51, 73, 178, 167, 57, 54, 143, 121, 112, 100, 86, 93, 179, 106, 80, 164, 103, 77, 58, 130, 192, 180, 124, 70, 81, 120, 99, 173, 123, 165, 228, 99, 95, 99, 59, 187, 58, 142, 129, 135, 120, 99, 7, 167, 180, 131, 0, 73, 73, 0, 221, 17, 162, 228, 123, 118, 74, 255, 254, 117, 154, 167, 214, 167, 158, 89, 214, 255, 49, 180, 77, 163, 124, 255, 116, 255, 239, 119, 152, 118, 82, 182, 180, 38, 143, 203, 198, 140, 124, 230, 113, 29, 182, 255, ...]",0,lehPGCdtNmY,0.844075,0
306,"[151, 10, 181, 95, 218, 88, 96, 85, 176, 233, 59, 81, 147, 152, 66, 93, 125, 161, 195, 128, 53, 175, 95, 78, 56, 176, 134, 226, 181, 121, 129, 65, 97, 137, 7, 155, 103, 85, 219, 95, 93, 201, 54, 145, 136, 140, 106, 137, 0, 255, 232, 85, 31, 54, 127, 0, 173, 0, 15, 227, 95, 55, 43, 243, 223, 169, 122, 168, 213, 141, 216, 116, 151, 255, 98, 1, 75, 233, 159, 255, 0, 133, 231, 197, 194, 100, 208, 254, 162, 0, 129, 188, 201, 158, 22, 207, 70, 117, 133, 142, ...]",0,lehPGCdtNmY,0.858872,1
307,"[156, 7, 176, 105, 185, 59, 78, 58, 180, 185, 72, 27, 148, 168, 80, 35, 101, 94, 204, 119, 85, 156, 161, 126, 48, 157, 181, 136, 121, 141, 92, 105, 138, 123, 120, 140, 149, 22, 192, 80, 62, 129, 51, 190, 202, 146, 133, 93, 12, 214, 158, 130, 48, 98, 85, 0, 128, 72, 176, 255, 71, 45, 16, 175, 215, 80, 197, 154, 255, 47, 186, 37, 194, 255, 51, 139, 135, 200, 130, 255, 58, 191, 232, 78, 223, 99, 26, 182, 165, 47, 153, 255, 163, 110, 54, 245, 18, 55, 150, 150, ...]",0,lehPGCdtNmY,0.84915,0
308,"[157, 11, 186, 112, 219, 92, 86, 89, 166, 193, 93, 91, 124, 122, 117, 100, 108, 120, 192, 123, 36, 206, 68, 110, 11, 139, 169, 158, 154, 97, 110, 85, 152, 177, 93, 171, 173, 99, 142, 59, 126, 153, 81, 194, 158, 140, 110, 138, 0, 172, 165, 72, 71, 118, 90, 53, 175, 41, 162, 255, 156, 126, 85, 180, 255, 138, 105, 126, 244, 94, 186, 123, 226, 255, 70, 124, 183, 160, 127, 233, 33, 251, 236, 121, 154, 108, 48, 181, 139, 52, 118, 238, 172, 177, 16, 255, 10, 42, 136, 133, ...]",0,lehPGCdtNmY,0.878213,1
309,"[149, 9, 180, 80, 233, 115, 99, 107, 139, 222, 79, 128, 117, 99, 144, 149, 119, 174, 160, 131, 42, 221, 79, 91, 41, 179, 164, 181, 157, 139, 171, 84, 114, 154, 20, 192, 198, 84, 195, 27, 139, 255, 47, 188, 86, 185, 136, 168, 137, 147, 214, 100, 30, 55, 37, 5, 190, 45, 82, 255, 237, 107, 67, 255, 203, 132, 141, 199, 250, 80, 167, 108, 169, 255, 88, 44, 35, 190, 117, 226, 43, 132, 212, 151, 128, 161, 124, 193, 179, 70, 106, 194, 193, 97, 0, 255, 28, 161, 156, 88, ...]",0,lehPGCdtNmY,0.818036,0


In [19]:
mismatch_count = (expanded_train_val_data['is_turkey'] != expanded_train_val_data['fixing_label']).sum()
print("Number of samples where is_turkey differs from fixing_label:", mismatch_count)

Number of samples where is_turkey differs from fixing_label: 4429


In [20]:
expanded_train_val_data.to_csv('fixing_train_data.csv', index=False)
expanded_test_data.to_csv('fixing_test_data.csv', index=False)