In [55]:
from __future__ import annotations
import typing
import json
import pathlib
import os
import time
import datetime

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim import AdamW
from torch.nn.utils import clip_grad_norm_

import transformers
import transformers.modeling_outputs
import transformers.configuration_utils
from transformers import AutoTokenizer, AutoModel, AutoConfig

from nltk.corpus import stopwords

import sklearn
from sklearn.decomposition import PCA
from sklearn.model_selection import GroupShuffleSplit
from tqdm import tqdm

import IPython
from IPython.display import display

In [41]:
IS_KAGGLE = "KAGGLE_DOCKER_IMAGE" in os.environ

DATASETS = pathlib.Path(
    "."
    if not IS_KAGGLE
    else "/kaggle/input/influencers-or-observers-predicting-social-roles/Kaggle2025"
)

DATASET_TRAIN = DATASETS / "train.jsonl"
DATASET_KAGGLE = DATASETS / "kaggle_test.jsonl"

CACHE_DIR = pathlib.Path(".")
VERSION = "v11-vaughn"

In [42]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [43]:
torch.random.manual_seed(42)
np.random.seed(42)

# Data loading

In [44]:
def load_json(path: pathlib.Path, cache: bool = False) -> pd.DataFrame:
    path_pq = (CACHE_DIR / path.name).with_stem(f"{path.stem}_raw").with_suffix(".parquet")
    
    if cache and path_pq.exists():
        return pd.read_parquet(path_pq)
    
    # This leaves things to be desired, since there's no way to specify dtypes
    # and it assumes float instead of int, causing a loss in precision...
    # But I guess it only matters for ids, which we'll probably discard in preprocessing anyway
    result = pd.json_normalize(list(map(json.loads, path.read_bytes().splitlines())))
    
    if cache:
        result.to_parquet(path_pq)
    
    return result


In [45]:
train_data = load_json(DATASET_TRAIN, cache=True)
kaggle_data = load_json(DATASET_KAGGLE, cache=True)

# Preprocessing

In [46]:
def preprocess(df: pd.DataFrame) -> pd.DataFrame:
    # For technical reasons, any text columns we want to use should have no dots in their names.
    # The simplest way to achieve this is to replace all dots indiscriminately.
    
    df = df.rename(columns=lambda x: x.replace(".", "_"))
    
    df["is_reply"] = df["in_reply_to_status_id"].notna()
    
    df = df.drop(columns=[
        "in_reply_to_status_id_str",
        # "in_reply_to_status_id",
        "in_reply_to_user_id_str",
        "in_reply_to_user_id",
        "quoted_status_id_str",
        "quoted_status_id",
        "id_str",
        "quoted_status_in_reply_to_status_id_str",
        "quoted_status_in_reply_to_status_id",
        "quoted_status_in_reply_to_user_id_str",
        "quoted_status_in_reply_to_user_id",
        "quoted_status_id_str",
        "quoted_status_id",
        "quoted_status_user_id_str",
        "quoted_status_user_id",
        # "quoted_status_permalink_expanded",
        "quoted_status_permalink_display",
        "quoted_status_permalink_url",
        "quoted_status_quoted_status_id",
        "quoted_status_quoted_status_id_str",
        # "quoted_status_place_id",
        # "place_id",
        "lang",  # Always "fr"
        "retweeted",  # Always False
        "filter_level",  # Always "low"
        "geo",  # Always None
        "place",  # Always None
        "coordinates",  # Always None
        "contributors",  # Always None
        "quote_count",  # Always 0
        "reply_count",  # Always 0
        "retweet_count",  # Always 0
        "favorite_count",  # Always 0
        "favorited",  # Always False
        "quoted_status_geo",  # Always None
        "quoted_status_place",  # Always None
        "quoted_status_coordinates",  # Always None
        "quoted_status_retweeted",  # Always False
        "quoted_status_filter_level",  # Always "low"
        "quoted_status_contributors",  # Always None
        "quoted_status_user_utc_offset",  # Always None
        "quoted_status_user_lang",  # Always None
        "quoted_status_user_time_zone",  # Always None
        "quoted_status_user_follow_request_sent",  # Always None
        "quoted_status_user_following",  # Always None
        "quoted_status_user_notifications",  # Always None
        "user_default_profile_image",  # Always False
        "user_protected",  # Always False
        "user_contributors_enabled",  # Always False
        "user_lang",  # Always None
        "user_notifications",  # Always None
        "user_following",  # Always None
        "user_utc_offset",  # Always None
        "user_time_zone",  # Always None
        "user_follow_request_sent",  # Always None
    ])
    
    df["full_text"] = df.apply(lambda tweet: extract_full_text(tweet), axis=1)
    
    source_split = df["source"].str.removeprefix("<a href=\"").str.removesuffix("</a>").str.split("\" rel=\"nofollow\">").map(lambda x: x if len(x) == 2 else pd.NA)
    df["source_url"] = source_split.map(lambda x: x[0], na_action="ignore")
    df["source_name"] = source_split.map(lambda x: x[1], na_action="ignore")
    
    df["misc_text"] = df.apply(
        lambda x: "via: {0}; reply: @{1}; quote: @{2} {3}".format(x["source_name"], x["in_reply_to_screen_name"], x["quoted_status_user_screen_name"], x["quoted_status_user_name"]), axis=1,
    )
    
    return df


