In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
# from torchsummary import summary
import pandas as pd
from torch.utils.data import Dataset, DataLoader
# from torch_summary import summary
from sklearn.preprocessing import MinMaxScaler,StandardScaler
import numpy as np
import talib
import matplotlib.pyplot as plt

In [2]:


class CustomDataset(Dataset):
    def __init__(self, X_data, y_data, window_size=36, normalize=True):

        X_data = np.array(X_data, dtype=np.float32)
        y_data = np.array(y_data, dtype=np.float32)

        # --- Create sliding windows ---
        windows = []
        for i in range(len(X_data) - window_size + 1):
            window = X_data[i:i + window_size]  # flatten features
            windows.append(window)
        self.X_data = np.array(windows, dtype=np.float32)

        # Align y_data: pick the last value of each window as target
        self.y_data = y_data[window_size - 1:]


        self.length = len(self.X_data)

    def __getitem__(self, index):
        x = torch.tensor(self.X_data[index], dtype=torch.float32)
        y = torch.tensor(self.y_data[index], dtype=torch.float32)
        return x.T, y

    def __len__(self):
        return self.length


In [3]:
class HybridGSRModel(nn.Module):
    def __init__(self, n_features=5, lstm_units=50, cnn1_filters=64, cnn2_filters=32):
        """
        Initializes the PyTorch model based on the diagram.
        
        :param n_features: The number of features at each time step (input channels).
        :param lstm_units: The number of hidden units in each LSTM layer.
        :param cnn1_filters: The number of filters for the 1st CNN branch.
        :param cnn2_filters: The number of filters for the 2nd CNN branch.
        """
        super(HybridGSRModel, self).__init__()
        
        # --- Branch 1: 1st Convolution Neural Network ---
        # Keras 'padding=same' with kernel=3 is padding=1
        self.cnn_branch1 = nn.Sequential(
            nn.Conv1d(in_channels=n_features, out_channels=cnn1_filters, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2), # P1
            
            nn.Conv1d(in_channels=cnn1_filters, out_channels=cnn1_filters, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2), # P2
            
            nn.Conv1d(in_channels=cnn1_filters, out_channels=cnn1_filters, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2), # P3
            
            nn.Conv1d(in_channels=cnn1_filters, out_channels=cnn1_filters, kernel_size=3, padding=1),
            nn.ReLU()
            # F1 output
        )
        
        # --- Branch 1: LSTM Network ---
        # The input size for the LSTM is the number of filters from the CNN (cnn1_filters)
        self.lstm_branch1 = nn.LSTM(
            input_size=cnn1_filters, 
            hidden_size=lstm_units, 
            num_layers=3,         # 3 stacked LSTMs
            batch_first=True      # Input shape is (batch, seq_len, features)
        )
        
        # --- Branch 2: 2nd Convolution Neural Network ---
        # Keras 'padding=same' with kernel=5 is padding=2
        self.cnn_branch2 = nn.Sequential(
            nn.Conv1d(in_channels=n_features, out_channels=cnn2_filters, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2),
            
            nn.Conv1d(in_channels=cnn2_filters, out_channels=cnn2_filters, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2)
            # 2nd CNN Output (before pooling)
        )
        
        # --- Merging and Final Prediction Head ---
        # The input to the dense layer is the concatenated output of:
        # 1. Branch 1 LSTM (size: lstm_units)
        # 2. Branch 2 CNN (size: cnn2_filters)
        self.head = nn.Sequential(
            nn.Linear(lstm_units + cnn2_filters, 100),
            nn.ReLU(),
            nn.Linear(100, 1), # Final GSR Prediction (regression)
            nn.Sigmoid()
        )

    def forward(self, x):
        """
        Defines the forward pass of the model.
        :param x: Input tensor of shape (batch_size, timesteps, n_features)
        """
        
        # PyTorch Conv1D expects (batch, channels, timesteps)
        # Input 'x' is (batch, timesteps, channels), so we permute
        x_cnn_input = x.permute(0, 2, 1)
        
        # --- Branch 1: CNN -> LSTM ---
        
        # 1. Pass through 1st CNN branch
        f1_output = self.cnn_branch1(x_cnn_input) # Shape: (batch, cnn1_filters, reduced_timesteps)
        
        # 2. Prepare for LSTM
        # LSTM (batch_first=True) expects (batch, timesteps, features/channels)
        lstm_input = f1_output.permute(0, 2, 1) 
        
        # 3. Pass through LSTM
        # We only need the final hidden state of the last layer
        # h_n shape is (num_layers, batch, hidden_size)
        _, (h_n, c_n) = self.lstm_branch1(lstm_input)
        
        # Get the hidden state of the last layer
        lstm_output = h_n[-1] # Shape: (batch, lstm_units)
        
        # --- Branch 2: CNN -> Global Pooling ---
        
        # 1. Pass through 2nd CNN branch
        cnn2_output = self.cnn_branch2(x_cnn_input) # Shape: (batch, cnn2_filters, reduced_timesteps)
        
        # 2. Apply Global Average Pooling
        # We average over the time dimension (dim=2)
        cnn_pooled_output = torch.mean(cnn2_output, dim=2) # Shape: (batch, cnn2_filters)
        
        # --- Merging and Final Prediction ---
        
        # 1. Concatenate the outputs from both branches
        merged = torch.cat((lstm_output, cnn_pooled_output), dim=1) # Shape: (batch, lstm_units + cnn2_filters)
        
        # 2. Pass through the final prediction head
        prediction = self.head(merged)
        
        return prediction # Squeeze to (batch_size) for loss calculation

