In [None]:
import os
import glob
import random

from collections import namedtuple

import librosa

import matplotlib.pyplot as plt

import IPython.display as ipd

import numpy as np

import torch.optim as optim

In [None]:
import queue

import time

import threading

from tqdm import tqdm

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter # arcface-pytorch

In [None]:
BATCH_SIZE = 64

In [None]:
FileInfo = namedtuple('file_info', 'file_path mode equipment status equip_id file_id')

BatchData = namedtuple('batch_data', 'mel equipment status')

In [None]:
dataset_path = 'dev_data/*'

dataset_direc_list = [path for path in glob.glob(dataset_path) if os.path.isdir(path)]
dataset_direc_list.sort()

print(dataset_direc_list)

In [None]:
equipments = ['ToyCar', 'ToyConveyor', 'fan', 'pump', 'slider', 'valve']

num_equipments = len(equipments)

EQUIPMENT_DICT = {
    equip: i for i, equip in enumerate(equipments)
}

status = ['normal', 'anomaly']

STATUS_DICT = {
    stat: i for i, stat in enumerate(status)
}

In [None]:
STATUS_DICT

In [None]:
def get_metadata(dataset_dir, mode='train'):
    
    file_path_list = glob.glob(direc + '/' + mode + '/*.wav')
    file_path_list.sort()
    metadata = [FileInfo(file_path, mode, os.path.basename(direc), *path_to_file_info(file_path)) for file_path in file_path_list]

    return metadata

In [None]:
def path_to_file_info(path):
    
    '''
    return status, equip_id, file_num
    '''
    
    segments = os.path.basename(path).split('_')
    
    return segments[0], segments[2], segments[3]
    

In [None]:
def audio_visual_inspection(metadatum):
    file = getattr(metadatum, 'file_path')

    print(file)

    y, sr = librosa.core.load(file, sr=None)
    
    mel = librosa.feature.melspectrogram(y, sr=sr, n_fft=int(sr * 0.1), hop_length=int(sr * 0.05), power=1, n_mels=160)
    mel = 20 * np.log10(np.maximum(mel, 1e-8))

    fig, axes = plt.subplots(2, 1, figsize=(15, 6))
    axes[0].plot(y)
    axes[0].set_xlim([0, len(y)])
    axes[1].imshow(mel, origin='reversed', aspect='auto')
    plt.tight_layout()
    plt.show()
    
    print(mel.shape)

    return ipd.Audio(y, rate=sr)

In [None]:
metadata_train = list()
metadata_test = list()

for direc in dataset_direc_list:
# direc = random.choice(dataset_direc_list)

# print(direc)

#     metadata_train.append(get_metadata(direc, 'train'))
#     metadata_test.append(get_metadata(direc, 'test'))
    metadata_train += get_metadata(direc, 'train')
    metadata_test += get_metadata(direc, 'test')
    
# print(list(map(len, metadata_train)))
# print(list(map(len, metadata_test)))

print(len(metadata_train))
print(len(metadata_test))

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
def load_mel(file_path):

    y, sr = librosa.core.load(file_path, sr=None)
    
    y = y[:sr * 10]
    
    mel = librosa.feature.melspectrogram(y, sr=sr, n_fft=int(sr * 0.1), hop_length=int(sr * 0.05), power=1, n_mels=160)
    mel = (20 * np.log10(np.maximum(mel, 1e-8)) + 160 ) / 160
    
    return mel

def batch_list_to_batch(batch_list):
    
    mel_batch_list = list()
    equip_list = list()
    status_list = list()
    
    for mel, equip, status in batch_list:
        mel_batch_list.append(mel)
        equip_list.append(EQUIPMENT_DICT[equip])
        status_list.append(STATUS_DICT[status])
        
    # print(mel.shape) # (16, 160, 201)
    
    return BatchData(np.stack(mel_batch_list), np.array(equip_list, dtype=int), np.array(status_list, dtype=int))

