# Data

In [2]:
import re

def deduplicate_entities(entities):
    deduplicated = entities
    
    deduplicated = list(map(lambda x: re.sub(r'in\'$', 'ing', x).strip(), deduplicated))
    deduplicated = list(map(lambda x: re.sub(r'ning$', '', x).strip(), deduplicated))
    deduplicated = list(map(lambda x: re.sub(r'ing$', '', x).strip(), deduplicated))
    deduplicated = list(map(lambda x: re.sub(r'ed$', '', x).strip(), deduplicated))
    
    deduplicated = list(set(map(str.lower, deduplicated)))
    
    return deduplicated

In [3]:
from string import punctuation

def preprocess_text(text):
    text = re.sub(r"\s+", " ", text)
    text = text.lower()
    punctuation_ = punctuation.replace("'", "").replace("-", "")
    text = re.sub(f"[{re.escape(punctuation_)}]+", "", text)
    
    text = re.sub(r"in\'\b", "ing", text)
    text = re.sub(r"ning\b", "", text)
    text = re.sub(r"ing\b", "", text)
    text = re.sub(r"ed\b", "", text)
    
    return text

In [4]:
import os
import json
from pathlib import Path

base_path = "../data/labels_manual"
lyrics_path = "../data/lyrics/"
lyrics_dict = dict()

for filename in os.listdir(base_path):
    if not filename.endswith(".json") or ("template" in filename): continue
    
    file_path = Path(lyrics_path) / filename.replace(".json", ".txt")
    with open(file_path, "r") as f:
        lyrics = f.read().lower()
    lyrics = preprocess_text(lyrics)
        
    file_path = Path(base_path) / filename
    with open(file_path, "r") as f:
        labels = json.load(f)
        
    for entity_type in labels:
        labels[entity_type] = deduplicate_entities(labels[entity_type])
        
    lyrics_dict[filename.replace(".json", "")] = {
        "lyrics": lyrics,
        "labels": labels
    }
    
len(lyrics_dict)  

334

In [17]:
test_lyrics_dict = dict()

with open("../data/extraction/MANUAL/test.json", "r") as f:
    test = json.load(f)
    
for values in test:
    for entity_type in values["entities"]:
        values["entities"][entity_type] = deduplicate_entities(values["entities"][entity_type])
        
    test_lyrics_dict[values["id"]] = {
        "lyrics": preprocess_text(values["context"]),
        "labels": values["entities"]
    }
    
len(test_lyrics_dict)  

51

In [9]:
manual_vocab = {
    "human-powered": ["bicycle", "bike", "scooter", "skateboard", "walking", "on foot"],
    "animal-powered": ["horse", "carriage", "camel", "donkey", "sleigh"],
    "railways": ["train", "subway", "metro", "tram", "railway", "railroad"],
    "roadways": ["car", "bus", "taxi", "truck", "lorry", "jeep", "motorcycle"],
    "water_transport": ["boat", "ship", "ferry", "yacht", "sailboat", "submarine"],
    "air_transport": ["plane", "airplane", "jet", "helicopter", "rocket", "zeppelin"],
}

print("Manual Vocab")
for entity_type in manual_vocab:
    manual_vocab[entity_type] = deduplicate_entities(manual_vocab[entity_type])
    
    print(entity_type, len(manual_vocab[entity_type]))

Manual Vocab
human-powered 6
animal-powered 5
railways 6
roadways 7
water_transport 6
air_transport 6


In [10]:
from collections import defaultdict, Counter

