# UID Paper Replication - Step by Step

Name: Ashish Pisey
Roll no.: M25Ai2117

In [1]:
# STEP 0: Mount Google Drive and set paths

from google.colab import drive
drive.mount('/content/drive')

TRAIN_PATH = '/content/drive/MyDrive/MTP/Projects/UID_replication/Datasets/hi_hdtb-ud-train.conllu'
DEV_PATH = '/content/drive/MyDrive/MTP/Projects/UID_replication/Datasets/hi_hdtb-ud-dev.conllu'
TEST_PATH = '/content/drive/MyDrive/MTP/Projects/UID_replication/Datasets/hi_hdtb-ud-test.conllu'

# For testing with smaller data, limit sentences
MAX_TRAIN = 5000  # Set to None for all
MAX_TEST = 500    # Set to None for all
MAX_VARIANTS = 20  # Variants per sentence

Mounted at /content/drive


In [2]:
# STEP 1: Import required libraries

import os
import math
import random
from collections import defaultdict, Counter
from itertools import permutations
import numpy as np
from scipy import stats
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import KFold

random.seed(42)
np.random.seed(42)

In [3]:
# STEP 2: Define data structures for tokens and sentences

class Token:
    def __init__(self, id, form, lemma, upos, xpos, feats, head, deprel, deps, misc):
        self.id = id
        self.form = form
        self.lemma = lemma
        self.upos = upos
        self.xpos = xpos
        self.feats = feats
        self.head = head
        self.deprel = deprel
        self.deps = deps
        self.misc = misc

class Sentence:
    def __init__(self, sent_id, text, tokens):
        self.sent_id = sent_id
        self.text = text
        self.tokens = tokens

    def get_root(self):
        for t in self.tokens:
            if t.head == 0:
                return t
        return None

    def get_word_sequence(self):
        sorted_tokens = sorted(self.tokens, key=lambda x: x.id)
        return [t.form for t in sorted_tokens]

    def get_preverbal_constituents(self):
        root = self.get_root()
        if not root:
            return []
        preverbal = [t for t in self.tokens if t.id < root.id]
        constituents = defaultdict(list)
        for t in preverbal:
            constituents[t.head].append(t)
        result = []
        for head_id in sorted(constituents.keys()):
            result.append(sorted(constituents[head_id], key=lambda x: x.id))
        return result

In [4]:
# STEP 3: Parse CoNLL-U files

def parse_feats(feats_str):
    feats = {}
    if feats_str and feats_str != '_':
        for feat in feats_str.split('|'):
            if '=' in feat:
                key, value = feat.split('=', 1)
                feats[key] = value
    return feats

def parse_conllu(filepath, max_sent=None):
    sentences = []
    current_tokens = []
    sent_id = ''
    text = ''

    with open(filepath, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()

            if line.startswith('# sent_id'):
                sent_id = line.split('=', 1)[1].strip() if '=' in line else ''
            elif line.startswith('# text'):
                text = line.split('=', 1)[1].strip() if '=' in line else ''
            elif line == '':
                if current_tokens:
                    sentences.append(Sentence(sent_id, text, current_tokens))
                    current_tokens = []
                    sent_id = ''
                    text = ''
                    if max_sent and len(sentences) >= max_sent:
                        break
            else:
                fields = line.split('\t')
                if len(fields) == 10 and '-' not in fields[0] and '.' not in fields[0]:
                    token = Token(
                        id=int(fields[0]),
                        form=fields[1],
                        lemma=fields[2],
                        upos=fields[3],
                        xpos=fields[4],
                        feats=parse_feats(fields[5]),
                        head=int(fields[6]) if fields[6] != '_' else 0,
                        deprel=fields[7],
                        deps=fields[8],
                        misc=fields[9]
                    )
                    current_tokens.append(token)

    if current_tokens:
        sentences.append(Sentence(sent_id, text, current_tokens))

    return sentences

In [5]:
# CoNLL-U parsing with `conllu` library (reference implementation)

import sys
import subprocess

try:
    from conllu import parse_incr
except ImportError:
    subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'conllu'])
    from conllu import parse_incr