class DatasetFeeder:
    
    def __init__(self, metadata_list):
        self.batch_queue = queue.Queue(maxsize=100)
        self.batch_size = BATCH_SIZE
        self.metadata_list = metadata_list
        self.batching_finished = False
        self.max_batch_num = int(np.ceil(len(self.metadata_list) / self.batch_size))
        
    def start_batching(self):
        
        random.shuffle(self.metadata_list)
        
        batch_data_list = list()
        
        for metadata in self.metadata_list:
            
            file_path = getattr(metadata, 'file_path')
            equipment = getattr(metadata, 'equipment')
            status = getattr(metadata, 'status')
            
            mel = load_mel(file_path)
            
            batch_data_list.append((mel, equipment, status))
            
            if len(batch_data_list) >= self.batch_size:
                self.batch_queue.put(batch_list_to_batch(batch_data_list))
                batch_data_list = list()
        
        if len(batch_data_list) > 0:
            self.batch_queue.put(batch_list_to_batch(batch_data_list))
            batch_data_list = list()
            
        self.batching_finished = True
    
    def batch_generator(self):
        
        self.batching_finished = False
        t = threading.Thread(target=self.start_batching, args=())
        t.start()
        
        while not (self.batching_finished and self.batch_queue.empty()):
            try : 
                batch = self.batch_queue.get_nowait()
                yield batch
                
            except:
                time.sleep(1)
                
        t.join()
            
        return 0
        

In [None]:
class CRNN_Model(nn.Module):
    
    def __init__(self, device, s=5, m=0.35):
        super(CRNN_Model, self).__init__()
        self.cnn_layers_1 = nn.Sequential(nn.Conv2d(1, 64, (9, 3), dilation=2), 
                                        nn.BatchNorm2d(64), 
                                        nn.ReLU(),
                                        nn.Conv2d(64, 128, (9, 3), dilation=2), 
                                        nn.BatchNorm2d(128),
                                        nn.ReLU())
        
        self.cnn_layers_2 = nn.Sequential(nn.Conv2d(128, 128, (9, 3), dilation=2), 
                                nn.BatchNorm2d(128), 
                                nn.ReLU(),
                                nn.Conv2d(128, 128, (9, 3), dilation=2), 
                                nn.BatchNorm2d(128),
                                nn.ReLU())
        
        self.cnn_layers_3 = nn.Sequential(nn.Conv2d(128, 256, (9, 3), dilation=2), 
                        nn.BatchNorm2d(256), 
                        nn.ReLU(),
                        nn.Conv2d(256, 256, (9, 3), dilation=2), 
                        nn.BatchNorm2d(256),
                        nn.ReLU())
        
        self.cnn_layers_4 = nn.Sequential(nn.Conv2d(256, 256, (9, 3), dilation=2), 
                nn.BatchNorm2d(256), 
                nn.ReLU(),
                nn.Conv2d(256, 256, (9, 3), dilation=2), 
                nn.BatchNorm2d(256),
                nn.ReLU())
        
        self.cnn_layers_5 = nn.Sequential(nn.Conv2d(256, 512, (32, 3)), 
        nn.BatchNorm2d(512), 
        nn.ReLU()) # (B, H, 1, L)

        self.rnn_layers = nn.ModuleList((nn.GRU(512, 256, batch_first=True),
                                        nn.GRU(256, 256, batch_first=True),
                                        nn.GRU(256, 256, batch_first=True)))
        
        # GRU (B, L, H)
        # nn.utils.weight_norm(nn.linear(256, 6, bias=True), name='weight')
#         self.W = nn.Linear(256, 6, bias=True)
#         self.W = nn.utils.weight_norm(nn.Linear(256, 6, bias=True), name='weight')
    
        self.W = Parameter(torch.FloatTensor(6, 256), requires_grad=True)
        nn.init.xavier_uniform_(self.W)
        
        self.s = s
        self.m = m
        
        self.cos_m = np.cos(m)
        self.sin_m = np.sin(m)
        
        self.device = device
        
    def forward(self, input_tensor, label_tensor):
        
        # print(input_tensor.shape)
        tensor = self.cnn_layers_1(input_tensor); # print(tensor.shape)
        tensor = self.cnn_layers_2(tensor); # print(tensor.shape)
        tensor = self.cnn_layers_3(tensor); # print(tensor.shape)        
        tensor = self.cnn_layers_4(tensor); # print(tensor.shape)
        tensor = self.cnn_layers_5(tensor); # print(tensor.shape)
        tensor = torch.squeeze(tensor, 2) # (B, H, 1, L) => (B, H, L)
        tensor.transpose_(1, 2); # print(tensor.shape) # (B, H, L) => (B, L, H)
        
        for rnn_layer in self.rnn_layers:
            tensor, _ = rnn_layer(tensor) # (B, L, H)
        
        # print(tensor.shape)
        
