<a href="https://colab.research.google.com/github/Seydifa/Challenge/blob/main/Torch_Version_de_CMI_idea_test_notebook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# IMPORTANT: SOME KAGGLE DATA SOURCES ARE PRIVATE
# RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES.
import kagglehub
kagglehub.login()

VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…

Kaggle credentials set.
Kaggle credentials successfully validated.


In [2]:
# IMPORTANT: RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES,
# THEN FEEL FREE TO DELETE THIS CELL.
# NOTE: THIS NOTEBOOK ENVIRONMENT DIFFERS FROM KAGGLE'S PYTHON
# ENVIRONMENT SO THERE MAY BE MISSING LIBRARIES USED BY YOUR
# NOTEBOOK.

data_path = kagglehub.competition_download('cmi-detect-behavior-with-sensor-data')

print('Data source import complete.')


Downloading from https://www.kaggle.com/api/v1/competitions/data/download-all/cmi-detect-behavior-with-sensor-data...


100%|██████████| 178M/178M [00:00<00:00, 238MB/s]

Extracting files...





Data source import complete.


In [3]:
from scipy.spatial.transform import Rotation as R
from matplotlib import pyplot as plt
from torch.utils.data import Dataset
from tqdm.auto import tqdm
import seaborn as sns
import pandas as pd
import polars as pl
import numpy as np
import torch
import cv2
import os

import warnings
warnings.filterwarnings("ignore")

In [4]:
df = pl.read_csv(os.path.join(data_path, 'train.csv'))
demographies_df = pl.read_csv(os.path.join(data_path, 'train_demographics.csv'))

In [5]:
def remove_gravity_from_acc(acc_data, rot_data):

    if isinstance(acc_data, pd.DataFrame):
        acc_values = acc_data[['acc_x', 'acc_y', 'acc_z']].values
    else:
        acc_values = acc_data

    if isinstance(rot_data, pd.DataFrame):
        quat_values = rot_data[['rot_x', 'rot_y', 'rot_z', 'rot_w']].values
    else:
        quat_values = rot_data

    num_samples = acc_values.shape[0]
    linear_accel = np.zeros_like(acc_values)

    gravity_world = np.array([0, 0, 9.81])

    for i in range(num_samples):
        if np.all(np.isnan(quat_values[i])) or np.all(np.isclose(quat_values[i], 0)):
            linear_accel[i, :] = acc_values[i, :]
            continue

        try:
            rotation = R.from_quat(quat_values[i])
            gravity_sensor_frame = rotation.apply(gravity_world, inverse=True)
            linear_accel[i, :] = acc_values[i, :] - gravity_sensor_frame
        except ValueError:
             linear_accel[i, :] = acc_values[i, :]

    return linear_accel

def calculate_angular_velocity_from_quat(rot_data, time_delta=1/200): # Assuming 200Hz sampling rate
    if isinstance(rot_data, pd.DataFrame):
        quat_values = rot_data[['rot_x', 'rot_y', 'rot_z', 'rot_w']].values
    else:
        quat_values = rot_data

    num_samples = quat_values.shape[0]
    angular_vel = np.zeros((num_samples, 3))

    for i in range(num_samples - 1):
        q_t = quat_values[i]
        q_t_plus_dt = quat_values[i+1]

        if np.all(np.isnan(q_t)) or np.all(np.isclose(q_t, 0)) or \
           np.all(np.isnan(q_t_plus_dt)) or np.all(np.isclose(q_t_plus_dt, 0)):
            continue

        try:
            rot_t = R.from_quat(q_t)
            rot_t_plus_dt = R.from_quat(q_t_plus_dt)

            # Calculate the relative rotation
            delta_rot = rot_t.inv() * rot_t_plus_dt

            # Convert delta rotation to angular velocity vector
            # The rotation vector (Euler axis * angle) scaled by 1/dt
            # is a good approximation for small delta_rot
            angular_vel[i, :] = delta_rot.as_rotvec() / time_delta
        except ValueError:
            # If quaternion is invalid, angular velocity remains zero
            pass

    return angular_vel