def parse_conllu_with_conllu(filepath, max_sent=None):
    sentences = []

    with open(filepath, 'r', encoding='utf-8') as f:
        for i, tokenlist in enumerate(parse_incr(f)):
            sent_id = tokenlist.metadata.get('sent_id', '')
            text = tokenlist.metadata.get('text', '')
            tokens = []

            for tok in tokenlist:
                tok_id = tok.get('id')
                # Skip multiword tokens and empty nodes to match our parser behavior.
                if not isinstance(tok_id, int):
                    continue

                head = tok.get('head')
                tokens.append(Token(
                    id=tok_id,
                    form=tok.get('form', ''),
                    lemma=tok.get('lemma', '_'),
                    upos=tok.get('upos', '_'),
                    xpos=tok.get('xpos', '_'),
                    feats=dict(tok.get('feats') or {}),
                    head=head if isinstance(head, int) else 0,
                    deprel=tok.get('deprel', '_'),
                    deps=tok.get('deps', '_'),
                    misc=tok.get('misc', '_'),
                ))

            if tokens:
                sentences.append(Sentence(sent_id, text, tokens))

            if max_sent and len(sentences) >= max_sent:
                break

    return sentences

# Quick demo (parses a few sentences if one of the configured paths exists)
_demo_path = next((p for p in [globals().get('TEST_PATH'), globals().get('DEV_PATH'), globals().get('TRAIN_PATH')] if isinstance(p, str) and os.path.exists(p)), None)
if _demo_path:
    _demo_sents = parse_conllu_with_conllu(_demo_path, max_sent=2)
    print(f'conllu parser demo: parsed {len(_demo_sents)} sentence(s) from {_demo_path}')
    if _demo_sents:
        print(f'  First sent_id: {_demo_sents[0].sent_id}')
        print(f'  First text: {_demo_sents[0].text}')
        print(f'  First words: {_demo_sents[0].get_word_sequence()[:15]}')
else:
    print('conllu parser cell ready. Set TRAIN_PATH/DEV_PATH/TEST_PATH to run the demo.')


conllu parser demo: parsed 2 sentence(s) from /content/drive/MyDrive/MTP/Projects/UID_replication/Datasets/hi_hdtb-ud-test.conllu
  First sent_id: test-s1
  First text: इसके अतिरिक्त गुग्गुल कुंड, भीम गुफा तथा भीमशिला भी दर्शनीय स्थल हैं ।
  First words: ['इसके', 'अतिरिक्त', 'गुग्गुल', 'कुंड', ',', 'भीम', 'गुफा', 'तथा', 'भीमशिला', 'भी', 'दर्शनीय', 'स्थल', 'हैं', '।']


In [6]:
# STEP 4: Loading dataset

print('Loading training data...')
train_sentences = parse_conllu(TRAIN_PATH, MAX_TRAIN)
print(f'Loaded {len(train_sentences)} training sentences')

print('Loading dev data...')
dev_sentences = parse_conllu(DEV_PATH)
print(f'Loaded {len(dev_sentences)} dev sentences')

print('Loading test data...')
test_sentences = parse_conllu(TEST_PATH, MAX_TEST)
print(f'Loaded {len(test_sentences)} test sentences')

all_sentences = train_sentences + dev_sentences + test_sentences
print(f'Total: {len(all_sentences)} sentences')

Loading training data...
Loaded 5000 training sentences
Loading dev data...
Loaded 1659 dev sentences
Loading test data...
Loaded 500 test sentences
Total: 7159 sentences


In [7]:
# STEP 5: Build trigram model for lexical surprisal

unigram_counts = Counter()
bigram_counts = Counter()
trigram_counts = Counter()
vocab = set()
total_tokens = 0

print('Training trigram model...')

for sentence in train_sentences:
    words = sentence.get_word_sequence()
    tokens = ['<s>', '<s>'] + words + ['</s>']

    for i in range(len(tokens)):
        unigram_counts[tokens[i]] += 1
        vocab.add(tokens[i])

        if i > 0:
            bigram_counts[(tokens[i-1], tokens[i])] += 1

        if i > 1:
            trigram_counts[(tokens[i-2], tokens[i-1], tokens[i])] += 1

    total_tokens += len(words)

