# Assignment 2 – Task 1 Symbolic, Unconditioned Generation

This notebook implements and significantly extends the symbolic Markov model from Homework 3. We extract musical features (pitch and duration) from a dataset of MIDI files using the MiDiTok library and build unigram, bigram, and trigram models to learn a distribution p(x) over musical sequences. We evaluate models using perplexity and probability distributions, and sample new music using our trained models.

In [1]:
import random
import numpy as np
from numpy.random import choice
from collections import defaultdict
from glob import glob
from symusic import Score
from miditok import REMI, TokenizerConfig
from midiutil import MIDIFile
import matplotlib.pyplot as plt
from IPython.display import Audio
from statistics import mean

# Set seed for reproducibility
random.seed(42)
np.random.seed(42)

### Tokenizer and Dataset Setup

We use the REMI tokenizer from MiDiTok to convert symbolic MIDI data into token sequences for modeling. This approach allows us to extract structured pitch, duration, and positional events from each file.

In [2]:
# Load MIDI files and initialize tokenizer
midi_files = glob("PDMX_subset/*.mid")

tokenizer = REMI(TokenizerConfig(
    num_velocities=1,
    use_chords=False,
    use_programs=False
))

tokenizer.train(vocab_size=1000, files_paths=midi_files)






### Baseline Markov Models (Pitch Only)

We begin by re-implementing and extending the pitch-based unigram, bigram, and trigram models from Homework 3. These models serve as the baseline for measuring improvements.

In [3]:
#   Extracts all pitch events (as integers) from a single MIDI file.

def note_extraction(midi_file):
    note_events = []
    midi = Score(midi_file)
    tokens = tokenizer(midi)[0].tokens
    for token in tokens:
        if 'Pitch' in token:
            note = int(token.split('_')[1])
            note_events.append(note)
    return note_events

#   Aggregates pitch counts over all MIDI files and returns a dictionary 
#   mapping each pitch to its total count.

def note_frequency(midi_files):
    note_counts = defaultdict(int)
    for midi_file in midi_files:
        note_events = note_extraction(midi_file)
        for note in note_events:
            note_counts[note] += 1
    return note_counts

#   Takes pitch counts from all MIDI files and normalizes them to produce
#   a probability distribution over note pitch events.
#   Returns a dictionary mapping each pitch to its probability.

def note_unigram_probability(midi_files):
    note_counts = note_frequency(midi_files)
    
    unigramProbabilities = {}
    counts = sum(list(note_counts.values()))
    for n in note_counts:
        unigramProbabilities[n] = note_counts[n] / counts
    return unigramProbabilities

#   Computes bigram (pairwise) transition probabilities for notes.
#   Returns:
#       - bigramTransitions: {prev_note: [next_note1, next_note2, ...]}
#       - bigramTransitionProbabilities: {prev_note: [p1, p2, ...]}

def note_bigram_probability(midi_files):
    bigrams = defaultdict(int)
    
    for file in midi_files:
        note_events = note_extraction(file)
        for (note1, note2) in zip(note_events[:-1], note_events[1:]):
            bigrams[(note1, note2)] += 1
            
    bigramTransitions = defaultdict(list)
    bigramTransitionProbabilities = defaultdict(list)

    for b1,b2 in bigrams:
        bigramTransitions[b1].append(b2)
        bigramTransitionProbabilities[b1].append(bigrams[(b1,b2)])
        
    for k in bigramTransitionProbabilities:
        Z = sum(bigramTransitionProbabilities[k])
        bigramTransitionProbabilities[k] = [x / Z for x in bigramTransitionProbabilities[k]]
        
    return bigramTransitions, bigramTransitionProbabilities

#   Samples the next note based on bigram transition probabilities,
#   given the current note.

def sample_next_note(note):
    bigramTransitions, bigramTransitionProbabilities = note_bigram_probability(midi_files)
    next_note = choice(bigramTransitions[note], 1, p=bigramTransitionProbabilities[note])[0]
    return next_note

