In [None]:
import pandas as pd
import nltk
from nltk.corpus import wordnet as wn
from nltk.tokenize import word_tokenize
from nltk import pos_tag
import re
import networkx as nx
import matplotlib.pyplot as plt
import seaborn as sns
import string
!pip install contractions
import contractions
!pip install python-Levenshtein
from Levenshtein import distance as levenshtein_distance
!pip install eng_to_ipa
import eng_to_ipa as ipa
import numpy as np
from collections import Counter
from nltk.corpus import words
!pip install pyspellchecker
from spellchecker import SpellChecker

nltk.download('words')
nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')
nltk.download('wordnet')
nltk.download('omw-1.4')
nltk.download('averaged_perceptron_tagger_eng')
nltk.download('punkt_tab')

In [None]:
#Maps pos tags to tags used by WordNet
def get_wordnet_pos(treebank_tag):
    if treebank_tag.startswith('J'):
        return wn.ADJ
    elif treebank_tag.startswith('V'):
        return wn.VERB
    elif treebank_tag.startswith('N'):
        return wn.NOUN
    elif treebank_tag.startswith('R'):
        return wn.ADV
    else:
        return None

In [None]:
animal_synset = wn.synset("animal.n.01")

def get_synset(word, pos, sft):
    """
    Retrieve synset for word with specified POS tag
    pick the animal meaning (e.g. pick Synset('wren.n.02') for bird rather than n.01 for the church)
    if no animal meaning, pick the first meaning
    """
    wn_pos = get_wordnet_pos(pos)
    if wn_pos:
        synsets = wn.synsets(word, pos=wn_pos)
        if synsets:
            if sft:
                for syn in synsets:
                    if animal_synset in syn.lowest_common_hypernyms(animal_synset):
                        return syn
                return synsets[0]
            return synsets[0]
    return None

In [None]:
def get_similarity_edges(row, remove_none_syns):
    text = row['asr_cleaned']
    filename = row['file']

    tokens = word_tokenize(contractions.fix(text))
    tagged = pos_tag(tokens)
    tagged = [pair for pair in tagged if pair[0] not in string.punctuation]


    # Keep only nouns and pronouns (change this to whatever pos tags to include)
    # valid_tags = ('N', 'PRP', 'WP')
    if 'SFT' in filename:
        valid_tags = ('N')
    else:
        valid_tags = None

    if valid_tags:
        filtered = [(word, pos) for word, pos in tagged if pos.startswith(valid_tags)]
    else:
        filtered = [(word, pos) for word, pos in tagged]

    sft = 'SFT' in filename

    if remove_none_syns:
        filtered_syns = [(word, get_synset(word, pos, sft)) for word, pos in filtered if get_synset(word, pos, sft) is not None]
    else:
        filtered_syns = [(word, get_synset(word, pos, sft)) for word, pos in filtered]

    filtered_phons = [(word, ipa.convert(word)) for word, pos in filtered]

    # Return empty dict if similarity cannot be computes (<2 words)
    # !!
    if len(filtered_syns) < 2:
        return [], []

    sem_edges = []
    for i in range(len(filtered_syns) - 1):
        word1, syn1 = filtered_syns[i]
        word2, syn2 = filtered_syns[i + 1]

        if syn1 and syn2:
            sim = syn1.wup_similarity(syn2)
            # sim is None if the pos of syn1 and syn2 are different, or there's no common ancestor in the tree
            # this case is really rare - so far none in the train data
            edge = (word1.lower(), word2.lower(), {'sem_similarity': sim})
        else:
            edge = (word1.lower(), word2.lower(), {'sem_similarity': None})
        sem_edges.append(edge)

    phon_edges = []
    for i in range(len(filtered_phons) - 1):
        word1, phon1 = filtered_phons[i]
        word2, phon2 = filtered_phons[i + 1]

        if phon1 and phon2:
            sim = 1 - (levenshtein_distance(phon1, phon2) / max(len(phon1), len(phon2)))
            edge = (word1.lower(), word2.lower(), {'phon_similarity': sim})
        else:
            edge = (word1.lower(), word2.lower(), {'phon_similarity': None})
        phon_edges.append(edge)


    values = [edge[2]['sem_similarity'] for edge in sem_edges if edge[2]['sem_similarity'] is not None]
    if values:
        average = sum(values) / len(values)

        for edge in sem_edges:
            word1, word2, attr = edge
            if attr['sem_similarity'] is None:
                attr['sem_similarity'] = average
    else:
        return get_similarity_edges(row, remove_none_syns=True)
    return sem_edges, phon_edges

