In [1]:
import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.autograd import Variable

In [9]:
device = torch.device("cpu")
time_size = 10
embed_dim = 128
embed_max = 256
label_num = 527

In [3]:
def df_to_tensor(data_name):
    df = pd.read_parquet(f'{data_name}.parquet')
    df = df[~df.isnull()]
    df = df[df['audio_embedding'].apply(lambda x: len(x)) == time_size]
    print(f"{data_name}_df shape: ",df.shape)

    embeddings = np.vstack(df['audio_embedding'].apply(lambda x: np.vstack(x))).reshape(-1,time_size,embed_dim)
    embedding_tensor = torch.Tensor(embeddings)
    print(f"{data_name}_embedding shape: ", embedding_tensor.size())

    def label_converter(x):
        output = np.zeros(527,dtype=int)
        for label in x:
            output[label] = 1
        return output

    df['label'] = df['labels'].apply(lambda x: label_converter(x))

    labels = np.vstack(df['label']).reshape(-1,label_num)
    label_tensor = torch.Tensor(labels)
    print(f"{data_name}_label shape: ",label_tensor.size())
    
    return embedding_tensor/embed_max, label_tensor

In [4]:
train_embedding, train_label = df_to_tensor('bal_train')
test_embedding, test_label = df_to_tensor('eval')

bal_train_df shape:  (21782, 5)
bal_train_embedding shape:  torch.Size([21782, 10, 128])
bal_train_label shape:  torch.Size([21782, 527])
eval_df shape:  (19976, 5)
eval_embedding shape:  torch.Size([19976, 10, 128])
eval_label shape:  torch.Size([19976, 527])


In [5]:
train_set = TensorDataset(train_embedding, train_label)
test_set = TensorDataset(test_embedding, test_label)

In [6]:
bs = 16
train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=bs, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=1, shuffle=False)