#   Computes the perplexity of the bigram model on a given MIDI file.
#   Uses:
#       - p(w1) from unigram probability
#       - p(w_i | w_{i-1}) from bigram transition probabilities
#   Returns the perplexity value as a float.

def note_bigram_perplexity(midi_file):
    unigramProbabilities = note_unigram_probability(midi_files)
    bigramTransitions, bigramTransitionProbabilities = note_bigram_probability(midi_files)
    
    note_events = note_extraction(midi_file)
    perplexities = [unigramProbabilities[note_events[0]]]
    for (note1, note2) in zip(note_events[:-1], note_events[1:]):
        index = bigramTransitions[note1].index(note2)
        prob = bigramTransitionProbabilities[note1][index]
        perplexities.append(prob)

    assert len(perplexities) == len(note_events)
    perplexity = np.exp(-np.sum(np.log(perplexities)) / len(note_events))
    return perplexity

#   Computes trigram (second-order Markov) transition probabilities for notes.
#   Returns:
#       - trigramTransitions: {(note_{i-2}, note_{i-1}): [note_i, ...]}
#       - trigramTransitionProbabilities: {(note_{i-2}, note_{i-1}): [p1, p2, ...]}


def note_trigram_probability(midi_files):
    trigrams = defaultdict(int)
    for file in midi_files:
        note_events = note_extraction(file)
        for (note1, note2, note3) in zip(note_events[:-2], note_events[1:-1], note_events[2:]):
            trigrams[(note1, note2, note3)] += 1
            
    trigramTransitions = defaultdict(list)
    trigramTransitionProbabilities = defaultdict(list)

    for t1,t2,t3 in trigrams:
        trigramTransitions[(t1,t2)].append(t3)
        trigramTransitionProbabilities[(t1,t2)].append(trigrams[(t1,t2,t3)])
        
    for k in trigramTransitionProbabilities:
        Z = sum(trigramTransitionProbabilities[k])
        trigramTransitionProbabilities[k] = [x / Z for x in trigramTransitionProbabilities[k]]
        
    return trigramTransitions, trigramTransitionProbabilities

#   Computes the perplexity of the trigram model on a given MIDI file.
#   Uses:
#       - p(w1) from unigram
#       - p(w2 | w1) from bigram
#       - p(w_i | w_{i-2}, w_{i-1}) for i > 2
#   Returns the perplexity value as a float.

def note_trigram_perplexity(midi_file):
    unigramProbabilities = note_unigram_probability(midi_files)
    bigramTransitions, bigramTransitionProbabilities = note_bigram_probability(midi_files)
    trigramTransitions, trigramTransitionProbabilities = note_trigram_probability(midi_files)
    
    note_events = note_extraction(midi_file)
    perplexities = [unigramProbabilities[note_events[0]]]
    index = bigramTransitions[note_events[0]].index(note_events[1])
    prob = bigramTransitionProbabilities[note_events[0]][index]
    perplexities.append(prob)
    
    for (note1, note2, note3) in zip(note_events[:-2], note_events[1:-1], note_events[2:]):
        index = trigramTransitions[(note1, note2)].index(note3)
        prob = trigramTransitionProbabilities[(note1, note2)][index]
        perplexities.append(prob)

    assert len(perplexities) == len(note_events)
    perplexity = np.exp(-np.sum(np.log(perplexities)) / len(note_events))
    return perplexity

#   Extracts a list of (beat_position, beat_length) pairs from a MIDI file using REMI tokens.
#   The beat position is taken from the 'Position_x' token, and beat length is looked up using
#   the duration2length table based on the 'Duration_x' token.
#   This captures rhythmic motifs and will be useful for modeling beat-level transitions or evaluating rhythm structure.

duration2length = {
    '0.2.8': 2,  # sixteenth note, 0.25 beat in 4/4 time signature
    '0.4.8': 4,  # eighth note, 0.5 beat in 4/4 time signature
    '1.0.8': 8,  # quarter note, 1 beat in 4/4 time signature
    '2.0.8': 16, # half note, 2 beats in 4/4 time signature
    '4.0.4': 32, # whole note, 4 beats in 4/4 time signature
}

