In [22]:
import os
import sys
from torchaudio.backend.soundfile_backend import load
from torch.utils.data import Dataset,DataLoader, ConcatDataset
from torchaudio.transforms import MelSpectrogram
import torch.nn as nn
import torch 
import numpy as np
import gc
from tqdm import tqdm
from torchvision.models import convnext_tiny, ConvNeXt_Tiny_Weights, resnet50, ResNet50_Weights, ResNet18_Weights, resnet18
import torch.nn.functional as F
import math
from torch import flatten
import matplotlib.pyplot as plt
from torchvision import transforms
from torch.optim import lr_scheduler
import pandas as pd
from sklearn.metrics import roc_auc_score
#

params = {'eval': True,
        'num_classes': 3,
        'batch_size': 32,
        'epochs': 1,
        'win_len': 626, 
        'mel_bins': 128, 
        'hop_len': 313,
        'n_fft' : 626,
        'power' : 2.0,
        'use_log_mel': False,
        'arcface': True,
        'net': 'STGramNet' # resnet50, SepSTGramNet, STGramNet
       }

### ArcFaceLoss

In [23]:
class ArcFace(nn.Module):
    def __init__(self, embed_size, num_classes, scale=64, margin=0.5, easy_margin=False, **kwargs):
        super().__init__()
        self.scale = scale
        self.margin = margin
        self.ce = nn.CrossEntropyLoss()
        self.weight = nn.Parameter(torch.FloatTensor(num_classes, embed_size))
        self.easy_margin = easy_margin
        self.cos_m = math.cos(margin)
        self.sin_m = math.sin(margin)
        self.th = math.cos(math.pi - margin)
        self.mm = math.sin(math.pi - margin) * margin

        nn.init.xavier_uniform_(self.weight)

    def forward(self, embedding: torch.Tensor, ground_truth):
        # --------------------------- cos(theta) & phi(theta) ---------------------------
        cos_theta = F.linear(F.normalize(embedding), F.normalize(self.weight)).clamp(-1 + 1e-7, 1 - 1e-7)
        sin_theta = torch.sqrt((1.0 - torch.pow(cos_theta, 2)).clamp(-1 + 1e-7, 1 - 1e-7))
        phi = cos_theta * self.cos_m - sin_theta * self.sin_m
        if self.easy_margin:
            phi = torch.where(cos_theta > 0, phi, cos_theta)
        else:
            phi = torch.where(cos_theta > self.th, phi, cos_theta - self.mm)
        # --------------------------- convert label to one-hot ---------------------------
        one_hot = torch.zeros(cos_theta.size(), device='cuda')
        one_hot.scatter_(1, ground_truth.view(-1, 1).long(), 1)
        # -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
        output = (one_hot * phi) + (
                (1.0 - one_hot) * cos_theta)  # you can use torch.where if your torch.__version__ is 0.4
        output *= self.scale

        loss = self.ce(output, ground_truth)
        return loss, output

## SepSTGramNet

