In [73]:
#!pip install emoji catboost

In [74]:

from __future__ import annotations
import typing
import json
import pathlib
import os
import hashlib
import urllib.parse

import numpy as np
import pandas as pd

import torch

from nltk.corpus import stopwords
import emoji

import sklearn
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split, GroupShuffleSplit

import catboost

from tqdm import tqdm

import IPython
from IPython.display import display

In [75]:
IS_KAGGLE = "KAGGLE_DOCKER_IMAGE" in os.environ

DATASETS = pathlib.Path(
    "."
    if not IS_KAGGLE
    else "/kaggle/input/influencers-or-observers-predicting-social-roles/Kaggle2025"
)

DATASET_TRAIN = DATASETS / "train.jsonl"
DATASET_KAGGLE = DATASETS / "kaggle_test.jsonl"

CACHE_DIR = pathlib.Path(".")

In [76]:
np.random.seed(42)

# Data loading

In [77]:
def load_json(path: pathlib.Path, cache: bool = False) -> pd.DataFrame:
    path_pq = (CACHE_DIR / path.name).with_stem(f"{path.stem}_raw").with_suffix(".parquet")
    
    if cache and path_pq.exists():
        return pd.read_parquet(path_pq)
    
    # This leaves things to be desired, since there's no way to specify dtypes
    # and it assumes float instead of int, causing a loss in precision...
    # But I guess it only matters for ids, which we'll probably discard in preprocessing anyway
    result = pd.json_normalize(list(map(json.loads, path.read_bytes().splitlines())))
    
    if cache:
        result.to_parquet(path_pq)
    
    return result


In [78]:
train_data = load_json(DATASET_TRAIN, cache=True)
kaggle_data = load_json(DATASET_KAGGLE, cache=True)

# Preprocessing

