<a href="https://colab.research.google.com/github/IVornehm/Thesis/blob/main/Embeddings_and_Neural_Network.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Word Embeddings

In [None]:

import os, re, ast, json
import pandas as pd
import numpy as np
from tqdm import tqdm
from collections import Counter, defaultdict
from scipy.spatial.distance import pdist
from scipy.cluster.hierarchy import linkage, fcluster
import spacy
from sentence_transformers import SentenceTransformer
from concurrent.futures import ThreadPoolExecutor, as_completed
from multiprocessing import cpu_count
import torch
from sklearn.cluster import MiniBatchKMeans
from sklearn.metrics import silhouette_score
from kneed import KneeLocator
import matplotlib.pyplot as plt
import ast


from google.colab import drive
drive.mount('/content/drive')

root_dir = "/content/drive/My Drive/Thesis Linguistics"
file_path = "/content/childes_data_filtered_2.xlsx"
output_path = os.path.join(root_dir, "childes_data_filtered_2.xlsx")

def parse_lists(cell):
    if isinstance(cell, str):
        if "|" in cell:
            return cell.split("|")
        try:
            return ast.literal_eval(cell)
        except:
            return [cell]
    return cell

def extract_clean_tokens(utt):
    raw_tokens = re.findall(r'\b\w+\b', utt)
    filtered_tokens = []
    for raw_token in raw_tokens:
        parts = raw_token.split("_") if "_" in raw_token else [raw_token]
        for token in parts:
            t_lower = token.lower()
            if not t_lower: continue
            if any(sym in t_lower for sym in {"&", "@", "=", "{", "}"}): continue
            if any(sub in t_lower for sub in {"babbl", "moan", "vocal"}): continue
            if any(char.isdigit() for char in t_lower): continue
            if t_lower in {"xxx", "yyy", "www", "nonspeech"}: continue
            if len(t_lower) == 1 and t_lower not in {"i", "a", "s"}: continue
            if len(t_lower) == 3 and t_lower[0] == t_lower[1] == t_lower[2]: continue
            if re.search(r"[aeiou]h$", t_lower): continue
            if t_lower in {"haha", "uhhum", "uhhuhhaw", "mm", "er", "ew", "ha", "hee"}: continue
            if "hm" in t_lower or "mh" in t_lower or "hrm" in t_lower: continue
            if re.search(r"(.{2,}).*\1", t_lower): continue
            filtered_tokens.append(t_lower)
    return filtered_tokens

def normalize_batch(vectors):
    norms = np.linalg.norm(vectors, axis=1, keepdims=True)
    return vectors / np.clip(norms, 1e-10, None)

def compute_inequality(counter):
    values = list(counter.values())
    total = sum(values)
    if total == 0: return [0.0] * 5
    mean = np.mean(values)
    std = np.std(values)
    norm_cv = round(std / mean, 4) if mean > 0 else 0.0
    gini = lambda x: (np.sum((2 * np.arange(1, len(x)+1) - len(x) - 1) * np.sort(x))) / (len(x) * np.sum(x))
    gini_score = round(gini(values), 4)
    top = sorted(values, reverse=True)
    return [norm_cv, gini_score, round(top[0]/total,4), round(sum(top[:3])/total,4), round(sum(top[:5])/total,4)]

df = pd.read_excel(file_path)
df["RelativePath"] = df["RelativePath"].apply(parse_lists)

file_data, all_utt_set = {}, set()
all_paths = set(p for paths in df["RelativePath"] for p in paths)
for path in tqdm(all_paths, desc="Loading .cha files"):
    full = os.path.join(root_dir, path.replace("\\", "/"))
    if os.path.exists(full):
        try:
            with open(full, "r", encoding="utf-8") as f:
                file_data[path] = re.findall(r"^(\*[A-Z]+[A-Za-z0-9_]*):\s*(.*)", f.read(), flags=re.MULTILINE)
        except:
            continue



speaker_tokens, speaker_utterances = {}, {}
combined_vocab = set()
all_utt_set = set()

for idx, row in tqdm(df.iterrows(), total=len(df), desc="Collecting utterances"):
    tokens, utts = {"CHI": [], "MOT": []}, {"CHI": [], "MOT": []}
    for path in row["RelativePath"]:
        for spk_code, utt in file_data.get(path.replace("\\", "/"), []):
            spk = spk_code.lstrip("*")
            if spk in tokens:
                segments = re.split(r"[.?!]+", utt)
                for seg in segments:
                    seg = seg.strip()
                    if not seg:
                        continue
                    tks = extract_clean_tokens(seg)
                    if tks:
                        utt_str = " ".join(tks)
                        tokens[spk].append(tks)
                        utts[spk].append(utt_str)
                        combined_vocab.update(tks)
                        all_utt_set.add(utt_str)
    speaker_tokens[idx] = tokens
    speaker_utterances[idx] = utts


