## Homework 3: Symbolic Music Generation Using Markov Chains

**Before starting the homework:**

Please run `pip install miditok` to install the [MiDiTok](https://github.com/Natooz/MidiTok) package, which simplifies MIDI file processing by making note and beat extraction more straightforward.

You’re also welcome to experiment with other MIDI processing libraries such as [mido](https://github.com/mido/mido), [pretty_midi](https://github.com/craffel/pretty-midi) and [miditoolkit](https://github.com/YatingMusic/miditoolkit). However, with these libraries, you’ll need to handle MIDI quantization yourself, for example, converting note-on/note-off events into beat positions and durations.

In [2]:
# run this command to install MiDiTok
! pip install miditok
! pip install midiutil

[0mCollecting midiutil
  Downloading MIDIUtil-1.2.1.tar.gz (1.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m16.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25ldone
[?25hBuilding wheels for collected packages: midiutil
  Building wheel for midiutil (setup.py) ... [?25ldone
[?25h  Created wheel for midiutil: filename=MIDIUtil-1.2.1-py3-none-any.whl size=54569 sha256=b733ffd403f1fdb8023223c8b16a9f1d8023007a9ae2756784117731dfed67e4
  Stored in directory: /root/.cache/pip/wheels/6c/42/75/fce10c67f06fe627fad8acd1fd3a004a24e07b0f077761fbbd
Successfully built midiutil
Installing collected packages: midiutil
Successfully installed midiutil-1.2.1
[0m

In [3]:
# import required packages
import random
from glob import glob
from collections import defaultdict

import numpy as np
from numpy.random import choice

from symusic import Score
from miditok import REMI, TokenizerConfig
from midiutil import MIDIFile

In [4]:
# You can change the random seed but try to keep your results deterministic!
# If I need to make changes to the autograder it'll require rerunning your code,
# so it should ideally generate the same results each time.
random.seed(42)

### Load music dataset
We will use a subset of the [PDMX dataset](https://zenodo.org/records/14984509).

Please find the link in the homework spec.

All pieces are monophonic music (i.e. one melody line) in 4/4 time signature.

In [7]:
midi_files = glob('PDMX_subset/*.mid')
len(midi_files)

1000

### Train a tokenizer with the REMI method in MidiTok

In [8]:
config = TokenizerConfig(num_velocities=1, use_chords=False, use_programs=False)
tokenizer = REMI(config)
tokenizer.train(vocab_size=1000, files_paths=midi_files)






### Use the trained tokenizer to get tokens for each midi file
In REMI representation, each note will be represented with four tokens: `Position, Pitch, Velocity, Duration`, e.g. `('Position_28', 'Pitch_74', 'Velocity_127', 'Duration_0.4.8')`; a `Bar_None` token indicates the beginning of a new bar.

In [9]:
# e.g.:
midi = Score(midi_files[0])
tokens = tokenizer(midi)[0].tokens
tokens[:10]

['Bar_None',
 'Position_0',
 'Pitch_60',
 'Velocity_127',
 'Duration_1.0.8',
 'Position_8',
 'Pitch_62',
 'Velocity_127',
 'Duration_1.0.8',
 'Position_16']

1. Write a function to extract note pitch events from a midi file; and another extract all note pitch events from the dataset and output a dictionary that maps note pitch events to the number of times they occur in the files. (e.g. {60: 120, 61: 58, …}).

`note_extraction()`
- **Input**: a midi file

- **Output**: a list of note pitch events (e.g. [60, 62, 61, ...])

`note_frequency()`
- **Input**: all midi files `midi_files`

- **Output**: a dictionary that maps note pitch events to the number of times they occur, e.g {60: 120, 61: 58, …}

In [10]:
def note_extraction(midi_file):
    """
    从MIDI文件中提取音符音高事件
    
    参数:
        midi_file: MIDI文件路径
    
    返回:
        音符音高事件列表
    """
    try:
        # 使用MiDiTok加载MIDI文件
        midi = Score(midi_file)
        tokens = tokenizer(midi)[0].tokens
        
        # 提取所有Pitch标记
        note_events = []
        for token in tokens:
            if token.startswith('Pitch_'):
                # 从'Pitch_60'这样的格式中提取数字
                pitch = int(token.split('_')[1])
                note_events.append(pitch)
                
        return note_events
    except Exception as e:
        print(f"处理文件 {midi_file} 时出错: {str(e)}")
        return []

In [24]:
note_extraction(midi_files[0])

[60,
 62,
 64,
 65,
 67,
 69,
 71,
 72,
 61,
 63,
 65,
 66,
 68,
 70,
 72,
 73,
 62,
 64,
 66,
 67,
 69,
 71,
 73,
 74,
 63,
 65,
 67,
 68,
 70,
 72,
 74,
 75,
 64,
 66,
 68,
 69,
 71,
 73,
 75,
 76,
 65,
 67,
 69,
 70,
 72,
 74,
 76,
 77,
 66,
 68,
 70,
 71,
 73,
 75,
 77,
 78,
 67,
 69,
 71,
 72,
 74,
 76,
 78,
 79,
 68,
 70,
 72,
 73,
 75,
 77,
 79,
 80,
 69,
 71,
 73,
 74,
 76,
 78,
 80,
 81,
 70,
 72,
 74,
 75,
 77,
 79,
 81,
 71,
 73,
 75,
 76,
 78,
 80,
 82,
 83]

In [25]:

from collections import Counter
def note_frequency(midi_files):
    """
    统计所有MIDI文件中音符音高的出现频率
    
    参数:
        midi_files: MIDI文件路径列表
    
    返回:
        音符音高及其出现次数的字典
    """
    all_notes = []
    
    # 处理每个MIDI文件
    for midi_file in midi_files:
        notes = note_extraction(midi_file)
        all_notes.extend(notes)
    
    # 使用Counter统计频率
    note_freq = dict(Counter(all_notes))
    
    return note_freq

In [27]:
note_frequency(midi_files)

{60: 729,
 62: 5010,
 64: 5976,
 65: 2223,
 67: 12003,
 69: 19315,
 71: 17023,
 72: 9395,
 61: 326,
 63: 327,
 66: 6087,
 68: 771,
 70: 2128,
 73: 6921,
 74: 19834,
 75: 696,
 76: 15518,
 77: 2628,
 78: 9714,
 79: 9400,
 80: 922,
 81: 6442,
 82: 291,
 83: 1309,
 84: 146,
 57: 457,
 59: 435,
 55: 178,
 58: 103,
 86: 42,
 88: 13,
 85: 65,
 52: 75,
 53: 29,
 50: 15,
 54: 1,
 56: 13,
 35: 16,
 37: 16,
 48: 16,
 43: 14,
 36: 6,
 45: 2,
 38: 2,
 47: 2,
 92: 1}

2. Write a function to normalize the above dictionary to produce probability scores (e.g. {60: 0.13, 61: 0.065, …})

`note_unigram_probability()`
- **Input**: all midi files `midi_files`

- **Output**: a dictionary that maps note pitch events to probabilities, e.g. {60: 0.13, 61: 0.06, …}

In [12]:
def note_unigram_probability(midi_files):
    note_counts = note_frequency(midi_files)
    unigramProbabilities = {}

    total_notes = sum(note_counts.values())
    for note, count in note_counts.items():
        unigramProbabilities[note] = count / total_notes


    return unigramProbabilities

In [28]:
note_unigram_probability(midi_files)

{60: 0.00465413221821432,
 62: 0.03198518849554697,
 64: 0.038152392504868,
 65: 0.01419223034443132,
 67: 0.07663038273693619,
 69: 0.12331215884061672,
 71: 0.10867941392409104,
 72: 0.059980208765601555,
 61: 0.002081271746416829,
 63: 0.0020876560155776167,
 66: 0.038861046381715454,
 68: 0.004922271522967409,
 70: 0.013585724774156479,
 73: 0.04418552686181249,
 74: 0.1266255945350656,
 75: 0.004443451335908322,
 76: 0.09907108883710537,
 77: 0.016777859354550388,
 78: 0.06201679062789287,
 79: 0.0600121301114055,
 80: 0.005886296166246369,
 81: 0.04112746193379513,
 82: 0.0018578223257892552,
 83: 0.008357008331471254,
 84: 0.0009321032974750215,
 57: 0.002917611006480033,
 59: 0.002777157084942701,
 55: 0.0011363999106202317,
 58: 0.0006575797235611453,
 86: 0.0002681393047530884,
 88: 8.299549909024164e-05,
 85: 0.00041497749545120824,
 52: 0.0004788201870590864,
 53: 0.00018514380566284674,
 50: 9.576403741181728e-05,
 54: 6.3842691607878185e-06,
 56: 8.299549909024164e-05,
 3

3. Generate a table of pairwise probabilities containing p(next_note | previous_note) values for the dataset; write a function that randomly generates the next note based on the previous note based on this distribution.

`note_bigram_probability()`
- **Input**: all midi files `midi_files`

- **Output**: two dictionaries:

  - `bigramTransitions`: key: previous_note, value: a list of next_note, e.g. {60:[62, 64, ..], 62:[60, 64, ..], ...} (i.e., this is a list of every other note that occured after note 60, every note that occured after note 62, etc.)

  - `bigramTransitionProbabilities`: key:previous_note, value: a list of probabilities for next_note in the same order of `bigramTransitions`, e.g. {60:[0.3, 0.4, ..], 62:[0.2, 0.1, ..], ...} (i.e., you are converting the values above to probabilities)

`sample_next_note()`
- **Input**: a note

- **Output**: next note sampled from pairwise probabilities

In [29]:
def note_bigram_probability(midi_files):
    """
    计算音符音高的二元概率
    
    参数:
        midi_files: MIDI文件路径列表
    
    返回:
        bigramTransitions: 前一个音符到下一个音符的映射
        bigramTransitionProbabilities: 对应的转移概率
    """
    # 初始化两个字典
    bigramTransitions = defaultdict(list)
    bigramTransitionProbabilities = defaultdict(list)
    
    # 处理每个MIDI文件
    for midi_file in midi_files:
        # 获取音符序列
        notes = note_extraction(midi_file)
        
        # 统计转移关系
        for i in range(len(notes)-1):
            prev_note = notes[i]
            next_note = notes[i+1]
            bigramTransitions[prev_note].append(next_note)
    
    # 计算转移概率
    for prev_note, next_notes in bigramTransitions.items():
        # 计算每个下一个音符的出现次数
        note_counts = Counter(next_notes)
        total = len(next_notes)
        
        # 计算概率
        probabilities = [count/total for count in note_counts.values()]
        
        # 更新转移概率字典
        bigramTransitionProbabilities[prev_note] = probabilities
        # 更新转移字典，去除重复
        bigramTransitions[prev_note] = list(note_counts.keys())
    
    return bigramTransitions, bigramTransitionProbabilities

In [32]:
bigramTransitions, bigramTransitionProbabilities = note_bigram_probability(midi_files)

In [33]:
bigramTransitions

defaultdict(list,
            {60: [62,
              63,
              65,
              64,
              60,
              72,
              67,
              55,
              59,
              79,
              77,
              69,
              57,
              73,
              76,
              74,
              58,
              66,
              75,
              71,
              70],
             62: [64,
              62,
              74,
              67,
              63,
              72,
              66,
              69,
              71,
              59,
              57,
              61,
              79,
              65,
              60,
              81,
              73,
              55,
              58,
              70,
              76,
              78,
              68,
              77,
              50],
             64: [65,
              66,
              69,
              62,
              73,
              71,
              67,
              

In [34]:
bigramTransitionProbabilities

defaultdict(list,
            {60: [0.19665271966527198,
              0.03905160390516039,
              0.06415620641562064,
              0.18131101813110181,
              0.19665271966527198,
              0.05160390516039052,
              0.029288702928870293,
              0.019525801952580194,
              0.06834030683403068,
              0.005578800557880056,
              0.015341701534170154,
              0.01394700139470014,
              0.04184100418410042,
              0.002789400278940028,
              0.01813110181311018,
              0.002789400278940028,
              0.015341701534170154,
              0.011157601115760111,
              0.00697350069735007,
              0.009762900976290097,
              0.009762900976290097],
             62: [0.19212022745735174,
              0.15698619008935824,
              0.042648253452477664,
              0.13383428107229894,
              0.012185215272136474,
              0.016653127538586516,
              0

In [31]:
def sample_next_note(prev_note):
    """
    根据前一个音符采样下一个音符
    
    参数:
        prev_note: 前一个音符
    
    返回:
        采样得到的下一个音符
    """
    # 获取全局变量
    global bigramTransitions, bigramTransitionProbabilities
    
    # 如果前一个音符不在训练数据中，返回None
    if prev_note not in bigramTransitions:
        return None
    
    # 获取可能的下一个音符列表和对应的概率
    next_notes = bigramTransitions[prev_note]
    probabilities = bigramTransitionProbabilities[prev_note]
    
    # 使用numpy的choice函数进行采样
    next_note = choice(next_notes, p=probabilities)
    
    return next_note

4. Write a function to calculate the perplexity of your model on a midi file.

    The perplexity of a model is defined as

    $\quad \text{exp}(-\frac{1}{N} \sum_{i=1}^N \text{log}(p(w_i|w_{i-1})))$

    where $p(w_1|w_0) = p(w_1)$, $p(w_i|w_{i-1}) (i>1)$ refers to the pairwise probability p(next_note | previous_note).

`note_bigram_perplexity()`
- **Input**: a midi file

- **Output**: perplexity value

In [13]:
def note_bigram_perplexity(midi_file):
    """
    计算二元模型的困惑度
    
    参数:
        midi_file: MIDI文件路径
    
    返回:
        困惑度值
    """
    # 获取单字概率和二元概率
    unigramProbabilities = note_unigram_probability(midi_files)
    bigramTransitions, bigramTransitionProbabilities = note_bigram_probability(midi_files)
    
    # 获取音符序列
    notes = note_extraction(midi_file)
    N = len(notes)
    
    if N == 0:
        return float('inf')
    
    # 计算对数概率之和
    log_prob_sum = 0
    
    for i in range(N):
        if i == 0:
            # 第一个音符使用单字概率
            if notes[i] in unigramProbabilities:
                log_prob_sum += np.log(unigramProbabilities[notes[i]])
            else:
                return float('inf')
        else:
            # 其他音符使用二元概率
            prev_note = notes[i-1]
            current_note = notes[i]
            
            if prev_note in bigramTransitions:
                next_notes = bigramTransitions[prev_note]
                probs = bigramTransitionProbabilities[prev_note]
                
                if current_note in next_notes:
                    idx = next_notes.index(current_note)
                    log_prob_sum += np.log(probs[idx])
                else:
                    return float('inf')
            else:
                return float('inf')
    
    # 计算困惑度
    perplexity = np.exp(-log_prob_sum / N)
    
    return perplexity

5. Implement a second-order Markov chain, i.e., one which estimates p(next_note | next_previous_note, previous_note); write a function to compute the perplexity of this new model on a midi file.

    The perplexity of this model is defined as

    $\quad \text{exp}(-\frac{1}{N} \sum_{i=1}^N \text{log}(p(w_i|w_{i-2}, w_{i-1})))$

    where $p(w_1|w_{-1}, w_0) = p(w_1)$, $p(w_2|w_0, w_1) = p(w_2|w_1)$, $p(w_i|w_{i-2}, w_{i-1}) (i>2)$ refers to the probability p(next_note | next_previous_note, previous_note).


`note_trigram_probability()`
- **Input**: all midi files `midi_files`

- **Output**: two dictionaries:

  - `trigramTransitions`: key - (next_previous_note, previous_note), value - a list of next_note, e.g. {(60, 62):[64, 66, ..], (60, 64):[60, 64, ..], ...}

  - `trigramTransitionProbabilities`: key: (next_previous_note, previous_note), value: a list of probabilities for next_note in the same order of `trigramTransitions`, e.g. {(60, 62):[0.2, 0.2, ..], (60, 64):[0.4, 0.1, ..], ...}

`note_trigram_perplexity()`
- **Input**: a midi file

- **Output**: perplexity value

In [14]:
def note_trigram_probability(midi_files):
    """
    计算音符音高的三元概率
    
    参数:
        midi_files: MIDI文件路径列表
    
    返回:
        trigramTransitions: 前两个音符到下一个音符的映射
        trigramTransitionProbabilities: 对应的转移概率
    """
    # 初始化两个字典
    trigramTransitions = defaultdict(list)
    trigramTransitionProbabilities = defaultdict(list)
    
    # 处理每个MIDI文件
    for midi_file in midi_files:
        # 获取音符序列
        notes = note_extraction(midi_file)
        
        # 统计三元转移关系
        for i in range(len(notes)-2):
            next_prev_note = notes[i]
            prev_note = notes[i+1]
            next_note = notes[i+2]
            
            # 使用元组作为键
            key = (next_prev_note, prev_note)
            trigramTransitions[key].append(next_note)
    
    # 计算转移概率
    for key, next_notes in trigramTransitions.items():
        # 计算每个下一个音符的出现次数
        note_counts = Counter(next_notes)
        total = len(next_notes)
        
        # 计算概率
        probabilities = [count/total for count in note_counts.values()]
        
        # 更新转移概率字典
        trigramTransitionProbabilities[key] = probabilities
        # 更新转移字典，去除重复
        trigramTransitions[key] = list(note_counts.keys())
    
    return trigramTransitions, trigramTransitionProbabilities


In [15]:
def note_trigram_perplexity(midi_file):
    """
    计算三元模型的困惑度
    
    参数:
        midi_file: MIDI文件路径
    
    返回:
        困惑度值
    """
    # 获取各种概率模型
    unigramProbabilities = note_unigram_probability(midi_files)
    bigramTransitions, bigramTransitionProbabilities = note_bigram_probability(midi_files)
    trigramTransitions, trigramTransitionProbabilities = note_trigram_probability(midi_files)
    
    # 获取音符序列
    notes = note_extraction(midi_file)
    N = len(notes)
    
    if N == 0:
        return float('inf')
    
    # 计算对数概率之和
    log_prob_sum = 0
    
    for i in range(N):
        if i == 0:
            # 第一个音符使用单字概率
            if notes[i] in unigramProbabilities:
                log_prob_sum += np.log(unigramProbabilities[notes[i]])
            else:
                return float('inf')
        elif i == 1:
            # 第二个音符使用二元概率
            prev_note = notes[i-1]
            current_note = notes[i]
            
            if prev_note in bigramTransitions:
                next_notes = bigramTransitions[prev_note]
                probs = bigramTransitionProbabilities[prev_note]
                
                if current_note in next_notes:
                    idx = next_notes.index(current_note)
                    log_prob_sum += np.log(probs[idx])
                else:
                    return float('inf')
            else:
                return float('inf')
        else:
            # 其他音符使用三元概率
            next_prev_note = notes[i-2]
            prev_note = notes[i-1]
            current_note = notes[i]
            
            key = (next_prev_note, prev_note)
            if key in trigramTransitions:
                next_notes = trigramTransitions[key]
                probs = trigramTransitionProbabilities[key]
                
                if current_note in next_notes:
                    idx = next_notes.index(current_note)
                    log_prob_sum += np.log(probs[idx])
                else:
                    return float('inf')
            else:
                return float('inf')
    
    # 计算困惑度
    perplexity = np.exp(-log_prob_sum / N)
    
    return perplexity

6. Our model currently doesn’t have any knowledge of beats. Write a function that extracts beat lengths and outputs a list of [(beat position; beat length)] values.

    Recall that each note will be encoded as `Position, Pitch, Velocity, Duration` using REMI. Please keep the `Position` value for beat position, and convert `Duration` to beat length using provided lookup table `duration2length` (see below).

    For example, for a note represented by four tokens `('Position_24', 'Pitch_72', 'Velocity_127', 'Duration_0.4.8')`, the extracted (beat position; beat length) value is `(24, 4)`.

    As a result, we will obtain a list like [(0,8),(8,16),(24,4),(28,4),(0,4)...], where the next beat position is the previous beat position + the beat length. As we divide each bar into 32 positions by default, when reaching the end of a bar (i.e. 28 + 4 = 32 in the case of (28, 4)), the beat position reset to 0.

In [16]:
duration2length = {
    '0.2.8': 2,  # sixteenth note, 0.25 beat in 4/4 time signature
    '0.4.8': 4,  # eighth note, 0.5 beat in 4/4 time signature
    '1.0.8': 8,  # quarter note, 1 beat in 4/4 time signature
    '2.0.8': 16, # half note, 2 beats in 4/4 time signature
    '4.0.4': 32, # whole note, 4 beats in 4/4 time signature
}

`beat_extraction()`
- **Input**: a midi file

- **Output**: a list of (beat position; beat length) values

In [38]:
def beat_extraction(midi_file):
    """
    从MIDI文件中提取节拍位置和长度
    
    参数:
        midi_file: MIDI文件路径
    
    返回:
        节拍位置和长度的列表,格式为[(position, length), ...]
    """
    try:
        # 使用MiDiTok加载MIDI文件
        midi = Score(midi_file)
        tokens = tokenizer(midi)[0].tokens
        
        # 存储节拍信息
        beat_info = []
        
        # 遍历所有标记
        i = 0
        while i < len(tokens):
            token = tokens[i]
            
            # 检查是否是Position标记
            if token.startswith('Position_'):
                position = int(token.split('_')[1])
                
                # 检查下一个标记是否是Duration
                if i + 3 < len(tokens) and tokens[i+3].startswith('Duration_'):
                    duration = tokens[i+3]
                    # 使用映射表获取节拍长度
                    if duration in duration2length:
                        beat_length = duration2length[duration]
                        beat_info.append((position, beat_length))
            
            i += 1
        
        return beat_info
        
    except Exception as e:
        print(f"处理文件 {midi_file} 时出错: {str(e)}")
        return []

In [40]:
beat_extraction(midi_files[1])

[]

7. Implement a Markov chain that computes p(beat_length | previous_beat_length) based on the above function.

`beat_bigram_probability()`
- **Input**: all midi files `midi_files`

- **Output**: two dictionaries:

  - `bigramBeatTransitions`: key: previous_beat_length, value: a list of beat_length, e.g. {4:[8, 2, ..], 8:[8, 4, ..], ...}

  - `bigramBeatTransitionProbabilities`: key - previous_beat_length, value - a list of probabilities for beat_length in the same order of `bigramBeatTransitions`, e.g. {4:[0.3, 0.2, ..], 8:[0.4, 0.4, ..], ...}

In [18]:
def beat_bigram_probability(midi_files):
    """
    计算节拍长度的二元概率
    
    参数:
        midi_files: MIDI文件路径列表
    
    返回:
        bigramBeatTransitions: 前一个节拍长度到下一个节拍长度的映射
        bigramBeatTransitionProbabilities: 对应的转移概率
    """
    # 初始化两个字典
    bigramBeatTransitions = defaultdict(list)
    bigramBeatTransitionProbabilities = defaultdict(list)
    
    # 处理每个MIDI文件
    for midi_file in midi_files:
        # 获取节拍信息
        beat_info = beat_extraction(midi_file)
        
        # 统计转移关系
        for i in range(len(beat_info)-1):
            prev_length = beat_info[i][1]  # 获取前一个节拍长度
            next_length = beat_info[i+1][1]  # 获取下一个节拍长度
            bigramBeatTransitions[prev_length].append(next_length)
    
    # 计算转移概率
    for prev_length, next_lengths in bigramBeatTransitions.items():
        # 计算每个下一个节拍长度的出现次数
        length_counts = Counter(next_lengths)
        total = len(next_lengths)
        
        # 计算概率
        probabilities = [count/total for count in length_counts.values()]
        
        # 更新转移概率字典
        bigramBeatTransitionProbabilities[prev_length] = probabilities
        # 更新转移字典，去除重复
        bigramBeatTransitions[prev_length] = list(length_counts.keys())
    
    return bigramBeatTransitions, bigramBeatTransitionProbabilities

8. Implement a function to compute p(beat length | beat position), and compute the perplexity of your models from Q7 and Q8. For both models, we only consider the probabilities of predicting the sequence of **beat lengths**.

`beat_pos_bigram_probability()`
- **Input**: all midi files `midi_files`

- **Output**: two dictionaries:

  - `bigramBeatPosTransitions`: key - beat_position, value - a list of beat_length

  - `bigramBeatPosTransitionProbabilities`: key - beat_position, value - a list of probabilities for beat_length in the same order of `bigramBeatPosTransitions`

`beat_bigram_perplexity()`
- **Input**: a midi file

- **Output**: two perplexity values correspond to the models in Q7 and Q8, respectively

In [19]:
def beat_pos_bigram_probability(midi_files):
    """
    计算节拍位置到节拍长度的概率
    
    参数:
        midi_files: MIDI文件路径列表
    
    返回:
        bigramBeatPosTransitions: 节拍位置到节拍长度的映射
        bigramBeatPosTransitionProbabilities: 对应的概率
    """
    # 初始化两个字典
    bigramBeatPosTransitions = defaultdict(list)
    bigramBeatPosTransitionProbabilities = defaultdict(list)
    
    # 处理每个MIDI文件
    for midi_file in midi_files:
        # 获取节拍信息
        beat_info = beat_extraction(midi_file)
        
        # 统计转移关系
        for position, length in beat_info:
            bigramBeatPosTransitions[position].append(length)
    
    # 计算转移概率
    for position, lengths in bigramBeatPosTransitions.items():
        # 计算每个节拍长度的出现次数
        length_counts = Counter(lengths)
        total = len(lengths)
        
        # 计算概率
        probabilities = [count/total for count in length_counts.values()]
        
        # 更新转移概率字典
        bigramBeatPosTransitionProbabilities[position] = probabilities
        # 更新转移字典，去除重复
        bigramBeatPosTransitions[position] = list(length_counts.keys())
    
    return bigramBeatPosTransitions, bigramBeatPosTransitionProbabilities


In [20]:

def beat_bigram_perplexity(midi_file):
    """
    计算两个模型的困惑度
    
    参数:
        midi_file: MIDI文件路径
    
    返回:
        perplexity_Q7: 基于前一个节拍长度的模型的困惑度
        perplexity_Q8: 基于节拍位置的模型的困惑度
    """
    # 获取两个概率模型
    bigramBeatTransitions, bigramBeatTransitionProbabilities = beat_bigram_probability(midi_files)
    bigramBeatPosTransitions, bigramBeatPosTransitionProbabilities = beat_pos_bigram_probability(midi_files)
    
    # 获取节拍信息
    beat_info = beat_extraction(midi_file)
    N = len(beat_info)
    
    if N == 0:
        return float('inf'), float('inf')
    
    # 计算两个模型的对数概率之和
    log_prob_sum_Q7 = 0
    log_prob_sum_Q8 = 0
    
    for i in range(N):
        position, length = beat_info[i]
        
        # 计算Q7模型的对数概率
        if i > 0:
            prev_length = beat_info[i-1][1]
            if prev_length in bigramBeatTransitions:
                next_lengths = bigramBeatTransitions[prev_length]
                probs = bigramBeatTransitionProbabilities[prev_length]
                
                if length in next_lengths:
                    idx = next_lengths.index(length)
                    log_prob_sum_Q7 += np.log(probs[idx])
                else:
                    return float('inf'), float('inf')
            else:
                return float('inf'), float('inf')
        
        # 计算Q8模型的对数概率
        if position in bigramBeatPosTransitions:
            lengths = bigramBeatPosTransitions[position]
            probs = bigramBeatPosTransitionProbabilities[position]
            
            if length in lengths:
                idx = lengths.index(length)
                log_prob_sum_Q8 += np.log(probs[idx])
            else:
                return float('inf'), float('inf')
        else:
            return float('inf'), float('inf')
    
    # 计算困惑度
    perplexity_Q7 = np.exp(-log_prob_sum_Q7 / N)
    perplexity_Q8 = np.exp(-log_prob_sum_Q8 / N)
    
    return perplexity_Q7, perplexity_Q8

9. Implement a Markov chain that computes p(beat_length | previous_beat_length, beat_position), and report its perplexity.

`beat_trigram_probability()`
- **Input**: all midi files `midi_files`

- **Output**: two dictionaries:

  - `trigramBeatTransitions`: key: (previous_beat_length, beat_position), value: a list of beat_length

  - `trigramBeatTransitionProbabilities`: key: (previous_beat_length, beat_position), value: a list of probabilities for beat_length in the same order of `trigramBeatTransitions`

`beat_trigram_perplexity()`
- **Input**: a midi file

- **Output**: perplexity value

In [21]:
def beat_trigram_probability(midi_files):
    """
    计算节拍长度的三元概率
    
    参数:
        midi_files: MIDI文件路径列表
    
    返回:
        trigramBeatTransitions: 前一个节拍长度和当前位置到下一个节拍长度的映射
        trigramBeatTransitionProbabilities: 对应的转移概率
    """
    # 初始化两个字典
    trigramBeatTransitions = defaultdict(list)
    trigramBeatTransitionProbabilities = defaultdict(list)
    
    # 处理每个MIDI文件
    for midi_file in midi_files:
        # 获取节拍信息
        beat_info = beat_extraction(midi_file)
        
        # 统计三元转移关系
        for i in range(len(beat_info)-1):
            prev_length = beat_info[i][1]  # 前一个节拍长度
            current_position = beat_info[i+1][0]  # 当前位置
            next_length = beat_info[i+1][1]  # 下一个节拍长度
            
            # 使用元组作为键
            key = (prev_length, current_position)
            trigramBeatTransitions[key].append(next_length)
    
    # 计算转移概率
    for key, next_lengths in trigramBeatTransitions.items():
        # 计算每个下一个节拍长度的出现次数
        length_counts = Counter(next_lengths)
        total = len(next_lengths)
        
        # 计算概率
        probabilities = [count/total for count in length_counts.values()]
        
        # 更新转移概率字典
        trigramBeatTransitionProbabilities[key] = probabilities
        # 更新转移字典，去除重复
        trigramBeatTransitions[key] = list(length_counts.keys())
    
    return trigramBeatTransitions, trigramBeatTransitionProbabilities


In [22]:

def beat_trigram_perplexity(midi_file):
    """
    计算三元模型的困惑度
    
    参数:
        midi_file: MIDI文件路径
    
    返回:
        困惑度值
    """
    # 获取概率模型
    trigramBeatTransitions, trigramBeatTransitionProbabilities = beat_trigram_probability(midi_files)
    bigramBeatPosTransitions, bigramBeatPosTransitionProbabilities = beat_pos_bigram_probability(midi_files)
    
    # 获取节拍信息
    beat_info = beat_extraction(midi_file)
    N = len(beat_info)
    
    if N == 0:
        return float('inf')
    
    # 计算对数概率之和
    log_prob_sum = 0
    
    for i in range(N):
        position, length = beat_info[i]
        
        if i == 0:
            # 第一个节拍使用位置概率
            if position in bigramBeatPosTransitions:
                lengths = bigramBeatPosTransitions[position]
                probs = bigramBeatPosTransitionProbabilities[position]
                
                if length in lengths:
                    idx = lengths.index(length)
                    log_prob_sum += np.log(probs[idx])
                else:
                    return float('inf')
            else:
                return float('inf')
        else:
            # 其他节拍使用三元概率
            prev_length = beat_info[i-1][1]
            key = (prev_length, position)
            
            if key in trigramBeatTransitions:
                next_lengths = trigramBeatTransitions[key]
                probs = trigramBeatTransitionProbabilities[key]
                
                if length in next_lengths:
                    idx = next_lengths.index(length)
                    log_prob_sum += np.log(probs[idx])
                else:
                    return float('inf')
            else:
                return float('inf')
    
    # 计算困惑度
    perplexity = np.exp(-log_prob_sum / N)
    
    return perplexity

10. Use the model from Q5 to generate N notes, and the model from Q8 to generate beat lengths for each note. Save the generated music as a midi file (see code from workbook1) as q10.mid. Remember to reset the beat position to 0 when reaching the end of a bar.

`music_generate`
- **Input**: target length, e.g. 500

- **Output**: a midi file q10.mid

Note: the duration of one beat in MIDIUtil is 1, while in MidiTok is 8. Divide beat length by 8 if you use methods in MIDIUtil to save midi files.

In [23]:
def music_generate(length):
    """
    生成音乐并保存为MIDI文件
    
    参数:
        length: 目标音符数量
    
    返回:
        无，直接保存为q10.mid文件
    """
    # 获取概率模型
    unigramProbabilities = note_unigram_probability(midi_files)
    bigramTransitions, bigramTransitionProbabilities = note_bigram_probability(midi_files)
    trigramTransitions, trigramTransitionProbabilities = note_trigram_probability(midi_files)
    bigramBeatPosTransitions, bigramBeatPosTransitionProbabilities = beat_pos_bigram_probability(midi_files)
    
    # 初始化MIDI文件
    midi = MIDIFile(1)  # 一个音轨
    midi.addTempo(0, 0, 120)  # 设置速度为120 BPM
    
    # 生成音符和节拍
    current_position = 0
    current_time = 0
    prev_note = None
    prev_prev_note = None
    
    for i in range(length):
        # 生成音符
        if i == 0:
            # 第一个音符使用单字概率
            notes = list(unigramProbabilities.keys())
            probs = list(unigramProbabilities.values())
            current_note = choice(notes, p=probs)
        elif i == 1:
            # 第二个音符使用二元概率
            if prev_note in bigramTransitions:
                notes = bigramTransitions[prev_note]
                probs = bigramTransitionProbabilities[prev_note]
                current_note = choice(notes, p=probs)
            else:
                # 如果前一个音符不在训练数据中，使用单字概率
                notes = list(unigramProbabilities.keys())
                probs = list(unigramProbabilities.values())
                current_note = choice(notes, p=probs)
        else:
            # 其他音符使用三元概率
            key = (prev_prev_note, prev_note)
            if key in trigramTransitions:
                notes = trigramTransitions[key]
                probs = trigramTransitionProbabilities[key]
                current_note = choice(notes, p=probs)
            else:
                # 如果前两个音符的组合不在训练数据中，使用二元概率
                if prev_note in bigramTransitions:
                    notes = bigramTransitions[prev_note]
                    probs = bigramTransitionProbabilities[prev_note]
                    current_note = choice(notes, p=probs)
                else:
                    # 如果前一个音符也不在训练数据中，使用单字概率
                    notes = list(unigramProbabilities.keys())
                    probs = list(unigramProbabilities.values())
                    current_note = choice(notes, p=probs)
        
        # 生成节拍长度
        if current_position in bigramBeatPosTransitions:
            lengths = bigramBeatPosTransitions[current_position]
            probs = bigramBeatPosTransitionProbabilities[current_position]
            beat_length = choice(lengths, p=probs)
        else:
            # 如果当前位置不在训练数据中，使用默认长度
            beat_length = 4  # 默认使用八分音符
        
        # 添加音符到MIDI文件
        # 注意：MIDIUtil中一个节拍的长度是1，而MidiTok中是8
        duration = beat_length / 8
        midi.addNote(0, 0, current_note, current_time, duration, 100)
        
        # 更新时间和位置
        current_time += duration
        current_position = (current_position + beat_length) % 32  # 重置到0-31范围内
        
        # 更新前一个音符
        prev_prev_note = prev_note
        prev_note = current_note
    
    # 保存MIDI文件
    with open('q10.mid', 'wb') as f:
        midi.writeFile(f)