def evaluate_extraction(true_entities_dict, predicted_entities_dict):
    true_entities_list = true_entities_dict.values()
    predicted_entities_list = predicted_entities_dict.values()
    
    true_entities_keys = list(true_entities_dict.keys())
    predicted_entities_keys = list(predicted_entities_dict.keys())
    
    assert len(true_entities_list) == len(predicted_entities_list), "The number of examples must match."

    # Collect all entity types appearing in any example
    entity_types = set()
    for true_dict in true_entities_list:
        entity_types.update(true_dict.keys())
    for pred_dict in predicted_entities_list:
        entity_types.update(pred_dict.keys())

    # Dictionary to store aggregated counts per entity type
    per_entity_counts = {etype: {"TP": 0, "FP": 0, "FN": 0} for etype in entity_types}
    fp = defaultdict(list)
    fn = defaultdict(list)

    # Iterate over all examples and aggregate counts for each entity type
    for true_dict, pred_dict, true_key, pred_key in zip(true_entities_list, predicted_entities_list, true_entities_keys, predicted_entities_keys):
        for etype in entity_types:
            # Get the list of entities for this type; default to empty list if missing.
            true_list = true_dict.get(etype, [])
            pred_list = pred_dict.get(etype, [])
            tmp = []
            for entity in pred_list:
                if isinstance(entity, str):
                    tmp.append(entity)
                elif "text" in entity:
                    tmp.append(entity["text"])
                elif "name" in entity:
                    tmp.append(entity["name"])
            pred_list = tmp
            pred_list = list(set(pred_list))
            
            # Use Counter to account for duplicates
            true_counter = Counter(true_list)
            pred_counter = Counter(pred_list)
            
            # Count true positives: for each entity present in both, add the minimum count
            common_entities = set(true_counter.keys()) & set(pred_counter.keys())
            TP = sum(min(true_counter[ent], pred_counter[ent]) for ent in common_entities)
            
            # Count false positives: predicted count minus the matched count for every predicted entity
            FP = sum(pred_counter[ent] - min(true_counter.get(ent, 0), pred_counter[ent]) for ent in pred_counter)
            
            # Count false negatives: true count minus the matched count for every true entity
            FN = sum(true_counter[ent] - min(true_counter[ent], pred_counter.get(ent, 0)) for ent in true_counter)
            
            # Aggregate counts
            per_entity_counts[etype]["TP"] += TP
            per_entity_counts[etype]["FP"] += FP
            per_entity_counts[etype]["FN"] += FN
            
            if FP > 0:
                fp[true_key].append({"entity": etype, "true": true_list, "pred": pred_list})
            if FN > 0:
                fn[true_key].append({"entity": etype, "true": true_list, "pred": pred_list})

    # Compute precision, recall, and F1 for each entity type
    per_entity_results = {}
    for etype, counts in per_entity_counts.items():
        TP = counts["TP"]
        FP = counts["FP"]
        FN = counts["FN"]
        precision = TP / (TP + FP) if (TP + FP) > 0 else 0.0
        recall    = TP / (TP + FN) if (TP + FN) > 0 else 0.0
        f1        = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
        per_entity_results[etype] = {"Precision": round(precision, 3), "Recall": round(recall, 3), "F1": round(f1, 3)}

    # Macro-average: average the metric scores over all entity types
    macro_precision = sum(result["Precision"] for result in per_entity_results.values()) / len(entity_types)
    macro_recall    = sum(result["Recall"] for result in per_entity_results.values()) / len(entity_types)
    macro_f1        = sum(result["F1"] for result in per_entity_results.values()) / len(entity_types)
    
    # Micro-average: aggregate counts over all entity types
    total_TP = sum(counts["TP"] for counts in per_entity_counts.values())
    total_FP = sum(counts["FP"] for counts in per_entity_counts.values())
    total_FN = sum(counts["FN"] for counts in per_entity_counts.values())

    micro_precision = total_TP / (total_TP + total_FP) if (total_TP + total_FP) > 0 else 0.0
    micro_recall = total_TP / (total_TP + total_FN) if (total_TP + total_FN) > 0 else 0.0
    micro_f1 = (2 * micro_precision * micro_recall) / (micro_precision + micro_recall) if (micro_precision + micro_recall) > 0 else 0.0
    
    results = {
        "Macro Precision": macro_precision,
        "Macro Recall": macro_recall,
        "Macro F1": macro_f1,
        "Micro Precision": micro_precision,
        "Micro Recall": micro_recall,
        "Micro F1": micro_f1
    }
    results = {metric: [round(value, 3)] for metric, value in results.items()}
    
    return per_entity_results, results, fp, fn

# Baseline 1: Rule-based keywords matching

In [11]:
from tqdm.notebook import tqdm

vocab_ = manual_vocab