In [None]:
def get_similarity_edges_valid_only(row):
    syns = row['valid_tokens']
    phons = [(token, ipa.convert(token)) for token, syn in syns]

    # Return empty dict if similarity cannot be computes (<2 words)
    if len(syns) < 2:
        return [], []

    sem_edges = []
    for i in range(len(syns) - 1):
        word1, syn1 = syns[i]
        word2, syn2 = syns[i + 1]

        if syn1 and syn2:
            sim = syn1.wup_similarity(syn2)
            # sim is None if the pos of syn1 and syn2 are different, or there's no common ancestor in the tree
            # this case is really rare - so far none in the train data
            edge = (word1.lower(), word2.lower(), {'similarity': sim})
        else:
            edge = (word1.lower(), word2.lower(), {'similarity': None})
        sem_edges.append(edge)

    phon_edges = []
    for i in range(len(phons) - 1):
        word1, phon1 = phons[i]
        word2, phon2 = phons[i + 1]

        if phon1 and phon2:
            sim = 1 - (levenshtein_distance(phon1, phon2) / max(len(phon1), len(phon2)))
            edge = (word1.lower(), word2.lower(), {'similarity': sim})
        else:
            edge = (word1.lower(), word2.lower(), {'similarity': None})
        phon_edges.append(edge)

    sem_values = [edge[2]['similarity'] for edge in sem_edges if edge[2]['similarity'] is not None]
    sem_average = sum(sem_values) / len(sem_values)

    for edge in sem_edges:
        word1, word2, attr = edge
        if attr['similarity'] is None:
            attr['similarity'] = sem_average

    return sem_edges, phon_edges

In [None]:
def plot_graph(graph):
    # 'sem_similarity' or 'phon_similarity'
    plt.figure(figsize=(7, 7))
    pos = nx.spring_layout(graph, weight='similarity')  # Layout for positioning
    nx.draw(
        graph, pos, with_labels=True, node_color='lightblue', edgecolors='black',
        node_size=1600, font_size=16, connectionstyle="arc3,rad=0.2", arrows=True
    )
    plt.show()


def get_graph(edges):
    gr = nx.MultiDiGraph()
    gr.add_edges_from(edges)
    return gr

def get_graph_info(graph):
    res = {}
    lengths = dict(nx.all_pairs_dijkstra_path_length(graph, weight='similarity'))
    diameter = max(max(lengths[u].values()) for u in graph.nodes)
    res['diameter'] = diameter

    res['number_of_nodes'] = graph.number_of_nodes()
    res['number_of_edges'] = graph.number_of_edges()
    res['PE'] =  (np.array(list(Counter(graph.edges()).values()))>1).sum()

    res['LCC'] =  nx.algorithms.components.number_weakly_connected_components(graph)
    res['LSC'] =  nx.algorithms.components.number_strongly_connected_components(graph)

    degrees = list(dict(graph.degree()).values())
    res['degree_average'] =  np.mean(degrees)
    res['degree_std'] =  np.std(degrees)
    return res


In [None]:
# Read in dataframe
df = pd.read_csv('/content/ASR transcripts - Process-train_manual_vs_asr.csv') #replace with file path
df['id'] = df['file'].str.split('__').str[0]
metadata = pd.read_csv('PROCESS_METADATA_ALL.csv')
df = df.merge(metadata, left_on='id', right_on='anyon_IDs')

In [None]:
# Preprocessing to only keep patient speech and remove diarisation markers (Pat: and Oth:)

def extract_patient_speech(text):
    # Keep only lines that start with Pat:
    patient_lines = re.findall(r'Pat:\s*(.*?)(?=Pat:|Oth:|$)', text, flags=re.DOTALL)
    # Join them into one cleaned string
    return ' '.join(line.strip() for line in patient_lines)

# Apply to the 'asr' transcripts
df['asr_cleaned'] = df['asr'].apply(extract_patient_speech)

In [None]:
spell = SpellChecker()
word_list = set(w.lower() for w in words.words())

def get_vf_scores(row, task):
    text = row['asr_cleaned']
    filename = row['file']
    tokens = word_tokenize(contractions.fix(text))
    tokens = [token for token in tokens if token not in string.punctuation]
    # print()
    # print(row['asr_cleaned'])
    # print(tokens)
    real_tokens = [token for token in tokens if token.lower() in word_list or token in spell]
    if task == 'PFT':
        real_tokens_p = [token for token in real_tokens if token.lower().startswith('p')]
        if row['id'] == 'Process-rec-114':
            real_tokens_p = [token for token in real_tokens if token.lower().startswith('b')]
        real_tokens_p_syns = []
        for token in real_tokens_p:
            syn = wn.synsets(token)
            if syn:
                real_tokens_p_syns.append((token, syn[0]))
            else:
                real_tokens_p_syns.append((token, None))

        # print(real_tokens_p)
        return real_tokens_p_syns, tokens, len(real_tokens_p), len(tokens), len(real_tokens_p)/len(tokens)
    if task == 'SFT':
        real_tokens_animal = []
        for token in tokens:
            synsets = wn.synsets(token, pos=wn.NOUN)
            if synsets:
                # is_animal = False
                for syn in synsets:
                    if animal_synset in syn.lowest_common_hypernyms(animal_synset):
                        real_tokens_animal.append((token, syn))
                        break
                #         is_animal = True
                # if is_animal:
                #     real_tokens_animal.append(token)
        # print(real_tokens_animal)
        return real_tokens_animal, tokens, len(real_tokens_animal), len(tokens), len(real_tokens_animal)/len(tokens)