vocab_size = len(vocab)
print(f'Vocabulary size: {vocab_size}')
print(f'Total tokens: {total_tokens}')
print(f'Unigrams: {len(unigram_counts)}')
print(f'Bigrams: {len(bigram_counts)}')
print(f'Trigrams: {len(trigram_counts)}')

Training trigram model...
Vocabulary size: 10294
Total tokens: 101712
Unigrams: 10294
Bigrams: 52111
Trigrams: 81145


In [8]:
# STEP 6: Function to compute lexical surprisal

def get_trigram_prob(word, prev1, prev2):
    lambda3 = 0.6
    lambda2 = 0.3
    lambda1 = 0.1

    trigram_count = trigram_counts.get((prev2, prev1, word), 0)
    bigram_context = bigram_counts.get((prev2, prev1), 0)
    p_trigram = trigram_count / bigram_context if bigram_context > 0 else 0

    bigram_count = bigram_counts.get((prev1, word), 0)
    unigram_context = unigram_counts.get(prev1, 0)
    p_bigram = bigram_count / unigram_context if unigram_context > 0 else 0

    p_unigram = (unigram_counts.get(word, 0) + 1) / (total_tokens + vocab_size)

    prob = lambda3 * p_trigram + lambda2 * p_bigram + lambda1 * p_unigram
    return max(prob, 1e-10)

def compute_lexical_surprisal(word_sequence):
    tokens = ['<s>', '<s>'] + word_sequence + ['</s>']
    surprisals = []

    for i in range(2, len(tokens)):
        prob = get_trigram_prob(tokens[i], tokens[i-1], tokens[i-2])
        surprisal = -math.log2(prob)
        surprisals.append(surprisal)

    return surprisals

In [9]:
# STEP 7: Build syntactic surprisal model

state_counts = Counter()
transition_counts = Counter()
total_transitions = 0

print('Training syntactic model...')

def get_state(tokens, idx):
    if idx < 0:
        return 'START'
    parts = []
    for i in range(max(0, idx-1), idx+1):
        parts.append(f"{tokens[i].upos}:{tokens[i].deprel}")
    return '|'.join(parts)

for sentence in train_sentences:
    tokens = sorted(sentence.tokens, key=lambda x: x.id)

    for i, token in enumerate(tokens):
        state = get_state(tokens, i)
        state_counts[state] += 1

        if i > 0:
            prev_state = get_state(tokens, i-1)
            transition_counts[(prev_state, state)] += 1
            total_transitions += 1

print(f'Unique states: {len(state_counts)}')
print(f'Unique transitions: {len(transition_counts)}')

Training syntactic model...
Unique states: 2347
Unique transitions: 12203


In [10]:
# STEP 8: Function to compute syntactic surprisal

def get_state_prob(tokens, idx):
    state = get_state(tokens, idx)
    count = state_counts.get(state, 0)
    return (count + 1) / (total_transitions + len(state_counts) + 1)

def compute_syntactic_surprisal(sentence):
    tokens = sorted(sentence.tokens, key=lambda x: x.id)
    surprisals = []

    for i in range(len(tokens)):
        if i == 0:
            surprisals.append(0.0)
        else:
            prev_prob = get_state_prob(tokens, i-1)
            curr_prob = get_state_prob(tokens, i)
            ratio = curr_prob / max(prev_prob, 1e-10)
            surprisal = -math.log2(ratio)
            surprisals.append(max(0, surprisal))

    return surprisals

In [11]:
# STEP 9: Define UID measures

def mean_info(info_list):
    return sum(info_list) / len(info_list) if info_list else 0

def uid_global(info_list):
    if len(info_list) <= 1:
        return 0
    m = mean_info(info_list)
    variance = sum((x - m) ** 2 for x in info_list) / len(info_list)
    return -variance

def uid_local(info_list):
    if len(info_list) <= 1:
        return 0
    diffs = [(info_list[i] - info_list[i-1]) ** 2 for i in range(1, len(info_list))]
    return -sum(diffs) / len(info_list)