In [79]:
def preprocess(df: pd.DataFrame) -> pd.DataFrame:
    # For technical reasons, any text columns we want to use should have no dots in their names.
    # The simplest way to achieve this is to replace all dots indiscriminately.
    
    df = df.rename(columns=lambda x: x.replace(".", "_"))
    
    df["is_reply"] = df["in_reply_to_status_id"].notna()
    
    user_id_key = ["user_description", "user_created_at", "user_profile_image_url"]
    df["user_hash"] = df[user_id_key].fillna("<NA>").astype(str).agg(''.join, axis=1).where(~df[user_id_key].isna().all(axis=1), df["id_str"]).map(fast_hash)
    
    df = df.drop(columns=[
        "in_reply_to_status_id_str",
        # "in_reply_to_status_id",
        "in_reply_to_user_id_str",
        "in_reply_to_user_id",
        "quoted_status_id_str",
        "quoted_status_id",
        "id_str",
        "quoted_status_in_reply_to_status_id_str",
        "quoted_status_in_reply_to_status_id",
        "quoted_status_in_reply_to_user_id_str",
        "quoted_status_in_reply_to_user_id",
        "quoted_status_id_str",
        "quoted_status_id",
        "quoted_status_user_id_str",
        "quoted_status_user_id",
        # "quoted_status_permalink_expanded",
        "quoted_status_permalink_display",
        "quoted_status_permalink_url",
        "quoted_status_quoted_status_id",
        "quoted_status_quoted_status_id_str",
        # "quoted_status_place_id",
        # "place_id",
        "lang",  # Always "fr"
        "retweeted",  # Always False
        "filter_level",  # Always "low"
        "geo",  # Always None
        "place",  # Always None
        "coordinates",  # Always None
        "contributors",  # Always None
        "quote_count",  # Always 0
        "reply_count",  # Always 0
        "retweet_count",  # Always 0
        "favorite_count",  # Always 0
        "favorited",  # Always False
        "quoted_status_geo",  # Always None
        "quoted_status_place",  # Always None
        "quoted_status_coordinates",  # Always None
        "quoted_status_retweeted",  # Always False
        "quoted_status_filter_level",  # Always "low"
        "quoted_status_contributors",  # Always None
        "quoted_status_user_utc_offset",  # Always None
        "quoted_status_user_lang",  # Always None
        "quoted_status_user_time_zone",  # Always None
        "quoted_status_user_follow_request_sent",  # Always None
        "quoted_status_user_following",  # Always None
        "quoted_status_user_notifications",  # Always None
        "user_default_profile_image",  # Always False
        "user_protected",  # Always False
        "user_contributors_enabled",  # Always False
        "user_lang",  # Always None
        "user_notifications",  # Always None
        "user_following",  # Always None
        "user_utc_offset",  # Always None
        "user_time_zone",  # Always None
        "user_follow_request_sent",  # Always None
    ])
    
    df["full_text"] = df.apply(lambda tweet: extract_full_text(tweet), axis=1)
    
    source_split = df["source"].str.removeprefix("<a href=\"").str.removesuffix("</a>").str.split("\" rel=\"nofollow\">").map(lambda x: x if len(x) == 2 else pd.NA)
    df["source_url"] = source_split.map(lambda x: x[0], na_action="ignore")
    df["source_name"] = source_split.map(lambda x: x[1], na_action="ignore")
    
    df["misc_text"] = df.apply(
        lambda x: "via: {0}; reply: @{1}; quote: @{2} {3}".format(x["source_name"], x["in_reply_to_screen_name"], x["quoted_status_user_screen_name"], x["quoted_status_user_name"]), axis=1,
    )
    
    df["source_domain"] = df["source_url"].map(extract_domain, na_action="ignore")
    df["user_domain"] = df["user_url"].str.extract(r"https?://([^/]+)/")
    
    for col in [
        "quoted_status_user_profile_link_color",
        "quoted_status_user_profile_background_color",
        "quoted_status_user_profile_sidebar_border_color",
        "quoted_status_user_profile_text_color",
        "user_profile_link_color",
        "user_profile_background_color",
        "user_profile_sidebar_border_color",
        "user_profile_text_color",
        "user_profile_sidebar_fill_color",
    ]:
        df[f"{col}_r"], df[f"{col}_g"], df[f"{col}_b"] = extract_color(df[col])
    
    for col in [
        "full_text",
        "source_name",
        "in_reply_to_screen_name",
        "quoted_status_extended_tweet_entities_urls",
        "quoted_status_extended_tweet_entities_user_mentions",
        "quoted_status_extended_tweet_full_text",
        "quoted_status_entities_urls",
        "quoted_status_user_profile_image_url_https",
        "quoted_status_user_profile_background_image_url",
        "quoted_status_user_profile_background_image_url_https",
        "quoted_status_user_screen_name",
        "quoted_status_user_name",
        "entities_hashtags",
        "entities_user_mentions",
        "user_profile_image_url_https",
        "user_profile_background_image_url",
        "user_description",
        "user_translator_type",
        "user_url",
        "user_profile_banner_url",
        "user_location",
        "display_text_range",
        "extended_tweet_entities_urls",
        "extended_tweet_entities_hashtags",
        "extended_tweet_entities_user_mentions",
        "quoted_status_permalink_expanded",
    ]:
        df[f"{col}_len"] = df[col].map(len, na_action="ignore").fillna(0)
    
    dt_cols = ["created_at", "quoted_status_created_at", "quoted_status_user_created_at", "user_created_at"]
    df[dt_cols] = df[dt_cols].apply(pd.to_datetime, format="%a %b %d %H:%M:%S %z %Y", errors="coerce").map(pd.Timestamp.timestamp, na_action="ignore")
    
    ts = pd.to_datetime(df["timestamp_ms"].astype("int64"), unit="ms")
    df["hour"] = ts.dt.hour
    df["weekday"] = ts.dt.weekday
    df["month"] = ts.dt.month
    
    for em in [
        ":backhand_index_pointing_right:",
        ":right_arrow:",
        ":red_circle:",
        ":right_arrow_curving_down:",
        ":play_button:",
        ":backhand_index_pointing_down:",
        ":thinking_face:",
        ":check_mark_button:",
        ":warning:",
        ":Canada:",
        ":down_arrow:",
        ":loudly_crying_face:",
        ":face_with_medical_mask:",
        ":rolling_on_the_floor_laughing:",
        ":megaphone:",
        ":police_car_light:",
        ":loudspeaker:",
        ":laptop:",
        ":syringe:",
        ":studio_microphone:",
        ":face_vomiting:",
        ":information:",
        ":skull:",
        ":round_pushpin:",
        ":speech_balloon:",
        ":face_with_tears_of_joy:",
        ":blue_circle:",
        ":television:",
    ]:
        df[f"emoji_{em.strip(':')}"] = df["full_text"].str.contains(emoji.demojize(em), regex=False)
    
    return df