nlp = spacy.load("en_core_web_sm", disable=["parser", "ner"])
word_pos_cache = {}
for doc in tqdm(nlp.pipe(list(combined_vocab), batch_size=1000), total=len(combined_vocab)//1000+1, desc="POS tagging"):
    for token in doc:
        if token.is_alpha:
            word_pos_cache[token.text] = token.pos_ in {"NOUN", "VERB", "ADJ", "ADV"}
content_word_types = {w for w, is_content in word_pos_cache.items() if is_content}


for idx in speaker_tokens:
    for spk in ["CHI", "MOT"]:
        new_tokens, new_utts = [], []
        for tks, utt in zip(speaker_tokens[idx][spk], speaker_utterances[idx][spk]):
            if any(tok in content_word_types for tok in tks):
                new_tokens.extend(tks)
                new_utts.append(utt)
        speaker_tokens[idx][spk] = new_tokens
        speaker_utterances[idx][spk] = new_utts



model = SentenceTransformer("all-MiniLM-L6-v2", device="cuda" if torch.cuda.is_available() else "cpu")
word_emb_cache = dict(zip(list(combined_vocab), model.encode(list(combined_vocab), batch_size=256, show_progress_bar=True)))
utt_emb_cache = dict(zip(list(all_utt_set), model.encode(list(all_utt_set), batch_size=256, show_progress_bar=True)))

thresholds = [0.25, 0.50, 0.75, 1.00]
metric_names = ["TokenNormCV", "Gini", "MaxDominance", "Top3Dominance", "Top5Dominance"]
for spk in ["CHI", "MOT"]:
    for t in thresholds:
        df[f"{spk}_EmbWord_t{t:.2f}"] = np.nan
        df[f"{spk}_EmbUtt_t{t:.2f}"] = np.nan
        for m in metric_names:
            df[f"{spk}_{m}_EmbWord_t{t:.2f}"] = np.nan
            df[f"{spk}_{m}_EmbUtt_t{t:.2f}"] = np.nan

centroid_records_word, centroid_records_utt = [], []

def process_row(idx):
    row_res = {}
    for spk in ["CHI", "MOT"]:
        words = [w for w in speaker_tokens[idx][spk] if w in content_word_types]
        type_counts = Counter(words)
        if len(type_counts) >= 2:
            embs = normalize_batch(np.array([word_emb_cache[w] for w in type_counts]))
            Z = linkage(pdist(embs, "cosine"), method="average")
            for t in thresholds:
                clusters = fcluster(Z, t=t, criterion="distance")
                cluster_map = defaultdict(list)
                for w, c in zip(type_counts.keys(), clusters):
                    cluster_map[c].append((w, type_counts[w]))
                row_res[f"{spk}_EmbWord_t{t:.2f}"] = len(cluster_map)
                row_res.update({f"{spk}_{m}_EmbWord_t{t:.2f}": v for m, v in zip(metric_names, compute_inequality(Counter({c: sum(freq for _, freq in ws) for c, ws in cluster_map.items()})))})
                for ws in cluster_map.values():
                    emb = np.mean([word_emb_cache[w] for w, _ in ws], axis=0)
                    centroid_records_word.append({"RowID": idx, "Speaker": spk, "Threshold": t, "Centroid": emb})

        utts = [u for u in speaker_utterances[idx][spk] if u in utt_emb_cache]
        if len(utts) >= 2:
            embs = normalize_batch(np.array([utt_emb_cache[u] for u in utts]))
            Z = linkage(pdist(embs, "cosine"), method="average")
            for t in thresholds:
                clusters = fcluster(Z, t=t, criterion="distance")
                cluster_map = defaultdict(list)
                for u, c in zip(utts, clusters):
                    cluster_map[c].append(u)
                row_res[f"{spk}_EmbUtt_t{t:.2f}"] = len(cluster_map)
                row_res.update({f"{spk}_{m}_EmbUtt_t{t:.2f}": v for m, v in zip(metric_names, compute_inequality(Counter({c: len(us) for c, us in cluster_map.items()})))})
                for us in cluster_map.values():
                    emb = np.mean([utt_emb_cache[u] for u in us], axis=0)
                    centroid_records_utt.append({"RowID": idx, "Speaker": spk, "Threshold": t, "Centroid": emb})
    return idx, row_res

print("Starting threaded row processing...")
with ThreadPoolExecutor(max_workers=min(16, cpu_count())) as executor:
    futures = [executor.submit(process_row, i) for i in df.index]
    for future in tqdm(as_completed(futures), total=len(futures), desc="Row clustering"):
        idx, res = future.result()
        for k, v in res.items():
            df.at[idx, k] = v



def assign_meta_topics_and_save(centroids, label_prefix, emb_cache, df, t, save_dir):
    print(f"Clustering {len(centroids)} centroids for {label_prefix} at threshold t={t:.2f}")

    mat = np.vstack([c["Centroid"] for c in centroids])
    mat = normalize_batch(mat)


    distortions = []
    trial_ks = list(range(10, 301, 10))
    for k in trial_ks:
        kmeans = MiniBatchKMeans(n_clusters=k, random_state=42, batch_size=128, n_init="auto").fit(mat)
        distortions.append(kmeans.inertia_)


    def smooth(values, window=3):
        return [np.mean(values[max(0, i-window+1):i+1]) for i in range(len(values))]

    smooth_distortions = smooth(distortions)
    kneedle = KneeLocator(trial_ks, smooth_distortions, curve="convex", direction="decreasing")


    if kneedle.elbow:
        k = kneedle.elbow
        inertia = distortions[trial_ks.index(k)]
        print(f"Elbow found (smoothed): k={k}, inertia={inertia:.2f} for {label_prefix} at t={t:.2f}")

    else:
        print(f"No elbow found after smoothing. Trying silhouette fallback for {label_prefix} at t={t:.2f}")


        print(f"No elbow found after smoothing. Trying silhouette fallback "
              f"for {label_prefix} at t={t:.2f}")

        max_centroids = 10000

        if len(mat) > max_centroids:
            print(f"Subsampling {max_centroids} of {len(mat)} centroids "
                  "for silhouette selection")
            sample_idx   = np.random.choice(len(mat), max_centroids, replace=False)
            mat_sample   = mat[sample_idx]
        else:
            mat_sample   = mat

        best_k, best_score = None, -1
        for k_try in trial_ks:
            if k_try >= len(mat_sample):
                continue
            kmeans  = MiniBatchKMeans(
                        n_clusters=k_try,
                        random_state=42,
                        batch_size=128,
                        n_init="auto"
                    ).fit(mat_sample)
            score   = silhouette_score(mat_sample, kmeans.labels_)
            if score > best_score:
                best_k, best_score = k_try, score

        k = best_k
        print(f"Using silhouette fallback: k={k} "
              f"(silhouette = {best_score:.4f}) for {label_prefix} at t={t:.2f}")

    method_used = "elbow" if kneedle.elbow else "silhouette"
    print(f"Final k={k} used via {method_used} for {label_prefix} at t={t:.2f}")




    kmeans = MiniBatchKMeans(n_clusters=k, random_state=42, batch_size=128, n_init="auto", verbose=0).fit(mat)
    labels = kmeans.labels_


    topic_counts = defaultdict(lambda: {"CHI": [0]*k, "MOT": [0]*k})
    for rec, topic_id in zip(centroids, labels):
        row_id = rec["RowID"]
        speaker = rec["Speaker"]
        topic_counts[row_id][speaker][topic_id] += 1

    for topic_id in range(k):
        for spk in ["CHI", "MOT"]:
            colname = f"{label_prefix}_{spk}_MetaTopic_{topic_id:02d}_t{t:.2f}"
            df[colname] = df.index.map(
                lambda i: topic_counts[i][spk][topic_id] if i in topic_counts and spk in topic_counts[i] else 0
            )

    all_items = list(emb_cache.keys())
    all_embs = normalize_batch(np.vstack([emb_cache[i] for i in all_items]))
    output = {}

    nlp_single = spacy.load("en_core_web_sm", disable=["parser", "ner"])




    item_speaker_map = {}

    for idx in df.index:
        if label_prefix == "Word":
            for word in speaker_tokens[idx]["MOT"]:
                item_speaker_map[word] = "MOT"
            for word in speaker_tokens[idx]["CHI"]:
                item_speaker_map[word] = "CHI"


        elif label_prefix == "Utt":
            for utt in speaker_utterances[idx]["MOT"]:
                item_speaker_map[utt] = "MOT"
            for utt in speaker_utterances[idx]["CHI"]:
                item_speaker_map[utt] = "CHI"



    for topic_id in range(k):
        centroid = kmeans.cluster_centers_[topic_id]
        dists = np.linalg.norm(all_embs - centroid, axis=1)
        sorted_indices = np.argsort(dists)

        top5 = []


        if label_prefix == "Word":


            dists_topic = dists[sorted_indices]
            words_topic = [all_items[i] for i in sorted_indices]

            percentiles = [0, 5, 10, 15, 20, 25]

            cutoffs = np.percentile(dists_topic, percentiles)

            selected_words = []
            used_indices = set()

            for pct_val, pct_label in zip(cutoffs, percentiles):
                idx = np.argmin(np.abs(dists_topic - pct_val))
                if idx in used_indices:
                    continue
                used_indices.add(idx)
                word = words_topic[idx]
                selected_words.append({
                    "word": word,
                    "speaker": item_speaker_map.get(word, "UNKNOWN"),
                    "distance": round(float(dists_topic[idx]), 4),
                    "percentile": pct_label
                })

            output[f"{label_prefix}_MetaTopic_{topic_id:02d}_t{t:.2f}"] = selected_words

        else:
            dists_topic = dists[sorted_indices]
            utts_topic = [all_items[i] for i in sorted_indices]

            percentiles = [0, 5, 10, 15, 20, 25]

            cutoffs = np.percentile(dists_topic, percentiles)

            selected_utts = []
            used_indices = set()

            for pct_val, pct_label in zip(cutoffs, percentiles):
                idx = np.argmin(np.abs(dists_topic - pct_val))
                if idx in used_indices:
                    continue
                used_indices.add(idx)

                selected_utts.append({
                    "utt": utts_topic[idx],
                    "speaker": item_speaker_map.get(utts_topic[idx], "UNKNOWN"),
                    "distance": round(float(dists_topic[idx]), 4),
                    "percentile": pct_label
                })

            output[f"{label_prefix}_MetaTopic_{topic_id:02d}_t{t:.2f}"] = selected_utts

    out_path = os.path.join(save_dir, f"{label_prefix}_MetaTopic_TopItems_t{t:.2f}.json")
    with open(out_path, "w", encoding="utf-8") as f:
        json.dump(output, f, indent=2, ensure_ascii=False)


for t in thresholds:
    if t < 1.00:
        word_centroids = [c for c in centroid_records_word if c["Threshold"] == t]
        assign_meta_topics_and_save(word_centroids, "Word", word_emb_cache, df, t, root_dir)

        utt_centroids = [c for c in centroid_records_utt if c["Threshold"] == t]
        assign_meta_topics_and_save(utt_centroids, "Utt", utt_emb_cache, df, t, root_dir)


df.to_excel(output_path, index=False)
print("Meta-topic distributions and top items saved.")


Anova

In [None]:
import pandas as pd
import statsmodels.formula.api as smf
from statsmodels.stats.anova import anova_lm

df = pd.read_excel("/content/childes_data_filtered_2.xlsx")

suffixes = [
    "Raw", "Content", "ContentLemma", "ContentLemmaPOS", "Lex", "Syn",
    "EmbWord_t0.25", "EmbWord_t0.50", "EmbWord_t0.75"
]

outcome_labels = {
    "CHI_TTR": "TTR",
    "CHI_CV": "CV",
    "CHI_Gini": "Gini",
    "CHI_Top3Dominance": "Top-3"
}

suffix_labels = {
    "Raw": "All Words",
    "Content": "Content Words",
    "ContentLemma": "Lemmatized Content Words",
    "ContentLemmaPOS": "Lemmatized Content Words with POS",
    "Lex": "Lexical Category",
    "Syn": "Synonyms",
    "EmbWord_t0.25": "Word Embeddings (T=0.25)",
    "EmbWord_t0.50": "Word Embeddings (T=0.50)",
    "EmbWord_t0.75": "Word Embeddings (T=0.75)"
}

rows = []

for suffix in suffixes:
    label = suffix_labels.get(suffix, suffix)
    rows.append({
        "Suffix": label,
        "Outcome": "",
        "F_Study": "", "p_Study": "", "F_Child": "", "p_Child": ""
    })

    for prefix, outcome_name in outcome_labels.items():
        outcome = f"{prefix}_{suffix}"
        if outcome not in df.columns:
            continue

        try:
            model_study = smf.ols(f"Q('{outcome}') ~ C(Study)", data=df).fit()
            anova_study = anova_lm(model_study)
            f_study = anova_study.loc["C(Study)", "F"]
            p_study = anova_study.loc["C(Study)", "PR(>F)"]
            p_study_fmt = r"\textless{}0.001" if p_study < 0.001 else f"{p_study:.3f}"


            model_child = smf.ols(f"Q('{outcome}') ~ C(NumericID)", data=df).fit()
            anova_child = anova_lm(model_child)
            f_child = anova_child.loc["C(NumericID)", "F"]
            p_child = anova_child.loc["C(NumericID)", "PR(>F)"]
            p_child_fmt = r"\textless{}0.001" if p_child < 0.001 else f"{p_child:.3f}"


            rows.append({
                "Suffix": "",
                "Outcome": outcome_name,
                "F_Study": f"{f_study:.2f}",
                "p_Study": p_study_fmt,
                "F_Child": f"{f_child:.2f}",
                "p_Child": p_child_fmt
            })

        except Exception as e:
            print(f"Error with {outcome}: {e}")
            continue

table_df = pd.DataFrame(rows)

latex = table_df.to_latex(
    index=False,
    escape=False,
    caption="ANOVA results for effects of Study and Child on each outcome, grouped by input type.",
    label="tab:anova_longformat_descriptive",
    column_format="llcccc",
    header=["Input Type", "Outcome", "F (Study)", "p (Study)", "F (Child)", "p (Child)"]
)

print(latex)


Neural Network

In [None]:
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import GroupKFold
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import numpy as np
import matplotlib.pyplot as plt
import joblib
import random
import os

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

file_path = "/content/childes_data_filtered_2.xlsx"
df = pd.read_excel(file_path)

lang_encoder = LabelEncoder()
study_encoder = LabelEncoder()
child_encoder = LabelEncoder()
gender_encoder = LabelEncoder()

df["Language_encoded"] = lang_encoder.fit_transform(df["Language"].astype(str))
df["Study_encoded"] = study_encoder.fit_transform(df["Study"].astype(str))
df["Child_encoded"] = child_encoder.fit_transform(df["NumericID"].astype(str))
df["Gender_encoded"] = gender_encoder.fit_transform(df["Gender"].astype(str))


suffixes = [
    "Raw", "Content", "ContentLemma", "ContentLemmaPOS", "Lex", "Syn",
    "EmbWord_t0.25", "EmbWord_t0.50", "EmbWord_t0.75"
]


configs = []

for suffix in suffixes:
    configs.append({
        "input_suffix": suffix,
        "output_suffix": suffix,
        "model_name": f"best_model_ROTATED_{suffix}_TO_{suffix}.pt"
    })

content_inputs = ["Lex", "Syn", "EmbWord_t0.25", "EmbWord_t0.50", "EmbWord_t0.75"]
for suffix in content_inputs:
    configs.append({
        "input_suffix": suffix,
        "output_suffix": "Content",
        "model_name": f"best_model_ROTATED_{suffix}_TO_Content.pt"
    })



class TTRDataset(Dataset):
    def __init__(self, df, feature_columns, target_cols):
        self.features = torch.tensor(df[feature_columns].values, dtype=torch.float32)
        self.language = torch.tensor(df["Language_encoded"].values, dtype=torch.long)
        self.study = torch.tensor(df["Study_encoded"].values, dtype=torch.long)
        self.child = torch.tensor(df["Child_encoded"].values, dtype=torch.long)
        self.gender = torch.tensor(df["Gender_encoded"].values, dtype=torch.long)
        self.targets = torch.tensor(df[target_cols].values, dtype=torch.float32)

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):
        return {
            "features": self.features[idx],
            "language": self.language[idx],
            "study": self.study[idx],
            "child": self.child[idx],
            "gender": self.gender[idx],
            "target": self.targets[idx],
        }

