In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [2]:
from src.efficient_kan import KAN

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from tqdm import tqdm

In [4]:
import numpy as np
import random
import matplotlib.pyplot as plt

In [5]:
config = {}
config['encoder'] = 'KAN'
config['input_dim'] = 1024*16
config['feature_dim'] = 36864
config['batch_size'] = 1024
config['hidden_layers'] = [128, 64]
config['lr'] = 1e-3
config['epoch'] = 100

In [6]:
import librosa
import librosa.display

In [7]:
class2label = {'blues': 0,
               'classical': 1,
               'country': 2,
               'metal': 3,
               'disco': 4,
               'hiphop': 5,
               'jazz': 6,
               'pop': 7,
               'reggae': 8,
               'rock': 9,}
               
class genres_classification_dataset(Dataset):
    def __init__(self, root='data/Data/genres_original/', input_dim=1024, train=True):
        super(genres_classification_dataset, self).__init__()
        self.classes = os.listdir(root)
        self.data = []
        self.train = train
        for _class in self.classes:
            wavs = os.listdir(os.path.join(root, _class))
            if train:
                sample_wav = wavs[:len(wavs)//2]
            else:
                sample_wav = wavs[len(wavs)//2:]
            for wav in sample_wav:
                sound, sample_rate = librosa.load(os.path.join(root, _class, wav))
                for i in range(len(sound)//input_dim):
                    d={}
                    clip = sound[i*input_dim:(i+1)*input_dim]
                    d['clip'] = clip
                    d['class'] = _class
                    self.data.append(d)

    def __getitem__(self, index):
        data = self.data[index]
        _class = data['class']
        clip = data['clip']
        # sound, sample_rate = librosa.load(wav_path)
        target = torch.from_numpy(clip).squeeze().unsqueeze(0).to(torch.float32)
        onehot = torch.zeros(10)
        onehot[class2label[_class]] = 1
        return target, onehot

    def __len__(self):
        return len(self.data)

trainset = genres_classification_dataset(input_dim=config['input_dim'], train=True)
valset = genres_classification_dataset(input_dim=config['input_dim'], train=False)
trainloader = DataLoader(trainset, batch_size=config['batch_size'], shuffle=True)
valloader = DataLoader(trainset, batch_size=config['batch_size'], shuffle=False)

In [8]:
class KAE(torch.nn.Module):
    def __init__(
        self,
        layers_hidden,
        input_dim,
    ):
        super(KAE, self).__init__()
        self.encoder = KAN([input_dim, layers_hidden[0], layers_hidden[1]])
        self.decoder = KAN([layers_hidden[1], layers_hidden[0], input_dim])

    def forward(self, x: torch.Tensor):
        x = self.decoder(self.encoder(x))
        return x
class simple_classification(torch.nn.Module):
    def __init__(
        self,
        input_dim,
        layers_hidden,
        output_dim=10,
    ):
        super(simple_classification, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, layers_hidden[0]),
            nn.ReLU(inplace=True),
            nn.Linear(layers_hidden[0], layers_hidden[1]),
            nn.ReLU(inplace=True),
            nn.Linear(layers_hidden[1], output_dim)
            )

    def forward(self, x: torch.Tensor):
        x = self.model(x)
        return x

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import sys
new_path = 'third_part/beats'
sys.path.append(new_path)
if config['encoder'] == 'beats':
    from BEATs import BEATs, BEATsConfig
    # load the pre-trained checkpoints
    checkpoint = torch.load('BEATs_iter3_plus_AS2M.pt')
    
    cfg = BEATsConfig(checkpoint['cfg'])
    BEATs_model = BEATs(cfg)
    BEATs_model.load_state_dict(checkpoint['model'])
    BEATs_model.eval()
    BEATs_model.to(device)
    model = simple_classification(config['feature_dim'], config['hidden_layers'])
if config['encoder'] == 'KAE':
    # Define model
    kae_pretrain = torch.load('checkpoints/KAN_music_genres_pretrain.pth')
    kae_pretrain.to(device)
    model = simple_classification(config['feature_dim'], config['hidden_layers'])
if config['encoder'] == 'KAN':
    model = KAN([config['input_dim'], 1024, 64, 10])
model.to(device)
# Define optimizer
optimizer = optim.AdamW(model.parameters(), lr=config['lr'], weight_decay=1e-4)
# Define learning rate scheduler
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.8)

# Define loss
criterion = nn.CrossEntropyLoss()

In [None]:
for epoch in range(config['epoch']):
    # Train
    model.train()
    with tqdm(trainloader) as pbar:
        for i, (_input, label) in enumerate(pbar):
            _input = _input.view(-1, config['input_dim']).to(device)
            label = label.to(device)
            optimizer.zero_grad()
            with torch.no_grad():
                if config['encoder'] == 'beats':
                    feature = BEATs_model.extract_features(_input, padding_mask=torch.zeros(_input.shape[0], config['input_dim']).bool().to(device))[0]
                    feature = feature.reshape(_input.shape[0], -1)
                elif config['encoder'] == 'KAE':
                    feature = kae_pretrain.encoder(_input)
                else:
                    feature = _input
            output = model(feature)
            loss = criterion(output, label)
            loss.backward()
            optimizer.step()
            pbar.set_postfix(loss=loss.detach().cpu().item(), lr=optimizer.param_groups[0]['lr'])

    # # Validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for (_input, label) in valloader:
            _input = _input.view(-1, config['input_dim']).to(device)
            label = label.to(device)
            if config['encoder'] == 'beats':
                feature = BEATs_model.extract_features(_input, padding_mask=torch.zeros(_input.shape[0], config['input_dim']).bool().to(device))[0]
                feature = feature.reshape(_input.shape[0], -1)
            elif config['encoder'] == 'KAE':
                feature = kae_pretrain.encoder(_input)
            else:
                feature = _input
            output = model(feature)
            val_loss += criterion(output, label).item()
            
    val_loss /= len(valloader)

    # Update learning rate
    scheduler.step()

    print(
        f"Epoch {epoch + 1}, Val Loss: {val_loss}"
    )

100%|██████████| 20/20 [00:03<00:00,  5.32it/s, loss=2.14, lr=0.001]


Epoch 1, Val Loss: 2.0353218257427215


100%|██████████| 20/20 [00:03<00:00,  5.84it/s, loss=2.01, lr=0.0008]


Epoch 2, Val Loss: 1.8675489127635956


100%|██████████| 20/20 [00:03<00:00,  5.61it/s, loss=1.81, lr=0.00064]


Epoch 3, Val Loss: 1.648305779695511


100%|██████████| 20/20 [00:03<00:00,  6.23it/s, loss=1.5, lr=0.000512] 


Epoch 4, Val Loss: 1.4207991421222688


100%|██████████| 20/20 [00:03<00:00,  6.16it/s, loss=1.43, lr=0.00041]


Epoch 5, Val Loss: 1.2244133085012436


100%|██████████| 20/20 [00:03<00:00,  6.07it/s, loss=1.24, lr=0.000328]


Epoch 6, Val Loss: 1.0839591175317764


100%|██████████| 20/20 [00:03<00:00,  6.01it/s, loss=1.11, lr=0.000262]


Epoch 7, Val Loss: 0.9583573043346405


100%|██████████| 20/20 [00:03<00:00,  6.00it/s, loss=0.96, lr=0.00021] 


Epoch 8, Val Loss: 0.8838223367929459


100%|██████████| 20/20 [00:03<00:00,  6.03it/s, loss=0.917, lr=0.000168]


Epoch 9, Val Loss: 0.8238902747631073


100%|██████████| 20/20 [00:03<00:00,  6.00it/s, loss=0.776, lr=0.000134]


Epoch 10, Val Loss: 0.7692569702863693


 40%|████      | 8/20 [00:01<00:02,  5.97it/s, loss=0.721, lr=0.000107]

In [None]:
test_data = valset[0]
if config['encoder'] == 'beats':
    feature = BEATs_model.extract_features(test_data[0].to(device), padding_mask=torch.zeros(1, config['input_dim']).bool().to(device))[0]
    feature = feature.reshape(1, -1)
elif config['encoder'] == 'KAE':
    feature = kae_pretrain.encoder(test_data[0].to(device))
else:
    feature = test_data[0].to(device)
pred = model(feature)
print(pred, test_data[1])

In [None]:
padding_mask = torch.zeros(1, 16384).bool().to(device)

representation = BEATs_model.extract_features(test_data[0].to(device), padding_mask=padding_mask)[0]

In [None]:
plt.imshow(representation[0].detach().cpu().numpy())
plt.show()