# **V2** Thai Classical Music Generation ‚Äì Baseline Sequence Modeling




## 0. Prerequisites & Setup

- Reuse normalized symbolic dataset from Stage 2  
- Pitch-only representation (octave stripped)  
- Rest token included as explicit `<REST>` symbol  
- Sequence = flattened token stream per song  

**Goal:**  
Train a baseline conditional LSTM language model on symbolic pitch sequences.

### 0.0 Libs

In [5]:
!pip install mido python-rtmidi
!pip install tqdm


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m26.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3.10 -m pip install --upgrade pip[0m

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m26.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3.10 -m pip install --upgrade pip[0m


In [3]:
!wget https://github.com/Phonbopit/sarabun-webfont/raw/master/fonts/thsarabunnew-webfont.ttf

--2026-02-21 15:54:54--  https://github.com/Phonbopit/sarabun-webfont/raw/master/fonts/thsarabunnew-webfont.ttf
Resolving github.com (github.com)... 140.82.114.3
Connecting to github.com (github.com)|140.82.114.3|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/Phonbopit/sarabun-webfont/master/fonts/thsarabunnew-webfont.ttf [following]
--2026-02-21 15:54:55--  https://raw.githubusercontent.com/Phonbopit/sarabun-webfont/master/fonts/thsarabunnew-webfont.ttf
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.110.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 98308 (96K) [application/octet-stream]
Saving to: ‚Äòthsarabunnew-webfont.ttf.2‚Äô


2026-02-21 15:54:57 (388 KB/s) - ‚Äòthsarabunnew-webfont.ttf.2‚Äô saved [98308/98308]



#### Colab-only: Clone repo & cd into it

In [None]:
# Uncomment on first Colab run:
# !git clone https://github.com/GetomG/Thai-Music-Thesis.git
# %cd Thai-Music-Thesis

# If already cloned:
# !cd Thai-Music-Thesis && git pull

Cloning into 'Thai-Music-Thesis'...


OSError: [Errno 5] Input/output error

In [7]:
# ============================================================
# Import Helper Utilities from thai_music_utils
# ============================================================

# 1. Notation Processing
from thai_music_utils.notation_utils import (
    flatten_song_notation,
    normalize_octave_markers,
    notation_to_sequence
)

# 2. Octave Inference (DP-based register guessing)
from thai_music_utils.octave_inference import (
    is_thai_note,
    get_fixed_octave,
    guess_octaves_with_constraints,
    add_octaves_respecting_labels
)

# 3. Preprocessing Utilities
from thai_music_utils.preprocessing import (
    flatten_song_data,
    remove_all_signs
)

# 4. EDA Helpers (Symbolic Analysis)
from thai_music_utils.eda_symbolic_normalization import (
    normalize_token,
    normalize_bar,
    flatten_song,
    THAI_NOTES,
    UP_MARK,
    LOW_MARK,
    REST_TOKEN
)

# 5. EDA Stats
from thai_music_utils.eda_stats import (
    extract_symbols,
    pitch_stats,
    stats_to_df
)

# 6. I/O Utilities
from thai_music_utils.io_utils import (
    save_json_bar_per_line
)

# 7. MIDI Rendering (Ranad-specific)
from thai_music_utils.midi_ranad import (
    generate_ranad_midi
)


In [8]:
#setting Thai fonts

import matplotlib as mpl
import matplotlib.pyplot as plt

mpl.font_manager.fontManager.addfont('thsarabunnew-webfont.ttf')
mpl.rc('font', family='TH Sarabun New')
mpl.rcParams["axes.unicode_minus"] = False

In [9]:
import copy

### 0.1 Data Intake

In [10]:
import sys, os
from pathlib import Path

IS_COLAB = "google.colab" in sys.modules

if IS_COLAB:
    from google.colab import drive
    drive.mount("/content/drive")
    DATA_ROOT = Path("/content/drive/MyDrive/thai_music_data")
else:
    # Local: data lives next to the notebook
    DATA_ROOT = Path(os.path.abspath("")).resolve() / "thai_music_data"

print(f"Runtime: {'Colab' if IS_COLAB else 'Local'}")
print(f"DATA_ROOT: {DATA_ROOT}")

Runtime: Local
DATA_ROOT: /Users/thanakrit/Documents/Thai Music Thesis/thai_music_data


In [11]:
import json
from pathlib import Path
from collections import defaultdict

BASE = DATA_ROOT / "songs"

songs = []

for motif_dir in BASE.iterdir():
    if not motif_dir.is_dir():
        continue

    motif = motif_dir.name

    for song_dir in motif_dir.iterdir():
        json_dir = song_dir / "json"
        if not json_dir.exists():
            continue

        for json_file in json_dir.glob("*.json"):
            try:
                with open(json_file, "r", encoding="utf-8") as f:
                    data = json.load(f)

                songs.append({
                    "motif": motif,
                    "song": song_dir.name,
                    "path": str(json_file),
                    "data": data
                })

            except Exception as e:
                print(f"‚ö†Ô∏è Skipped {json_file}: {e}")

In [12]:
print(f"Total songs loaded: {len(songs)}")

by_motif = defaultdict(int)
for s in songs:
    by_motif[s["motif"]] += 1

for motif, n in by_motif.items():
    print(f"{motif}: {n} songs")

Total songs loaded: 44
‡πÄ‡∏Ç‡∏°‡∏£: 7 songs
‡∏à‡∏µ‡∏ô: 4 songs
‡∏û‡∏°‡πà‡∏≤: 5 songs
‡∏•‡∏≤‡∏ß: 18 songs
‡πÑ‡∏ó‡∏¢‡πÄ‡∏î‡∏¥‡∏°: 2 songs
‡πÅ‡∏Ç‡∏Å: 8 songs


In [None]:
# from collections import defaultdict

# note_count_by_motif = defaultdict(int)
# song_count_by_motif = defaultdict(int)

# for s in songs:
#     motif = s["motif"]
#     note_count_by_motif[motif] += len(s["pitch_sequence"])
#     song_count_by_motif[motif] += 1

# print("=== Note Count per Motif ===\n")

# for motif in sorted(note_count_by_motif.keys()):
#     total_notes = note_count_by_motif[motif]
#     total_songs = song_count_by_motif[motif]
#     avg_len = total_notes / total_songs

#     print(f"{motif}:")
#     print(f"  Total notes (including rests): {total_notes}")
#     print(f"  Songs: {total_songs}")
#     print(f"  Avg notes per song: {avg_len:.2f}")
#     print()

=== Note Count per Motif ===

‡∏à‡∏µ‡∏ô:
  Total notes (including rests): 1740
  Songs: 4
  Avg notes per song: 435.00

‡∏û‡∏°‡πà‡∏≤:
  Total notes (including rests): 971
  Songs: 5
  Avg notes per song: 194.20

‡∏•‡∏≤‡∏ß:
  Total notes (including rests): 5706
  Songs: 17
  Avg notes per song: 335.65

‡πÄ‡∏Ç‡∏°‡∏£:
  Total notes (including rests): 5936
  Songs: 7
  Avg notes per song: 848.00

‡πÅ‡∏Ç‡∏Å:
  Total notes (including rests): 5725
  Songs: 8
  Avg notes per song: 715.62

‡πÑ‡∏ó‡∏¢‡πÄ‡∏î‡∏¥‡∏°:
  Total notes (including rests): 836
  Songs: 2
  Avg notes per song: 418.00



#### **Filter only Khmer**

In [13]:
songs = [s for s in songs if s["motif"] == "‡πÄ‡∏Ç‡∏°‡∏£"]

print("Total Khmer songs:", len(songs))

Total Khmer songs: 7


### *0.2* Symbolic Normalization & Sequence Construction

#### 1Ô∏è‚É£ Normalize + Flatten

Reuse the earlier symbolic normalization logic, but simplify it for sequence modeling.

**Goal:**
- Clean and standardize symbolic tokens
- Decide whether to keep or remove register (pitch-only vs pitch+register)
- Convert structured JSON (section ‚Üí bar ‚Üí token) into one linear sequence per song
- Prepare clean token streams ready for vocabulary building and LSTM training

Output:
Each song ‚Üí one flat symbolic sequence (list of tokens)

#### `normalize_token` & `song_to_pitch_sequence`

These are the core functions for this notebook's tokenization:

- `normalize_token` ‚Äî Strips octave markers, splits dashes into compressed `<REST_k>` tokens  
- `song_to_pitch_sequence` ‚Äî Walks the full song JSON (sections ‚Üí bars ‚Üí tokens) and returns one flat token list

In [14]:
THAI_NOTES = set("‡∏î‡∏£‡∏°‡∏ü‡∏ã‡∏•‡∏ó")
UP_MARK = "‡πç"
LOW_MARK = "‡∏∫"

def normalize_token(token):
    """
    Convert token into pitch tokens + compressed rest tokens.

    Rest compression rule:
    - Any number of consecutive dashes is decomposed into
      chunks of <REST_4>, <REST_3>, <REST_2>, <REST_1>
    """

    if not isinstance(token, str):
        return ["<REST_1>"]

    token = token.strip()

    # ---- Pure rest token ----
    if set(token) == {"-"}:
        dash_count = len(token)
        rests = []

        while dash_count > 0:
            if dash_count >= 4:
                rests.append("<REST_4>")
                dash_count -= 4
            elif dash_count >= 3:
                rests.append("<REST_3>")
                dash_count -= 3
            elif dash_count >= 2:
                rests.append("<REST_2>")
                dash_count -= 2
            else:
                rests.append("<REST_1>")
                dash_count -= 1

        return rests

    # ---- Mixed pitch token ----
    out = []
    i = 0

    while i < len(token):
        ch = token[i]

        if ch == "-":
            # count consecutive dashes
            dash_count = 0
            while i < len(token) and token[i] == "-":
                dash_count += 1
                i += 1

            while dash_count > 0:
                if dash_count >= 4:
                    out.append("<REST_4>")
                    dash_count -= 4
                elif dash_count >= 3:
                    out.append("<REST_3>")
                    dash_count -= 3
                elif dash_count >= 2:
                    out.append("<REST_2>")
                    dash_count -= 2
                else:
                    out.append("<REST_1>")
                    dash_count -= 1

        elif ch in THAI_NOTES:
            out.append(ch)
            i += 1

            # skip octave mark
            if i < len(token) and token[i] in {UP_MARK, LOW_MARK}:
                i += 1

        else:
            i += 1

    return out if out else ["<REST_1>"]

def song_to_pitch_sequence(song_json):
    """
    Convert full song JSON into one flat pitch sequence.
    Handles:
    - normal string tokens
    - dict blocks like {"‡∏ô‡∏≥": [...]}, {"‡∏ï‡∏≤‡∏°": [...]}
    - nested lists
    """

    sequence = []

    def process_token(tok):
        # Case 1: string token
        if isinstance(tok, str):
            sequence.extend(normalize_token(tok))

        # Case 2: dict (‡∏ô‡∏≥ / ‡∏ï‡∏≤‡∏° block)
        elif isinstance(tok, dict):
            for key in tok:
                for inner_tok in tok[key]:
                    process_token(inner_tok)

        # Case 3: list (nested structure)
        elif isinstance(tok, list):
            for inner_tok in tok:
                process_token(inner_tok)

    for section in song_json.get("sections", []):
        for bar in section.get("bars", []):
            for tok in bar:
                process_token(tok)

    return sequence

#### Apply to songs

In [15]:
#Apply to songs
for s in songs:
    s["pitch_sequence"] = song_to_pitch_sequence(s["data"])

print(len(songs[2]["pitch_sequence"]))
print(songs[2]["pitch_sequence"][:80])

1007
['<REST_4>', '<REST_3>', '‡∏ü', '<REST_1>', '‡∏ü', '‡∏ü', '‡∏ü', '<REST_1>', '‡∏ü', '<REST_1>', '‡∏ü', '<REST_1>', '‡∏•', '‡∏î', '‡∏£', '<REST_1>', '‡∏ü', '<REST_1>', '‡∏ã', '‡∏•', '‡∏ã', '‡∏î', '‡∏•', '<REST_1>', '‡∏ã', '<REST_1>', '‡∏ü', '<REST_2>', '‡∏ü', '‡∏ã', '‡∏•', '<REST_1>', '‡∏î', '<REST_1>', '‡∏£', '<REST_1>', '‡∏ã', '‡∏ü', '‡∏£', '<REST_1>', '‡∏î', '<REST_1>', '‡∏•', '<REST_3>', '‡∏ã', '<REST_3>', '‡∏•', '<REST_2>', '‡∏£', '‡∏î', '‡∏•', '‡∏î', '<REST_1>', '‡∏£', '<REST_4>', '<REST_4>', '<REST_3>', '‡∏ü', '<REST_3>', '‡∏ã', '<REST_2>', '‡∏ü', '‡∏•', '<REST_1>', '‡∏ã', '‡∏ü', '‡∏£', '<REST_4>', '<REST_4>', '<REST_3>', '‡∏ü', '<REST_3>', '‡∏£', '‡∏î', '‡∏£', '‡∏ü', '‡∏î', '<REST_4>', '‡∏ã', '‡∏•']


In [16]:
from collections import defaultdict

note_count_by_motif = defaultdict(int)
song_count_by_motif = defaultdict(int)

for s in songs:
    motif = s["motif"]
    note_count_by_motif[motif] += len(s["pitch_sequence"])
    song_count_by_motif[motif] += 1

print("=== Note Count per Motif ===\n")

for motif in sorted(note_count_by_motif.keys()):
    total_notes = note_count_by_motif[motif]
    total_songs = song_count_by_motif[motif]
    avg_len = total_notes / total_songs

    print(f"{motif}:")
    print(f"  Total notes (including rests): {total_notes}")
    print(f"  Songs: {total_songs}")
    print(f"  Avg notes per song: {avg_len:.2f}")
    print()

=== Note Count per Motif ===

‡πÄ‡∏Ç‡∏°‡∏£:
  Total notes (including rests): 5936
  Songs: 7
  Avg notes per song: 848.00



In [15]:
# from thai_music_utils.eda_symbolic_normalization import flatten_song

# pattern = "‡∏•‡∏ã‡∏ü"
# matches = []

# for s in songs:
#     if s["motif"] != "‡∏•‡∏≤‡∏ß":
#         continue

#     # flatten to normalized token list
#     sequence = flatten_song(s["data"])

#     # merge into one continuous string
#     seq_string = "".join(sequence)

#     if pattern in seq_string:
#         matches.append(s["song"])

# print("Songs under ‡∏•‡∏≤‡∏ß containing '‡∏•‡∏ã‡∏ü':")
# for name in matches:
#     print("-", name)

# print("\nTotal:", len(matches))

#### 2Ô∏è‚É£ Build Vocabulary (Token ‚Üí Integer Mapping)

Neural networks cannot process symbolic tokens directly.  
We must convert each pitch token into a numeric ID.

Steps:
- Collect all unique tokens from all songs
- Assign each token a unique integer
- Build two mappings:
  - `token_to_id`
  - `id_to_token`

This defines:
- The model vocabulary
- The input/output dimension for the LSTM

In [17]:
from collections import Counter

# Collect all tokens across songs
all_tokens = []

for s in songs:
    all_tokens.extend(s["pitch_sequence"])

# Unique vocabulary
vocab = sorted(set(all_tokens))

# Create mappings
token_to_id = {tok: i for i, tok in enumerate(vocab)}
id_to_token = {i: tok for tok, i in token_to_id.items()}

vocab_size = len(vocab)

print("Vocabulary:", vocab)
print("Vocab size:", vocab_size)

Vocabulary: ['<REST_1>', '<REST_2>', '<REST_3>', '<REST_4>', '‡∏ã', '‡∏î', '‡∏ó', '‡∏ü', '‡∏°', '‡∏£', '‡∏•']
Vocab size: 11


In [17]:
# for s in songs:
#     for tok in s["pitch_sequence"]:
#         if tok in ["<", "R", "E", "S", "T", ">"]:
#             print("BROKEN:", s["song"])
#             break

#### 3Ô∏è‚É£ Convert Songs to Integer Sequences

Neural networks cannot process symbolic tokens directly.  
We must convert each pitch token into a numeric ID.

Steps:
- Take each song‚Äôs `pitch_sequence`
- Replace each token with its corresponding integer ID using `token_to_id`
- Store the new numeric sequence (e.g., `id_sequence`) inside each song

This produces:
- One integer sequence per song  
- Shape per song: `(sequence_length,)`
- Vocabulary size = `vocab_size`

These integer sequences will be used to:
- Create training samples (input ‚Üí next token prediction)
- Feed into the LSTM model
- Compute loss over predicted next-token probabilities

In [18]:
# 3Ô∏è‚É£ Convert Songs to Integer Sequences

for s in songs:
    s["id_sequence"] = [
        token_to_id[token]
        for token in s["pitch_sequence"]
        if token in token_to_id
    ]

# sanity check
print("Example song length:", len(songs[2]["id_sequence"]))
print("First 30 token IDs:", songs[2]["id_sequence"][:30])

Example song length: 1007
First 30 token IDs: [3, 2, 7, 0, 7, 7, 7, 0, 7, 0, 7, 0, 10, 5, 9, 0, 7, 0, 4, 10, 4, 5, 10, 0, 4, 0, 7, 1, 7, 4]


#### 4Ô∏è‚É£ Prepare Training Sequences (LSTM Input Construction)

The LSTM does not see full songs at once.

Instead, we train it using sliding windows:

Given a sequence:
‡∏î ‡∏£ ‡∏° ‡∏ã ‡∏î ‡∏£ ‡∏î ...

We create training samples like:

```
Input (length = seq_len) ‚Üí Target
[‡∏î ‡∏£ ‡∏° ‡∏ã] ‚Üí ‡∏î
[‡∏£ ‡∏° ‡∏ã ‡∏î] ‚Üí ‡∏£
[‡∏° ‡∏ã ‡∏î ‡∏£] ‚Üí ‡∏î
```

This teaches the model:
"Given the previous N notes, predict the next note."

Steps:
- Choose a sequence length (e.g., 16)
- Slide window across every song
- Convert token IDs into (X, y) pairs
- X shape: (num_samples, seq_len)
- y shape: (num_samples,)

In [19]:
for s in songs:
    s["id_sequence"] = [
        token_to_id[token]
        for token in s["pitch_sequence"]
    ]

In [20]:
import numpy as np

SEQ_LEN = 16  # number of previous notes to condition on

X = []
y = []

for s in songs:
    ids = s["id_sequence"]

    if len(ids) <= SEQ_LEN:
        continue

    for i in range(len(ids) - SEQ_LEN):
        X.append(ids[i:i+SEQ_LEN])
        y.append(ids[i+SEQ_LEN])

X = np.array(X)
y = np.array(y)

print("X shape:", X.shape)
print("y shape:", y.shape)

X shape: (5824, 16)
y shape: (5824,)


## 1Ô∏è‚É£ LSTM Model Definition

We now define a baseline LSTM language model.

Goal:
- Input: sequence of 16 token IDs
- Output: probability distribution over next token

Architecture:
- Embedding layer (token ‚Üí dense vector)
- LSTM layer
- Linear output layer
- Softmax handled by CrossEntropyLoss

This is a standard neural language model setup.

In [21]:
# ============================================================
# 0. PREREQUISITES
# ============================================================

import numpy as np
import random
from collections import Counter

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# Reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cpu


In [23]:
import torch
import torch.nn as nn

# ============================================================
# LSTM Language Model (Improved Version)
# - 2-layer stacked LSTM
# - Dropout between layers
# - Predict next token from last hidden state
# ============================================================

class LSTMLanguageModel(nn.Module):
    def __init__(self, vocab_size, embed_dim=64, hidden_dim=128, num_layers=2, dropout=0.25):
        super().__init__()

        # 1Ô∏è‚É£ Embedding Layer
        # Converts token IDs ‚Üí dense vectors
        # Shape: (batch, seq_len) ‚Üí (batch, seq_len, embed_dim)
        # --------------------------------------------------------
        self.embedding = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=embed_dim
        )

        # 2Ô∏è‚É£ Stacked LSTM
        # num_layers=2 ‚Üí hierarchical pattern modeling
        # dropout applied BETWEEN LSTM layers
        # --------------------------------------------------------
        self.lstm = nn.LSTM(
            input_size=embed_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            dropout=dropout,
            batch_first=True
        )

        # 3Ô∏è‚É£ Final Linear Layer
        # Maps final hidden state ‚Üí vocabulary logits
        # Shape: (batch, hidden_dim) ‚Üí (batch, vocab_size)
        # --------------------------------------------------------
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x):
        """
        x shape: (batch_size, seq_len)
        """

        # Step 1: Convert tokens to embeddings
        x = self.embedding(x)

        # Step 2: Pass through stacked LSTM
        # output shape: (batch, seq_len, hidden_dim)
        output, _ = self.lstm(x)

        # Step 3: Take last timestep only
        last_hidden = output[:, -1, :]

        # Step 4: Project to vocab size
        logits = self.fc(last_hidden)

        return logits


