In [None]:
import json
import pathlib

from datasets import load_dataset
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from tqdm import tqdm

In [None]:
dataset = load_dataset('iggy12345/rus-pol-edge-probing-phono-feats')

# Average span length

In [None]:
dataset = dataset.map(lambda x: {'span_length': (((x['windows']['end'][0] - x['windows']['start'][0] + 1) + (x['windows-phoneme']['end'][0] - x['windows-phoneme']['start'][0] + 1)) / 2) if len(x['windows']['start']) > 0 else 0})

In [None]:
average_length = 0
count = 0
for split in dataset:
    average_length += sum(dataset[split]['span_length'])
    count += len(dataset[split])
average_length /= count
print('average span length:', average_length)

# Feature Distribution

In [None]:
supports = {}
for split in dataset:
    supports[split] = {}
    for row in dataset[split]:
        if len(row['features']) == 0:
            continue
        feats = row['features'][0]
        for feat in feats:
            if feat not in supports[split]:
                supports[split][feat] = 0
            else:
                supports[split][feat] += 1

In [None]:
with open('data/mappings.json', 'r') as fp:
    phoneme_mappings = json.load(fp)
inverse_mappings = {}
for k, v in phoneme_mappings['features'].items():
    inverse_mappings[v] = k

In [None]:
plt.figure(figsize=(12, 4))

width = 0.25

feature_ticks = np.arange(len(inverse_mappings) * 2)
feature_idxs = {}
labels = []
for feat in phoneme_mappings['features'].keys():
    feature_idxs[-phoneme_mappings['features'][feat]] = len(labels)
    labels.append(feat + ' N/A')
for feat in phoneme_mappings['features'].keys():
    feature_idxs[phoneme_mappings['features'][feat]] = len(labels)
    labels.append(feat)

for si, split in enumerate(supports.keys()):
    counts = np.zeros(feature_ticks.shape)
    for feat_idx in supports[split].keys():
        counts[feature_idxs[feat_idx]] = supports[split][feat_idx]
    plt.bar(feature_ticks + [-width, 0, width][si], counts, label=split, width=width)

plt.xlabel("Feature")
plt.xticks(feature_ticks, labels, rotation=90)
plt.ylabel("Count")
plt.title("Feature Distribution")
plt.legend()
plt.show()

# Feature Inventory

In [None]:
lang_supports = {}
for split in dataset:
    lang_supports[split] = {}
    for row in dataset[split]:
        if len(row['features']) == 0:
            continue
        lang = row['language']
        if lang not in lang_supports[split]:
            lang_supports[split][lang] = set()
        feats = row['features'][0]
        for feat in feats:
            if feat > 0:
                lang_supports[split][lang].add(feat)

In [None]:
for split in lang_supports:
    split_intersection = lang_supports[split]['rus'] & lang_supports[split]['pol']
    rus_invent = lang_supports[split]['rus'] - split_intersection
    pol_invent = lang_supports[split]['pol'] - split_intersection
    print(split)
    print('disjoint inventory:')
    print('russian:', rus_invent)
    print('polish:', pol_invent)
    print('combined:', split_intersection)