In [10]:
class YoutubeAudioClassifier(nn.Module):
    def __init__(self, time_size, embed_dim, fc_dim, output_dim):
        super(YoutubeAudioClassifier, self).__init__()
        self.time_size = time_size
        self.embed_dim = embed_dim
        self.embed_dim1, self.embed_dim2 = self.split_embed_dim()
        self.fc_dim = fc_dim
        self.output_dim = output_dim
        
        self.intra_init_conv = nn.Sequential(
            nn.Conv2d(self.time_size, self.time_size//2, 3, padding=1),    # 10*16*8 -> 5*16*8
            nn.BatchNorm2d(self.time_size//2),
            nn.ReLU()
        )
        
        self.intra_stride = nn.Sequential(
            nn.Conv2d(self.time_size//2, self.time_size//4, 2, stride=2),  # 5*16*8 -> 2*8*4
            nn.BatchNorm2d(self.time_size//4),
            nn.ReLU()
        )
        self.intra_dim1_dil = nn.Sequential(
            nn.Conv2d(self.time_size//2, self.time_size//2, 3, dilation=(2,1)),  # 5*16*8 -> 5*12*6
            nn.BatchNorm2d(self.time_size//2),
            nn.ReLU(),
            nn.Conv2d(self.time_size//2, self.time_size//4, 3, dilation=(2,1)),  # 5*12*6 -> 2*8*4
            nn.BatchNorm2d(self.time_size//4),
            nn.ReLU()
        )
        self.intra_dim2_dil = nn.Sequential(
            nn.Conv2d(self.time_size//2, self.time_size//4, (2,3), dilation=(1,2), stride=(2,1)),  # 5*16*8 -> 2*8*4
            nn.BatchNorm2d(self.time_size//4),
            nn.ReLU()
        )
        
        
        self.inter_conv1 = nn.Sequential(
            nn.Conv2d(1, 4, (3,5), stride=(1,2)),  # 1*10*128 -> 4*8*62
            nn.BatchNorm2d(4),
            nn.ReLU(),
            nn.Conv2d(4, 8, (3,8), dilation=(1,3)),  # 4*8*62 -> 8*6*41
            nn.BatchNorm2d(8),
            nn.ReLU()
        )
        self.inter_max_conv = nn.Sequential(
            nn.MaxPool2d((2,6), stride=(2,3), padding=(1,0)),  # 1*10*128 -> 1*6*41
            nn.Conv2d(1, 8, 3, padding=1), # 1*6*41 -> 8*6*41
            nn.BatchNorm2d(8),
            nn.ReLU()
        )
        
        self.inter_conv2 = nn.Sequential(
            nn.Conv2d(8, 4, (3,3), stride=(1,2)),  # 8*6*41 -> 4*4*20
            nn.BatchNorm2d(4),
            nn.ReLU(),
            nn.Conv2d(4, 2, (3,5), dilation=(1,3), padding=(1,0)),  # 4*4*20 -> 2*4*8
            nn.BatchNorm2d(2),
            nn.ReLU()
        )
        
        self.combine_norm = nn.Sequential(
            nn.BatchNorm2d(2),
            nn.ReLU()
        )
        
        self.fc = nn.Linear(self.fc_dim, self.output_dim)  # 64 -> 527


    
    # For balanced Width * Height split of input data for intra blocks    
    def split_embed_dim(self):
        for i in reversed(np.arange(np.ceil(np.sqrt(self.embed_dim))+1)):
            if self.embed_dim % i == 0:
                return self.embed_dim // int(i) , int(i)
                break
    
    def forward(self, data):
        intra_data = data.view(-1, self.time_size, self.embed_dim1, self.embed_dim2)
        inter_data = data.view(-1, 1, self.time_size, self.embed_dim)
        
        intra_block1_out = self.intra_init_conv(intra_data)
        intra_block2_out = self.intra_stride(intra_block1_out) + self.intra_dim1_dil(intra_block1_out) + self.intra_dim2_dil(intra_block1_out)
        
        inter_block1_out = self.inter_conv1(inter_data) + self.inter_max_conv(inter_data)
        inter_block2_out = self.inter_conv2(inter_block1_out).transpose(-2,-1)
        
        cnn_out = self.combine_norm(intra_block2_out + inter_block2_out).view(-1,self.fc_dim)
        fc_out = self.fc(cnn_out)
        
        return F.softmax(fc_out,dim=1)
        

In [21]:
YAC = YoutubeAudioClassifier(time_size, embed_dim, 64, label_num).to(device)
YAC_optimizer = optim.Adam(YAC.parameters(), lr = 0.01)
loss_function = nn.MultiLabelSoftMarginLoss()

In [22]:
# Training process
n_epoch = 100
loss_list = []

for epoch in tqdm(range(n_epoch)):           
    epoch_loss_list = []
    for embedding, labels in train_loader:
        YAC_optimizer.zero_grad()
        embedding = embedding.to(device)
        loss = loss_function(YAC(embedding), labels.to(device))
        loss.backward()
        YAC_optimizer.step()
        
        epoch_loss_list.append(loss.item())
    
    epoch_mean_loss = np.mean(np.array(epoch_loss_list))    
    loss_list.append(epoch_mean_loss)
    print(epoch_mean_loss)

    if epoch % (n_epoch//10) == 0:
        print(f"Epoch_{epoch} Loss: {epoch_mean_loss}")

  1%|▋                                                                         | 1/100 [00:11<18:55, 11.47s/it]

0.6934587360741283
Epoch_0 Loss: 0.6934587360741283


  2%|█▍                                                                        | 2/100 [00:22<18:43, 11.46s/it]

0.6934341386989


  3%|██▏                                                                       | 3/100 [00:34<18:30, 11.45s/it]

0.6934096862653628


  4%|██▉                                                                       | 4/100 [00:45<18:18, 11.44s/it]

0.6933993029786977


  5%|███▋                                                                      | 5/100 [00:57<18:10, 11.48s/it]

0.6933901992225787


  6%|████▍                                                                     | 6/100 [01:08<18:02, 11.52s/it]

0.6933778316964136


  7%|█████▏                                                                    | 7/100 [01:20<17:53, 11.54s/it]

0.6933716227917244


  8%|█████▉                                                                    | 8/100 [01:32<17:42, 11.55s/it]

0.6933697401251911


  9%|██████▋                                                                   | 9/100 [01:43<17:31, 11.55s/it]

0.6933666646217985


 10%|███████▎                                                                 | 10/100 [01:55<17:21, 11.57s/it]

0.6933631705730147


 11%|████████                                                                 | 11/100 [02:06<17:14, 11.62s/it]

0.6933609440200654
Epoch_10 Loss: 0.6933609440200654


 12%|████████▊                                                                | 12/100 [02:18<17:03, 11.63s/it]

0.6933572875245553


 13%|█████████▍                                                               | 13/100 [02:30<16:49, 11.61s/it]

0.6933549964655513


 14%|██████████▏                                                              | 14/100 [02:41<16:35, 11.57s/it]

0.6933526173124579


 15%|██████████▉                                                              | 15/100 [02:53<16:21, 11.54s/it]

0.6933502534762703


 16%|███████████▋                                                             | 16/100 [03:04<16:10, 11.56s/it]

0.6933471685639212


 17%|████████████▍                                                            | 17/100 [03:16<15:58, 11.55s/it]

0.6933458600625417


 18%|█████████████▏                                                           | 18/100 [03:27<15:46, 11.54s/it]

0.6933433359677571


 18%|█████████████▏                                                           | 18/100 [03:31<16:02, 11.74s/it]


KeyboardInterrupt: 