#   Extracts a list of (beat_position, beat_length) pairs from a MIDI file using REMI tokens.
#   - beat_position: extracted from the 'Position_x' token (0–31 within a bar)
#   - beat_length: mapped from the 'Duration_x' token using the duration2length lookup table
#   This allows analysis and modeling of beat-level rhythmic motifs.
#   Returns:
#       List of tuples (beat_position, beat_length)

def beat_extraction(midi_file):
    midi = Score(midi_file)
    tokens = tokenizer(midi)[0].tokens
    beats = []
    
    for i in range(len(tokens)):
        if 'Position' in tokens[i] and 'Duration' in tokens[i+3]:
            position = int(tokens[i].split('_')[1])
            length = duration2length[tokens[i+3].split('_')[1]]
            beats.append((position, length))
    return beats

#   Computes bigram transition probabilities over beat lengths using extracted (beat_position, beat_length) pairs.
#   - For each MIDI file, extract beat lengths in sequence using beat_extraction().
#   - Count occurrences of (prev_length, next_length) transitions to build bigrams.
#   Returns:
#       - bigramBeatTransitions: dictionary mapping each beat length to a list of following beat lengths
#       - bigramBeatTransitionProbabilities: corresponding list of normalized probabilities for each next beat length

def beat_bigram_probability(midi_files):
    bigramBeat = defaultdict(int)
    for file in midi_files:
        beats = beat_extraction(file)
        for (beat1, beat2) in zip(beats[:-1], beats[1:]):
            bigramBeat[(beat1[1], beat2[1])] += 1
            
    bigramBeatTransitions = defaultdict(list)
    bigramBeatTransitionProbabilities = defaultdict(list)

    for b1,b2 in bigramBeat:
        bigramBeatTransitions[b1].append(b2)
        bigramBeatTransitionProbabilities[b1].append(bigramBeat[(b1,b2)])
        
    for k in bigramBeatTransitionProbabilities:
        Z = sum(bigramBeatTransitionProbabilities[k])
        bigramBeatTransitionProbabilities[k] = [x / Z for x in bigramBeatTransitionProbabilities[k]]
        
    return bigramBeatTransitions, bigramBeatTransitionProbabilities

#   Computes the probability distribution of beat length given beat position.
#   Returns:
#       - bigramBeatPosTransitions: {beat_position: [beat_length1, ...]}
#       - bigramBeatPosTransitionProbabilities: {beat_position: [p1, ...]}

def beat_pos_bigram_probability(midi_files):
    bigramBeatPos = defaultdict(int)
    for file in midi_files:
        beats = beat_extraction(file)
        for beat in beats:
            bigramBeatPos[(beat[0], beat[1])] += 1
            
    bigramBeatPosTransitions = defaultdict(list)
    bigramBeatPosTransitionProbabilities = defaultdict(list)

    for b1,b2 in bigramBeatPos:
        bigramBeatPosTransitions[b1].append(b2)
        bigramBeatPosTransitionProbabilities[b1].append(bigramBeatPos[(b1,b2)])
        
    for k in bigramBeatPosTransitionProbabilities:
        Z = sum(bigramBeatPosTransitionProbabilities[k])
        bigramBeatPosTransitionProbabilities[k] = [x / Z for x in bigramBeatPosTransitionProbabilities[k]]
        
    return bigramBeatPosTransitions, bigramBeatPosTransitionProbabilities

#   Computes two perplexities for beat prediction:
#       1. Using beat_length | previous_beat_length (Q7)
#       2. Using beat_length | beat_position (Q8)
#   Returns:
#       Tuple (perplexity_q7, perplexity_q8)

