# COGS 118B Final Project
## Rhythm Game Beat Chart Difficulty Classification

**Contributers:** Miles Davis, Theo Bui, Alan De Luna, Sanjith Shanmugavel, Nicole Huynh

#### Introduction

Rhythm games are blah blah blah. Add relevant background and motivation for classifying difficulties. Maybe something along the lines of: "Difficulty levels between songs in ___ game are inconsistent. Some beat charts may be considered more difficult because of the number of required hits for a high score simply increasing, increasing tempo, or rhythms becoming more complex."

#### Algorithm Choice
???

#### Data Preprocessing
1. read in `.tja` files as string
2. convert string to `numpy` array
3. extract relevant information

Metadata:
- `BPM`
- `OFFSET`
- `GENRE`

Course info:
- `COURSE`
- `LEVEL`
- `BALLOON`
- song notation -- written between `#START` and `#END` commands
- timing notation -- `#MEASURE` and `#BPMCHANGE`
    - from Taiko explanation: 'Measures in the chart are separated with a comma character , followed by a line break. Timing between each measure is the same as long as #MEASURE and #BPMCHANGE commands are not used. Measures may contain any amount of notes, including zero, the less numbers there are in a measure, the more far apart the notes will be in the chart, each measure is equally divided by the amount of numbers there are inside. "12," can be written as "1020," and "10002000,", the timing is identical in all three examples.'
- delay -- offsets position of following notation (can be negative to indicate overlap) and written as `#DELAY`
- scroll speed -- `#SCROLL`
- Go-Go Time -- written between `#GOGOSTART` and `#GOGOEND`
- measure lines turned on/off -- `#BARLINEOFF` and `#BARLINEON`
- branching and paths -- ???
- BPM change -- ???

In [48]:
import numpy as np
import pandas as pd
import os

In [19]:
test_path = 'test_data/cruel_angels_thesis.tja'

try:
    with open(test_path, 'r', encoding='utf-8') as file:
        test_str = file.read()
except FileNotFoundError:
    print('file not found :-(')
except Exception as e:
    print(f'an error occurred: {e}')

In [20]:
print(test_str)

﻿//TJADB Project
TITLE:A Cruel Angel's Thesis -New Audio-
TITLEJA:残酷な天使のテーゼ -新曲-
SUBTITLE:--Youko Takahashi/Neon Genesis Evangelion
BPM:80
WAVE:A Cruel Angels Thesis -New Audio-.ogg
OFFSET:-5.041
DEMOSTART:68.079

COURSE:Edit
LEVEL:8
SCOREINIT:740
SCOREDIFF:208


#START


,
#BPMCHANGE 79
1000100010010010,
#MEASURE 8/4
#BPMCHANGE 76.7
100000000000100000000000100000000000500000000000
#BPMCHANGE 134
000000000000000000000000000000000008000000000000,
#MEASURE 4/4
#BPMCHANGE 129
1010210121012010,
1010210121020000,
1010210121012010,
1010221030040000,
1010201012102000,
1010201012102000,
1120212011202120,
1120221010000000,
1120201010222000,
1120201010222000,
1020102010222010,
1020102022102210,
1020102010221020,
1020102010221020,
1120221011202210,
1120221010222000,
1010221022221000,
1010221022221000,
1020122010201220,
1101210121012102,
1101102012022020,
1120221011202210,
1101201011012010,
1212,
1101102011011020,
1220122011122020,
1121102011211010,
3000000030030030,

#GOGOSTART
#SECTION
#BRANCHST

In [21]:
# split string by lines
chart = test_str.split('COURSE:')
len(chart)

6

In [22]:
chart

