In [49]:
#!pip install emoji

In [50]:
from __future__ import annotations
import typing
import json
import pathlib
import os
import time
import datetime
import hashlib
import urllib.parse

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim import AdamW
from torch.nn.utils import clip_grad_norm_

import transformers
import transformers.modeling_outputs
import transformers.configuration_utils
from transformers import AutoTokenizer, AutoModel, AutoConfig

from nltk.corpus import stopwords
import emoji

import sklearn
from sklearn.decomposition import PCA
from sklearn.model_selection import GroupShuffleSplit

from tqdm import tqdm

import IPython
from IPython.display import display

In [51]:
IS_KAGGLE = "KAGGLE_DOCKER_IMAGE" in os.environ

DATASETS = pathlib.Path(
    "."
    if not IS_KAGGLE
    else "/kaggle/input/influencers-or-observers-predicting-social-roles/Kaggle2025"
)

DATASET_TRAIN = DATASETS / "train.jsonl"
DATASET_KAGGLE = DATASETS / "kaggle_test.jsonl"

CACHE_DIR = pathlib.Path(".")

In [52]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [53]:
torch.random.manual_seed(42)
np.random.seed(42)

# Data loading

In [54]:
def load_json(path: pathlib.Path, cache: bool = False) -> pd.DataFrame:
    path_pq = (CACHE_DIR / path.name).with_stem(f"{path.stem}_raw").with_suffix(".parquet")
    
    if cache and path_pq.exists():
        return pd.read_parquet(path_pq)
    
    # This leaves things to be desired, since there's no way to specify dtypes
    # and it assumes float instead of int, causing a loss in precision...
    # But I guess it only matters for ids, which we'll probably discard in preprocessing anyway
    result = pd.json_normalize(list(map(json.loads, path.read_bytes().splitlines())))
    
    if cache:
        result.to_parquet(path_pq)
    
    return result


In [55]:
train_data = load_json(DATASET_TRAIN, cache=True)
kaggle_data = load_json(DATASET_KAGGLE, cache=True)

# Preprocessing

In [56]:
def preprocess(df: pd.DataFrame) -> pd.DataFrame:
    # For technical reasons, any text columns we want to use should have no dots in their names.
    # The simplest way to achieve this is to replace all dots indiscriminately.
    
    df = df.rename(columns=lambda x: x.replace(".", "_"))
    
    df["is_reply"] = df["in_reply_to_status_id"].notna()
    
    user_id_key = ["user_description", "user_created_at", "user_profile_image_url"]
    df["user_hash"] = df[user_id_key].fillna("<NA>").astype(str).agg(''.join, axis=1).where(~df[user_id_key].isna().all(axis=1), df["id_str"]).map(fast_hash)
    
    df = df.drop(columns=[
        "in_reply_to_status_id_str",
        # "in_reply_to_status_id",
        "in_reply_to_user_id_str",
        "in_reply_to_user_id",
        "quoted_status_id_str",
        "quoted_status_id",
        "id_str",
        "quoted_status_in_reply_to_status_id_str",
        "quoted_status_in_reply_to_status_id",
        "quoted_status_in_reply_to_user_id_str",
        "quoted_status_in_reply_to_user_id",
        "quoted_status_id_str",
        "quoted_status_id",
        "quoted_status_user_id_str",
        "quoted_status_user_id",
        # "quoted_status_permalink_expanded",
        "quoted_status_permalink_display",
        "quoted_status_permalink_url",
        "quoted_status_quoted_status_id",
        "quoted_status_quoted_status_id_str",
        # "quoted_status_place_id",
        # "place_id",
        "lang",  # Always "fr"
        "retweeted",  # Always False
        "filter_level",  # Always "low"
        "geo",  # Always None
        "place",  # Always None
        "coordinates",  # Always None
        "contributors",  # Always None
        "quote_count",  # Always 0
        "reply_count",  # Always 0
        "retweet_count",  # Always 0
        "favorite_count",  # Always 0
        "favorited",  # Always False
        "quoted_status_geo",  # Always None
        "quoted_status_place",  # Always None
        "quoted_status_coordinates",  # Always None
        "quoted_status_retweeted",  # Always False
        "quoted_status_filter_level",  # Always "low"
        "quoted_status_contributors",  # Always None
        "quoted_status_user_utc_offset",  # Always None
        "quoted_status_user_lang",  # Always None
        "quoted_status_user_time_zone",  # Always None
        "quoted_status_user_follow_request_sent",  # Always None
        "quoted_status_user_following",  # Always None
        "quoted_status_user_notifications",  # Always None
        "user_default_profile_image",  # Always False
        "user_protected",  # Always False
        "user_contributors_enabled",  # Always False
        "user_lang",  # Always None
        "user_notifications",  # Always None
        "user_following",  # Always None
        "user_utc_offset",  # Always None
        "user_time_zone",  # Always None
        "user_follow_request_sent",  # Always None
    ])
    
    df["full_text"] = df.apply(lambda tweet: extract_full_text(tweet), axis=1)
    
    source_split = df["source"].str.removeprefix("<a href=\"").str.removesuffix("</a>").str.split("\" rel=\"nofollow\">").map(lambda x: x if len(x) == 2 else pd.NA)
    df["source_url"] = source_split.map(lambda x: x[0], na_action="ignore")
    df["source_name"] = source_split.map(lambda x: x[1], na_action="ignore")
    
    df["misc_text"] = df.apply(
        lambda x: "via: {0}; reply: @{1}; quote: @{2} {3}".format(x["source_name"], x["in_reply_to_screen_name"], x["quoted_status_user_screen_name"], x["quoted_status_user_name"]), axis=1,
    )
    
    df["source_domain"] = df["source_url"].map(extract_domain, na_action="ignore")
    df["user_domain"] = df["user_url"].str.extract(r"https?://([^/]+)/")
    
    for col in [
        "quoted_status_user_profile_link_color",
        "quoted_status_user_profile_background_color",
        "quoted_status_user_profile_sidebar_border_color",
        "quoted_status_user_profile_text_color",
        "user_profile_link_color",
        "user_profile_background_color",
        "user_profile_sidebar_border_color",
        "user_profile_text_color",
        "user_profile_sidebar_fill_color",
    ]:
        df[f"{col}_r"], df[f"{col}_g"], df[f"{col}_b"] = extract_color(df[col])
    
    return df