def extract_full_text(tweet: pd.Series) -> str:
    text: str = tweet["text"]
    
    if not pd.isna(tweet["extended_tweet_full_text"]):
        text = tweet["extended_tweet_full_text"]
    
    return text


In [47]:
X_train = train_data.drop("label", axis=1)
y_train = train_data["label"]

X_kaggle = kaggle_data

X_train = preprocess(X_train)
X_kaggle = preprocess(X_kaggle)

# Data exploration

# Models

In [48]:
# Made this a class to hold all the caches. It may resemble an nn.Module, but isn't one!
class FeatureExtractor:
    text_encoder_name: str | None
    text_tokenizer: nn.Module | None
    text_encoder: nn.Module | None
    text_config: transformers.configuration_utils.PretrainedConfig | None
    text_enc_cache_path: pathlib.Path | None
    
    # New attribute to hold pre-computed embeddings
    text_encodings: dict[str, dict[str, torch.Tensor]]
    
    def __init__(
        self,
        text_encoder_name: str | None = None,
        text_enc_cache_path: pathlib.Path | None = None,
        device: torch.device = device,
    ):
        # super().__init__() # Removed this line as FeatureExtractor is not an nn.Module
        self.device = device
        self.means = None
        self.stds = None
        self.afm_cache = {}
        self.text_enc_cache_path = text_enc_cache_path
        
        self.text_encoder_name = text_encoder_name
        self.text_tokenizer = None
        self.text_encoder = None
        self.text_config = None
        self.text_encodings = {"train": {}, "infer": {}} # Initialize dict for train/infer cache
        
        # Load Text Encoder/Tokenizer only if text_encoder_name is provided
        if text_encoder_name is not None:
            self.text_tokenizer = AutoTokenizer.from_pretrained(text_encoder_name)
            if hasattr(self.text_tokenizer, "to"):
                self.text_tokenizer = self.text_tokenizer.to(self.device)
            self.text_encoder = AutoModel.from_pretrained(text_encoder_name).to(self.device)
            self.text_config = self.text_encoder.config
            
        self.train() # Default to training mode
    
    def freeze_encoder(self):
        if self.text_encoder is not None:
            for param in self.text_encoder.parameters():
                param.requires_grad = False
            self.text_encoder.eval()
            print("Text encoder frozen.")

    def unfreeze_encoder(self):
        if self.text_encoder is not None:
            for param in self.text_encoder.parameters():
                param.requires_grad = True
            print("Text encoder unfrozen.")

    def train(self):
        self.training = True
        if self.text_encoder is not None:
            self.text_encoder.train()
    
    def eval(self):
        self.training = False
        if self.text_encoder is not None:
            self.text_encoder.eval()
    
    def state_dict(self):
        return {
            "means": self.means,
            "stds": self.stds,
            "afm_cache": self.afm_cache,
        }
    
    def load_state_dict(self, state_dict):
        self.means = state_dict["means"]
        self.stds = state_dict["stds"]
        self.afm_cache = state_dict["afm_cache"]
    
    def dims(self) -> dict[str, int]:
        return {
            "md": len(self.METADATA_FIELDS),
        } | {
            field: compress or self.embed_size
            for field, compress in self.TEXT_FIELDS
        }
    
    @property
    def embed_size(self) -> int:
        return self.text_config.hidden_size
    
    def extract(self, df: pd.DataFrame, split_name: str, override_cache: bool = False) -> dict[str, torch.Tensor]:
        """
        Extracts features, loading or computing text embeddings from cache.
        split_name should be 'train' or 'infer'
        """
        result: dict[str, torch.Tensor] = {}
        cache_key = split_name
        
        # 1. Metadata extraction (always computed)
        result["md"] = self.extract_raw_metadata(df)
        
        # 2. Text embedding extraction (cached)
        cf = self.text_enc_cache_path / f"{cache_key}.ckpt"
        cf.parent.mkdir(parents=True, exist_ok=True)
        
        if not override_cache and cf.exists():
            print(f"Loading cached encodings for {split_name}...")
            # Load embeddings into the internal dictionary first
            self.text_encodings[split_name] = torch.load(cf)
            
            # Transfer to the result dictionary
            for col_name, value in self.text_encodings[split_name].items():
                result[col_name] = value.to(self.device)
                
            # Perform PCA/padding only on loaded tensors if needed
            for col_name, compress in self.TEXT_FIELDS:
                if col_name in result:
                     if compress is not None and compress != result[col_name].shape[1]:
                         # This part is complex if PCA was applied, best to ensure PCA is part of the initial encoding if cached.
                         # Since the original notebook applied PCA *after* encoding but *before* caching, 
                         # we assume the cached tensor is the final (potentially PCA'd/padded) result.
                         # If you need to re-apply PCA after loading, you must store the original embeddings and PCA components.
                         # For now, we assume the cached size is the intended size (either original or compressed).
                         pass 
        else:
            print(f"Computing and caching embeddings for {split_name}...")
            self.text_encodings[split_name] = {}
            for col_name, compress in self.TEXT_FIELDS:
                emb = self.embed_texts(df[col_name])
                
                # Apply PCA/Padding
                if compress is not None and compress < emb.shape[1]:
                    pca = PCA(n_components=compress)
                    # PCA requires NumPy/CPU, ensure the tensor is on CPU before converting to NumPy
                    emb_np = emb.cpu().detach().numpy()
                    emb_np_compressed = pca.fit_transform(emb_np)
                    emb = torch.tensor(emb_np_compressed, dtype=torch.float32, device=self.device)
                    print(f"Applied PCA to {col_name} reducing size from {emb_np.shape[1]} to {compress}")
                elif compress is not None and compress > emb.shape[1]:
                    print(f"Warning: embedding for {col_name} zero-padded from {emb.shape[1]} to {compress}")
                    emb = torch.nn.functional.pad(emb, (0, compress - emb.shape[1]))
                
                result[col_name] = emb
                self.text_encodings[split_name][col_name] = emb.cpu().detach()
                
            # Save the computed embeddings
            torch.save(self.text_encodings[split_name], cf)
            print(f"Encodings saved to {cf}")
            
        return result
    
    # The _extract method is removed as its logic is now inside extract
    
    def extract_raw_metadata(self, df: pd.DataFrame) -> torch.Tensor:
        # ... (Keep existing implementation of extract_raw_metadata)
        md_cols: list[pd.Series] = []

        for fn, col_name in tqdm(self.METADATA_FIELDS, desc="Extracting metadata"):\
            md_cols.append(fn(self, df[col_name]))
        
        md: pd.DataFrame = pd.concat(md_cols, axis=1)
        
        if self.training:
            self.means = md.mean().fillna(0)
            self.stds = md.std().fillna(1)
        
        assert self.means is not None and self.stds is not None, "You forgot to train/load the feature extractor"

        md = (md - self.means) / self.stds

        return torch.from_numpy(md.to_numpy()).float().to(self.device)

    def embed_texts(
        self,
        texts: pd.Series,
        batch_size: int = 64,
        progress: bool = True
    ) -> torch.Tensor:
        # ... (Keep existing implementation of embed_texts)
        # Ensure encoder is available before calling
        if self.text_encoder is None or self.text_tokenizer is None:
            raise ValueError("Text encoder and tokenizer must be loaded to embed texts.")
            
        tokenizer = self.text_tokenizer
        encoder = self.text_encoder
        encoder.eval() # Always evaluate the encoder when embedding texts

        all_embeddings = []

        with torch.no_grad():
            batch_offsets = range(0, len(texts), batch_size)
            if progress:
                batch_offsets = tqdm(batch_offsets, desc=f"Embedding {texts.name or '<unnamed>'}")
            for i in batch_offsets:
                batch_texts = texts.iloc[i:i + batch_size]
                nonna = batch_texts.notna() & batch_texts.str.len().gt(0)

                tokenized = tokenizer(
                    batch_texts[nonna].tolist(),
                    padding=True,
                    truncation=True,
                    return_tensors="pt",
                    max_length=self.text_config.max_position_embeddings
                ).to(self.device)

                outputs: transformers.modeling_outputs.BaseModelOutput = encoder(**tokenized)
                last_hidden: torch.Tensor = outputs.last_hidden_state
                mask: torch.Tensor = tokenized["attention_mask"].unsqueeze(-1)
                
                masked_hidden = last_hidden * mask
                summed = masked_hidden.sum(dim=1)
                counts = mask.sum(dim=1)
                embeddings = torch.zeros(len(batch_texts), last_hidden.shape[2], device=self.device)
                nonna = nonna.reset_index(drop=True)
                embeddings[nonna[nonna].index] = (summed / counts)

                all_embeddings.append(embeddings)

        return torch.cat(all_embeddings, dim=0)
    
    def apply_fill_mean(
        self,
        col: pd.Series,
        func: typing.Callable[[typing.Any], typing.Any],
    ) -> pd.Series:
        col = col.map(func, na_action="ignore")
        
        key = (col.name, func.__name__)
        if self.training:
            self.afm_cache[key] = col.mean()
        assert key in self.afm_cache, "You forgot to train/load the feature extractor"
        
        return col.fillna(self.afm_cache[key])
    
    def md_bool(self, col: pd.Series) -> pd.Series:
        return col.map(lambda x: (1 if x else -1), na_action="ignore").fillna(0)

    def md_len(self, col: pd.Series) -> pd.Series:
        return col.map(len, na_action="ignore").fillna(0)

    def md_time(self, col: pd.Series) -> pd.Series:
        return self.apply_fill_mean(col, lambda x: time.mktime(time.strptime(x, "%a %b %d %H:%M:%S %z %Y")))

    def md_num(self, col: pd.Series) -> pd.Series:
        return self.apply_fill_mean(col, pd.to_numeric)

    def md_place(self, col: pd.Series) -> pd.Series:
        return col.map(lambda x: int(x, 16), na_action="ignore").fillna(0)
    
    METADATA_FIELDS: list[tuple[typing.Callable[[FeatureExtractor, pd.Series], pd.Series], str]] = [
        (md_bool, "is_quote_status"),
        (md_bool, "is_reply"),
        (md_bool, "possibly_sensitive"),
        (md_bool, "quoted_status_user_verified"),
        (md_bool, "user_is_translator"),
        (md_bool, "user_geo_enabled"),
        (md_bool, "user_profile_use_background_image"),
        (md_bool, "user_default_profile"),
        
        (md_len, "full_text"),
        (md_len, "source_name"),
        (md_len, "in_reply_to_screen_name"),
        (md_len, "quoted_status_extended_tweet_entities_urls"),
        (md_len, "quoted_status_extended_tweet_entities_user_mentions"),
        (md_len, "quoted_status_extended_tweet_full_text"),
        (md_len, "quoted_status_entities_urls"),
        (md_len, "quoted_status_user_profile_image_url_https"),
        (md_len, "quoted_status_user_profile_background_image_url"),
        (md_len, "quoted_status_user_profile_background_image_url_https"),
        (md_len, "quoted_status_user_screen_name"),
        (md_len, "quoted_status_user_name"),
        (md_len, "entities_hashtags"),
        (md_len, "entities_user_mentions"),
        (md_len, "user_profile_image_url_https"),
        (md_len, "user_profile_background_image_url"),
        (md_len, "user_description"),
        (md_len, "user_translator_type"),
        (md_len, "user_url"),
        (md_len, "user_profile_banner_url"),
        (md_len, "user_location"),
        (md_len, "display_text_range"),
        (md_len, "extended_tweet_entities_urls"),
        (md_len, "extended_tweet_entities_hashtags"),
        (md_len, "extended_tweet_entities_user_mentions"),
        (md_len, "quoted_status_permalink_expanded"),
        
        (md_time, "created_at"),
        (md_time, "user_created_at"),
        (md_time, "quoted_status_created_at"),
        (md_time, "quoted_status_user_created_at"),
        
        (md_num, "user_statuses_count"),
        (md_num, "user_listed_count"),
        (md_num, "user_favourites_count"),
        (md_num, "user_profile_background_tile"),
        (md_num, "quoted_status_quote_count"),
        (md_num, "quoted_status_user_followers_count"),
        (md_num, "quoted_status_user_favourites_count"),
        (md_num, "in_reply_to_status_id"),
        
        (md_place, "quoted_status_place_id"),
        (md_place, "place_id"),
    ]

    TEXT_FIELDS: list[tuple[str, int | None]] = [
        ("full_text", None),
        ("user_description", 64),
        ("misc_text", None),
        # ("source_name", None),
        # ("in_reply_to_screen_name", None),
        # ("quoted_status_user_screen_name", None),
        # ("quoted_status_user_name", None),
    ]