def beat_bigram_perplexity(midi_file):
    bigramBeatTransitions, bigramBeatTransitionProbabilities = beat_bigram_probability(midi_files)
    bigramBeatPosTransitions, bigramBeatPosTransitionProbabilities = beat_pos_bigram_probability(midi_files)

    unigramBeat = defaultdict(int)
    for file in midi_files:
        beats = beat_extraction(file)
        for beat in beats:
            unigramBeat[beat[1]] += 1
    unigramBeatProbabilities = {}
    counts = sum(list(unigramBeat.values()))
    for n in unigramBeat:
        unigramBeatProbabilities[n] = unigramBeat[n] / counts
        
    beat_events = beat_extraction(midi_file)
    beats = [b[1] for b in beat_events]

    # perplexity for Q7
    perplexities = [unigramBeatProbabilities[beats[0]]]
    for (beat1, beat2) in zip(beats[:-1], beats[1:]):
        index = bigramBeatTransitions[beat1].index(beat2)
        prob = bigramBeatTransitionProbabilities[beat1][index]
        perplexities.append(prob)
    assert len(perplexities) == len(beats)
    perplexity_Q7 = np.exp(-np.sum(np.log(perplexities)) / len(beats))
    
    # perplexity for Q8
    perplexities = []
    for (beat_position, beat_length) in beat_events:
        index = bigramBeatPosTransitions[beat_position].index(beat_length)
        prob = bigramBeatPosTransitionProbabilities[beat_position][index]
        perplexities.append(prob)
    assert len(perplexities) == len(beat_events)
    perplexity_Q8 = np.exp(-np.sum(np.log(perplexities)) / len(beats))
    
    return perplexity_Q7, perplexity_Q8

#   Computes trigram transition probabilities over beat lengths, conditioned on
#   (previous_beat_length, current_beat_position).
#   Returns:
#       - trigramBeatTransitions: {(prev_length, position): [curr_length1, ...]}
#       - trigramBeatTransitionProbabilities: {(prev_length, position): [p1, ...]}

def beat_trigram_probability(midi_files):
    trigramBeat = defaultdict(int)
    for file in midi_files:
        beats = beat_extraction(file)
        for (beat1, beat2) in zip(beats[:-1], beats[1:]):
            trigramBeat[(beat1[1], beat2[0], beat2[1])] += 1
            
    trigramBeatTransitions = defaultdict(list)
    trigramBeatTransitionProbabilities = defaultdict(list)

    for t1,t2,t3 in trigramBeat:
        trigramBeatTransitions[(t1,t2)].append(t3)
        trigramBeatTransitionProbabilities[(t1,t2)].append(trigramBeat[(t1,t2,t3)])
        
    for k in trigramBeatTransitionProbabilities:
        Z = sum(trigramBeatTransitionProbabilities[k])
        trigramBeatTransitionProbabilities[k] = [x / Z for x in trigramBeatTransitionProbabilities[k]]
        
    return trigramBeatTransitions, trigramBeatTransitionProbabilities

#   Computes the perplexity of the trigram beat model on a given MIDI file.
#   Each prediction is based on (previous_beat_length, current_beat_position) -> current_beat_length

def beat_trigram_perplexity(midi_file):
    bigramBeatPosTransitions, bigramBeatPosTransitionProbabilities = beat_pos_bigram_probability(midi_files)
    trigramBeatTransitions, trigramBeatTransitionProbabilities = beat_trigram_probability(midi_files)

    beats = beat_extraction(midi_file)

    perplexities = []
    index = bigramBeatPosTransitions[beats[0][0]].index(beats[0][1])
    prob = bigramBeatPosTransitionProbabilities[beats[0][0]][index]
    perplexities.append(prob)

    for (beat1, beat2) in zip(beats[:-1], beats[1:]):
        index = trigramBeatTransitions[(beat1[1], beat2[0])].index(beat2[1])
        prob = trigramBeatTransitionProbabilities[(beat1[1], beat2[0])][index]
        perplexities.append(prob)

    assert len(perplexities) == len(beats)
    perplexity = np.exp(-np.sum(np.log(perplexities)) / len(beats))
    return perplexity

### Extended Markov Modeling: Note Durations