def extract_full_text(tweet: pd.Series) -> str:
    text: str = tweet["text"]
    
    if not pd.isna(tweet["extended_tweet_full_text"]):
        text = tweet["extended_tweet_full_text"]
    
    return text

def fast_hash(content: str) -> str:
    h = hashlib.blake2s(digest_size=16)
    h.update(content.encode('utf-8'))
    return h.hexdigest()

def extract_domain(url: str) -> str:
    return urllib.parse.urlparse(url).netloc

def extract_color(color: pd.Series) -> tuple[pd.Series, pd.Series, pd.Series]:
    return tuple(
        color.str.slice(i, i + 2).map(lambda x: int(x, 16), na_action="ignore")
        for i in (0, 2, 4)
    )


In [57]:
X_train = train_data.drop("label", axis=1)
y_train = train_data["label"]

X_kaggle = kaggle_data

X_train = preprocess(X_train)
X_kaggle = preprocess(X_kaggle)

# Data exploration

In [58]:
X_train.full_text.map(emoji.emoji_count, na_action="ignore").corr(y_train)

0.03703379602187509

In [59]:
emoji_freq = X_train.full_text.map(lambda x: [y["emoji"] for y in emoji.emoji_list(x)], na_action="ignore").explode().value_counts()
len(emoji_freq)

1569

In [60]:
top_emoji = list(emoji_freq[:512].index)
most_relevant_emoji = pd.Series({
    f"{emoji.demojize(em)} in full_text": X_train.full_text.map(lambda x: x.count(em), na_action="ignore").corr(y_train)
    for em in top_emoji
}).abs().sort_values(ascending=False).pipe(lambda x: x[x >= 0.03])

with pd.option_context("display.max_rows", None, "display.max_columns", None):
    display(most_relevant_emoji)

:backhand_index_pointing_right: in full_text    0.072842
:right_arrow: in full_text                      0.055572
:right_arrow_curving_down: in full_text         0.033804
:play_button: in full_text                      0.032676
dtype: float64

