In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from typing import Dict, List, Tuple, Optional, Union, Callable
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
import logging
from collections import defaultdict
import heapq
from einops import rearrange, repeat

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# ==================== ДАТАСЕТЫ ====================
class UniversalRecDataset(Dataset):
    """Универсальный датасет для всех типов задач рекомендаций"""
    def __init__(self, 
                 df: pd.DataFrame,
                 user_col: str = 'user_id',
                 item_col: str = 'item_id', 
                 rating_col: Optional[str] = 'rating',
                 time_col: Optional[str] = 'timestamp',
                 sequence_length: int = 50,
                 task_type: str = 'sequential',
                 additional_features: Optional[List[str]] = None):
        """
        additional_features: список дополнительных колонок для фич
        """
        self.df = df.copy()
        self.user_col = user_col
        self.item_col = item_col
        self.rating_col = rating_col
        self.time_col = time_col
        self.sequence_length = sequence_length
        self.task_type = task_type
        self.additional_features = additional_features or []
        
        if task_type == 'sequential':
            # Подготовка для последовательных задач
            self.df = self.df.sort_values([user_col, time_col]).reset_index(drop=True)
            self.user_sequences = defaultdict(list)
            for _, row in self.df.iterrows():
                user_id = row[user_col]
                item_id = row[item_col]
                seq_item = {'item_id': item_id}
                for feat in self.additional_features:
                    if feat in row:
                        seq_item[feat] = row[feat]
                self.user_sequences[user_id].append(seq_item)
            self.users = list(self.user_sequences.keys())
        else:
            # Для других задач просто сохраняем все взаимодействия
            pass
    
    def __len__(self):
        if self.task_type == 'sequential':
            return len(self.users)
        else:
            return len(self.df)
    
    def __getitem__(self, idx):
        if self.task_type == 'sequential':
            user_id = self.users[idx]
            items = self.user_sequences[user_id]
            
            # Берем последние N взаимодействий
            if len(items) > self.sequence_length:
                sequence = items[-self.sequence_length:]
            else:
                # Паддинг первым элементом
                first_item = items[0]
                sequence = [first_item] * (self.sequence_length - len(items)) + items
            
            # Извлекаем айтемы и фичи
            item_sequence = [item['item_id'] for item in sequence[:-1]]
            target_item = [sequence[-1]['item_id']]
            
            batch = {
                'user_id': int(user_id),
                'sequence': torch.LongTensor(item_sequence),
                'target_item': torch.LongTensor(target_item)
            }
            
            # Добавляем дополнительные фичи если есть
            if self.additional_features:
                for feat in self.additional_features:
                    feat_sequence = [item.get(feat, 0) for item in sequence[:-1]]
                    batch[f'sequence_{feat}'] = torch.FloatTensor(feat_sequence)
                    batch[f'target_{feat}'] = torch.FloatTensor([sequence[-1].get(feat, 0)])
            
            return batch
        else:
            row = self.df.iloc[idx]
            result = {
                'user_id': int(row[self.user_col]),
                'item_id': int(row[self.item_col])
            }
            
            if self.rating_col and self.rating_col in row:
                result['rating'] = float(row[self.rating_col])
            
            if self.time_col and self.time_col in row:
                result['timestamp'] = float(row[self.time_col])
                
            # Добавляем дополнительные фичи
            for feat in self.additional_features:
                if feat in row:
                    result[feat] = float(row[feat])
                    
            return result

# ==================== МОДЕЛИ ====================

class SASEmbedding(nn.Module):
    """Эмбеддинги с позиционным кодированием для SASRec"""
    def __init__(self, n_items: int, embedding_dim: int, max_len: int = 200):
        super().__init__()
        self.item_embeddings = nn.Embedding(n_items, embedding_dim)
        self.position_embeddings = nn.Embedding(max_len, embedding_dim)
        self.layer_norm = nn.LayerNorm(embedding_dim)
        self.dropout = nn.Dropout(0.1)
        
    def forward(self, item_ids: torch.Tensor, positions: Optional[torch.Tensor] = None):
        if positions is None:
            positions = torch.arange(item_ids.size(1), device=item_ids.device).unsqueeze(0)
            positions = positions.expand_as(item_ids)
            
        item_emb = self.item_embeddings(item_ids)
        pos_emb = self.position_embeddings(positions)
        
        embeddings = item_emb + pos_emb
        embeddings = self.layer_norm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings

class SASRec(nn.Module):
    """Мощная sequential модель - SASRec (Self-Attentive Sequential Recommendation)"""
    def __init__(self, n_items: int, embedding_dim: int = 128, n_heads: int = 8, 
                 n_layers: int = 2, dropout: float = 0.1, max_len: int = 200):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.embeddings = SASEmbedding(n_items, embedding_dim, max_len)
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embedding_dim,
            nhead=n_heads,
            dim_feedforward=embedding_dim * 4,
            dropout=dropout,
            activation='gelu',
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
        self.output_projection = nn.Linear(embedding_dim, n_items)
        
    def forward(self, sequences: torch.Tensor) -> torch.Tensor:
        # sequences: (batch_size, seq_len)
        embedded = self.embeddings(sequences)  # (batch_size, seq_len, embed_dim)
        attended = self.transformer(embedded)  # (batch_size, seq_len, embed_dim)
        
        # Используем последнее состояние для предсказания
        last_states = attended[:, -1, :]  # (batch_size, embed_dim)
        logits = self.output_projection(last_states)  # (batch_size, n_items)
        return F.log_softmax(logits, dim=1)

class DeepFM(nn.Module):
    """DeepFM - мощная модель для факторизации и нейронных сетей"""
    def __init__(self, n_users: int, n_items: int, embedding_dim: int = 64, 
                 deep_dims: List[int] = [256, 128, 64]):
        super().__init__()
        self.user_embedding = nn.Embedding(n_users, embedding_dim)
        self.item_embedding = nn.Embedding(n_items, embedding_dim)
        
        # FM part - second order interactions
        self.fm_linear = nn.Linear(n_users + n_items, 1)
        self.fm_bias = nn.Parameter(torch.tensor(0.0))
        
        # Deep part
        total_features = embedding_dim * 2  # concatenated user + item embeddings
        layers = []
        input_dim = total_features
        for hidden_dim in deep_dims:
            layers.extend([
                nn.Linear(input_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.ReLU(),
                nn.Dropout(0.2)
            ])
            input_dim = hidden_dim
        layers.append(nn.Linear(input_dim, 1))
        self.mlp = nn.Sequential(*layers)
        
    def forward(self, user_ids: torch.Tensor, item_ids: torch.Tensor) -> torch.Tensor:
        # FM embeddings
        user_emb = self.user_embedding(user_ids)
        item_emb = self.item_embedding(item_ids)
        
        # Linear part
        user_onehot = F.one_hot(user_ids, num_classes=self.user_embedding.num_embeddings).float()
        item_onehot = F.one_hot(item_ids, num_classes=self.item_embedding.num_embeddings).float()
        linear_input = torch.cat([user_onehot, item_onehot], dim=1)
        linear_output = self.fm_linear(linear_input) + self.fm_bias
        
        # FM second order (simplified)
        concat_emb = torch.cat([user_emb, item_emb], dim=1)
        fm_output = torch.sum(user_emb * item_emb, dim=1, keepdim=True)
        
        # Deep part
        deep_output = self.mlp(concat_emb)
        
        return linear_output + fm_output + deep_output

class WideDeep(nn.Module):
    """Wide & Deep модель"""
    def __init__(self, n_users: int, n_items: int, embedding_dim: int = 64,
                 deep_dims: List[int] = [256, 128, 64]):
        super().__init__()
        self.user_embedding = nn.Embedding(n_users, embedding_dim)
        self.item_embedding = nn.Embedding(n_items, embedding_dim)
        
        # Wide part - linear model
        self.wide_linear = nn.Linear(n_users + n_items, 1)
        
        # Deep part - neural network
        concat_dim = embedding_dim * 2
        layers = []
        input_dim = concat_dim
        for hidden_dim in deep_dims:
            layers.extend([
                nn.Linear(input_dim, hidden_dim),
                nn.ReLU(),
                nn.Dropout(0.1)
            ])
            input_dim = hidden_dim
        layers.append(nn.Linear(input_dim, 1))
        self.deep_mlp = nn.Sequential(*layers)
        
    def forward(self, user_ids: torch.Tensor, item_ids: torch.Tensor) -> torch.Tensor:
        # Wide part
        user_onehot = F.one_hot(user_ids, num_classes=self.user_embedding.num_embeddings).float()
        item_onehot = F.one_hot(item_ids, num_classes=self.item_embedding.num_embeddings).float()
        wide_input = torch.cat([user_onehot, item_onehot], dim=1)
        wide_output = self.wide_linear(wide_input)
        
        # Deep part
        user_emb = self.user_embedding(user_ids)
        item_emb = self.item_embedding(item_ids)
        concat_emb = torch.cat([user_emb, item_emb], dim=1)
        deep_output = self.deep_mlp(concat_emb)
        
        return wide_output + deep_output

class LightGCN(nn.Module):
    """LightGCN - графовая модель для рекомендаций"""
    def __init__(self, n_users: int, n_items: int, embedding_dim: int = 64, n_layers: int = 3):
        super().__init__()
        self.n_users = n_users
        self.n_items = n_items
        self.embedding_dim = embedding_dim
        self.n_layers = n_layers
        
        # Инициализация эмбеддингов
        self.user_embeddings = nn.Embedding(n_users, embedding_dim)
        self.item_embeddings = nn.Embedding(n_items, embedding_dim)
        
        # Инициализация
        nn.init.xavier_uniform_(self.user_embeddings.weight)
        nn.init.xavier_uniform_(self.item_embeddings.weight)
        
    def forward(self, user_ids: torch.Tensor, item_ids: torch.Tensor, 
                adj_matrix: torch.sparse.FloatTensor) -> torch.Tensor:
        """
        adj_matrix: sparse adjacency matrix user-item interactions
        """
        # Get initial embeddings
        user_emb = self.user_embeddings.weight  # (n_users, embed_dim)
        item_emb = self.item_embeddings.weight  # (n_items, embed_dim)
        
        # Concatenate user and item embeddings
        all_emb = torch.cat([user_emb, item_emb], dim=0)  # (n_users + n_items, embed_dim)
        
        # Graph convolution
        embeddings = [all_emb]
        for _ in range(self.n_layers):
            all_emb = torch.sparse.mm(adj_matrix, all_emb)
            embeddings.append(all_emb)
        
        # Take mean of all layer embeddings (LightGCN)
        final_emb = torch.stack(embeddings, dim=0).mean(dim=0)
        
        # Split back
        user_final = final_emb[:self.n_users]
        item_final = final_emb[self.n_users:]
        
        # Get embeddings for specific users and items
        user_rep = user_final[user_ids]
        item_rep = item_final[item_ids]
        
        # Dot product
        return torch.sum(user_rep * item_rep, dim=1)

class MultiHeadAttentionLayer(nn.Module):
    """Custom Multi-Head Attention для рекомендаций"""
    def __init__(self, embedding_dim: int, n_heads: int, dropout: float = 0.1):
        super().__init__()
        assert embedding_dim % n_heads == 0
        
        self.embedding_dim = embedding_dim
        self.n_heads = n_heads
        self.head_dim = embedding_dim // n_heads
        
        self.fc_q = nn.Linear(embedding_dim, embedding_dim)
        self.fc_k = nn.Linear(embedding_dim, embedding_dim)
        self.fc_v = nn.Linear(embedding_dim, embedding_dim)
        
        self.fc_o = nn.Linear(embedding_dim, embedding_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, query, key, value, mask=None):
        batch_size = query.shape[0]
        
        Q = self.fc_q(query)
        K = self.fc_k(key)
        V = self.fc_v(value)
        
        Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        
        energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.head_dim ** 0.5
        
        if mask is not None:
            energy = energy.masked_fill(mask == 0, -1e10)
        
        attention = torch.softmax(energy, dim=-1)
        attention = self.dropout(attention)
        
        x = torch.matmul(attention, V)
        x = x.permute(0, 2, 1, 3).contiguous()
        x = x.view(batch_size, -1, self.embedding_dim)
        
        x = self.fc_o(x)
        
        return x, attention

class AdvancedSequentialModel(nn.Module):
    """Продвинутая sequential модель с attention и temporal features"""
    def __init__(self, n_items: int, embedding_dim: int = 128, n_heads: int = 8,
                 n_layers: int = 2, dropout: float = 0.1, max_len: int = 200):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.embeddings = SASEmbedding(n_items, embedding_dim, max_len)
        
        # Multi-head attention layers
        self.attention_layers = nn.ModuleList([
            MultiHeadAttentionLayer(embedding_dim, n_heads, dropout)
            for _ in range(n_layers)
        ])
        
        # Feed forward layers
        self.feed_forwards = nn.ModuleList([
            nn.Sequential(
                nn.Linear(embedding_dim, embedding_dim * 4),
                nn.GELU(),
                nn.Dropout(dropout),
                nn.Linear(embedding_dim * 4, embedding_dim),
                nn.Dropout(dropout)
            ) for _ in range(n_layers)
        ])
        
        self.layer_norms = nn.ModuleList([
            nn.LayerNorm(embedding_dim) for _ in range(n_layers * 2)
        ])
        
        self.output_projection = nn.Linear(embedding_dim, n_items)
        
    def forward(self, sequences: torch.Tensor) -> torch.Tensor:
        embedded = self.embeddings(sequences)
        
        x = embedded
        for i, (attn_layer, ff_layer, norm1, norm2) in enumerate(
            zip(self.attention_layers, self.feed_forwards, 
                self.layer_norms[::2], self.layer_norms[1::2])
        ):
            # Self-attention
            attn_out, _ = attn_layer(x, x, x)
            x = norm1(x + attn_out)
            
            # Feed forward
            ff_out = ff_layer(x)
            x = norm2(x + ff_out)
        
        # Используем последнее состояние
        last_states = x[:, -1, :]
        logits = self.output_projection(last_states)
        return F.log_softmax(logits, dim=1)

# ==================== (COMBINED) ====================

class UniversalPowerfulModel(nn.Module):
    """Универсальная мощная модель, комбинирующая разные подходы"""
    def __init__(self, 
                 n_users: int, 
                 n_items: int,
                 model_types: List[str],  # ['mf', 'deepfm', 'sasrec', 'lightgcn']
                 embedding_dim: int = 128,
                 hidden_dim: int = 256,
                 n_heads: int = 8,
                 n_layers: int = 2,
                 sequence_length: int = 50):
        super().__init__()
        self.model_types = model_types
        self.embedding_dim = embedding_dim
        self.sequence_length = sequence_length
        
        # Shared embeddings
        self.user_embedding = nn.Embedding(n_users, embedding_dim)
        self.item_embedding = nn.Embedding(n_items, embedding_dim)
        
        self.models = nn.ModuleDict()
        
        for model_type in model_types:
            if model_type == 'mf':
                self.models['mf'] = nn.Sequential(
                    nn.Linear(embedding_dim * 2, hidden_dim),
                    nn.ReLU(),
                    nn.Dropout(0.1),
                    nn.Linear(hidden_dim, 1)
                )
            elif model_type == 'deepfm':
                self.models['deepfm'] = DeepFM(n_users, n_items, embedding_dim)
            elif model_type == 'sasrec':
                self.models['sasrec'] = SASRec(n_items, embedding_dim, n_heads, n_layers)
            elif model_type == 'wide_deep':
                self.models['wide_deep'] = WideDeep(n_users, n_items, embedding_dim)
            elif model_type == 'advanced_seq':
                self.models['advanced_seq'] = AdvancedSequentialModel(
                    n_items, embedding_dim, n_heads, n_layers
                )
        
        # Final combination layer
        n_model_outputs = len(model_types)
        self.combination_layer = nn.Linear(n_model_outputs, 1)
        
    def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
        outputs = []
        
        for model_type, model in self.models.items():
            if model_type in ['sasrec', 'advanced_seq']:
                # Sequential models
                if 'sequence' in batch:
                    output = model(batch['sequence'])
                    # For sequential models, we return log probabilities
                    # Extract score for target item
                    target_item = batch['target_item'].squeeze(1)
                    batch_size = output.size(0)
                    scores = output[torch.arange(batch_size), target_item]
                    outputs.append(scores.unsqueeze(1))
            else:
                # Other models
                if 'user_id' in batch and 'item_id' in batch:
                    if model_type in ['deepfm', 'wide_deep']:
                        output = model(batch['user_id'], batch['item_id']).squeeze(1)
                    else:
                        user_emb = self.user_embedding(batch['user_id'])
                        item_emb = self.item_embedding(batch['item_id'])
                        concat_emb = torch.cat([user_emb, item_emb], dim=1)
                        output = model(concat_emb).squeeze(1)
                    outputs.append(output.unsqueeze(1))
        
        if outputs:
            combined_output = torch.cat(outputs, dim=1)
            final_output = self.combination_layer(combined_output).squeeze(1)
            return final_output
        else:
            # Fallback
            user_emb = self.user_embedding(batch['user_id'])
            item_emb = self.item_embedding(batch['item_id'])
            return torch.sum(user_emb * item_emb, dim=1)

# ==================== МЕТРИКИ ====================

class RecMetrics:
    """Класс для вычисления рекомендательных метрик"""
    
    @staticmethod
    def precision_at_k(y_true: List[int], y_pred: List[int], k: int = 20) -> float:
        """Precision@K"""
        if len(y_pred) > k:
            y_pred = y_pred[:k]
        y_true_set = set(y_true)
        y_pred_set = set(y_pred)
        return len(y_true_set & y_pred_set) / len(y_pred_set) if y_pred_set else 0.0
    
    @staticmethod
    def recall_at_k(y_true: List[int], y_pred: List[int], k: int = 20) -> float:
        """Recall@K"""
        if len(y_pred) > k:
            y_pred = y_pred[:k]
        y_true_set = set(y_true)
        y_pred_set = set(y_pred)
        return len(y_true_set & y_pred_set) / len(y_true_set) if y_true_set else 0.0
    
    @staticmethod
    def map_at_k(y_true: List[int], y_pred: List[int], k: int = 20) -> float:
        """Mean Average Precision@K"""
        if len(y_pred) > k:
            y_pred = y_pred[:k]
        
        score = 0.0
        num_hits = 0.0
        
        for i, p in enumerate(y_pred):
            if p in y_true and p not in y_pred[:i]:
                num_hits += 1.0
                score += num_hits / (i + 1.0)
        
        if not y_true:
            return 0.0
        
        return score / min(len(y_true), k)
    
    @staticmethod
    def ndcg_at_k(y_true: List[int], y_pred: List[int], k: int = 20) -> float:
        """Normalized Discounted Cumulative Gain@K"""
        if len(y_pred) > k:
            y_pred = y_pred[:k]
        
        # Бинарная релевантность
        relevance = [1.0 if item in y_true else 0.0 for item in y_pred]
        
        # DCG
        dcg = sum(rel / np.log2(pos + 2) for pos, rel in enumerate(relevance))
        
        # IDCG
        idcg = sum(1.0 / np.log2(pos + 2) for pos in range(min(len(y_true), k)))
        
        return dcg / idcg if idcg > 0.0 else 0.0

# ==================== LIGHTNING ====================

class UniversalRecLightning(pl.LightningModule):
    """Универсальный Lightning модуль для всех задач рекомендаций"""
    def __init__(self,
                 model: nn.Module,
                 task_type: str = 'sequential',
                 learning_rate: float = 1e-3,
                 weight_decay: float = 1e-4,
                 metrics: List[str] = ['map']):
        super().__init__()
        self.model = model
        self.task_type = task_type
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.metrics = metrics
        self.metrics_calculator = RecMetrics()
        
    def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
        return self.model(batch)
    
    def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor:
        if self.task_type == 'classic':
            predictions = self.model(batch)
            ratings = batch['rating']
            loss = F.mse_loss(predictions, ratings)
            
        elif self.task_type == 'sequential':
            # Для sequential моделей
            if 'target_item' in batch:
                logits = self.model(batch)
                target_items = batch['target_item'].squeeze(1)
                # Используем NLL loss для log_softmax
                loss = F.nll_loss(logits, target_items)
            else:
                # Fallback для других sequential моделей
                predictions = self.model(batch)
                ratings = batch.get('rating', torch.ones(batch['user_id'].size(0)))
                loss = F.mse_loss(predictions, ratings)
                
        elif self.task_type == 'implicit':
            scores = self.model(batch)
            loss = -F.logsigmoid(scores).mean()
            
        self.log('train_loss', loss, prog_bar=True)
        return loss
    
    def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int):
        with torch.no_grad():
            predictions = self.model(batch)
            
            if self.task_type == 'classic':
                ratings = batch['rating']
                val_loss = F.mse_loss(predictions, ratings)
                self.log('val_loss', val_loss)
                
            elif self.task_type == 'sequential':
                if 'target_item' in batch:
                    target_items = batch['target_item'].squeeze(1)
                    # Calculate accuracy for sequential models
                    if hasattr(predictions, 'shape') and len(predictions.shape) > 1:
                        accuracy = (predictions.argmax(dim=1) == target_items).float().mean()
                        self.log('val_acc', accuracy)
    
    def configure_optimizers(self):
        # Используем AdamW с весами для разных частей модели
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.learning_rate,
            weight_decay=self.weight_decay
        )
        
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=self.trainer.max_epochs,
            eta_min=1e-6
        )
        
        return {
            'optimizer': optimizer,
            'lr_scheduler': scheduler
        }