def calculate_angular_distance(rot_data):
    if isinstance(rot_data, pd.DataFrame):
        quat_values = rot_data[['rot_x', 'rot_y', 'rot_z', 'rot_w']].values
    else:
        quat_values = rot_data

    num_samples = quat_values.shape[0]
    angular_dist = np.zeros(num_samples)

    for i in range(num_samples - 1):
        q1 = quat_values[i]
        q2 = quat_values[i+1]

        if np.all(np.isnan(q1)) or np.all(np.isclose(q1, 0)) or \
           np.all(np.isnan(q2)) or np.all(np.isclose(q2, 0)):
            angular_dist[i] = 0 # Или np.nan, в зависимости от желаемого поведения
            continue
        try:
            # Преобразование кватернионов в объекты Rotation
            r1 = R.from_quat(q1)
            r2 = R.from_quat(q2)

            # Вычисление углового расстояния: 2 * arccos(|real(p * q*)|)
            # где p* - сопряженный кватернион q
            # В scipy.spatial.transform.Rotation, r1.inv() * r2 дает относительное вращение.
            # Угол этого относительного вращения - это и есть угловое расстояние.
            relative_rotation = r1.inv() * r2

            # Угол rotation vector соответствует угловому расстоянию
            # Норма rotation vector - это угол в радианах
            angle = np.linalg.norm(relative_rotation.as_rotvec())
            angular_dist[i] = angle
        except ValueError:
            angular_dist[i] = 0 # В случае недействительных кватернионов
            pass

    return angular_dist

def group_ing(df, max_sequence_length=None):
    IMU_FEATURES = ['acc_x', 'acc_y', 'acc_z', 'rot_w', 'rot_x', 'rot_y', 'rot_z']
    THERMO_FEATURES = [f'thm_{i}' for i in range(1, 6)]
    SEQ_COLS = IMU_FEATURES + THERMO_FEATURES
    if hasattr(df, 'to_pandas'):
        df = df.to_pandas()
    df = df[SEQ_COLS].copy()
    df['acc_mag'] = np.sqrt(df['acc_x']**2 + df['acc_y']**2 + df['acc_z']**2)
    df['rot_angle'] = 2 * np.arccos(df['rot_w'].clip(-1, 1))
    SEQ_COLS += ['acc_mag', 'rot_angle']
    df['acc_mag_jerk'] = df['acc_mag'].diff().fillna(0)
    df['rot_angle_vel'] = df['rot_angle'].diff().fillna(0)
    SEQ_COLS += ['acc_mag_jerk', 'rot_angle_vel']
    acc_data_group = df[['acc_x', 'acc_y', 'acc_z']]
    rot_data_group = df[['rot_x', 'rot_y', 'rot_z', 'rot_w']]
    linear_accel = remove_gravity_from_acc(acc_data_group, rot_data_group)
    df_linear_accel = pd.DataFrame(linear_accel, columns=['linear_acc_x', 'linear_acc_y', 'linear_acc_z'], index=df.index)
    df = pd.concat([df, df_linear_accel], axis=1)
    SEQ_COLS += ['linear_acc_x', 'linear_acc_y', 'linear_acc_z']
    df['linear_acc_mag'] = np.sqrt(df['linear_acc_x']**2 + df['linear_acc_y']**2 + df['linear_acc_z']**2)
    df['linear_acc_mag_jerk'] = df['linear_acc_mag'].diff().fillna(0)
    SEQ_COLS += ['linear_acc_mag', 'linear_acc_mag_jerk']
    angular_vel = calculate_angular_velocity_from_quat(rot_data_group)
    df_angular_vel = pd.DataFrame(angular_vel, columns=['angular_vel_x', 'angular_vel_y', 'angular_vel_z'], index=df.index)
    df = pd.concat([df, df_angular_vel], axis=1)
    SEQ_COLS += ['angular_vel_x', 'angular_vel_y', 'angular_vel_z']
    angular_dist_group = calculate_angular_distance(rot_data_group)
    df_angular_distance = pd.DataFrame(angular_dist_group, columns=['angular_distance'], index=df.index)
    df = pd.concat([df, df_angular_distance], axis=1)
    SEQ_COLS += ['angular_distance']
    sequence = df[SEQ_COLS].values
    sequence = np.nan_to_num(sequence, nan=0.0)
    if max_sequence_length:
        length = sequence.shape[0]
        if length < max_sequence_length:
            pad_width = [(0, max_sequence_length - sequence.shape[0])] + [(0, 0)] * (sequence.ndim - 1)
            sequence = np.pad(sequence, pad_width, mode='constant', constant_values=0)
        else:
            sequence = sequence[:max_sequence_length]
    return sequence