In [61]:
X_train[[col for col in X_train.columns if col.endswith("_r") or col.endswith("_g") or col.endswith("_b")]].corrwith(y_train).abs().sort_values(ascending=False).pipe(lambda x: x[x >= 0.05])


user_profile_link_color_b              0.245343
user_profile_background_color_r        0.196165
user_profile_background_color_g        0.192955
user_profile_link_color_g              0.189547
user_profile_background_color_b        0.184779
user_profile_link_color_r              0.164342
user_profile_sidebar_border_color_b    0.136743
user_profile_sidebar_fill_color_b      0.124482
user_profile_sidebar_border_color_g    0.119252
user_profile_sidebar_fill_color_g      0.115195
user_profile_sidebar_fill_color_r      0.107217
user_profile_sidebar_border_color_r    0.077069
dtype: float64

# Models

In [62]:
# Made this a class to hold all the caches. It may resemble an nn.Module, but isn't one!
class FeatureExtractor:
    mode: typing.Literal["train", "eval", "infer"]
    device: torch.device
    means: pd.Series | None
    stds: pd.Series | None
    afm_cache: dict[tuple[str, str], float]
    text_encoder_name: str | None
    text_tokenizer: nn.Module | None
    text_encoder: nn.Module | None
    text_config: transformers.configuration_utils.PretrainedConfig | None
    text_enc_cache_path: pathlib.Path | None
    
    def __init__(
        self,
        text_encoder_name: str | None = None,
        text_enc_cache_path: pathlib.Path | None = None,
        device: torch.device = device,
    ):
        super().__init__()
        self.device = device
        self.means = None
        self.stds = None
        self.afm_cache = {}
        self.text_enc_cache_path = text_enc_cache_path
        
        self.text_encoder_name = text_encoder_name
        self.text_tokenizer = None
        self.text_encoder = None
        self.text_config = None
        
        if text_encoder_name is not None:
            self.text_tokenizer = AutoTokenizer.from_pretrained(text_encoder_name)
            if hasattr(self.text_tokenizer, "to"):  # Distilbert doesn't, apprently
                self.text_tokenizer = self.text_tokenizer.to(self.device)
            self.text_encoder = AutoModel.from_pretrained(text_encoder_name).to(self.device)
            self.text_config = self.text_encoder.config
        
        self.train()
    
    def train(self):
        self.mode = "train"
    
    def eval(self):
        self.mode = "eval"
    
    def infer(self):
        self.mode = "infer"
    
    def state_dict(self):
        return {
            "means": self.means,
            "stds": self.stds,
            "afm_cache": self.afm_cache,
        }
    
    def load_state_dict(self, state_dict):
        self.means = state_dict["means"]
        self.stds = state_dict["stds"]
        self.afm_cache = state_dict["afm_cache"]
    
    def dims(self) -> dict[str, int]:
        return {
            "md": len(self.METADATA_FIELDS),
            "substrings": len(self.SUBSTRINGS),
        } | {
            field: compress or self.embed_size
            for field, compress in self.TEXT_FIELDS
        }
    
    @property
    def embed_size(self) -> int:
        return self.text_config.hidden_size
    
    def extract(self, df: pd.DataFrame, override_cache: bool = False) -> dict[str, torch.Tensor]:
        result: dict[str, torch.Tensor] = {}
        
        if self.text_enc_cache_path is None:
            self._extract(df, result)
            return result
        
        cf = self.text_enc_cache_path / f"{self.mode}.ckpt"
        cf.parent.mkdir(parents=True, exist_ok=True)
        if cf.exists() and not override_cache:
            encodings: dict[str, torch.Tensor] = torch.load(cf)
            for col_name, value in encodings.items():
                result[col_name] = value.to(self.device)
        
        keys_pre = len(result)
        self._extract(df, result)
        keys_post = len(result)
        
        if keys_post > keys_pre:
            torch.save({
                field: embedding.cpu().detach()
                for field, embedding in result.items()
            }, cf)
        
        return result
    
    def _extract(self, df: pd.DataFrame, result: dict[str, torch.Tensor]):
        if "md" not in result:
            result["md"] = self.extract_raw_metadata(df)
        
        if "substrings" not in result:
            result["substrings"] = torch.tensor([df["full_text"].str.count(x).fillna(0).to_numpy() for x in self.SUBSTRINGS], dtype=torch.float32, device=self.device).T
        
        for col_name, compress in self.TEXT_FIELDS:
            if col_name in result:
                continue
            
            emb = self.embed_texts(df[col_name])
            
            if compress is not None and compress < emb.shape[1]:
                pca = PCA(n_components=compress)
                emb = pca.fit_transform(emb.cpu().detach().numpy())
                emb = torch.tensor(emb, dtype=torch.float32, device=self.device)
            elif compress is not None:
                print(f"Warning: embedding for {col_name} zero-padded from {emb.shape[1]} to {compress}, consider reducing requested size")
                emb = torch.nn.functional.pad(emb, (0, compress - emb.shape[1]))
            
            result[col_name] = emb
    
    def extract_raw_metadata(self, df: pd.DataFrame) -> torch.Tensor:
        md_cols: list[pd.Series] = []

        for fn, col_name in tqdm(self.METADATA_FIELDS, desc="Extracting metadata"):
            md_cols.append(fn(self, df[col_name]))
        
        md: pd.DataFrame = pd.concat(md_cols, axis=1)
        
        # The second case shouldn't be triggered, but sometimes the preprocessor used during training is lost
        if self.mode == "train":
            self.means = md.mean().fillna(0)
            self.stds = md.std().fillna(1)
        
        assert self.means is not None and self.stds is not None, "You forgot to train/load the feature extractor"

        md = (md - self.means) / self.stds

        return torch.from_numpy(md.to_numpy()).float().to(self.device)

    def embed_texts(
        self,
        texts: pd.Series,
        batch_size: int = 64,
        progress: bool = True
    ) -> torch.Tensor:
        tokenizer = self.text_tokenizer
        encoder = self.text_encoder
        encoder.eval()

        all_embeddings = []

        with torch.no_grad():
            batch_offsets = range(0, len(texts), batch_size)
            if progress:
                batch_offsets = tqdm(batch_offsets, desc=f"Embedding {texts.name or '<unnamed>'}")
            for i in batch_offsets:
                batch_texts = texts.iloc[i:i + batch_size]
                nonna = batch_texts.notna() & batch_texts.str.len().gt(0)

                tokenized = tokenizer(
                    batch_texts[nonna].tolist(),
                    padding=True,
                    truncation=True,
                    return_tensors="pt",
                    max_length=self.text_config.max_position_embeddings
                ).to(self.device)

                outputs: transformers.modeling_outputs.BaseModelOutput = encoder(**tokenized)
                last_hidden: torch.Tensor = outputs.last_hidden_state
                mask: torch.Tensor = tokenized["attention_mask"].unsqueeze(-1)
                
                masked_hidden = last_hidden * mask
                summed = masked_hidden.sum(dim=1)
                counts = mask.sum(dim=1)
                embeddings = torch.zeros(len(batch_texts), last_hidden.shape[2], device=self.device)
                nonna = nonna.reset_index(drop=True)
                embeddings[nonna[nonna].index] = (summed / counts)

                all_embeddings.append(embeddings)

        return torch.cat(all_embeddings, dim=0)
    
    def apply_fill_mean(
        self,
        col: pd.Series,
        func: typing.Callable[[typing.Any], typing.Any],
    ) -> pd.Series:
        col = col.map(func, na_action="ignore")
        
        key = (col.name, func.__name__)
        if self.mode == "train":
            self.afm_cache[key] = col.mean()
        assert key in self.afm_cache, "You forgot to train/load the feature extractor"
        
        return col.fillna(self.afm_cache[key])
    
    def md_bool(self, col: pd.Series) -> pd.Series:
        return col.map(lambda x: (1 if x else -1), na_action="ignore").fillna(0)

    def md_len(self, col: pd.Series) -> pd.Series:
        return col.map(len, na_action="ignore").fillna(0)

    def md_time(self, col: pd.Series) -> pd.Series:
        return self.apply_fill_mean(col, lambda x: time.mktime(time.strptime(x, "%a %b %d %H:%M:%S %z %Y")))

    def md_num(self, col: pd.Series) -> pd.Series:
        return self.apply_fill_mean(col, pd.to_numeric)

    def md_place(self, col: pd.Series) -> pd.Series:
        return col.map(lambda x: int(x, 16), na_action="ignore").fillna(0)
    
    def md_emoji_count(self, col: pd.Series) -> pd.Series:
        return col.map(emoji.emoji_count, na_action="ignore").fillna(0)
    
    METADATA_FIELDS: list[tuple[typing.Callable[[FeatureExtractor, pd.Series], pd.Series], str]] = [
        (md_bool, "is_quote_status"),
        (md_bool, "is_reply"),
        (md_bool, "possibly_sensitive"),
        (md_bool, "quoted_status_user_verified"),
        (md_bool, "user_is_translator"),
        (md_bool, "user_geo_enabled"),
        (md_bool, "user_profile_use_background_image"),
        (md_bool, "user_default_profile"),
        
        (md_len, "full_text"),
        (md_len, "source_name"),
        (md_len, "in_reply_to_screen_name"),
        (md_len, "quoted_status_extended_tweet_entities_urls"),
        (md_len, "quoted_status_extended_tweet_entities_user_mentions"),
        (md_len, "quoted_status_extended_tweet_full_text"),
        (md_len, "quoted_status_entities_urls"),
        (md_len, "quoted_status_user_profile_image_url_https"),
        (md_len, "quoted_status_user_profile_background_image_url"),
        (md_len, "quoted_status_user_profile_background_image_url_https"),
        (md_len, "quoted_status_user_screen_name"),
        (md_len, "quoted_status_user_name"),
        (md_len, "entities_hashtags"),
        (md_len, "entities_user_mentions"),
        (md_len, "user_profile_image_url_https"),
        (md_len, "user_profile_background_image_url"),
        (md_len, "user_description"),
        (md_len, "user_translator_type"),
        (md_len, "user_url"),
        (md_len, "user_profile_banner_url"),
        (md_len, "user_location"),
        (md_len, "display_text_range"),
        (md_len, "extended_tweet_entities_urls"),
        (md_len, "extended_tweet_entities_hashtags"),
        (md_len, "extended_tweet_entities_user_mentions"),
        (md_len, "quoted_status_permalink_expanded"),
        
        (md_time, "created_at"),
        (md_time, "user_created_at"),
        (md_time, "quoted_status_created_at"),
        (md_time, "quoted_status_user_created_at"),
        
        (md_num, "user_statuses_count"),
        (md_num, "user_listed_count"),
        (md_num, "user_favourites_count"),
        (md_num, "user_profile_background_tile"),
        (md_num, "quoted_status_quote_count"),
        (md_num, "quoted_status_user_followers_count"),
        (md_num, "quoted_status_user_favourites_count"),
        (md_num, "in_reply_to_status_id"),
        (md_num, "user_profile_link_color_b"),
        (md_num, "user_profile_background_color_r"),
        (md_num, "user_profile_background_color_g"),
        (md_num, "user_profile_link_color_g"),
        (md_num, "user_profile_background_color_b"),
        (md_num, "user_profile_link_color_r"),
        (md_num, "user_profile_sidebar_border_color_b"),
        (md_num, "user_profile_sidebar_fill_color_b"),
        (md_num, "user_profile_sidebar_border_color_g"),
        (md_num, "user_profile_sidebar_fill_color_g"),
        (md_num, "user_profile_sidebar_fill_color_r"),
        (md_num, "user_profile_sidebar_border_color_r"),
        
        (md_place, "quoted_status_place_id"),
        (md_place, "place_id"),
        
        (md_emoji_count, "full_text"),
    ]

    TEXT_FIELDS: list[tuple[str, int | None]] = [
        ("full_text", None),
        ("user_description", None),
        ("misc_text", None),
        # ("source_name", None),
        # ("in_reply_to_screen_name", None),
        # ("quoted_status_user_screen_name", None),
        # ("quoted_status_user_name", None),
    ]
    
    SUBSTRINGS: list[str] = [
        emoji.demojize(em)
        for em in [
            ":backhand_index_pointing_right:",
            ":right_arrow:",
            ":right_arrow_curving_down:",
            ":play_button:",
        ]
    ]