# Instantiate model
# ------------------------------------------------------------
model = LSTMLanguageModel(vocab_size).to(device)

print(model)

LSTMLanguageModel(
  (embedding): Embedding(11, 64)
  (lstm): LSTM(64, 128, num_layers=2, batch_first=True, dropout=0.25)
  (fc): Linear(in_features=128, out_features=11, bias=True)
)


### 1Ô∏è‚É£ Training Setup

Now we train the LSTM as a next-token predictor.

Task:
Given 16 previous tokens ‚Üí predict the next token.

We define:

‚Ä¢ Loss: CrossEntropyLoss  
  - Standard for multi-class classification  
  - Compares predicted distribution vs true next-token  

‚Ä¢ Optimizer: Adam  
  - Stable for sequence models  
  - Good default for LSTM  

‚Ä¢ Training loop:
  1. Forward pass
  2. Compute loss
  3. Backpropagation
  4. Update weights
  5. Repeat for multiple epochs

Goal:
Minimize next-token prediction loss.

In [None]:
from tqdm import tqdm

# Convert to PyTorch tensors
X_tensor = torch.tensor(X, dtype=torch.long).to(device)
y_tensor = torch.tensor(y, dtype=torch.long).to(device)

# Training config
epochs = 30
batch_size = 64
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