class CMIDataset(Dataset):
    def __init__(
          self,
          df,
          demographies_df,
          max_sequence_length=None,
          do_resize=False,
          group_id='sequence_id',
          label_col='gesture',
          label2idx=None,
          jitter_prob=0,
          channel_shuffle_prob=0,
          shift_prob=0,
          max_shift=1,
          alpha=0.0,
          cache=False
        ):
        self.df = df.clone() if isinstance(df, pl.DataFrame) else df.copy()
        self.group_id = group_id
        self.demographies_df = demographies_df
        self.max_sequence_length = max_sequence_length
        self.do_resize = do_resize
        self._ids = self.df[self.group_id].unique()
        self.label_col = label_col
        self.label2idx = label2idx
        self.alpha = alpha
        self.cache = cache

        if self.label2idx is None:
            self.label2idx = {label: idx for idx, label in enumerate(self.df[self.label_col].unique())}

        self._demo_static_cols = ['age', 'height_cm', 'shoulder_to_wrist_cm', 'elbow_to_wrist_cm', 'adult_child', 'sex', 'handedness']
        self._df_static_cols = ['sequence_counter']

        self.probs = [jitter_prob, channel_shuffle_prob, shift_prob]
        self.transforms = [self._jitter, self._channel_shuffle, lambda x: self._shift_time(x, max_shift)]

        if cache:
            self._tof_data    = []
            self._cnt_data    = []
            self._static_vecs = []
            self._labels      = []
            self._masks       = []
            for idx in tqdm(range(len(self)), desc='Caching data'):
                inputs = self._pipeline_(idx)
                self._tof_data.append(inputs['tof_sequence'])
                self._cnt_data.append(inputs['cnt_sequence'])
                self._static_vecs.append(inputs['static'])
                self._labels.append(inputs['labels'])
                self._masks.append(inputs['padding_mask'])

    def _prepare_tof_data(self, df_grouped, max_sequence_length=None, do_resize=False):
        tof_sequences = []
        for sensor_id in range(1, 6):
            tof_columns = [f"tof_{sensor_id}_v{i}" for i in range(64)]
            if hasattr(df_grouped, 'to_numpy'):
                sequence = df_grouped[tof_columns].to_numpy()
            else:
                sequence = df_grouped[tof_columns].values
            tof_sequences.append(sequence)

        tof_sequences = np.stack(tof_sequences, axis=-1)
        tof_sequences = np.uint8(tof_sequences+1)
        if max_sequence_length:
            if do_resize:
                tof_sequences = cv2.resize(tof_sequences, (tof_sequences.shape[1], max_sequence_length), interpolation=cv2.INTER_CUBIC)
            else:
                length = tof_sequences.shape[0]
                if length < max_sequence_length:
                    pad_width = [(0, max_sequence_length - tof_sequences.shape[0])] + [(0, 0)] * (tof_sequences.ndim - 1)
                    tof_sequences = np.pad(tof_sequences, pad_width, mode='constant', constant_values=0)
                else:
                    tof_sequences = tof_sequences[:max_sequence_length]
        return tof_sequences

    def __len__(self):
        return len(self._ids)

    def _compute_attention_mask(self, sequence):
        return np.any(np.not_equal(sequence, 0), axis=-1)

    def _pipeline_(self, idx):
        if isinstance(self.df, pl.DataFrame):
            group_df = self.df.filter(pl.col(self.group_id) == self._ids[idx])
            demo_df = self.demographies_df.filter(pl.col('subject') == group_df['subject'].first())
        else:
            group_df = self.df[self.df[self.group_id] == self._ids[idx]]
            demo_df = self.demographies_df[self.demographies_df['subject'] == group_df['subject'].iloc[0]]
        tof_sequence = self._prepare_tof_data(group_df, max_sequence_length=self.max_sequence_length, do_resize=self.do_resize)
        cnt_sequence = group_ing(group_df, max_sequence_length=self.max_sequence_length)
        demo_static = demo_df[self._demo_static_cols].to_numpy() if hasattr(demo_df, 'to_numpy') else demo_df[self._demo_static_cols].values
        df_static = group_df[self._df_static_cols].to_numpy() if hasattr(group_df, 'to_numpy') else group_df[self._df_static_cols].values
        df_static = np.asarray(df_static.max(), dtype=np.float32).reshape((len(self._df_static_cols), ))
        static = np.concatenate([demo_static.flatten(), df_static.flatten()], axis=0)
        padding_mask = self._compute_attention_mask(cnt_sequence)
        tof_sequence = np.transpose(tof_sequence, (2, 0, 1))
        inputs = {'tof_sequence': tof_sequence, 'cnt_sequence': cnt_sequence, 'static': static, 'padding_mask': padding_mask}
        inputs['labels'] = self.label2idx[group_df[self.label_col].first()]
        inputs['labels'] = self._one_hot(inputs['labels'], len(self.label2idx))
        return inputs

    def __getitem__(self, idx):
        if self.cache:
            inputs = {
                'tof_sequence': self._tof_data[idx],
                'cnt_sequence': self._cnt_data[idx],
                'static': self._static_vecs[idx],
                'padding_mask': self._masks[idx]
            }
            inputs['labels'] = self._labels[idx]
        else:
            inputs = self._pipeline_(idx)
        inputs = self._augment(inputs)
        inputs = {k: torch.from_numpy(v).float() for k, v in inputs.items()}
        return inputs

    def _augment(self, inputs):
        if any(self.probs):
            for prob, transform in zip(self.probs, self.transforms):
                if np.random.rand() < prob:
                    for k in inputs.keys():
                        if k == 'static' or k == 'padding_mask':
                            continue
                        inputs[k] = transform(inputs[k])
        inputs = self._mixup(inputs, self.alpha)
        return inputs

    def _mixup(self, x, alpha=0.0):
        if alpha > 0:
            ramdom_idx = np.random.randint(0, len(self))
            y = self._pipeline_(ramdom_idx)
            lam = np.random.beta(alpha, alpha)
            for k in x.keys():
                if k == 'static' or k == 'padding_mask':
                    continue
                x[k] = lam * x[k] + (1 - lam) * y[k]
        return x

    def _one_hot(self, x, num_classes):
        return np.eye(num_classes)[x]

    def _jitter(self, x, sigma=0.01):
        return x + sigma * np.random.randn(*x.shape)

    def _channel_shuffle(self, x):
        shuffled_indices = np.random.permutation(x.shape[-1])
        return x[..., shuffled_indices]

    def _shift_time(self, x, max_shift=1):
        shift = np.random.randint(-max_shift, max_shift + 1)
        if shift > 0:
            return np.concatenate([x[:, shift:], np.zeros((x.shape[0], shift, *x.shape[2:]))], axis=1)
        else:
            return np.concatenate([np.zeros((x.shape[0], -shift, *x.shape[2:])), x[:, :shift]], axis=1)