class MultitaskTTRModel(nn.Module):
    def __init__(self, n_langs, n_studies, n_children, n_genders, input_dim, embed_dim=8):
        super().__init__()
        self.lang_embed = nn.Embedding(n_langs, embed_dim)
        self.study_embed = nn.Embedding(n_studies, embed_dim)
        self.child_embed = nn.Embedding(n_children, embed_dim)
        self.gender_embed = nn.Embedding(n_genders, embed_dim)

        self.shared = nn.Sequential(
            nn.Linear(input_dim + 4 * embed_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU()
        )

        self.ttr = nn.Sequential(nn.Linear(32, 1), nn.Sigmoid())
        self.cv = nn.Sequential(nn.Linear(32, 1), nn.ReLU())
        self.gini = nn.Sequential(nn.Linear(32, 1), nn.Sigmoid())
        self.dom = nn.Sequential(nn.Linear(32, 1), nn.Sigmoid())

    def forward(self, features, language, study, child, gender):
        x = torch.cat([
            features,
            self.lang_embed(language),
            self.study_embed(study),
            self.child_embed(child),
            self.gender_embed(gender)
        ], dim=1)
        x = self.shared(x)
        return torch.cat([self.ttr(x), self.cv(x), self.gini(x), self.dom(x)], dim=1)

latex_results = []


for cfg in configs:
    input_suffix = cfg["input_suffix"]
    output_suffix = cfg["output_suffix"]

    print(f"\n=== Cross-validating: MOT input {input_suffix} → CHI output {output_suffix} ===")

    feature_columns = [
        f"MOT_TTR_{input_suffix}", f"MOT_CV_{input_suffix}",
        f"MOT_Gini_{input_suffix}", f"MOT_Top3Dominance_{input_suffix}",
        "MOT_AvgWordLength", "Rel_MOT_Turns", "MOT_AvgTurnLength",
        "MOT_Concreteness", "Age_in_Days"
    ]

    target_cols = [
        f"CHI_TTR_{output_suffix}", f"CHI_CV_{output_suffix}",
        f"CHI_Gini_{output_suffix}", f"CHI_Top3Dominance_{output_suffix}"
    ]

    required_cols = ["Language_encoded", "Study_encoded", "Child_encoded", "Gender_encoded"] + feature_columns + target_cols
    data = df.dropna(subset=required_cols).copy()
    if data.empty:
        print(f"Skipping {input_suffix} → {output_suffix} — no data after dropna.")
        continue

    scaler = StandardScaler()
    data[feature_columns] = scaler.fit_transform(data[feature_columns])

    all_preds, all_targets = [], []
    gkf = GroupKFold(n_splits=5)



    for fold, (train_idx, val_idx) in enumerate(gkf.split(data, groups=data["Study_encoded"])):
        print(f"Fold {fold + 1}/5")
        train_data = data.iloc[train_idx]
        val_data = data.iloc[val_idx]

        train_dataset = TTRDataset(train_data, feature_columns, target_cols)
        val_dataset = TTRDataset(val_data, feature_columns, target_cols)
        train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=64)

        model = MultitaskTTRModel(
            n_langs=len(lang_encoder.classes_),
            n_studies=len(study_encoder.classes_),
            n_children=len(child_encoder.classes_),
            n_genders=len(gender_encoder.classes_),
            input_dim=len(feature_columns)
        ).to(device)

        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
        loss_fn = nn.MSELoss()
        best_val_loss = float("inf")
        patience = 3
        patience_counter = 0

        for epoch in range(1, 51):
            model.train()
            for batch in train_loader:
                optimizer.zero_grad()
                preds = model(batch["features"].to(device), batch["language"].to(device),
                              batch["study"].to(device), batch["child"].to(device), batch["gender"].to(device))
                targets = batch["target"].to(device)
                loss = loss_fn(preds, targets)
                loss.backward()
                optimizer.step()

            model.eval()
            val_preds, val_targets = [], []
            with torch.no_grad():
                for batch in val_loader:
                    preds = model(batch["features"].to(device), batch["language"].to(device),
                                  batch["study"].to(device), batch["child"].to(device), batch["gender"].to(device)).cpu().numpy()
                    targets = batch["target"].cpu().numpy()
                    val_preds.append(preds)
                    val_targets.append(targets)

            val_preds = np.vstack(val_preds)
            val_targets = np.vstack(val_targets)
            val_loss = loss_fn(torch.tensor(val_preds), torch.tensor(val_targets)).item()

            if val_loss < best_val_loss:
                best_val_loss = val_loss
                patience_counter = 0
            else:
                patience_counter += 1
                if patience_counter >= patience:
                    break

        all_preds.append(val_preds)
        all_targets.append(val_targets)

    all_preds = np.vstack(all_preds)
    all_targets = np.vstack(all_targets)

    metric_names = ["TTR", "CV", "Gini", "Top3"]
    print(f"\n Evaluation for: MOT_{input_suffix} → CHI_{output_suffix}")
    for i, name in enumerate(metric_names):
        y_true = all_targets[:, i]
        y_pred = all_preds[:, i]
        rmse = np.sqrt(mean_squared_error(y_true, y_pred))
        mae = mean_absolute_error(y_true, y_pred)
        r2 = r2_score(y_true, y_pred)
        print(f"{name:>5} — RMSE: {rmse:.4f}, MAE: {mae:.4f}, R²: {r2:.4f}")



        latex_results.append({
            "Input": input_suffix,
            "Output": output_suffix,
            "Metric": name,
            "RMSE": rmse,
            "MAE": mae,
            "R2": r2
        })

        plt.figure()
        plt.scatter(y_true, y_pred, alpha=0.5, label="Predictions")
        plt.plot([y_true.min(), y_true.max()], [y_true.min(), y_true.max()], 'r--', label="Ideal")
        plt.xlabel(f"Actual {name}")
        plt.ylabel(f"Predicted {name}")
        plt.title(f"{input_suffix} → {output_suffix} | {name}")
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        plt.show()


    print(f"\n🟢 Final training on full data: {input_suffix} → {output_suffix}")
    full_dataset = TTRDataset(data, feature_columns, target_cols)
    full_loader = DataLoader(full_dataset, batch_size=64, shuffle=True)

    model = MultitaskTTRModel(
        n_langs=len(lang_encoder.classes_),
        n_studies=len(study_encoder.classes_),
        n_children=len(child_encoder.classes_),
        n_genders=len(gender_encoder.classes_),
        input_dim=len(feature_columns)
    ).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    loss_fn = nn.MSELoss()
    patience = 3
    patience_counter = 0
    best_loss = float("inf")

    for epoch in range(1, 51):
        model.train()
        total_loss = 0
        for batch in full_loader:
            optimizer.zero_grad()
            preds = model(batch["features"].to(device), batch["language"].to(device),
                          batch["study"].to(device), batch["child"].to(device), batch["gender"].to(device))
            targets = batch["target"].to(device)
            loss = loss_fn(preds, targets)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        avg_loss = total_loss / len(full_loader)
        print(f"Epoch {epoch:02} | Full Train Loss: {avg_loss:.4f}")

        if avg_loss < best_loss:
            best_loss = avg_loss
            patience_counter = 0
            torch.save(model.state_dict(), cfg["model_name"])
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("⏹️ Early stopping.")
                break