In [4]:
class CNN_LSTM_Model(nn.Module):
    def __init__(self, input_features, ws, num_classes=1):
        super(CNN_LSTM_Model, self).__init__()
        conv_filters = 256
        conv_kernel_size = 3
        lstm_units = 256
        
        self.cnn_extractor = nn.Sequential(
            nn.Conv1d(in_channels=input_features, out_channels=conv_filters, 
                      kernel_size=conv_kernel_size, padding='same'),
            nn.BatchNorm1d(num_features=conv_filters),
            nn.ReLU(),
            nn.Conv1d(in_channels=conv_filters, out_channels=conv_filters, 
                      kernel_size=conv_kernel_size, padding='same'),
            nn.BatchNorm1d(num_features=conv_filters),
            nn.ReLU(),
            nn.MaxPool1d(3,1,1),
            
            nn.Conv1d(in_channels=conv_filters, out_channels=conv_filters, 
                      kernel_size=conv_kernel_size, padding='same'),
            nn.BatchNorm1d(num_features=conv_filters),
            nn.ReLU(),
            nn.Conv1d(in_channels=conv_filters, out_channels=conv_filters, 
                      kernel_size=conv_kernel_size, padding='same'),
            nn.BatchNorm1d(num_features=conv_filters),
            nn.ReLU(),
            nn.MaxPool1d(3,1,1),
        )
        
        self.lstm = nn.LSTM(
            conv_filters, 
            lstm_units,
            batch_first=True,
            num_layers=2,
            dropout=0.5  # This dropout is good!
        )
        
        # --- NEW, SIMPLIFIED FULLY-CONNECTED LAYER ---
        self.fc = nn.Sequential(
            nn.Dropout(p=0.5),  # Add a dropout layer for regularization
            # The input is now just 'lstm_units', NOT 'lstm_units * ws'
            nn.Linear(in_features=lstm_units, out_features=num_classes),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.cnn_extractor(x)
        x = x.permute(0, 2, 1) 
        # print(x.shape)
        # out shape is (batch_size, seq_len, lstm_units)
        out, _ = self.lstm(x)
        
        # --- KEY CHANGE ---
        # We only take the output from the VERY LAST time step
        # This is out[:, -1, :]
        out = out[:, -1, :]
        # print(out.shpe)
        # No nn.Flatten() needed
        out = self.fc(out)
        return out

In [5]:
df= pd.read_csv('data/ALL.csv')
df.index = pd.to_datetime(df["timestamp"])

In [6]:
df

Unnamed: 0_level_0,timestamp,nvda_open,nvda_high,nvda_low,nvda_close,nvda_volume,amd_open,amd_high,amd_low,amd_close,...,btc_close,btc_volume,gold_open,gold_high,gold_low,gold_close,gold_volume,overall_sentiment_score,nvda_sentiment_score,nvda_relevance_score
timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
2020-01-02 09:00:00+00:00,2020-01-02 09:00:00+00:00,5.9180,5.9481,5.9180,5.9389,60000,46.63,46.78000,46.63,46.63,...,,,,,,,,,,
2020-01-02 10:00:00+00:00,2020-01-02 10:00:00+00:00,5.9389,5.9486,5.9372,5.9449,29920,46.64,47.00000,46.63,46.80,...,,,,,,,,,,
2020-01-02 11:00:00+00:00,2020-01-02 11:00:00+00:00,5.9486,5.9501,5.9436,5.9436,37800,46.85,46.92000,46.71,46.79,...,,,,,,,,,,
2020-01-02 12:00:00+00:00,2020-01-02 12:00:00+00:00,5.9329,5.9481,5.9242,5.9464,614480,46.76,46.88000,46.63,46.86,...,,,,,,,,,,
2020-01-02 13:00:00+00:00,2020-01-02 13:00:00+00:00,5.9456,5.9625,5.9247,5.9556,1660520,46.86,46.95000,46.63,46.90,...,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2025-06-30 19:00:00+00:00,2025-06-30 19:00:00+00:00,157.7411,158.6510,157.6611,157.8611,26770205,141.72,142.24000,141.22,141.89,...,107760.75,10943.0,3298.40,3309.47,3297.98,3308.65,10356.0,0.211782,0.053219,0.150592
2025-06-30 20:00:00+00:00,2025-06-30 20:00:00+00:00,157.8611,158.6600,155.9600,157.7811,72869000,141.89,144.01955,141.01,141.70,...,107528.57,9128.0,3308.65,3309.13,3302.22,3302.95,4943.0,0.358923,0.043446,0.043264
2025-06-30 21:00:00+00:00,2025-06-30 21:00:00+00:00,157.7811,158.6600,155.9600,157.7811,223935,141.65,143.77955,141.01,141.52,...,,,,,,,,0.190412,0.122056,0.208046
2025-06-30 22:00:00+00:00,2025-06-30 22:00:00+00:00,157.7611,157.7911,157.4911,157.5113,219009,141.52,141.75000,141.33,141.44,...,,,,,,,,0.000000,0.000000,0.000000


In [7]:
# df.columns = ['data','open','high','low','close','tickvol','volume','spread']

In [8]:
window_size = 14
df['y'] = (df['nvda_open'] < df['nvda_close']).astype(int)
df['y'] = df['y'].shift(-1)

## ðŸ’¨ Momentum Indicators
# RSI
df['rsi_14'] = talib.RSI(df['nvda_close'], timeperiod=14)

# Stochastic Oscillator
df['stoch_k'], df['stoch_d'] = talib.STOCH(
    df['nvda_high'], df['nvda_low'], df['nvda_close'],
    fastk_period=14, slowk_period=3, slowd_period=3
)

# MACD
df['macd'], df['macd_signal'], df['macd_hist'] = talib.MACD(
    df['nvda_close'],
    fastperiod=12, slowperiod=26, signalperiod=9
)

## ðŸ“ˆ Trend Indicators
# We calculate the SMA to create our custom feature
sma_20 = talib.SMA(df['nvda_close'], timeperiod=window_size)
df['price_to_sma'] = (df['nvda_close'] - sma_20) / sma_20

## ðŸŒŠ Volatility Indicators
# ATR
df['atr_14'] = talib.ATR(
    df['nvda_high'], df['nvda_low'], df['nvda_close'],
    timeperiod=14
)

# Bollinger Bands
upper_bb, middle_bb, lower_bb = talib.BBANDS(
    df['nvda_close'],
    timeperiod=window_size, nbdevup=2, nbdevdn=2
)
# Create the "Bollinger Band Width" feature
df['bb_width'] = (upper_bb - lower_bb) / middle_bb

## ðŸ“Š Volume Indicators
# On-Balance Volume (OBV)
df['obv'] = talib.OBV(df['nvda_close'], df['nvda_volume'])
# Create the "OBV Slope" feature (using a 10-period change)
df['obv_slope'] = df['obv'].diff(periods=10)
# df['rsi_lag_1'] = df['rsi_14'].shift(1)
# df['rsi_lag_2'] = df['rsi_14'].shift(2)
# df['price_to_sma_lag_1'] = df['price_to_sma'].shift(1)
# df['bb_width_lag_1'] = df['bb_width'].shift(1)
# df['atr_lag_1'] = df['atr_14'].shift(1)
# df.drop

In [9]:
df[['amd_open', 'amd_high', 'amd_low', 'amd_close',
       'amd_volume', 'intc_open', 'intc_high', 'intc_low', 'intc_close',
       'intc_volume', 'spy_open', 'spy_high', 'spy_low', 'spy_close',
       'spy_volume', 'dia_open', 'dia_high', 'dia_low', 'dia_close',
       'dia_volume', 'iwm_open', 'iwm_high', 'iwm_low', 'iwm_close',
       'iwm_volume','btc_open',
       'btc_high', 'btc_low', 'btc_close', 'btc_volume', 'gold_open',
       'gold_high', 'gold_low', 'gold_close', 'gold_volume']] = df[['amd_open', 'amd_high', 'amd_low', 'amd_close',
       'amd_volume', 'intc_open', 'intc_high', 'intc_low', 'intc_close',
       'intc_volume', 'spy_open', 'spy_high', 'spy_low', 'spy_close',
       'spy_volume', 'dia_open', 'dia_high', 'dia_low', 'dia_close',
       'dia_volume', 'iwm_open', 'iwm_high', 'iwm_low', 'iwm_close',
       'iwm_volume','btc_open',
       'btc_high', 'btc_low', 'btc_close', 'btc_volume', 'gold_open',
       'gold_high', 'gold_low', 'gold_close', 'gold_volume']].fillna(method='ffill')

  'gold_high', 'gold_low', 'gold_close', 'gold_volume']].fillna(method='ffill')


