In [1]:
import pandas as pd

In [2]:
df = pd.read_csv('chorales/train/chorale_000.csv')

In [3]:
df

Unnamed: 0,note0,note1,note2,note3
0,74,70,65,58
1,74,70,65,58
2,74,70,65,58
3,74,70,65,58
4,75,70,58,55
...,...,...,...,...
187,70,65,62,46
188,70,65,62,46
189,70,65,62,46
190,70,65,62,46


In [4]:
import os

train_files = sorted(os.path.join('chorales','train', f) for f in os.listdir(os.path.join('chorales', 'train')) if f.endswith('.csv'))
test_files = sorted(os.path.join('chorales','test', f) for f in os.listdir(os.path.join('chorales', 'test')) if f.endswith('.csv'))
valid_files = sorted(os.path.join('chorales','valid', f) for f in os.listdir(os.path.join('chorales', 'valid')) if f.endswith('.csv'))

In [5]:
train_data = [pd.read_csv(f).values.tolist() for f in train_files]
test_data = [pd.read_csv(f).values.tolist() for f in test_files]
valid_data = [pd.read_csv(f).values.tolist() for f in valid_files]

36 = C1
81 = A5
0 -> silence

In [6]:
from music21 import stream, chord

chorale = train_data[20]

s= stream.Stream()
for row in chorale:
    s.append(chord.Chord([n for n in row if n], quarterLength=1))

s.show('midi')

### Preprocessing

In [7]:
import numpy as np

min_note, max_note = 36, 81
window_size, window_offset, batch_size = 32, 16, 32

def make_xy(chorales):
    windows = [c[i:i+window_size+1] for c in chorales for i in range(0, len(c) - window_size, window_offset)]

    data = np.array(windows, dtype=int)

    data = np.where(data==0, 0, data - min_note + 1)
    data = np.clip(data, 0, max_note - min_note + 1)
    
    flat = data.reshape(data.shape[0], -1)

    return flat[:, :-1], flat[:, 1:]

X_train, Y_train = make_xy(train_data)
X_test, Y_test = make_xy(test_data)
X_valid, Y_valid = make_xy(valid_data)

In [8]:
X_train.shape

(3111, 131)

In [9]:
Y_train.shape

(3111, 131)

### Training The Model

In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [11]:
train_notes = set([z for x in train_data for y in x for z in y])
test_notes = set([z for x in test_data for y in x for z in y])
valid_notes = set([z for x in valid_data for y in x for z in y])

num_notes = len(set.union(train_notes, test_notes, valid_notes))
num_notes

47

In [12]:
class CasualConv1D(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation=1):
        super().__init__()
        self.padding = (kernel_size-1)*dilation
        self.conv = nn.Conv1d(
            in_channels,
            out_channels,
            kernel_size,
            dilation=dilation,
            padding=self.padding
        )

    def forward(self, x):
        x = self.conv(x)
        return x[:, :, :-self.padding]

In [13]:
class Model(nn.Module):
    def __init__(self, num_notes):
        super().__init__()

        # Embedding
        self.embedding = nn.Embedding(num_notes, 20)

        self.conv1 = CasualConv1D(4*5, 32, kernel_size=2, dilation=1)
        self.bn1 = nn.BatchNorm1d(32)

        self.conv2 = CasualConv1D(32, 48, kernel_size=2, dilation=2)
        self.bn2 = nn.BatchNorm1d(48)

        self.conv3 = CasualConv1D(48, 64, kernel_size=2, dilation=4)
        self.bn3 = nn.BatchNorm1d(64)

        self.conv4 = CasualConv1D(64, 96, kernel_size=2, dilation=8)
        self.bn4 = nn.BatchNorm1d(96)

        self.conv5 = CasualConv1D(96, 128, kernel_size=2, dilation=16)
        self.bn5 = nn.BatchNorm1d(128)

        self.dropout = nn.Dropout(0.05)

        # LSTM
        self.lstm = nn.LSTM(
            input_size=128,
            hidden_size=256,
            batch_first=True
        )

        # Dense output
        self.fc = nn.Linear(256, num_notes)

    def forward(self, x):
        if x.dim() == 3:
            x = x.view(x.size(0), -1)
        x = self.embedding(x)
        x = x.permute(0, 2, 1)   

        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        x = F.relu(self.bn5(self.conv5(x)))

        x = self.dropout(x)

        x = x.permute(0, 2, 1)
        x, _ = self.lstm(x)

        x = self.fc(x)
        return x

In [14]:
model = Model(num_notes=num_notes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.NAdam(model.parameters(), lr=1e-3)

In [15]:
from torch.utils.data import DataLoader, TensorDataset

train_ds = TensorDataset(
    torch.LongTensor(X_train),
    torch.LongTensor(Y_train)
)

val_ds = TensorDataset(
    torch.LongTensor(X_valid),
    torch.LongTensor(Y_valid)
)

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=batch_size)