In [None]:
# get_similarity_edges_valid_only

sft_df = df[df['file'].str.contains('SFT')]
sft_df[['valid_tokens', 'tokens', 'n_valid_tokens', 'n_totoal_tokens', 'valid_token_ratio']] = sft_df.apply(lambda x: get_vf_scores(x, task='SFT'), axis=1, result_type='expand')

pft_df = df[df['file'].str.contains('PFT')]
pft_df[['valid_tokens', 'tokens', 'n_valid_tokens', 'n_totoal_tokens', 'valid_token_ratio']] = pft_df.apply(lambda x: get_vf_scores(x, task='PFT'), axis=1, result_type='expand')

sft_df[['sem_edges', 'phon_edges']] = sft_df.apply(lambda x: get_similarity_edges_valid_only(x), axis=1, result_type='expand')
sft_graph_info_sem = sft_df['sem_edges'].apply(lambda x: get_graph_info(get_graph(x))).apply(pd.Series)
sft_graph_info_phon = sft_df['phon_edges'].apply(lambda x: get_graph_info(get_graph(x))).apply(pd.Series)

pft_df[['sem_edges', 'phon_edges']] = pft_df.apply(lambda x: get_similarity_edges_valid_only(x), axis=1, result_type='expand')
pft_graph_info_sem = pft_df['sem_edges'].apply(lambda x: get_graph_info(get_graph(x))).apply(pd.Series)
pft_graph_info_phon = pft_df['phon_edges'].apply(lambda x: get_graph_info(get_graph(x))).apply(pd.Series)


In [None]:
sft_df = pd.concat([sft_df, sft_graph_info_sem.add_prefix('sem_'), sft_graph_info_phon.add_prefix('phon_')], axis=1)
pft_df = pd.concat([pft_df, pft_graph_info_sem.add_prefix('sem_'), pft_graph_info_phon.add_prefix('phon_')], axis=1)
pft_df_selected = pft_df[['id', 'valid_tokens', 'tokens',
       'n_valid_tokens', 'n_totoal_tokens', 'valid_token_ratio', 'sem_edges',
       'phon_edges', 'sem_diameter', 'sem_number_of_nodes',
       'sem_number_of_edges', 'sem_PE', 'sem_LCC', 'sem_LSC',
       'sem_degree_average', 'sem_degree_std', 'phon_diameter',
       'phon_number_of_nodes', 'phon_number_of_edges', 'phon_PE', 'phon_LCC',
       'phon_LSC', 'phon_degree_average', 'phon_degree_std']]
merged_df = sft_df.merge(pft_df_selected, on='id', suffixes=('_sft', '_pft'))


In [None]:
merged_df.columns

In [None]:
merged_df['label'] = merged_df['diagnosis'].map({'HC': 0, 'MCI': 1, 'Dementia': 2})

In [None]:
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import classification_report
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.naive_bayes import MultinomialNB
from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import StandardScaler

classifiers = {
    "Logistic Regression": LogisticRegression(max_iter=1000),
    "Random Forest": RandomForestClassifier(),
    "SVM": SVC(probability=True),
    "Naive Bayes": MultinomialNB(),
    "KNN": KNeighborsClassifier()
}

skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

X = merged_df[['n_valid_tokens_sft', 'n_totoal_tokens_sft',
       'valid_token_ratio_sft',
       'sem_diameter_sft', 'sem_number_of_nodes_sft',
       'sem_number_of_edges_sft', 'sem_PE_sft', 'sem_LCC_sft', 'sem_LSC_sft',
       'sem_degree_average_sft', 'sem_degree_std_sft', 'phon_diameter_sft',
       'phon_number_of_nodes_sft', 'phon_number_of_edges_sft', 'phon_PE_sft',
       'phon_LCC_sft', 'phon_LSC_sft', 'phon_degree_average_sft',
       'phon_degree_std_sft',
       'n_valid_tokens_pft', 'n_totoal_tokens_pft', 'valid_token_ratio_pft',
       'sem_diameter_pft',
       'sem_number_of_nodes_pft', 'sem_number_of_edges_pft', 'sem_PE_pft',
       'sem_LCC_pft', 'sem_LSC_pft', 'sem_degree_average_pft',
       'sem_degree_std_pft', 'phon_diameter_pft', 'phon_number_of_nodes_pft',
       'phon_number_of_edges_pft', 'phon_PE_pft', 'phon_LCC_pft',
       'phon_LSC_pft', 'phon_degree_average_pft', 'phon_degree_std_pft']]