In [63]:
class TweetDataset(Dataset):
    features: dict[str, torch.Tensor]
    labels: torch.Tensor
    device: torch.device
    
    def __init__(
        self,
        feature_extractor: FeatureExtractor,
        df: pd.DataFrame,
        labels: pd.Series,
        device: torch.device = device,
    ):
        self.features = feature_extractor.extract(df)
        self.labels = torch.tensor(labels.to_numpy(), dtype=torch.long, device=device)

    def __len__(self):
        return len(self.features["md"])

    def __getitem__(self, idx):
        return {
            "features": {key: val[idx] for key, val in self.features.items()},
            "label": self.labels[idx],
        }


def collate_fn(batch):
    features = {
        key: torch.stack([x["features"][key] for x in batch])
        for key in batch[0]["features"].keys()
    }
    labels = torch.stack([x["label"] for x in batch])
    return features, labels


In [78]:
NUM_CLASSES = 2

class TweetClassifier(nn.Module):
    feature_sizes: dict[str, int]
    
    layer1: nn.ModuleDict
    fc2: nn.Linear
    fc3: nn.Linear
    
    def __init__(
        self,
        feature_sizes: dict[str, int],
        hidden_dim: int = 512,
    ):
        super().__init__()

        self.feature_sizes = feature_sizes
        
        self.layer1 = nn.ModuleDict()
        
        def _add(name, dropout: float):
            self.layer1[name] = nn.Sequential(
                nn.Dropout(dropout),
                nn.Linear(feature_sizes[name], hidden_dim),
            )
        
        _add("md", 0.1)
        _add("substrings", 0.1)
        _add("full_text", 0.1)
        _add("user_description", 0.4)
        _add("misc_text", 0.1)
        
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, NUM_CLASSES)
    
    @property
    def device(self) -> torch.device:
        return next(self.parameters()).device
    
    def forward(self, features: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
        batch_size = len(features["md"])
        
        x = torch.zeros(batch_size, self.fc2.in_features, device=self.device)
        # x = torch.zeros(batch_size, self.fc3.in_features, device=self.device)
        
        for name, module in self.layer1.items():
            x += module(features[name])

        x = F.relu(x)
        x = F.relu(self.fc2(x))
        logits = self.fc3(x)
        probs = F.softmax(logits, dim=-1)
        log_probs = F.log_softmax(logits, dim=-1)

        return {
            "logits": logits,
            "probs": probs,
            "log_probs": log_probs,
        }


In [65]:

def train_model(
    model: TweetClassifier,
    train_ds: Dataset,
    val_ds: Dataset,
    epochs: int = 3,
    lr: float = 2e-4,
    weight_decay: float = 0.01,  # TODO: Lower?
    max_grad_norm: float = 1.0,
    freeze_components_after: dict[str, int] | None = None,
    device: torch.device = device,
    batch_size: int = 32,
    optimizer: torch.optim.Optimizer | None = None,
    checkpoints_path: pathlib.Path | str | None = ".",
    return_best: bool = False,
) -> TweetClassifier:
    if freeze_components_after is None:
        freeze_components_after = {}
    
    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=collate_fn,
    )
    
    model.to(device)
    if optimizer is None:
        optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    
    for param in model.parameters():
        param.requires_grad = True
    
    criterion = torch.nn.CrossEntropyLoss()
    
    best_val_loss = float("inf")
    best_model_file: pathlib.Path | None = None

    for epoch in range(1, epochs + 1):
        print(f"Epoch {epoch}/{epochs}")
        model.train()
        total_loss = 0.0

        status_bar = tqdm(train_loader, desc="Training")

        for features, labels in status_bar:
            features: dict[str, torch.Tensor]
            labels: torch.Tensor
            features = {k: v.to(device) for k, v in features.items()}
            labels = labels.to(device)

            optimizer.zero_grad(set_to_none=True)

            out = model(features)
            logits = out["logits"]
            
            loss: torch.Tensor = criterion(logits, labels)
            loss.backward()
            clip_grad_norm_(model.parameters(), max_grad_norm)
            optimizer.step()

            total_loss += loss.item()
            status_bar.set_postfix({"loss": total_loss / (status_bar.n + 1)})
        
        print(f"Train Loss: {total_loss / len(train_loader):.4f}")
        
        val_metrics = evaluate_model(
            model=model,
            val_ds=val_ds,
            device=device,
            batch_size=batch_size,
        )

        print(f"Val Loss: {val_metrics['loss']:.4f}, Acc: {val_metrics['acc']:.4f}")

        if checkpoints_path is not None:
            ckpt = pathlib.Path(checkpoints_path) / f"epoch_{epoch:02}.pt"
            torch.save(model.state_dict(), ckpt)
            print(f"Checkpoint saved to {ckpt}")
            
            if val_metrics["loss"] < best_val_loss:
                best_val_loss = val_metrics["loss"]
                best_model_file = ckpt
        
        for comp, e in freeze_components_after.items():
            if e != epoch:
                continue
            
            print(f"Freezing {comp} for the rest of training")
            for param in model.layer1[comp].parameters():
                param.requires_grad = False

    if return_best and best_model_file is not None:
        print(f"Best model: {best_model_file}")
        model.load_state_dict(torch.load(best_model_file))
    
    return model