suffix_code = {
    "Raw": "A",
    "Content": "C",
    "ContentLemma": "CL",
    "ContentLemmaPOS": "CLP",
    "Lex": "LEX",
    "Syn": "SYN",
    "EmbWord_t0.25": "WEMB (0.25)",
    "EmbWord_t0.50": "WEMB (0.50)",
    "EmbWord_t0.75": "WEMB (0.75)",
}


results_df = pd.DataFrame(latex_results)

for col in ["RMSE", "MAE", "R2"]:
    results_df[col] = pd.to_numeric(results_df[col], errors="coerce")


results_df["Input"] = results_df["Input"].map(suffix_code)
results_df["Output"] = results_df["Output"].map(suffix_code)




table = results_df.pivot_table(
    index=["Input", "Output"],
    columns="Metric",
    values=["RMSE", "MAE", "R2"],
    aggfunc="mean"
)


table.columns = [f"{stat}_{metric}" for stat, metric in table.columns]
table.reset_index(inplace=True)


order = []
for metric in ["Gini", "CV", "TTR", "Top3"]:
    order.extend([f"{stat}_{metric}" for stat in ["RMSE", "MAE", "R2"]])
table = table[["Input", "Output"] + order]




col_format = "ll" + "ccc" * 4
column_labels = ["Gini", "CV", "TTR", "Top-3 Proportion"]
sub_labels = ["RMSE", "MAE", "R$^2$"] * 4


