In [2]:
import torch
import torchaudio
import torch.nn as nn
import torch.optim as optim
import os
import numpy as np

In [1]:
def load_song_files(song_dir):
    song_data = []
    for file in os.listdir(song_dir):
        if file.endswith(".mp3"):
            song, _ = torchaudio.load(os.path.join(song_dir, file))
            song_data.append(song)
    return song_data

def preprocess_audio(song_data):
    processed_data = [torchaudio.transforms.Resample(orig_freq=44100, new_freq=5000)(song) for song in song_data]
    max_length = max(len(song[0]) for song in processed_data)
    processed_data = [torch.nn.functional.pad(song, (0, max_length - len(song[0])), mode='constant', value=0) for song in processed_data]
    return processed_data

In [10]:
def create_siamese_network(input_shape):
    model = nn.Sequential(
        nn.Conv1d(in_channels=input_shape[0], out_channels=32, kernel_size=3),
        nn.ReLU(),
        nn.MaxPool1d(kernel_size=2),
        nn.Conv1d(in_channels=32, out_channels=64, kernel_size=3),
        nn.ReLU(),
        nn.MaxPool1d(kernel_size=2),
        nn.Flatten(),
        nn.Linear(in_features=64 * (input_shape[1] // 4), out_features=32),
        nn.ReLU(),
        nn.Linear(in_features=32, out_features=16),
        nn.ReLU(),
        nn.Linear(in_features=16, out_features=8),
        nn.ReLU(),
        nn.Linear(in_features=8, out_features=4)
    )
    return model

In [11]:
def train_siamese_network(song_data, num_epochs=10, learning_rate=0.001, margin=0.2):
    processed_data = preprocess_audio(song_data)
    model = create_siamese_network(processed_data[0].shape)
    triplet_loss = nn.TripletMarginLoss(margin=margin)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    for epoch in range(num_epochs):
        running_loss = 0.0
        print(num_epochs, running_loss)
        for i in range(len(processed_data)):
            anchor = processed_data[i]
            positive = anchor.clone()  
            negative_idx = np.random.randint(0, len(processed_data))
            negative = processed_data[negative_idx]

            anchor_output = model(anchor.unsqueeze(0))
            positive_output = model(positive.unsqueeze(0))
            negative_output = model(negative.unsqueeze(0))

            loss = triplet_loss(anchor_output, positive_output, negative_output)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss / len(processed_data)}")

    return model

In [12]:
def extract_embeddings(song_data, model):
    embeddings = []
    with torch.no_grad():
        for song in song_data:
            processed_song = preprocess_audio([song])[0]
            embedding = model(processed_song.unsqueeze(0))
            embeddings.append(embedding)
    return embeddings

In [13]:
if __name__ == "__main__":
    # Directory containing MP3 song files
    song_dir = r'D:\vs_code\DL\proj\resources\fma_small_edited_truncated'

    # Load song files
    song_data = load_song_files(song_dir)

    # Train Siamese network
    model = train_siamese_network(song_data)

    # Extract embeddings
    learned_embeddings = extract_embeddings(song_data, model)
    for i, embedding in enumerate(learned_embeddings):
        print(f"Embedding for song {i + 1}:", embedding)

10 0.0


RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x2611200 and 2611264x32)