def evaluate_model(
    model: TweetClassifier,
    val_ds: Dataset,
    device: torch.device = device,
    batch_size: int = 32,
) -> tuple[float, float]:
    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=collate_fn,
    )
    
    model.eval()
    criterion = torch.nn.CrossEntropyLoss()

    total_loss = 0.0
    correct = 0
    count = 0

    with torch.no_grad():
        status_bar = tqdm(val_loader, desc="Evaluating")
        
        for features, labels in status_bar:
            features: dict[str, torch.Tensor]
            labels: torch.Tensor
            features = {k: v.to(device) for k, v in features.items()}
            labels = labels.to(device)

            out = model(features)
            logits: torch.Tensor = out["logits"]
            
            loss: torch.Tensor = criterion(logits, labels)
            total_loss += loss.item() * labels.size(0)
            preds = logits.argmax(dim=-1)
            correct += (preds == labels).sum().item()
            count += labels.size(0)
            
            status_bar.set_postfix({"loss": total_loss / count, "acc": correct / count})

    return {
        "loss": total_loss / count,
        "acc": correct / count,
    }


In [66]:
def infer_with_model(
    model: TweetClassifier,
    feature_extractor: FeatureExtractor,
    df: pd.DataFrame,
    out_file: pathlib.Path | str | None = None,
    device: torch.device = device,
    batch_size: int = 32,
) -> pd.Series:
    feature_extractor.infer()
    data_loader = DataLoader(
        TweetDataset(feature_extractor, df, pd.Series(np.zeros(len(df))), device=device),
        batch_size=batch_size,
        shuffle=False,
        collate_fn=collate_fn,
    )
    
    model.eval()
    
    predictions = torch.zeros(len(df), dtype=torch.long)
    cur_idx = 0
    
    with torch.no_grad():
        for features, _ in tqdm(data_loader, desc="Inferring"):
            features: dict[str, torch.Tensor]
            features = {k: v.to(device) for k, v in features.items()}

            out = model(features)
            logits: torch.Tensor = out["logits"].cpu()
            
            predictions[cur_idx:cur_idx+len(features["md"])] = logits.argmax(dim=-1)
            cur_idx += len(features["md"])
    
    df = df.copy()
    df["pred_label"] = pd.Series(predictions).astype(int)

    # Reconciliation between same users
    # Note: the pandas implementation OOM-ed, but the pure python one seems a lot slower
    per_user_stats: dict[str, list[int]] = dict()
    for _, row in df.iterrows():
        per_user_stats.setdefault(row["user_hash"], [0, 0])[int(row["pred_label"])] += 1
    
    per_user_correct: dict[tuple[str, str], int] = dict()
    for key, stats in per_user_stats.items():
        if stats[0] == 0 or stats[1] == 0:
            continue
        
        per_user_correct[key] = np.select(
            [stats[0] > stats[1], stats[1] > stats[0]],
            [0, 1],
            default=np.random.randint(0, 2),
        )
    
    del per_user_stats
    
    for idx, row in df.iterrows():
        key = row["user_hash"]
        if key in per_user_correct:
            per_user_correct[key]
            df.at[idx, "pred_label"] = per_user_correct[key]
    
    if out_file is not None:
        output = df[["challenge_id", "pred_label"]]
        output.columns = ["ID", "Prediction"]
        output.to_csv(out_file, index=False)
    
    return df["pred_label"]