#         tensor = self.W(tensor) # (B, L, H) => (B, L, C)
        
        cosine_tensor = F.linear(F.normalize(tensor), F.normalize(self.W)) # (B, L, C)
        sine_tensor = torch.sqrt((1.0 - torch.pow(cosine_tensor, 2))) # (B, L, C)
        phi_tensor = cosine_tensor * self.cos_m - sine_tensor * self.sin_m # (B, L, C)
                             
        one_hot = torch.zeros([label_tensor.shape[0], 6]).to(self.device)
        one_hot.scatter_(1, label_tensor, 1)
        one_hot.unsqueeze_(1)
        one_hot.repeat((1, phi_tensor.shape[1], 1)).shape # (B, L, C)
        
        output_tensor = (one_hot * phi_tensor) + ((1.0 - one_hot) * cosine_tensor)
        output_tensor *= self.s
        
        return output_tensor, cosine_tensor

In [None]:
net = CRNN_Model(device).to(device)

criterion = nn.CrossEntropyLoss()

optimizer = optim.SGD(net.parameters(), lr=0.0001, momentum=0.9)

In [None]:
# for param in net.parameters():
#     print(param.shape)

In [None]:
train_dataset_feeder = DatasetFeeder(metadata_train)

In [None]:
# one_hot = torch.zeros([16, 6])
# print(one_hot.shape)

# label = torch.tensor(batch[1]).view(len(label),1)
# print(label.shape)

# one_hot.scatter_(1, label, 1)

# one_hot.unsqueeze_(1).shape

In [None]:
matrix = np.array([i for i in range(16 * 4 * 6)]).reshape([16, 4, 6])

In [None]:
matrix[0, 0, :]

In [None]:
m_tensor = torch.tensor(matrix)

print(m_tensor[0, 0, :])

print(m_tensor[0, 1, :])

print(m_tensor[0, 2, :])

In [None]:
# m_tensor.transpose_(0, 1)

In [None]:
new_m_tensor = m_tensor.reshape(-1, m_tensor.shape[-1])

In [None]:
print(new_m_tensor[0, :])

print(new_m_tensor[1, :])

print(new_m_tensor[2, :])

In [None]:
# [ B, L, C ] => # [ B x L, C ]

In [None]:
loss_history = list()

for i, batch in tqdm(enumerate(train_dataset_feeder.batch_generator()), total=train_dataset_feeder.max_batch_num):
    
    mel_batch = torch.tensor(np.expand_dims(batch[0], 1)).to(device)
    
    label = torch.tensor(batch[1]).view(len(batch[1]), 1).to(device)
        
    output_tensor, cosine_tensor = net(mel_batch, label)
    
    expanded_label = label.repeat(1, output_tensor.shape[1])
    
    expanded_label = expanded_label.reshape(-1) # B x L
    
    output_tensor = output_tensor.reshape(-1, output_tensor.shape[-1]) # (B x L, C)
    
    optimizer.zero_grad()

    loss = criterion(output_tensor, expanded_label)
    
    loss.backward()
    
    optimizer.step()
    
    loss_history.append(loss.item())
    
#     print("#############")
#     print(label.T)  
#     print(loss.item())
#     print(output_tensor.shape)
#     print()
    
    
    if i % 20 == 0:
        print(len(loss_history))
        print(loss_history)
        plt.figure()
        plt.plot(loss_history)
        plt.show()
        plt.savefig('{:03d}.png'.format(i), dpi=300)
        
    

In [None]:
metadatum = random.choice(metadata_train)
audio_visual_inspection(metadatum)

In [None]:
while True:
    metadatum = random.choice(metadata_test)
    
    if 'anom' in metadatum[0]: break
    
audio_visual_inspection(metadatum)