In [3]:
from __future__ import annotations
import typing
import json
import pathlib
import os
import time

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim import AdamW
from torch.nn.utils import clip_grad_norm_

import transformers
import transformers.modeling_outputs
from transformers import AutoTokenizer, AutoModel

from nltk.corpus import stopwords

import sklearn
from sklearn.dummy import DummyClassifier
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, classification_report
from sklearn.model_selection import train_test_split, StratifiedKFold, cross_val_score
from sklearn.pipeline import Pipeline

from tqdm import tqdm

import IPython
from IPython.display import display

In [4]:
IS_KAGGLE = "KAGGLE_DOCKER_IMAGE" in os.environ

DATASETS = pathlib.Path(
    "."
    if not IS_KAGGLE
    else "/kaggle/input/influencers-or-observers-predicting-social-roles/Kaggle2025"
)

DATASET_TRAIN = DATASETS / "train.jsonl"
DATASET_KAGGLE = DATASETS / "kaggle_test.jsonl"

CACHE_DIR = pathlib.Path(".")

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [6]:
torch.random.manual_seed(42)
np.random.seed(42)

# Data loading

In [7]:
def load_json(path: pathlib.Path, cache: bool = False) -> pd.DataFrame:
    path_pq = (CACHE_DIR / path.name).with_stem(f"{path.stem}_raw").with_suffix(".parquet")
    
    if cache and path_pq.exists():
        return pd.read_parquet(path_pq)
    
    # This leaves things to be desired, since there's no way to specify dtypes
    # and it assumes float instead of int, causing a loss in precision...
    # But I guess it only matters for ids, which we'll probably discard in preprocessing anyway
    result = pd.json_normalize(list(map(json.loads, path.read_bytes().splitlines())))
    
    if cache:
        result.to_parquet(path_pq)
    
    return result


In [16]:
train_data = load_json(DATASET_TRAIN, cache=True)
kaggle_data = load_json(DATASET_KAGGLE, cache=True)

# Preprocessing

In [9]:
set(train_data.columns) - set(kaggle_data.columns)

{'label', 'withheld_in_countries'}

In [42]:

def preprocess(df: pd.DataFrame) -> pd.DataFrame:
    df["is_reply"] = df["in_reply_to_status_id"].notna()
    
    df = df.drop(columns=[
        "in_reply_to_status_id_str",
        # "in_reply_to_status_id",
        "in_reply_to_user_id_str",
        "in_reply_to_user_id",
        "quoted_status_id_str",
        "quoted_status_id",
        "id_str",
        "quoted_status.in_reply_to_status_id_str",
        "quoted_status.in_reply_to_status_id",
        "quoted_status.in_reply_to_user_id_str",
        "quoted_status.in_reply_to_user_id",
        "quoted_status.id_str",
        "quoted_status.id",
        "quoted_status.user.id_str",
        "quoted_status.user.id",
        # "quoted_status_permalink.expanded",
        "quoted_status_permalink.display",
        "quoted_status_permalink.url",
        "quoted_status.quoted_status_id",
        "quoted_status.quoted_status_id_str",
        # "quoted_status.place.id",
        # "place.id",
        "lang",  # Always "fr"
        "retweeted",  # Always False
        "filter_level",  # Always "low"
        "geo",  # Always None
        "place",  # Always None
        "coordinates",  # Always None
        "contributors",  # Always None
        "quote_count",  # Always 0
        "reply_count",  # Always 0
        "retweet_count",  # Always 0
        "favorite_count",  # Always 0
        "favorited",  # Always False
        "quoted_status.geo",  # Always None
        "quoted_status.place",  # Always None
        "quoted_status.coordinates",  # Always None
        "quoted_status.retweeted",  # Always False
        "quoted_status.filter_level",  # Always "low"
        "quoted_status.contributors",  # Always None
        "quoted_status.user.utc_offset",  # Always None
        "quoted_status.user.lang",  # Always None
        "quoted_status.user.time_zone",  # Always None
        "quoted_status.user.follow_request_sent",  # Always None
        "quoted_status.user.following",  # Always None
        "quoted_status.user.notifications",  # Always None
        "user.default_profile_image",  # Always False
        "user.protected",  # Always False
        "user.contributors_enabled",  # Always False
        "user.lang",  # Always None
        "user.notifications",  # Always None
        "user.following",  # Always None
        "user.utc_offset",  # Always None
        "user.time_zone",  # Always None
        "user.follow_request_sent",  # Always None
    ])
    
    # TODO: Augment text with other string features?
    df["full_text"] = df.apply(lambda tweet: extract_full_text(tweet), axis=1)
    
    return df


