Imports Completos

In [22]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn import CrossEntropyLoss, MSELoss
from torch.utils.data import IterableDataset, DataLoader

import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import pyarrow.compute as pc

from sklearn.cluster import KMeans
import os
from tqdm import tqdm
import matplotlib.pyplot as plt

# Seus módulos locais
from external_information import ExternalInformationFusionDTPC, ExternalInformationDense
from partial_information import CoordLSTM

# Configurações
parquet_file = "humob_all_cities_dpsk.parquet"
n_users_by_city = {"A": 100_000, "B": 25_000, "C": 20_000, "D": 6_000}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"🚀 Configuração inicial")
print(f"Device: {device}")
print(f"Arquivo: {parquet_file}")

🚀 Configuração inicial
Device: cuda
Arquivo: humob_all_cities_dpsk.parquet


 Célula 2: Centros Estáveis (K-Means)

In [23]:
def compute_stable_centers(
    parquet_path: str,
    cities: list[str] = ["A"],
    day_threshold: int = 60,
    n_clusters: int = 1024,
    sample_size: int = 200_000,
    save_path: str = "centers_stable.npy",
    chunk_size: int = 50_000
) -> torch.Tensor:
    """
    Calcula centros estáveis usando K-Means.
    Mais robusto que HDBSCAN: sempre produz exatamente n_clusters centros.
    """
    
    # Se já existe, carrega
    if os.path.exists(save_path):
        print(f"📂 Carregando centros existentes: {save_path}")
        centers = np.load(save_path)
        return torch.from_numpy(centers.astype(np.float32))
    
    print(f"🔄 Calculando {n_clusters} centros para cidades {cities}...")
    
    # 1) Coleta coordenadas
    pf = pq.ParquetFile(parquet_path)
    coords_list = []
    
    for batch in pf.iter_batches(batch_size=chunk_size):
        tbl = pa.Table.from_batches([batch], schema=pf.schema_arrow)
        
        # Filtra cidades e dias
        city_mask = pc.is_in(tbl.column("city"), pa.array(cities))
        day_mask = pc.less(tbl.column("d"), day_threshold)
        mask = pc.and_(city_mask, day_mask)
        
        tbl = tbl.filter(mask)
        if tbl.num_rows == 0:
            continue
            
        xs = tbl.column("x").to_numpy()
        ys = tbl.column("y").to_numpy()
        coords_list.append(np.stack([xs, ys], axis=1))
    
    coords = np.vstack(coords_list)
    print(f"📊 Coletadas {len(coords):,} coordenadas")
    
    # 2) Amostra se muito grande
    if len(coords) > sample_size:
        idx = np.random.choice(len(coords), sample_size, replace=False)
        coords = coords[idx]
        print(f"🎲 Amostradas {sample_size:,} coordenadas")
    
    # 3) K-Means
    print("⚙️ Executando K-Means...")
    kmeans = KMeans(
        n_clusters=n_clusters, 
        n_init='auto', 
        random_state=42,
        max_iter=300
    ).fit(coords)
    
    centers = kmeans.cluster_centers_.astype(np.float32)
    
    # 4) Salva para reutilizar
    np.save(save_path, centers)
    print(f"💾 Centros salvos em: {save_path}")
    print(f"📏 Shape dos centros: {centers.shape}")
    
    return torch.from_numpy(centers)


def load_or_compute_centers(
    parquet_path: str,
    cities: list[str] = ["A"],
    n_clusters: int = 1024,
    save_path: str = "centers_stable.npy"
) -> torch.Tensor:
    """Função conveniente que carrega se existe, senão calcula."""
    return compute_stable_centers(
        parquet_path=parquet_path,
        cities=cities,
        n_clusters=n_clusters,
        save_path=save_path
    )

