In [5]:
# train_supervised.py
import os, json, time, datetime as dt
import numpy as np
import pandas as pd
import yfinance as yf
from joblib import dump
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neural_network import MLPClassifier

from pymongo import MongoClient
from dotenv import load_dotenv

load_dotenv()
connection_string = os.getenv("MONGODB_URI")
if not connection_string:
    raise ValueError("MONGODB_URI não encontrada. Verifique seu arquivo .env")

client = MongoClient(connection_string)
db = client["investia"]
assets_collection = db["assets"] # Nome da sua coleção
print("Conectado ao MongoDB Atlas com sucesso!")

BASE = "../"
ART = f"{BASE}/artifacts"
os.makedirs(ART, exist_ok=True)

US_TICKERS = ["AAPL","MSFT","AMZN","GOOGL","META","NVDA","TSLA","JPM","V","PG","KO","PEP","NFLX","AMD","INTC","DIS"]
BR_TICKERS = ["PETR4.SA","VALE3.SA","ITUB4.SA","BBDC4.SA","ABEV3.SA","WEGE3.SA","BBAS3.SA","B3SA3.SA","RAIL3.SA","PRIO3.SA","LREN3.SA","GGBR4.SA"]
TICKERS = US_TICKERS + BR_TICKERS

def region_for(t):
    return "BR" if t.endswith(".SA") else "US"

def _normalize_ohlcv(df):
    if df is None or df.empty:
        return None
    if hasattr(df, "columns") and isinstance(df.columns, pd.MultiIndex):
        df.columns = df.columns.get_level_values(-1)
    cols = {c: c.strip() for c in df.columns}
    df = df.rename(columns=cols)
    if "Close" not in df.columns:
        if "Adj Close" in df.columns:
            df["Close"] = df["Adj Close"]
        else:
            return None
    if "Volume" not in df.columns:
        df["Volume"] = np.nan
    return df

def fetch_single(ticker, start, end, interval="1d", tries=3, sleep_s=0.6):
    for _ in range(tries):
        try:
            t = yf.Ticker(ticker)
            df = t.history(start=start, end=end, interval=interval, auto_adjust=False, actions=False, repair=True)
            df = _normalize_ohlcv(df)
            if df is not None and not df.empty:
                idx = df.index
                if hasattr(idx, "tz") and idx.tz is not None:
                    df.index = idx.tz_convert("UTC").tz_localize(None)
                df = df.reset_index().rename(columns={"Date":"date"})
                df["ticker"] = ticker
                return df
        except Exception:
            pass
        time.sleep(sleep_s)
    try:
        df = yf.download(ticker, period="1y", interval=interval, auto_adjust=False, progress=False)
        df = _normalize_ohlcv(df)
        if df is not None and not df.empty:
            idx = df.index
            if hasattr(idx, "tz") and hasattr(idx, "tz_localize") and idx.tz is not None:
                df.index = idx.tz_convert("UTC").tz_localize(None)
            df = df.reset_index().rename(columns={"Date":"date"})
            df["ticker"] = ticker
            return df
    except Exception:
        pass
    return pd.DataFrame()

def fetch_features(tickers, period_days=420, interval="1d"):
    end = dt.date.today()
    start = end - dt.timedelta(days=period_days)
    rows = []
    
    for t in tickers:
        print(f"Processando: {t}")
        
        # --- PASSO 1: Buscar 'name' e 'setor' ---
        try:
            ticker_info = yf.Ticker(t).info
            name = ticker_info.get('shortName', ticker_info.get('longName', t))
            setor = ticker_info.get('sector', 'N/A')
        except Exception as e:
            print(f"Erro ao buscar info de {t}: {e}. Usando valores padrão.")
            name = t
            setor = "N/A"
        
        # --- PASSO 2: Buscar dados históricos ---
        df = fetch_single(t, start=start, end=end, interval=interval)
        
        # --- PASSO 3: Checagens iniciais ---
        if df is None or df.empty:
            continue
        if "Close" not in df.columns:
            continue
        df = df.dropna(subset=["Close"])
        if df.empty:
            continue
            
        # --- PASSO 4: CALCULAR AS FEATURES (AQUI!) ---
        # Este bloco estava faltando ou no lugar errado
        df["ret_1d"] = df["Close"].pct_change()
        df["ret_3m"] = df["Close"].pct_change(63)
        df["ret_6m"] = df["Close"].pct_change(126)
        df["vol_63"] = df["ret_1d"].rolling(63).std()
        df["volavg_21"] = df["Volume"].rolling(21).mean()
        
        # --- PASSO 5: Filtrar para a última linha válida ---
        # Agora o df tem as colunas novas
        df = df.dropna().tail(1) 
        if df.empty:
            # Isso pode acontecer se não houver dados suficientes
            # para as janelas de 63 ou 126 dias
            continue
            
        # --- PASSO 6: Adicionar à lista (AGORA É SEGURO) ---
        # Agora df["ret_3m"] com certeza existe
        rows.append({
            "ticker": t,
            "pais": region_for(t),
            
            "name": name,
            "setor": setor,
            
            "ret_3m": float(df["ret_3m"].iloc[0]),
            "ret_6m": float(df["ret_6m"].iloc[0]),
            "vol_63": float(df["vol_63"].iloc[0]),
            "volavg_21": float(df["volavg_21"].iloc[0]),
        })
        time.sleep(0.1) 
    
    if not rows:
        return pd.DataFrame()
    return pd.DataFrame(rows)