y = merged_df['label']

for clf_name, clf in classifiers.items():
    print(f"\n=== {clf_name} ===")
    all_preds = []
    all_trues = []

    for fold, (train_idx, test_idx) in enumerate(skf.split(X, y), 1):
        X_train, X_test = X.iloc[train_idx], X.iloc[test_idx]
        y_train, y_test = y[train_idx], y[test_idx]

        scaler = StandardScaler()
        X_train = scaler.fit_transform(X_train)
        X_test = scaler.transform(X_test)

        if isinstance(clf, MultinomialNB):
            X_train = np.abs(X_train)
            X_test = np.abs(X_test)

        clf.fit(X_train, y_train)
        preds = clf.predict(X_test)

        all_preds.extend(preds)
        all_trues.extend(y_test)

    print(classification_report(all_trues, all_preds))


In [None]:
sft_df_sem = pd.concat([sft_df, sft_graph_info_sem], axis=1)

for col in sft_graph_info_sem.columns:
    try:
        sns.histplot(data=sft_df_sem, x=col, hue='diagnosis', kde=True, element='step', stat='density')
        plt.title(f'SFT - Histogram of {col} by diagnosis')
        plt.show()
    except:
        print(col)

In [None]:
sft_df_phon = pd.concat([sft_df, sft_graph_info_phon], axis=1)

for col in sft_graph_info_phon.columns:
    try:
        sns.histplot(data=sft_df_phon, x=col, hue='diagnosis', kde=True, element='step', stat='density')
        plt.title(f'SFT - Histogram of {col} by diagnosis')
        plt.show()
    except:
        print(col)

In [None]:
for name, sub_group in pft_df.groupby('diagnosis'):
    print(name)
    for _, row in sub_group.iterrows():
        graph = get_graph(row['phon_edges'])
        plot_graph(graph)

In [None]:
# sft_df['edges'].iloc[0].apply(lambda x: plot_graph(get_graph(x)))
for name, sub_group in sft_df.groupby('diagnosis'):
    print(name)
    for _, row in sub_group.iterrows():
        graph = get_graph(row['sem_edges'])
        plot_graph(graph)


In [None]:
pft_df_phon = pd.concat([pft_df, pft_graph_info_phon], axis=1)

for col in pft_graph_info_phon.columns:
    try:
        sns.histplot(data=pft_df_phon, x=col, hue='diagnosis', kde=True, element='step', stat='density')
        plt.title(f'PFT - Histogram of {col} by diagnosis')
        plt.show()
    except:
        print(col)

In [None]:
pft_df_sem = pd.concat([pft_df, pft_graph_info_sem], axis=1)

for col in pft_graph_info_sem.columns:
    try:
        sns.histplot(data=pft_df_sem, x=col, hue='diagnosis', kde=True, element='step', stat='density')
        plt.title(f'PFT - Histogram of {col} by diagnosis')
        plt.show()
    except:
        print(col)

In [None]:
for col in ['n_valid_tokens', 'n_totoal_tokens', 'valid_token_ratio']:
    try:
        sns.histplot(data=sft_df, x=col, hue='diagnosis', kde=True, element='step', stat='density')
        plt.title(f'SFT - Histogram of {col} by diagnosis')
        plt.show()
    except:
        print(col)

In [None]:
for col in ['n_valid_tokens', 'n_totoal_tokens', 'valid_token_ratio']:
    try:
        sns.histplot(data=pft_df, x=col, hue='diagnosis', kde=True, element='step', stat='density')
        plt.title(f'PFT - Histogram of {col} by diagnosis')
        plt.show()
    except:
        print(col)

In [None]:
merged_df = sft_df.merge(pft_df[['id', 'n_valid_tokens']], on='id', suffixes=('_sft', '_pft'))
merged_df['n_valid_tokens_diff'] = merged_df['n_valid_tokens_sft'] - merged_df['n_valid_tokens_pft']

sns.histplot(data=merged_df, x='n_valid_tokens_diff', hue='diagnosis', kde=True, element='step', stat='density')
plt.title(f'Histogram of n_valid_tokens_diff by diagnosis')
plt.show()