def extract_full_text(tweet: pd.Series) -> str:
    text: str = tweet["text"]
    
    if not pd.isna(tweet["extended_tweet.full_text"]):
        text = tweet["extended_tweet.full_text"]
    
    return text


In [43]:
X_train = train_data.drop("label", axis=1)
y_train = train_data["label"]

X_kaggle = kaggle_data

X_train = preprocess(X_train)
X_kaggle = preprocess(X_kaggle)

# Data exploration

In [None]:
X_train[[
    "in_reply_to_status_id_str",
    "in_reply_to_status_id",
    "in_reply_to_user_id_str",
    "in_reply_to_user_id",
    "quoted_status_id_str",
    "quoted_status_id",
    "id_str",
    "quoted_status.in_reply_to_status_id_str",
    "quoted_status.in_reply_to_status_id",
    "quoted_status.in_reply_to_user_id_str",
    "quoted_status.in_reply_to_user_id",
    "quoted_status.id_str",
    "quoted_status.id",
    "quoted_status.user.id_str",
    "quoted_status.user.id",
    "quoted_status.quoted_status_id",
    "quoted_status.quoted_status_id_str",
]].map(lambda x: float(x) if isinstance(x, str) else x).corrwith(y_train)

in_reply_to_status_id_str                 -0.050524
in_reply_to_status_id                     -0.050524
in_reply_to_user_id_str                   -0.008360
in_reply_to_user_id                       -0.008360
quoted_status_id_str                      -0.019543
quoted_status_id                          -0.019543
id_str                                    -0.026025
quoted_status.in_reply_to_status_id_str   -0.029068
quoted_status.in_reply_to_status_id       -0.029068
quoted_status.in_reply_to_user_id_str     -0.026719
quoted_status.in_reply_to_user_id         -0.026719
quoted_status.id_str                      -0.019542
quoted_status.id                          -0.019542
quoted_status.user.id_str                  0.018087
quoted_status.user.id                      0.018087
quoted_status.quoted_status_id            -0.022994
quoted_status.quoted_status_id_str        -0.022994
dtype: float64

In [25]:
X_train[[
    "quoted_status.place.id",
    "place.id",
]].map(lambda x: int(x, base=16) if isinstance(x, str) else x).corrwith(y_train)

quoted_status.place.id    0.043506
place.id                 -0.044059
dtype: float64

In [26]:
X_train[[
    "quoted_status_permalink.expanded",
    "quoted_status_permalink.display",
    "quoted_status_permalink.url",
]].map(lambda x: len(x) if isinstance(x, str) else x).corrwith(y_train)

  c /= stddev[:, None]
  c /= stddev[None, :]


quoted_status_permalink.expanded   -0.037841
quoted_status_permalink.display     0.000148
quoted_status_permalink.url              NaN
dtype: float64

In [111]:
X_train.corrwith(y_train, numeric_only=True)

is_quote_status                       -0.018314
truncated                             -0.009665
challenge_id                           0.001228
quoted_status.retweet_count            0.017500
quoted_status.favorite_count           0.017766
quoted_status.quote_count              0.046755
quoted_status.reply_count              0.007355
quoted_status.user.friends_count      -0.019404
quoted_status.user.listed_count        0.001474
quoted_status.user.favourites_count   -0.050869
quoted_status.user.statuses_count     -0.000199
quoted_status.user.followers_count     0.007178
user.listed_count                      0.078584
user.favourites_count                  0.146453
user.is_translator                     0.013591
user.geo_enabled                       0.296986
user.profile_background_tile           0.180543
user.statuses_count                    0.281050
user.profile_use_background_image     -0.129781
user.default_profile                  -0.324203
is_reply                              -0

In [112]:
dt_cols = X_train[:10].apply(lambda col: pd.to_datetime(col, format="%a %b %d %H:%M:%S %z %Y", errors="coerce"))
dt_cols = dt_cols.columns[dt_cols.notna().any()]