In [24]:
class TGramNet(nn.Module):
    def __init__(self, mel_bins,win_len, hop_len, num_layers=3, **kwargs):
        super(TGramNet, self).__init__()
        
        self.conv1d = nn.Conv1d(in_channels=1, out_channels=mel_bins, kernel_size=win_len, stride=hop_len, padding=win_len // 2, bias=False)
        self.conv_encoder = nn.Sequential(
            *[nn.Sequential(
                nn.LayerNorm(int(160000//hop_len)+1),
                nn.GELU(),
                nn.Conv1d(mel_bins, mel_bins, 3, 1, 1, bias=False)
            ) for _ in range(num_layers)]
        )
        
    def forward(self, x_w):
        x_w = self.conv1d(x_w)
        x_w = self.conv_encoder(x_w)
        return x_w

class SepSTGramNet(nn.Module):
    def __init__(self, params):
        super(SepSTGramNet, self).__init__()
        self.tgramnet = TGramNet(**params)
        
        self.t_resnet=resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
        for p in self.t_resnet.parameters():
            p.requires_grad = True
        self.t_resnet.fc=nn.Linear(in_features=512, out_features=384, bias=True)
        
        self.s_resnet=resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
        for p in self.s_resnet.parameters():
            p.requires_grad = True
        self.s_resnet.fc=nn.Linear(in_features=512, out_features=384, bias=True)
        
        if not params["arcface"]:
            self.fc=nn.Sequential(nn.Linear(in_features=384*2, out_features=70),   
                               nn.BatchNorm1d(70),
                               nn.GELU(),
                               nn.Linear(in_features=70, out_features=params["num_classes"]))
        else:
            self.fc=nn.Sequential(nn.Linear(in_features=384*2, out_features=70),   
                               nn.BatchNorm1d(70),
                               nn.GELU())
            
    def forward(self, x_wav, x_mel):
        x_t = self.tgramnet(x_wav).unsqueeze(1).repeat(1, 3, 1, 1)
        x_t = self.t_resnet(x_t)
        x_mel = self.s_resnet(x_mel)
        
        if params['use_log_mel']:
            x_log_mel = 20.0 / params["power"] * torch.log10(x_mel + sys.float_info.epsilon)
            x_log_mel = self.l_resnet(x_log_mel)

            x = torch.cat((x_mel, x_mel, x_log_mel), dim=1).to(device) 
        else:
            x = torch.cat((x_mel, x_mel), dim=1).to(device) 
        
        out = self.fc(x)
        return out

## STGramNet

In [25]:
class TGramNet(nn.Module):
    def __init__(self, mel_bins,win_len, hop_len, num_layers=3, **kwargs):
        super(TGramNet, self).__init__()
        
        self.conv1d = nn.Conv1d(in_channels=1, out_channels=mel_bins, kernel_size=win_len, stride=hop_len, padding=win_len // 2, bias=False)
        self.conv_encoder = nn.Sequential(
            *[nn.Sequential(
                nn.LayerNorm(int(160000//hop_len)+1),
                nn.GELU(),
                nn.Conv1d(mel_bins, mel_bins, 3, 1, 1, bias=False)
            ) for _ in range(num_layers)]
        )
        
    def forward(self, x_w):
        x_w = self.conv1d(x_w)
        x_w = self.conv_encoder(x_w)
        return x_w

class STGramNet(nn.Module):
    def __init__(self, params):
        super(STGramNet, self).__init__()
        self.tgramnet = TGramNet(**params)
        self.resnet = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
        if not params["arcface"]:
            self.resnet.fc = torch.nn.Linear(self.resnet.fc.in_features, params["num_classes"])
        else:
            self.resnet = torch.nn.Sequential(*(list(self.resnet.children())[:-1]))
            
    def forward(self, x_wav, x_mel, label=None):
        x_t = self.tgramnet(x_wav).unsqueeze(1)
        zeros = torch.zeros_like(x_t, requires_grad=True).float()
        x = torch.cat((x_mel, x_t, zeros), dim=1).to(device) 
        out = self.resnet(x)
        return out


### Model definition

In [26]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if params['net'] == 'resnet50':
    net = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
    for p in net.parameters():
        p.requires_grad = False

    if not params["arcface"]:
        net.fc = torch.nn.Linear(net.fc.in_features, params["num_classes"])
        criterion = torch.nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(net.fc.parameters(), lr=1e-3)
    else:
        net = torch.nn.Sequential(*(list(net.children())[:-1]))
        criterion = ArcFace(2048, params["num_classes"], scale=2, margin=0.1).to(device)
        optimizer = torch.optim.Adam(criterion.parameters(), lr=1e-3)

elif params['net'] == 'SepSTGramNet':
    net = SepSTGramNet(params)
    if not params["arcface"]:
        optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
        criterion = torch.nn.CrossEntropyLoss()
    else:
        criterion = ArcFace(70, params["num_classes"], scale=2, margin=0.1).to(device)        
        net_params = [{'params': net.parameters()}, {'params': criterion.parameters()}]        
        optimizer = torch.optim.Adam(net_params, lr=1e-3)

elif params["net"] == "STGramNet":
    net = STGramNet(params)
    if not params["arcface"]:
        optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
        criterion = torch.nn.CrossEntropyLoss()
    else:
        criterion = ArcFace(512, params["num_classes"], scale=2, margin=0.1).to(device)        
        net_params = [{'params': net.parameters()}, {'params': criterion.parameters()}]        
        optimizer = torch.optim.Adam(net_params, lr=1e-3)
        
net = net.to(device)

### Datasets

In [27]:
class AudioDataset(Dataset):
    def __init__(self, dir_path, params, dev=True, audio_transform=None):
        self.dir_path = dir_path
        self.files = os.listdir(dir_path)
        self.audio_transform = audio_transform
        self.dev = dev
        self.params = params
        
        
    def __len__(self):
        return len(self.files)

    
    def __getitem__(self, idx):
        filename = self.files[idx]
        path = self.dir_path + '/' + filename
        machine_id = path.split('/')[-1].split('_')[-2]
        label_path = path.split('/')[-1].split('_')[0]
        
        if label_path == "normal":
            anomaly_label = 1
        elif label_path == "anomaly":
            anomaly_label = 0
        else:
            anomaly_label = 0
            
        audio, sr = load(path)
        mel_spectrogram = MelSpectrogram(sample_rate=sr, n_fft=params['n_fft'], hop_length=params['hop_len'], power=params["power"])
        mel_spectr = mel_spectrogram(audio)
        if self.audio_transform is not None:
            audio = self.audio_transform(audio)
        return audio, mel_spectr, self._return_one_hot(machine_id), anomaly_label, filename
    
    
    def _return_one_hot(self, machine_id):
        machine_id = int(machine_id)
        if not self.dev:
            machine_id -= 1
        machine_id //= 2
        if self.params["arcface"]:
            return torch.tensor(machine_id).long()
        t = np.zeros(self.params["num_classes"])
        t[machine_id] = 1
        return torch.tensor(t).float()

    
# load datasets    

dev_train_path = "/kaggle/input/eurecom-aml-2023-challenge-2/dev_data/dev_data/slider/train"
dev_test_path = "/kaggle/input/eurecom-aml-2023-challenge-2/dev_data/dev_data/slider/test"
eval_train_path = "/kaggle/input/eurecom-aml-2023-challenge-2/eval_data/eval_data/slider/train"
eval_test_path = "/kaggle/input/eurecom-aml-2023-challenge-2/eval_data/eval_data/slider/test"

dev_train_ds = AudioDataset(dev_train_path, params)
eval_train_ds = AudioDataset(eval_train_path, params, dev=False)
dev_test_ds = AudioDataset(dev_test_path, params)
eval_test_ds = AudioDataset(eval_test_path, params, dev=False)

dev_train_dl = DataLoader(dev_train_ds, batch_size=params["batch_size"], shuffle=True, num_workers=2)
dev_test_dl = DataLoader(dev_test_ds, batch_size=params["batch_size"])
eval_train_dl = DataLoader(eval_train_ds, batch_size=params["batch_size"], shuffle=True, num_workers=2)
eval_test_dl = DataLoader(eval_test_ds, batch_size=params["batch_size"])

## DEV - Train

In [28]:
if not params["eval"]:
    net.train()
    l = []
    for e in range(params["epochs"]):
        print(f"epoch {e}th")
        epoch_loss = []
        for audio, spectrogram, machine_id, label, filename in tqdm(dev_train_dl):
            optimizer.zero_grad()
            if params["net"] == "SepSTGramNet":
                output = net(audio.to(device), spectrogram.to(device).repeat(1, 3, 1, 1)).squeeze() 
            elif params["net"] == "STGramNet":
                output = net(audio.to(device), spectrogram.to(device)).squeeze() 
            else:
                output = net(spectrogram.to(device).repeat(1, 3, 1, 1)).squeeze()
            if params["arcface"]:
                loss, _ = criterion(output, machine_id.to(device))
            else:
                loss = criterion(output, machine_id.to(device))
            loss.backward()
            optimizer.step()
            epoch_loss.append(loss.item())
        l.append(np.array(epoch_loss).mean())
        print(np.array(epoch_loss).mean())
    print('done')

    plt.plot(l)

## DEV - Test

In [29]:
if not params["eval"]:
    net.eval()
    names = []
    scores = []
    labels = []
    with torch.no_grad():
        for audio, spectrogram, machine_id, label, filename in tqdm(dev_test_dl):
            if not params["arcface"]:
                if params["net"] == "SepSTGramNet":
                    output = net.forward(audio.to(device), spectrogram.to(device).repeat(1, 3, 1, 1))
                elif params["net"] == "STGramNet":
                    output = net.forward(audio.to(device), spectrogram.to(device))
                else:
                    output = net.forward(spectrogram.to(device).repeat(1, 3, 1, 1))
                softmax = output.softmax(dim=1)
                score = softmax[:, machine_id.to(device).argmax(dim=1)].diag()
                # prob = 1 - score
            else:
                if params["net"] == "SepSTGramNet":
                    _, output = criterion(net.forward(audio.to(device), spectrogram.to(device).repeat(1, 3, 1, 1)), machine_id.to(device))
                elif params["net"] == "STGramNet":
                    _, output = criterion(net.forward(audio.to(device), spectrogram.to(device)).squeeze(2, 3), machine_id.to(device))
                else:
                    _, output = criterion(net.forward(spectrogram.to(device).repeat(1, 3, 1, 1)).squeeze(2, 3), machine_id.to(device))
                softmax = output.softmax(dim=1)
                score = softmax[:, machine_id.to(device)].diag()
            names.extend(filename)
            scores.extend(score.cpu().detach().numpy())
            labels.extend(label.cpu().detach().numpy())
            
    print(roc_auc_score(labels, scores))

## EVAL - Train and Test

In [32]:
if params["eval"]:
    
    # train
    net.train()
    l = []
    for e in range(params["epochs"]):
        print(f"epoch {e}th")
        epoch_loss = []
        for audio, spectrogram, machine_id, label, filename in tqdm(eval_train_dl):
            optimizer.zero_grad()
            if params["net"] == "SepSTGramNet":
                output = net(audio.to(device), spectrogram.to(device).repeat(1, 3, 1, 1)).squeeze()
            elif params["net"] == "STGramNet":
                output = net(audio.to(device), spectrogram.to(device))
            else:
                output = net(spectrogram.to(device).repeat(1, 3, 1, 1)).squeeze()
            if params["arcface"]:
                if params["net"] == "STGramNet":
                    loss, _ = criterion(output.squeeze(2, 3), machine_id.to(device))
                else:
                    loss, _ = criterion(output, machine_id.to(device))
            else:
                loss = criterion(output, machine_id.to(device))
            loss.backward()
            optimizer.step()
            epoch_loss.append(loss.item())
        l.append(np.array(epoch_loss).mean())
        print(np.array(epoch_loss).mean())
    print('done')
    
    # eval
    net.eval()
    names = []
    scores = []
    labels = []
    with torch.no_grad():
        for audio, spectrogram, machine_id, label, filename in tqdm(eval_test_dl):
            if not params["arcface"]:
                if params["net"] == "SepSTGramNet":
                    output = net.forward(audio.to(device), spectrogram.to(device).repeat(1, 3, 1, 1))
                elif params["net"] == "STGramNet":
                    output = net.forward(audio.to(device), spectrogram.to(device))
                else:
                    output = net.forward(spectrogram.to(device).repeat(1, 3, 1, 1))
                softmax = output.softmax(dim=1)
                score = softmax[:, machine_id.to(device).argmax(dim=1)].diag()
            else:
                if params["net"] == "SepSTGramNet":
                    _, output = criterion(net.forward(audio.to(device), spectrogram.to(device).repeat(1, 3, 1, 1)), machine_id.to(device))
                elif params["net"] == "STGramNet":
                    _, output = criterion(net.forward(audio.to(device), spectrogram.to(device)).squeeze(2, 3), machine_id.to(device))
                else:
                    _, output = criterion(net.forward(spectrogram.to(device).repeat(1, 3, 1, 1)).squeeze(2, 3), machine_id.to(device))
                softmax = output.softmax(dim=1)
                score = softmax[:, machine_id.to(device)].diag()
            
            prob = 1 - score
            names.extend(filename)
            scores.extend(prob.cpu().detach().numpy())
            labels.extend(label.cpu().detach().numpy())
            
    submission_df = pd.DataFrame(names, columns=["file_name"])
    submission_df["anomaly_score"] = scores
    submission_df.to_csv("/kaggle/working/sample_submission.csv", index=False)

epoch 0th


100%|██████████| 75/75 [00:21<00:00,  3.46it/s]


0.268907280365626
done


100%|██████████| 27/27 [00:14<00:00,  1.82it/s]


5 epochs
mel: 0.8461880982105701
log_mel: 0.7744319600499376

10 epochs
mel: 0.8667582188930503
log_mel: 0.7555222638368705

20 epochs
mel: 0.8175343320848939
log_mel: 0.7449105285060342

ArcFaceLoss
batch_size=32, epochs=10, s=2, m=0.1
0.9132792342904703