header1 = (
    " & ".join(["Input", "Output"] + [f"\\multicolumn{{3}}{{c}}{{{label}}}" for label in column_labels])
    + " \\\\"
)
header2 = " & ".join(["", ""] + sub_labels) + " \\\\"
midrule = (
    "\\cmidrule(lr){3-5} "
    "\\cmidrule(lr){6-8} "
    "\\cmidrule(lr){9-11} "
    "\\cmidrule(lr){12-14}"
)


latex_body = table.to_latex(
    index=False,
    float_format="%.3f",
    column_format=col_format,
    header=False,
    escape=False
)




lines = latex_body.splitlines()
lines.insert(3, header2)
lines.insert(3, midrule)
lines.insert(3, header1)


lines[0] = "\\begin{tabular}{" + col_format + "}"
lines[1] = "\\toprule"


lines.append("\\bottomrule")
lines.append("\\end{tabular}")


latex_final = "\n".join(lines)
print(latex_final)


Simulations

In [None]:
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler, LabelEncoder
import torch.nn as nn
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.ticker import FormatStrFormatter, MaxNLocator
import gc
from tqdm.auto import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
df = pd.read_excel("/content/childes_data_filtered_2.xlsx")
df["Age_in_Days"] = df["Age_in_Days"].astype(float)

suffix_labels = {
    "Raw": "All Words",
    "Content": "Content Words",
    "ContentLemma": "Lemmatized Content Words",
    "ContentLemmaPOS": "Lemmatized Content Words with POS",
    "Syn": "First Synset Identifier",
    "Lex": "Lexical Category",
    "EmbWord_t0.25": "Word Embeddings  T = 0.25",
    "EmbWord_t0.50": "Word Embeddings T = 0.5",
    "EmbWord_t0.75": "Word Embeddings T = 0.75",
}

