In [None]:
import torch
import torch.nn as nn
from transformers import AutoModel


class FineTuneClassifier(nn.Module):
    def __init__(self, base_model_path: str, num_labels: int) -> None:
        super(FineTuneClassifier, self).__init__()
        self.base_model = AutoModel.from_pretrained(base_model_path)

        for param in self.base_model.parameters():
            param.requires_grad = False

        self.classifier = nn.Linear(self.base_model.config.hidden_size * 2, num_labels)

    @classmethod
    def from_classifier_head(
        cls, base_model_path: str, path: str, num_labels: int
    ) -> nn.Module:
        model = cls(base_model_path, num_labels)
        model.classifier.load_state_dict(torch.load(path))
        return model

    def forward(
        self, input_ids: torch.tensor, attention_mask: torch.tensor
    ) -> torch.tensor:
        outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask)
        B, T, C = outputs.logits.shape

        all_tokens_hidden = outputs.logits  # (B, T, C)
        last_token_hidden = outputs.logits[:, -1, :]  # (B, C)
        last_token_hidden = last_token_hidden.unsqueeze(1).expand(B, T, C)

        combined_representation = torch.cat(
            (all_tokens_hidden, last_token_hidden), dim=-1
        )
        logits = self.classifier(combined_representation)
        return logits


class BaselineClassifier(nn.Module):
    def __init__(
        self,
        d_model: int,
        num_layers: int,
        nhead: int,
        max_seq_length: int,
        vocab_size: int,
        pad_token_id: int,
        num_labels: int,
    ) -> None:
        super(BaselineClassifier, self).__init__()
        self.pad_token_id = pad_token_id
        self.token_embedding = nn.Embedding(
            vocab_size, d_model, padding_idx=pad_token_id
        )
        self.pos_embedding = nn.Embedding(max_seq_length, d_model)
        decoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead, batch_first=True
        )
        self.transformer = nn.TransformerEncoder(decoder_layer, num_layers=num_layers)
        self.classifier = nn.Linear(d_model * 2, num_labels)

    def forward(self, token_ids: torch.tensor) -> torch.tensor:
        batch_size, seq_len = token_ids.shape

        token_emb = self.token_embedding(token_ids)
        pos_ids = torch.arange(seq_len, device=token_ids.device).unsqueeze(0)
        pos_emb = self.pos_embedding(pos_ids)
        embeddings = token_emb + pos_emb

        causal_mask = torch.triu(
            torch.ones(seq_len, seq_len, device=token_ids.device, dtype=torch.bool),
            diagonal=1,
        )

        pad_mask = token_ids.eq(self.pad_token_id)  # shape: (batch_size, seq_len)

        output = self.transformer(
            embeddings, mask=causal_mask, src_key_padding_mask=pad_mask
        )

        B, T, C = output.shape
        all_tokens_hidden = output  # (B, T, C)
        last_token_hidden = output[:, -1, :]  # (B, C)
        last_token_hidden = last_token_hidden.unsqueeze(1).expand(B, T, C)

        combined_representation = torch.cat(
            (all_tokens_hidden, last_token_hidden), dim=-1
        )
        logits = self.classifier(combined_representation)
        return logits


In [None]:
from typing import Dict
BASELINE_MODELS: Dict[str, Dict[str, int]] = {
    "mini": {
        "d_model": 64,
        "num_layers": 4,
        "num_heads": 4,
        "max_len": 16_384,
    },
    "small": {
        "d_model": 510,
        "num_layers": 8,
        "num_heads": 6,
        "max_len": 16_384,
    },
    "medium": {
        "d_model": 1344,
        "num_layers": 24,
        "num_heads": 16,
        "max_len": 16_384,
    },
    "large": {
        "d_model": 1824,
        "num_layers": 36,
        "num_heads": 24,
        "max_len": 16_384,
    },
}

In [None]:
tmp()

In [None]:
def tmp():
    for name, config in BASELINE_MODELS.items():
        d_model = config["d_model"]
        num_layers = config["num_layers"]
        nhead = config["num_heads"]
        max_seq_length = config["max_len"]
        vocab_size = 130_000
        pad_token_id = 0
        num_labels = 2

        model = BaselineClassifier(
            d_model,
            num_layers,
            nhead,
            max_seq_length,
            vocab_size,
            pad_token_id,
            num_labels,
        )
        total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        emb_params = vocab_size * d_model

        model.cuda()
        used_memory = torch.cuda.memory_allocated(device=torch.device("cuda:0"))
        print(f"Model: {name}, Total Parameters: {total_params / 1e6:.2f}M, % emb parameters: {emb_params/total_params}")
        print(f"Active params: {(total_params - emb_params) / 1e6:.2f}M, emb params: {emb_params / 1e6:.2f}M")
        print(f"Used VRAM: {used_memory / (1024 ** 2):.2f} MB")