# ==================== ТРЕЙНЕР ====================

class UniversalRecTrainer:
    """Универсальный трейнер для всех задач рекомендаций"""
    def __init__(self,
                 model_type: str = 'mf',  # 'mf', 'deepfm', 'sasrec', 'lightgcn', 'wide_deep', 'advanced_seq', 'universal'
                 task_type: str = 'sequential',
                 embedding_dim: int = 128,
                 hidden_dim: int = 256,
                 sequence_length: int = 50,
                 n_heads: int = 8,
                 n_layers: int = 2,
                 learning_rate: float = 1e-3,
                 batch_size: int = 256,
                 max_epochs: int = 50,
                 model_types: Optional[List[str]] = None):  # для универсальной модели
        
        self.model_type = model_type
        self.task_type = task_type
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.sequence_length = sequence_length
        self.n_heads = n_heads
        self.n_layers = n_layers
        self.learning_rate = learning_rate
        self.batch_size = batch_size
        self.max_epochs = max_epochs
        self.model_types = model_types or [model_type]
        
        self.user_encoder = None
        self.item_encoder = None
        self.model = None
        self.lightning_model = None
        self.trainer = None
        
    def prepare_data(self, df: pd.DataFrame, 
                    user_col: str = 'user_id',
                    item_col: str = 'item_id',
                    rating_col: Optional[str] = 'rating',
                    time_col: Optional[str] = 'timestamp',
                    additional_features: Optional[List[str]] = None,
                    test_size: float = 0.2,
                    val_size: float = 0.1):
        """Подготовка данных с энкодингом и разделением"""
        
        # Создаем энкодеры
        self.user_encoder = LabelEncoder()
        self.item_encoder = LabelEncoder()
        
        df[user_col] = self.user_encoder.fit_transform(df[user_col])
        df[item_col] = self.item_encoder.fit_transform(df[item_col])
        
        # Разделение данных
        if self.task_type == 'sequential':
            # Для последовательных задач разбиваем по времени
            df = df.sort_values([user_col, time_col])
            user_groups = df.groupby(user_col)
            train_data, test_data = [], []
            
            for user_id, user_df in user_groups:
                n_interactions = len(user_df)
                split_idx = int(n_interactions * (1 - test_size))
                train_data.append(user_df.iloc[:split_idx])
                test_data.append(user_df.iloc[split_idx:])
            
            train_df = pd.concat(train_data)
            test_df = pd.concat(test_data)
            
        else:
            # Для других задач случайное разделение
            train_df, test_df = train_test_split(df, test_size=test_size, random_state=42)
        
        if len(train_df) > len(test_df):
            train_df, val_df = train_test_split(train_df, test_size=val_size/(1-test_size), random_state=42)
        else:
            val_df = test_df
        
        # Создание датасетов
        train_dataset = UniversalRecDataset(
            train_df, user_col, item_col, rating_col, time_col, 
            self.sequence_length, self.task_type, additional_features
        )
        
        val_dataset = UniversalRecDataset(
            val_df, user_col, item_col, rating_col, time_col, 
            self.sequence_length, self.task_type, additional_features
        )
        
        test_dataset = UniversalRecDataset(
            test_df, user_col, item_col, rating_col, time_col, 
            self.sequence_length, self.task_type, additional_features
        )
        
        return train_dataset, val_dataset, test_dataset
    
    def build_model(self, n_users: int, n_items: int):
        """Создание модели"""
        if self.model_type == 'sasrec':
            self.model = SASRec(n_items, self.embedding_dim, self.n_heads, self.n_layers)
        elif self.model_type == 'deepfm':
            self.model = DeepFM(n_users, n_items, self.embedding_dim)
        elif self.model_type == 'wide_deep':
            self.model = WideDeep(n_users, n_items, self.embedding_dim)
        elif self.model_type == 'lightgcn':
            self.model = LightGCN(n_users, n_items, self.embedding_dim, self.n_layers)
        elif self.model_type == 'advanced_seq':
            self.model = AdvancedSequentialModel(
                n_items, self.embedding_dim, self.n_heads, self.n_layers
            )
        elif self.model_type == 'universal':
            self.model = UniversalPowerfulModel(
                n_users, n_items, self.model_types, self.embedding_dim, 
                self.hidden_dim, self.n_heads, self.n_layers, self.sequence_length
            )
        else:  # 'mf'
            self.model = UniversalPowerfulModel(
                n_users, n_items, [self.model_type], self.embedding_dim, 
                self.hidden_dim, self.n_heads, self.n_layers, self.sequence_length
            )
        
        self.lightning_model = UniversalRecLightning(
            model=self.model,
            task_type=self.task_type,
            learning_rate=self.learning_rate
        )
    
    def train(self, train_dataset, val_dataset, num_workers: int = 4):
        """Обучение модели"""
        
        train_loader = DataLoader(
            train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=num_workers,
            pin_memory=True
        )
        
        val_loader = DataLoader(
            val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=num_workers,
            pin_memory=True
        )
        
        # Callbacks
        checkpoint_callback = ModelCheckpoint(
            monitor='val_loss' if self.task_type == 'classic' else 'val_loss',
            mode='min',
            save_top_k=1,
            filename='best-{epoch:02d}-{val_loss:.2f}'
        )
        
        early_stop_callback = EarlyStopping(
            monitor='val_loss',
            patience=10,
            mode='min'
        )
        
        # Создание трейнера
        self.trainer = pl.Trainer(
            max_epochs=self.max_epochs,
            callbacks=[checkpoint_callback, early_stop_callback],
            accelerator='auto',
            devices='auto',
            precision=16 if torch.cuda.is_available() else 32,
            gradient_clip_val=1.0,
            accumulate_grad_batches=2  # Для стабильности обучения
        )
        
        # Обучение
        self.trainer.fit(
            self.lightning_model,
            train_dataloaders=train_loader,
            val_dataloaders=val_loader
        )
        
        logger.info(f"Обучение завершено. Лучшая модель: {checkpoint_callback.best_model_path}")
    
    def predict_top_k(self, user_ids: List[int], k: int = 20) -> np.ndarray:
        """Генерация top-k рекомендаций для пользователей"""
        self.lightning_model.eval()
        
        all_items = torch.arange(len(self.item_encoder.classes_))
        predictions = []
        
        with torch.no_grad():
            for user_id in user_ids:
                user_tensor = torch.LongTensor([user_id])
                
                # Для sequential моделей нужно сгенерировать последовательность
                # В реальности используйте исторические данные пользователя
                if hasattr(self.model, 'forward') and any(
                    mt in ['sasrec', 'advanced_seq'] for mt in self.model.model_types
                ):
                    # Для sequential моделей в реальности нужно использовать исторические данные
                    # Это упрощенная версия
                    scores = torch.randn(len(self.item_encoder.classes_))  # Заглушка
                else:
                    # Для других моделей
                    user_items_scores = []
                    for item_id in range(len(self.item_encoder.classes_)):
                        batch = {
                            'user_id': user_tensor,
                            'item_id': torch.LongTensor([item_id])
                        }
                        try:
                            score = self.lightning_model.model(batch)
                            user_items_scores.append(score.item())
                        except:
                            user_items_scores.append(0.0)
                    
                    scores = torch.tensor(user_items_scores)
                
                # Получаем top-k
                top_k = torch.topk(scores, min(k, len(scores)))
                predictions.append(top_k.indices.cpu().numpy())
        
        return np.array(predictions)
    
    def generate_submission(self, test_users: List[int], k: int = 20, 
                           output_file: str = 'submission.csv'):
        """Генерация submission файла"""
        predictions = self.predict_top_k(test_users, k)
        
        # Декодируем ID обратно
        decoded_predictions = []
        for pred in predictions:
            decoded = self.item_encoder.inverse_transform(pred)
            decoded_predictions.append(' '.join(map(str, decoded)))
        
        submission_df = pd.DataFrame({
            'user_id': self.user_encoder.inverse_transform(test_users),
            'item_ids': decoded_predictions
        })
        
        submission_df.to_csv(output_file, index=False)
        logger.info(f"Submission сохранен в {output_file}")


