In [11]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from torchsummary import summary

from encoder import Encoder, load_encoder
from decoder import Decoder, load_decoder
from train import train_epoch, test_epoch
# from utility import get_all_files_paths

In [13]:
TRAIN = False

BETA = 1
BATCH_SIZE = 64
LATENT_SPACE_DIM = 1024
LEARNING_RATE = 1e-5

SAVE_ROUND = 20
NUM_EPOCHS = 300

torch.manual_seed(0) # random seed for reproducible results
# 似乎不同设备要单独设置随机种子？

<torch._C.Generator at 0x7ff6f62a0990>

In [14]:
dataset_path = "data/spec/fma_small"
# dataset_path = "data/spec/GTZAN_646"
model_save_path = "models/Echoes"
csv_save_path = "output/Echoes_output"

os.makedirs(model_save_path, exist_ok=True)
os.makedirs(csv_save_path, exist_ok=True)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

print(f'device: {device}')

device: cuda


In [4]:
class MusicDataset(Dataset):
	def __init__(self, file_paths):
		self.file_paths = file_paths

	def __len__(self):
		return len(self.file_paths)

	def __getitem__(self, idx):
		file_path = self.file_paths[idx]
		data = np.load(file_path)
	
		if np.isnan(data).any():
			filename = os.path.basename(file_path)
			label = filename.replace(".npy", "").lstrip("0")
			print(f"Warning: NaN value found in {label}")
  
		data = data[np.newaxis, :, :]  # Add a channel dimension
		data = torch.tensor(data, dtype=torch.float32)
		
		# filename = os.path.basename(os.path.dirname(file_path))
		filename = os.path.basename(file_path)
		label = filename.replace(".npy", "").lstrip("0")
		return data, label

file_paths = []
for root, dirs, files in os.walk(dataset_path):
	for file in files:
		file_paths.append(os.path.join(root, file))

# 取出20%的数据作为测试集，10%的数据作为验证集，剩下的作为训练集（理论上每个文件夹有风格区别，应该尽量做到比例均匀），但是（先验的）风格不好分，暂时直接随机划分

m = len(file_paths)
test_size = int(m * 0.15)
valid_size = int(m * 0.15)
train_size = m - test_size - valid_size
# train : valid : test = 70 : 15 : 15

paths = np.array(file_paths)
np.random.shuffle(paths)

train_paths = paths[:train_size]
valid_paths = paths[train_size:train_size + valid_size]
test_paths = paths[train_size + valid_size:]

train_dataset = MusicDataset(train_paths)
valid_dataset = MusicDataset(valid_paths)
test_dataset = MusicDataset(test_paths)

print(f'Training set size:   {len(train_dataset)}')
print(f'Validation set size: {len(valid_dataset)}')
print(f'Test set size:       {len(test_dataset)}')

if TRAIN:
	train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
	valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False)
	test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

Training set size:   5597
Validation set size: 1198
Test set size:       1198


In [None]:
if TRAIN:
	loss_fn = torch.nn.MSELoss()
	
	encoder = Encoder(encoded_space_dim=LATENT_SPACE_DIM)
	decoder = Decoder(encoded_space_dim=LATENT_SPACE_DIM)
	params_to_optimize = [
		{'params': encoder.parameters()},
		{'params': decoder.parameters()}
	]

	optim = torch.optim.Adam(params_to_optimize, lr=LEARNING_RATE, weight_decay=1e-05)

	encoder = encoder.to(device)
	decoder = decoder.to(device)
	
	losses = {'train_loss':[],'val_loss':[]}

	for epoch in range(NUM_EPOCHS):
		# beta = epoch / NUM_EPOCHS  # Increase beta over time
		beta = BETA
		train_loss =train_epoch(encoder, decoder, device, train_loader, loss_fn, optim, beta)
		val_loss = test_epoch(encoder, decoder, device, valid_loader, loss_fn, beta)
	
		print('\n EPOCH {}/{} \t train loss {} \t val loss {}'.format(epoch + 1, NUM_EPOCHS,train_loss,val_loss))

		# track losses
		losses['train_loss'].append(train_loss)
		losses['val_loss'].append(val_loss)

		# save model
		if (epoch + 1) % SAVE_ROUND == 0: 
			torch.save(encoder.state_dict(), f'{model_save_path}/encoder_{epoch+1}.pth')
			torch.save(decoder.state_dict(), f'{model_save_path}/decoder_{epoch+1}.pth')

	test_loss = test_epoch(encoder, decoder, device, test_loader, loss_fn, beta)

In [6]:
plt.figure(figsize=(8,6))
plt.semilogy(losses['train_loss'], label='Train')
plt.semilogy(losses['val_loss'], label='Valid')
plt.xlabel('Epoch')
plt.ylabel('Average Loss|')
# plt.grid()
plt.legend()
plt.title('loss')
plt.show()

NameError: name 'losses' is not defined

<Figure size 800x600 with 0 Axes>

In [7]:
test_loss

NameError: name 'test_loss' is not defined

In [8]:
encoder_path = f"{model_save_path}/encoder.pth"
decoder_path = f"{model_save_path}/decoder.pth"
torch.save(encoder.state_dict(), encoder_path)
torch.save(decoder.state_dict(), decoder_path)