pd.Series({
    col: X_train[col].apply(lambda x: time.mktime(time.strptime(x, "%a %b %d %H:%M:%S %z %Y")) if pd.notnull(x) else pd.NA).corr(y_train)
    for col in dt_cols
})

created_at                      -0.026027
quoted_status.created_at        -0.018908
quoted_status.user.created_at    0.021092
user.created_at                 -0.291983
dtype: float64

In [113]:
def safe_len(x: typing.Any) -> int | float:
    try:
        return len(x)
    except TypeError:
        return pd.NA

len_cols = X_train[:10].apply(lambda col: col.map(safe_len))
len_cols = len_cols.columns[len_cols.notna().any()]

with pd.option_context("display.max_rows", None, "display.max_columns", None):
    display(pd.Series({
        f"len({col})": X_train[col].apply(safe_len).fillna(0).corr(y_train)
        for col in len_cols
    }))

  c /= stddev[:, None]
  c /= stddev[None, :]
  f"len({col})": X_train[col].apply(safe_len).fillna(0).corr(y_train)


len(created_at)                                                    NaN
len(source)                                                  -0.084287
len(in_reply_to_screen_name)                                 -0.217397
len(text)                                                    -0.011660
len(timestamp_ms)                                                  NaN
len(quoted_status.extended_tweet.entities.urls)              -0.031582
len(quoted_status.extended_tweet.entities.hashtags)          -0.000267
len(quoted_status.extended_tweet.entities.user_mentions)     -0.015762
len(quoted_status.extended_tweet.entities.symbols)            0.004725
len(quoted_status.extended_tweet.full_text)                  -0.031543
len(quoted_status.extended_tweet.display_text_range)         -0.027365
len(quoted_status.created_at)                                -0.018283
len(quoted_status.source)                                    -0.014854
len(quoted_status.text)                                      -0.023169
len(qu

# Models

In [114]:
# TODO: discard quoted_status.lang != "fr"?
# TODO: some tweets are images

In [115]:

NUM_CLASSES = 2

class TweetClassifier(nn.Module):
    # tokenizer: nn.Module
    # encoder: nn.Module
    metadata_dim: int
    md_batchnorm: nn.Module
    fc1: nn.Module
    fc2: nn.Module
    
    def __init__(
        self,
        # pretrained_encoder: str = "distilbert-base-cased", # "camembert-base", "Geotrend/distilbert-base-en-fr-cased", "flaubert/flaubert_base_cased", "flaubert/flaubert_small_cased"
        metadata_dim: int = 16,
        hidden_dim: int = 128,
        # max_length: int = 256,
    ):
        super().__init__()
        
        # self.tokenizer = AutoTokenizer.from_pretrained(pretrained_encoder)
        # self.encoder = AutoModel.from_pretrained(pretrained_encoder)

        # # Don't finetune the encoder... yet?
        # for param in self.encoder.parameters():
        #     param.requires_grad = False
        
        # self.encoder_dim = self.encoder.config.hidden_size
        # self.max_length = max_length

        self.metadata_dim = metadata_dim
        self.md_batchnorm = nn.BatchNorm1d(metadata_dim)
        
        # self.fc1 = nn.Linear(self.encoder_dim + metadata_dim, hidden_dim)
        self.fc1 = nn.Linear(metadata_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, NUM_CLASSES)
    
    @property
    def device(self) -> torch.device:
        return next(self.parameters()).device
    
    # def encode_text(self, texts: list[str]) -> torch.Tensor:
    #     encoded: torch.Tensor = self.tokenizer(
    #         texts,
    #         padding=True,
    #         truncation=True,
    #         max_length=self.max_length,
    #         return_tensors="pt",
    #     ).to(self.device)

    #     outputs: transformers.modeling_outputs.BaseModelOutput = self.encoder(**encoded)
    #     cls_embeddings = outputs.last_hidden_state[:, 0, :]  # CLS token
    #     return cls_embeddings  # [batch, encoder_dim]
    
    def forward(
        self,
        texts: list[str] | torch.Tensor,
        metadata: torch.Tensor,
    ) -> dict[str, torch.Tensor]:
        """
        Returns dict with:
            "logits": tensor [batch_size, num_classes]
            "probs": tensor [batch_size, num_classes]
        """
        device = self.device
        batch_size = len(texts)
        
        # if isinstance(texts, torch.Tensor):
        #     text_vecs = texts
        # else:
        #     text_vecs = self.encode_text(texts)  # [B, encoder_dim]

        metadata = metadata.to(device)
        assert metadata.shape == (batch_size, self.metadata_dim)
        
        metadata = self.md_batchnorm(metadata)
        
        # x = torch.cat([text_vecs, metadata], dim=1)
        x = metadata

        hidden = F.relu(self.fc1(x))
        logits = self.fc2(hidden)
        probs = F.softmax(logits, dim=-1)

        return {
            "logits": logits,
            "probs": probs,
        }


In [56]:
X_train[X_train["geo.coordinates"].notna()][["geo.type", "geo.coordinates"]]

Unnamed: 0,geo.type,geo.coordinates
2999,Point,"[46.79534912, 7.14200852]"
9352,Point,"[47.93526676, 5.20570719]"
19083,Point,"[45.7656, 4.9819]"
29074,Point,"[10.01801802, 123.41320069]"
29080,Point,"[48.87152512, 2.31456399]"
35951,Point,"[46.80194092, 7.13807141]"
38222,Point,"[45.7656, 4.9819]"
50626,Point,"[45.7656, 4.9819]"
50731,Point,"[45.7656, 4.9819]"
68162,Point,"[48.0041, 0.196681]"


In [None]:
def apply_fill_mean(
    col: pd.Series,
    func: typing.Callable[[typing.Any], typing.Any],
    # Note: the default value for cache is DELIBERATELY mutable!
    cache: dict[tuple[str, int], float] = {},
    **kwargs,
) -> pd.Series:
    key = (col, hash(func))
    if key not in cache:
        cache[key] = X_train[col.name].apply(func, **kwargs).mean()
    return col.apply(func, **kwargs).fillna(cache[key])

def md_bool(col: pd.Series) -> pd.Series:
    return col.apply(lambda x: (1 if x else -1) if pd.notnull(x) else 0)

def md_len(col: pd.Series) -> pd.Series:
    return col.apply(safe_len).fillna(0)

def md_time(col: pd.Series) -> pd.Series:
    # lambda x: time.mktime(time.strptime(x, "%a %b %d %H:%M:%S %z %Y")) if pd.notnull(x) else pd.NA
    return apply_fill_mean(col, pd.to_datetime, errors="coerce", format="%a %b %d %H:%M:%S %z %Y")

def md_num(col: pd.Series) -> pd.Series:
    return apply_fill_mean(col, pd.to_numeric)

def md_place(col: pd.Series) -> pd.Series:
    return col.apply(int, base=16).fillna(0)

METADATA = [
    (md_bool, "is_quote_status"),
    (md_bool, "is_reply"),
    (md_bool, "possibly_sensitive"),
    (md_bool, "quoted_status.user.verified"),
    (md_bool, "user.is_translator"),
    (md_bool, "user.geo_enabled"),
    (md_bool, "user.profile_use_background_image"),
    (md_bool, "user.default_profile"),
    
    (md_len, "full_text"),
    (md_len, "source"),  # TODO: Analyze the contents
    (md_len, "in_reply_to_screen_name"),  # TODO: Analyze the contents
    (md_len, "quoted_status.extended_tweet.entities.urls"),
    (md_len, "quoted_status.extended_tweet.entities.user_mentions"),
    (md_len, "quoted_status.extended_tweet.full_text"),
    (md_len, "quoted_status.entities.urls"),
    (md_len, "quoted_status.user.profile_image_url_https"),
    (md_len, "quoted_status.user.profile_background_image_url"),
    (md_len, "quoted_status.user.profile_background_image_url_https"),
    (md_len, "quoted_status.user.screen_name"),  # TODO: Analyze the contents
    (md_len, "quoted_status.user.name"),  # TODO: Analyze the contents
    (md_len, "entities.hashtags"),  # TODO: Analyze the contents
    (md_len, "entities.user_mentions"),  # TODO: Analyze the contents
    (md_len, "user.profile_image_url_https"),
    (md_len, "user.profile_background_image_url"),
    (md_len, "user.description"),
    (md_len, "user.translator_type"),
    (md_len, "user.url"),
    (md_len, "user.profile_banner_url"),
    (md_len, "user.location"),
    (md_len, "display_text_range"),
    (md_len, "extended_tweet.entities.urls"),
    (md_len, "extended_tweet.entities.hashtags"),
    (md_len, "extended_tweet.entities.user_mentions"),
    (md_len, "quoted_status_permalink.expanded"),
    
    (md_time, "created_at"),
    (md_time, "user.created_at"),
    (md_time, "quoted_status.created_at"),
    (md_time, "quoted_status.user.created_at"),
    
    (md_num, "user.statuses_count"),
    (md_num, "user.listed_count"),
    (md_num, "user.favourites_count"),
    (md_num, "user.profile_background_tile"),
    (md_num, "quoted_status.quote_count"),
    (md_num, "quoted_status.user.followers_count"),
    (md_num, "quoted_status.user.favourites_count"),
    (md_num, "in_reply_to_status_id"),
    
    (md_place, "quoted_status.place.id"),
    (md_place, "place.id"),
]

METADATA_DIM = len(METADATA)

def extract_metadata(df: pd.DataFrame) -> torch.Tensor:
    md: list[pd.Series] = []

    for fn, col_name in METADATA:
        md.append(fn(df[col_name]))

    return torch.from_numpy(np.array(md)).transpose(0, 1).float()


In [117]:
class TweetDataset(Dataset):
    texts: list[str]
    metadata: torch.Tensor
    labels: torch.Tensor
    device: torch.device
    
    def __init__(self, df: pd.DataFrame, labels: pd.Series, device: torch.device = device):
        self.texts = df["full_text"].tolist()
        self.metadata = extract_metadata(df).to(device)
        self.labels = torch.tensor(labels, dtype=torch.long, device=device)
        self.device = device

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        return {
            "text": self.texts[idx],
            "metadata": self.metadata[idx],
            "label": self.labels[idx],
        }


def collate_fn(batch):
    texts = [x["text"] for x in batch]
    metadata = torch.stack([x["metadata"] for x in batch])
    labels = torch.stack([x["label"] for x in batch])
    return texts, metadata, labels


In [None]:
def train_model(
    model: TweetClassifier,
    train_ds: Dataset,
    val_ds: Dataset,
    epochs: int = 3,
    lr: float = 2e-4,
    weight_decay: float = 0.01,
    max_grad_norm: float = 1.0,
    device: torch.device = device,
    batch_size: int = 32,
    optimizer: torch.optim.Optimizer | None = None,
) -> TweetClassifier:
    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=collate_fn,
    )
    
    model.to(device)
    if optimizer is None:
        optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

    criterion = torch.nn.CrossEntropyLoss()

    for epoch in range(1, epochs + 1):
        print(f"Epoch {epoch}/{epochs}")
        model.train()
        total_loss = 0.0

        status_bar = tqdm(train_loader, desc="Training")

        for texts, metadata, labels in status_bar:
            texts: list[str]
            metadata: torch.Tensor
            labels: torch.Tensor
            
            metadata = metadata.to(device)
            labels = labels.to(device)

            optimizer.zero_grad(set_to_none=True)

            out = model(
                texts=texts,
                metadata=metadata,
            )
            logits = out["logits"]
            
            loss: torch.Tensor = criterion(logits, labels)
            loss.backward()
            clip_grad_norm_(model.parameters(), max_grad_norm)  # TODO: ?
            optimizer.step()

            total_loss += loss.item()
            status_bar.set_postfix({"loss": total_loss / (status_bar.n + 1)})

        print(f"Train Loss: {total_loss / len(train_loader):.4f}")

    return model