In [14]:
import pandas as pd

In [26]:
df = pd.read_csv("../data/stats/data_stats_master.csv")
df.head()

Unnamed: 0,data,model,num_samples,num_sentences,num_words,num_chars,num_tokens
0,nyt-comments,human,4223213,18713462,75699056,1418028952,367081295
1,blogs,human,576731,8328335,11967700,557323671,164358740
2,raid,human,138244,1808789,7756169,215270586,95663743
3,natural-questions,human,231628,544546,4821294,52668992,14758408
4,writingprompts,human,303140,13802625,4407470,721933659,209316368


In [27]:
df[df["model"] == "human"][["data", "num_samples", "num_sentences", "num_words", "num_tokens"]].sort_values("data")

Unnamed: 0,data,num_samples,num_sentences,num_words,num_tokens
1,blogs,576731,8328335,11967700,164358740
5,essays,2638,123010,67709,1910966
3,natural-questions,231628,544546,4821294,14758408
6,nyt-articles,15813,21318,316972,421258
0,nyt-comments,4223213,18713462,75699056,367081295
2,raid,138244,1808789,7756169,95663743
8,reddit,655484,1817797,11192328,32554655
7,tweets,389916,735759,4405208,8375173
4,writingprompts,303140,13802625,4407470,209316368
9,xsum,226394,4298218,5268988,105941910


In [28]:
df.groupby("data").sum().reset_index()[["data", "num_samples", "num_sentences", "num_words", "num_tokens"]].sort_values("data")

Unnamed: 0,data,num_samples,num_sentences,num_words,num_tokens
0,blogs,1182270,21203478,42920004,407024467
1,essays,58036,2765379,1685250,39439565
2,natural-questions,292083,866273,7906504,22504808
3,nyt-articles,347885,2165238,11276035,63994699
4,nyt-comments,8657565,36261431,175669667,746779789
5,raid,864020,10784312,36630232,332672268
6,reddit,3408276,14035768,91421960,340245625
7,tweets,3665193,8117363,44659891,117976085
8,writingprompts,621437,23319770,14339509,399446147
9,xsum,939528,12847522,26403940,360435457


In [30]:
df.groupby("data").sum().reset_index()[["data", "num_samples", "num_sentences", "num_words", "num_tokens"]].sort_values("data").sum()

data             blogsessaysnatural-questionsnyt-articlesnyt-co...
num_samples                                               20036293
num_sentences                                            132366534
num_words                                                452912992
num_tokens                                              2830518910
dtype: object

In [31]:
df.groupby("model").sum().reset_index()[["model", "num_samples", "num_sentences", "num_words", "num_tokens"]]

Unnamed: 0,model,num_samples,num_sentences,num_words,num_tokens
0,Falcon3-3B-Instruct,629186,3101084,13685350,70997241
1,Falcon3-7B-Instruct,629186,3137392,13173150,70783667
2,Llama-3.1-8B-Instruct,628962,3024230,16408956,73421205
3,Llama-3.2-3B-Instruct,640759,3045728,17141415,75775399
4,Meta-Llama-3.1-70B-Instruct-AWQ-INT4,640750,2994366,15242777,67950245
5,Meta-Llama-3.3-70B-Instruct-AWQ-INT4,640766,2861116,17991535,72319091
6,Ministral-8B-Instruct-2410,629186,5622574,14260174,96603355
7,Mistral-Nemo-Instruct-2407,629186,4600976,11064530,89254622
8,Phi-3-medium-128k-instruct,629186,4817462,18533613,104452491
9,Phi-3-mini-128k-instruct,640760,5390971,14991412,92989442


In [32]:
df.groupby("model").sum().reset_index()[["model", "num_samples", "num_sentences", "num_words", "num_tokens"]].sum()

model            Falcon3-3B-InstructFalcon3-7B-InstructLlama-3....
num_samples                                               20036293
num_sentences                                            132366534
num_words                                                452912992
num_tokens                                              2830518910
dtype: object

In [33]:
df["is_human"] = df["model"].apply(lambda x: 1 if x == "human" else 0)

In [36]:
df.groupby("is_human").sum()[["num_samples", "num_sentences", "num_words", "num_tokens"]].reset_index()

Unnamed: 0,is_human,num_samples,num_sentences,num_words,num_tokens
0,0,13273092,82172675,327010098,1830136394
1,1,6763201,50193859,125902894,1000382516