def predict(text):
    predictions = dict()
    
    for entity_type in vocab_:
        predictions[entity_type] = []
        
        for entity in vocab_[entity_type]:
            entity = entity.lower()
            pattern = r'\b' + re.escape(entity) + r's?\b'
            if re.search(pattern, text):
                predictions[entity_type].append(entity)
                
        predictions[entity_type] = deduplicate_entities(predictions[entity_type])
    
    return predictions

In [18]:
from tqdm.notebook import tqdm

for id_, item in tqdm(test_lyrics_dict.items()):
    test_lyrics_dict[id_]["rule-based"] = predict(item["lyrics"])

  0%|          | 0/51 [00:00<?, ?it/s]

In [19]:
true_labels = {id_: item["labels"] for id_, item in test_lyrics_dict.items()}
predicted_labels = {id_: item["rule-based"] for id_, item in test_lyrics_dict.items()}

per_entity_results, overall_results, false_positives, false_negatives = evaluate_extraction(true_labels, predicted_labels)
per_entity_results

{'water_transport': {'Precision': 0.0, 'Recall': 0.0, 'F1': 0.0},
 'human-powered': {'Precision': 0.25, 'Recall': 0.2, 'F1': 0.222},
 'air_transport': {'Precision': 0.333, 'Recall': 0.25, 'F1': 0.286},
 'railways': {'Precision': 0.333, 'Recall': 1.0, 'F1': 0.5},
 'animal-powered': {'Precision': 0.0, 'Recall': 0.0, 'F1': 0.0},
 'roadways': {'Precision': 0.375, 'Recall': 0.15, 'F1': 0.214}}

# Baseline 2: Clustering models

In [20]:
from gensim.models import KeyedVectors

path = '../data/w2v/GoogleNews-vectors-negative300.bin.gz'
model = KeyedVectors.load_word2vec_format(path, binary=True)

In [21]:
import spacy
from nltk.corpus import stopwords
from tqdm.notebook import tqdm

stop_words = set(stopwords.words('english'))

nlp = spacy.load("en_core_web_sm")

def extract_noun_verb_tokens(text):
    doc = nlp(text.lower())
    return [token.text for token in doc if token.pos_ in ("NOUN", "VERB") and token.text not in stop_words]

lyrics_list = [lyrics_dict[id_]["lyrics"] for id_ in lyrics_dict]

sentences = []
for lyric in tqdm(lyrics_list):
    tokens = extract_noun_verb_tokens(lyric)
    sentences.append(tokens)

  0%|          | 0/334 [00:00<?, ?it/s]

In [22]:
from gensim.models import Word2Vec

w2v = Word2Vec(
    sentences,
    vector_size=300,
    window=15,
    min_count=5,
    workers=4,
    seed=42
)

In [171]:
w2v_ = model # w2v.wv

In [172]:
import numpy as np

def embed_mention(text):
    tokens = [t for t in text.lower().split() if t in w2v_]
    if not tokens:
        return np.zeros(w2v_.vector_size)
    return np.mean([w2v_[t] for t in tokens], axis=0)

In [173]:
import numpy as np
from numpy.linalg import norm

def cosine(a, b):
    d = norm(a)*norm(b)
    return float(np.dot(a, b) / d) if d != 0 else 0.0

def build_class_prototypes(w2v, class_seeds):
    prototypes = {}
    for cls, seeds in class_seeds.items():
        vecs = [embed_mention(w) for w in seeds if w in w2v]
        if vecs:
            prototypes[cls] = np.mean(vecs, axis=0)
    return prototypes

class_vecs = build_class_prototypes(w2v_, manual_vocab)

In [174]:
def max_transport_sim(word, w2v, class_vecs):
    if word not in w2v:
        return 0.0, None
    v = w2v[word]
    best_cls, best_sim = None, -1
    for cls, pvec in class_vecs.items():
        sim = cosine(v, pvec)
        if sim > best_sim:
            best_cls, best_sim = cls, sim
    return best_sim, best_cls

In [175]:
from tqdm.notebook import tqdm

words = list(w2v_.key_to_index.keys()) # all words in vocab

