In [None]:
import numpy as np
from music21 import *

# Get all Bach chorale files
bach_chorales = corpus.getComposer('bach')

# File to store the results
output_file = 'data/preprocess.txt'
subdivision = 0.5

with open(output_file, 'w') as f:
    for chorale in bach_chorales:
        try:
            s = corpus.parse(chorale)
            k = s.analyze('key')
            
            if k.mode == 'minor':
                target_key =  pitch.Pitch('A')  # Transpose to A if the key is minor
            else:
                target_key =  pitch.Pitch('C')  # Transpose to C if the key is major
            i = interval.Interval(k.tonic, target_key)
            s = s.transpose(i)
            # print(f'Key: {s.analyze('key')}')

            parts = s.parts
            if len(parts) > 4: 
                continue
            measure = parts[0].getElementsByClass('Measure')[0]  

            time_signature = measure.timeSignature
            measure_length = measure.quarterLength
            # print(f'Time signature: {time_signature} | measure_length: {measure_length}')

            if measure_length == 1:
                measure_length = 4 

            num_slots = int(measure_length / subdivision)
            data = []

            for measure_index in range(len(parts[0].getElementsByClass('Measure'))):
                shape = (1, num_slots, len(parts))
                out = np.empty(shape, dtype=object)
                
                for part_index, part in enumerate(parts):
                    measure = part.measure(measure_index + 1)
                    if measure is not None:
                        offset = 0.0
                        count = 0
                        while offset < measure_length:
                            # Iterate over both notes and rests in the measure
                            element = next((el for el in measure.notesAndRests if el.offset == offset), None)
                            
                            if element:
                                if isinstance(element, note.Note):  # Handle note
                                    token = f"p{part_index}{element.nameWithOctave}"
                                elif isinstance(element, note.Rest):  # Handle rest
                                    token = f"p{part_index}|"
                                
                                # Check for fermata
                                for exp in element.expressions:
                                    if isinstance(exp, expressions.Fermata):
                                        token += '·'
                                
                                out[0][count][part_index] = token
                            else:
                                out[0][count][part_index] = f"p{part_index}"
                            
                            # Move the offset forward by 0.25 (quarter note duration)
                            offset += subdivision
                            count += 1
                if np.all(out.flatten() == None):
                    continue
                else:
                    data.extend(out.flatten())  # Add to data if not all are None
            # Write to file
            f.write(" ".join(data) + "\n" )


        except Exception as e:
            print(f"Error processing {chorale}: {e}")
            continue  # Skip to the next chorale if an error occurs

        print(f"Processing {chorale} ")

print(f"Processing complete. Data saved to {output_file}")


In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import numpy as np
# Read the file and build a vocabulary set
char_set = set()

with open('data/preprocess.txt', 'r') as file:
    for line in file:
        split = line.split()
        for token in split:
            char_set.add(token)  # Add each unique token to the set

# Create mappings
string_to_int = {token: idx for idx, token in enumerate(char_set)}
int_to_string = {idx: token for token, idx in string_to_int.items()}

# Encoding and decoding functions
encode = lambda s: [string_to_int[c] for c in s]
decode = lambda l: ' '.join([int_to_string[i] for i in l]) 

data = [];
dict = {};
# Open and read the file
with open('data/preprocess.txt', 'r') as file:
    i = 0
    for line in file:
        data.extend(encode(line.split()))
        dict[i] = torch.tensor(encode(line.split()), dtype=torch.int64)
        data.extend(encode(line.split()))
        i += 1


data = torch.tensor(data, dtype=torch.int64)
vocab_size = len(char_set)
print(len(data))
print(len(dict))


In [None]:
from backtobach.training import Batch_provider
from backtobach.model import GPTLanguageModel, device, learning_rate, block_size, batch_size


batch_provider=Batch_provider(
    chorales_dict=dict, 
    merged_chorales=data, 
    block_size=block_size, 
    batch_size=batch_size, 
    device=device)

x, y = batch_provider.get_batch('train')
first = x[0] 
first_y = y[0]

print(decode(first.tolist()))
print(decode(first_y.tolist()))


In [None]:
from backtobach.model import GPTLanguageModel, device, learning_rate, block_size, batch_size
from backtobach.training import Batch_provider, train, eval_iters

model = GPTLanguageModel(vocab_size)
m = model.to(device)
context = torch.zeros((1, 1), dtype=torch.long, device=device)

optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate)
train(model=m, batch_provider=Batch_provider(
    chorales_dict=dict, 
    merged_chorales=data, 
    block_size=block_size, 
    batch_size=batch_size, 
    device=device),
    optimizer=optimizer,
    eval_iters=eval_iters,
)

In [None]:
import torch
context = torch.tensor(encode(['p0C4', 'p1G3', 'p2E3', 'p3C3']))

# Reshape to the desired shape (1, 1)
context = context.unsqueeze(0) # Adds two dimensions
context = context.to(device)

print(context)

# decode = lambda l: ' '.join([int_to_string[i] for i in l])
generated_chars = decode(m.generate(context, max_new_tokens=500)[0].tolist())
print(generated_chars)

In [None]:
# Save the model
import pickle
with open('gpt-01.pkl', 'wb') as f:
    pickle.dump(m, f)
torch.save(m.state_dict(), 'bachGPT')


In [None]:
from music21 import stream, note
import re

chorale = stream.Score()

soprano = stream.Part()
alto = stream.Part()
tenor = stream.Part()
bass = stream.Part()

parts = [soprano, alto, tenor, bass];

soprano.id = "Soprano"
alto.id = "Alto"
tenor.id = "Tenor"
bass.id = "Bass"

# pattern = r"p(\d+)([A-Ga-g][#-]?\d+|\|)-(\d+\.\d+)"
pattern = r"p(\d+)([A-Ga-g][#-]?\d+|\|)"
pattern2 = r"p(\d+)(|)-(\d+\.\d+)"
p2 = r"p(\d+)"

with open("data/out.txt", 'r') as f:
    for line in f:
        tokens = line.split()
        for token in tokens:
            match = re.match(pattern, token)
            if match:
                part = int(match.group(1))     
                notestr = match.group(2)      
                # d = match.group(3)            
                # print(f"Token: {token} -> Part: {part}, Note: {note}, Duration: {duration}")
                if part < 4:
                    if notestr == '|':
                        n = note.Rest()
                    else:
                        n = note.Note(notestr, quarterLength=1.0 );
                    n.duration.quarterLength = 0.5
                    parts[part].append(n)
            else:
                match = re.match(p2, token)
                part = int(match.group(1)) 
                if part < 4:
                    parts[part].append(note.Rest(quarterLength=0.5))

chorale.append(parts)
chorale.write("musicxml", fp="my_chorale.mxl")
