In [8]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import pandas as pd
import matplotlib.pyplot as plt

from encoder import Encoder, load_encoder
from decoder import Decoder, load_decoder
from train import train_epoch, test_epoch
from utility import get_all_files_paths

In [9]:
TRAIN = False

BATCH_SIZE = 128
LATENT_SPACE_DIM = 128
LEARNING_RATE = 1e-5

SAVE_ROUND = 20
NUM_EPOCHS = 500

torch.manual_seed(0) # random seed for reproducible results

<torch._C.Generator at 0x7f6d54145150>

In [10]:
dataset_path = "data/spec/GTZAN_646"
model_save_path = "models/Echoes"
csv_save_path = "output/Echoes_output"

os.makedirs(model_save_path, exist_ok=True)
os.makedirs(csv_save_path, exist_ok=True)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

print(f'device: {device}')

device: cuda


In [11]:
class AudioDataset(Dataset):
    def __init__(self, file_paths):
        self.file_paths = file_paths

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        file_path = self.file_paths[idx]
        data = np.load(file_path)
        data = data[np.newaxis, :, :]  # Add a channel dimension
        data = torch.tensor(data, dtype=torch.float32)
        
        filename = os.path.basename(file_path)
        label_str = filename.split('.')[0]
        label = self.label_to_index(label_str)
        
        return data, label

    @staticmethod
    def label_to_index(label_str):
        label_map = {'blues': 'blues', 'disco': 'disco', 'rock': 'rock', 'metal': 'metal', 'classical': 'classical', 'pop': 'pop', 'reggae':'reggae','country':'country', 'hiphop':'hiphop', 'jazz':'jazz'}
        return label_map.get(label_str, -1)  # Return -1 if label is not found

genres = ['blues', 'disco', 'rock', 'metal', 'pop', 'classical', 'reggae', 'country','hiphop','jazz']

genre_file_paths = {genre: get_all_files_paths(f"{dataset_path}/{genre}", [".npy"]) for genre in genres}

train_file_paths = []
valid_file_paths = []
test_file_paths = []

for genre, paths in genre_file_paths.items():
    m = len(paths)
    test_size = int(m * 0.8)
    valid_size = int((m - test_size) * 0.1)
    train_size = m - test_size - valid_size
    
    paths = np.array(paths)
    np.random.shuffle(paths)
    
    train_paths = paths[:train_size]
    valid_paths = paths[train_size:train_size + valid_size]
    test_paths = paths[train_size + valid_size:]
    
    train_file_paths.extend(train_paths)
    valid_file_paths.extend(valid_paths)
    test_file_paths.extend(test_paths)

train_dataset = AudioDataset(train_file_paths)
valid_dataset = AudioDataset(valid_file_paths)
test_dataset = AudioDataset(test_file_paths)

batch_size = 128

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

print(f'Training set size: {len(train_dataset)}')
print(f'Validation set size: {len(valid_dataset)}')
print(f'Test set size: {len(test_dataset)}')

Training set size: 180
Validation set size: 20
Test set size: 800


In [12]:
if TRAIN:
	loss_fn = torch.nn.MSELoss()
	
	encoder = Encoder(encoded_space_dim=LATENT_SPACE_DIM)
	decoder = Decoder(encoded_space_dim=LATENT_SPACE_DIM)
	params_to_optimize = [
		{'params': encoder.parameters()},
		{'params': decoder.parameters()}
	]

	optim = torch.optim.Adam(params_to_optimize, lr=LEARNING_RATE, weight_decay=1e-05)

	encoder = encoder.to(device)
	decoder = decoder.to(device)
	
	losses = {'train_loss':[],'val_loss':[]}

	for epoch in range(NUM_EPOCHS):
		train_loss =train_epoch(encoder, decoder, device,train_loader, loss_fn, optim)
		val_loss = test_epoch(encoder, decoder, device, test_loader, loss_fn)
	
		print('\n EPOCH {}/{} \t train loss {} \t val loss {}'.format(epoch + 1, NUM_EPOCHS,train_loss,val_loss))

		# track losses
		losses['train_loss'].append(train_loss)
		losses['val_loss'].append(val_loss)

		# save model
		if (epoch + 1) % SAVE_ROUND == 0: 
			torch.save(encoder.state_dict(), f'{model_save_path}/encoder_{epoch+1}.pth')
			torch.save(decoder.state_dict(), f'{model_save_path}/decoder_{epoch+1}.pth')

	plt.figure(figsize=(8,6))
	plt.semilogy(losses['train_loss'], label='Train')
	plt.semilogy(losses['val_loss'], label='Valid')
	plt.xlabel('Epoch')
	plt.ylabel('Average Loss')
	# plt.grid()
	plt.legend()
	plt.title('loss')
	plt.show()
	
	encoder_path = f"{model_save_path}/encoder.pth"
	decoder_path = f"{model_save_path}/decoder.pth"
	torch.save(encoder.state_dict(), encoder_path)
	torch.save(decoder.state_dict(), decoder_path)