var_labels = {
    "TTR": "TTR",
    "CV": "CV",
    "Gini": "Gini",
    "Top3Dominance": "Top-3 Proportion",
    "AvgWordLength": "Mean Word Length",
    "AvgTurnLength": "Mean Turn Length",
    "Rel_MOT_Turns": "Conversation Share",
    "Concreteness": "Concreteness",
}

def pretty_feature_label(var):
    if var.startswith("MOT_"):
        rest = var[4:]
        for suffix in suffix_labels:
            if rest.endswith("_" + suffix):
                base = rest[:-(len(suffix)+1)]
                return var_labels.get(base, base)
        return var_labels.get(rest, rest)
    return var_labels.get(var, var)

encoders = {}
for col in ["Language", "Study", "NumericID", "Gender"]:
    encoder = LabelEncoder()
    df[col + "_encoded"] = encoder.fit_transform(df[col].astype(str))
    encoders[col] = encoder



class MultitaskTTRModel(nn.Module):
    def __init__(self, n_langs, n_studies, n_children, n_genders, input_dim, embed_dim=8):
        super().__init__()
        self.lang_embed = nn.Embedding(n_langs, embed_dim)
        self.study_embed = nn.Embedding(n_studies, embed_dim)
        self.child_embed = nn.Embedding(n_children, embed_dim)
        self.gender_embed = nn.Embedding(n_genders, embed_dim)

        self.shared = nn.Sequential(
            nn.Linear(input_dim + 4 * embed_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU()
        )

        self.ttr = nn.Sequential(nn.Linear(32, 1), nn.Sigmoid())
        self.cv = nn.Sequential(nn.Linear(32, 1), nn.ReLU())
        self.gini = nn.Sequential(nn.Linear(32, 1), nn.Sigmoid())
        self.dom = nn.Sequential(nn.Linear(32, 1), nn.Sigmoid())

    def forward(self, features, language, study, child, gender):
        x = torch.cat([
            features,
            self.lang_embed(language),
            self.study_embed(study),
            self.child_embed(child),
            self.gender_embed(gender)
        ], dim=1)
        x = self.shared(x)
        return torch.cat([self.ttr(x), self.cv(x), self.gini(x), self.dom(x)], dim=1)

target_labels = ["TTR", "CV", "Gini", "Top3Dominance"]



output_types = ["Content", "Matched"]
suffixes = [
    "EmbWord_t0.75", "EmbWord_t0.50", "EmbWord_t0.25",
    "Syn", "Lex",
    "ContentLemmaPOS", "ContentLemma", "Content",
    "Raw"
]

for suffix in suffixes:
    for output_type in output_types:
        output_suffix = suffix if output_type == "Matched" else "Content"
        model_path = f"best_model_ROTATED_{suffix}_TO_{output_suffix}.pt"
        print(f"\n Plotting: MOT_{suffix} → CHI_{output_suffix}")

        feature_cols = [
            f"MOT_TTR_{suffix}", f"MOT_CV_{suffix}", f"MOT_Gini_{suffix}", f"MOT_Top3Dominance_{suffix}",
            "MOT_AvgWordLength", "Rel_MOT_Turns", "MOT_AvgTurnLength",
            "MOT_Concreteness", "Age_in_Days"
        ]
        target_cols = [f"CHI_{metric}_{output_suffix}" for metric in ["TTR", "CV", "Gini", "Top3Dominance"]]

        df_clean = df.dropna(subset=feature_cols + target_cols).copy()
        if df_clean.empty:
            print(f" Skipping {suffix} → {output_suffix} — no data.")
            continue

        top3_col = f"CHI_Top3Dominance_{output_suffix}"
        top3_scaler = StandardScaler()
        top3_scaler.fit(df_clean[[top3_col]])


        model = MultitaskTTRModel(
            n_langs=len(encoders["Language"].classes_),
            n_studies=len(encoders["Study"].classes_),
            n_children=len(encoders["NumericID"].classes_),
            n_genders=len(encoders["Gender"].classes_),
            input_dim=len(feature_cols)
        ).to(device)

        try:
            model.load_state_dict(torch.load(model_path, map_location=device))
        except FileNotFoundError:
            print(f"❌ Missing model: {model_path}")
            continue

        model.eval()
        all_preds = []
        mot_inputs = [f for f in feature_cols if f != "Age_in_Days"]
        age_vals = np.linspace(df_clean["Age_in_Days"].min(), df_clean["Age_in_Days"].max(), 50)

        for mot_var in mot_inputs:
            other_vars = [f for f in feature_cols if f not in ["Age_in_Days", mot_var]]
            cat_vars = ["Language_encoded", "Study_encoded", "NumericID_encoded", "Gender_encoded"]
            samples = df_clean[other_vars + cat_vars].copy().reset_index(drop=True)
            var_vals = np.linspace(df_clean[mot_var].min(), df_clean[mot_var].max(), 50)

            preds_all = []
            scaler = StandardScaler()
            scaler.fit(df_clean[feature_cols])

            for age in age_vals:
                for mot_val in var_vals:
                    temp = samples.copy()
                    temp["Age_in_Days"] = age
                    temp[mot_var] = mot_val
                    X = scaler.transform(temp[feature_cols])
                    X_tensor = torch.tensor(X, dtype=torch.float32, device=device)

                    lang = torch.tensor(samples["Language_encoded"].values, device=device)
                    study = torch.tensor(samples["Study_encoded"].values, device=device)
                    child = torch.tensor(samples["NumericID_encoded"].values, device=device)
                    gender = torch.tensor(samples["Gender_encoded"].values, device=device)


                    with torch.no_grad():
                        pred = model(X_tensor, lang, study, child, gender).cpu().numpy()


                    preds_all.append(pred.mean(axis=0))

            avg_preds = np.array(preds_all).reshape(len(var_vals), len(age_vals), 4)
            all_preds.append(avg_preds)

        fig, axes = plt.subplots(len(mot_inputs), 4, figsize=(4 * 4, len(mot_inputs) * 3))
        Zmins = [np.inf] * 4
        Zmaxs = [-np.inf] * 4
        for j in range(4):
            for pred in all_preds:
                Zmins[j] = min(Zmins[j], pred[:, :, j].min())
                Zmaxs[j] = max(Zmaxs[j], pred[:, :, j].max())

        for i, mot_var in enumerate(mot_inputs):
            var_vals = np.linspace(df_clean[mot_var].min(), df_clean[mot_var].max(), 50)
            avg_preds = all_preds[i]
            shared_ylim = (var_vals.min(), var_vals.max())
            shared_locator = MaxNLocator(nbins=3)
            shared_formatter = FormatStrFormatter('%.2f')

            for j, label in enumerate(["TTR", "CV", "Gini", "Top3Dominance"]):
                ax = axes[i, j] if len(mot_inputs) > 1 else axes[j]
                Z = avg_preds[:, :, j]
                contour = ax.contourf(age_vals / 365.25, var_vals, Z, levels=50, cmap="viridis",
                                      vmin=Zmins[j], vmax=Zmaxs[j])

                ax.set_ylim(shared_ylim)
                ax.yaxis.set_major_locator(shared_locator)
                ax.yaxis.set_major_formatter(shared_formatter)


                if i == 0:
                    ax.set_title(var_labels.get(label, label), fontsize=14)
                if j == 0:
                    ax.set_ylabel(pretty_feature_label(mot_var), fontsize=14)
                    ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
                    ax.yaxis.set_major_locator(MaxNLocator(nbins=3))

                ticks = np.arange(1, 6, 1)
                ax.set_xticks(ticks)
                if i == len(mot_inputs) - 1:
                    ax.set_xticklabels([str(int(t)) for t in ticks])
                else:
                    ax.set_xticklabels([])

                for xtick in ticks:
                    ax.axvline(x=xtick, color="white", linestyle="--", linewidth=0.5, alpha=0.3)


                if j != 0:
                    ax.set_yticklabels([])
                ax.set_xlabel("")

                ax.tick_params(labelsize=12)
                divider = make_axes_locatable(ax)
                cax = divider.append_axes("right", size="5%", pad=0.1)
                cbar = fig.colorbar(contour, cax=cax)
                cbar.ax.tick_params(labelsize=12)
                ticks = np.linspace(cbar.vmin, cbar.vmax, 3)
                cbar.set_ticks(ticks)
                cbar.ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))


        title_suffix = "→ CHI_Content" if output_type == "Content" else f"→ CHI_{suffix}"

        plt.tight_layout(rect=[0, 0.03, 0.95, 0.94])

        plt.subplots_adjust(wspace=0.3, hspace=0.3)


        fig.text(0.5, 0.94, f"Child ({suffix_labels.get(output_suffix, output_suffix)})",
         ha="center", va="bottom", fontsize=20)

        fig.text(-0.04, 0.5, f"Mother ({suffix_labels.get(suffix, suffix)})",
                ha="center", va="center", rotation="vertical", fontsize=20)

        fig.text(0.5, 0.02, "Age (years)", ha="center", va="center", fontsize=16)


        plt.show()



        del all_preds
        del model
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()