In [10]:
# df["2022-03-03 9:00":][['overall_sentiment_score', 'nvda_sentiment_score',
#        'nvda_relevance_score', 'y', 'rsi_14', 'stoch_k', 'stoch_d', 'macd',
#        'macd_signal', 'macd_hist', 'price_to_sma', 'atr_14', 'bb_width', 'obv',
#        'obv_slope', 'rsi_lag_1', 'rsi_lag_2', 'price_to_sma_lag_1',
#        'bb_width_lag_1', 'atr_lag_1']].isnull().sum()

In [11]:
df.dropna(inplace=True)

In [12]:
df

Unnamed: 0_level_0,timestamp,nvda_open,nvda_high,nvda_low,nvda_close,nvda_volume,amd_open,amd_high,amd_low,amd_close,...,stoch_k,stoch_d,macd,macd_signal,macd_hist,price_to_sma,atr_14,bb_width,obv,obv_slope
timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
2022-03-03 09:00:00+00:00,2022-03-03 09:00:00+00:00,24.0095,24.0374,23.9216,23.9735,46180,117.4600,117.50000,116.900,117.40,...,58.092946,58.122658,0.071580,0.054932,0.016648,0.001353,0.231889,0.032379,6.500109e+09,70025360.0
2022-03-03 10:00:00+00:00,2022-03-03 10:00:00+00:00,23.9795,23.9975,23.9466,23.9596,9530,117.4000,117.49000,117.200,117.38,...,58.235596,58.160142,0.066487,0.057243,0.009244,-0.000089,0.218961,0.029836,6.500099e+09,26313440.0
2022-03-03 11:00:00+00:00,2022-03-03 11:00:00+00:00,23.9825,24.0484,23.9456,23.9596,50490,117.3000,117.78000,117.300,117.30,...,56.812660,57.713734,0.061739,0.058142,0.003597,-0.001209,0.210664,0.024157,6.500099e+09,-3122080.0
2022-03-03 12:00:00+00:00,2022-03-03 12:00:00+00:00,23.9945,24.0694,23.9596,24.0574,166470,117.5400,117.94000,117.450,117.65,...,59.804748,58.284334,0.065117,0.059537,0.005580,0.002022,0.203460,0.022117,6.500266e+09,-29258710.0
2022-03-03 13:00:00+00:00,2022-03-03 13:00:00+00:00,23.9715,24.2750,23.9715,24.2680,776990,117.8001,118.26000,117.432,118.26,...,70.100361,62.239256,0.083822,0.064394,0.019428,0.008960,0.210605,0.017972,6.501043e+09,-59969930.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2025-06-30 18:00:00+00:00,2025-06-30 18:00:00+00:00,157.8111,158.0161,156.8311,157.7361,14311666,142.3450,142.43000,141.400,141.71,...,74.995658,75.831163,0.562108,0.721528,-0.159420,-0.001190,1.188289,0.014730,1.978200e+10,-28124420.0
2025-06-30 19:00:00+00:00,2025-06-30 19:00:00+00:00,157.7411,158.6510,157.6611,157.8611,26770205,141.7200,142.24000,141.220,141.89,...,72.165016,75.129133,0.528390,0.682900,-0.154510,-0.000620,1.174118,0.014222,1.980877e+10,-1354215.0
2025-06-30 20:00:00+00:00,2025-06-30 20:00:00+00:00,157.8611,158.6600,155.9600,157.7811,72869000,141.8900,144.01955,141.010,141.70,...,69.890375,72.350350,0.489569,0.644234,-0.154664,-0.001231,1.283110,0.013996,1.973590e+10,-74223215.0
2025-06-30 21:00:00+00:00,2025-06-30 21:00:00+00:00,157.7811,158.6600,155.9600,157.7811,223935,141.6500,143.77955,141.010,141.52,...,68.435802,70.163731,0.453575,0.606102,-0.152527,-0.001334,1.384316,0.013753,1.973590e+10,-73691014.0