In [6]:
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder

sequence_id = df['sequence_id'].unique()
train_ids, dev_ids = train_test_split(sequence_id, test_size=0.2, random_state=42)
val_ids, test_ids = train_test_split(dev_ids, test_size=0.5, random_state=42)

if isinstance(df, pl.DataFrame):
    df_train = df.filter(pl.col('sequence_id').is_in(train_ids))
    df_valid = df.filter(pl.col('sequence_id').is_in(val_ids))
    df_test = df.filter(pl.col('sequence_id').is_in(test_ids))
else:
    df_train = df[df['sequence_id'].isin(train_ids)]
    df_valid = df[df['sequence_id'].isin(val_ids)]
    df_test = df[df['sequence_id'].isin(test_ids)]

train_dataset = CMIDataset(
    df_train,
    demographies_df,
    group_id='sequence_id',
    label_col='gesture',
    max_sequence_length=256,
    do_resize=False,
    jitter_prob=0.25,
    channel_shuffle_prob=0.15,
    shift_prob=0.0,
    max_shift=10,
    alpha=0.2,
    cache=True
)

val_dataset = CMIDataset(
    df_valid,
    demographies_df,
    group_id='sequence_id',
    label_col='gesture',
    max_sequence_length=256,
    do_resize=False,
    label2idx=train_dataset.label2idx,
    alpha=0.0,
    cache=True
)
test_dataset = CMIDataset(
    df_test,
    demographies_df,
    group_id='sequence_id',
    label_col='gesture',
    max_sequence_length=256,
    do_resize=False,
    label2idx=train_dataset.label2idx,
    cache=True
)