In [24]:
class CityAPretrainDataset(IterableDataset):
    """
    Dataset otimizado para pré-treino em cidade A.
    Faz amostragem estratificada por usuário e divide train/val por dias.
    """
    def __init__(
        self,
        parquet_path: str,
        mode: str = "train",  # "train" ou "val"
        sequence_length: int = 48,
        prediction_steps: int = 1,
        chunk_size: int = 10_000,
        max_sequences_per_user: int = 50,
        train_days: tuple = (1, 55),
        val_days: tuple = (56, 60)
    ):
        self.parquet_path = parquet_path
        self.mode = mode
        self.sequence_length = sequence_length
        self.prediction_steps = prediction_steps
        self.chunk_size = chunk_size
        self.max_sequences_per_user = max_sequences_per_user
        
        if mode == "train":
            self.day_range = train_days
        else:
            self.day_range = val_days
    
    def _sample_user_sequences(self, user_group, max_seqs: int):
        """Amostra sequências de um usuário de forma estratificada."""
        sequences = self._build_sequences_for_user(user_group)
        
        if len(sequences) <= max_seqs:
            return sequences
        
        # Amostragem estratificada por hora do dia
        sequences_by_hour = {}
        for seq in sequences:
            hour = seq['t'] // 2  # agrupa slots em horas (48 slots / 24h)
            if hour not in sequences_by_hour:
                sequences_by_hour[hour] = []
            sequences_by_hour[hour].append(seq)
        
        # Amostra proporcionalmente de cada hora
        sampled = []
        for hour_seqs in sequences_by_hour.values():
            n_sample = max(1, min(len(hour_seqs), max_seqs // len(sequences_by_hour)))
            sampled.extend(np.random.choice(hour_seqs, n_sample, replace=False))
        
        return sampled[:max_seqs]
    
    def _build_sequences_for_user(self, user_data):
        """Constrói sequências temporais para um usuário específico."""
        sequences = []
        
        for i in range(len(user_data) - self.sequence_length - self.prediction_steps + 1):
            seq_start = i
            seq_end = i + self.sequence_length
            target_start = seq_end
            target_end = target_start + self.prediction_steps
            
            coords_seq = user_data.iloc[seq_start:seq_end][['x', 'y']].values
            current_info = user_data.iloc[seq_end - 1]
            target_coords = user_data.iloc[target_start:target_end][['x', 'y']].values
            
            sequences.append({
                'uid': current_info['uid'],
                'd': current_info['d'], 
                't': current_info['t'],
                'city_idx': 0,  # A é sempre 0
                'poi': current_info['POI'],
                'coords_seq': coords_seq.astype(np.float32),
                'target_coords': target_coords.astype(np.float32)
            })
            
        return sequences

    def __iter__(self):
        pf = pq.ParquetFile(self.parquet_path)
        
        for batch in pf.iter_batches(batch_size=self.chunk_size):
            table = pa.Table.from_batches([batch], schema=pf.schema_arrow)

            # Filtra cidade A e range de dias
            city_mask = pc.equal(table.column("city"), "A")
            day_mask = pc.and_(
                pc.greater_equal(table.column("d"), self.day_range[0]),
                pc.less_equal(table.column("d"), self.day_range[1])
            )
            mask = pc.and_(city_mask, day_mask)
            
            table = table.filter(mask)
            if table.num_rows == 0:
                continue

            # Converte e normaliza POIs
            df = table.to_pandas()

            # Normaliza POIs (VERSÃO CORRIGIDA)
            poi_cols = [col for col in df.columns if 'POI' in col or col == 'POI']
            if poi_cols:
                for col in poi_cols:
                    if col in df.columns:
                        # Verifica o tipo da coluna POI
                        sample_val = df[col].iloc[0] if len(df) > 0 else None
                        
                        if sample_val is not None:
                            # Se POI é uma lista/array (vetor 85-D)
                            if hasattr(sample_val, '__len__') and not isinstance(sample_val, str):
                                # df[col] = df[col].apply(lambda x: np.log1p(np.array(x)) if x is not None else x)
                                df[col] = df[col].apply(lambda x: np.log1p(np.array(x, dtype=np.float32)) if x is not None else np.zeros(85, dtype=np.float32))
                            # Se POI é um valor escalar
                            else:
                                df[col] = np.log1p(df[col].fillna(0))
            
            df = df.sort_values(['uid', 'd', 't'])
            
            # Processa cada usuário com amostragem estratificada
            for uid, user_group in df.groupby('uid'):
                if len(user_group) < self.sequence_length + self.prediction_steps:
                    continue
                
                sequences = self._sample_user_sequences(
                    user_group, 
                    self.max_sequences_per_user
                )
                
                for seq in sequences:
                    yield (
                        torch.tensor(seq['uid'], dtype=torch.long),
                        torch.tensor(seq['d'], dtype=torch.long),
                        torch.tensor(seq['t'], dtype=torch.long),
                        torch.tensor(seq['city_idx'], dtype=torch.long),
                        torch.from_numpy(seq['poi']),
                        torch.from_numpy(seq['coords_seq']),
                        torch.from_numpy(seq['target_coords'])
                    )


def create_pretrain_loaders(
    parquet_path: str,
    batch_size: int = 32,
    sequence_length: int = 48,
    max_sequences_per_user: int = 50,
    num_workers: int = 0
):
    """Cria loaders de treino e validação para cidade A."""
    
    train_ds = CityAPretrainDataset(
        parquet_path=parquet_path,
        mode="train",
        sequence_length=sequence_length,
        max_sequences_per_user=max_sequences_per_user
    )
    
    val_ds = CityAPretrainDataset(
        parquet_path=parquet_path,
        mode="val", 
        sequence_length=sequence_length,
        max_sequences_per_user=max_sequences_per_user // 2
    )
    
    train_loader = DataLoader(
        train_ds, 
        batch_size=batch_size, 
        num_workers=num_workers,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        num_workers=num_workers, 
        pin_memory=True
    )
    
    return train_loader, val_loader

In [25]:
class ParquetCityDDataset(IterableDataset):
    """Dataset para avaliação em cidade D."""
    def __init__(
        self,
        parquet_path: str,
        city_list: list[str],
        chunk_size: int = 10_000,
        sequence_length: int = 48,
        prediction_steps: int = 1
    ):
        self.parquet_path = parquet_path
        self.city_list = city_list
        self.city_set = set(city_list)
        self.city_to_idx = {c: i for i, c in enumerate(city_list)}
        self.chunk_size = chunk_size
        self.sequence_length = sequence_length
        self.prediction_steps = prediction_steps

    def _build_sequences_for_user(self, user_data):
        """Constrói sequências temporais para um usuário específico."""
        sequences = []
        
        for i in range(len(user_data) - self.sequence_length - self.prediction_steps + 1):
            seq_start = i
            seq_end = i + self.sequence_length
            target_start = seq_end
            target_end = target_start + self.prediction_steps
            
            coords_seq = user_data.iloc[seq_start:seq_end][['x', 'y']].values
            current_info = user_data.iloc[seq_end - 1]
            target_coords = user_data.iloc[target_start:target_end][['x', 'y']].values
            
            sequences.append({
                'uid': current_info['uid'],
                'd': current_info['d'], 
                't': current_info['t'],
                'city_idx': self.city_to_idx[current_info['city']],
                'poi': current_info['POI'],
                'coords_seq': coords_seq.astype(np.float32),
                'target_coords': target_coords.astype(np.float32)
            })
            
        return sequences

    def __iter__(self):
        pf = pq.ParquetFile(self.parquet_path)
        
        for batch in pf.iter_batches(batch_size=self.chunk_size):
            table = pa.Table.from_batches([batch], schema=pf.schema_arrow)

            # Filtra cidades
            mask = pc.is_in(table.column("city"), pa.array(list(self.city_set)))
            table = table.filter(mask)
            if table.num_rows == 0:
                continue

            # Converte para pandas
            df = table.to_pandas()
            

            # Normaliza POIs (VERSÃO CORRIGIDA)
            poi_cols = [col for col in df.columns if 'POI' in col or col == 'POI']
            if poi_cols:
                for col in poi_cols:
                    if col in df.columns:
                        # Verifica o tipo da coluna POI
                        sample_val = df[col].iloc[0] if len(df) > 0 else None
                        
                        if sample_val is not None:
                            # Se POI é uma lista/array (vetor 85-D)
                            if hasattr(sample_val, '__len__') and not isinstance(sample_val, str):
                                # df[col] = df[col].apply(lambda x: np.log1p(np.array(x)).tolist() if x is not None else x)
                                df[col] = df[col].apply(lambda x: np.log1p(np.array(x, dtype=np.float32)) if x is not None else np.zeros(85, dtype=np.float32))
                            # Se POI é um valor escalar
                            else:
                                df[col] = np.log1p(df[col].fillna(0))
            
            df = df.sort_values(['uid', 'd', 't'])
            
            for uid, user_group in df.groupby('uid'):
                if len(user_group) < self.sequence_length + self.prediction_steps:
                    continue
                    
                sequences = self._build_sequences_for_user(user_group)
                
                for seq in sequences:
                    yield (
                        torch.tensor(seq['uid'], dtype=torch.long),
                        torch.tensor(seq['d'], dtype=torch.long),
                        torch.tensor(seq['t'], dtype=torch.long),
                        torch.tensor(seq['city_idx'], dtype=torch.long),
                        torch.from_numpy(seq['poi']),
                        torch.from_numpy(seq['coords_seq']),
                        torch.from_numpy(seq['target_coords'])
                    )

In [26]:
class WeightedFusion(nn.Module):
    """
    Funde dois vetores de mesmo tamanho por uma soma ponderada aprendível.
    """
    def __init__(self, dim: int = 20, init_w_r: float = 0.5, init_w_e: float = 0.5):
        super().__init__()
        self.w_r = nn.Parameter(torch.tensor(init_w_r, dtype=torch.float32))
        self.w_e = nn.Parameter(torch.tensor(init_w_e, dtype=torch.float32))
        self.dim = dim

    def forward(self, static_red: torch.Tensor, dyn_emb: torch.Tensor) -> torch.Tensor:
        assert static_red.shape == dyn_emb.shape and static_red.size(1) == self.dim
        fused = self.w_r * static_red + self.w_e * dyn_emb
        return fused


class MLP500(nn.Module):
    """MLP simples com 1 hidden layer de 500 ReLUs."""
    def __init__(self, in_dim: int, hidden_dim: int, n_clusters: int):
        super().__init__()
        self.fc1 = nn.Linear(in_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, n_clusters)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = F.relu(self.fc1(x))
        return self.fc2(x)


class DestinationHead(nn.Module):
    """Combina MLP + softmax + weighted sum pelos cluster centers."""
    def __init__(self, in_dim: int, hidden_dim: int, cluster_centers: torch.Tensor):
        super().__init__()
        C, coord_dim = cluster_centers.shape
        assert coord_dim == 2
        self.mlp500 = MLP500(in_dim, hidden_dim, C)
        self.register_buffer("centers", cluster_centers)

    def forward(self, fused: torch.Tensor) -> torch.Tensor:
        logits = self.mlp500(fused)
        P = F.softmax(logits, dim=1)
        coords = P @ self.centers
        return coords


def discretize_coordinates(coords_pred: torch.Tensor, grid_size: int = 200):
    """Converte coordenadas contínuas para grid discreto [0, grid_size-1]."""
    coords_discrete = torch.round(coords_pred).long()
    coords_discrete = torch.clamp(coords_discrete, 0, grid_size - 1)
    return coords_discrete

In [27]:
class HuMobModel(nn.Module):
    """
    Modelo completo que combina todas as partes e faz rollout para múltiplos passos.
    """
    def __init__(
        self,
        n_users: int,
        n_days: int,
        n_slots: int,
        n_cities: int,
        cluster_centers: torch.Tensor,
        emb_dim: int = 10,
        poi_in_dim: int = 85,
        poi_out_dim: int = 10,
        lstm_hidden: int = 10,
        fusion_dim: int = 20,
        sequence_length: int = 48,
        prediction_steps: int = 1
    ):
        super().__init__()
        
        # Salva configurações
        self.sequence_length = sequence_length
        self.prediction_steps = prediction_steps
        
        # Componentes
        self.fusion = ExternalInformationFusionDTPC(
            n_users=n_users,
            n_days=n_days,
            n_slots=n_slots,
            n_cities=n_cities,
            emb_dim=emb_dim,
            poi_in_dim=poi_in_dim,
            poi_out_dim=poi_out_dim
        )
        self.dense = ExternalInformationDense(
            in_dim=self.fusion.out_dim, 
            out_dim=fusion_dim
        )
        self.lstm = CoordLSTM(
            input_size=2, 
            hidden_size=lstm_hidden, 
            bidirectional=True
        )
        self.weighted_fusion = WeightedFusion(dim=fusion_dim)
        self.destination_head = DestinationHead(
            in_dim=fusion_dim,
            hidden_dim=500,
            cluster_centers=cluster_centers
        )
        
    def forward_single_step(self, uid, d, t, city, poi, coords_seq):
        """Faz uma predição para um único passo."""
        # Informação estática (contexto)
        static_emb = self.fusion(uid, d, t, city, poi)
        static_red = self.dense(static_emb)
        
        # Informação dinâmica (padrão de movimento)
        dyn_emb = self.lstm(coords_seq)
        
        # Fusão inteligente
        fused = self.weighted_fusion(static_red, dyn_emb)
        
        # Predição final
        pred_coords = self.destination_head(fused)
        
        return pred_coords
    
    def rollout_predictions(
        self, 
        uid, d, t, city, poi, coords_seq, 
        n_steps: int,
        use_predictions: bool = True
    ):
        """Faz predições para múltiplos passos futuros."""
        predictions = []
        current_seq = coords_seq.clone()
        current_t = t.clone()
        current_d = d.clone()
        
        for step in range(n_steps):
            # Prediz próximo passo
            pred = self.forward_single_step(uid, current_d, current_t, city, poi, current_seq)
            predictions.append(pred)
            
            if use_predictions and step < n_steps - 1:
                # Atualiza sequência: remove primeiro ponto, adiciona predição
                new_point = pred.unsqueeze(1)  # (batch, 1, 2)
                current_seq = torch.cat([current_seq[:, 1:, :], new_point], dim=1)
            
            # Incrementa tempo corretamente (wrap em 48 slots)
            current_t = (current_t + 1) % 48
            # Se voltou para slot 0, incrementa dia
            mask_new_day = (current_t == 0)
            current_d = current_d + mask_new_day.long()
            
        return torch.stack(predictions, dim=1)  # (batch, n_steps, 2)
    
    def forward(self, uid, d, t, city, poi, coords_seq, n_steps=1):
        """Forward principal - pode ser usado tanto para treino quanto inferência."""
        if n_steps == 1:
            return self.forward_single_step(uid, d, t, city, poi, coords_seq)
        else:
            return self.rollout_predictions(uid, d, t, city, poi, coords_seq, n_steps)

In [28]:
def pretrain_on_city_A(
    parquet_path: str,
    centers: torch.Tensor,
    device: torch.device,
    n_epochs: int = 6,
    learning_rate: float = 2e-3,
    batch_size: int = 32,
    sequence_length: int = 48,
    save_path: str = "ckpt_A_warmup.pt"
):
    """Pré-treina o modelo na cidade A para validar o pipeline."""
    print("🏋️ Iniciando pré-treino na cidade A...")
    
    # 1. Cria loaders
    train_loader, val_loader = create_pretrain_loaders(
        parquet_path=parquet_path,
        batch_size=batch_size,
        sequence_length=sequence_length
    )
    
    # 2. Instancia modelo
    model = HuMobModel(
        n_users=100_000,  # cidade A
        n_days=75,
        n_slots=48,
        n_cities=4,
        cluster_centers=centers,
        sequence_length=sequence_length,
        prediction_steps=1  # Next-step apenas
    ).to(device)
    
    # 3. Setup de treino (MSE apenas para começar)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=2)
    criterion = nn.MSELoss()
    
    best_val_loss = float('inf')
    train_losses = []
    val_losses = []
    
    print(f"Parâmetros treináveis: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
    
    for epoch in range(n_epochs):
        # === TREINO ===
        model.train()
        train_loss_epoch = 0
        train_count = 0
        
        print(f"\n🔄 Época {epoch+1}/{n_epochs}")
        train_pbar = tqdm(train_loader, desc=f'Treino A')
        
        for batch in train_pbar:
            uid, d, t, city, poi, coords_seq, target_coords = [b.to(device) for b in batch]
            
            optimizer.zero_grad()
            
            # Forward (apenas next-step)
            pred = model.forward_single_step(uid, d, t, city, poi, coords_seq)
            target = target_coords.squeeze(1)  # Remove dimensão de step
            
            # Loss MSE simples
            loss = criterion(pred, target)
            
            # Backward
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            optimizer.step()
            
            # Métricas
            train_loss_epoch += loss.item() * target.size(0)
            train_count += target.size(0)
            
            train_pbar.set_postfix({
                'Loss': f'{loss.item():.4f}',
                'LR': f'{optimizer.param_groups[0]["lr"]:.6f}'
            })
        
        # === VALIDAÇÃO ===
        model.eval()
        val_loss_epoch = 0
        val_count = 0
        
        with torch.no_grad():
            val_pbar = tqdm(val_loader, desc='Val A')
            for batch in val_pbar:
                uid, d, t, city, poi, coords_seq, target_coords = [b.to(device) for b in batch]
                
                pred = model.forward_single_step(uid, d, t, city, poi, coords_seq)
                target = target_coords.squeeze(1)
                
                loss = criterion(pred, target)
                
                val_loss_epoch += loss.item() * target.size(0)
                val_count += target.size(0)
                
                val_pbar.set_postfix({'Val Loss': f'{loss.item():.4f}'})
        
        # Médias
        avg_train_loss = train_loss_epoch / max(train_count, 1)
        avg_val_loss = val_loss_epoch / max(val_count, 1)
        
        train_losses.append(avg_train_loss)
        val_losses.append(avg_val_loss)
        
        print(f"Treino: {avg_train_loss:.4f} | Val: {avg_val_loss:.4f}")
        print(f"Fusion weights: w_r={model.weighted_fusion.w_r.item():.3f}, w_e={model.weighted_fusion.w_e.item():.3f}")
        
        # Learning rate scheduling
        scheduler.step(avg_val_loss)
        
        # Salva melhor modelo
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            print(f"💾 Novo melhor modelo! Loss: {best_val_loss:.4f}")
            
            torch.save({
                'state_dict': model.state_dict(),
                'centers': centers.cpu().numpy(),
                'config': {
                    'sequence_length': sequence_length,
                    'prediction_steps': 1,
                    'n_clusters': centers.shape[0]
                },
                'train_loss': avg_train_loss,
                'val_loss': avg_val_loss,
                'epoch': epoch
            }, save_path)
    
    print(f"\n✅ Pré-treino concluído! Modelo salvo em: {save_path}")
    print(f"Melhor loss de validação: {best_val_loss:.4f}")
    
    return model, train_losses, val_losses


def evaluate_zero_shot_on_D(
    parquet_path: str,
    checkpoint_path: str,
    device: torch.device,
    n_samples: int = 5000,
    sequence_length: int = 48
):
    """Avalia o modelo pré-treinado em A na cidade D (zero-shot)."""
    print("🎯 Avaliação zero-shot em cidade D...")
    
    # 1. Carrega checkpoint
    ckpt = torch.load(checkpoint_path, map_location=device)
    centers = torch.from_numpy(ckpt['centers']).to(device)
    
    # 2. Instancia modelo
    model = HuMobModel(
        n_users=6_000,  # cidade D
        n_days=75,
        n_slots=48,
        n_cities=4,
        cluster_centers=centers,
        sequence_length=sequence_length,
        prediction_steps=1
    ).to(device)
    
    model.load_state_dict(ckpt['state_dict'])
    model.eval()
    
    # 3. Cria dataset de avaliação em D
    eval_ds = ParquetCityDDataset(
        parquet_path=parquet_path,
        city_list=["D"],
        sequence_length=sequence_length,
        prediction_steps=1,
        chunk_size=2500
    )
    
    # 4. Avalia em amostra
    criterion = nn.MSELoss()
    total_loss = 0
    total_samples = 0
    coords_errors = []
    
    with torch.no_grad():
        eval_loader = DataLoader(eval_ds, batch_size=32, num_workers=2)
        pbar = tqdm(eval_loader, desc='Eval D')
        
        for batch in pbar:
            if total_samples >= n_samples:
                break
                
            uid, d, t, city, poi, coords_seq, target_coords = [b.to(device) for b in batch]
            
            # Filtra apenas dados ≤ dia 60 (observados)
            mask = d <= 60
            if not mask.any():
                continue
                
            uid, d, t, city, poi = uid[mask], d[mask], t[mask], city[mask], poi[mask]
            coords_seq, target_coords = coords_seq[mask], target_coords[mask]
            
            pred = model.forward_single_step(uid, d, t, city, poi, coords_seq)
            target = target_coords.squeeze(1)
            
            loss = criterion(pred, target)
            total_loss += loss.item() * target.size(0)
            total_samples += target.size(0)
            
            # Calcula erro em células (após discretização)
            pred_discrete = discretize_coordinates(pred)
            target_discrete = discretize_coordinates(target)
            cell_error = torch.abs(pred_discrete - target_discrete).float().mean(dim=1)
            coords_errors.extend(cell_error.cpu().tolist())
            
            pbar.set_postfix({
                'MSE': f'{loss.item():.4f}',
                'Samples': total_samples
            })
    
    avg_mse = total_loss / max(total_samples, 1)
    avg_cell_error = np.mean(coords_errors)
    
    print(f"\n📊 Resultados zero-shot em D:")
    print(f"  MSE: {avg_mse:.4f}")
    print(f"  Erro médio em células: {avg_cell_error:.2f}")
    print(f"  Amostras avaliadas: {total_samples:,}")
    
    return avg_mse, avg_cell_error

In [29]:
def run_pretrain_pipeline(
    parquet_path: str = "humob_all_cities_dpsk.parquet",
    device: torch.device = None,
    n_clusters: int = 1024,
    pretrain_epochs: int = 6,
    sequence_length: int = 48,
    batch_size: int = 32,
    learning_rate: float = 2e-3
):
    """Pipeline completo de pré-treino."""
    
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    print("🚀 Iniciando pipeline de pré-treino HuMob")
    print(f"Device: {device}")
    print(f"Clusters: {n_clusters}")
    print(f"Epochs: {pretrain_epochs}")
    print(f"Sequence length: {sequence_length}")
    print("="*50)
    
    # 1. CENTROS ESTÁVEIS
    print("\n📍 ETAPA 1: Calculando centros estáveis...")
    centers = load_or_compute_centers(
        parquet_path=parquet_path,
        cities=["A"],
        n_clusters=n_clusters,
        save_path=f"centers_A_{n_clusters}.npy"
    ).to(device)
    
    print(f"✅ Centros prontos: {centers.shape}")
    
    # 2. PRÉ-TREINO EM A
    print("\n🏋️ ETAPA 2: Pré-treino na cidade A...")
    model, train_losses, val_losses = pretrain_on_city_A(
        parquet_path=parquet_path,
        centers=centers,
        device=device,
        n_epochs=pretrain_epochs,
        learning_rate=learning_rate,
        batch_size=batch_size,
        sequence_length=sequence_length,
        save_path="ckpt_A_warmup.pt"
    )
    
    # Plot das curvas
    plt.figure(figsize=(10, 4))
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train')
    plt.plot(val_losses, label='Val')
    plt.xlabel('Época')
    plt.ylabel('MSE Loss')
    plt.legend()
    plt.title('Pré-treino em A')
    plt.grid(True)
    
    plt.subplot(1, 2, 2)
    plt.plot([model.weighted_fusion.w_r.item()], [model.weighted_fusion.w_e.item()], 'ro', markersize=10)
    plt.xlabel('w_r (estático)')
    plt.ylabel('w_e (dinâmico)')
    plt.title('Pesos da Fusão')
    plt.grid(True)
    
    plt.tight_layout()
    plt.savefig('pretrain_results.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    # 3. AVALIAÇÃO ZERO-SHOT EM D
    print("\n🎯 ETAPA 3: Avaliação zero-shot na cidade D...")
    mse_d, cell_error_d = evaluate_zero_shot_on_D(
        parquet_path=parquet_path,
        checkpoint_path="ckpt_A_warmup.pt",
        device=device,
        n_samples=5000,
        sequence_length=sequence_length
    )
    
    # 4. RELATÓRIO FINAL
    print("\n📊 RELATÓRIO FINAL")
    print("="*50)
    print(f"✅ Pré-treino concluído em {pretrain_epochs} épocas")
    print(f"📈 Loss final de treino: {train_losses[-1]:.4f}")
    print(f"📉 Loss final de validação: {val_losses[-1]:.4f}")
    print(f"🎯 MSE zero-shot em D: {mse_d:.4f}")
    print(f"📍 Erro médio em células: {cell_error_d:.2f}")
    print(f"⚖️ Pesos da fusão: w_r={model.weighted_fusion.w_r.item():.3f}, w_e={model.weighted_fusion.w_e.item():.3f}")
    print(f"💾 Checkpoint salvo: ckpt_A_warmup.pt")
    print(f"🎨 Gráficos salvos: pretrain_results.png")
    
    # Verifica se está tudo OK
    print("\n🔍 VERIFICAÇÕES:")
    converged = val_losses[-1] < val_losses[0] * 0.8
    reasonable_mse = mse_d < 1000
    reasonable_error = cell_error_d < 20
    
    print(f"  Convergiu? {'✅' if converged else '❌'} (loss diminuiu 20%+)")
    print(f"  MSE razoável? {'✅' if reasonable_mse else '❌'} (< 1000)")
    print(f"  Erro de células OK? {'✅' if reasonable_error else '❌'} (< 20 células)")
    
    all_good = converged and reasonable_mse and reasonable_error
    
    if all_good:
        print("\n🎉 PIPELINE EXECUTADO COM SUCESSO!")
        print("   Pronto para próximas etapas:")
        print("   - Fine-tune em D (opcional)")
        print("   - Rollout para submissão (dias 61-75)")
    else:
        print("\n⚠️ ALGUNS PROBLEMAS DETECTADOS")
        print("   Recomendações:")
        if not converged:
            print("   - Aumentar número de épocas ou ajustar LR")
        if not reasonable_mse:
            print("   - Verificar normalização dos dados")
        if not reasonable_error:
            print("   - Revisar cálculo de centros ou head")
    
    return {
        'model': model,
        'centers': centers,
        'train_losses': train_losses,
        'val_losses': val_losses,
        'mse_zero_shot': mse_d,
        'cell_error': cell_error_d,
        'checkpoint_path': 'ckpt_A_warmup.pt'
    }

In [30]:
# ===================================================
# 🚀 EXECUÇÃO PRINCIPAL - ESTRATÉGIA DE PRÉ-TREINO
# ===================================================

print("🎯 Executando estratégia de pré-treino sugerida...")
print(f"Device: {device}")
print("Estratégia: A (pré-treino) → D (zero-shot) → D (fine-tune opcional)")
print("="*60)

# # EXECUTA O PIPELINE PRINCIPAL
# results = run_pretrain_pipeline(
#     parquet_path=parquet_file,
#     device=device,
#     n_clusters=1024,        # Estável e rápido
#     pretrain_epochs=6,      # Rápido para validar
#     sequence_length=48,     # 1 dia de histórico
#     batch_size=32,
#     learning_rate=2e-3
# )

results = run_pretrain_pipeline(
    parquet_path=parquet_file,
    device=device,
    n_clusters=64,         # ⬇️ Era 1024 → 256 (4x menos centros)
    pretrain_epochs=2,      # ⬇️ Era 6 → 2 (3x menos épocas)  
    sequence_length=6,     # ⬇️ Era 48 → 12 (4x menos histórico)
    batch_size=32,         
    learning_rate=1e-3     
)

# Verifica se o pré-treino foi bem-sucedido
if results['mse_zero_shot'] < 1000:  # Limite razoável
    print("\n✅ Pré-treino bem-sucedido! Modelo pronto para uso.")
    print("\n🎯 PRÓXIMOS PASSOS:")
    print(f"   1. Modelo final salvo em: {results['checkpoint_path']}")
    print("   2. Para submissão HuMob, use create_submission_data()")
    print("   3. Para rollout completo (15 dias), ajuste n_steps=48*15")
    print("   4. Lembre-se de discretizar coordenadas [0,199]")
else:
    print("\n❌ Pré-treino não convergiu bem.")
    print("💡 Sugestões:")
    print("   - Aumentar número de épocas")
    print("   - Ajustar learning rate")
    print("   - Verificar normalização dos POIs")
    print("   - Revisar sequência temporal do dataset")

print("\n" + "="*60)
print("🏁 Pipeline executado com sucesso!")
print("📁 Arquivos gerados:")
print("   - centers_A_1024.npy (centros estáveis)")
print("   - ckpt_A_warmup.pt (modelo pré-treinado)")
print("   - pretrain_results.png (gráficos)")

🎯 Executando estratégia de pré-treino sugerida...
Device: cuda
Estratégia: A (pré-treino) → D (zero-shot) → D (fine-tune opcional)
🚀 Iniciando pipeline de pré-treino HuMob
Device: cuda
Clusters: 64
Epochs: 2
Sequence length: 6

📍 ETAPA 1: Calculando centros estáveis...
🔄 Calculando 64 centros para cidades ['A']...
📊 Coletadas 88,405,298 coordenadas
🎲 Amostradas 200,000 coordenadas
⚙️ Executando K-Means...
💾 Centros salvos em: centers_A_64.npy
📏 Shape dos centros: (64, 2)
✅ Centros prontos: torch.Size([64, 2])

🏋️ ETAPA 2: Pré-treino na cidade A...
🏋️ Iniciando pré-treino na cidade A...
Parâmetros treináveis: 1,046,836

🔄 Época 1/2


Treino A: 94it [00:27,  3.44it/s, Loss=nan, LR=0.001000]      Traceback (most recent call last):
  File "/home/andersonc/anaconda3/envs/orion_ct/lib/python3.12/multiprocessing/queues.py", line 270, in _feed
    send_bytes(obj)
  File "/home/andersonc/anaconda3/envs/orion_ct/lib/python3.12/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/home/andersonc/anaconda3/envs/orion_ct/lib/python3.12/multiprocessing/connection.py", line 427, in _send_bytes
    self._send(header + buf)
  File "/home/andersonc/anaconda3/envs/orion_ct/lib/python3.12/multiprocessing/connection.py", line 384, in _send
    n = write(self._handle, buf)
        ^^^^^^^^^^^^^^^^^^^^^^^^
OSError: [Errno 9] Bad file descriptor
Traceback (most recent call last):
  File "/home/andersonc/anaconda3/envs/orion_ct/lib/python3.12/multiprocessing/queues.py", line 259, in _feed
    reader_close()
  File "/home/andersonc/anaconda3/envs/orion_ct/lib/python3.12/multiprocessi

KeyboardInterrupt: 