In [None]:
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler, LabelEncoder
import torch.nn as nn
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.ticker import FormatStrFormatter, MaxNLocator
import gc
from tqdm.auto import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
df = pd.read_excel("/content/childes_data_filtered_2.xlsx")
df["Age_in_Days"] = df["Age_in_Days"].astype(float)



pred_indices = {"TTR": 0, "CV": 1, "Gini": 2, "Top3Dominance": 3}


suffix_labels = {
    "Raw": "All Words",
    "Content": "Content Words",
    "ContentLemma": "Lemmatized Content Words",
    "ContentLemmaPOS": "Lemmatized Content Words with POS",
    "Syn": "Synonyms",
    "Lex": "Lexical Category",
    "EmbWord_t0.25": "Word Embeddings  T = 0.25",
    "EmbWord_t0.50": "Word Embeddings T = 0.5",
    "EmbWord_t0.75": "Word Embeddings T = 0.75",
}

var_labels = {
    "TTR": "TTR",
    "CV": "CV",
    "Gini": "Gini",
    "Top3Dominance": "Top-3 Proportion",
    "AvgWordLength": "Mean Word Length",
    "AvgTurnLength": "Mean Turn Length",
    "Rel_MOT_Turns": "Conversation Share",
    "Concreteness": "Concreteness",
}


def pretty_feature_label(var):
    if var.startswith("MOT_"):
        rest = var[4:]
        for suffix in suffix_labels:
            if rest.endswith("_" + suffix):
                base = rest[:-(len(suffix)+1)]
                return var_labels.get(base, base)
        return var_labels.get(rest, rest)
    return var_labels.get(var, var)

encoders = {}
for col in ["Language", "Study", "NumericID", "Gender"]:
    encoder = LabelEncoder()
    df[col + "_encoded"] = encoder.fit_transform(df[col].astype(str))
    encoders[col] = encoder