In [49]:
class TweetDataset(Dataset):
    features: dict[str, torch.Tensor]
    labels: torch.Tensor
    
    def __init__(
        self,
        features: dict[str, torch.Tensor],
        labels: pd.Series,
        device: torch.device,
    ):
        self.features = features
        self.labels = torch.tensor(labels.values, dtype=torch.long, device=device)

    def __len__(self):
        return len(self.features["md"])

    def __getitem__(self, idx):
        return {
            "features": {key: val[idx] for key, val in self.features.items()},
            "label": self.labels[idx],
        }


def collate_fn(batch):
    features = {
        key: torch.stack([x["features"][key] for x in batch])
        for key in batch[0]["features"].keys()
    }
    labels = torch.stack([x["label"] for x in batch])
    return features, labels

In [50]:
NUM_CLASSES = 2

class TweetClassifier(nn.Module):
    feature_sizes: dict[str, int]
    
    layer1: nn.ModuleDict
    fc2: nn.Linear
    fc3: nn.Linear
    
    def __init__(
        self,
        feature_sizes: dict[str, int],
        hidden_dim: int = 512,
    ):
        super().__init__()

        self.feature_sizes = feature_sizes
        
        self.layer1 = nn.ModuleDict()
        
        def _add(name, dropout: float):
            self.layer1[name] = nn.Sequential(
                nn.Dropout(dropout),
                nn.Linear(feature_sizes[name], hidden_dim),
            )
        
        _add("md", 0.1)
        _add("full_text", 0.1)
        _add("user_description", 0.75)
        _add("misc_text", 0.3)
        
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, NUM_CLASSES)
    
    @property
    def device(self) -> torch.device:
        return next(self.parameters()).device
    
    def forward(self, features: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
        batch_size = len(features["md"])
        
        x = torch.zeros(batch_size, self.fc2.in_features, device=self.device)
        
        for name, module in self.layer1.items():
            x += module(features[name])

        x = F.relu(x)
        x = F.relu(self.fc2(x))
        logits = self.fc3(x)
        probs = F.softmax(logits, dim=-1)
        log_probs = F.log_softmax(logits, dim=-1)

        return {
            "logits": logits,
            "probs": probs,
            "log_probs": log_probs,
        }


In [51]:
def train_model(
    model: TweetClassifier,
    train_ds: Dataset,
    val_ds: Dataset,
    epochs: int = 3,
    lr: float = 2e-4,
    weight_decay: float = 0.01,  # TODO: Lower?
    max_grad_norm: float = 1.0,
    device: torch.device = device,
    batch_size: int = 32,
    optimizer: torch.optim.Optimizer | None = None,
    checkpoints_path: pathlib.Path | str | None = ".",
    return_best: bool = False,
) -> TweetClassifier:
    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=collate_fn,
    )
    
    model.to(device)
    if optimizer is None:
        optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

    criterion = torch.nn.CrossEntropyLoss()
    
    best_val_loss = float("inf")
    best_model_file: pathlib.Path | None = None

    for epoch in range(1, epochs + 1):
        print(f"Epoch {epoch}/{epochs}")
        model.train()
        total_loss = 0.0

        status_bar = tqdm(train_loader, desc="Training")

        for features, labels in status_bar:
            features: dict[str, torch.Tensor]
            labels: torch.Tensor
            features = {k: v.to(device) for k, v in features.items()}
            labels = labels.to(device)

            optimizer.zero_grad(set_to_none=True)

            out = model(features)
            logits = out["logits"]
            
            loss: torch.Tensor = criterion(logits, labels)
            loss.backward()
            clip_grad_norm_(model.parameters(), max_grad_norm)
            optimizer.step()

            total_loss += loss.item()
            status_bar.set_postfix({"loss": total_loss / (status_bar.n + 1)})
        
        print(f"Train Loss: {total_loss / len(train_loader):.4f}")
        
        val_metrics = evaluate_model(
            model=model,
            val_ds=val_ds,
            device=device,
            batch_size=batch_size,
        )

        print(f"Val Loss: {val_metrics['loss']:.4f}, Acc: {val_metrics['acc']:.4f}")

        if checkpoints_path is not None:
            ckpt = pathlib.Path(checkpoints_path) / f"epoch_{epoch:02}.pt"
            torch.save(model.state_dict(), ckpt)
            print(f"Checkpoint saved to {ckpt}")
            
            if val_metrics["loss"] < best_val_loss:
                best_val_loss = val_metrics["loss"]
                best_model_file = ckpt

    if return_best and best_model_file is not None:
        print(f"Best model: {best_model_file}")
        model.load_state_dict(torch.load(best_model_file))
    
    return model


