In [None]:
import torch
from src.model.bert_classifier import BERTClassifier
from src.config import config, MODEL_CONFIG
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.express as px
from model.data_loading import CustomDataModule

In [None]:
# Settings
BATCH_SIZE = 16
target_col_name = MODEL_CONFIG.target_col_name

In [None]:
dataset = pd.read_parquet(config.data.merged)

torch.cuda.empty_cache()
dm = CustomDataModule(news_data_path=config.data.learning_dataset, 
                      input_ids_path=config.data.news.input_ids, 
                      masks_path=config.data.news.masks, 
                      batch_size=BATCH_SIZE,
                      target_col_name=target_col_name)

In [None]:
model = BERTClassifier.load_from_checkpoint("/path/to/checkpoint.ckpt")

# disable randomness, dropout, etc...
model.eval()

with torch.no_grad():
    logits = model(dm.predict_dataloader())

probs = logits.softmax(dim=1)
max_probs = np.apply_along_axis(np.max, axis=1, arr=probs)
cls_preds = np.apply_along_axis(np.argmax, axis=1, arr=probs)

In [None]:
dataset.loc[:, "max_probs"] = max_probs
dataset.loc[:, "cls_preds"] = cls_preds

# Change Over Time

# Analysis of Single Forecast: 

In [None]:
idx = 11
tmp = test_dat.loc[pred_margin_mask]
row = tmp.iloc[idx, :]
print(row)
# print(f"Fcst: {row.Fcst}")
# print(f"Target: {row.IntradayReturn}")

print(row.body[:750])
pr_time, ticker, fcst = row[["Date", "ID", "Fcst"]]
df = stocks.query("(Date >= @pr_time) & (ID == @ticker)").head(30)
fig = go.Figure(data=[go.Candlestick(x=df['Date'],
                open=df['Open'],
                high=df['High'],
                low=df['Low'],
                close=df['Close'])])
fig.update_layout(xaxis_rangeslider_visible=False)
fig.show()

In [None]:
print(row)

# Trading Performance

In [None]:
tmp = test_dat.loc[pred_margin_mask].dropna()

In [None]:
tmp.head()

In [None]:
trades = np.sign(tmp["Fcst"])*tmp["CloseToCloseReturn"]

In [None]:
trades.mean()

In [None]:
px.scatter(tmp, x=MODEL_CONFIG.target_col_name, y="Fcst")