In [13]:
# df.drop(['high','low'],axis=1,inplace=True)

In [14]:
X = df.drop('y',axis=1)
# df
y = df['y']

In [15]:
X.drop('timestamp',axis=1,inplace=True)

In [16]:
# window_size = 24

# # Create list of windows
# windows = []
# for i in range(len(X) - window_size + 1):
#     # Take 30 consecutive rows for all columns, flatten them into 1D array
#     window = X.iloc[i:i + window_size].values.flatten()
#     windows.append(window)
# cols = []
# for col in X.columns:
#     for j in range(window_size):
#         cols.append(f"{col}_{j}")

# # Make new DataFrame
# X_new = pd.DataFrame(windows, columns=cols)
# y_new = y[window_size - 1:].reset_index(drop=True)
# X_new = X_new.reset_index(drop=True)
# # print(df_windows.head())

In [17]:
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report,roc_curve,auc
from xgboost import XGBClassifier

train_split = 0.7
val_split = 0.15  # 15%
test_split = 0.15  # 15%

# Calculate split indices
train_end = int(len(X) * train_split)
val_end = train_end + int(len(X) * val_split)

# Split the data
X_train = X.iloc[:train_end]
X_val = X.iloc[train_end:val_end]
X_test = X.iloc[val_end:]

y_train = y.iloc[:train_end]
y_val = y.iloc[train_end:val_end]
y_test = y.iloc[val_end:]