def evaluate_model(
    model: TweetClassifier,
    val_ds: Dataset,
    device: torch.device = device,
    batch_size: int = 32,
) -> tuple[float, float]:
    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=collate_fn,
    )
    
    model.eval()
    criterion = torch.nn.CrossEntropyLoss()

    total_loss = 0.0
    correct = 0
    count = 0

    with torch.no_grad():
        status_bar = tqdm(val_loader, desc="Evaluating")
        
        for features, labels in status_bar:
            features: dict[str, torch.Tensor]
            labels: torch.Tensor
            features = {k: v.to(device) for k, v in features.items()}
            labels = labels.to(device)

            out = model(features)
            logits: torch.Tensor = out["logits"]
            
            loss: torch.Tensor = criterion(logits, labels)
            total_loss += loss.item()
            preds = logits.argmax(dim=-1)
            correct += (preds == labels).sum().item()
            count += labels.size(0)
            
            status_bar.set_postfix({"loss": total_loss / (status_bar.n + 1), "acc": correct / count})

    return {
        "loss": total_loss / len(val_loader),
        "acc": correct / count,
    }


In [52]:
def infer_with_model(
    model: TweetClassifier,
    feature_extractor: FeatureExtractor,
    df: pd.DataFrame,
    out_file: pathlib.Path,
    device: torch.device = device,
    batch_size: int = 32,
) -> pd.Series:
    
    feature_extractor.eval()
    
    # 1. Setup Data Loader with Lazy Extracted Features
    # The features are extracted/loaded from cache here:
    infer_features = feature_extractor.extract(df, 'infer')
    infer_ds = TweetDataset(
        infer_features, 
        pd.Series(torch.zeros(len(df), dtype=torch.long)), # Placeholder labels
        device=device
    )
    
    data_loader = DataLoader(
        infer_ds,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=collate_fn,
    )
    
    model.eval()
    
    predictions = torch.zeros(len(df), dtype=torch.long)
    offset = 0
    
    # 2. Run Model Inference
    with torch.no_grad():
        for features, _ in tqdm(data_loader, desc="Inferring"):
            
            out = model(features)
            logits = out["logits"]
            preds = logits.argmax(dim=-1)
            
            predictions[offset: offset + len(preds)] = preds.cpu()
            offset += len(preds)
            
    # --- USER-LEVEL RECONCILIATION ---
    
    # Copy the input dataframe and attach the single-tweet predictions
    df = df.copy()
    df["pred_label"] = pd.Series(predictions).astype(int)

    # Reconciliation between same users
    same_user_key = ["user_created_at", "user_profile_image_url"]
    
    # Step A: Count predicted labels (0 or 1) for each unique user key
    per_user_stats: dict[tuple[str, str], list[int]] = dict()
    for _, row in df.iterrows():
        # .setdefault returns [count_label_0, count_label_1]
        per_user_stats.setdefault(tuple(row[same_user_key].tolist()), [0, 0])[int(row["pred_label"])] += 1
    
    # Step B: Determine the reconciled label for users with conflicting predictions
    per_user_correct: dict[tuple[str, str], int] = dict()
    for key, stats in per_user_stats.items():
        # The original code only calculates the majority/tie-breaker if both labels were seen (conflict)
        if stats[0] == 0 or stats[1] == 0:
            continue # Skip users with unanimous predictions
        
        # Calculate majority vote (0 or 1), or randomly pick on a tie
        per_user_correct[key] = np.select(
            [stats[0] > stats[1], stats[1] > stats[0]],
            [0, 1],
            default=np.random.randint(0, 2),
        )
    
    del per_user_stats
    
    # Step C: Apply the reconciled prediction back to the DataFrame
    for idx, row in df.iterrows():
        key = tuple(row[same_user_key].tolist())
        if key in per_user_correct:
            # Overwrite the prediction with the reconciled label
            df.at[idx, "pred_label"] = per_user_correct[key]
    
    # 3. Save to Output File
    if out_file is not None:
        output = df[["challenge_id", "pred_label"]]
        output.columns = ["ID", "Prediction"]
        output.to_csv(out_file, index=False)

    return df["pred_label"]