In [16]:
epochs = 20

for epoch in range(epochs):
    # train
    model.train()
    total_loss, correct, total = 0, 0, 0

    for x, y in train_loader:
        optimizer.zero_grad()

        logits = model(x)
        logits = logits.reshape(-1, logits.size(-1))
        y = y.reshape(-1)

        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        correct += (logits.argmax(dim=1)==y).sum().item()
        total += y.numel()

    train_acc = correct / total

    # validation
    model.eval()
    val_loss, val_correct, val_total = 0, 0, 0

    with torch.no_grad():
        for x, y in val_loader:
            logits = model(x)
            logits = logits.reshape(-1, logits.size(-1))
            y = y.reshape(-1)

            loss = criterion(logits, y)
            val_loss += loss.item()
            val_correct += (logits.argmax(dim=1)==y).sum().item()
            val_total += y.numel()

    val_acc = val_correct / val_total

    print(
        f"Epoch {epoch+1}/{epochs} | "
        f"loss={total_loss:.4f} - acc={train_acc:.4f} | "
        f"val_loss={val_loss:.4f} - val_acc={val_acc:.4f}"
    )

Epoch 1/20 | loss=194.5846 - acc=0.4844 | val_loss=40.7552 - val_acc=0.6784
Epoch 2/20 | loss=91.4687 - acc=0.7547 | val_loss=30.3362 - val_acc=0.7553
Epoch 3/20 | loss=73.0343 - acc=0.7933 | val_loss=25.6505 - val_acc=0.7875
Epoch 4/20 | loss=65.0582 - acc=0.8100 | val_loss=23.5886 - val_acc=0.7998
Epoch 5/20 | loss=60.0539 - acc=0.8208 | val_loss=22.5357 - val_acc=0.8060
Epoch 6/20 | loss=56.6595 - acc=0.8286 | val_loss=22.7238 - val_acc=0.8032
Epoch 7/20 | loss=53.7213 - acc=0.8362 | val_loss=21.7971 - val_acc=0.8083
Epoch 8/20 | loss=51.1175 - acc=0.8426 | val_loss=21.0646 - val_acc=0.8132
Epoch 9/20 | loss=48.9457 - acc=0.8482 | val_loss=20.5404 - val_acc=0.8179
Epoch 10/20 | loss=47.0271 - acc=0.8538 | val_loss=21.3925 - val_acc=0.8127
Epoch 11/20 | loss=45.2653 - acc=0.8582 | val_loss=20.3259 - val_acc=0.8210
Epoch 12/20 | loss=43.6194 - acc=0.8626 | val_loss=20.5985 - val_acc=0.8172
Epoch 13/20 | loss=41.7966 - acc=0.8678 | val_loss=19.7476 - val_acc=0.8258
Epoch 14/20 | loss=4

In [20]:
import numpy as np

def sample_next_note(probs):
    probabilities = np.asarray(probs, dtype=float)  # probabilities for each note to be the next
    
    probs_sum = probabilities.sum()  # get the sum for normalization

    # if the probability sum is zero, negative or infinite -> just return the note with the highest probability
    if probs_sum <= 0 or not np.isfinite(probs_sum):
        return int(np.argmax(probabilities))

    probabilities /= probs_sum # otherwise normalize the probabilities to be between 0 and 1
    return np.random.choice(len(probabilities), p=probabilities)

def generate_chorale(model, seed_chords, length, device="cpu"):
    model.eval()

    # 1. Pre-process and Flatten the Seed
    token_sequence = np.array(seed_chords, dtype=int)
    token_sequence = np.where(
        token_sequence == 0,
        token_sequence,
        token_sequence - 36 + 1
    )

    # Convert to Tensor and flatten 8 chords of 4 notes into a sequence of 32 notes
    # Shape changes from (8, 4) -> (1, 32)
    token_sequence = torch.LongTensor(token_sequence).to(device)
    token_sequence = token_sequence.view(1, -1) 

    for _ in range(length * 4):
        with torch.no_grad():
            # logits shape: (1, current_length, num_notes)
            logits = model(token_sequence)          
            
            # We only care about the very last note predicted
            last_logits = logits[:, -1, :]          
            probs = F.softmax(last_logits, dim=-1)
            probs = probs.squeeze(0).cpu().numpy()

        # 2. Sample and Cat
        next_token_idx = sample_next_note(probs)
        # Create a (1, 1) tensor to match the (1, Seq_Len) shape of token_sequence
        next_token_tensor = torch.LongTensor([[next_token_idx]]).to(device)

        # Both are now 2D, so dim=1 works!
        token_sequence = torch.cat([token_sequence, next_token_tensor], dim=1)

    # 3. Post-process back to original note range
    token_sequence = token_sequence.cpu().numpy()
    token_sequence = np.where(
        token_sequence == 0,
        token_sequence,
        token_sequence + 36 - 1
    )

    # Reshape the long sequence of notes back into chords (4 notes each)
    return token_sequence.reshape(-1, 4)