num_samples = X_tensor.shape[0]


In [None]:
import math  # put this once at top of notebook
for epoch in range(epochs):

    model.train()
    total_loss = 0.0

    progress_bar = tqdm(
        range(0, num_samples, batch_size),
        desc=f"Epoch {epoch+1}/{epochs}"
    )

    for i in progress_bar:

        X_batch = X_tensor[i:i+batch_size]
        y_batch = y_tensor[i:i+batch_size]

        optimizer.zero_grad()

        logits = model(X_batch)
        loss = criterion(logits, y_batch)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        progress_bar.set_postfix(loss=loss.item())

    avg_loss = total_loss / (num_samples // batch_size)
    print(f"Epoch {epoch+1} Average Loss: {avg_loss:.4f}")

    perplexity = math.exp(avg_loss)
    print(f"Epoch {epoch+1} Perplexity: {perplexity:.4f}")

Epoch 1/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 91/91 [00:05<00:00, 16.89it/s, loss=0.635]


Epoch 1 Average Loss: 0.5374
Epoch 1 Perplexity: 1.7115


Epoch 2/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 91/91 [00:07<00:00, 12.70it/s, loss=0.574]


Epoch 2 Average Loss: 0.4845
Epoch 2 Perplexity: 1.6233


Epoch 3/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 91/91 [00:05<00:00, 17.12it/s, loss=0.578]


Epoch 3 Average Loss: 0.4638
Epoch 3 Perplexity: 1.5901


Epoch 4/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 91/91 [00:05<00:00, 15.59it/s, loss=0.521]


Epoch 4 Average Loss: 0.4309
Epoch 4 Perplexity: 1.5387


Epoch 5/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 91/91 [00:06<00:00, 14.28it/s, loss=0.482]


Epoch 5 Average Loss: 0.3998
Epoch 5 Perplexity: 1.4915


Epoch 6/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 91/91 [00:05<00:00, 17.14it/s, loss=0.488]


Epoch 6 Average Loss: 0.3833
Epoch 6 Perplexity: 1.4672


Epoch 7/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 91/91 [00:07<00:00, 12.91it/s, loss=0.388]


Epoch 7 Average Loss: 0.3488
Epoch 7 Perplexity: 1.4174


Epoch 8/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 91/91 [00:05<00:00, 18.02it/s, loss=0.401]


Epoch 8 Average Loss: 0.3263
Epoch 8 Perplexity: 1.3859


Epoch 9/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 91/91 [00:05<00:00, 17.19it/s, loss=0.383]


Epoch 9 Average Loss: 0.3121
Epoch 9 Perplexity: 1.3663


Epoch 10/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 91/91 [00:06<00:00, 13.01it/s, loss=0.375]


Epoch 10 Average Loss: 0.2934
Epoch 10 Perplexity: 1.3410


Epoch 11/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 91/91 [00:05<00:00, 17.28it/s, loss=0.24]


Epoch 11 Average Loss: 0.2834
Epoch 11 Perplexity: 1.3276


Epoch 12/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 91/91 [00:06<00:00, 13.75it/s, loss=0.279]


Epoch 12 Average Loss: 0.2658
Epoch 12 Perplexity: 1.3045


Epoch 13/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 91/91 [00:05<00:00, 16.14it/s, loss=0.186]


Epoch 13 Average Loss: 0.2423
Epoch 13 Perplexity: 1.2742


Epoch 14/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 91/91 [00:05<00:00, 17.38it/s, loss=0.194]


Epoch 14 Average Loss: 0.2215
Epoch 14 Perplexity: 1.2479


Epoch 15/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 91/91 [00:07<00:00, 12.93it/s, loss=0.203]


Epoch 15 Average Loss: 0.2191
Epoch 15 Perplexity: 1.2449


Epoch 16/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 91/91 [00:05<00:00, 15.23it/s, loss=0.198]


Epoch 16 Average Loss: 0.2120
Epoch 16 Perplexity: 1.2361


Epoch 17/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 91/91 [00:06<00:00, 13.17it/s, loss=0.228]


Epoch 17 Average Loss: 0.1973
Epoch 17 Perplexity: 1.2181


Epoch 18/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 91/91 [00:05<00:00, 16.53it/s, loss=0.17]


Epoch 18 Average Loss: 0.1869
Epoch 18 Perplexity: 1.2055


Epoch 19/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 91/91 [00:05<00:00, 17.31it/s, loss=0.186]


Epoch 19 Average Loss: 0.1765
Epoch 19 Perplexity: 1.1931


Epoch 20/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 91/91 [00:07<00:00, 12.87it/s, loss=0.171]


Epoch 20 Average Loss: 0.1690
Epoch 20 Perplexity: 1.1841


Epoch 21/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 91/91 [00:05<00:00, 17.16it/s, loss=0.0998]


Epoch 21 Average Loss: 0.1531
Epoch 21 Perplexity: 1.1654


Epoch 22/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 91/91 [00:06<00:00, 13.45it/s, loss=0.182]


Epoch 22 Average Loss: 0.1517
Epoch 22 Perplexity: 1.1638


Epoch 23/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 91/91 [00:05<00:00, 16.28it/s, loss=0.099]


Epoch 23 Average Loss: 0.1426
Epoch 23 Perplexity: 1.1533


Epoch 24/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 91/91 [00:05<00:00, 17.17it/s, loss=0.101]


Epoch 24 Average Loss: 0.1306
Epoch 24 Perplexity: 1.1396


Epoch 25/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 91/91 [00:07<00:00, 12.83it/s, loss=0.0746]


Epoch 25 Average Loss: 0.1335
Epoch 25 Perplexity: 1.1428


Epoch 26/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 91/91 [00:05<00:00, 17.22it/s, loss=0.111]


Epoch 26 Average Loss: 0.1296
Epoch 26 Perplexity: 1.1384


Epoch 27/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 91/91 [00:06<00:00, 14.00it/s, loss=0.143]


Epoch 27 Average Loss: 0.1267
Epoch 27 Perplexity: 1.1351


Epoch 28/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 91/91 [00:05<00:00, 15.75it/s, loss=0.117]


Epoch 28 Average Loss: 0.1193
Epoch 28 Perplexity: 1.1267


Epoch 29/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 91/91 [00:05<00:00, 17.33it/s, loss=0.0982]


Epoch 29 Average Loss: 0.1146
Epoch 29 Perplexity: 1.1214


Epoch 30/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 91/91 [00:07<00:00, 12.92it/s, loss=0.0955]

Epoch 30 Average Loss: 0.1097
Epoch 30 Perplexity: 1.1159





In [None]:
WEIGHTS_DIR = DATA_ROOT / "weights"
WEIGHTS_DIR.mkdir(parents=True, exist_ok=True)

print("Weights directory:", WEIGHTS_DIR)

Weights directory: /content/drive/MyDrive/thai_music_data/weights


In [None]:
import torch

torch.save(model.state_dict(), WEIGHTS_DIR / "lstm_pitch_only_khmer_35.pth")
print("Model saved.")

Model saved.


## 2Ô∏è‚É£ Generation (Sampling)

Now we use the trained LSTM to generate new symbolic pitch sequences.

The model works as a next-token predictor:
- Input: 16-token context window
- Output: probability distribution over next token

Generation procedure:
1. Start with a seed sequence
2. Predict next token
3. Append prediction
4. Slide window forward
5. Repeat

We use **temperature sampling**:
- Low temperature (<1.0) ‚Üí safer, repetitive
- High temperature (>1.0) ‚Üí more creative, unstable

This lets us observe whether the model has learned meaningful Thai melodic structure.

In [31]:
WEIGHTS_DIR = DATA_ROOT / "weights"

model = LSTMLanguageModel(vocab_size).to(device)
model.load_state_dict(
    torch.load(
        WEIGHTS_DIR / "lstm_pitch_only_khmer_35.pth",
        map_location=device
    )
)
model.eval()

LSTMLanguageModel(
  (embedding): Embedding(11, 64)
  (lstm): LSTM(64, 128, num_layers=2, batch_first=True, dropout=0.25)
  (fc): Linear(in_features=128, out_features=11, bias=True)
)

In [32]:
import torch
import torch.nn.functional as F

def generate_sequence(
    model,
    seed_ids,
    max_new_tokens=100,
    temperature=1.0,
    seed=None   # ‚Üê add this
):

    model.eval()


    # üî¥ ADD THIS
    if seed is not None:
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)

    generated = seed_ids.copy()

    for _ in range(max_new_tokens):

        context = generated[-16:]
        x = torch.tensor(context, dtype=torch.long).unsqueeze(0).to(device)

        with torch.no_grad():
            logits = model(x)

        logits = logits / temperature
        probs = F.softmax(logits, dim=-1)

        next_id = torch.multinomial(probs, num_samples=1).item()
        generated.append(next_id)

    return generated

