In [None]:
from scipy.io.wavfile import read
import openl3
import os
import pickle
import numpy as np
from sklearn.linear_model import LogisticRegression
import pandas as pd
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
import torch.optim as optim
from sklearn.metrics import accuracy_score
from sklearn.naive_bayes import GaussianNB
import random


def set_seed():
    random.seed(42)
    np.random.seed(42)
    torch.manual_seed(42)
    torch.manual_seed(42)
    torch.backends.cudnn.deterministic = True


train_df = pd.read_csv('train.csv')
uniq_labels = np.sort(np.unique(train_df['label']))
label_encoder = {label: i for i, label in enumerate(uniq_labels)}
label_decoder = {i: label for i, label in enumerate(uniq_labels)}

emb_model = openl3.models.load_embedding_model(input_repr="mel256", content_type="music", embedding_size=512)


class Masked_Loss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, y_pred, y_true):
        losses = F.log_softmax(y_pred.cuda()) * y_true.cuda()
        return -losses.sum() / 41.


class Augmentations():
    def __init__(self, sr):
        self.energy_boosting_prob = 0.5 # топ рузультат - 0.5
        self.random_noise_prob = 0
        self.mixup_prob = 1
        self.sr = sr

    def energy_boosting(self, wav):
        if np.random.rand() < self.energy_boosting_prob:
            return wav * np.random.uniform(0.5, 2.0)
        else:
            return None

    def random_noise(self, wav):
        if np.random.rand() < self.random_noise_prob:
            return wav + np.float32(np.random.rand(wav.shape[0]))
        else:
            return None

    def mixup(self, wav, each_class_wav, label):
        if np.random.rand() < self.mixup_prob:
            lam = np.random.beta(0.4, 0.4)
            if lam < 0.5:
                lam += 0.5
            augm_class = np.random.randint(41)
            res_label = np.array(label)
            if res_label.argmax() != augm_class:
                res_label[res_label.argmax()] = lam
                res_label[augm_class] = 1 - lam
                return wav * lam + (1 - lam) * each_class_wav[augm_class], res_label
            else:
                return wav, res_label
        else:
            return None


def pad_to_format(wav, sr, targ_sec):
    wav = np.array(wav)
    if wav.shape[0] / sr > targ_sec:
        start_point = np.random.randint(0, wav.shape[0] - targ_sec * sr)
        wav = wav[start_point:start_point + targ_sec * sr]
    else:
        wav = np.hstack((wav, np.zeros(targ_sec * sr - len(wav))))
    return wav


def make_label_matrix(labels):
    mab_matrix = []
    for label in labels:
        lab_vec = np.zeros(41)
        lab_vec[label] = 1
        mab_matrix.append(lab_vec)
    return mab_matrix


def make_augmentation(x, y):
    examples = []
    final_x, final_y = [], []
    for lab in list(label_decoder.keys()):
        examples.append([x[i] for i in range(len(x)) if y[i].argmax() == lab])
    augmentator = Augmentations(16000)
    for i in range(len(x)):
        augm_1_result = augmentator.energy_boosting(x[i])
        augm_2_result = augmentator.random_noise(x[i])
        rand_examples = [examples[i][np.random.randint(len(examples[i]))] for i in range(41)]
        augm_3_result, soft_y = augmentator.mixup(x[i], rand_examples, y[i])

        final_x.append(x[i])
        final_y.append(y[i])
        if augm_1_result is not None:
            final_x.append(augm_1_result)
            final_y.append(y[i])
        if augm_2_result is not None:
            final_x.append(augm_2_result)
            final_y.append(y[i])
        if augm_3_result is not None:
            final_x.append(augm_3_result)
            final_y.append(soft_y)
    return np.array(final_x), np.array(final_y)


class EventDetectionDataset(Dataset):
    def __init__(self, x, y=None):
        self.x = x
        self.y = y
        self.anom_prob = 0.2

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        # добавить сюда prepare_shape()
        if self.y is not None:
            return self.x[idx], self.y[idx]
        return self.x[idx]


class VeryDumb(nn.Module):

    def __init__(self):
        super().__init__()
        self.dense_1 = nn.Linear(512, 1024)
        self.dense_2 = nn.Linear(1024, 2048)
        self.dense_3 = nn.Linear(2048, 512)
        self.dense_4 = nn.Linear(512, 41)
        self.do_1 = nn.Dropout(0.2)
        self.do_2 = nn.Dropout(0.3)

    def forward(self, x):
        x = F.tanh(self.dense_1(x))
        x = F.tanh(self.dense_2(x))
        x = self.do_1(x)
        x = F.tanh(self.dense_3(x))
        x = self.do_2(x)
        x = self.dense_4(x)
        return x


def calc_feature(x, sr):
    features = []
    for wav in x:
        emb, ts = openl3.get_embedding(wav, sr, emb_model, hop_size=1, center=False)
        features.append(emb.mean(axis=0))
    return features


