In [2]:
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
import os
from transformersFinal import TransformerAutoencoder  # Asegúrate que esta importación funcione

# Directorios desde donde estás ejecutando (src/TRANSFORMERS/good_architecture/)
notebook_dir = Path().resolve()

# Subimos dos niveles para llegar al proyecto raíz
project_root = notebook_dir.parents[2]

# Rutas absolutas
models_dir = project_root / "models" / "transformers" / "Good_model"
data_sequences_dir = project_root / "data" / "sequences_ready"
data_original_dir = project_root / "data" / "processed_data"
plot_output_dir = notebook_dir / "plots" / "reconstructions"
plot_output_dir.mkdir(parents=True, exist_ok=True)

def load_model(model_path, input_dim=5, seq_len=30, d_model=64, nhead=4, num_layers=2):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = TransformerAutoencoder(
        input_dim=input_dim,
        seq_len=seq_len,
        d_model=d_model,
        nhead=nhead,
        num_layers=num_layers
    ).to(device)

    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    return model, device

def evaluate_and_plot(company, model, device, data_path, original_csv_path, start_year=2018, end_year=2021, threshold_percentile=95):
    data = pd.read_csv(data_path).values.reshape(-1, 30, 5)
    tensor_data = torch.tensor(data, dtype=torch.float32).to(device)

    with torch.no_grad():
        output = model(tensor_data).cpu().numpy()

    real_close = tensor_data[:, -1, 0].cpu().numpy()
    reconstructed_close = output[:, -1, 0]
    reconstruction_error = (real_close - reconstructed_close) ** 2

    original_data = pd.read_csv(original_csv_path)
    dates = pd.to_datetime(original_data["Date"], utc=True)
    aligned_dates = dates[29:29 + len(tensor_data)].reset_index(drop=True)

    threshold = np.percentile(reconstruction_error, threshold_percentile)
    anomalies = reconstruction_error > threshold

    mask = (aligned_dates.dt.year >= start_year) & (aligned_dates.dt.year <= end_year)
    plot_dates = aligned_dates[mask]
    plot_real_close = real_close[mask]
    plot_reconstructed_close = reconstructed_close[mask]
    plot_error = reconstruction_error[mask]
    plot_anomalies = anomalies[mask]

    fig, axes = plt.subplots(2, 1, figsize=(16, 8), sharex=True)

    axes[0].plot(plot_dates, plot_real_close, label="Real (Close)", color='blue', linewidth=1)
    axes[0].plot(plot_dates, plot_reconstructed_close, label="Reconstructed", color='orange', linewidth=1)
    axes[0].scatter(plot_dates[plot_anomalies], plot_real_close[plot_anomalies], color='red', s=30, label="Anomalies")
    axes[0].set_title(f"{company} - Reconstructed vs Real Close Price")
    axes[0].legend()
    axes[0].grid()

    axes[1].plot(plot_dates, plot_error, label="Reconstruction Error", color='purple', linewidth=1)
    axes[1].axhline(y=threshold, color='red', linestyle='--', label=f"Threshold ({threshold_percentile}%)")
    axes[1].set_title(f"{company} - Reconstruction Error")
    axes[1].legend()
    axes[1].grid()
    axes[1].set_xlabel("Date")

    plt.tight_layout()
    fig.savefig(plot_output_dir / f"{company}_evaluation.png")
    plt.close()

# Ejecutar en notebook:
companies = ["AAPL", "GOOGL", "MSFT", "NVDA", "TSLA"]

for company in companies:
    print(f"Processing {company}...")

    model_path = models_dir / f"{company.lower()}_transformer_autoencoder.pth"
    data_path = data_sequences_dir / f"{company}_data_sequences.csv"
    original_csv_path = data_original_dir / f"{company}_data.csv"

    if not model_path.exists():
        print(f"❌ Modelo no encontrado: {model_path}")
        continue
    if not data_path.exists():
        print(f"❌ Datos de secuencia no encontrados: {data_path}")
        continue
    if not original_csv_path.exists():
        print(f"❌ CSV original no encontrado: {original_csv_path}")
        continue

    model, device = load_model(model_path)
    evaluate_and_plot(company, model, device, data_path, original_csv_path)




Processing AAPL...
Processing GOOGL...
Processing MSFT...
Processing NVDA...
Processing TSLA...