def decode_ids(id_sequence):
  return [id_to_token[i] for i in id_sequence]

### 2.1 Generate from Existing Song (Paused)

Seeds the model with the first `seq_len` tokens of a loaded song, then generates a continuation. Useful for side-by-side comparison with the original.

In [27]:
def generate_from_song(
    song_idx,
    model,
    songs,
    seq_len=16,
    max_new_tokens=120,
    temperature=0.8
):
    """
    Generate continuation from a selected song
    and print side-by-side comparison.
    """

    model.eval()

    song = songs[song_idx]
    song_name = song["song"]

    print(f"\n=== Song Index: {song_idx} | Song: {song_name} ===")

    # ---- Seed ----
    seed_ids = song["id_sequence"][:seq_len]

    # ---- Generate ----
    generated_ids = generate_sequence(
        model,
        seed_ids=seed_ids,
        max_new_tokens=max_new_tokens,
        temperature=temperature
    )

    # ---- Decode ----
    original_ids = song["id_sequence"][:seq_len + max_new_tokens]

    original_tokens = decode_ids(original_ids)
    generated_tokens = decode_ids(generated_ids)

    # ---- Print ----
    print("\nSEED:")
    print(original_tokens[:seq_len])

    print("\nORIGINAL continuation:")
    print(original_tokens[seq_len:seq_len + 60])

    print("\nGENERATED continuation:")
    print(generated_tokens[seq_len:seq_len + 60])

In [None]:
generate_from_song(
    song_idx=22,
    model=model,
    songs=songs,
    seq_len=16,
    max_new_tokens=120,
    temperature=1.2
)


=== Song Index: 22 | Song: ‡πÅ‡∏Ç‡∏Å‡∏Ç‡∏≤‡∏ß ===

SEED:
['<REST>', '<REST>', '<REST>', '<REST>', '<REST>', '<REST>', '<REST>', '‡∏£', '<REST>', '‡∏£', '‡∏£', '‡∏£', '<REST>', '‡∏£', '<REST>', '‡∏£']

ORIGINAL continuation:
['<REST>', '‡∏°', '<REST>', '‡∏î', '<REST>', '‡∏£', '<REST>', '‡∏°', '<REST>', '‡∏ü', '<REST>', '‡∏ã', '<REST>', '<REST>', '<REST>', '‡∏£', '<REST>', '<REST>', '<REST>', '‡∏ã', '‡∏•', '‡∏ó', '‡∏î', '‡∏£', '<REST>', '‡∏£', '‡∏£', '‡∏£', '<REST>', '‡∏£', '<REST>', '‡∏£', '<REST>', '‡∏î', '<REST>', '‡∏î', '<REST>', '<REST>', '<REST>', '<REST>', '‡∏•', '<REST>', '<REST>', '‡∏î', '‡∏•', '‡∏ã', '‡∏ü', '<REST>', '‡∏ã', '<REST>', '<REST>', '<REST>', '<REST>', '<REST>', '<REST>', '<REST>', '<REST>', '<REST>', '<REST>', '<REST>']