In [None]:

def example_powerful_t_shopping():
    """Пример для T-Shopping с мощной моделью"""
    # df = pd.read_parquet('train_data.pq')
    
    trainer = UniversalRecTrainer(
        model_type='advanced_seq',  # или 'sasrec' для более простой версии
        task_type='sequential',
        embedding_dim=256,
        hidden_dim=512,
        n_heads=8,
        n_layers=3,
        sequence_length=100,
        learning_rate=1e-3,
        batch_size=64,  # Уменьшите если памяти не хватает
        max_epochs=30
    )
    
    # Подготовка данных
    # train_dataset, val_dataset, test_dataset = trainer.prepare_data(
    #     df, user_col='user_id', item_col='item_id', time_col='date'
    # )
    
    # Создание модели
    # trainer.build_model(
    #     n_users=len(trainer.user_encoder.classes_),
    #     n_items=len(trainer.item_encoder.classes_)
    # )
    
    # Обучение
    # trainer.train(train_dataset, val_dataset)
    
    # Генерация submission
    # test_users = list(set(df['user_id']))
    # trainer.generate_submission(test_users, k=20)

def example_universal_model():
    """Пример универсальной модели, комбинирующей разные подходы"""
    # df = pd.read_csv('ratings.csv')
    
    trainer = UniversalRecTrainer(
        model_type='universal',
        model_types=['mf', 'deepfm', 'wide_deep'],  # Комбинируем разные подходы
        task_type='classic',
        embedding_dim=128,
        learning_rate=1e-3
    )
    
    # Аналогично: prepare_data -> build_model -> train