In [119]:
def evaluate_model(
    model: TweetClassifier,
    val_ds: Dataset,
    device: torch.device = device,
    batch_size: int = 32,
) -> tuple[float, float]:
    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=collate_fn,
    )
    
    model.eval()
    criterion = torch.nn.CrossEntropyLoss()

    total_loss = 0.0
    correct = 0
    count = 0

    with torch.no_grad():
        status_bar = tqdm(val_loader, desc="Evaluating")
        
        for texts, metadata, labels in status_bar:
            texts: list[str]
            metadata: torch.Tensor
            labels: torch.Tensor
            metadata = metadata.to(device)
            labels = labels.to(device)

            out = model(
                texts=texts,
                metadata=metadata,
            )
            logits: torch.Tensor = out["logits"]
            
            loss: torch.Tensor = criterion(logits, labels)
            total_loss += loss.item()
            preds = logits.argmax(dim=-1)
            correct += (preds == labels).sum().item()
            count += labels.size(0)
            
            status_bar.set_postfix({"loss": total_loss / (status_bar.n + 1), "acc": correct / count})

    return {
        "loss": total_loss / len(val_loader),
        "acc": correct / count,
    }


In [120]:
def infer_with_model(
    model: TweetClassifier,
    df: pd.DataFrame,
    out_file: pathlib.Path | str | None = None,
    device: torch.device = device,
    batch_size: int = 32,
) -> torch.Tensor:
    data_loader = DataLoader(
        TweetDataset(df, torch.zeros(len(df), dtype=torch.long, device=device), device=device),
        batch_size=batch_size,
        shuffle=False,
        collate_fn=collate_fn,
    )
    
    model.eval()
    
    predictions = torch.zeros(len(df), dtype=torch.long)
    cur_idx = 0
    
    with torch.no_grad():
        for texts, metadata, _ in tqdm(data_loader, desc="Inferring"):
            texts: list[str]
            metadata: torch.Tensor
            metadata = metadata.to(device)

            out = model(
                texts=texts,
                metadata=metadata,
            )
            logits: torch.Tensor = out["logits"].cpu()
            
            predictions[cur_idx:cur_idx+len(texts)] = logits.argmax(dim=-1)
            cur_idx += len(texts)
    
    if out_file is not None:
        output = pd.concat([df["challenge_id"], pd.DataFrame(predictions)], axis=1, ignore_index=True)
        output.columns = ["ID", "Prediction"]
        output.to_csv(out_file, index=False)
    
    return predictions