GENERATED continuation:
['‡∏î', '‡∏£', '‡∏ü', '‡∏ã', '‡∏ü', '‡∏£', '‡∏ü', '‡∏î', '‡∏£', '‡∏£', '<REST>', '‡∏£', '‡∏£', '‡∏£', '<REST>', '‡∏£', '<REST>', '‡∏î', '<REST>', '<REST>', '<REST>', '<REST>', '<REST>', '<REST>', '<REST>', '<REST>', '<REST>'

### 2.2 Generate from Raw Fragment

Generate a continuation from a manually provided Thai notation fragment (slot-format strings like `"---‡∏ü"`).  
The fragment is normalized, padded to `seq_len`, and fed as the seed.

In [33]:
# Sort songs by length of pitch_sequence (descending)
sorted_songs = sorted(
    songs,
    key=lambda s: len(s["pitch_sequence"]),
    reverse=True
)

print("Top 5 Longest Songs:\n")

for i, s in enumerate(sorted_songs[:5]):
    print(f"{i+1}. {s['song']} | Motif: {s['motif']} | Length: {len(s['pitch_sequence'])}")

Top 5 Longest Songs:

1. ‡πÄ‡∏Ç‡∏°‡∏£‡πÇ‡∏û‡∏ò‡∏¥‡∏™‡∏±‡∏ï‡∏ß‡πå | Motif: ‡πÄ‡∏Ç‡∏°‡∏£ | Length: 1679
2. ‡πÄ‡∏Ç‡∏°‡∏£‡∏û‡∏ß‡∏á | Motif: ‡πÄ‡∏Ç‡∏°‡∏£ | Length: 1007
3. ‡πÄ‡∏Ç‡∏°‡∏£‡∏•‡∏≠‡∏≠‡∏≠‡∏á‡∏Ñ‡πå | Motif: ‡πÄ‡∏Ç‡∏°‡∏£ | Length: 838
4. ‡πÄ‡∏Ç‡∏°‡∏£‡∏ä‡∏ô‡∏ö‡∏ó | Motif: ‡πÄ‡∏Ç‡∏°‡∏£ | Length: 759
5. ‡πÄ‡∏Ç‡∏°‡∏£‡∏õ‡∏≤‡∏Å‡∏ó‡πà‡∏≠ | Motif: ‡πÄ‡∏Ç‡∏°‡∏£ | Length: 717


### 2.3 Generation & Post-processing Helpers

- `generate_from_fragment` ‚Äî Normalizes a raw fragment, generates continuation, pretty-prints per bar  
- `combine_fragment_and_generated` ‚Äî Stitches original fragment (with octave marks) + generated continuation into a single slot list  
- `slots_to_song_data` ‚Äî Converts flat slot list back into song JSON format (8 slots per bar)

In [25]:
def generate_from_fragment(
    fragment_tokens,
    model,
    token_to_id,
    id_to_token,
    seq_len=16,
    max_new_tokens=120,
    temperature=0.8,
    bar_size=32,  # 1 bar = 32 slots
    seed=None   # ‚Üê add this
):
    """
    Generate continuation from manually provided JSON-style fragment.

    Display:
    - <REST> shown as "-"
    - New line every bar (32 tokens)
    """

    model.eval()

    # ---- Normalize fragment ----
    normalized = []
    for tok in fragment_tokens:
        normalized.extend(normalize_token(tok))

    fragment_ids = [
        token_to_id[t]
        for t in normalized
        if t in token_to_id
    ]

    if not fragment_ids:
        print("‚ö†Ô∏è Fragment produced no valid tokens.")
        return None

    # ---- Left pad ----
    if len(fragment_ids) < seq_len:
        pad_len = seq_len - len(fragment_ids)
        fragment_ids = [token_to_id["<REST>"]] * pad_len + fragment_ids
    else:
        fragment_ids = fragment_ids[-seq_len:]

    # ---- Generate ----
    generated_ids = generate_sequence(
        model,
        seed_ids=fragment_ids,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        seed=seed
    )

    generated_tokens = [id_to_token[i] for i in generated_ids]


    # ---- Convert REST_k back to proper dash count ----
    pretty_stream = []

    for t in generated_tokens[seq_len:]:
        if t.startswith("<REST_"):
            k = int(t.replace("<REST_", "").replace(">", ""))
            pretty_stream.extend(["-"] * k)
        else:
            pretty_stream.append(t)

    print("\nGENERATED continuation:\n")

    # ---- Print per bar ----
    for i in range(0, len(pretty_stream), bar_size):
        bar = pretty_stream[i:i+bar_size]

        # group every 4 characters (1 slot)
        grouped = [
            "".join(bar[j:j+4])
            for j in range(0, len(bar), 4)
        ]

        print(", ".join(grouped))

    return generated_tokens

In [26]:
def combine_fragment_and_generated(
    fragment_tokens,
    generated_tokens,
    normalize_token,
    seq_len=16
):
    """
    Combine original fragment (preserving octave marks)
    with generated continuation.

    Returns dash-based slot list.
    """

    # -----------------------------
    # 1Ô∏è‚É£ Keep fragment EXACTLY as is
    # -----------------------------
    fragment_slots = fragment_tokens.copy()

    # -----------------------------
    # 2Ô∏è‚É£ Remove seed duplication from generated
    # -----------------------------
    continuation = generated_tokens[seq_len:]

    # -----------------------------
    # 3Ô∏è‚É£ Convert generated tokens back to dash string
    # -----------------------------
    generated_parts = []

    for tok in continuation:
        if tok.startswith("<REST_"):
            k = int(tok.replace("<REST_", "").replace(">", ""))
            generated_parts.append("-" * k)
        else:
            generated_parts.append(tok)

    generated_string = "".join(generated_parts)

    # -----------------------------
    # 4Ô∏è‚É£ Split generated part into 4-char slots
    # -----------------------------
    generated_slots = [
        generated_string[i:i+4]
        for i in range(0, len(generated_string), 4)
    ]

    # -----------------------------
    # 5Ô∏è‚É£ Combine fragment + generated
    # -----------------------------
    combined_slots = fragment_slots + generated_slots

    return combined_slots

In [27]:
def slots_to_song_data(slots, title="Generated"):

    bars = [
        slots[i:i+8]
        for i in range(0, len(slots), 8)
    ]

    return {
        "title": title,
        "sections": [
            {
                "name": "Generated",
                "bars": bars
            }
        ]
    }

In [28]:
THAI_NOTES = "‡∏î‡∏£‡∏°‡∏ü‡∏ã‡∏•‡∏ó"

thai_base = {'‡∏î': 58, '‡∏£': 60, '‡∏°': 62, '‡∏ü': 63, '‡∏ã': 65, '‡∏•': 67, '‡∏ó': 69}
octave_offset = {1: -12, 2: 0, 3: 12}

# Only these note‚Äìoctave combos are allowed (16 total)
allowed_oct = {
    '‡∏î': [2, 3],      # ‡∏î, ‡∏î‡πç
    '‡∏£': [2, 3],      # ‡∏£, ‡∏£‡πç
    '‡∏°': [2, 3],      # ‡∏°, ‡∏°‡πç
    '‡∏ü': [2, 3],      # ‡∏ü, ‡∏ü‡πç
    '‡∏ã': [1, 2, 3],   # ‡∏ã‡∏∫ ‡∏ã ‡∏ã‡πç
    '‡∏•': [1, 2, 3],   # ‡∏•‡∏∫ ‡∏• ‡∏•‡πç
    '‡∏ó': [1, 2],      # ‡∏ó‡∏∫ ‡∏ó   (no ‡∏ó‡πç)
}


## 3Ô∏è‚É£ Fragment Demo

Generate a continuation from a hand-written Khmer-style fragment.  
The fragment uses standard Thai slot notation (4 chars per slot, 8 slots per bar).

In [34]:
fragment = [ "----", "---‡∏ü", "-‡∏ü‡∏ü‡∏ü", "-‡∏ü-‡∏ü", "---‡∏ã", "---‡∏•", "---‡∏î‡πç", "---‡∏£‡πç",
            "-‡∏ü‡πç-‡∏î‡πç", "-‡∏£‡πç-‡∏ü‡πç", "-‡∏•‡πç‡∏ã‡πç‡∏ü‡πç", "-‡∏î‡πç-‡∏£‡πç", "‡∏ü‡πç‡∏£‡πç‡∏î‡πç‡∏•", "-‡∏î‡πç--", "-‡∏ã-‡∏•", "‡∏î‡πç‡∏ã‡∏•‡∏î‡πç", "----" ]

generated_tokens = generate_from_fragment(
    fragment_tokens=fragment,
    model=model,
    token_to_id=token_to_id,
    id_to_token=id_to_token,
    seq_len=16,
    max_new_tokens=240,
    temperature=1.1,
    seed=42
)


# GENERATED continuation:

# ---‡∏î, -‡∏£‡∏ü‡∏î, ‡∏£‡∏î‡∏•‡∏î, ‡∏£‡∏ü-‡∏£, ---‡∏•, ---‡∏î, ---‡∏ü, ---‡∏£
# -‡∏ü-‡∏ã, ‡∏•‡∏ã‡∏ü‡∏£, ‡∏ã‡∏ü‡∏£‡∏î, ‡∏ü‡∏£‡∏î‡∏•, ----, ---‡∏•, -‡∏•‡∏•‡∏•, -‡∏•-‡∏•
# --‡∏î‡∏£, ‡∏ü‡∏£‡∏î‡∏•, ‡∏î‡∏•‡∏ã‡∏ü, ‡∏î‡∏£‡∏ü‡∏ã, ‡∏•‡∏ü‡∏•‡∏ã, -‡∏ü-‡∏£, ----, ----
# -‡∏°-‡∏°, -‡∏°-‡∏°, ‡∏ü‡∏ü‡∏ü‡∏ü, -‡∏£-‡∏£, ‡∏î‡∏î‡∏î‡∏î, -‡∏•-‡∏•, ‡∏ã‡∏ã‡∏ã‡∏ã, -‡∏•-‡∏•
# ‡∏î‡∏î‡∏î‡∏î, -‡∏£-‡∏£, -‡∏°‡∏ü‡∏£, ‡∏ü‡∏î-‡∏£, -‡∏ü--, -‡∏ã--, -‡∏•--, -‡∏î--
# -‡∏£-‡∏ü, ----, ---‡∏ü, -‡∏ü‡∏ü‡∏ü, -‡∏ü-‡∏ü, ‡∏•‡∏ã‡∏ü‡∏£, ‡∏ü‡∏î‡∏£‡∏ü, ‡∏î‡∏£‡∏ü‡∏ã
# ‡∏•‡∏ü‡∏ã‡∏•, -‡∏°-‡∏°, -‡∏ü-‡∏£, -‡∏î-‡∏•, ‡∏ã‡∏•‡∏î‡∏•, ‡∏ã‡∏ü-‡∏ã, ----, ----
# ‡∏ü‡∏ã‡∏•‡∏ã, ‡∏ü‡∏£-‡∏ü, ----, ‡∏î‡∏£‡∏ü‡∏ã, -‡∏ü-‡∏•, ‡∏ã‡∏ã--, --‡∏ã‡∏•, ‡∏î‡∏£‡∏ü‡∏£
# ‡∏ü‡∏ã‡∏ü‡∏£, ‡∏ü‡∏î-‡∏£, ----, ‡∏î‡∏£‡∏ü‡∏ã, ‡∏•‡∏ã‡∏ü‡∏£, ‡∏î‡∏•‡∏î‡∏£, ‡∏î‡∏£‡∏ü‡∏ã, -‡∏•‡∏î‡∏£
# ‡∏î‡∏£


GENERATED continuation:

-‡∏î-‡∏ü, ‡∏ã‡∏•‡∏î‡∏£, ‡∏î‡∏£‡∏ü‡∏ã, ‡∏•‡∏ã‡∏ü‡∏£, ‡∏ü‡∏ã-‡∏ü, ---‡∏î, ‡∏£‡∏°‡∏ü‡∏ã, ---‡∏•
‡∏ã‡∏ã‡∏ã‡∏ã, ‡∏î‡∏£‡∏î‡∏î, ‡∏ã‡∏•‡∏î‡∏£, ‡∏ã‡∏ü‡∏ü‡∏ü, ‡∏î‡∏£‡∏ü‡∏ã, -‡∏•‡∏î‡∏ã, ‡∏•‡∏ã‡∏ü‡∏£, ----
---‡∏•, ---‡∏ã, ‡∏ü‡∏ã‡∏•‡∏î, ---‡∏£, ‡∏î‡∏°‡∏£‡∏£, ----, ----, ‡∏ü‡∏£‡∏î‡∏•
-‡∏ã-‡∏ü, --‡∏•‡∏ã, ‡∏ü‡∏ã-‡∏•, -‡∏î‡∏ü‡∏£, ‡∏î‡∏£‡∏ü‡∏•, ‡∏ã‡∏ü--, ‡∏ü‡∏ü-‡∏ü, --‡∏•‡∏ã
‡∏ü‡∏£-‡∏ü, -‡∏ã-‡∏ã, -‡∏ã‡∏î‡∏£, ‡∏ü‡∏ã‡∏ü‡∏•, ‡∏ã‡∏ü‡∏•‡∏ã, ‡∏ü‡∏£--, ‡∏°‡∏°‡∏ã‡∏°, ‡∏£‡∏î‡∏£‡∏î
‡∏•‡∏î-‡∏£, -‡∏ü‡∏•‡∏ã, ‡∏ü‡∏ã-‡∏•, -‡∏î‡∏ü‡∏£, ‡∏î‡∏£‡∏ü‡∏•, ‡∏ã‡∏ü--, ‡∏ü‡∏ü-‡∏ü, --‡∏•‡∏ã
‡∏ü‡∏£-‡∏ü, -‡∏ã‡∏•‡∏ã, ‡∏ü‡∏£‡∏ü‡∏î, -‡∏£--, ----, -‡∏£-‡∏£, ‡∏£‡∏£-‡∏£, -‡∏£‡∏î‡∏£
‡∏ü‡∏ã‡∏ü‡∏£, ‡∏ü‡∏î‡∏£‡∏î, ‡∏•‡∏î-‡∏£, -‡∏ü--, --‡∏ü‡∏ã, ‡∏•‡∏î-‡∏£, -‡∏ü-‡∏î, ‡∏£‡∏ü-‡∏£
---‡∏î, ---‡∏•, ----, ---‡∏•, -‡∏•‡∏•‡∏•, -‡∏•-‡∏•, ‡∏ã‡∏•


## 4Ô∏è‚É£ Post-processing & MIDI Export

Pipeline: `combined_slots` ‚Üí song JSON ‚Üí DP octave inference ‚Üí octave-mark-to-numeric conversion ‚Üí Ranad MIDI

In [35]:
combined_slots = combine_fragment_and_generated(
    fragment_tokens=fragment,
    generated_tokens=generated_tokens,
    normalize_token=normalize_token
)

print(combined_slots[:80])

['----', '---‡∏ü', '-‡∏ü‡∏ü‡∏ü', '-‡∏ü-‡∏ü', '---‡∏ã', '---‡∏•', '---‡∏î‡πç', '---‡∏£‡πç', '-‡∏ü‡πç-‡∏î‡πç', '-‡∏£‡πç-‡∏ü‡πç', '-‡∏•‡πç‡∏ã‡πç‡∏ü‡πç', '-‡∏î‡πç-‡∏£‡πç', '‡∏ü‡πç‡∏£‡πç‡∏î‡πç‡∏•', '-‡∏î‡πç--', '-‡∏ã-‡∏•', '‡∏î‡πç‡∏ã‡∏•‡∏î‡πç', '----', '-‡∏î-‡∏ü', '‡∏ã‡∏•‡∏î‡∏£', '‡∏î‡∏£‡∏ü‡∏ã', '‡∏•‡∏ã‡∏ü‡∏£', '‡∏ü‡∏ã-‡∏ü', '---‡∏î', '‡∏£‡∏°‡∏ü‡∏ã', '---‡∏•', '‡∏ã‡∏ã‡∏ã‡∏ã', '‡∏î‡∏£‡∏î‡∏î', '‡∏ã‡∏•‡∏î‡∏£', '‡∏ã‡∏ü‡∏ü‡∏ü', '‡∏î‡∏£‡∏ü‡∏ã', '-‡∏•‡∏î‡∏ã', '‡∏•‡∏ã‡∏ü‡∏£', '----', '---‡∏•', '---‡∏ã', '‡∏ü‡∏ã‡∏•‡∏î', '---‡∏£', '‡∏î‡∏°‡∏£‡∏£', '----', '----', '‡∏ü‡∏£‡∏î‡∏•', '-‡∏ã-‡∏ü', '--‡∏•‡∏ã', '‡∏ü‡∏ã-‡∏•', '-‡∏î‡∏ü‡∏£', '‡∏î‡∏£‡∏ü‡∏•', '‡∏ã‡∏ü--', '‡∏ü‡∏ü-‡∏ü', '--‡∏•‡∏ã', '‡∏ü‡∏£-‡∏ü', '-‡∏ã-‡∏ã', '-‡∏ã‡∏î‡∏£', '‡∏ü‡∏ã‡∏ü‡∏•', '‡∏ã‡∏ü‡∏•‡∏ã', '‡∏ü‡∏£--', '‡∏°‡∏°‡∏ã‡∏°', '‡∏£‡∏î‡∏£‡∏î', '‡∏•‡∏î-‡∏£', '-‡∏ü‡∏•‡∏ã', '‡∏ü‡∏ã-‡∏•', '-‡∏î‡∏ü‡∏£', '‡∏î‡∏£‡∏ü‡∏•', '‡∏ã‡∏ü--', '‡∏ü‡∏ü-‡∏ü', '--‡∏•‡∏ã', '‡∏ü‡∏£-‡∏ü', '-‡∏ã‡∏•‡∏ã', '‡∏ü‡∏£‡∏ü‡∏î', '-‡∏£--', '----', '-‡∏£-‡∏£', '‡∏£‡∏£-‡∏£', '-‡∏£‡∏î‡∏

In [None]:
# Wrap combined slots into song JSON structure
song_data_generated = slots_to_song_data(combined_slots)
import pprint

pprint.pprint(song_data_generated, width=120)

{'sections': [{'bars': [['----', '---‡∏ü', '-‡∏ü‡∏ü‡∏ü', '-‡∏ü-‡∏ü', '---‡∏ã', '---‡∏•', '---‡∏î‡πç', '---‡∏£‡πç'],
                        ['-‡∏ü‡πç-‡∏î‡πç', '-‡∏£‡πç-‡∏ü‡πç', '-‡∏•‡πç‡∏ã‡πç‡∏ü‡πç', '-‡∏î‡πç-‡∏£‡πç', '‡∏ü‡πç‡∏£‡πç‡∏î‡πç‡∏•', '-‡∏î‡πç--', '-‡∏ã-‡∏•', '‡∏î‡πç‡∏ã‡∏•‡∏î‡πç'],
                        ['----', '-‡∏î-‡∏ü', '‡∏ã‡∏•‡∏î‡∏£', '‡∏î‡∏£‡∏ü‡∏ã', '‡∏•‡∏ã‡∏ü‡∏£', '‡∏ü‡∏ã-‡∏ü', '---‡∏î', '‡∏£‡∏°‡∏ü‡∏ã'],
                        ['---‡∏•', '‡∏ã‡∏ã‡∏ã‡∏ã', '‡∏î‡∏£‡∏î‡∏î', '‡∏ã‡∏•‡∏î‡∏£', '‡∏ã‡∏ü‡∏ü‡∏ü', '‡∏î‡∏£‡∏ü‡∏ã', '-‡∏•‡∏î‡∏ã', '‡∏•‡∏ã‡∏ü‡∏£'],
                        ['----', '---‡∏•', '---‡∏ã', '‡∏ü‡∏ã‡∏•‡∏î', '---‡∏£', '‡∏î‡∏°‡∏£‡∏£', '----', '----'],
                        ['‡∏ü‡∏£‡∏î‡∏•', '-‡∏ã-‡∏ü', '--‡∏•‡∏ã', '‡∏ü‡∏ã-‡∏•', '-‡∏î‡∏ü‡∏£', '‡∏î‡∏£‡∏ü‡∏•', '‡∏ã‡∏ü--', '‡∏ü‡∏ü-‡∏ü'],
                        ['--‡∏•‡∏ã', '‡∏ü‡∏£-‡∏ü', '-‡∏ã-‡∏ã', '-‡∏ã‡∏î‡∏£', '‡∏ü‡∏ã‡∏ü‡∏•', '‡∏ã‡∏ü‡∏•‡∏ã', '‡∏ü‡∏£--', '‡∏°‡∏°‡∏ã‡∏°'],
                        ['‡∏£‡∏î‡∏£‡∏î', '‡

In [None]:
# Infer octaves using DP smoothness constraint (from thai_music_utils)
song_data_auto = add_octaves_respecting_labels(song_data_generated)

In [None]:
# Inspect: song data after octave inference
from pprint import pprint

pprint(song_data_auto, width=120, sort_dicts=False)

{'title': 'Generated',
 'sections': [{'name': 'Generated',
               'bars': [['----', '---‡∏ü', '-‡∏ü‡∏ü‡∏ü', '-‡∏ü-‡∏ü', '---‡∏ã', '---‡∏•', '---‡∏î‡πç', '---‡∏£‡πç'],
                        ['-‡∏ü‡πç-‡∏î‡πç', '-‡∏£‡πç-‡∏ü‡πç', '-‡∏•‡πç‡∏ã‡πç‡∏ü‡πç', '-‡∏î‡πç-‡∏£‡πç', '‡∏ü‡πç‡∏£‡πç‡∏î‡πç‡∏•', '-‡∏î‡πç--', '-‡∏ã-‡∏•', '‡∏î‡πç‡∏ã‡∏•‡∏î‡πç'],
                        ['----', '-‡∏î‡πç-‡∏ü', '‡∏ã‡∏•‡∏î‡πç‡∏£‡πç', '‡∏î‡πç‡∏£‡πç‡∏ü‡πç‡∏ã‡πç', '‡∏•‡πç‡∏ã‡πç‡∏ü‡πç‡∏£‡πç', '‡∏ü‡πç‡∏ã‡πç-‡∏ü‡πç', '---‡∏î‡πç', '‡∏£‡πç‡∏°‡πç‡∏ü‡πç‡∏ã‡πç'],
                        ['---‡∏•‡πç', '‡∏ã‡πç‡∏ã‡πç‡∏ã‡πç‡∏ã‡πç', '‡∏î‡πç‡∏£‡πç‡∏î‡πç‡∏î‡πç', '‡∏ã‡∏•‡∏î‡πç‡∏£‡πç', '‡∏ã‡∏ü‡∏ü‡∏ü', '‡∏î‡∏£‡∏ü‡∏ã', '-‡∏•‡∏î‡πç‡∏ã', '‡∏•‡∏ã‡∏ü‡∏£'],
                        ['----', '---‡∏•', '---‡∏ã', '‡∏ü‡∏ã‡∏•‡∏î‡πç', '---‡∏£‡πç', '‡∏î‡πç‡∏°‡πç‡∏£‡πç‡∏£‡πç', '----', '----'],
                        ['‡∏ü‡πç‡∏£‡πç‡∏î‡πç‡∏•', '-‡∏ã-‡∏ü', '--‡∏•‡∏ã', '‡∏ü‡∏ã-‡∏•', '-‡∏î‡πç‡∏ü‡∏£', '‡∏î‡∏£‡∏ü‡∏•', '‡∏ã‡∏ü--', '‡∏ü‡∏ü-‡∏ü'],

In [None]:
# Re-flatten octave-annotated song data back into slot list for MIDI export
combined_slots = []

for sec in song_data_auto["sections"]:
    for bar in sec["bars"]:
        if isinstance(bar, list):
            combined_slots.extend(bar)

In [None]:
sequence_string = "".join(combined_slots)

# ============================================================
# Convert Thai octave marks to numeric tags before MIDI export
# ============================================================
import re

LOW_DOT = "‡∏∫"      # octave 1
HIGH_DOT = "‡πç"     # octave 3

# Convert ‡∏∫ (LOW_DOT) ‚Üí "1" after the note
sequence_string = re.sub(rf"([‡∏î‡∏£‡∏°‡∏ü‡∏ã‡∏•‡∏ó]){LOW_DOT}", r"\g<1>1", sequence_string)

# Convert ‡πç (HIGH_DOT) ‚Üí "3" after the note
sequence_string = re.sub(rf"([‡∏î‡∏£‡∏°‡∏ü‡∏ã‡∏•‡∏ó]){HIGH_DOT}", r"\g<1>3", sequence_string)

# Notes without octave marks default to octave 2 in generate_ranad_midi

midi_out = DATA_ROOT / "generated.mid" if not IS_COLAB else "/content/generated.mid"

generate_ranad_midi(
    sequence=sequence_string,
    output_path=str(midi_out),
    bpm=150,
    global_transpose=12,
    play_in_octave_pairs=True,
    enable_roll=True
)

print("‚úÖ MIDI exported successfully!")

‚úÖ MIDI exported successfully!


## 5Ô∏è‚É£ Evaluation

Compare the generated output against a reference song from the corpus.  
All metrics normalize `combined_slots` using the same `normalize_token` logic as training.

### 5.1 REST Type Distribution

Compares the proportion of `<REST_1>` / `<REST_2>` / `<REST_3>` / `<REST_4>` between the original song and the generated output.  
A well-trained model should produce a similar rhythmic density profile.

In [41]:
def evaluate_rest_type_distribution_from_slots(combined_slots, song_name, songs, normalize_token):
    """
    Compare REST_1/2/3/4 distribution between:
    - Original song (pitch_sequence)
    - Generated combined_slots (fragment + continuation)

    Uses SAME normalize_token logic as training.
    """

    # -----------------------------
    # 1Ô∏è‚É£ Find original song
    # -----------------------------
    song_match = None
    for s in songs:
        if s["song"] == song_name:
            song_match = s
            break

    if not song_match:
        print(f"‚ùå Song '{song_name}' not found.")
        return

    original_tokens = song_match["pitch_sequence"]

    # -----------------------------
    # 2Ô∏è‚É£ Normalize combined_slots
    # -----------------------------
    generated_tokens = []

    for slot in combined_slots:
        normalized = normalize_token(slot)
        generated_tokens.extend(normalized)

    # -----------------------------
    # 3Ô∏è‚É£ Count REST types
    # -----------------------------
    rest_types = ["<REST_1>", "<REST_2>", "<REST_3>", "<REST_4>"]

    def get_distribution(tokens):
        total_rests = sum(1 for t in tokens if t.startswith("<REST"))

        counts = {r: 0 for r in rest_types}
        for t in tokens:
            if t in counts:
                counts[t] += 1

        proportions = {
            r: (counts[r] / total_rests if total_rests > 0 else 0)
            for r in rest_types
        }

        return counts, proportions, total_rests

    orig_counts, orig_props, orig_total = get_distribution(original_tokens)
    gen_counts, gen_props, gen_total = get_distribution(generated_tokens)

    # -----------------------------
    # 4Ô∏è‚É£ Print Results
    # -----------------------------
    print("\n" + "="*70)
    print(f"üéµ REST TYPE DISTRIBUTION ‚Äî {song_name}")
    print("="*70)

    print("\nOriginal Song:")
    print(f"Total REST tokens: {orig_total}")
    for r in rest_types:
        print(f"{r}: {orig_counts[r]:4d} ({orig_props[r]*100:5.1f}%)")

    print("\nGenerated (Fragment + Continuation):")
    print(f"Total REST tokens: {gen_total}")
    for r in rest_types:
        print(f"{r}: {gen_counts[r]:4d} ({gen_props[r]*100:5.1f}%)")

    print("\nAbsolute Proportion Differences:")
    for r in rest_types:
        diff = abs(gen_props[r] - orig_props[r])
        print(f"{r}: {diff*100:5.1f}%")

    print("="*70 + "\n")

    return {
        "original": orig_props,
        "generated": gen_props
    }

In [42]:
evaluate_rest_type_distribution_from_slots(
    combined_slots,
    "‡πÄ‡∏Ç‡∏°‡∏£‡∏û‡∏ß‡∏á",
    songs,
    normalize_token
)


üéµ REST TYPE DISTRIBUTION ‚Äî ‡πÄ‡∏Ç‡∏°‡∏£‡∏û‡∏ß‡∏á

Original Song:
Total REST tokens: 277
<REST_1>:  176 ( 63.5%)
<REST_2>:   30 ( 10.8%)
<REST_3>:   29 ( 10.5%)
<REST_4>:   42 ( 15.2%)

Generated (Fragment + Continuation):
Total REST tokens: 77
<REST_1>:   47 ( 61.0%)
<REST_2>:   10 ( 13.0%)
<REST_3>:   13 ( 16.9%)
<REST_4>:    7 (  9.1%)

Absolute Proportion Differences:
<REST_1>:   2.5%
<REST_2>:   2.2%
<REST_3>:   6.4%
<REST_4>:   6.1%



{'original': {'<REST_1>': 0.6353790613718412,
  '<REST_2>': 0.10830324909747292,
  '<REST_3>': 0.10469314079422383,
  '<REST_4>': 0.15162454873646208},
 'generated': {'<REST_1>': 0.6103896103896104,
  '<REST_2>': 0.12987012987012986,
  '<REST_3>': 0.16883116883116883,
  '<REST_4>': 0.09090909090909091}}

In [38]:
# Convert combined_slots to normalized token stream
normalized_combined = []

for slot in combined_slots:
    normalized_combined.extend(normalize_token(slot))

# Print first 200 tokens for inspection
print("First 200 normalized tokens:\n")
print(normalized_combined[:200])

# Optional: print as single string
print("\nAs flat sequence:\n")
print(" ".join(normalized_combined[:200]))

First 200 normalized tokens:

['<REST_4>', '<REST_3>', '‡∏ü', '<REST_1>', '‡∏ü', '‡∏ü', '‡∏ü', '<REST_1>', '‡∏ü', '<REST_1>', '‡∏ü', '<REST_3>', '‡∏ã', '<REST_3>', '‡∏•', '<REST_3>', '‡∏î', '<REST_3>', '‡∏£', '<REST_1>', '‡∏ü', '<REST_1>', '‡∏î', '<REST_1>', '‡∏£', '<REST_1>', '‡∏ü', '<REST_1>', '‡∏•', '‡∏ã', '‡∏ü', '<REST_1>', '‡∏î', '<REST_1>', '‡∏£', '‡∏ü', '‡∏£', '‡∏î', '‡∏•', '<REST_1>', '‡∏î', '<REST_2>', '<REST_1>', '‡∏ã', '<REST_1>', '‡∏•', '‡∏î', '‡∏ã', '‡∏•', '‡∏î', '<REST_4>', '<REST_3>', '‡∏î', '<REST_1>', '‡∏£', '‡∏ü', '‡∏î', '‡∏£', '‡∏î', '‡∏•', '‡∏î', '‡∏£', '‡∏ü', '<REST_1>', '‡∏£', '<REST_3>', '‡∏•', '<REST_3>', '‡∏î', '<REST_3>', '‡∏ü', '<REST_3>', '‡∏£', '<REST_1>', '‡∏ü', '<REST_1>', '‡∏ã', '‡∏•', '‡∏ã', '‡∏ü', '‡∏£', '‡∏ã', '‡∏ü', '‡∏£', '‡∏î', '‡∏ü', '‡∏£', '‡∏î', '‡∏•', '<REST_4>', '<REST_3>', '‡∏•', '<REST_1>', '‡∏•', '‡∏•', '‡∏•', '<REST_1>', '‡∏•', '<REST_1>', '‡∏•', '<REST_2>', '‡∏î', '‡∏£', '‡∏ü', '‡∏£', '‡∏î', '‡∏•', '‡∏î', '‡∏•', '‡∏ã', '‡∏ü', '‡∏î', '‡∏£'

### 5.2 N-gram Overlap

Measures what fraction of the generated n-grams (bigram / trigram / 4-gram) also appear in the reference song.  
Higher overlap ‚Üí the model is reproducing known melodic patterns.  
Too high ‚Üí possible memorization; too low ‚Üí the output diverges from the style.

In [39]:
from collections import Counter

def evaluate_ngram_overlap(combined_slots, song_name, songs, normalize_token, n=3):
    """
    Compute n-gram overlap between:
    - Original song
    - Generated (fragment + continuation)

    n = 2 (bigram), 3 (trigram), 4 (quadgram)
    """

    # -----------------------------
    # 1Ô∏è‚É£ Find original song
    # -----------------------------
    song_match = None
    for s in songs:
        if s["song"] == song_name:
            song_match = s
            break

    if not song_match:
        print(f"‚ùå Song '{song_name}' not found.")
        return

    original_tokens = song_match["pitch_sequence"]

    # -----------------------------
    # 2Ô∏è‚É£ Normalize combined_slots
    # -----------------------------
    generated_tokens = []
    for slot in combined_slots:
        generated_tokens.extend(normalize_token(slot))

    # -----------------------------
    # 3Ô∏è‚É£ Build n-grams
    # -----------------------------
    def build_ngrams(tokens, n):
        return [
            tuple(tokens[i:i+n])
            for i in range(len(tokens) - n + 1)
        ]

    orig_ngrams = set(build_ngrams(original_tokens, n))
    gen_ngrams = build_ngrams(generated_tokens, n)

    if not gen_ngrams:
        print("‚ö†Ô∏è No generated n-grams.")
        return

    overlap_count = sum(1 for g in gen_ngrams if g in orig_ngrams)

    overlap_ratio = overlap_count / len(gen_ngrams)

    # -----------------------------
    # 4Ô∏è‚É£ Print results
    # -----------------------------
    print("\n" + "="*70)
    print(f"üéµ {n}-GRAM OVERLAP ‚Äî {song_name}")
    print("="*70)
    print(f"Generated {n}-grams: {len(gen_ngrams)}")
    print(f"Overlap count: {overlap_count}")
    print(f"Overlap ratio: {overlap_ratio:.3f} ({overlap_ratio*100:.1f}%)")
    print("="*70 + "\n")

    return overlap_ratio

In [40]:
evaluate_ngram_overlap(combined_slots, "‡πÄ‡∏Ç‡∏°‡∏£‡∏û‡∏ß‡∏á", songs, normalize_token, n=2)
evaluate_ngram_overlap(combined_slots, "‡πÄ‡∏Ç‡∏°‡∏£‡∏û‡∏ß‡∏á", songs, normalize_token, n=3)
evaluate_ngram_overlap(combined_slots, "‡πÄ‡∏Ç‡∏°‡∏£‡∏û‡∏ß‡∏á", songs, normalize_token, n=4)


üéµ 2-GRAM OVERLAP ‚Äî ‡πÄ‡∏Ç‡∏°‡∏£‡∏û‡∏ß‡∏á
Generated 2-grams: 295
Overlap count: 291
Overlap ratio: 0.986 (98.6%)


üéµ 3-GRAM OVERLAP ‚Äî ‡πÄ‡∏Ç‡∏°‡∏£‡∏û‡∏ß‡∏á
Generated 3-grams: 294
Overlap count: 271
Overlap ratio: 0.922 (92.2%)


üéµ 4-GRAM OVERLAP ‚Äî ‡πÄ‡∏Ç‡∏°‡∏£‡∏û‡∏ß‡∏á
Generated 4-grams: 293
Overlap count: 240
Overlap ratio: 0.819 (81.9%)



0.8191126279863481

### 5.3 Pitch KL Divergence

KL(P ‚Äñ Q) where P = pitch distribution of the reference song and Q = pitch distribution of the generated output.  
Measures how much the generated pitch usage diverges from the original.  
Lower KL ‚Üí closer match in overall pitch preference (e.g. how often ‡∏ã vs ‡∏î appears).

In [41]:
import numpy as np
from collections import Counter

def evaluate_pitch_kl(combined_slots, song_name, songs, normalize_token):
    """
    Compute KL divergence between pitch distributions
    of original song and generated output.
    """

    # -----------------------------
    # 1Ô∏è‚É£ Find original song
    # -----------------------------
    song_match = None
    for s in songs:
        if s["song"] == song_name:
            song_match = s
            break

    if not song_match:
        print(f"‚ùå Song '{song_name}' not found.")
        return

    original_tokens = song_match["pitch_sequence"]

    # -----------------------------
    # 2Ô∏è‚É£ Normalize generated slots
    # -----------------------------
    generated_tokens = []
    for slot in combined_slots:
        generated_tokens.extend(normalize_token(slot))

    # -----------------------------
    # 3Ô∏è‚É£ Extract pitch-only tokens
    # -----------------------------
    THAI_PITCHES = ["‡∏î", "‡∏£", "‡∏°", "‡∏ü", "‡∏ã", "‡∏•", "‡∏ó"]

    def get_pitch_distribution(tokens):
        pitch_tokens = [t for t in tokens if t in THAI_PITCHES]
        total = len(pitch_tokens)

        counts = Counter(pitch_tokens)

        probs = np.array([
            counts[p] / total if total > 0 else 0
            for p in THAI_PITCHES
        ])

        return probs

    P = get_pitch_distribution(original_tokens)
    Q = get_pitch_distribution(generated_tokens)

    # Add small epsilon to avoid log(0)
    epsilon = 1e-8
    P = P + epsilon
    Q = Q + epsilon

    kl_div = np.sum(P * np.log(P / Q))

    # -----------------------------
    # 4Ô∏è‚É£ Print Results
    # -----------------------------
    print("\n" + "="*70)
    print(f"üéµ PITCH KL DIVERGENCE ‚Äî {song_name}")
    print("="*70)
    print("Original distribution:", np.round(P, 3))
    print("Generated distribution:", np.round(Q, 3))
    print(f"\nKL(P || Q): {kl_div:.4f}")
    print("="*70 + "\n")

    return kl_div

In [42]:
evaluate_pitch_kl(combined_slots, "‡πÄ‡∏Ç‡∏°‡∏£‡∏û‡∏ß‡∏á", songs, normalize_token)


üéµ PITCH KL DIVERGENCE ‚Äî ‡πÄ‡∏Ç‡∏°‡∏£‡∏û‡∏ß‡∏á
Original distribution: [0.237 0.173 0.019 0.253 0.162 0.152 0.004]
Generated distribution: [0.196 0.187 0.033 0.263 0.148 0.172 0.   ]

KL(P || Q): 0.0591



np.float64(0.05914725564600476)