In [19]:
import os
import re
import torch
import wandb
import numpy as np
import torch.nn as nn
from torch.optim import lr_scheduler
from music21 import converter, instrument, note, chord
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split

In [20]:
def load_midi_files(midi_folder):
    notes = []
    for file in os.listdir(midi_folder):
        midi = converter.parse(os.path.join(midi_folder, file))
        notes_to_parse = None
        try:
            s2 = instrument.partitionByInstrument(midi)
            notes_to_parse = s2.parts[0].recurse() 
        except:
            notes_to_parse = midi.flat.notes

        for element in notes_to_parse:
            if isinstance(element, note.Note):
                if re.match(r"[A-G](#|-)?\d", str(element.pitch)):
                    notes.append(str(element.pitch))
            elif isinstance(element, chord.Chord):
                # Filter each note in the chord
                chord_notes = '.'.join(str(n) for n in element.normalOrder if re.match(r"[A-G](#|-)?\d", str(n)))
                if chord_notes:
                    notes.append(chord_notes)
    return notes

notes = load_midi_files('indian_classical')
# Extract the unique pitches in the dataset
pitchnames = sorted(set(item for item in notes))
# Create a dictionary to map pitches to integers
note_to_int = {note: num for num, note in enumerate(pitchnames)}

In [21]:
print(len(note_to_int))

54


In [22]:
# Prepare sequences
sequence_length = 100
network_input = []
network_output = []

for i in range(len(notes) - sequence_length):
    sequence_in = notes[i:i + sequence_length]
    sequence_out = notes[i + sequence_length]
    network_input.append([note_to_int[char] for char in sequence_in])
    network_output.append(note_to_int[sequence_out])

network_input = np.reshape(network_input, (len(network_input), sequence_length, 1))
network_input = torch.tensor(network_input / float(len(pitchnames)), dtype=torch.float32)
network_output = torch.tensor(network_output, dtype=torch.long)

# Assuming network_input and network_output are already defined
train_input, val_input, train_output, val_output = train_test_split(network_input, network_output, test_size=0.2, random_state=42)

train_dataset = TensorDataset(train_input, train_output)
val_dataset = TensorDataset(val_input, val_output)

train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
validation_dataloader = DataLoader(val_dataset, batch_size=64, shuffle=False)


In [25]:
class AdvancedMusicLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(AdvancedMusicLSTM, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True, num_layers=3, dropout=0.3, bidirectional=True)
        self.fc1 = nn.Linear(hidden_size * 2, hidden_size)  # Adjust for bidirectional
        self.dropout = nn.Dropout(0.3)
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x, _ = self.lstm(x)
        x = self.fc1(x[:, -1, :])
        x = self.dropout(x)
        x = self.fc2(x)
        return x

model = AdvancedMusicLSTM(1, 512, len(pitchnames))

In [30]:
from torch.optim import lr_scheduler

def train_model_with_checkpoint(model, dataloader, validation_dataloader, epochs):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.RMSprop(model.parameters(), lr=0.01)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)  # Adjust learning rate
    best_loss = float('inf')
    
    wandb.init(project="music-generation-2")

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for inputs, labels in train_dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        avg_loss = total_loss / len(train_dataloader)
        print(f'Epoch {epoch+1}/{epochs}, Loss: {avg_loss}')
        wandb.log({"epoch": epoch + 1, "loss": avg_loss})

        # Validation phase
        if epoch % 10 == 0:
            model.eval()
            val_loss = 0
            with torch.no_grad():
                for inputs, labels in validation_dataloader:
                    inputs, labels = inputs.to(device), labels.to(device)
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    val_loss += loss.item()
            avg_val_loss = val_loss / len(validation_dataloader)
            print(f'Validation Phase, Loss: {avg_val_loss}')
            wandb.log({"Validation loss" : avg_val_loss})

        # Scheduler step
        scheduler.step()

        # Checkpoint model
        if val_loss < best_loss:
            best_loss = val_loss
            torch.save(model.state_dict(), 'best_model.pth')
            print(f'Epoch {epoch+1}/{epochs}, Training Loss: {total_loss / len(dataloader)}, Validation Loss: {val_loss / len(validation_dataloader)} - Model Saved')

train_model_with_checkpoint(model, train_dataloader, validation_dataloader, 200)

cuda


0,1
Validation loss,▁
epoch,▁▂▃▄▅▅▆▇█
loss,█▃▃▃▃▂▃▂▁

0,1
Validation loss,3.25655
epoch,9.0
loss,3.28579


Epoch 1/200, Loss: 132.70593287291067
Validation Phase, Loss: 63.07322754398469
Epoch 1/200, Training Loss: 132.70593287291067, Validation Loss: 63.07322754398469 - Model Saved
Epoch 2/200, Loss: 39.63145472926478
Epoch 3/200, Loss: 26.61579434333309
Epoch 4/200, Loss: 29.19660988930733
Epoch 5/200, Loss: 23.115126479056574
Epoch 6/200, Loss: 30.095445656007335
Epoch 7/200, Loss: 20.83486185535308
Epoch 8/200, Loss: 7.793941011351924
Epoch 9/200, Loss: 3.525677923233278
Epoch 10/200, Loss: 3.48427782135625
Epoch 11/200, Loss: 3.8102242696669792
Validation Phase, Loss: 8.403663912127096
Epoch 11/200, Training Loss: 3.8102242696669792, Validation Loss: 8.403663912127096 - Model Saved
Epoch 12/200, Loss: 3.680686575751151
Epoch 13/200, Loss: 3.540700387570166
Epoch 14/200, Loss: 3.6210102092835212
Epoch 15/200, Loss: 3.3405400283875
Epoch 16/200, Loss: 3.3111663672231857
Epoch 17/200, Loss: 3.3094706054656737
Epoch 18/200, Loss: 3.306395257672956
Epoch 19/200, Loss: 3.3036360798343534
Epo

