In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
import numpy as np
import pickle


In [45]:
class Autoencoder(nn.Module):
    def __init__(self, input_size):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_size, 4096),
            nn.BatchNorm1d(4096),
            nn.ReLU(True),
            nn.Dropout(0.5),
            nn.Linear(4096, 2048),
            nn.ReLU(True),
            nn.Linear(2048, 1024),
            nn.ReLU(True),



        )
        self.decoder = nn.Sequential(
            nn.Linear(1024, 2048),
            nn.BatchNorm1d(2048),
            nn.ReLU(True),
            nn.Linear(2048, 4096),
            nn.BatchNorm1d(4096),
            nn.ReLU(True),
            nn.Dropout(0.5),
            nn.Linear(4096, input_size),
            nn.Sigmoid()
        )

    # initialize weights using xavier normal
    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0)
                
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x


In [46]:
class VectorDataset(Dataset):
    def __init__(self, vector_dict):
        self.vectors = list(vector_dict.values())
        self.keys = list(vector_dict.keys())

    def __len__(self):
        return len(self.vectors)

    def __getitem__(self, idx):
        vector = self.vectors[idx]
        return torch.tensor(vector, dtype=torch.float)

In [47]:
# load feature vectors
# Load the feature vectors
with open('midi_feature_vectors.pkl', 'rb') as f:
    feature_vectors = pickle.load(f)


dataset = VectorDataset(feature_vectors)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)  # Adjust batch size as needed


In [48]:
input = len(feature_vectors['fernando'])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
autoencoder = Autoencoder(input_size=input).to(device)
autoencoder.init_weights()

In [49]:
criterion = nn.MSELoss()  # Using MSE loss for reconstruction error
optimizer = torch.optim.Adam(autoencoder.parameters(), lr=1e-4)  # Adjust learning rate as needed
# Example of setting a learning rate scheduler
#scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10000, gamma=0.1)

num_epochs = 100000 # Adjust the number of epochs as needed
autoencoder.train()
for epoch in range(num_epochs):
    running_loss = 0.0
    for data in dataloader:
        # Transfer data to GPU
        inputs = data.to(device)
        
        # Forward pass
        outputs = autoencoder(inputs)
        loss = criterion(outputs, inputs)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # save loss for average
        running_loss += loss.item()



    # average loss
    run = running_loss / len(dataloader)
    if epoch % 100 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {run:.4f}')




Epoch [1/100000], Loss: 40.2246
Epoch [101/100000], Loss: 39.3882
Epoch [201/100000], Loss: 39.7564
Epoch [301/100000], Loss: 39.6306
Epoch [401/100000], Loss: 39.6084
Epoch [501/100000], Loss: 39.7256
Epoch [601/100000], Loss: 40.2051
Epoch [701/100000], Loss: 39.6996
Epoch [801/100000], Loss: 39.6544
Epoch [901/100000], Loss: 39.9970
Epoch [1001/100000], Loss: 38.8898
Epoch [1101/100000], Loss: 39.6201
Epoch [1201/100000], Loss: 39.2668
Epoch [1301/100000], Loss: 39.1118
Epoch [1401/100000], Loss: 39.9108
Epoch [1501/100000], Loss: 39.4932
Epoch [1601/100000], Loss: 39.1466
Epoch [1701/100000], Loss: 39.4219
Epoch [1801/100000], Loss: 39.7034
Epoch [1901/100000], Loss: 39.5246
Epoch [2001/100000], Loss: 39.3839
Epoch [2101/100000], Loss: 39.0843
Epoch [2201/100000], Loss: 39.7096
Epoch [2301/100000], Loss: 39.4051
Epoch [2401/100000], Loss: 39.6544
Epoch [2501/100000], Loss: 39.3942
Epoch [2601/100000], Loss: 39.5425
Epoch [2701/100000], Loss: 40.1405
Epoch [2801/100000], Loss: 39.65

KeyboardInterrupt: 

In [50]:
# Save the model
torch.save(autoencoder.state_dict(), 'midi_autoencoder.pth')