In [121]:
model = TweetClassifier(
    # pretrained_encoder="camembert-base",
    metadata_dim=METADATA_DIM,
    hidden_dim=128,
    # max_length=256
).to(device)

In [122]:
full_train_ds = TweetDataset(X_train, y_train, device=device)

train_ds, val_ds = random_split(full_train_ds, [0.9, 0.1])

  return col.apply(safe_len).fillna(0)
  return tmp.fillna(tmp.mean())


In [123]:
model.md_batchnorm(full_train_ds.metadata)[760, 0]

tensor(-0.7305, device='cuda:0', grad_fn=<SelectBackward0>)

In [124]:
def get_model_path(version: str) -> pathlib.Path:
    return pathlib.Path(
        f"./models/{version}/model-{version}.pt"
        if not IS_KAGGLE
        else f"/kaggle/input/model-{version}-pt/model-{version}.pt"
    )

In [125]:
model_path = get_model_path("v6")

# model.load_state_dict(torch.load(model_path, weights_only=True))

model = train_model(model, train_ds, epochs=5, batch_size=64, device=device)
torch.save(model.state_dict(), "model-v6.pt")


Epoch 1/5


Training: 100%|██████████| 2179/2179 [00:11<00:00, 185.15it/s, loss=0.482]


Train Loss: 0.4807
Epoch 2/5