def extract_full_text(tweet: pd.Series) -> str:
    text: str = tweet["text"]
    
    if not pd.isna(tweet["extended_tweet_full_text"]):
        text = tweet["extended_tweet_full_text"]
    
    return text

def fast_hash(content: str) -> str:
    h = hashlib.blake2s(digest_size=16)
    h.update(content.encode('utf-8'))
    return h.hexdigest()

def extract_domain(url: str) -> str:
    return urllib.parse.urlparse(url).netloc

def extract_color(color: pd.Series) -> tuple[pd.Series, pd.Series, pd.Series]:
    return tuple(
        color.str.slice(i, i + 2).map(lambda x: int(x, 16), na_action="ignore")
        for i in (0, 2, 4)
    )


In [80]:
X_train_pre = train_data.drop("label", axis=1)
y_train = train_data["label"]

X_kaggle_pre = kaggle_data

X_train_pre = preprocess(X_train_pre)
X_kaggle_pre = preprocess(X_kaggle_pre)

  df[f"{col}_len"] = df[col].map(len, na_action="ignore").fillna(0)
  df[f"emoji_{em.strip(':')}"] = df["full_text"].str.contains(emoji.demojize(em), regex=False)
  df[f"emoji_{em.strip(':')}"] = df["full_text"].str.contains(emoji.demojize(em), regex=False)
  df[f"emoji_{em.strip(':')}"] = df["full_text"].str.contains(emoji.demojize(em), regex=False)
  df[f"{col}_len"] = df[col].map(len, na_action="ignore").fillna(0)
  df[f"emoji_{em.strip(':')}"] = df["full_text"].str.contains(emoji.demojize(em), regex=False)
  df[f"emoji_{em.strip(':')}"] = df["full_text"].str.contains(emoji.demojize(em), regex=False)
  df[f"emoji_{em.strip(':')}"] = df["full_text"].str.contains(emoji.demojize(em), regex=False)


In [81]:
inferred_dtypes = pd.Series({
    col: X_train_pre[col].dropna().convert_dtypes().dtype
    for col in X_train_pre.columns
})

  if (arr.astype(int) == arr).all():
  if (arr.astype(int) == arr).all():


In [82]:
with pd.option_context("display.max_rows", None, "display.max_columns", None):
    display(inferred_dtypes)

in_reply_to_status_id                                               Float64
created_at                                                            Int64
source                                                       string[python]
in_reply_to_screen_name                                      string[python]
is_quote_status                                                     boolean
text                                                         string[python]
truncated                                                           boolean
timestamp_ms                                                 string[python]
challenge_id                                                          Int64
quoted_status_extended_tweet_entities_urls                           object
quoted_status_extended_tweet_entities_hashtags                       object
quoted_status_extended_tweet_entities_user_mentions                  object
quoted_status_extended_tweet_entities_symbols                        object
quoted_statu

In [83]:
column_kinds: dict[str, typing.Literal["unknown", "num", "bool", "text", "cat", "emb", "skip"]] = \
    dict.fromkeys(X_train_pre.columns, "unknown")

def _mark_resolved():
    inferred_dtypes.drop([k for k, v in column_kinds.items() if v != "unknown"], axis=0, errors="ignore", inplace=True)

In [84]:
column_kinds["challenge_id"] = "skip"
_mark_resolved()

In [85]:
print("Inferring kind as num for:")
for col in inferred_dtypes[inferred_dtypes.astype("str").isin(["Int64", "Float64", "Int32", "Float32"])].index:
    print(f"> {col}")
    column_kinds[col] = "num"

_mark_resolved()

Inferring kind as num for:
> in_reply_to_status_id
> created_at
> quoted_status_created_at
> quoted_status_retweet_count
> quoted_status_favorite_count
> quoted_status_quote_count
> quoted_status_reply_count
> quoted_status_user_friends_count
> quoted_status_user_listed_count
> quoted_status_user_favourites_count
> quoted_status_user_created_at
> quoted_status_user_statuses_count
> quoted_status_user_followers_count
> user_listed_count
> user_favourites_count
> user_created_at
> user_statuses_count
> quoted_status_user_profile_link_color_r
> quoted_status_user_profile_link_color_g
> quoted_status_user_profile_link_color_b
> quoted_status_user_profile_background_color_r
> quoted_status_user_profile_background_color_g
> quoted_status_user_profile_background_color_b
> quoted_status_user_profile_sidebar_border_color_r
> quoted_status_user_profile_sidebar_border_color_g
> quoted_status_user_profile_sidebar_border_color_b
> quoted_status_user_profile_text_color_r
> quoted_status_user_profile