# Scale data (fit only on training set)
scaler = MinMaxScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_val_scaled = scaler.transform(X_val)
X_test_scaled = scaler.transform(X_test)

# Print sizes
print(f"Total samples: {len(X)}")
print(f"Training samples: {len(X_train_scaled)}")
print(f"Validation samples: {len(X_val_scaled)}")
print(f"Testing samples: {len(X_test_scaled)}")

Total samples: 13325
Training samples: 9327
Validation samples: 1998
Testing samples: 2000


In [18]:
trainDataset = CustomDataset(X_train_scaled,y_train,window_size=window_size)
valDataset = CustomDataset(X_val_scaled,y_val,window_size=window_size)
testDataset = CustomDataset(X_test_scaled,y_test,window_size=window_size)
batch_size = 32
train_loader = DataLoader(dataset=trainDataset, batch_size=batch_size, shuffle=False,drop_last=True)
val_loader = DataLoader(dataset=valDataset, batch_size=1, shuffle=False,drop_last=True)
test_loader = DataLoader(dataset=testDataset, batch_size=1, shuffle=False)

In [19]:
temp = next(iter(train_loader))[0]

In [20]:
temp.shape

torch.Size([32, 55, 14])

In [21]:
# model = CNN_LSTM(3,50,1,1).to('cuda')
model = CNN_LSTM_Model(55,window_size).to('cuda')
# model = HybridGSRModel(55).to('cuda')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.eval()
with torch.no_grad():
    output = model(temp.to('cuda'))