KeyboardInterrupt: 

In [35]:
import torch
import numpy as np

# Assuming 'AdvancedMusicLSTM' is the class of your model
model = AdvancedMusicLSTM(input_size=1, hidden_size=512, output_size=len(pitchnames))  # Adjust parameters as necessary

# Load the model state
model_path = 'best_model.pth'  # Replace with your actual model path
model.load_state_dict(torch.load(model_path))
model.eval()

def generate_music(model, network_input, pitchnames, note_to_int, num_generate=500):
    """ Generate music given a sequence of notes """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    # Pick a random sequence from the input as a starting point for the generation
    start = np.random.randint(0, len(network_input)-1)
    int_to_note = {num: note for note, num in note_to_int.items()}
    pattern = network_input[start].tolist()
    prediction_output = []

    # Generate notes
    for note_index in range(num_generate):
        prediction_input = torch.tensor([pattern], dtype=torch.float32).to(device)
        prediction = model(prediction_input)
        _, index = torch.max(prediction, 1)
        
        result = int_to_note[index.item()]
        prediction_output.append(result)
        
        pattern.append(index.item() / float(len(pitchnames)))
        pattern = pattern[1:len(pattern)]

    return prediction_output

# Generate a piece of music
generated_notes = generate_music(model, network_input, pitchnames, note_to_int)

RuntimeError: Error(s) in loading state_dict for AdvancedMusicLSTM:
	Missing key(s) in state_dict: "lstm.weight_ih_l0", "lstm.weight_hh_l0", "lstm.bias_ih_l0", "lstm.bias_hh_l0", "lstm.weight_ih_l0_reverse", "lstm.weight_hh_l0_reverse", "lstm.bias_ih_l0_reverse", "lstm.bias_hh_l0_reverse", "lstm.weight_ih_l1", "lstm.weight_hh_l1", "lstm.bias_ih_l1", "lstm.bias_hh_l1", "lstm.weight_ih_l1_reverse", "lstm.weight_hh_l1_reverse", "lstm.bias_ih_l1_reverse", "lstm.bias_hh_l1_reverse", "lstm.weight_ih_l2", "lstm.weight_hh_l2", "lstm.bias_ih_l2", "lstm.bias_hh_l2", "lstm.weight_ih_l2_reverse", "lstm.weight_hh_l2_reverse", "lstm.bias_ih_l2_reverse", "lstm.bias_hh_l2_reverse". 
	Unexpected key(s) in state_dict: "lstm1.weight_ih_l0", "lstm1.weight_hh_l0", "lstm1.bias_ih_l0", "lstm1.bias_hh_l0", "lstm1.weight_ih_l1", "lstm1.weight_hh_l1", "lstm1.bias_ih_l1", "lstm1.bias_hh_l1", "batch_norm1.weight", "batch_norm1.bias", "batch_norm1.running_mean", "batch_norm1.running_var", "batch_norm1.num_batches_tracked", "lstm2.weight_ih_l0", "lstm2.weight_hh_l0", "lstm2.bias_ih_l0", "lstm2.bias_hh_l0", "lstm2.weight_ih_l1", "lstm2.weight_hh_l1", "lstm2.bias_ih_l1", "lstm2.bias_hh_l1", "batch_norm2.weight", "batch_norm2.bias", "batch_norm2.running_mean", "batch_norm2.running_var", "batch_norm2.num_batches_tracked". 
	size mismatch for fc1.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([512, 1024]).
	size mismatch for fc1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for fc2.weight: copying a param with shape torch.Size([54, 256]) from checkpoint, the shape in current model is torch.Size([54, 512]).

In [None]:
from music21 import stream, note, chord, midi

def create_midi(prediction_output, output_path='output.mid'):
    """ Convert the output from the prediction to MIDI file """
    offset = 0
    output_notes = []

    # Create note and chord objects based on the values generated by the model
    for pattern in prediction_output:
        # Pattern is a chord
        if ('.' in pattern) or pattern.isdigit():
            notes_in_chord = pattern.split('.')
            notes = []
            for current_note in notes_in_chord:
                new_note = note.Note(int(current_note))
                new_note.storedInstrument = instrument.Piano()
                notes.append(new_note)
            new_chord = chord.Chord(notes)
            new_chord.offset = offset
            output_notes.append(new_chord)
        # Pattern is a note
        else:
            new_note = note.Note(pattern)
            new_note.offset = offset
            new_note.storedInstrument = instrument.Piano()
            output_notes.append(new_note)

        # Increase offset each iteration so that notes do not stack
        offset += 0.5

    midi_stream = stream.Stream(output_notes)
    midi_stream.write('midi', fp=output_path)

# Create a MIDI file from the generated notes
create_midi(generated_notes)