def uid_global_norm(info_list):
    if len(info_list) <= 1:
        return 0
    m = mean_info(info_list)
    if m == 0:
        return 0
    return -sum(((x / m) - 1) ** 2 for x in info_list) / len(info_list)

def uid_local_norm(info_list):
    if len(info_list) <= 1:
        return 0
    m = mean_info(info_list)
    if m == 0:
        return 0
    diffs = [((info_list[i] - info_list[i-1]) ** 2) / (m ** 2) for i in range(1, len(info_list))]
    return -sum(diffs) / len(info_list)

def uid_local_prev_norm(info_list):
    if len(info_list) <= 1:
        return 0
    diffs = []
    for i in range(1, len(info_list)):
        if info_list[i-1] != 0:
            diffs.append(((info_list[i] / info_list[i-1]) - 1) ** 2)
    return -sum(diffs) / len(info_list) if diffs else 0

def compute_all_uid(info_list):
    return {
        'UIDglob': uid_global(info_list),
        'UIDloc': uid_local(info_list),
        'UIDglobNorm': uid_global_norm(info_list),
        'UIDlocNorm': uid_local_norm(info_list),
        'UIDlocPrevNorm': uid_local_prev_norm(info_list)
    }

In [12]:
# STEP 10: Generate variants by permuting preverbal constituents

def generate_variant(sentence, permutation):
    constituents = sentence.get_preverbal_constituents()
    root = sentence.get_root()
    postverbal = [t for t in sentence.tokens if root and t.id > root.id]

    new_tokens = []
    new_id = 1

    for const_idx in permutation:
        if const_idx < len(constituents):
            for t in constituents[const_idx]:
                new_tokens.append(Token(new_id, t.form, t.lemma, t.upos, t.xpos, t.feats, 0, t.deprel, t.deps, t.misc))
                new_id += 1

    if root:
        new_tokens.append(Token(new_id, root.form, root.lemma, root.upos, root.xpos, root.feats, 0, root.deprel, root.deps, root.misc))
        new_id += 1

    for t in postverbal:
        new_tokens.append(Token(new_id, t.form, t.lemma, t.upos, t.xpos, t.feats, t.head, t.deprel, t.deps, t.misc))
        new_id += 1

    text = ' '.join([t.form for t in new_tokens])
    return Sentence(f"{sentence.sent_id}_var", text, new_tokens)

def generate_variants(sentence, max_var=99):
    constituents = sentence.get_preverbal_constituents()
    n_const = len(constituents)
    if n_const <= 1:
        return []

    original = tuple(range(n_const))
    total_alternatives = math.factorial(n_const) - 1
    selected_perms = []

    # Avoid materializing all permutations; this can explode RAM for long sentences.
    if total_alternatives <= max_var:
        for p in permutations(range(n_const)):
            if p != original:
                selected_perms.append(p)
    else:
        seen = {original}
        while len(selected_perms) < max_var:
            p = list(range(n_const))
            random.shuffle(p)
            p = tuple(p)
            if p in seen:
                continue
            seen.add(p)
            selected_perms.append(p)

    return [generate_variant(sentence, p) for p in selected_perms]


In [13]:
# STEP 11: Extract features for a sentence

def extract_features(sentence):
    features = {}

    words = sentence.get_word_sequence()
    lex_surp = compute_lexical_surprisal(words)
    syn_surp = compute_syntactic_surprisal(sentence)

    features['lex_sum'] = sum(lex_surp)
    features['syn_sum'] = sum(syn_surp)

    lex_uid = compute_all_uid(lex_surp)
    syn_uid = compute_all_uid(syn_surp)

    for k, v in lex_uid.items():
        features[f'lex_{k}'] = v
    for k, v in syn_uid.items():
        features[f'syn_{k}'] = v

    return features

In [14]:
# STEP 12: Generate variants and extract features

print('Generating variants and extracting features...')

corpus_features = []
variant_features = []
pair_count = 0