To model rhythmic structure in addition to pitch, we extract and analyze note durations from the same tokenized MIDI sequences.

We implement:
- A **duration unigram model** to estimate overall rhythmic probabilities.
- A **duration bigram model** to capture likely transitions between rhythmic values.
- A **duration perplexity function** to evaluate model performance on unseen MIDI files.

This allows us to compare rhythmic predictability across files and evaluate the diversity of generated rhythms in later stages.

In [4]:
#   Extracts all note durations (as strings) from a single MIDI file using REMI tokens.

def duration_extraction(midi_file):
    durations = []
    midi = Score(midi_file)
    tokens = tokenizer(midi)[0].tokens
    for token in tokens:
        if 'Duration' in token:
            duration = token.split('_')[1]
            durations.append(duration)
    return durations

#   Aggregates duration counts over all MIDI files and returns a dictionary 
#   mapping each duration token (as string) to its total count.

def duration_frequency(midi_files):
    duration_counts = defaultdict(int)
    for midi_file in midi_files:
        durations = duration_extraction(midi_file)
        for dur in durations:
            duration_counts[dur] += 1
    return duration_counts

#   Converts duration counts to a probability distribution.

def duration_unigram_probability(midi_files):
    duration_counts = duration_frequency(midi_files)
    total = sum(duration_counts.values())
    return {dur: count / total for dur, count in duration_counts.items()}

#   Computes bigram transition probabilities for durations.
#   Returns:
#       - bigramDurations: {prev_duration: [next_duration1, ...]}
#       - bigramDurationProbabilities: {prev_duration: [p1, ...]}

def duration_bigram_probability(midi_files):
    bigrams = defaultdict(int)
    for midi_file in midi_files:
        durations = duration_extraction(midi_file)
        for d1, d2 in zip(durations[:-1], durations[1:]):
            bigrams[(d1, d2)] += 1
    
    bigramDurations = defaultdict(list)
    bigramDurationProbabilities = defaultdict(list)
    for d1, d2 in bigrams:
        bigramDurations[d1].append(d2)
        bigramDurationProbabilities[d1].append(bigrams[(d1, d2)])
    
    for d in bigramDurationProbabilities:
        Z = sum(bigramDurationProbabilities[d])
        bigramDurationProbabilities[d] = [p / Z for p in bigramDurationProbabilities[d]]
    
    return bigramDurations, bigramDurationProbabilities

#   Computes the perplexity of the duration bigram model on a MIDI file.

def duration_bigram_perplexity(midi_file):
    unigramProbs = duration_unigram_probability(midi_files)
    bigramDurations, bigramProbs = duration_bigram_probability(midi_files)
    durations = duration_extraction(midi_file)

    perplexities = [unigramProbs[durations[0]]]
    for d1, d2 in zip(durations[:-1], durations[1:]):
        index = bigramDurations[d1].index(d2)
        prob = bigramProbs[d1][index]
        perplexities.append(prob)

    assert len(perplexities) == len(durations)
    return np.exp(-np.sum(np.log(perplexities)) / len(durations))

### Music Generation with Extended Markov Models

This function generates 500-note music sequences by combining extended pitch and rhythm modeling:

- **Pitch Model**: Uses a trigram Markov model over notes (`note_trigram_probability`) to capture local melodic context. This improves upon Homework 3’s pitch-only models by considering two-note histories.

- **Rhythm Model**: Instead of using fixed or random durations, we extend Homework 3 by modeling note durations using a **duration bigram model** (`duration_bigram_probability`). This captures rhythmic transitions between note lengths, allowing for more realistic and musically coherent rhythm generation.

- **Output**: The sampled notes and durations are converted to MIDI using MIDIUtil, with durations mapped from REMI tokens and scaled by dividing their encoded beat lengths by 8. The final composition is saved as `a2.mid`.

These extensions improve musicality by introducing both melodic structure and expressive rhythmic variation.