Caching data:   0%|          | 0/6520 [00:00<?, ?it/s]

Caching data:   0%|          | 0/815 [00:00<?, ?it/s]

Caching data:   0%|          | 0/816 [00:00<?, ?it/s]

In [7]:
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.models.t5.modeling_t5 import T5Attention, T5Config, T5LayerSelfAttention
from torchvision.transforms.functional import resize as torch_resize
from transformers import PretrainedConfig, PreTrainedModel
import torch.nn.functional as F
from torch import nn
import timm

In [8]:
class CMIConfig(PretrainedConfig):
    def __init__(
        self,
        ctn_dim=25,
        static_dim=8,
        tof_dim=5,
        tof_length=64,
        sequence_length=16,
        do_resize=False,
        vision_backbone='resnet50d',
        num_classes=18,
        model_dim=512,
        hidden_dim=512,
        num_heads=8,
        num_layers=6,
        dropout=0.1,
        pretrained_weights=False,
        freeze_backbone=False,
        epsilon=1e-4,
        **kwargs
      ):
        super().__init__(**kwargs)
        self.ctn_dim = ctn_dim
        self.static_dim = static_dim
        self.tof_dim = tof_dim
        self.tof_length = tof_length
        self.sequence_length = sequence_length
        self.do_resize = do_resize
        self.vision_backbone = vision_backbone
        self.num_classes = num_classes
        self.model_dim = model_dim
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.dropout = dropout
        self.pretrained_weights = pretrained_weights
        self.freeze_backbone = freeze_backbone
        self.epsilon = epsilon

In [9]:
sample = train_dataset[1]

In [10]:
for k, v in sample.items():
    print(k, v.shape)

tof_sequence torch.Size([5, 256, 64])
cnt_sequence torch.Size([256, 25])
static torch.Size([8])
padding_mask torch.Size([256])
labels torch.Size([18])


In [11]:
config = CMIConfig(
      static_dim=sample['static'].shape[0],
      ctn_dim=sample['cnt_sequence'].shape[1],
      tof_dim=sample['tof_sequence'].shape[0],
      tof_length=sample['tof_sequence'].shape[2],
      sequence_length=sample['cnt_sequence'].shape[0],
      num_classes=len(train_dataset.label2idx),
      model_dim=512,
      hidden_dim=1024,
      num_heads=8,
      num_layers=6,
      dropout=0.2,
      pretrained_weights=True,
      freeze_backbone=False,
      vision_backbone='resnet50d',
      epsilon=1e-4,
)
config

CMIConfig {
  "ctn_dim": 25,
  "do_resize": false,
  "dropout": 0.2,
  "epsilon": 0.0001,
  "freeze_backbone": false,
  "hidden_dim": 1024,
  "model_dim": 512,
  "num_classes": 18,
  "num_heads": 8,
  "num_layers": 6,
  "pretrained_weights": true,
  "relative_attention_max_distance": 128,
  "relative_attention_num_buckets": 32,
  "sequence_length": 256,
  "static_dim": 8,
  "tof_dim": 5,
  "tof_length": 64,
  "transformers_version": "4.53.1",
  "vision_backbone": "resnet50d"
}

In [12]:
class TOFEncoder(nn.Module):
    def __init__(self, config: CMIConfig, **kwargs):
        super().__init__(**kwargs)
        self.config = config
        self.vision_tower = timm.create_model(
            config.vision_backbone,
            pretrained=config.pretrained_weights,
            num_classes=0,  global_pool='',
            in_chans=config.tof_dim
        )
        self.adapter = nn.Sequential(
            nn.Linear(self._get_num_vision_features(), config.model_dim),
            nn.LayerNorm(config.model_dim),
            nn.GELU(),
            nn.Dropout(config.dropout)
        )

    def _get_num_vision_features(self):
        if  hasattr(self.vision_tower, 'feature_info'):
            info = self.vision_tower.feature_info[-1]
            return info.get('num_chs', None) or info.get('channels', None)
        elif hasattr(self.vision_tower, 'num_features'):
            return self.vision_tower.num_features
        else:
            print('Warning !!! num feature inference')
            with torch.no_grad():
                x = torch.ones((1, self.config.tof_dim, self.config.sequence_length, self.config.tof_length))
                x = self.vision_tower(x)
            return x.shape[-1]

    def forward(self, tof_sequence, padding_mask=None):
        # raw features: [B, C, Tseq, Ttof]
        x = self.vision_tower(tof_sequence)

        if padding_mask is not None and not self.config.do_resize:
            # [B, 1, Tseq, 1] → [B,1,Tseq,Ttof]
            mask = padding_mask.view(-1,1,self.config.sequence_length,1)
            mask = mask.expand(-1,1,-1,self.config.tof_length)
            # resize to match spatial dims
            mask = torch_resize(mask.float(), x.shape[2:], antialias=True).bool()

            num   = (x * mask).sum(dim=(2,3))
            denom = mask.sum(dim=(2,3)).clamp_min(self.config.epsilon)
            x     = num / denom
        else:
            # global mean pooling
            x = x.mean(dim=(2,3))

        return self.adapter(x)


