In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%cd ..

/home/den/dev/git/ozon-e-cup-2025


In [None]:
import numpy as np
from torchinfo import summary

import src.preprocessing as prep
import src.torch_modules as torch_modules

In [None]:
# Generate random data
N_SAMPLES = 1000


BINARY__FEATS_DIM = 10
META__FEATS_DIM = 50

TEXT__EMB_DIM = 128
TEXT__FEATS_DIM = 300

IMG__EMB_DIM = 512


binary_feats = np.random.randint(0, 2, size=(N_SAMPLES, BINARY__FEATS_DIM))
meta_feats = np.random.randn(N_SAMPLES, META__FEATS_DIM)

text_embs = np.random.randn(N_SAMPLES, TEXT__EMB_DIM)
text_feats = np.random.randn(N_SAMPLES, TEXT__FEATS_DIM)

img_embs = np.random.randn(N_SAMPLES, IMG__EMB_DIM)

In [None]:
binary_feats.shape, meta_feats.shape, text_embs.shape, text_feats.shape, img_embs.shape

((1000, 10), (1000, 50), (1000, 128), (1000, 300), (1000, 512))

In [None]:
binary_feats.shape[1] + meta_feats.shape[1] + text_embs.shape[1] + text_feats.shape[1] + img_embs.shape[1]

1000

# 1 - Нормализация, конкатенация

In [None]:
# Можем делать как с изображениями
X_all_with_imgs, meta_scaler_1, text_scaler_1 = prep.preprocess_features(
    binary_feats=binary_feats,
    meta_feats=meta_feats,
    text_feats=text_feats,
    text_embs=text_embs,
    img_embs=img_embs,
    fit_scalers=True,
)

# Так и без
X_all_wo_imgs, meta_scaler_2, text_scaler_2 = prep.preprocess_features(
    binary_feats=binary_feats,
    meta_feats=meta_feats,
    text_feats=text_feats,
    text_embs=text_embs,
    fit_scalers=True,
)

In [None]:
len(X_all_with_imgs), len(X_all_wo_imgs)

(7, 5)

In [None]:
np.concatenate(X_all_with_imgs, axis=-1).shape, np.concatenate(X_all_wo_imgs, axis=-1).shape

((1000, 1002), (1000, 489))

In [None]:
# Скейлеры, которые можем подавать и на валидации/тесте
meta_scaler_1, meta_scaler_1.mean_.shape, meta_scaler_1.scale_.shape, text_scaler_1, text_scaler_1.mean_.shape, text_scaler_1.scale_.shape

(StandardScaler(), (50,), (50,), StandardScaler(), (300,), (300,))

# 2 - Добавим аттеншн

## 2.1 легко

In [None]:
# Препроцессим данные
X_all_with_imgs, meta_scaler, text_scaler = prep.preprocess_features(
    binary_feats=binary_feats,
    meta_feats=meta_feats,
    text_feats=text_feats,
    text_embs=text_embs,
    img_embs=img_embs,
    fit_scalers=True,
)
(
    binary_feats,
    meta_feats_prep,
    text_presence_flag,
    text_feats_prep,
    text_embs_prep,
    img_presence_flag,
    img_embs_prep,
) = X_all_with_imgs

# Add presence flags to binary_feats
binary_feats = np.concatenate([binary_feats, text_presence_flag, img_presence_flag], axis=1)

binary_feats.shape, meta_feats_prep.shape, text_feats_prep.shape, text_embs_prep.shape, img_embs_prep.shape

((1000, 12), (1000, 50), (1000, 300), (1000, 128), (1000, 512))

In [None]:
import torch

# Multi-modal fusion
MMProj = torch_modules.MultiModalProjector(
    [
        binary_feats.shape[1],
        meta_feats_prep.shape[1],
        text_feats_prep.shape[1],
        text_embs_prep.shape[1],
        img_embs_prep.shape[1],
    ],
    emb_dim=64,
)

proj_embs = MMProj(
    torch.from_numpy(binary_feats).float(),
    torch.from_numpy(meta_feats_prep).float(),
    torch.from_numpy(text_feats_prep).float(),
    torch.from_numpy(text_embs_prep).float(),
    torch.from_numpy(img_embs_prep).float(),
)
proj_embs.shape

torch.Size([1000, 5, 64])

Получили последовательность векторов, которую можем теперь подать в простенький трансформер.

В MMTtransformerEncoder это уже зашито.

In [None]:
mm_tr_enc = torch_modules.MMTransformerEncoder(
    input_dims=[
        binary_feats.shape[1],
        meta_feats_prep.shape[1],
        text_feats_prep.shape[1],
        text_embs_prep.shape[1],
        img_embs_prep.shape[1],
    ],
    emb_dim=64,
    num_heads=4,
    num_layers=2,
    mlp_hidden_dim=128,
)

In [None]:
logits = mm_tr_enc(
    torch.from_numpy(binary_feats).float(),
    torch.from_numpy(meta_feats_prep).float(),
    torch.from_numpy(text_feats_prep).float(),
    torch.from_numpy(text_embs_prep).float(),
    torch.from_numpy(img_embs_prep).float(),
)
logits.shape

torch.Size([1000])

In [None]:
summary(
    mm_tr_enc,
    input_data=[
        torch.from_numpy(binary_feats).float(),
        torch.from_numpy(meta_feats_prep).float(),
        torch.from_numpy(text_feats_prep).float(),
        torch.from_numpy(text_embs_prep).float(),
        torch.from_numpy(img_embs_prep).float(),
    ],
)

Layer (type:depth-idx)                        Output Shape              Param #
MMTransformerEncoder                          [1000]                    --
├─MultiModalProjector: 1-1                    [1000, 5, 64]             --
│    └─ModuleList: 2-1                        --                        --
│    │    └─Linear: 3-1                       [1000, 64]                832
│    │    └─Linear: 3-2                       [1000, 64]                3,264
│    │    └─Linear: 3-3                       [1000, 64]                19,264
│    │    └─Linear: 3-4                       [1000, 64]                8,256
│    │    └─Linear: 3-5                       [1000, 64]                32,832
├─ModuleList: 1-2                             --                        --
│    └─ModuleList: 2-2                        --                        --
│    │    └─AttentionBlock: 3-6               [1000, 5, 64]             16,704
│    │    └─MLPBlock: 3-7                     [1000, 5, 64]             24,9

## 2.2 тяжело