Training: 100%|██████████| 2179/2179 [00:11<00:00, 190.94it/s, loss=0.444]


Train Loss: 0.4415
Epoch 3/5


Training: 100%|██████████| 2179/2179 [00:12<00:00, 181.01it/s, loss=0.434]


Train Loss: 0.4333
Epoch 4/5


Training: 100%|██████████| 2179/2179 [00:12<00:00, 168.37it/s, loss=0.429]


Train Loss: 0.4283
Epoch 5/5


Training: 100%|██████████| 2179/2179 [00:12<00:00, 172.17it/s, loss=0.427]

Train Loss: 0.4255





In [126]:
# TODO: Pre-encode all texts?
evaluate_model(model, val_ds, batch_size=64, device=device)

Evaluating: 100%|██████████| 243/243 [00:01<00:00, 234.71it/s, loss=0.46, acc=0.799] 


{'loss': 0.4410637799849726, 'acc': 0.7987218384868633}

In [127]:
infer_with_model(model, X_kaggle, batch_size=64, out_file="predictions-v6.csv", device=device)

  return col.apply(safe_len).fillna(0)
  return tmp.fillna(tmp.mean())
  self.labels = torch.tensor(labels, dtype=torch.long, device=device)
Inferring: 100%|██████████| 1616/1616 [00:02<00:00, 715.34it/s]


tensor([1, 1, 0,  ..., 1, 1, 0])