In [14]:
from tqdm import tqdm

if not TRAIN:
	encoder_path = f"{model_save_path}/encoder_300.pth"
	encoder = load_encoder(encoder_path, LATENT_SPACE_DIM)
	encoder = encoder.to(device)

encoded_samples = []
for sample in tqdm(test_dataset):
    img = sample[0].unsqueeze(0).to(device)
    label = sample[1]
    encoder.eval()
    with torch.no_grad():
        encoded_img  = encoder(img)

    encoded_img = encoded_img.flatten().cpu().numpy()
    encoded_sample = {f"Enc. Variable {i}": enc for i, enc in enumerate(encoded_img)}
    encoded_sample['label'] = label
    encoded_samples.append(encoded_sample)
encoded_samples = pd.DataFrame(encoded_samples)
encoded_samples.to_csv(f"{csv_save_path}/gtzan_encoded.csv", index=False)
encoded_samples

  model.load_state_dict(torch.load(model_path, map_location=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")))
100%|██████████| 800/800 [00:01<00:00, 613.09it/s]


Unnamed: 0,Enc. Variable 0,Enc. Variable 1,Enc. Variable 2,Enc. Variable 3,Enc. Variable 4,Enc. Variable 5,Enc. Variable 6,Enc. Variable 7,Enc. Variable 8,Enc. Variable 9,...,Enc. Variable 119,Enc. Variable 120,Enc. Variable 121,Enc. Variable 122,Enc. Variable 123,Enc. Variable 124,Enc. Variable 125,Enc. Variable 126,Enc. Variable 127,label
0,0.222251,0.747878,0.343024,1.020067,-0.482725,0.411456,-0.859440,1.467224,-0.780477,0.909242,...,0.602332,-0.512713,0.354799,0.731263,0.239950,-0.710075,-1.261259,0.665961,0.401077,blues
1,0.252542,-0.075412,0.165016,-0.114197,0.056422,-0.295308,0.880866,0.270192,0.030375,0.054961,...,-0.501592,-0.022005,0.054017,0.379239,-0.209144,-0.254233,-0.383936,0.051060,-0.464982,blues
2,0.428650,-0.202218,-0.356651,0.420297,-0.647681,0.569896,0.304857,-1.206783,0.345872,1.401142,...,-0.072946,0.780444,0.880327,1.024387,-0.153665,-0.902088,0.224170,0.445963,0.583370,blues
3,0.030659,0.428549,0.272852,-0.119707,0.110874,0.443100,0.579181,-0.371511,-0.099708,-0.157338,...,-0.738868,0.075720,0.150221,0.273691,0.202919,0.105559,0.000740,0.391945,-0.143470,blues
4,0.086829,0.118544,0.248251,-0.240198,0.711506,0.172540,0.649633,-0.209181,0.084456,-0.133848,...,-0.299968,-0.556515,-0.218803,0.415372,0.386510,-0.421162,0.225155,0.626977,-0.112364,blues
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
795,0.010365,0.236053,-0.600733,0.143314,0.105116,0.420065,0.736434,-0.204361,0.177135,0.015280,...,0.438108,-0.292478,0.147047,0.554778,0.147946,0.028099,0.218859,-0.193416,0.173785,jazz
796,0.255348,0.006153,-1.118552,1.267267,-0.734694,-0.069714,0.737840,-0.046843,0.153062,0.994863,...,-0.306522,-0.218442,-0.026611,0.466025,-0.643068,-0.165016,0.342865,0.117370,0.208695,jazz
797,-0.057818,0.247108,0.036724,0.010742,0.287799,0.509897,0.656168,-0.191511,0.119337,-0.120354,...,-0.427980,0.059454,-0.009033,0.008451,-0.373940,-0.143857,-0.023573,0.203725,-0.395718,jazz
798,0.094037,0.357295,-0.102208,0.030102,0.026339,-0.391917,0.416954,-0.207274,0.902881,0.306998,...,0.114790,0.097130,0.282914,-0.264130,-0.462403,-0.378946,-0.150494,0.087574,0.308878,jazz
