The UPOS field contains a part-of-speech tag from the universal POS tag set, while the XPOS optionally contains a language-specific (or even treebank-specific) part-of-speech / morphological tag

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from functions import pickle_load
from process_conllu import ConlluDataset

In [2]:
def merge_word_hidden_pos(viterbi_pos, true_pos, obs_seqs):
    merged_pos = []
    for seq_idx, (pred_hidden_seq, true_hidden_seq, obs_seq) in enumerate(zip(
        viterbi_pos, true_pos, obs_seqs)):
        for hidden_state, pos, obs_state in zip(pred_hidden_seq, true_hidden_seq, obs_seq):
            merged_pos.append((seq_idx, hidden_state, pos, obs_state))
    return pd.DataFrame(merged_pos, columns=["sentence", "hidden state", "pos", "token"])

In [12]:
hmm_upos_df = pd.read_csv("results/eval_hmm_upos.csv")
hmm_xpos_df = pd.read_csv("results/eval_hmm_xpos.csv")
bert_upos_df = pd.read_csv("results/eval_bert_upos.csv")
bert_xpos_df = pd.read_csv("results/eval_bert_xpos.csv")
dataset: ConlluDataset = pickle_load("checkpoints/dataset.pkl")
viterbi_upos: list[list[int]] = pickle_load("checkpoints/viterbi_upos.pkl")
viterbi_xpos: list[list[int]] = pickle_load("checkpoints/viterbi_xpos.pkl")

In [3]:
dataset: ConlluDataset = pickle_load("checkpoints/dataset.pkl")
viterbi_upos: list[list[int]] = pickle_load("checkpoints/viterbi_upos.pkl")

In [4]:
viterbi_upos_df = merge_word_hidden_pos(viterbi_upos, dataset.upos, dataset.sequences)
# viterbi_xpos_df = merge_word_hidden_pos(viterbi_xpos, dataset.xpos, dataset.sequences)

In [5]:
viterbi_upos_df["hidden state"].groupby(viterbi_upos_df["hidden state"]).size()

hidden state
0      40087
1      20069
2      16881
3      32406
4     141473
5      86618
6     162839
7      38273
8      25998
9      38408
10     70538
11     44780
12    105267
13     23326
14     50313
15     16807
16     35618
Name: hidden state, dtype: int64

In [6]:
upos_hidden_state_to_token = viterbi_upos_df[["hidden state", "token"]]\
  .groupby("token")\
  .agg(count=("token", "count"),
       hidden_states=("hidden state", frozenset))\
  .reset_index()

In [7]:
upos_hidden_state_to_token.sort_values("count", ascending=False).head(20)

Unnamed: 0,token,count,hidden_states
19,",",48723,"(2, 9, 10, 11, 12)"
30807,the,47975,"(0, 9, 4)"
27,.,39020,"(9, 12)"
108,[NUM],23927,"(0, 3, 4, 6, 9, 10, 16)"
21245,of,23005,"(2, 5, 14)"
31175,to,22352,"(9, 13, 11, 5)"
279,a,20149,"(0, 4, 5, 8, 9, 11, 13)"
15156,in,16931,"(0, 2, 5, 7, 12, 13, 14)"
1383,and,16668,"(1, 2, 3, 5, 6, 9, 10, 11, 12, 14, 15)"
15,'s,9326,"(9, 10, 11, 5)"


In [8]:
set([len(s) for s in upos_hidden_state_to_token["hidden_states"]])

{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14}

In [5]:
viterbi_xpos_df["hidden state"].groupby(viterbi_xpos_df["hidden state"]).size()

hidden state
6     663367
10     48723
30     39020
40    104502
44     94089
Name: hidden state, dtype: int64

In [6]:
xpos_hidden_state_to_token = viterbi_xpos_df[["hidden state", "token"]]\
  .groupby("token")\
  .agg(count=("token", "count"),
       hidden_states=("hidden state", frozenset))\
  .reset_index()

In [7]:
xpos_hidden_state_to_token.sort_values("count", ascending=False).head(20)

Unnamed: 0,token,count,hidden_states
19,",",48723,(10)
30807,the,47975,(40)
27,.,39020,(30)
108,[NUM],23927,(6)
21245,of,23005,(6)
31175,to,22352,(6)
279,a,20149,(6)
15156,in,16931,(6)
1383,and,16668,(6)
15,'s,9326,(6)


In [26]:
set([len(s) for s in xpos_hidden_state_to_token["hidden_states"]])

{1, 2}

In [9]:
xpos_hidden_state_to_token[np.array([len(s) for s in xpos_hidden_state_to_token["hidden_states"]]) == 2]["hidden_states"]

24       (40, 6)
303      (40, 6)
426      (40, 6)
636      (40, 6)
724      (40, 6)
          ...   
33631    (40, 6)
33910    (40, 6)
34230    (40, 6)
34241    (40, 6)
34270    (40, 6)
Name: hidden_states, Length: 191, dtype: object

In [10]:
s = xpos_hidden_state_to_token[np.array([len(s) for s in xpos_hidden_state_to_token["hidden_states"]]) == 2]["hidden_states"]
s.groupby(s).size()

hidden_states
(40, 6)    191
Name: hidden_states, dtype: int64