class MultitaskTTRModel(nn.Module):
    def __init__(self, n_langs, n_studies, n_children, n_genders, input_dim, embed_dim=8):
        super().__init__()
        self.lang_embed = nn.Embedding(n_langs, embed_dim)
        self.study_embed = nn.Embedding(n_studies, embed_dim)
        self.child_embed = nn.Embedding(n_children, embed_dim)
        self.gender_embed = nn.Embedding(n_genders, embed_dim)

        self.shared = nn.Sequential(
            nn.Linear(input_dim + 4 * embed_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU()
        )

        self.ttr = nn.Sequential(nn.Linear(32, 1), nn.Sigmoid())
        self.cv = nn.Sequential(nn.Linear(32, 1), nn.ReLU())
        self.gini = nn.Sequential(nn.Linear(32, 1), nn.Sigmoid())
        self.dom = nn.Sequential(nn.Linear(32, 1), nn.Sigmoid())

    def forward(self, features, language, study, child, gender):
        x = torch.cat([
            features,
            self.lang_embed(language),
            self.study_embed(study),
            self.child_embed(child),
            self.gender_embed(gender)
        ], dim=1)
        x = self.shared(x)
        return torch.cat([self.ttr(x), self.cv(x), self.gini(x), self.dom(x)], dim=1)



output_types = ["Content", "Matched"]
suffixes = [
    "EmbWord_t0.75", "EmbWord_t0.50", "EmbWord_t0.25",
    "Syn", "Lex",
    "ContentLemmaPOS", "ContentLemma", "Content",
    "Raw"
]

for suffix in suffixes:

    if suffix == "Lex":
        target_labels = ["TTR", "CV"]
    else:
        target_labels = ["TTR", "Gini"]

    for output_type in output_types:
        output_suffix = suffix if output_type == "Matched" else "Content"
        model_path = f"best_model_ROTATED_{suffix}_TO_{output_suffix}.pt"
        print(f"\n Plotting: MOT_{suffix} → CHI_{output_suffix}")

        feature_cols = [
            f"MOT_TTR_{suffix}", f"MOT_CV_{suffix}", f"MOT_Gini_{suffix}", f"MOT_Top3Dominance_{suffix}",
            "MOT_AvgWordLength", "Rel_MOT_Turns", "MOT_AvgTurnLength",
            "MOT_Concreteness", "Age_in_Days"
        ]
        target_cols = [f"CHI_{metric}_{output_suffix}" for metric in ["TTR", "CV", "Gini", "Top3Dominance"]]

        df_clean = df.dropna(subset=feature_cols + target_cols).copy()
        if df_clean.empty:
            print(f" Skipping {suffix} → {output_suffix} — no data.")
            continue

        top3_col = f"CHI_Top3Dominance_{output_suffix}"
        top3_scaler = StandardScaler()
        top3_scaler.fit(df_clean[[top3_col]])


        model = MultitaskTTRModel(
            n_langs=len(encoders["Language"].classes_),
            n_studies=len(encoders["Study"].classes_),
            n_children=len(encoders["NumericID"].classes_),
            n_genders=len(encoders["Gender"].classes_),
            input_dim=len(feature_cols)
        ).to(device)

        try:
            model.load_state_dict(torch.load(model_path, map_location=device))
        except FileNotFoundError:
            print(f"❌ Missing model: {model_path}")
            continue

        model.eval()
        all_preds = []
        allowed_inputs = ["MOT_TTR_", "MOT_CV_", "MOT_Gini_", "MOT_Top3Dominance_"]

        mot_inputs = [
            f for f in feature_cols
            if f != "Age_in_Days" and any(f.startswith(prefix) for prefix in allowed_inputs)
        ]

        row_order = ["TTR", "CV", "Gini", "Top3Dominance"]

        mot_inputs = sorted(mot_inputs, key=lambda x: row_order.index(x.split("_")[1]))

        age_vals = np.linspace(df_clean["Age_in_Days"].min(), df_clean["Age_in_Days"].max(), 50)

        for mot_var in mot_inputs:
            other_vars = [f for f in feature_cols if f not in ["Age_in_Days", mot_var]]
            cat_vars = ["Language_encoded", "Study_encoded", "NumericID_encoded", "Gender_encoded"]
            samples = df_clean[other_vars + cat_vars].copy().reset_index(drop=True)
            var_vals = np.linspace(df_clean[mot_var].min(), df_clean[mot_var].max(), 50)

            preds_all = []
            scaler = StandardScaler()
            scaler.fit(df_clean[feature_cols])

            for age in age_vals:
                for mot_val in var_vals:
                    temp = samples.copy()
                    temp["Age_in_Days"] = age
                    temp[mot_var] = mot_val
                    X = scaler.transform(temp[feature_cols])
                    X_tensor = torch.tensor(X, dtype=torch.float32, device=device)

                    lang = torch.tensor(samples["Language_encoded"].values, device=device)
                    study = torch.tensor(samples["Study_encoded"].values, device=device)
                    child = torch.tensor(samples["NumericID_encoded"].values, device=device)
                    gender = torch.tensor(samples["Gender_encoded"].values, device=device)


                    with torch.no_grad():
                        pred = model(X_tensor, lang, study, child, gender).cpu().numpy()

                    preds_all.append(pred.mean(axis=0))

            avg_preds = np.array(preds_all).reshape(len(var_vals), len(age_vals), 4)
            all_preds.append(avg_preds)

        fig, axes = plt.subplots(len(mot_inputs), len(target_labels), figsize=(16, 24))


        Zmins = {}
        Zmaxs = {}
        for label in target_labels:
            idx = pred_indices[label]
            Zmins[label] = min(pred[:, :, idx].min() for pred in all_preds)
            Zmaxs[label] = max(pred[:, :, idx].max() for pred in all_preds)


        for i, mot_var in enumerate(mot_inputs):
            var_vals = np.linspace(df_clean[mot_var].min(), df_clean[mot_var].max(), 50)
            avg_preds = all_preds[i]
            shared_ylim = (var_vals.min(), var_vals.max())
            shared_locator = MaxNLocator(nbins=3)
            shared_formatter = FormatStrFormatter('%.2f')

            for j, label in enumerate(target_labels):

                ax = axes[i, j] if len(mot_inputs) > 1 else axes[j]
                Z = avg_preds[:, :, pred_indices[label]]

                contour = ax.contourf(age_vals / 365.25, var_vals, Z, levels=50, cmap="viridis",
                                      vmin = Zmins[label], vmax = Zmaxs[label])

                ax.set_ylim(shared_ylim)
                ax.yaxis.set_major_locator(shared_locator)
                ax.yaxis.set_major_formatter(shared_formatter)


                if i == 0:
                    ax.set_title(var_labels.get(label, label), fontsize=22)
                if j == 0:
                    ax.set_ylabel(pretty_feature_label(mot_var), fontsize=22)
                    ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
                    ax.yaxis.set_major_locator(MaxNLocator(nbins=3))

                ticks = np.arange(1, 6, 1)
                ax.set_xticks(ticks)
                if i == len(mot_inputs) - 1:
                    ax.set_xticklabels([str(int(t)) for t in ticks])
                else:
                    ax.set_xticklabels([])

                for xtick in ticks:
                    ax.axvline(x=xtick, color="white", linestyle="--", linewidth=1.5, alpha=0.3)


                if j != 0:
                    ax.set_yticklabels([])
                ax.set_xlabel("")

                ax.tick_params(labelsize=18, length=8)
                divider = make_axes_locatable(ax)
                cax = divider.append_axes("right", size="5%", pad=0.1)
                cbar = fig.colorbar(contour, cax=cax)
                cbar.ax.tick_params(labelsize=18)
                ticks = np.linspace(cbar.vmin, cbar.vmax, 3)
                cbar.set_ticks(ticks)
                cbar.ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))


        title_suffix = "→ CHI_Content" if output_type == "Content" else f"→ CHI_{suffix}"

        plt.tight_layout(rect=[0, 0.03, 0.95, 0.94])

        plt.subplots_adjust(wspace=0.3, hspace=0.3)


        fig.text(0.5, 0.94, f"Child ({suffix_labels.get(output_suffix, output_suffix)})",
         ha="center", va="bottom", fontsize=24)

        fig.text(-0.04, 0.5, f"Mother ({suffix_labels.get(suffix, suffix)})",
                ha="center", va="center", rotation="vertical", fontsize=24)

        fig.text(0.5, 0.02, "Age (years)", ha="center", va="center", fontsize=20)


        plt.show()



        del all_preds
        del model
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