# Test runs

In [None]:
text_encoder_name = "almanach/camembertav2-base"
print(f"\n===== [ {text_encoder_name} ] =====\n")

model_folder = pathlib.Path(f"./models/{VERSION}/") / text_encoder_name.split("/")[-1]
model_folder.mkdir(exist_ok=True, parents=True)

# 1. Initialize FeatureExtractor
feature_extractor = FeatureExtractor(
    text_encoder_name=text_encoder_name, 
    text_enc_cache_path=model_folder / "text_enc_cache", 
    device=device
)

# Load metadata normalization stats (means/stds) if available
f_ext_ckpt = model_folder / "feature_extractor.ckpt"
if f_ext_ckpt.exists():
    # Load state_dict, which includes means, stds, and afm_cache
    feature_extractor.load_state_dict(torch.load(f_ext_ckpt, weights_only=False))

# 2. Freeze the encoder
feature_extractor.freeze_encoder()
feature_extractor.train() # Set to train mode for proper metadata normalization (if not loaded)

# 3. Extract/Load the full training set features (text and metadata)
# 'train' split name ensures the cache is named 'train.ckpt'
full_train_features = feature_extractor.extract(X_train, 'train')

# Save metadata normalization stats (means/stds) after extraction/computation
torch.save(feature_extractor.state_dict(), f_ext_ckpt)