print("\n--- Test Successful ---")
print(f"Input tensor shape:  {temp.shape} (Batch x feature x window size)")
print(f"Output tensor shape: {output.shape}")



--- Test Successful ---
Input tensor shape:  torch.Size([32, 55, 14]) (Batch x feature x window size)
Output tensor shape: torch.Size([32, 1])


In [22]:
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
# optimizer = torch.optim.SGD(model.parameters(), lr=0.01,momentum=0.9, weight_decay=1e-5)

In [23]:
from sklearn.metrics import precision_score, recall_score, f1_score,roc_auc_score


In [None]:
num_epochs = 20
for epoch in range(num_epochs):
    # --- Training ---
    model.train()
    running_loss = 0.0
    for x_batch, y_batch in train_loader:
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device).unsqueeze(1)

        optimizer.zero_grad()
        outputs = model(x_batch)
        loss = criterion(outputs, y_batch)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * x_batch.size(0)
    train_loss = running_loss / len(train_loader.dataset)
    
    # --- Validation ---
    model.eval()
    val_loss = 0.0
    val_preds = []
    val_labels = []
    val_probs = []  # for AUC

    with torch.no_grad():
        for x_batch, y_batch in val_loader:
            x_batch = x_batch.to(device)
            y_batch = y_batch.to(device).unsqueeze(1)
            outputs = model(x_batch)
            loss = criterion(outputs, y_batch)
            val_loss += loss.item() * x_batch.size(0)

            probs = outputs.cpu().numpy()
            preds = (probs >= 0.5).astype(float)
            val_probs.extend(probs)
            val_preds.extend(preds)
            val_labels.extend(y_batch.cpu().numpy())

    val_loss /= len(val_loader.dataset)
    val_precision = precision_score(val_labels, val_preds)
    val_recall = recall_score(val_labels, val_preds)
    val_f1 = f1_score(val_labels, val_preds)
    val_acc = (np.array(val_preds) == np.array(val_labels)).mean()
    val_auc = roc_auc_score(val_labels, val_probs)

    # --- Test ---
    model.eval()
    test_loss = 0.0
    test_preds = []
    test_labels = []
    test_probs = []  # for AUC

    with torch.no_grad():
        for x_batch, y_batch in test_loader:
            x_batch = x_batch.to(device)
            y_batch = y_batch.to(device).unsqueeze(1)
            outputs = model(x_batch)
            loss = criterion(outputs, y_batch)
            test_loss += loss.item() * x_batch.size(0)

            probs = outputs.cpu().numpy()
            preds = (probs >= 0.5).astype(float)
            test_probs.extend(probs)
            test_preds.extend(preds)
            test_labels.extend(y_batch.cpu().numpy())

    test_loss /= len(test_loader.dataset)
    test_precision = precision_score(test_labels, test_preds)
    test_recall = recall_score(test_labels, test_preds)
    test_f1 = f1_score(test_labels, test_preds)
    test_acc = (np.array(test_preds) == np.array(test_labels)).mean()
    test_auc = roc_auc_score(test_labels, test_probs)

    # --- Print results ---
    print(f"Epoch [{epoch+1}/{num_epochs}]")
    print(f" Train Loss: {train_loss:.4f}")
    print(f" Val   -> Loss: {val_loss:.4f} | Acc: {val_acc:.4f} | "
          f"Precision: {val_precision:.4f} | Recall: {val_recall:.4f} | "
          f"F1: {val_f1:.4f} | AUC: {val_auc:.4f}")
    print(f" Test  -> Loss: {test_loss:.4f} | Acc: {test_acc:.4f} | "
          f"Precision: {test_precision:.4f} | Recall: {test_recall:.4f} | "
          f"F1: {test_f1:.4f} | AUC: {test_auc:.4f}")
    print("-" * 90)