In [86]:
print("Inferring kind as bool for:")
for col in inferred_dtypes[inferred_dtypes == "boolean"].index:
    print(f"> {col}")
    column_kinds[col] = "bool"

_mark_resolved()

Inferring kind as bool for:
> is_quote_status
> truncated
> quoted_status_is_quote_status
> quoted_status_favorited
> quoted_status_truncated
> quoted_status_user_default_profile_image
> quoted_status_user_is_translator
> quoted_status_user_protected
> quoted_status_user_geo_enabled
> quoted_status_user_verified
> quoted_status_user_contributors_enabled
> quoted_status_user_profile_background_tile
> quoted_status_user_profile_use_background_image
> quoted_status_user_default_profile
> user_is_translator
> user_geo_enabled
> user_profile_background_tile
> user_profile_use_background_image
> user_default_profile
> possibly_sensitive
> quoted_status_possibly_sensitive
> quoted_status_scopes_followers
> is_reply
> emoji_backhand_index_pointing_right
> emoji_right_arrow
> emoji_red_circle
> emoji_right_arrow_curving_down
> emoji_play_button
> emoji_backhand_index_pointing_down
> emoji_thinking_face
> emoji_check_mark_button
> emoji_Canada
> emoji_down_arrow
> emoji_loudly_crying_face
> emoj

In [87]:
list_like_cols: list[str] = [
    col
    for col in inferred_dtypes.index
    if X_train_pre[col].map(lambda x: isinstance(x, np.ndarray), na_action="ignore").all()
]

In [88]:
print("Inferring kind as skip for:")
for col in list_like_cols:
    # # Variable-length lists
    # if X_train[col].dropna().map(len).unique().size > 2:
    print(f"> {col}")
    column_kinds[col] = "skip"

_mark_resolved()

# TODO: Split fixed-length list columns, aka the following, into separate scalar columns:
# tuple_cols = [
#     "coordinates_coordinates",
#     "display_text_range",
#     "entities_media",
#     "extended_tweet_display_text_range",
#     "geo_coordinates",
#     "place_bounding_box_coordinates",
#     "quoted_status_coordinates_coordinates",
#     "quoted_status_display_text_range",
#     "quoted_status_entities_media",
#     "quoted_status_extended_tweet_display_text_range",
#     "quoted_status_geo_coordinates",
#     "quoted_status_place_bounding_box_coordinates",
#     "quoted_status_withheld_in_countries",
#     "withheld_in_countries",
# ]

Inferring kind as skip for:
> quoted_status_extended_tweet_entities_urls
> quoted_status_extended_tweet_entities_hashtags
> quoted_status_extended_tweet_entities_user_mentions
> quoted_status_extended_tweet_entities_symbols
> quoted_status_extended_tweet_display_text_range
> quoted_status_entities_urls
> quoted_status_entities_hashtags
> quoted_status_entities_user_mentions
> quoted_status_entities_symbols
> entities_urls
> entities_hashtags
> entities_user_mentions
> entities_symbols
> display_text_range
> extended_tweet_entities_urls
> extended_tweet_entities_hashtags
> extended_tweet_entities_user_mentions
> extended_tweet_entities_symbols
> extended_tweet_display_text_range
> quoted_status_extended_entities_media
> quoted_status_entities_media
> quoted_status_display_text_range
> extended_tweet_extended_entities_media
> extended_tweet_entities_media
> quoted_status_extended_tweet_extended_entities_media
> quoted_status_extended_tweet_entities_media
> place_bounding_box_coordinates


In [89]:
print("Inferring kind as cat for:")
for col in inferred_dtypes.index:
    col: str
    if col.endswith("_type") or col.endswith("_kind"):
        print(f"> {col}")
        column_kinds[col] = "cat"

_mark_resolved()

Inferring kind as cat for:
> quoted_status_user_translator_type
> user_translator_type
> place_bounding_box_type
> place_place_type
> quoted_status_place_bounding_box_type
> quoted_status_place_place_type
> quoted_status_geo_type
> quoted_status_coordinates_type
> geo_type
> coordinates_type


In [90]:
print("Inferring kind as emb for:")
for col in inferred_dtypes.index:
    col: str
    if col.endswith("_emb"):
        print(f"> {col}")
        column_kinds[col] = "emb"