In [5]:
def music_generate(length):
    # === PITCH SAMPLING USING TRIGRAM MODEL ===
    unigramProbs = note_unigram_probability(midi_files)
    bigramTrans, bigramProbs = note_bigram_probability(midi_files)
    trigramTrans, trigramProbs = note_trigram_probability(midi_files)

    first_note = choice(list(unigramProbs.keys()), p=list(unigramProbs.values()))
    second_note = choice(bigramTrans[first_note], p=bigramProbs[first_note])
    sampled_notes = [first_note, second_note]

    while len(sampled_notes) < length:
        prev_pair = (sampled_notes[-2], sampled_notes[-1])
        if prev_pair not in trigramTrans:
            break
        next_note = choice(trigramTrans[prev_pair], p=trigramProbs[prev_pair])
        sampled_notes.append(next_note)

    # === DURATION SAMPLING USING DURATION BIGRAM MODEL ===
    dur_map = {'0.2.8': 2, '0.4.8': 4, '1.0.8': 8, '2.0.8': 16, '4.0.4': 32}
    durationTrans, durationProbs = duration_bigram_probability(midi_files)

    first_dur = choice(list(durationTrans.keys()))
    sampled_durations = [first_dur]

    while len(sampled_durations) < length:
        prev = sampled_durations[-1]
        if prev not in durationTrans:
            break
        next_dur = choice(durationTrans[prev], p=durationProbs[prev])
        sampled_durations.append(next_dur)

    # Filter + convert to beat units
    durations = [dur_map[d] / 8 for d in sampled_durations if d in dur_map]
    min_len = min(len(sampled_notes), len(durations))
    sampled_notes = sampled_notes[:min_len]
    durations = durations[:min_len]

    # === WRITE TO MIDI FILE ===
    midi = MIDIFile(1)
    track = 0
    tempo = 120
    midi.addTempo(track, 0, tempo)

    current_time = 0
    note_count = 0
    total_duration = 0

    for pitch, dur in zip(sampled_notes, durations):
        if dur <= 0:
            continue
        midi.addNote(track, channel=0, pitch=pitch, time=current_time, duration=dur, volume=100)
        current_time += dur
        note_count += 1
        total_duration += dur

    print(f"Generated {note_count} notes, total duration: {total_duration:.2f} beats")
    if note_count == 0:
        print("⚠️ No notes added — possibly invalid durations or pitches.")
    elif total_duration < 1:
        print("⚠️ Total duration is too short — increase note durations or inspect duration model.")

    with open("a2.mid", "wb") as f:
        midi.writeFile(f)

### Task 1 Completion

This section fulfills the requirements of **Assignment 2 – Task 1: Symbolic, Unconditioned Generation**.

- We train a model that learns a symbolic music distribution p(x) over note sequences from a given MIDI dataset.
- While Homework 3 introduced basic pitch-based Markov models, this implementation significantly extends that baseline by:
  - Using a **trigram model** for pitch, enabling more coherent melodic structure.
  - Introducing a **duration bigram model**, which learns transitions between note lengths, adding rhythmic expressiveness to the generated music.
- These two models are sampled together to generate complete unconditioned sequences of symbolic music.
- The final output is a self-contained MIDI file (`a2.mid`) that samples from the learned distribution p(x) without any external conditioning (e.g., no prompts or templates).

Therefore, this meets the definition of symbolic, unconditioned generation using a significantly extended Markov model architecture.

In [6]:
# Generate 500 notes and save as MIDI, then play back

music_generate(500)

# Debug: Check that note and beat lengths match and are valid
from pathlib import Path

print("MIDI generation complete.")
print(f"File exists: {Path('a2.mid').exists()}")
print(f"File size: {Path('a2.mid').stat().st_size} bytes")

Generated 500 notes, total duration: 287.25 beats
MIDI generation complete.
File exists: True
File size: 4545 bytes


# Assignment 2 – Task 2: Conditional Composer-Style Melody Generation

This notebook implements a conditional symbolic generation model by learning a separate pitch Markov chain for each composer using the labeled dataset from Assignment 1. For each composer with sufficient data (20+ MIDI files), we tokenize symbolic music data using MiDiTok's REMI format and build a trigram pitch model.