candidate_words = []
for w in tqdm(words):
    if w not in w2v_:
        continue
    sim, _ = max_transport_sim(w, w2v_, class_vecs)
    if sim >= 0.5:
        candidate_words.append(w)

embeddings = np.vstack([embed_mention(w) for w in candidate_words])
embeddings.shape

  0%|          | 0/3000000 [00:00<?, ?it/s]

(3515, 300)

In [176]:
from sklearn.cluster import KMeans

K = 50
kmeans = KMeans(n_clusters=K, n_init=10, random_state=0)
cluster_ids = kmeans.fit_predict(embeddings)

# mapping: word -> cluster id
word_to_cluster = {w: cid for w, cid in zip(candidate_words, cluster_ids)}

In [177]:
from collections import defaultdict

def score_cluster(cid, candidate_words, word_to_cluster, w2v, class_vecs):
    words_in_cluster = [w for w in candidate_words if word_to_cluster[w] == cid]
    scores = defaultdict(float)
    count = 0

    for w in words_in_cluster:
        if w not in w2v:
            continue
        v = w2v[w]
        for cls, pvec in class_vecs.items():
            scores[cls] += cosine(v, pvec)
        count += 1

    if count == 0:
        return None

    # normalize by cluster size to avoid bias toward large clusters
    for cls in scores:
        scores[cls] /= count
    return dict(scores)

def assign_cluster_types(K, candidate_words, word_to_cluster, w2v, class_vecs, threshold=0.5):
    cluster_to_type = {}
    for cid in range(K):
        scores = score_cluster(cid, candidate_words, word_to_cluster, w2v, class_vecs)
        if not scores:
            cluster_to_type[cid] = "OTHER"
            continue
        best_cls, best_score = max(scores.items(), key=lambda kv: kv[1])
        cluster_to_type[cid] = best_cls if best_score >= threshold else "OTHER"
    return cluster_to_type

cluster_to_type = assign_cluster_types(K, candidate_words, word_to_cluster, w2v_, class_vecs)

Counter(cluster_to_type.values())

Counter({'water_transport': 17,
         'air_transport': 10,
         'roadways': 9,
         'railways': 8,
         'human-powered': 3,
         'OTHER': 2,
         'animal-powered': 1})

In [178]:
def label_token(tok, w2v, word_to_cluster, cluster_to_type, class_vecs, sim_threshold=0.35):
    t = tok.lower()
    if t in word_to_cluster:
        cls = cluster_to_type[word_to_cluster[t]]
        if cls == "OTHER":
            return "OTHER"
        sim, best_cls = max_transport_sim(t, w2v, class_vecs)
        return cls if (sim >= sim_threshold) and (best_cls == cls) else "OTHER"
    elif t in w2v:
        sim, best_cls = max_transport_sim(t, w2v, class_vecs)
        return best_cls if sim >= sim_threshold else "OTHER"
    else:
        return "OTHER"
    
# label_token("car", w2v_, word_to_cluster, cluster_to_type, class_vecs)

In [181]:
for id_, item in tqdm(test_lyrics_dict.items()):
    tokens = extract_noun_verb_tokens(item["lyrics"])
    predictions = dict()
    for entity_type in manual_vocab:
        predictions[entity_type] = []
    for tok in tokens:
        label = label_token(tok, w2v_, word_to_cluster, cluster_to_type, class_vecs, sim_threshold=0.7)
        if label != "OTHER":
            predictions[label].append(tok)
    for entity_type in predictions:
        predictions[entity_type] = deduplicate_entities(predictions[entity_type])
    test_lyrics_dict[id_]["clustering"] = predictions

  0%|          | 0/51 [00:00<?, ?it/s]

In [182]:
true_labels = {id_: item["labels"] for id_, item in test_lyrics_dict.items()}
predicted_labels = {id_: item["clustering"] for id_, item in test_lyrics_dict.items()}

per_entity_results, overall_results, false_positives, false_negatives = evaluate_extraction(true_labels, predicted_labels)
per_entity_results