# 4. Create full dataset
# Pass the pre-extracted features and the original labels
full_train_ds = TweetDataset(full_train_features, y_train, device=device)

# 5. Split into train and validation sets
user_descs = pd.Series(X_train['user_description']).fillna('__MISSING__').factorize()[0]

splitter = GroupShuffleSplit(n_splits=1, test_size=0.1, random_state=42)
train_idx, val_idx = next(splitter.split(X_train, y_train, groups=user_descs))

train_ds = torch.utils.data.Subset(full_train_ds, train_idx)
val_ds   = torch.utils.data.Subset(full_train_ds, val_idx)

# Instantiate and train the classifier as before
model = TweetClassifier(
    feature_sizes=feature_extractor.dims(),
    hidden_dim=512,
).to(device)

# Uncomment the following lines to run training/inference
# model = train_model(model, train_ds, val_ds, lr=2e-4, epochs=10, batch_size=64, device=device, checkpoints_path=model_folder, return_best=True)
# torch.save(model.state_dict(), model_folder / "best_model.ckpt")
# torch.cuda.empty_cache()
# Inference requires extracting features for X_kaggle with split_name='infer'
feature_extractor.eval()
infer_features = feature_extractor.extract(X_kaggle, 'infer')
infer_ds = TweetDataset(infer_features, pd.Series(torch.zeros(len(X_kaggle), dtype=torch.long)), device=device)
infer_with_model(model, feature_extractor, X_kaggle, batch_size=64, device=device, out_file=model_folder / f"predictions-{VERSION}.csv")
torch.cuda.empty_cache()