["\ufeff//TJADB Project\nTITLE:A Cruel Angel's Thesis -New Audio-\nTITLEJA:残酷な天使のテーゼ -新曲-\nSUBTITLE:--Youko Takahashi/Neon Genesis Evangelion\nBPM:80\nWAVE:A Cruel Angels Thesis -New Audio-.ogg\nOFFSET:-5.041\nDEMOSTART:68.079\n\n",
 'Edit\nLEVEL:8\nSCOREINIT:740\nSCOREDIFF:208\n\n\n#START\n\n\n,\n#BPMCHANGE 79\n1000100010010010,\n#MEASURE 8/4\n#BPMCHANGE 76.7\n100000000000100000000000100000000000500000000000\n#BPMCHANGE 134\n000000000000000000000000000000000008000000000000,\n#MEASURE 4/4\n#BPMCHANGE 129\n1010210121012010,\n1010210121020000,\n1010210121012010,\n1010221030040000,\n1010201012102000,\n1010201012102000,\n1120212011202120,\n1120221010000000,\n1120201010222000,\n1120201010222000,\n1020102010222010,\n1020102022102210,\n1020102010221020,\n1020102010221020,\n1120221011202210,\n1120221010222000,\n1010221022221000,\n1010221022221000,\n1020122010201220,\n1101210121012102,\n1101102012022020,\n1120221011202210,\n1101201011012010,\n1212,\n1101102011011020,\n1220122011122020,\n1121102

In [23]:
# try to grab just relevant info listed above

# metadata for entire song chart
metadata = chart[0].split('\n')
bpm = int(metadata[4][4:])
offset = float(metadata[6][7:])

# courses by difficulty
edit = [i for i in chart if 'Edit' in i]
oni = [i for i in chart if 'Oni' in i]
hard = [i for i in chart if 'Hard' in i]
normal = [i for i in chart if 'Normal' in i]
easy = [i for i in chart if 'Easy' in i]

In [24]:
course_diff = {'Easy': 0, 'Normal': 1, 'Hard': 2, 'Oni': 3, 'Edit': 4, 'Tower': 5, 'Dan': 6}
lvl_diff = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

In [25]:
int([i[6:] for i in normal[0].split('\n') if 'LEVEL' in i][0])

4

In [26]:
normal

['Normal\nLEVEL:4\nSCOREINIT:1310\nSCOREDIFF:420\n\n\n#START\n\n\n1,\n#BPMCHANGE 79\n1,\n#BPMCHANGE 76.7\n#MEASURE 8/4\n100000000000100000000000500000000000000000000000\n#BPMCHANGE 134\n000000000000000000000008000000000000000000000000,\n#MEASURE 4/4\n#BPMCHANGE 129\n1110,\n11102000,\n1110,\n1010100030030000,\n1120,\n1120,\n11,\n12,\n10011020,\n10022000,\n10011020,\n10022000,\n1120,\n1120,\n11,\n12,\n10011020,\n10022000,\n10011020,\n500000000000000000000000000008000000000000000000,\n1001000010000000,\n2002000020000000,\n1001000010000000,\n1,\n1001000010000000,\n2002000020000000,\n1001000010000000,\n500000000000000000000000000008000000000000000000,\n\n#GOGOSTART\n1000100010010000,\n1000000020020000,\n1110,\n1000100020020000,\n1000100010010000,\n1000000020020000,\n1110,\n1000100030030000,\n1000100010010000,\n1000000020020000,\n1110,\n1000100030030000,\n\n#GOGOEND\n\n\n#END\n\n\n']

In [27]:
# complete course and measures
normal_full = normal[0].split('#START')[1]
normal_full = normal_full.split('#END')[0].split('\n')
all_measures = [i[:-1] for i in normal_full if ',' in i]

# just go-go time part of course and measures
gogo_time_idx = [
    i for i, val in enumerate(normal_full)
    if val in ('#GOGOSTART', '#GOGOEND')
]
gogo_time = normal_full[gogo_time_idx[0]:gogo_time_idx[1]+1]
gogo_time_measures = [i[:-1] for i in gogo_time if ',' in i]

# just non go-go time part of course and measures
reg_time = normal_full[:gogo_time_idx[0]] + normal_full[gogo_time_idx[1]+1:]
reg_measures = [i[:-1] for i in reg_time if ',' in i]

# every bpm and measure change
# but also taking note of what bpm/time sig the course is being changed to would also be valuable
bpm_change = [i for i in normal_full if '#BPMCHANGE' in i]
measure_change = [i for i in normal_full if '#MEASURE' in i]

# change gogo_time_idx to list of normal_full indices
gogo_time_idx = [i for i in range(gogo_time_idx[0], gogo_time_idx[1]+1)]
reg_time_idx = [i for i in range(len(normal_full)) if i not in gogo_time_idx]

In [28]:
normal_course = {
    'full_chart': normal_full,
    'all_measures': all_measures,
    'full_gogo_chart': gogo_time,
    'gogo_time_measures': gogo_time_measures,
    'gogo_time_indices': gogo_time_idx,
    'full_non_gogo_chart': reg_time,
    'non_gogo_measures': reg_measures,
    'non_gogo_indices': reg_time_idx,
    'all_bpm_changes': bpm_change,
    'all_time_sig_changes': measure_change
}

In [29]:
gogo_time_measures_copy = [i for i in gogo_time_measures]
gogo_time_measures_copy[2] = '1000100010000000'
gogo_time_measures_copy[6] = '1000100010000000'
gogo_time_measures_copy[10] = '1000100010000000'

In [30]:
gogo_time_measures_copy

['1000100010010000',
 '1000000020020000',
 '1000100010000000',
 '1000100020020000',
 '1000100010010000',
 '1000000020020000',
 '1000100010000000',
 '1000100030030000',
 '1000100010010000',
 '1000000020020000',
 '1000100010000000',
 '1000100030030000']

In [31]:
len('1000100010000000')

16

In [32]:
print(np.array(normal_course))

{'full_chart': ['', '', '', '1,', '#BPMCHANGE 79', '1,', '#BPMCHANGE 76.7', '#MEASURE 8/4', '100000000000100000000000500000000000000000000000', '#BPMCHANGE 134', '000000000000000000000008000000000000000000000000,', '#MEASURE 4/4', '#BPMCHANGE 129', '1110,', '11102000,', '1110,', '1010100030030000,', '1120,', '1120,', '11,', '12,', '10011020,', '10022000,', '10011020,', '10022000,', '1120,', '1120,', '11,', '12,', '10011020,', '10022000,', '10011020,', '500000000000000000000000000008000000000000000000,', '1001000010000000,', '2002000020000000,', '1001000010000000,', '1,', '1001000010000000,', '2002000020000000,', '1001000010000000,', '500000000000000000000000000008000000000000000000,', '', '#GOGOSTART', '1000100010010000,', '1000000020020000,', '1110,', '1000100020020000,', '1000100010010000,', '1000000020020000,', '1110,', '1000100030030000,', '1000100010010000,', '1000000020020000,', '1110,', '1000100030030000,', '', '#GOGOEND', '', '', ''], 'all_measures': ['1', '1', '000000000000000

In [33]:
list('1110')

['1', '1', '1', '0']

In [38]:
def handle_branches(chart_lines):
    has_branch = any('#BRANCHSTART' in line for line in chart_lines)
    if not has_branch:
        return chart_lines
    
    result = []
    in_branch = False
    current_branch = None
    branch_content = {'N': [], 'E': [], 'M': []}
    
    for line in chart_lines:
        stripped = line.strip()
        
        if stripped.startswith('#BRANCHSTART'):
            in_branch = True
            branch_content = {'N': [], 'E': [], 'M': []}
            current_branch = None
            continue
            
        elif stripped == '#N':
            current_branch = 'N'
            continue
            
        elif stripped == '#E':
            current_branch = 'E'
            continue
            
        elif stripped == '#M':
            current_branch = 'M'
            continue
            
        elif stripped == '#BRANCHEND':
            # Just picking M if possible and if not cycle through E and then N
            if branch_content['M']:
                result.extend(branch_content['M'])
            elif branch_content['E']:
                result.extend(branch_content['E'])
            elif branch_content['N']:
                result.extend(branch_content['N'])
            
            in_branch = False
            current_branch = None
            continue
        
        if in_branch and current_branch:
            branch_content[current_branch].append(line)
        elif not in_branch:
            result.append(line)
    
    return result

In [34]:
def format_measures(measures, max_len):
    new_measures = [i for i in measures]
    
    for i in range(len(measures)):
        if len(measures[i]) == 0:
            new_measures[i] = max_len*'0'
        elif len(measures[i]) != max_len:
            zeroes = int((max_len / len(measures[i])) - 1)
            reformat = [i + zeroes*'0' for i in list(measures[i])]
            new_measures[i] = ''.join(reformat)
    
    return new_measures

In [42]:
def extract_info(course):
    difficulty = course[0].split('\n')[0]
    level = int([i[6:] for i in course[0].split('\n') if 'LEVEL' in i][0])

    # complete course and measures
    full_course = course[0].split('#START')[1]
    full_course = full_course.split('#END')[0].split('\n')

    full_course = handle_branches(full_course)


    all_measures = [i[:-1] for i in full_course if ',' in i]
    max_measure_len = len(max(all_measures, key=len))
    all_measures = format_measures(all_measures, max_measure_len)

    # just go-go time part of course and measures
    gogo_time_idx = [
        i for i, val in enumerate(full_course)
        if val in ('#GOGOSTART', '#GOGOEND')
    ]
    gogo_time = full_course[gogo_time_idx[0]:gogo_time_idx[1]+1]
    gogo_time_measures = format_measures([i[:-1] for i in gogo_time if ',' in i], max_measure_len)

    # just non go-go time part of course and measures
    reg_time = full_course[:gogo_time_idx[0]] + full_course[gogo_time_idx[1]+1:]
    reg_measures = format_measures([i[:-1] for i in reg_time if ',' in i], max_measure_len)

    # every bpm and measure change
    # but also taking note of what bpm/time sig the course is being changed to would also be valuable
    bpm_change = [i for i in full_course if '#BPMCHANGE' in i]
    measure_change = [i for i in full_course if '#MEASURE' in i]

    # change gogo_time_idx to list of full_course indices
    gogo_time_idx = [i for i in range(gogo_time_idx[0], gogo_time_idx[1]+1)]
    reg_time_idx = [i for i in range(len(full_course)) if i not in gogo_time_idx]

    course_arr = {
        'course_difficulty': difficulty,
        'course_level': level,
        'full_course_chart': full_course,
        'all_measures': all_measures,
        'full_gogo_chart': gogo_time,
        'gogo_time_measures': gogo_time_measures,
        'gogo_time_indices': gogo_time_idx,
        'full_non_gogo_chart': reg_time,
        'non_gogo_measures': reg_measures,
        'non_gogo_indices': reg_time_idx,
        'all_bpm_changes': bpm_change,
        'all_time_sig_changes': measure_change
    }

    return np.array(course_arr)

In [52]:
def extract_all_courses(chart):
    metadata = chart[0].split('\n')
    bpm = int(metadata[4][4:])
    offset = float(metadata[6][7:])
    
    edit = [i for i in chart if 'Edit' in i]
    oni = [i for i in chart if 'Oni' in i]
    hard = [i for i in chart if 'Hard' in i]
    normal = [i for i in chart if 'Normal' in i]
    easy = [i for i in chart if 'Easy' in i]

    all_diff = [easy, normal, hard, oni, edit]
    all_extracted = []

    for d in all_diff:
        all_extracted.append(extract_info(d))

    return all_extracted

In [44]:
extract_all_courses(chart)

[array({'course_difficulty': 'Easy', 'course_level': 4, 'full_course_chart': ['', '', '', '1,', '#BPMCHANGE 79', '1,', '#BPMCHANGE 76.7', '#MEASURE 8/4', '100000000000100000000000500000000000000000000000', '#BPMCHANGE 134', '000000000000000000000000000000000008000000000000,', '#MEASURE 4/4', '#BPMCHANGE 129', '12,', '1120,', '11,', '1000100030030000,', '11,', '12,', '11,', '12,', '10011000,', '10022000,', '1012,', '10022000,', '11,', '12,', '11,', '12,', '10011000,', '10022000,', '1111,', '500000000000000000000000000000000008000000000000,', '1001000000000000,', '2002000000000000,', '12,', '12,', '1001000000000000,', '2002000000000000,', '11,', '500000000000000000000000000000000008000000000000,', '', '#GOGOSTART', '1110,', '1220,', '1110,', '1120,', '1110,', '1220,', '1110,', '1000000030030000,', '1110,', '1220,', '1110,', '1000000030030000,', '', '#GOGOEND', '', '', ''], 'all_measures': ['100000000000000000000000000000000000000000000000', '1000000000000000000000000000000000000000000000

In [45]:
def extract_note_features(measures):
    """Extract numerical features from note patterns"""
    
    # Concatenate all measures into one sequence
    full_sequence = ''.join(measures)
    
    features = {}
    
    # Basic note counts
    features['total_notes'] = sum(1 for c in full_sequence if c in '1234')
    features['don_notes'] = full_sequence.count('1')  # Red notes (center)
    features['ka_notes'] = full_sequence.count('2')   # Blue notes (rim)
    features['big_don'] = full_sequence.count('3')    # Big red
    features['big_ka'] = full_sequence.count('4')     # Big blue
    features['drumroll'] = full_sequence.count('5')   # Yellow drumrolls
    features['balloon'] = full_sequence.count('6')    # Balloons
    
    # Note density (notes per measure)
    features['note_density'] = features['total_notes'] / len(measures) if measures else 0
    
    # Pattern complexity metrics
    features['note_variety'] = len(set(c for c in full_sequence if c in '123456'))
    
    # Calculate note spacing (average zeros between notes)
    note_positions = [i for i, c in enumerate(full_sequence) if c in '1234']
    if len(note_positions) > 1:
        spacings = [note_positions[i+1] - note_positions[i] 
                   for i in range(len(note_positions)-1)]
        features['avg_spacing'] = np.mean(spacings)
        features['min_spacing'] = np.min(spacings)
        features['spacing_variance'] = np.var(spacings)
    else:
        features['avg_spacing'] = 0
        features['min_spacing'] = 0
        features['spacing_variance'] = 0
    
    # Pattern transitions (how often notes change type)
    transitions = sum(1 for i in range(len(note_positions)-1) 
                     if full_sequence[note_positions[i]] != full_sequence[note_positions[i+1]])
    features['transition_rate'] = transitions / len(note_positions) if note_positions else 0
    
    return features


def extract_timing_features(course_data):
    """Extract features from BPM and timing changes"""
    
    features = {}
    
    # Count timing changes
    features['bpm_changes'] = len(course_data['all_bpm_changes'])
    features['time_sig_changes'] = len(course_data['all_time_sig_changes'])
    
    # Extract BPM values
    bpm_values = []
    for change in course_data['all_bpm_changes']:
        bpm = float(change.split()[1])
        bpm_values.append(bpm)
    
    if bpm_values:
        features['max_bpm'] = max(bpm_values)
        features['min_bpm'] = min(bpm_values)
        features['avg_bpm'] = np.mean(bpm_values)
        features['bpm_variance'] = np.var(bpm_values)
    else:
        features['max_bpm'] = 0
        features['min_bpm'] = 0
        features['avg_bpm'] = 0
        features['bpm_variance'] = 0
    
    return features


def extract_gogo_features(course_data):
    """Extract features from Go-Go Time sections"""
    
    features = {}
    
    gogo_measures = course_data['gogo_time_measures']
    all_measures = course_data['all_measures']
    
    # Proportion of chart in go-go time
    features['gogo_proportion'] = len(gogo_measures) / len(all_measures) if all_measures else 0
    
    # Note density in go-go vs regular sections
    gogo_notes = sum(sum(1 for c in m if c in '1234') for m in gogo_measures)
    regular_measures = course_data['non_gogo_measures']
    regular_notes = sum(sum(1 for c in m if c in '1234') for m in regular_measures)
    
    features['gogo_note_density'] = gogo_notes / len(gogo_measures) if gogo_measures else 0
    features['regular_note_density'] = regular_notes / len(regular_measures) if regular_measures else 0
    
    return features


def create_feature_vector(course_data):
    """Combine all features into a single vector"""
    
    features = {}
    
    # Get all feature types
    note_features = extract_note_features(course_data['all_measures'])
    timing_features = extract_timing_features(course_data)
    gogo_features = extract_gogo_features(course_data)
    
    # Combine all features
    features.update(note_features)
    features.update(timing_features)
    features.update(gogo_features)
    
    # Add the target labels
    features['difficulty_name'] = course_data['course_difficulty']
    features['difficulty_level'] = course_data['course_level']
    
    return features

In [57]:
def process_all_charts(chart_directory):
    """Process all .tja files into a training dataset"""
    
    all_features = []
    
    # Iterate through all .tja files
    for filename in os.listdir(chart_directory):
        if not filename.endswith('.tja'):
            continue
            
        filepath = os.path.join(chart_directory, filename)
        
        try:
            with open(filepath, 'r', encoding='utf-8') as file:
                chart_str = file.read()
            
            # Parse the chart
            chart = chart_str.split('COURSE:')
            
            # Extract all courses
            all_courses = extract_all_courses(chart)
            
            # Process each difficulty
            for course_data in all_courses:
                # Convert numpy array to dict if needed
                if isinstance(course_data, np.ndarray):
                    course_data = course_data.item()  # This extracts the dict from numpy array
                
                features = create_feature_vector(course_data)
                all_features.append(features)
                
        except Exception as e:
            print(f"Error processing {filename}: {e}")
            import traceback
            traceback.print_exc()  # This will show full error details
            continue
    
    # Convert to DataFrame
    df = pd.DataFrame(all_features)
    
    return df

In [58]:
df = process_all_charts('test_data/')
df.to_csv('taiko_features.csv', index=False)

Error processing 88.tja: invalid literal for int() with base 10: 'ITLEJA:太鼓の達人オリジナル曲'
Error processing ADAMAS.tja: invalid literal for int() with base 10: 'ITLEJA:「ソードアート・オンライン アリシゼーション」より'
Error processing Ai Scream.tja: invalid literal for int() with base 10: 'ITLEJA:AiScReam 「ラブライブ！シリーズ」より'


Traceback (most recent call last):
  File "C:\Users\drake\AppData\Local\Temp\ipykernel_5720\4143115863.py", line 21, in process_all_charts
    all_courses = extract_all_courses(chart)
  File "C:\Users\drake\AppData\Local\Temp\ipykernel_5720\3790980975.py", line 3, in extract_all_courses
    bpm = int(metadata[4][4:])
ValueError: invalid literal for int() with base 10: 'ITLEJA:太鼓の達人オリジナル曲'
Traceback (most recent call last):
  File "C:\Users\drake\AppData\Local\Temp\ipykernel_5720\4143115863.py", line 21, in process_all_charts
    all_courses = extract_all_courses(chart)
  File "C:\Users\drake\AppData\Local\Temp\ipykernel_5720\3790980975.py", line 3, in extract_all_courses
    bpm = int(metadata[4][4:])
ValueError: invalid literal for int() with base 10: 'ITLEJA:「ソードアート・オンライン アリシゼーション」より'
Traceback (most recent call last):
  File "C:\Users\drake\AppData\Local\Temp\ipykernel_5720\4143115863.py", line 21, in process_all_charts
    all_courses = extract_all_courses(chart)
  File "C:\Users\d

In [56]:
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report, confusion_matrix

# Load the feature dataset
df = pd.read_csv('taiko_features.csv')

# Separate features and target
X = df.drop(['difficulty_name', 'difficulty_level'], axis=1)
y = df['difficulty_level']  # or use 'difficulty_name' for categorical

# Handle any NaN values
X = X.fillna(0)

# Split data
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

# Normalize features
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

# Train classifier
clf = RandomForestClassifier(n_estimators=100, random_state=42)
clf.fit(X_train_scaled, y_train)

# Evaluate
y_pred = clf.predict(X_test_scaled)
print(classification_report(y_test, y_pred))
print("\nConfusion Matrix:")
print(confusion_matrix(y_test, y_pred))

# Feature importance
feature_importance = pd.DataFrame({
    'feature': X.columns,
    'importance': clf.feature_importances_
}).sort_values('importance', ascending=False)

print("\nTop 10 Most Important Features:")
print(feature_importance.head(10))

ValueError: The least populated class in y has only 1 member, which is too few. The minimum number of groups for any class cannot be less than 2.