class EarlyFusionLayer(nn.Module):
    def __init__(self, config: CMIConfig, **kwargs):
        super().__init__(**kwargs)
        self.config = config
        self.static_proj = nn.Sequential(
            nn.Linear(config.static_dim, config.model_dim),
            nn.LayerNorm(config.model_dim),
            nn.GELU(),
            nn.Dropout(config.dropout)
        )
        self.ctn_proj = nn.Sequential(
            nn.Linear(config.ctn_dim, config.model_dim),
            nn.LayerNorm(config.model_dim),
            nn.GELU(),
            nn.Dropout(config.dropout)
        )
        self.tof_proj = nn.Sequential(
            nn.Linear(config.model_dim, config.model_dim),
            nn.LayerNorm(config.model_dim),
            nn.GELU(),
            nn.Dropout(config.dropout)
        )
        self.attn = nn.MultiheadAttention(
            embed_dim=config.model_dim,
            num_heads=config.num_heads,
            dropout=config.dropout,
            batch_first=True
        )
        self.ln1 = nn.LayerNorm(config.model_dim)
        self.start_token = nn.Parameter(torch.randn(1,1,config.model_dim))
        self.end_token   = nn.Parameter(torch.randn(1,1,config.model_dim))
        self.gate = nn.Sequential(
            nn.Linear(config.model_dim, config.hidden_dim),
            nn.GELU(),
            nn.Dropout(config.dropout),
            nn.Linear(config.hidden_dim, 1),
            nn.Sigmoid()
        )
        self.ffn  = nn.Sequential(
            nn.Linear(config.model_dim, config.hidden_dim),
            nn.GELU(),
            nn.Dropout(config.dropout),
            nn.Linear(config.hidden_dim, config.model_dim),
            nn.Dropout(config.dropout)
        )
        self.ln2 = nn.LayerNorm(config.model_dim)

    def forward(self, static, cnt_seq, tof_seq, padding_mask=None):
        B = static.size(0)
        s = self.static_proj(static).unsqueeze(1)    # [B,1,D]
        c = self.ctn_proj(cnt_seq)                   # [B, Lc, D]
        t = self.tof_proj(tof_seq).unsqueeze(1)      # [B,1,D]

        x = torch.cat([
            self.start_token.expand(B,-1,-1), s, c, t,
            self.end_token.expand(B,-1,-1)
        ], dim=1)  # [B, L, D]

        if padding_mask is not None:
            mask = torch.cat([
                torch.ones(B, 2,  dtype=torch.bool, device=x.device),
                padding_mask,
                torch.ones(B, 2,  dtype=torch.bool, device=x.device)
            ], dim=1)
        else:
            mask = None

        attn_out, _ = self.attn(x, x, x, key_padding_mask=~mask)
        gated = self.gate(attn_out) * x
        x = x + gated
        x = self.ln1(x)
        x = x + self.ffn(x)
        x = self.ln2(x)
        return x