Epoch [1/20]
 Train Loss: 0.6933
 Val   -> Loss: 0.6901 | Acc: 0.5395 | Precision: 0.3333 | Recall: 0.0011 | F1: 0.0022 | AUC: 0.4970
 Test  -> Loss: 0.6914 | Acc: 0.5325 | Precision: 0.4762 | Recall: 0.2200 | F1: 0.3010 | AUC: 0.4961
------------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch [2/20]
 Train Loss: 0.6897
 Val   -> Loss: 0.6901 | Acc: 0.5395 | Precision: 0.0000 | Recall: 0.0000 | F1: 0.0000 | AUC: 0.5122
 Test  -> Loss: 0.6901 | Acc: 0.5425 | Precision: 0.0000 | Recall: 0.0000 | F1: 0.0000 | AUC: 0.5119
------------------------------------------------------------------------------------------
Epoch [3/20]
 Train Loss: 0.6891
 Val   -> Loss: 0.7006 | Acc: 0.5285 | Precision: 0.4588 | Recall: 0.1402 | F1: 0.2148 | AUC: 0.5090
 Test  -> Loss: 0.7067 | Acc: 0.5164 | Precision: 0.4366 | Recall: 0.1969 | F1: 0.2714 | AUC: 0.4864
------------------------------------------------------------------------------------------
Epoch [4/20]
 Train Loss: 0.6873
 Val   -> Loss: 0.6946 | Acc: 0.5370 | Precision: 0.4483 | Recall: 0.0285 | F1: 0.0536 | AUC: 0.4906
 Test  -> Loss: 0.6943 | Acc: 0.5335 | Precision: 0.4511 | Recall: 0.0913 | F1: 0.1519 | AUC: 0.5184
------------------------------------------------------------------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch [5/20]
 Train Loss: 0.6847
 Val   -> Loss: 0.6949 | Acc: 0.5385 | Precision: 0.4483 | Recall: 0.0142 | F1: 0.0276 | AUC: 0.4944
 Test  -> Loss: 0.6899 | Acc: 0.5425 | Precision: 0.0000 | Recall: 0.0000 | F1: 0.0000 | AUC: 0.5569
------------------------------------------------------------------------------------------
Epoch [6/20]
 Train Loss: 0.6841
 Val   -> Loss: 0.6911 | Acc: 0.5401 | Precision: 0.5000 | Recall: 0.0088 | F1: 0.0172 | AUC: 0.5133
 Test  -> Loss: 0.6917 | Acc: 0.5425 | Precision: 0.5000 | Recall: 0.0011 | F1: 0.0022 | AUC: 0.4983
------------------------------------------------------------------------------------------