_mark_resolved()

Inferring kind as emb for:


In [91]:
print("Inferring kind as text for:")
for col in [
    "full_text",
    "user_description",
    "misc_text",
    "quoted_status_text",
    "quoted_status_extended_tweet_full_text",
    "quoted_status_user_description",
]:
    print(f"> {col}")
    column_kinds[col] = "text"

_mark_resolved()

Inferring kind as text for:
> full_text
> user_description
> misc_text
> quoted_status_text
> quoted_status_extended_tweet_full_text
> quoted_status_user_description


In [92]:
print("Inferring kind as cat for:")
for col in [
    "source",
    "in_reply_to_screen_name",
    "quoted_status_source",
    "quoted_status_in_reply_to_screen_name",
    "quoted_status_lang",
    "quoted_status_user_screen_name",
    "quoted_status_user_name",
    "quoted_status_user_location",
    "user_location",
    "place_country_code",
    "place_country",
    "place_full_name",
    "place_name",
    "place_id",
    "place_url",
    "quoted_status_place_country_code",
    "quoted_status_place_country",
    "quoted_status_place_full_name",
    "quoted_status_place_name",
    "quoted_status_place_id",
    "quoted_status_place_url",
    "source_name",
    "source_url",
    "source_domain",
    "user_domain",
]:
    print(f"> {col}")
    column_kinds[col] = "cat"

_mark_resolved()

Inferring kind as cat for:
> source
> in_reply_to_screen_name
> quoted_status_source
> quoted_status_in_reply_to_screen_name
> quoted_status_lang
> quoted_status_user_screen_name
> quoted_status_user_name
> quoted_status_user_location
> user_location
> place_country_code
> place_country
> place_full_name
> place_name
> place_id
> place_url
> quoted_status_place_country_code
> quoted_status_place_country
> quoted_status_place_full_name
> quoted_status_place_name
> quoted_status_place_id
> quoted_status_place_url
> source_name
> source_url
> source_domain
> user_domain


In [93]:
print("Inferring kind as skip for:")
for col in inferred_dtypes[inferred_dtypes == "string[python]"].index:
    print(f"> {col}")
    column_kinds[col] = "skip"

_mark_resolved()

Inferring kind as skip for:
> text
> timestamp_ms
> quoted_status_user_profile_image_url_https
> quoted_status_user_profile_background_image_url
> quoted_status_user_profile_background_image_url_https
> quoted_status_user_profile_link_color
> quoted_status_user_profile_background_color
> quoted_status_user_profile_sidebar_border_color
> quoted_status_user_profile_text_color
> quoted_status_user_profile_image_url
> quoted_status_user_url
> quoted_status_user_profile_banner_url
> quoted_status_user_profile_sidebar_fill_color
> quoted_status_permalink_expanded
> user_profile_image_url_https
> user_profile_background_image_url
> user_profile_background_image_url_https
> user_profile_link_color
> user_profile_background_color
> user_profile_sidebar_border_color
> user_profile_text_color
> user_profile_image_url
> user_url
> user_profile_banner_url
> user_profile_sidebar_fill_color
> extended_tweet_full_text
> user_hash


In [94]:
assert len(inferred_dtypes) == 0, "Some columns have not been categorized!"

In [95]:
column_kinds: pd.Series = pd.Series(column_kinds)

In [96]:
def preprocess2(df: pd.DataFrame) -> pd.DataFrame:
    # skip_cols = list(column_kinds[column_kinds == "skip"].index)
    # df = df.drop(columns=skip_cols, errors="ignore")
    df = df.copy()
    
    bool_cols = list(column_kinds[column_kinds == "bool"].index)
    df[bool_cols] = df[bool_cols].map({True: 1, False: -1, None: 0}.__getitem__).astype(int)
    
    cat_cols = list(column_kinds[column_kinds == "cat"].index)
    df[cat_cols] = df[cat_cols].fillna("none").astype(str)
    
    text_cols = list(column_kinds[column_kinds == "text"].index)
    df[text_cols] = df[text_cols].fillna("").astype(str)
    
    num_cols = list(column_kinds[column_kinds == "num"].index)
    df[num_cols] = df[num_cols].fillna(0).convert_dtypes()
    
    return df