for i, sentence in enumerate(test_sentences):
    variants = generate_variants(sentence, MAX_VARIANTS)
    if not variants:
        continue

    # Compute once per original sentence; reuse for all its variants.
    corp_feat = extract_features(sentence)

    for variant in variants:
        var_feat = extract_features(variant)
        corpus_features.append(corp_feat)
        variant_features.append(var_feat)
        pair_count += 1

    if (i + 1) % 10 == 0:
        print(f'Processed {i+1}/{len(test_sentences)} sentences, {pair_count} pairs')

print()
print(f'Total corpus-variant pairs: {pair_count}')


Generating variants and extracting features...
Processed 20/500 sentences, 198 pairs
Processed 30/500 sentences, 353 pairs
Processed 40/500 sentences, 503 pairs
Processed 50/500 sentences, 635 pairs
Processed 60/500 sentences, 781 pairs
Processed 70/500 sentences, 947 pairs
Processed 80/500 sentences, 1075 pairs
Processed 90/500 sentences, 1260 pairs
Processed 100/500 sentences, 1391 pairs
Processed 120/500 sentences, 1712 pairs
Processed 130/500 sentences, 1805 pairs
Processed 140/500 sentences, 1913 pairs
Processed 150/500 sentences, 2063 pairs
Processed 160/500 sentences, 2208 pairs
Processed 170/500 sentences, 2331 pairs
Processed 180/500 sentences, 2516 pairs
Processed 190/500 sentences, 2613 pairs
Processed 200/500 sentences, 2774 pairs
Processed 210/500 sentences, 2936 pairs
Processed 220/500 sentences, 3136 pairs
Processed 240/500 sentences, 3363 pairs
Processed 250/500 sentences, 3505 pairs
Processed 260/500 sentences, 3667 pairs
Processed 270/500 sentences, 3837 pairs
Process

In [15]:
# STEP 13: Classification functions

def make_diff(feat1, feat2):
    return [feat2[k] - feat1[k] for k in sorted(feat1.keys())]

def create_pairs(corpus_feats, variant_feats):
    X = []
    y = []

    for c, v in zip(corpus_feats, variant_feats):
        X.append(make_diff(c, v))
        y.append(0)

    for c, v in zip(corpus_feats, variant_feats):
        X.append(make_diff(v, c))
        y.append(1)

    return np.array(X), np.array(y)

def cross_validate(X, y, n_folds=5):
    kf = KFold(n_splits=n_folds, shuffle=True, random_state=42)
    accuracies = []
    weights = []

    for train_idx, test_idx in kf.split(X):
        clf = LogisticRegression(solver='lbfgs', max_iter=100)
        clf.fit(X[train_idx], y[train_idx])
        accuracies.append(clf.score(X[test_idx], y[test_idx]))
        weights.append(clf.coef_[0])

    return np.mean(accuracies), np.mean(weights, axis=0)

In [16]:
# STEP 14: Run classification experiments (Table 1)

experiments = [
    ('Lexical Surprisal', ['lex_sum']),
    ('UIDglob (lex)', ['lex_UIDglob']),
    ('UIDloc (lex)', ['lex_UIDloc']),
    ('UIDglobNorm (lex)', ['lex_UIDglobNorm']),
    ('UIDlocNorm (lex)', ['lex_UIDlocNorm']),
    ('UIDlocPrevNorm (lex)', ['lex_UIDlocPrevNorm']),
    ('Lex+UIDglob', ['lex_sum', 'lex_UIDglob']),
    ('Lex+UIDloc', ['lex_sum', 'lex_UIDloc']),
    ('Lex+UIDglobNorm', ['lex_sum', 'lex_UIDglobNorm']),
    ('Lex+UIDlocNorm', ['lex_sum', 'lex_UIDlocNorm']),
    ('Syntactic Surprisal', ['syn_sum']),
    ('UIDglob (syn)', ['syn_UIDglob']),
    ('Syn+UIDglob', ['syn_sum', 'syn_UIDglob']),
]

print('='*60)
print('CLASSIFICATION RESULTS (Table 1 Replication)')
print('='*60)

for name, feat_names in experiments:
    corp = [{k: f[k] for k in feat_names} for f in corpus_features]
    var = [{k: f[k] for k in feat_names} for f in variant_features]

    X, y = create_pairs(corp, var)
    acc, w = cross_validate(X, y, n_folds=5)

    print(f"{name:30s}: {acc*100:5.2f}%  weights: {[round(x, 2) for x in w]}")

