In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '3'

In [2]:
from src.efficient_kan import KAN

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from tqdm import tqdm

In [4]:
import numpy as np
import random
import matplotlib.pyplot as plt

In [5]:
config = {}
config['encoder'] = 'Wav2Vec'
config['input_dim'] = 1024*16
config['feature_dim'] = 36864
config['batch_size'] = 1024
config['hidden_layers'] = [1024, 64]
config['lr'] = 1e-3
config['epoch'] = 50

In [6]:
import librosa
import librosa.display

In [7]:
class2label = {'blues': 0,
               'classical': 1,
               'country': 2,
               'metal': 3,
               'disco': 4,
               'hiphop': 5,
               'jazz': 6,
               'pop': 7,
               'reggae': 8,
               'rock': 9,}
               
class genres_classification_dataset(Dataset):
    def __init__(self, root='data/Data/genres_original/', input_dim=1024, train=True):
        super(genres_classification_dataset, self).__init__()
        self.classes = os.listdir(root)
        self.data = []
        self.train = train
        for _class in self.classes:
            wavs = os.listdir(os.path.join(root, _class))
            if train:
                sample_wav = wavs[:len(wavs)//2]
            else:
                sample_wav = wavs[len(wavs)//2:]
            for wav in sample_wav:
                sound, sample_rate = librosa.load(os.path.join(root, _class, wav))
                for i in range(len(sound)//input_dim):
                    d={}
                    clip = sound[i*input_dim:(i+1)*input_dim]
                    d['clip'] = clip
                    d['class'] = _class
                    self.data.append(d)

    def __getitem__(self, index):
        data = self.data[index]
        _class = data['class']
        clip = data['clip']
        # sound, sample_rate = librosa.load(wav_path)
        target = torch.from_numpy(clip).squeeze().unsqueeze(0).to(torch.float32)
        onehot = torch.zeros(10)
        onehot[class2label[_class]] = 1
        return target, onehot

    def __len__(self):
        return len(self.data)

trainset = genres_classification_dataset(input_dim=config['input_dim'], train=True)
valset = genres_classification_dataset(input_dim=config['input_dim'], train=False)
trainloader = DataLoader(trainset, batch_size=config['batch_size'], shuffle=True)
valloader = DataLoader(trainset, batch_size=config['batch_size'], shuffle=False)

In [8]:
class KAE(torch.nn.Module):
    def __init__(
        self,
        layers_hidden,
        input_dim,
    ):
        super(KAE, self).__init__()
        self.encoder = KAN([input_dim, layers_hidden[0], layers_hidden[1]])
        self.decoder = KAN([layers_hidden[1], layers_hidden[0], input_dim])

    def forward(self, x: torch.Tensor):
        x = self.decoder(self.encoder(x))
        return x
class simple_classification(torch.nn.Module):
    def __init__(
        self,
        input_dim,
        layers_hidden,
        output_dim=10,
    ):
        super(simple_classification, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, layers_hidden[0]),
            nn.ReLU(inplace=True),
            nn.Linear(layers_hidden[0], layers_hidden[1]),
            nn.ReLU(inplace=True),
            # nn.LayerNorm(layers_hidden[1]),
            nn.Linear(layers_hidden[1], output_dim)
            )

    def forward(self, x: torch.Tensor):
        x = self.model(x)
        return x

In [27]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import sys
if config['encoder'] == 'beats':
    new_path = 'third_part/beats'
    sys.path.append(new_path)
    from BEATs import BEATs, BEATsConfig
    # load the pre-trained checkpoints
    checkpoint = torch.load('BEATs_iter3_plus_AS2M.pt')
    
    cfg = BEATsConfig(checkpoint['cfg'])
    BEATs_model = BEATs(cfg)
    BEATs_model.load_state_dict(checkpoint['model'])
    BEATs_model.eval()
    BEATs_model.to(device)
    model = simple_classification(config['feature_dim'], config['hidden_layers'])
elif config['encoder'] == 'KAE':
    # Define model
    kae_pretrain = torch.load('checkpoints/KAN_music_genres_pretrain.pth')
    kae_pretrain.to(device)
    model = simple_classification(config['feature_dim'], config['hidden_layers'])
elif config['encoder'] == 'KAN':
    model = KAN([config['input_dim'], 4096, 1024, 64, 10], grid_size=5, spline_order=4,)
elif config['encoder'] == 'SSAST':
    new_path = 'third_part/ssast'
    sys.path.append(new_path)
    from src.models import ASTModel
    pretrain_model_path = 'SSAST-Base-Patch-400.pth'
    import torchaudio.compliance.kaldi as ta_kaldi
    def ssast_preprocess(
                source: torch.Tensor,
                fbank_mean: float = 15.41663,
                fbank_std: float = 6.55582,
        ) -> torch.Tensor:
            fbanks = []
            for waveform in source:
                waveform = waveform.unsqueeze(0)
                waveform = waveform - waveform.mean()
                fbank = ta_kaldi.fbank(waveform, htk_compat=True, sample_frequency=33000, use_energy=False,
                                                      window_type='hanning', num_mel_bins=128, dither=0.0, frame_shift=10)
                fbanks.append(fbank)
            fbank = torch.stack(fbanks, dim=0)
            fbank = (fbank - fbank_mean) / (2 * fbank_std)
            return fbank
    input_tdim = 48  # fine-tuning data length can be different with pretraining data length
    ast_mdl = ASTModel(label_dim=35,
                 fshape=16, tshape=16, fstride=10, tstride=10,
                 input_fdim=128, input_tdim=input_tdim, model_size='base',
                 pretrain_stage=False, load_pretrained_mdl_path=pretrain_model_path)
    model = simple_classification(config['feature_dim'], config['hidden_layers'])
elif config['encoder'] == 'Wav2Vec':
    new_path = 'third_part/soxan'
    sys.path.append(new_path)
    from models import Wav2Vec2ForSpeechClassification
    pretrain_model_path = 'soxan_checkpoints'
    import torchaudio
    from transformers import AutoConfig, Wav2Vec2FeatureExtractor
    feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(pretrain_model_path)
    sampling_rate = feature_extractor.sampling_rate
    
    # for wav2vec
    w2v = Wav2Vec2ForSpeechClassification.from_pretrained(pretrain_model_path).to(device)
    resampler = torchaudio.transforms.Resample(16800, sampling_rate).to(device)
    model = simple_classification(config['feature_dim'], config['hidden_layers'])
else:
    raise NotImplementError
model.to(device)
# Define optimizer
optimizer = optim.AdamW(model.parameters(), lr=config['lr'], weight_decay=1e-4)
# Define learning rate scheduler
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.8)

# Define loss
criterion = nn.CrossEntropyLoss()

In [30]:
for epoch in range(config['epoch']):
    # Train
    model.train()
    with tqdm(trainloader) as pbar:
        for i, (_input, label) in enumerate(pbar):
            _input = _input.view(-1, config['input_dim']).to(device)
            label = label.to(device)
            optimizer.zero_grad()
            with torch.no_grad():
                if config['encoder'] == 'beats':
                    feature = BEATs_model.extract_features(_input, padding_mask=torch.zeros(_input.shape[0], config['input_dim']).bool().to(device))[0]
                    feature = feature.reshape(_input.shape[0], -1)
                elif config['encoder'] == 'KAE':
                    feature = kae_pretrain.encoder(_input)
                elif config['encoder'] == 'SSAST':
                    filter_bank = ssast_preprocess(_input)
                    feature = ast_mdl(filter_bank, task='extract_feature')
                    feature = feature.reshape(_input.shape[0], -1)
                elif config['encoder'] == 'Wav2Vec':
                    speech = resampler(_input).squeeze()
                    inputs = feature_extractor(speech, sampling_rate=sampling_rate, return_tensors="pt", padding=True)
                    inputs = {key: inputs[key].to(device).squeeze() for key in inputs}
                    feature = w2v.extract_feature(**inputs)
                    feature = feature.reshape(_input.shape[0], -1)
                else:
                    feature = _input
                    
            # if config['encoder'] == 'KAN':
            #     output = model(feature, update_grid=(i % 20 == 0))
            # else:
            output = model(feature)
            loss = criterion(output, label)
            loss.backward()
            optimizer.step()
            pbar.set_postfix(loss=loss.detach().cpu().item(), lr=optimizer.param_groups[0]['lr'])

    # # Validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for i, (_input, label) in enumerate(valloader):
            _input = _input.view(-1, config['input_dim']).to(device)
            label = label.to(device)
            if config['encoder'] == 'beats':
                feature = BEATs_model.extract_features(_input, padding_mask=torch.zeros(_input.shape[0], config['input_dim']).bool().to(device))[0]
                feature = feature.reshape(_input.shape[0], -1)
            elif config['encoder'] == 'KAE':
                feature = kae_pretrain.encoder(_input)
            elif config['encoder'] == 'SSAST':
                filter_bank = ssast_preprocess(_input)
                feature = ast_mdl(filter_bank, task='extract_feature')
                feature = feature.reshape(_input.shape[0], -1)
            elif config['encoder'] == 'Wav2Vec':
                speech = resampler(_input).squeeze()
                inputs = feature_extractor(speech, sampling_rate=sampling_rate, return_tensors="pt", padding=True)
                inputs = {key: inputs[key].to(device).squeeze() for key in inputs}
                feature = w2v.extract_feature(**inputs)
                feature = feature.reshape(_input.shape[0], -1)
            else:
                feature = _input
            output = model(feature)
            val_loss += criterion(output, label).item()
            
    val_loss /= len(valloader)

    # Update learning rate
    scheduler.step()

    print(
        f"Epoch {epoch + 1}, Val Loss: {val_loss}"
    )
torch.save(model, os.path.join('checkpoints', config['encoder']+'_'+str(config['epoch'])+'_'+str(val_loss)+'.pth'))

100%|██████████| 20/20 [00:18<00:00,  1.08it/s, loss=1.65, lr=0.001]


Epoch 1, Val Loss: 1.6312610328197479


100%|██████████| 20/20 [00:18<00:00,  1.06it/s, loss=1.63, lr=0.0008]


Epoch 2, Val Loss: 1.5856821477413177


100%|██████████| 20/20 [00:21<00:00,  1.06s/it, loss=1.52, lr=0.00064]


Epoch 3, Val Loss: 1.5677696883678436


100%|██████████| 20/20 [00:21<00:00,  1.06s/it, loss=1.48, lr=0.000512]


Epoch 4, Val Loss: 1.550961172580719


100%|██████████| 20/20 [00:19<00:00,  1.05it/s, loss=1.54, lr=0.00041]


Epoch 5, Val Loss: 1.5517475485801697


100%|██████████| 20/20 [00:19<00:00,  1.05it/s, loss=1.52, lr=0.000328]


Epoch 6, Val Loss: 1.5335833489894868


100%|██████████| 20/20 [00:18<00:00,  1.06it/s, loss=1.53, lr=0.000262]


Epoch 7, Val Loss: 1.525952821969986


100%|██████████| 20/20 [00:19<00:00,  1.05it/s, loss=1.53, lr=0.00021]


Epoch 8, Val Loss: 1.5200729072093964


100%|██████████| 20/20 [00:19<00:00,  1.05it/s, loss=1.62, lr=0.000168]


Epoch 9, Val Loss: 1.5126644670963287


100%|██████████| 20/20 [00:19<00:00,  1.05it/s, loss=1.48, lr=0.000134]


Epoch 10, Val Loss: 1.5128972023725509


100%|██████████| 20/20 [00:19<00:00,  1.05it/s, loss=1.5, lr=0.000107] 


Epoch 11, Val Loss: 1.5062468856573106


100%|██████████| 20/20 [00:19<00:00,  1.04it/s, loss=1.42, lr=8.59e-5]


Epoch 12, Val Loss: 1.4989856511354447


100%|██████████| 20/20 [00:19<00:00,  1.04it/s, loss=1.48, lr=6.87e-5]


Epoch 13, Val Loss: 1.500696748495102


100%|██████████| 20/20 [00:19<00:00,  1.05it/s, loss=1.37, lr=5.5e-5]


Epoch 14, Val Loss: 1.4973071992397309


100%|██████████| 20/20 [00:19<00:00,  1.05it/s, loss=1.51, lr=4.4e-5]


Epoch 15, Val Loss: 1.4959874749183655


100%|██████████| 20/20 [00:19<00:00,  1.04it/s, loss=1.49, lr=3.52e-5]


Epoch 16, Val Loss: 1.4929079294204712


100%|██████████| 20/20 [00:19<00:00,  1.04it/s, loss=1.44, lr=2.81e-5]


Epoch 17, Val Loss: 1.4931572258472443


100%|██████████| 20/20 [00:19<00:00,  1.04it/s, loss=1.51, lr=2.25e-5]


Epoch 18, Val Loss: 1.4912997961044312


100%|██████████| 20/20 [00:19<00:00,  1.04it/s, loss=1.54, lr=1.8e-5]


Epoch 19, Val Loss: 1.490935242176056


100%|██████████| 20/20 [00:19<00:00,  1.05it/s, loss=1.46, lr=1.44e-5]


Epoch 20, Val Loss: 1.4903143256902696


100%|██████████| 20/20 [00:18<00:00,  1.06it/s, loss=1.56, lr=1.15e-5]


Epoch 21, Val Loss: 1.490217262506485


100%|██████████| 20/20 [00:18<00:00,  1.06it/s, loss=1.5, lr=9.22e-6] 


Epoch 22, Val Loss: 1.4896993815898896


100%|██████████| 20/20 [00:18<00:00,  1.06it/s, loss=1.47, lr=7.38e-6]


Epoch 23, Val Loss: 1.489520800113678


100%|██████████| 20/20 [00:19<00:00,  1.05it/s, loss=1.51, lr=5.9e-6]


Epoch 24, Val Loss: 1.489421832561493


100%|██████████| 20/20 [00:19<00:00,  1.05it/s, loss=1.51, lr=4.72e-6]


Epoch 25, Val Loss: 1.4891971409320832


100%|██████████| 20/20 [00:19<00:00,  1.05it/s, loss=1.48, lr=3.78e-6]


Epoch 26, Val Loss: 1.4891792505979538


100%|██████████| 20/20 [00:19<00:00,  1.05it/s, loss=1.4, lr=3.02e-6] 


Epoch 27, Val Loss: 1.4891421258449555


100%|██████████| 20/20 [00:18<00:00,  1.06it/s, loss=1.48, lr=2.42e-6]


Epoch 28, Val Loss: 1.4889690548181533


100%|██████████| 20/20 [00:19<00:00,  1.05it/s, loss=1.48, lr=1.93e-6]


Epoch 29, Val Loss: 1.4889730513095856


100%|██████████| 20/20 [00:19<00:00,  1.05it/s, loss=1.54, lr=1.55e-6]


Epoch 30, Val Loss: 1.4888917446136474


100%|██████████| 20/20 [00:18<00:00,  1.06it/s, loss=1.47, lr=1.24e-6]


Epoch 31, Val Loss: 1.4888496547937393


100%|██████████| 20/20 [00:18<00:00,  1.06it/s, loss=1.47, lr=9.9e-7]


Epoch 32, Val Loss: 1.4888130277395248


100%|██████████| 20/20 [00:19<00:00,  1.05it/s, loss=1.37, lr=7.92e-7]


Epoch 33, Val Loss: 1.488776046037674


100%|██████████| 20/20 [00:18<00:00,  1.05it/s, loss=1.54, lr=6.34e-7]


Epoch 34, Val Loss: 1.4887786149978637


100%|██████████| 20/20 [00:18<00:00,  1.06it/s, loss=1.54, lr=5.07e-7]


Epoch 35, Val Loss: 1.4887514531612396


100%|██████████| 20/20 [00:19<00:00,  1.05it/s, loss=1.54, lr=4.06e-7]


Epoch 36, Val Loss: 1.488738626241684


100%|██████████| 20/20 [00:18<00:00,  1.06it/s, loss=1.48, lr=3.25e-7]


Epoch 37, Val Loss: 1.4887233793735504


100%|██████████| 20/20 [00:18<00:00,  1.06it/s, loss=1.5, lr=2.6e-7] 


Epoch 38, Val Loss: 1.4887136667966843


100%|██████████| 20/20 [00:19<00:00,  1.05it/s, loss=1.45, lr=2.08e-7]


Epoch 39, Val Loss: 1.488707810640335


100%|██████████| 20/20 [00:18<00:00,  1.06it/s, loss=1.47, lr=1.66e-7]


Epoch 40, Val Loss: 1.4887009739875794


100%|██████████| 20/20 [00:19<00:00,  1.05it/s, loss=1.48, lr=1.33e-7]


Epoch 41, Val Loss: 1.488697847723961


100%|██████████| 20/20 [00:18<00:00,  1.05it/s, loss=1.48, lr=1.06e-7]


Epoch 42, Val Loss: 1.4886965245008468


100%|██████████| 20/20 [00:18<00:00,  1.06it/s, loss=1.43, lr=8.51e-8]


Epoch 43, Val Loss: 1.4886909246444702


100%|██████████| 20/20 [00:19<00:00,  1.05it/s, loss=1.5, lr=6.81e-8] 


Epoch 44, Val Loss: 1.4886894971132278


100%|██████████| 20/20 [00:21<00:00,  1.06s/it, loss=1.52, lr=5.44e-8]


Epoch 45, Val Loss: 1.4886873662471771


100%|██████████| 20/20 [00:18<00:00,  1.06it/s, loss=1.41, lr=4.36e-8]


Epoch 46, Val Loss: 1.4886861741542816


100%|██████████| 20/20 [00:19<00:00,  1.05it/s, loss=1.43, lr=3.48e-8]


Epoch 47, Val Loss: 1.4886856436729432


100%|██████████| 20/20 [00:18<00:00,  1.05it/s, loss=1.54, lr=2.79e-8]


Epoch 48, Val Loss: 1.4886843681335449


100%|██████████| 20/20 [00:19<00:00,  1.05it/s, loss=1.51, lr=2.23e-8]


Epoch 49, Val Loss: 1.4886827677488328


100%|██████████| 20/20 [00:18<00:00,  1.06it/s, loss=1.43, lr=1.78e-8]


Epoch 50, Val Loss: 1.4886819303035737


In [None]:
test_data = valset[0]
if config['encoder'] == 'beats':
    feature = BEATs_model.extract_features(test_data[0].to(device), padding_mask=torch.zeros(1, config['input_dim']).bool().to(device))[0]
    feature = feature.reshape(1, -1)
elif config['encoder'] == 'KAE':
    feature = kae_pretrain.encoder(test_data[0].to(device))
else:
    feature = test_data[0].to(device)
pred = model(feature)
print(pred, test_data[1])

In [None]:
if config['encoder'] == 'beats':
    padding_mask = torch.zeros(1, 16384).bool().to(device)
    
    representation = BEATs_model.extract_features(test_data[0].to(device), padding_mask=padding_mask)[0]
    plt.imshow(representation[0].detach().cpu().numpy())
    plt.show()