In [67]:
def train_test_split_group(
    X: pd.DataFrame,
    y: pd.Series,
    group: pd.Series,
    test_size: float = 0.15,
    random_state: int = 42
) -> tuple[pd.DataFrame, pd.DataFrame, pd.Series, pd.Series]:
    gss = GroupShuffleSplit(
        test_size=test_size,
        n_splits=1,
        random_state=random_state,
    )
    
    train_idx, val_idx = next(gss.split(X, y, group))
    
    return X.iloc[train_idx], X.iloc[val_idx], y.iloc[train_idx], y.iloc[val_idx]


# Training & inference

In [None]:
text_encoder_name = "camembert/camembert-large"
version = "v21"

print(f"Running: {text_encoder_name} {version}")

model_folder = pathlib.Path("./models/") / version / text_encoder_name.split("/")[-1]
model_folder.mkdir(exist_ok=True, parents=True)

# I guess a global cache does make more sense
text_enc_cache_path = pathlib.Path("./text_enc_cache/") / text_encoder_name.split("/")[-1]
text_enc_cache_path.mkdir(exist_ok=True, parents=True)

feature_extractor = FeatureExtractor(text_encoder_name=text_encoder_name, text_enc_cache_path=text_enc_cache_path, device=device)

f_ext_ckpt = text_enc_cache_path / "feature_extractor.ckpt"
if f_ext_ckpt.exists():
    feature_extractor.load_state_dict(torch.load(f_ext_ckpt, weights_only=False))  # This has a pd.Series in it, so otherwise torch 2.6+ complains