CLASSIFICATION RESULTS (Table 1 Replication)
Lexical Surprisal             : 95.94%  weights: [np.float64(-0.28)]
UIDglob (lex)                 : 78.18%  weights: [np.float64(-0.29)]
UIDloc (lex)                  : 86.26%  weights: [np.float64(-0.15)]
UIDglobNorm (lex)             : 94.75%  weights: [np.float64(-26.1)]
UIDlocNorm (lex)              : 95.01%  weights: [np.float64(-14.26)]
UIDlocPrevNorm (lex)          : 86.34%  weights: [np.float64(-0.03)]
Lex+UIDglob                   : 95.89%  weights: [np.float64(-0.03), np.float64(-0.27)]
Lex+UIDloc                    : 96.79%  weights: [np.float64(-0.05), np.float64(-0.24)]
Lex+UIDglobNorm               : 95.94%  weights: [np.float64(0.0), np.float64(-0.28)]
Lex+UIDlocNorm                : 96.58%  weights: [np.float64(-3.57), np.float64(-0.23)]
Syntactic Surprisal           : 90.98%  weights: [np.float64(-0.35)]
UIDglob (syn)                 : 91.20%  weights: [np.float64(1.05)]
Syn+UIDglob                   : 91.86%  weights: [np.

In [17]:
# STEP 15: Correlation analysis (Table 2)

def pearson_corr(x, y):
    cx = [xi for xi, yi in zip(x, y) if not (np.isnan(xi) or np.isnan(yi))]
    cy = [yi for xi, yi in zip(x, y) if not (np.isnan(xi) or np.isnan(yi))]
    if len(cx) < 2:
        return 0
    return stats.pearsonr(cx, cy)[0]

print('\n' + '='*60)
print('CORRELATION ANALYSIS (Table 2 Replication)')
print('='*60)

lex_surp = [f['lex_sum'] for f in corpus_features]
syn_surp = [f['syn_sum'] for f in corpus_features]

print('\nLexical Surprisal vs UID Measures:')
print('-'*40)
for name in ['UIDglob', 'UIDloc', 'UIDglobNorm', 'UIDlocNorm', 'UIDlocPrevNorm']:
    vals = [f[f'lex_{name}'] for f in corpus_features]
    corr = pearson_corr(lex_surp, vals)
    print(f"  {name:20s}: {corr:7.4f}")

print('\nSyntactic Surprisal vs UID Measures:')
print('-'*40)
for name in ['UIDglob', 'UIDloc', 'UIDglobNorm', 'UIDlocNorm', 'UIDlocPrevNorm']:
    vals = [f[f'syn_{name}'] for f in corpus_features]
    corr = pearson_corr(syn_surp, vals)
    print(f"  {name:20s}: {corr:7.4f}")


CORRELATION ANALYSIS (Table 2 Replication)

Lexical Surprisal vs UID Measures:
----------------------------------------
  UIDglob             : -0.0218
  UIDloc              : -0.1868
  UIDglobNorm         :  0.3979
  UIDlocNorm          :  0.1942
  UIDlocPrevNorm      : -0.0261

Syntactic Surprisal vs UID Measures:
----------------------------------------
  UIDglob             : -0.3970
  UIDloc              : -0.3886
  UIDglobNorm         :  0.1351
  UIDlocNorm          :  0.0504
  UIDlocPrevNorm      : -0.2026


## Summary

This notebook replicates the UID paper:

1. Loaded actual Hindi treebank data from your Google Drive
2. Trained trigram model for lexical surprisal
3. Trained syntactic surprisal model
4. Generated variants by permuting preverbal constituents
5. Computed all 5 UID measures
6. Ran classification experiments
7. Computed correlations

**Expected findings (from paper):**
- Lexical surprisal: ~90% accuracy
- UID alone: 50-73% accuracy
- Lexical + UID: No significant improvement
- UIDglob negatively correlated with surprisal
- UIDglobNorm positively correlated

**Conclusion:** UID measures don't add value beyond surprisal.