def example_graph_model():
    """Пример с LightGCN (графовая модель)"""
    # df = pd.read_csv('interactions.csv')
    
    trainer = UniversalRecTrainer(
        model_type='lightgcn',
        task_type='implicit',
        embedding_dim=64,
        n_layers=3,
        learning_rate=1e-2  # Для LightGCN часто используется более высокий LR
    )


In [None]:
"""
МОЩНЫЕ МОДЕЛИ ПО ЗАДАЧАМ:

1. T-SHOPPING (Next Item Prediction):
   - 'sasrec': Transformer-based, excellent for sequences
   - 'advanced_seq': Custom attention with multiple heads
   - embedding_dim: 256-512
   - n_heads: 8-16
   - n_layers: 2-4

2. LARGE-SCALE IMPLICIT:
   - 'lightgcn': Graph-based, scales well
   - 'deepfm': Combines FM and deep learning
   - embedding_dim: 64-256

3. CLASSIC RATING PREDICTION:
   - 'deepfm': Best for feature-rich data
   - 'wide_deep': Good for sparse features
   - 'universal': Combines multiple approaches

4. HYBRID TASKS:
   - 'universal': Combines multiple model types
   - model_types: ['mf', 'deepfm', 'sasrec'] for maximum power

НАСТРОЙКИ ДЛЯ МАКСИМАЛЬНОЙ МОЩИ:
- embedding_dim: 256, 512 (для сложных паттернов)
- n_heads: 8, 16 (для attention моделей)
- n_layers: 3, 4, 6 (для глубоких моделей)
- hidden_dim: 512, 1024 (для deep parts)
- sequence_length: 100, 200 (для long sequences)
- batch_size: 32, 64 (адаптируйте под память GPU)

ПАМЯТЬ И ПРОИЗВОДИТЕЛЬНОСТЬ:
- Используйте gradient accumulation для больших моделей
- Включите mixed precision training (precision=16)
- Используйте model parallelism для очень больших моделей
"""