# Note: changing the split invalidates the cache now!
X_train_for_real, X_val, y_train_for_real, y_val = train_test_split_group(X_train, y_train, X_train["user_hash"], test_size=0.1, random_state=42)

feature_extractor.train()
train_ds = TweetDataset(feature_extractor, X_train_for_real, y_train_for_real, device=device)

torch.save(feature_extractor.state_dict(), f_ext_ckpt)

feature_extractor.eval()
val_ds = TweetDataset(feature_extractor, X_val, y_val, device=device)

model = TweetClassifier(
    feature_sizes=feature_extractor.dims(),
    # hidden_dim=512,
    # hidden_dim=384,
    # hidden_dim=768,
    hidden_dim=256,
).to(device)

model = train_model(
    model,
    train_ds,
    val_ds,
    lr=2e-4,
    epochs=10,
    freeze_components_after={
        "user_description": 5,
    },
    batch_size=64,
    device=device,
    checkpoints_path=model_folder,
    return_best=True,
)
# torch.save(model.state_dict(), model_folder / "best_model.ckpt")

infer_with_model(
    model,
    feature_extractor,
    X_kaggle,
    batch_size=64,
    device=device,
    out_file=model_folder / f"predictions-{version}.csv",
)

In [81]:
text_encoder_name = "camembert/camembert-large"
version = "v21"
epoch = 5