===== [ almanach/camembertav2-base ] =====

Text encoder frozen.


  return col.map(len, na_action="ignore").fillna(0)
Extracting metadata: 100%|██████████| 48/48 [00:05<00:00,  8.77it/s]


Loading cached encodings for train...


  return col.map(len, na_action="ignore").fillna(0)
Extracting metadata: 100%|██████████| 48/48 [00:03<00:00, 13.08it/s]


Loading cached encodings for infer...


Extracting metadata: 100%|██████████| 48/48 [00:03<00:00, 13.90it/s]


Loading cached encodings for infer...


Inferring: 100%|██████████| 1616/1616 [00:02<00:00, 787.08it/s]


In [None]:
# model_folder = pathlib.Path(f"./models/{VERSION}/camembertav2-base/")
# feature_extractor = FeatureExtractor(text_encoder_name="almanach/camembertav2-base", text_enc_cache_path=model_folder / "text_enc_cache", device=device)
# feature_extractor.load_state_dict(torch.load(model_folder / "feature_extractor.ckpt", weights_only=False))
# model = TweetClassifier(
#     feature_sizes=feature_extractor.dims(),
#     hidden_dim=512,
# ).to(device)
# model.load_state_dict(torch.load(model_folder / "epoch_05.pt"))
# good_predictions = infer_with_model(model, feature_extractor, X_kaggle, batch_size=64, device=device, out_file=model_folder / "predictions-v10-e09.csv")

  return col.map(len, na_action="ignore").fillna(0)
Extracting metadata: 100%|██████████| 48/48 [00:03<00:00, 13.17it/s]


Loading cached encodings for infer...


Inferring: 100%|██████████| 1616/1616 [00:02<00:00, 719.25it/s]