def synth_labels(df):
    v33 = df["vol_63"].quantile(0.33)
    v66 = df["vol_63"].quantile(0.66)
    profs = ["conservador","equilibrado","ousado"]
    samples = []
    for _, r in df.iterrows():
        for p in profs:
            label = 0
            if p == "conservador" and r["vol_63"] <= v33:
                label = 1
            if p == "equilibrado" and (r["vol_63"] > v33 and r["vol_63"] < v66):
                label = 1
            if p == "ousado" and r["vol_63"] >= v66:
                label = 1
            samples.append({
                "ticker": r["ticker"],
                "pais": r["pais"],
                "ret_3m": r["ret_3m"],
                "ret_6m": r["ret_6m"],
                "vol_63": r["vol_63"],
                "volavg_21": r["volavg_21"],
                "perfil_conservador": 1 if p=="conservador" else 0,
                "perfil_equilibrado": 1 if p=="equilibrado" else 0,
                "perfil_ousado": 1 if p=="ousado" else 0,
                "label": label
            })
    return pd.DataFrame(samples)

feats = fetch_features(TICKERS)
if feats.empty:
    raise RuntimeError("Sem dados de features. Verifique rede/yfinance.")
## feats.to_csv(f"{ART}/universe_features.csv", index=False)

print(f"Atualizando {len(feats)} ativos no MongoDB...")
data_to_save = feats.to_dict('records') 

for asset in data_to_save:
    query = {"ticker": asset["ticker"]} 
    update_data = {"$set": asset}     
    assets_collection.update_one(query, update_data, upsert=True)

print("Universo de ações salvo no MongoDB Atlas.")

data = synth_labels(feats)
feature_cols = ["ret_3m","ret_6m","vol_63","volavg_21","perfil_conservador","perfil_equilibrado","perfil_ousado"]
X = data[feature_cols].replace([np.inf,-np.inf], 0.0).fillna(0.0).values
y = data["label"].astype(int).values
scaler = StandardScaler()
Xs = scaler.fit_transform(X)
Xtr, Xte, ytr, yte = train_test_split(Xs, y, test_size=0.2, random_state=42, stratify=y)
clf = MLPClassifier(hidden_layer_sizes=(16,8), activation="relu", alpha=1e-4, max_iter=500, random_state=42)
clf.fit(Xtr, ytr)
dump(clf, f"{ART}/reco_model.joblib")
dump(scaler, f"{ART}/reco_scaler.joblib")
with open(f"{ART}/feature_cols.json","w",encoding="utf-8") as f:
    json.dump(feature_cols, f, ensure_ascii=False)
print(ART)

Conectado ao MongoDB Atlas com sucesso!
Processando: AAPL
Processando: MSFT
Processando: AMZN
Processando: GOOGL
Processando: META
Processando: NVDA
Processando: TSLA
Processando: JPM
Processando: V
Processando: PG
Processando: KO
Processando: PEP
Processando: NFLX
Processando: AMD
Processando: INTC
Processando: DIS
Processando: PETR4.SA
Processando: VALE3.SA
Processando: ITUB4.SA
Processando: BBDC4.SA
Processando: ABEV3.SA
Processando: WEGE3.SA
Processando: BBAS3.SA
Processando: B3SA3.SA
Processando: RAIL3.SA
Processando: PRIO3.SA
Processando: LREN3.SA
Processando: GGBR4.SA
Atualizando 28 ativos no MongoDB...
Universo de ações salvo no MongoDB Atlas.
..//artifacts