{'water_transport': {'Precision': 0.0, 'Recall': 0.0, 'F1': 0.0},
 'human-powered': {'Precision': 0.0, 'Recall': 0.0, 'F1': 0.0},
 'air_transport': {'Precision': 0.333, 'Recall': 0.25, 'F1': 0.286},
 'railways': {'Precision': 0.5, 'Recall': 1.0, 'F1': 0.667},
 'animal-powered': {'Precision': 0.0, 'Recall': 0.0, 'F1': 0.0},
 'roadways': {'Precision': 0.429, 'Recall': 0.15, 'F1': 0.222}}

In [170]:
true_labels = {id_: item["labels"] for id_, item in test_lyrics_dict.items()}
predicted_labels = {id_: item["clustering"] for id_, item in test_lyrics_dict.items()}

per_entity_results, overall_results, false_positives, false_negatives = evaluate_extraction(true_labels, predicted_labels)
per_entity_results

{'water_transport': {'Precision': 0.0, 'Recall': 0.0, 'F1': 0.0},
 'human-powered': {'Precision': 0.002, 'Recall': 0.2, 'F1': 0.003},
 'air_transport': {'Precision': 0.0, 'Recall': 0.0, 'F1': 0.0},
 'railways': {'Precision': 0.0, 'Recall': 0.0, 'F1': 0.0},
 'animal-powered': {'Precision': 0.0, 'Recall': 0.0, 'F1': 0.0},
 'roadways': {'Precision': 0.003, 'Recall': 0.2, 'F1': 0.006}}

In [96]:
true_labels = {id_: item["labels"] for id_, item in test_lyrics_dict.items()}
predicted_labels = {id_: item["clustering"] for id_, item in test_lyrics_dict.items()}

per_entity_results, overall_results, false_positives, false_negatives = evaluate_extraction(true_labels, predicted_labels)
per_entity_results

{'water_transport': {'Precision': 0.0, 'Recall': 0.0, 'F1': 0.0},
 'human-powered': {'Precision': 0.1, 'Recall': 0.2, 'F1': 0.133},
 'air_transport': {'Precision': 0.048, 'Recall': 0.25, 'F1': 0.08},
 'railways': {'Precision': 0.033, 'Recall': 1.0, 'F1': 0.065},
 'animal-powered': {'Precision': 0.0, 'Recall': 0.0, 'F1': 0.0},
 'roadways': {'Precision': 0.318, 'Recall': 0.35, 'F1': 0.333}}

In [146]:
from collections import Counter

freq = Counter(w for sent in sentences for w in sent)

clusters = {cid: [] for cid in range(K)}
for w, cid in word_to_cluster.items():
    clusters[cid].append(w)

for cid in range(K):
    print(f"\nCluster {cid}")
    print(sorted(clusters[cid], key=lambda w: -freq[w])[:10])


Cluster 0
['stop', 'lights', "nothin'", 'drop', 'lil', 'hurt', 'swear', 'lookin', 'sky', 'hoes']

Cluster 1
['water', 'lady', 'peat', 'gett', 'birthday', 'feelin', 'wonder', 'hang', 'dog', 'jeans']

Cluster 2
['na']

Cluster 3
['ways', 'shoes', 'fit', 'fell', 'shorty', 'lay', 'family', 'space', 'fool', 'finish']

Cluster 4
['bitch', 'money', 'r']

Cluster 5
['hope', 'lost', 'gets', 'fight', 'dress', 'house', 'honey', 'livin', 'power', 'cash']

Cluster 6
['doo']

Cluster 7
['got', 'aingt', 'think', 'done', 'wait', 'boy', 'hands', 'ooh']

Cluster 8
['tonight', 'babe']

Cluster 9
['choose', 'celebrate', 'known', 'climb', 'daylight', 'luck', 'low', 'logic', 'weather', 'thicke']

Cluster 10
['thunder']

Cluster 11
['bam']

Cluster 12
['wick']

Cluster 13
['fuck', 'shit', 'put', 'man', 'hit', 'niggas', 'nigga', 'dance', 'ass', 'lot']

Cluster 14
['get', 'go', 'say', 'make', '-', 'come', 'give', 'talk', 'stay', 'day']

Cluster 15
['turn', 'pull', 'br', 'ones']

Cluster 16
['comin']

Cluster 