# Libraries

In [1]:
from sklearn.model_selection import train_test_split

import polars as pl

import torch
from torch import nn

from transformers import BertConfig, BertForSequenceClassification

import pytorch_lightning

from torchinfo import summary
from torchmetrics.classification import F1Score

In [2]:
pl.set_random_seed(56)

In [3]:
pytorch_lightning.seed_everything(56, workers=True)

56

# Data

## Target

In [4]:
target = pl.read_csv("./data/train_targets.csv")
target = target.with_columns((pl.col("sex") == "male").cast(pl.Int8))
target.head()

viewer_uid,age,sex,age_class
i64,i64,i8,i64
10087154,30,1,1
10908708,25,0,1
10190464,34,1,2
10939673,25,1,1
10288257,48,1,3


In [5]:
train_target, val_target = train_test_split(target, test_size=1/3, random_state=56, stratify=target["age_class"])

In [6]:
len(train_target), len(val_target)

(120008, 60004)

## Events

In [7]:
events = pl.read_csv("./data/train_events.csv", try_parse_dates=True)
events = events.sort("event_timestamp")
print(events.shape)
events.head()

(1759616, 9)


event_timestamp,region,ua_device_type,ua_client_type,ua_os,ua_client_name,total_watchtime,rutube_video_id,viewer_uid
"datetime[μs, UTC]",str,str,str,str,str,i64,str,i64
2024-05-31 21:00:02 UTC,"""St.-Petersburg…","""smartphone""","""mobile app""","""Android""","""Rutube""",2661,"""video_157174""",10188822
2024-05-31 21:00:03 UTC,"""Altay Kray""","""desktop""","""browser""","""Windows""","""Yandex Browser…",281,"""video_331354""",10259113
2024-05-31 21:00:03 UTC,"""Sverdlovsk""","""smartphone""","""mobile app""",,"""Rutube""",48,"""video_390201""",10105859
2024-05-31 21:00:04 UTC,"""Irkutsk Oblast…","""smartphone""","""mobile app""","""Android""","""Rutube""",1275,"""video_445327""",10056670
2024-05-31 21:00:05 UTC,"""St.-Petersburg…","""smartphone""","""mobile app""","""Android""","""Rutube""",270,"""video_423245""",10181995


In [8]:
pad_id = 0
bos_id = 1
unk_id = 2
num_special_tokens = 3
n_positions = 64
min_item_cnt = 3

In [9]:
events = events.group_by("viewer_uid", maintain_order=True).tail(n_positions)

уберем редкие видео

In [10]:
item_counts = events["rutube_video_id"].value_counts()
print(len(item_counts))
items = sorted(item_counts.filter((pl.col("counts") >= min_item_cnt))["rutube_video_id"])
print(len(items))

124481
45038


In [11]:
item2id = {}
for item in items:
    item2id[item] = len(item2id) + num_special_tokens

In [12]:
events = events.with_columns(pl.col("rutube_video_id").map_dict(item2id, default=unk_id))

## Authors

In [13]:
videos_data = pl.read_csv("./data/video_info_v2.csv", columns=["rutube_video_id", "author_id"])
videos_data = videos_data.filter(pl.col("rutube_video_id").is_in(items))
videos_data.head()

rutube_video_id,author_id
str,i64
"""video_157198""",1043618
"""video_289824""",1009535
"""video_349723""",1048955
"""video_269867""",1002824
"""video_38243""",1010255


In [14]:
author2id = {}
for author in videos_data["author_id"].unique():
    author2id[author] = len(author2id) + num_special_tokens
    
item2author = torch.zeros(len(item2id) + num_special_tokens).int()
for item, author in zip(videos_data["rutube_video_id"], videos_data["author_id"]):
    item2author[item2id[item]] = author2id[author]
    
for i in range(num_special_tokens):
    item2author[i] = i
    
item2author

tensor([   0,    1,    2,  ..., 2651,  522, 2971], dtype=torch.int32)

# Dataset

In [15]:
class SeqDataset(torch.utils.data.Dataset):
    def __init__(self, data, target):
        super().__init__()
        
        self._data = (
            data
            .group_by("viewer_uid", maintain_order=True)
            .agg(pl.col("rutube_video_id"))
            .join(target, on="viewer_uid", how="left")
        )

    def __len__(self):
        return len(self._data)

    def __getitem__(self, index):
        row = self._data.row(index)
        input_ids = torch.tensor([bos_id] + row[1][::-1]).int()
        label = torch.tensor(row[-1]).long()
        return input_ids, label

In [16]:
train_dataset = SeqDataset(
    events.filter(pl.col("viewer_uid").is_in(train_target["viewer_uid"])),
    train_target.select(("viewer_uid", "age_class"))
)

In [17]:
val_dataset = SeqDataset(
    events.filter(pl.col("viewer_uid").is_in(val_target["viewer_uid"])),
    val_target.select(("viewer_uid", "age_class"))
)

In [18]:
train_dataset[0]

(tensor([    1, 18370, 15548, 32513, 21554, 10906, 39768, 16562, 15981, 30918,
         10906, 25024,    87,  7084, 44557, 44557,  9183, 31730,  4069, 22796,
         24118, 33968,  5991, 14651,    87, 35040,  4372,  3489, 37342, 13981,
         22676, 14651, 35040, 33225, 14593, 31578, 13131,     2],
        dtype=torch.int32),
 tensor(2))

In [19]:
val_dataset[0]

(tensor([    1, 20005, 20005,  9366, 20775, 15329,  9186, 17741, 17741, 17741,
         17741,  5869,  5869, 32923, 40564, 40564, 40564, 19128, 19128, 22408,
         42240, 15033, 15033, 15033, 33869, 33869, 28239, 28239, 15100, 16671,
         32256,  2585, 23849, 24794, 24794,  2155, 22979, 22979, 22979, 41911,
         41911, 38964, 38964, 38964, 28574, 15753, 15753, 15753, 44294, 44294,
         22305, 22305, 25327, 30127], dtype=torch.int32),
 tensor(1))

# Dataloader

In [20]:
def collate_fn(batch):
    batch_i, batch_l = [], []
    for i, l in batch:
        batch_i.append(i)
        batch_l.append(l)
    batch_i = nn.utils.rnn.pad_sequence(batch_i, batch_first=True, padding_value=pad_id)
    batch_l = torch.tensor(batch_l)
    return  batch_i, batch_l

In [21]:
train_dataloader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=256,
    num_workers=0,
    drop_last=True,
    shuffle=True,
    pin_memory=True,
    collate_fn=collate_fn,
)

In [22]:
val_dataloader = torch.utils.data.DataLoader(
    dataset=val_dataset,
    batch_size=256,
    num_workers=0,
    drop_last=False,
    shuffle=False,
    pin_memory=True,
    collate_fn=collate_fn,
)

In [23]:
next(iter(train_dataloader))

[tensor([[    1,     2,     0,  ...,     0,     0,     0],
         [    1, 11004, 16629,  ...,     0,     0,     0],
         [    1, 33292, 34463,  ...,     0,     0,     0],
         ...,
         [    1,  1697,     0,  ...,     0,     0,     0],
         [    1, 38749, 23105,  ...,     0,     0,     0],
         [    1, 15198,  4785,  ...,     0,     0,     0]], dtype=torch.int32),
 tensor([3, 2, 1, 2, 3, 1, 1, 3, 1, 2, 2, 3, 1, 1, 2, 3, 1, 1, 3, 1, 3, 2, 2, 2,
         1, 2, 3, 2, 1, 1, 1, 2, 2, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 2, 1, 2, 3,
         3, 1, 3, 3, 2, 2, 3, 1, 3, 3, 1, 0, 0, 0, 2, 2, 2, 3, 3, 2, 3, 2, 2, 3,
         1, 1, 1, 3, 3, 2, 3, 1, 1, 1, 1, 1, 3, 1, 3, 3, 3, 3, 3, 1, 1, 3, 3, 2,
         2, 1, 2, 1, 2, 1, 2, 3, 1, 1, 2, 2, 3, 1, 3, 0, 3, 1, 1, 1, 0, 3, 2, 2,
         1, 1, 2, 2, 3, 2, 2, 1, 3, 1, 2, 2, 2, 1, 1, 3, 1, 3, 1, 3, 1, 3, 2, 1,
         2, 3, 2, 3, 1, 3, 2, 1, 2, 1, 3, 1, 1, 2, 2, 1, 3, 1, 2, 3, 3, 1, 1, 1,
         2, 1, 3, 2, 3, 1, 3, 2, 3, 3, 2, 2,

In [24]:
next(iter(val_dataloader))

[tensor([[    1, 20005, 20005,  ...,     0,     0,     0],
         [    1, 32065,   421,  ...,     0,     0,     0],
         [    1, 12346, 32187,  ...,     0,     0,     0],
         ...,
         [    1, 43543, 22725,  ...,     0,     0,     0],
         [    1,  8440, 22729,  ...,     0,     0,     0],
         [    1, 20544,  1697,  ...,     0,     0,     0]], dtype=torch.int32),
 tensor([1, 0, 3, 3, 1, 1, 1, 1, 2, 3, 1, 3, 1, 1, 1, 2, 3, 1, 1, 3, 2, 3, 1, 1,
         2, 3, 1, 2, 1, 2, 1, 2, 1, 1, 2, 3, 0, 3, 2, 3, 1, 2, 2, 2, 1, 2, 2, 2,
         2, 2, 3, 2, 3, 3, 3, 2, 2, 3, 1, 2, 2, 2, 1, 0, 1, 2, 2, 1, 2, 2, 1, 2,
         1, 3, 3, 1, 2, 2, 2, 1, 2, 2, 2, 2, 1, 2, 2, 3, 1, 2, 2, 2, 1, 1, 1, 1,
         3, 1, 2, 2, 2, 1, 1, 1, 2, 1, 3, 1, 1, 1, 3, 1, 1, 1, 2, 1, 1, 3, 1, 1,
         2, 1, 3, 2, 3, 1, 3, 3, 1, 3, 1, 2, 2, 1, 3, 2, 1, 1, 1, 1, 2, 2, 1, 3,
         2, 3, 1, 1, 2, 2, 2, 1, 1, 2, 3, 1, 1, 3, 2, 1, 0, 3, 3, 1, 3, 1, 2, 2,
         2, 2, 3, 1, 2, 1, 1, 3, 1, 1, 3, 2,

# Model

In [25]:
class ClsModel(pytorch_lightning.LightningModule):
    def __init__(
            self,
            train_dataloader,
            val_dataloader,
            **kwargs,
    ):
        super().__init__()

        self._train_dataloader = train_dataloader
        self._val_dataloader = val_dataloader

        padding_idx = kwargs.pop("padding_idx", 0)
        num_special_tokens = kwargs.pop("num_special_tokens", 3)

        n_positions = kwargs.pop("n_positions", 256)
        n_layer = kwargs.pop("n_layer", 1)
        n_head = kwargs.pop("n_head", 8)
        head_size = kwargs.pop("head_size", 64)
        embedding_size = n_head * head_size
        dropout = kwargs.pop("dropout", 0.1)
        
        num_labels = kwargs.pop("num_labels", 1)
        
        self.padding_idx = padding_idx
        
        config = BertConfig(
            vocab_size=len(item2id) + num_special_tokens,
            hidden_size=embedding_size,
            num_hidden_layers=n_layer,
            num_attention_heads=n_head,
            intermediate_size=embedding_size * 4,
            hidden_dropout_prob=dropout,
            attention_probs_dropout_prob=dropout,
            max_position_embeddings=n_positions,
            type_vocab_size=len(author2id) + num_special_tokens,
            pad_token_id=padding_idx,
            num_labels=num_labels,
            problem_type="single_label_classification",
        )
        self.bert = BertForSequenceClassification(config)
        
        self.register_buffer(
            "video2author", item2author.clone(), persistent=False
        )
        
        self.metric = F1Score("multiclass", num_classes=num_labels, average="weighted")
        
        self.val_preds = []
        self.val_target = []

    def forward(self, x, labels):
        attention_mask = (x != self.padding_idx)
        outputs = self.bert(x, attention_mask=attention_mask, token_type_ids=self.video2author[x], labels=labels, return_dict=True)
        return outputs

    def training_step(self, batch, batch_idx):
        input_ids, labels = batch

        sequence_output = self.forward(input_ids, labels)
        loss = sequence_output.loss
        
        self.log("train_loss", loss, logger=True, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        input_ids, labels = batch

        sequence_output = self.forward(input_ids, labels)
        loss = sequence_output.loss
        
        self.val_preds.append(sequence_output.logits)
        self.val_target.append(labels)
        
        return loss

    def on_validation_epoch_end(self):
        f1_weighted = self.metric(
            torch.concatenate(self.val_preds).argmax(dim=-1),
            torch.concatenate(self.val_target)
        ).item()
        print(f"{f1_weighted=}")
        self.log("val_f1_weighted", f1_weighted, logger=True, on_epoch=True)
        self.val_preds.clear()
        self.val_target.clear()

    def configure_optimizers(self):
        weight_decay = 1e-05
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        param_optimizer = list(self.named_parameters())
        optimizer_grouped_parameters = [
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay},
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
        optim = torch.optim.AdamW(optimizer_grouped_parameters, lr=1e-4)
        return [optim]

    def train_dataloader(self):
        return self._train_dataloader

    def val_dataloader(self):
        return self._val_dataloader
    
    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        input_ids, _ = batch
        sequence_output = self.forward(input_ids, None)
        return sequence_output.logits.detach().cpu()

In [26]:
cls_model = ClsModel(
    train_dataloader,
    val_dataloader,
    padding_idx=pad_id,
    n_positions=n_positions + 1,
    n_layer=4,
    n_head=8,
    head_size=32,
    dropout=0.1,
    num_labels=4,
)

In [27]:
cls_model

ClsModel(
  (bert): BertForSequenceClassification(
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(45041, 256, padding_idx=0)
        (position_embeddings): Embedding(65, 256)
        (token_type_embeddings): Embedding(3324, 256)
        (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0-3): 4 x BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=256, out_features=256, bias=True)
                (key): Linear(in_features=256, out_features=256, bias=True)
                (value): Linear(in_features=256, out_features=256, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=256, out_features=256, bias=T

In [28]:
summary(cls_model)

Layer (type:depth-idx)                                       Param #
ClsModel                                                     --
├─BertForSequenceClassification: 1-1                         --
│    └─BertModel: 2-1                                        --
│    │    └─BertEmbeddings: 3-1                              12,398,592
│    │    └─BertEncoder: 3-2                                 3,159,040
│    │    └─BertPooler: 3-3                                  65,792
│    └─Dropout: 2-2                                          --
│    └─Linear: 2-3                                           1,028
├─MulticlassF1Score: 1-2                                     --
Total params: 15,624,452
Trainable params: 15,624,452
Non-trainable params: 0

# Training

In [29]:
trainer = pytorch_lightning.Trainer(
    accelerator="gpu",
    logger=False,
    callbacks=[],
    max_epochs=2,
    enable_checkpointing=False,
)

trainer.fit(cls_model)

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

0.46068698167800903

# Inference

пример для валидации

In [30]:
import numpy as np

In [31]:
prediction = trainer.predict(cls_model, val_dataloader)
prediction = np.concatenate([p.numpy() for p in prediction])
prediction

Predicting: 0it [00:00, ?it/s]

array([[-2.0205917 ,  0.87053806,  1.178466  , -0.3924513 ],
       [ 1.366481  ,  1.3662376 , -0.7040721 , -1.8407457 ],
       [-2.8360953 , -0.6107546 ,  1.1305466 ,  1.7830209 ],
       ...,
       [-0.9968277 ,  0.78015447,  0.27284256, -0.1264144 ],
       [-1.2421279 ,  0.99831134,  0.3773139 , -0.28822434],
       [-1.4236234 ,  1.0690387 ,  1.052502  , -0.9013518 ]],
      dtype=float32)

In [32]:
prediction = (
    val_dataset
    ._data
    .with_columns([pl.Series(f"age_class_pred_{i}", prediction[:, i]) for i in range(4)])
    .drop(("rutube_video_id", "age_class"))
)
prediction

viewer_uid,age_class_pred_0,age_class_pred_1,age_class_pred_2,age_class_pred_3
i64,f32,f32,f32,f32
10105859,-2.020592,0.870538,1.178466,-0.392451
10056670,1.366481,1.366238,-0.704072,-1.840746
10131484,-2.836095,-0.610755,1.130547,1.783021
10369145,-2.192721,0.907031,1.193253,-0.296297
10326121,-1.769892,1.072153,0.812914,-0.41428
10022320,-0.692371,1.625273,0.239216,-1.349897
10037831,-2.37202,0.36353,1.403388,0.154738
10472805,-0.962026,1.035835,0.457519,-0.704976
10005074,-2.830344,-0.343688,1.314867,1.329135
10245037,-0.536521,1.035475,0.190591,-0.713088