In [97]:
X_train = preprocess2(X_train_pre)
X_kaggle = preprocess2(X_kaggle_pre)

  df[num_cols] = df[num_cols].fillna(0).convert_dtypes()
  if (arr.astype(int) == arr).all():
  if (arr.astype(int) == arr).all():
  df[num_cols] = df[num_cols].fillna(0).convert_dtypes()
  if (arr.astype(int) == arr).all():
  if (arr.astype(int) == arr).all():


# Models

In [98]:
def train_test_split_group(
    X: pd.DataFrame,
    y: pd.Series,
    group: pd.Series,
    test_size: float = 0.15,
    random_state: int = 42
) -> tuple[pd.DataFrame, pd.DataFrame, pd.Series, pd.Series]:
    gss = GroupShuffleSplit(
        test_size=test_size,
        n_splits=1,
        random_state=random_state,
    )
    
    train_idx, val_idx = next(gss.split(X, y, group))
    
    return X.iloc[train_idx], X.iloc[val_idx], y.iloc[train_idx], y.iloc[val_idx]


In [99]:
X_train_for_real, X_val, y_train_for_real, y_val = train_test_split_group(
    X_train,
    y_train,
    X_train["user_hash"],
    test_size=0.05,
    random_state=42,
)

In [100]:
def make_pool(X, y) -> catboost.Pool:
    num_cols = list(column_kinds[column_kinds == "num"].index)
    cat_cols = list(column_kinds[column_kinds.isin(["cat", "bool"])].index)
    text_cols = list(column_kinds[column_kinds == "text"].index)
    # emb_cols = list(column_kinds[column_kinds == "emb"].index)
    
    X = X[num_cols + cat_cols + text_cols]
    # X = X[num_cols + cat_cols + text_cols + emb_cols]
    
    return catboost.Pool(
        data=X,
        label=y,
        cat_features=cat_cols,
        text_features=text_cols,
        # embedding_features=emb_cols,
    )

train_pool = make_pool(X_train_for_real, y_train_for_real)
val_pool = make_pool(X_val, y_val)
kaggle_pool = make_pool(X_kaggle, None)


In [None]:
VERSION = "v17"

In [None]:
model = catboost.CatBoostClassifier(
    depth=9,
    learning_rate=0.1,
    iterations=2000,
    early_stopping_rounds=200,
    loss_function="Logloss",
    eval_metric="Accuracy",
    train_dir=f"./models/{VERSION}",
    task_type="GPU",
    devices="0",
    random_seed=42,
)


In [103]:
# param_grid = {
#     "depth": [4, 6, 8, 10],
#     "learning_rate": [0.05, 0.1],
#     "iterations": [1000, 2000],
# }

# model.grid_search(
#     param_grid,
#     train_pool,
#     verbose=100,
#     plot=True,
#     plot_file="models/v12/gridsearch/grid_search.info",
# )

In [None]:
model.fit(
    train_pool,
    eval_set=val_pool,
    verbose=100,
    plot=True,
)

model.save_model(f"models/{VERSION}/model.cbm")

In [None]:
preds = model.predict(kaggle_pool)
# preds = model.virtual_ensembles_predict(kaggle_pool)

In [112]:
def reconcile_answers(preds: typing.Sequence[int]) -> typing.Sequence[int]:
    df = X_kaggle[["user_hash", "challenge_id"]].copy()
    df["pred_label"] = preds
    
    per_user_stats: dict[str, list[int]] = dict()
    for _, row in df.iterrows():
        per_user_stats.setdefault(row["user_hash"], [0, 0])[int(row["pred_label"])] += 1

    per_user_correct: dict[tuple[str, str], int] = dict()
    for key, stats in per_user_stats.items():
        if stats[0] == 0 or stats[1] == 0:
            continue
        
        per_user_correct[key] = np.select(
            [stats[0] > stats[1], stats[1] > stats[0]],
            [0, 1],
            default=np.random.randint(0, 2),
        )

    for idx, row in df.iterrows():
        if row["user_hash"] in per_user_correct:
            df.at[idx, "pred_label"] = per_user_correct[row["user_hash"]]
    
    return df["pred_label"].tolist()

In [113]:
output = pd.DataFrame({"ID": kaggle_data["challenge_id"], "Prediction": preds})
output.to_csv(f"models/{VERSION}/predictions-{VERSION}.csv", index=False)

In [114]:
output = pd.DataFrame({"ID": kaggle_data["challenge_id"], "Prediction": reconcile_answers(preds)})
output.to_csv(f"models/{VERSION}/predictions-{VERSION}-reconciled.csv", index=False)