Given a composer name (e.g., Chopin, Beethoven, Mozart), the model samples a new symbolic melody conditioned on that composer's learned pitch distribution. This allows us to stylistically generate melodies that reflect each composer’s unique musical tendencies.

This satisfies Task 2 of Assignment 2 under the category of *conditional symbolic generation*, where the conditioning variable is the composer identity. The generated output is symbolic (MIDI pitch sequences) and learned from real composer-labeled data. This approach highlights the stylistic variation across composers using interpretable and modular trigram models.

In [7]:
import os
import ast
from collections import defaultdict
from pathlib import Path

# Path to your data
base_dir = Path("task1_composer_classification")
midi_dir = base_dir / "midis"
train_json = base_dir / "train.json"

# Load train.json using ast since it's not valid JSON
with open(train_json, "r") as f:
    train_data = ast.literal_eval(f.read())

# Group MIDI files by composer
composer_to_midis = defaultdict(list)
for path, composer in train_data.items():
    composer_to_midis[composer].append(str(midi_dir / Path(path).name))

# Example: print 2 files per composer
for composer, files in composer_to_midis.items():
    print(f"{composer}: {files[:2]}")

Chopin: ['task1_composer_classification/midis/0.mid', 'task1_composer_classification/midis/2.mid']
Beethoven: ['task1_composer_classification/midis/1.mid', 'task1_composer_classification/midis/4.mid']
Bach: ['task1_composer_classification/midis/7.mid', 'task1_composer_classification/midis/19.mid']
Liszt: ['task1_composer_classification/midis/8.mid', 'task1_composer_classification/midis/16.mid']
Schumann: ['task1_composer_classification/midis/13.mid', 'task1_composer_classification/midis/71.mid']
Schubert: ['task1_composer_classification/midis/17.mid', 'task1_composer_classification/midis/21.mid']
Haydn: ['task1_composer_classification/midis/41.mid', 'task1_composer_classification/midis/64.mid']
Mozart: ['task1_composer_classification/midis/113.mid', 'task1_composer_classification/midis/201.mid']


In [8]:
# Only keep composers with at least 20 files
composer_to_midis = {comp: files for comp, files in composer_to_midis.items() if len(files) >= 20}

# Print remaining composers and file counts
print("Remaining composers with ≥20 MIDIs:")
for composer, files in composer_to_midis.items():
    print(f"{composer}: {len(files)} files")

Remaining composers with ≥20 MIDIs:
Chopin: 208 files
Beethoven: 490 files
Bach: 139 files
Liszt: 116 files
Schumann: 49 files
Schubert: 120 files
Haydn: 51 files
Mozart: 37 files


In [9]:
composer_models = {}

for composer, files in composer_to_midis.items():
    trigramTrans, trigramProbs = note_trigram_probability(files)
    composer_models[composer] = (trigramTrans, trigramProbs)

In [10]:
def generate_notes_from_composer(composer, length=100):
    trigramTrans, trigramProbs = composer_models[composer]

    # seed notes
    keys = list(trigramTrans.keys())
    seed = random.choice(keys)
    output = [seed[0], seed[1]]

    while len(output) < length:
        key = (output[-2], output[-1])
        if key not in trigramTrans:
            break
        next_note = choice(trigramTrans[key], p=trigramProbs[key])
        output.append(next_note)

    return output

In [11]:
def write_to_midi(pitches, output_file="composer_output.mid", duration=0.5):
    midi = MIDIFile(1)
    midi.addTempo(0, 0, 120)

    time = 0
    for pitch in pitches:
        midi.addNote(0, 0, pitch, time, duration, 100)
        time += duration

    with open(output_file, "wb") as f:
        midi.writeFile(f)

In [12]:
for composer in ["Chopin", "Beethoven", "Mozart"]:
    notes = generate_notes_from_composer(composer, length=100)
    write_to_midi(notes, output_file=f"{composer.lower()}_sample.mid")