def train():
    train_df = pd.read_csv('train.csv')
    train_folder = 'audio_train/train'
    sample_rate = 16000
    x, y = [], []
    for i in range(train_df.shape[0]):
        sr, wav_data = read(os.path.join(train_folder, train_df['fname'].iloc[i]))
        wav_data = pad_to_format(wav_data, sr, 2)
        x.append(wav_data)
    y = list(train_df['label'])
    y = [label_encoder[i] for i in y]
    x_tr, x_val, y_tr, y_val = train_test_split(x, y, stratify=y, test_size=0.1, random_state=42)
    
    
    y_tr = make_label_matrix(y_tr)
    y_val = make_label_matrix(y_val)
    x_tr, y_tr = make_augmentation(x_tr, y_tr)
    x_tr = calc_feature(x_tr, sample_rate)
    x_val = calc_feature(x_val, sample_rate)
    
    pickle.dump(x_tr, open('openl3_augmented/train_features.pickle', 'wb'))
    pickle.dump(x_val, open('openl3_augmented/validate_features.pickle', 'wb'))

    x_tr = pickle.load(open('openl3_augmented/train_features.pickle', 'rb'))
    x_val = pickle.load(open('openl3_augmented/validate_features.pickle', 'rb'))
    y_tr = pickle.load(open('openl3_augmented/train_targets.pickle', 'rb'))
    y_val = pickle.load(open('openl3_augmented/validate_targets.pickle', 'rb'))

    train_dset = EventDetectionDataset(x_tr, y_tr)
    val_dset = EventDetectionDataset(x_val, y_val)
    train_loader = DataLoader(train_dset, batch_size=200, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dset, batch_size=200, shuffle=False, num_workers=0)

    network = VeryDumb().cuda()
    optimizer = optim.SGD(network.parameters(), lr=1e-2)
    loss_f = Masked_Loss()

    n_epoch = 500

    train_loss = []
    val_loss = []

    train_acc = []
    val_acc = []

    top_score = 0
    top_epoch = 0
    for e in range(n_epoch):
        print('epoch #', e)
        loss_list = []
        outputs = []
        targets = []
        for i_batch, sample_batched in enumerate(train_loader):
            x, y = sample_batched
            x = torch.Tensor(np.float32(x)).view(x.shape[0], x.shape[1]).cuda()
            optimizer.zero_grad()

            output = network(x)
            outputs.append(output.detach().cpu().numpy().argmax(axis=1))

            target = y
            targets.append(target.argmax(axis=1))

            loss = loss_f(output, target.cuda().long())
            loss_list.append(loss.item())
            loss.backward()
            optimizer.step()
        y_true = np.hstack(targets)
        y_pred = np.hstack(outputs)
        acc = accuracy_score(y_true, y_pred)
        train_loss.append(np.mean(loss_list))
        train_acc.append(acc)
        print('mean train loss:', train_loss[-1])
        print('train accuracy:', acc)

        loss_list = []
        outputs = []
        targets = []
        with torch.no_grad():
            for i_batch, sample_batched in enumerate(val_loader):
                x, y = sample_batched
                x = torch.Tensor(np.float32(x)).view(x.shape[0], x.shape[1]).cuda()

                output = network(x)
                outputs.append(output.detach().cpu().numpy().argmax(axis=1))

                target = y
                targets.append(target.argmax(axis=1))

                loss = loss_f(output, target.cuda().long())
                loss_list.append(loss.item())

            y_true = np.hstack(targets)
            y_pred = np.hstack(outputs)
            acc = accuracy_score(y_true, y_pred)
            val_loss.append(np.mean(loss_list))
            val_acc.append(acc)
            print('mean val loss:', val_loss[-1])
            print('val accuracy:', acc)
        if e % 40 == 0:
            for param_group in optimizer.param_groups:
                param_group['lr'] *= 1e-1
        if e % 100 == 0:
            for param_group in optimizer.param_groups:
                param_group['lr'] = 1e-2
        if acc > top_score:
            torch.save(network.state_dict(), 'openl3_augmented.pt')
            top_score = acc
            top_epoch = e



def test():
    test_df = pd.read_csv('sample_submission.csv')
    test_folder = 'audio_test/test'
    sample_rate = 16000
    x = []
    for i in range(test_df.shape[0]):
        sr, wav_data = read(os.path.join(test_folder, test_df['fname'].iloc[i]))
        wav_data = pad_to_format(wav_data, sr, 2)
        x.append(wav_data)

    x = calc_feature(x, sample_rate)
    pickle.dump(x, open('openl3_augmented/test_features.pickle', 'wb'))

    network = VeryDumb().cuda()
    checkpoint = torch.load('openl3_augmented.pt')
    network.load_state_dict(checkpoint)
    network.eval()

    outputs = []

    test_dset = EventDetectionDataset(x, np.zeros((len(x), 41)))
    test_loader = DataLoader(test_dset, batch_size=200, shuffle=False, num_workers=0)
    with torch.no_grad():
        for i_batch, sample_batched in enumerate(test_loader):
            x, y = sample_batched
            x = torch.Tensor(np.float32(x)).view(x.shape[0], x.shape[1]).cuda()
            output = network(x)
            outputs.append(output.detach().cpu().numpy().argmax(axis=1))
        y_pred = np.hstack(outputs)
        real_pred = []
        for i in range(len(y_pred)):
            real_pred.append(label_decoder[y_pred[i]])
        test_df['label'] = real_pred
        test_df.to_csv('submit_augmented_fc.csv', index=False)



set_seed()
train()
test()