In [21]:
seed_chords = test_data[2]

chorale = seed_chords
s = stream.Stream()
for row in chorale:
    s.append(chord.Chord([n for n in row if n], quarterLength=1))

s.show('midi')

In [22]:
seed_chords = test_data[2][:8]
new_chorale = generate_chorale(model, seed_chords, 56)
new_chorale

array([[73, 68, 61, 53],
       [73, 68, 61, 53],
       [73, 68, 61, 53],
       [73, 68, 61, 53],
       [69, 66, 61, 54],
       [69, 66, 61, 54],
       [69, 66, 61, 54],
       [69, 66, 61, 54],
       [68, 64, 59, 56],
       [68, 64, 59, 56],
       [68, 64, 59, 57],
       [68, 64, 59, 57],
       [66, 66, 59, 59],
       [66, 66, 59, 59],
       [66, 66, 59, 59],
       [66, 66, 59, 59],
       [66, 66, 59, 59],
       [66, 66, 59, 59],
       [66, 64, 59, 47],
       [66, 64, 59, 47],
       [66, 64, 59, 47],
       [66, 64, 59, 47],
       [66, 63, 59, 47],
       [66, 63, 59, 47],
       [66, 63, 59, 47],
       [66, 63, 59, 47],
       [66, 63, 59, 47],
       [66, 63, 59, 47],
       [64, 61, 55, 52],
       [64, 59, 55, 52],
       [64, 59, 55, 52],
       [64, 59, 55, 52],
       [66, 59, 57, 52],
       [66, 59, 57, 52],
       [66, 59, 59, 50],
       [66, 59, 59, 50],
       [67, 59, 59, 49],
       [67, 59, 59, 52],
       [67, 60, 57, 52],
       [67, 60, 57, 52],


In [23]:
chorale = new_chorale.tolist()
s = stream.Stream()
for row in chorale:
    s.append(chord.Chord([n for n in row if n], quarterLength=1))
s.show('midi')

In [24]:
def generate_random_chorale(length, rest_probability=0.2, pitch_low=36, pitch_high=81, seed=None):
    rng = np.random.default_rng(seed)  # random number generator
    random_pitches = rng.integers(pitch_low, pitch_high + 1, size=(length, 4))  # generate random notes

    # some masking to have both silence and random pitches
    rest_mask = rng.random((length, 4)) < float(rest_probability)
    chorale = np.where(rest_mask, 0, random_pitches).astype(int)
    
    return chorale

In [25]:
# listen to completely random music to compare the quality to what our model generated
chorale = generate_random_chorale(56).tolist()
s = stream.Stream()
for row in chorale:
    s.append(chord.Chord([n for n in row if n], quarterLength=1))
s.show('midi')

In [51]:
torch.save(model.state_dict(), "weights.pth")

In [53]:
loaded_model = Model(num_notes=num_notes)
loaded_model.load_state_dict(torch.load("weights.pth", weights_only=False))

<All keys matched successfully>

In [56]:
new_chorale = generate_chorale(loaded_model, seed_chords, 56)
new_chorale

array([[73, 68, 61, 53],
       [73, 68, 61, 53],
       [73, 68, 61, 53],
       [73, 68, 61, 53],
       [69, 66, 61, 54],
       [69, 66, 61, 54],
       [69, 66, 61, 54],
       [69, 66, 61, 54],
       [71, 67, 62, 59],
       [71, 67, 62, 59],
       [71, 67, 62, 59],
       [71, 67, 62, 59],
       [73, 66, 61, 54],
       [73, 66, 61, 54],
       [73, 66, 61, 54],
       [73, 66, 61, 54],
       [74, 70, 59, 47],
       [74, 70, 59, 54],
       [76, 70, 59, 54],
       [76, 70, 59, 54],
       [76, 70, 64, 54],
       [76, 68, 64, 52],
       [76, 69, 66, 52],
       [76, 69, 57, 52],
       [76, 68, 57, 52],
       [76, 68, 59, 52],
       [74, 66, 59, 59],
       [74, 66, 59, 59],
       [74, 66, 59, 59],
       [74, 66, 59, 59],
       [73, 64, 61, 57],
       [73, 64, 61, 57],
       [73, 64, 61, 55],
       [73, 64, 61, 55],
       [74, 66, 57, 54],
       [74, 66, 57, 54],
       [74, 67, 62, 52],
       [74, 67, 62, 52],
       [69, 69, 62, 54],
       [69, 69, 62, 54],


In [57]:
chorale = new_chorale.tolist()
s = stream.Stream()
for row in chorale:
    s.append(chord.Chord([n for n in row if n], quarterLength=1))
s.show('midi')