print(f"Running pretrained: {text_encoder_name} {version} e{epoch:02}")

model_folder = pathlib.Path("./models/") / version / text_encoder_name.split("/")[-1]
model_folder.mkdir(exist_ok=True, parents=True)

text_enc_cache_path = pathlib.Path("./text_enc_cache/") / text_encoder_name.split("/")[-1]
text_enc_cache_path.mkdir(exist_ok=True, parents=True)

feature_extractor = FeatureExtractor(text_encoder_name=text_encoder_name, text_enc_cache_path=text_enc_cache_path, device=device)
feature_extractor.load_state_dict(torch.load(text_enc_cache_path / "feature_extractor.ckpt", weights_only=False))

model = TweetClassifier(
    feature_sizes=feature_extractor.dims(),
    hidden_dim=256,
).to(device)
model.load_state_dict(torch.load(model_folder / f"epoch_{epoch:02}.pt"))

good_predictions = infer_with_model(model, feature_extractor, X_kaggle, batch_size=64, device=device, out_file=model_folder / f"predictions-{version}-e{epoch:02}.csv")

Running pretrained: camembert/camembert-large v21 e05


Some weights of CamembertModel were not initialized from the model checkpoint at camembert/camembert-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Inferring: 100%|██████████| 1616/1616 [00:07<00:00, 204.29it/s]