class CMIClassifier(PreTrainedModel):
    def __init__(self, config: CMIConfig, **kwargs):
        super().__init__(config, **kwargs)
        self.static_norm = nn.LayerNorm(config.static_dim)
        self.ctn_norm    = nn.LayerNorm(config.ctn_dim)
        self.vision_enc  = TOFEncoder(config)
        self.fusion      = EarlyFusionLayer(config)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=config.model_dim,
            nhead=config.num_heads,
            dim_feedforward=config.hidden_dim,
            dropout=config.dropout,
            activation='gelu',
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(
            encoder_layer, num_layers=config.num_layers,
            norm=nn.LayerNorm(config.model_dim)
        )
        self.classifier = nn.Sequential(
            nn.Linear(config.model_dim, config.hidden_dim),
            nn.GELU(),
            nn.Dropout(config.dropout),
            nn.Linear(config.hidden_dim, config.num_classes)
        )
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, tof_sequence, cnt_sequence, static, padding_mask=None, labels=None):
        static = self.static_norm(static.float())
        cnt_seq = self.ctn_norm(cnt_sequence.float())
        tof_seq = (tof_sequence.float() / 127.5) - 1.0
        if padding_mask is not None:
            padding_mask = padding_mask.bool()

        x = self.vision_enc(tof_seq, padding_mask=padding_mask)
        x = self.fusion(static, cnt_seq, x, padding_mask=padding_mask)
        x = self.transformer(x)          # [B, L, D]
        x = x[:, 0, :]                   # use start token slot
        logits = self.classifier(x)

        outputs = {'logits': logits}
        if labels is not None:
            # ensure integer class labels
            labels = labels.float()
            loss = self.loss_fn(logits, labels)
            outputs['loss'] = loss
        return outputs

In [13]:
model = CMIClassifier(config)

model.safetensors:   0%|          | 0.00/103M [00:00<?, ?B/s]

In [14]:
sample = train_dataset[0]
sample = {k: v.unsqueeze(0) for k, v in sample.items()}
outputs = model(**sample)

In [15]:
outputs

{'logits': tensor([[-0.2498, -0.1620, -0.2841, -0.2317, -0.1529,  0.1177,  0.1414,  0.0854,
          -0.1853, -0.3522, -0.2700, -0.2751, -0.0190,  0.2404, -0.0663, -0.0804,
           0.0604, -0.2094]], grad_fn=<AddmmBackward0>),
 'loss': tensor(3.1417, grad_fn=<DivBackward1>)}

# Test training Model

In [16]:
from transformers import Trainer, TrainingArguments

args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=10,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    warmup_steps=500,
    weight_decay=0.00,
    logging_dir='./logs',
    logging_steps=10,
    eval_strategy='epoch',
    report_to='none',
)

In [17]:
from sklearn.metrics import f1_score, accuracy_score

def compute_metrics(eval_pred):
    """
    HuggingFace Trainer `compute_metrics` callback.

    Args:
      eval_pred: a tuple (logits, labels), both numpy arrays of shape (batch_size, ...)

    Returns:
      metrics: a flat dict of {metric_name: float} including per-class and avg stats
    """
    logits, labels = eval_pred
    # get predicted class indices
    preds = np.argmax(logits, axis=-1)
    labels = np.argmax(labels, axis=-1)

    f1 = f1_score(labels, preds, average='macro', zero_division=0)
    acc = accuracy_score(labels, preds)
    return {'f1': f1, 'acc': acc}

In [18]:
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=None,
    compute_metrics=compute_metrics,
)

In [19]:
trainer.train()

Could not estimate the number of tokens of the input, floating-point operations will not be computed


Epoch,Training Loss,Validation Loss,F1,Acc
1,2.7862,2.785944,0.019258,0.066258
2,2.714,2.711552,0.05381,0.119018
3,2.4287,2.276342,0.136098,0.240491
4,2.1252,1.879714,0.328965,0.373006


Epoch,Training Loss,Validation Loss,F1,Acc
1,2.7862,2.785944,0.019258,0.066258
2,2.714,2.711552,0.05381,0.119018
3,2.4287,2.276342,0.136098,0.240491
4,2.1252,1.879714,0.328965,0.373006
5,1.8833,1.582005,0.461583,0.471166
6,1.7404,1.443295,0.504778,0.516564
7,1.5994,1.363813,0.551724,0.544785
8,1.4973,1.273896,0.576996,0.573006
9,1.5207,1.256793,0.571609,0.566871
10,1.3987,1.269001,0.562112,0.568098


TrainOutput(global_step=1020, training_loss=2.045849386850993, metrics={'train_runtime': 5971.1334, 'train_samples_per_second': 10.919, 'train_steps_per_second': 0.171, 'total_flos': 0.0, 'train_loss': 2.045849386850993, 'epoch': 10.0})