NameError: name 'encoder' is not defined

In [9]:
if not TRAIN:
	encoder_path = f"{model_save_path}/encoder_280.pth"
	encoder = load_encoder(encoder_path, LATENT_SPACE_DIM)
	encoder = encoder.to(device)

encoded_samples = []
for sample in tqdm(test_dataset):
	data = sample[0].unsqueeze(0).to(device)
	label = sample[1]
	encoder.eval()
	with torch.no_grad():
		mu, log_var = encoder(data)

	encoded_data = mu + torch.exp(0.5 * log_var) * torch.randn_like(mu)
	encoded_data = encoded_data.flatten().cpu().numpy()
	encoded_sample = {f"enc. v {i}": enc for i, enc in enumerate(encoded_data)}
	encoded_sample['label'] = label
	encoded_samples.append(encoded_sample)

encoded_samples = pd.DataFrame(encoded_samples)
encoded_samples.to_csv(f"{csv_save_path}/fma_small_encoded.csv", index=False)
encoded_samples

  model.load_state_dict(torch.load(model_path, map_location=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")))
100%|██████████| 1198/1198 [00:03<00:00, 306.90it/s]


Unnamed: 0,enc. v 0,enc. v 1,enc. v 2,enc. v 3,enc. v 4,enc. v 5,enc. v 6,enc. v 7,enc. v 8,enc. v 9,...,enc. v 1015,enc. v 1016,enc. v 1017,enc. v 1018,enc. v 1019,enc. v 1020,enc. v 1021,enc. v 1022,enc. v 1023,label
0,-0.946606,-0.425239,-2.659478,0.162394,-0.106082,-0.586880,-0.604333,-0.326373,-1.081557,-0.359277,...,-1.900775,-1.080082,0.316355,0.337315,-0.339624,0.955027,-1.255026,-1.195891,-0.105890,74388
1,0.153706,-0.557118,0.899824,-0.714434,1.396821,0.869281,0.141491,-0.937902,0.799757,0.872251,...,-0.722580,-1.036961,-0.508824,0.724544,0.100975,-1.723591,-0.841320,-0.307872,1.502998,70774
2,2.511299,-0.717655,-0.519853,0.148019,0.111525,-0.411314,0.907526,0.808997,-0.700670,0.134560,...,-0.833435,0.387396,1.405884,0.841048,0.972707,0.032631,-0.392905,1.489119,-1.306501,67016
3,0.325960,-1.427279,-0.861544,0.333140,-0.803067,0.727766,-0.424727,0.413645,-0.140047,-0.247262,...,-0.215967,-0.910874,0.425684,-0.322223,1.161427,-0.297772,-0.081139,0.234709,0.092129,140259
4,-0.370955,0.620888,0.545438,-0.274665,0.054474,-1.041744,0.287337,0.654044,-0.173162,1.448630,...,-0.157629,1.303216,-0.125188,-0.714768,1.484253,3.224166,-0.414848,0.841281,0.273222,97393
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1193,1.618409,-2.107481,-0.294493,-0.587241,-0.084038,0.207361,-0.595117,1.163033,-0.929993,-0.155797,...,0.038489,-1.872512,-1.001752,-1.695376,-2.360294,0.044683,-1.196657,1.012713,-2.824301,97211
1194,1.075054,-0.508127,1.437324,1.397717,-2.181241,-0.535416,-1.249104,-0.114818,0.111373,1.085199,...,1.044502,0.012442,-0.293274,-0.253205,0.347142,-0.169346,-0.412717,0.034062,-1.604117,107912
1195,-0.119992,-0.400481,0.973900,0.299499,-0.928898,-0.505174,1.579869,0.656737,-0.811004,-0.174451,...,-0.910894,-0.055492,1.273955,1.067598,0.959377,0.365274,0.324394,-1.322098,-0.652902,114223
1196,-0.108738,1.671015,-0.976587,-0.006499,0.702491,0.928739,-0.242716,-0.183750,0.360491,0.523292,...,-0.482288,-0.040770,-1.196344,0.630051,-1.146848,0.650762,0.103186,-0.699240,0.403714,72477


In [12]:
summary(encoder, input_size=(1, 256, 646))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 128, 323]             320
       BatchNorm2d-2         [-1, 32, 128, 323]              64
         LeakyReLU-3         [-1, 32, 128, 323]               0
            Conv2d-4          [-1, 64, 64, 162]          18,496
       BatchNorm2d-5          [-1, 64, 64, 162]             128
         LeakyReLU-6          [-1, 64, 64, 162]               0
            Conv2d-7          [-1, 128, 32, 81]          73,856
       BatchNorm2d-8          [-1, 128, 32, 81]             256
         LeakyReLU-9          [-1, 128, 32, 81]               0
           Conv2d-10          [-1, 256, 16, 41]         295,168
      BatchNorm2d-11          [-1, 256, 16, 41]             512
        LeakyReLU-12          [-1, 256, 16, 41]               0
          Flatten-13               [-1, 167936]               0
           Linear-14                 [-