In [1]:
# ==== Import libraries ====
import math
import random
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.functional as F

RANDOM_SEED = 42
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)


<torch._C.Generator at 0x2171f717c90>

In [9]:
# ==== Check GPU/CUDA availability ====
print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU device: {torch.cuda.get_device_name(0)}")
    print(f"GPU count: {torch.cuda.device_count()}")
else:
    print("Training will use CPU")


PyTorch version: 2.9.1+cu130
CUDA available: True
CUDA version: 13.0
GPU device: NVIDIA GeForce RTX 4070 Ti SUPER
GPU count: 1


In [21]:
DATA_CSV = "data/GenomeCRISPR_+_strands.csv"
SEQ_LEN  = 23
VAL_FRAC = 0.10
TEST_FRAC= 0.10

seq_col   = "sequence"
cell_col  = "cellline"
phen_col  = "condition"
chr_col   = "chr"
strand_col= "strand"
screen_col= "screentype"
target_col= "log2fc"

# Read & keep only what we need
df = pd.read_csv(DATA_CSV, low_memory=False)
required_cols = [seq_col, cell_col, phen_col, chr_col, target_col]
optional_cols = [strand_col, screen_col]

# Check for required columns
missing = [c for c in required_cols if c not in df.columns]
if missing:
    raise KeyError(f"Missing expected columns: {missing}. Got: {list(df.columns)}")

# Check for optional columns (strand, screentype)
has_strand = strand_col in df.columns
has_screen = screen_col in df.columns
cols_to_keep = required_cols + ([strand_col] if has_strand else []) + ([screen_col] if has_screen else [])

df = df[cols_to_keep].copy()
df = df.dropna(subset=required_cols)

# Clean sequences: uppercase A/C/G/T and enforce length 23
df[seq_col] = df[seq_col].astype(str).str.upper().str.strip()
df = df[df[seq_col].str.len() == SEQ_LEN]
df = df[df[seq_col].str.match(r"^[ACGT]+$")]

# Factorize categoricals (one-liners)
cell_codes, cell_uniques = pd.factorize(df[cell_col].astype(str).str.strip(), sort=True)
phen_codes, phen_uniques = pd.factorize(df[phen_col].astype(str).str.strip(), sort=True)
# cast chr to string so 10/11/X/Y are handled uniformly
chr_codes,  chr_uniques  = pd.factorize(df[chr_col].astype(str).str.strip(),  sort=True)

# Factorize strand if available, otherwise create dummy
if has_strand:
    strand_codes, strand_uniques = pd.factorize(df[strand_col].astype(str).str.strip(), sort=True)
    n_strand = len(strand_uniques)
else:
    # Create dummy strand codes (all zeros) if not available
    strand_codes = np.zeros(len(df), dtype=np.int64)
    strand_uniques = np.array(["+"] if "+" in str(df.get(strand_col, "+").iloc[0] if len(df) > 0 else "+") else ["+"])
    n_strand = 1

# Factorize screentype if available, otherwise create dummy
if has_screen:
    screen_codes, screen_uniques = pd.factorize(df[screen_col].astype(str).str.strip(), sort=True)
    n_screen = len(screen_uniques)
else:
    # Create dummy screentype codes (all zeros) if not available
    screen_codes = np.zeros(len(df), dtype=np.int64)
    screen_uniques = np.array(["unknown"])
    n_screen = 1

n_cell, n_ph, n_chr = len(cell_uniques), len(phen_uniques), len(chr_uniques)

# One-hot the 23-mer sequences
BASE2IDX = {"A":0, "C":1, "G":2, "T":3}
def onehot_batch(seqs, L=SEQ_LEN):
    N = len(seqs)
    X = np.zeros((N, 4, L), dtype=np.float32)
    for i, s in enumerate(seqs):
        for j, ch in enumerate(s):
            X[i, BASE2IDX[ch], j] = 1.0
    return X

X_seq = onehot_batch(df[seq_col].tolist())
X_cell = cell_codes.astype(np.int64)
X_ph   = phen_codes.astype(np.int64)
X_chr  = chr_codes.astype(np.int64)
X_strand = strand_codes.astype(np.int64)
X_screen = screen_codes.astype(np.int64)
y      = df[target_col].astype(np.float32).to_numpy()

# Simple random split
idx_all = np.arange(len(df))
idx_train, idx_test = train_test_split(idx_all, test_size=TEST_FRAC, random_state=42)
idx_train, idx_val  = train_test_split(idx_train, test_size=VAL_FRAC/(1-TEST_FRAC), random_state=42)

def take(a, idx): return a[idx]
Xtr_seq, Xva_seq, Xte_seq = take(X_seq, idx_train), take(X_seq, idx_val), take(X_seq, idx_test)
Xtr_cel, Xva_cel, Xte_cel = take(X_cell, idx_train), take(X_cell, idx_val), take(X_cell, idx_test)
Xtr_ph,  Xva_ph,  Xte_ph  = take(X_ph,  idx_train), take(X_ph,  idx_val), take(X_ph,  idx_test)
Xtr_chr, Xva_chr, Xte_chr = take(X_chr, idx_train), take(X_chr, idx_val), take(X_chr, idx_test)
Xtr_str, Xva_str, Xte_str = take(X_strand, idx_train), take(X_strand, idx_val), take(X_strand, idx_test)
Xtr_scr, Xva_scr, Xte_scr = take(X_screen, idx_train), take(X_screen, idx_val), take(X_screen, idx_test)
y_tr,    y_va,    y_te    = take(y,     idx_train), take(y,     idx_val), take(y,     idx_test)

# Standardize targets using training set statistics
mu = y_tr.mean()
sigma = y_tr.std()
y_tr_norm = (y_tr - mu) / sigma
y_va_norm = (y_va - mu) / sigma
y_te_norm = (y_te - mu) / sigma

# Create full normalized array for easy indexing
y_norm = np.zeros_like(y)
y_norm[idx_train] = y_tr_norm
y_norm[idx_val] = y_va_norm
y_norm[idx_test] = y_te_norm

print(f"train={len(idx_train)}  val={len(idx_val)}  test={len(idx_test)}")
print(f"cells={n_cell}  phenotypes={n_ph}  chrs={n_chr}  strands={n_strand}  screentypes={n_screen}")
print(f"Target stats: mu={mu:.4f}, sigma={sigma:.4f}")


train=29452509  val=3681564  test=3681564
cells=420  phenotypes=34  chrs=279  strands=3  screentypes=2
Target stats: mu=-0.0878, sigma=0.8148


In [None]:
class EnhancedCrisprCNN(nn.Module):
    def __init__(self, base_channels=64, 
                 n_cell=420, n_phen=34, n_chr=301, n_strand=2, n_screen=1,
                 emb_dim=32, dropout=0.3):
        super().__init__()
        
        self.conv2d_1 = nn.Sequential(
            nn.Conv2d(1, base_channels, kernel_size=(3, 3), padding=(1, 1)),
            nn.BatchNorm2d(base_channels),
            nn.ReLU(),
            nn.Dropout(dropout * 0.3)
        )
        
        self.conv2d_2 = nn.Sequential(
            nn.Conv2d(base_channels, base_channels * 2, kernel_size=(3, 3), padding=(1, 1)),
            nn.BatchNorm2d(base_channels * 2),
            nn.ReLU(),
            nn.Dropout(dropout * 0.3)
        )
        
        self.conv2d_3_1 = nn.Sequential(
            nn.Conv2d(base_channels * 2, base_channels * 2, kernel_size=(3, 3), padding=(1, 1)),
            nn.BatchNorm2d(base_channels * 2),
            nn.ReLU(),
            nn.Dropout(dropout * 0.3)
        )
        self.conv2d_3_2 = nn.Sequential(
            nn.Conv2d(base_channels * 2, base_channels * 2, kernel_size=(3, 3), padding=(1, 1)),
            nn.BatchNorm2d(base_channels * 2),
            nn.Dropout(dropout * 0.2)
        )
        
        self.conv2d_4 = nn.Sequential(
            nn.Conv2d(base_channels * 2, base_channels * 2, kernel_size=(3, 3), padding=(1, 1)),
            nn.BatchNorm2d(base_channels * 2),
            nn.ReLU(),
            nn.Dropout(dropout * 0.2)
        )
        
        self.conv_multi1 = nn.Conv2d(base_channels * 2, base_channels, kernel_size=(1, 1))
        self.conv_multi2 = nn.Conv2d(base_channels * 2, base_channels, kernel_size=(3, 3), padding=(1, 1))
        self.conv_multi3 = nn.Conv2d(base_channels * 2, base_channels, kernel_size=(5, 5), padding=(2, 2))
        
        seq_feat_dim = base_channels * 6
    
        self.cell_emb = nn.Embedding(n_cell, emb_dim)
        self.phen_emb = nn.Embedding(n_phen, emb_dim)
        self.chr_emb = nn.Embedding(n_chr, emb_dim)
        self.strand_emb = nn.Embedding(n_strand, emb_dim)
        self.screen_emb = nn.Embedding(n_screen, emb_dim)
        
        self.cell_phen_interaction = nn.Sequential(
            nn.Linear(emb_dim * 2, emb_dim),
            nn.BatchNorm1d(emb_dim),
            nn.ReLU(),
            nn.Dropout(dropout * 0.4)
        )
        
        self.cell_chr_interaction = nn.Sequential(
            nn.Linear(emb_dim * 2, emb_dim),
            nn.BatchNorm1d(emb_dim),
            nn.ReLU(),
            nn.Dropout(dropout * 0.4)
        )
        
        self.phen_chr_interaction = nn.Sequential(
            nn.Linear(emb_dim * 2, emb_dim),
            nn.BatchNorm1d(emb_dim),
            nn.ReLU(),
            nn.Dropout(dropout * 0.4)
        )
        
        self.triple_interaction = nn.Sequential(
            nn.Linear(emb_dim * 3, emb_dim),
            nn.BatchNorm1d(emb_dim),
            nn.ReLU(),
            nn.Dropout(dropout * 0.4)
        )
        
        self.screen_cell_interaction = nn.Sequential(
            nn.Linear(emb_dim * 2, emb_dim),
            nn.BatchNorm1d(emb_dim),
            nn.ReLU(),
            nn.Dropout(dropout * 0.4)
        )
        
        self.screen_phen_interaction = nn.Sequential(
            nn.Linear(emb_dim * 2, emb_dim),
            nn.BatchNorm1d(emb_dim),
            nn.ReLU(),
            nn.Dropout(dropout * 0.4)
        )
        
        self.screen_chr_interaction = nn.Sequential(
            nn.Linear(emb_dim * 2, emb_dim),
            nn.BatchNorm1d(emb_dim),
            nn.ReLU(),
            nn.Dropout(dropout * 0.4)
        )
        
        self.seq_proj1 = nn.Sequential(
            nn.Linear(seq_feat_dim, emb_dim * 2),
            nn.BatchNorm1d(emb_dim * 2),
            nn.ReLU(),
            nn.Dropout(dropout * 0.3)
        )
        
        self.seq_proj2 = nn.Sequential(
            nn.Linear(seq_feat_dim, emb_dim * 2),
            nn.BatchNorm1d(emb_dim * 2),
            nn.ReLU(),
            nn.Dropout(dropout * 0.3)
        )
        
        self.seq_proj3 = nn.Sequential(
            nn.Linear(seq_feat_dim, emb_dim * 2),
            nn.BatchNorm1d(emb_dim * 2),
            nn.ReLU(),
            nn.Dropout(dropout * 0.3)
        )
        
        interaction_features = emb_dim * 7
        total_features = (emb_dim * 6) + (emb_dim * 5) + interaction_features
        
        self.fusion1 = nn.Sequential(
            nn.Linear(total_features, 320),
            nn.BatchNorm1d(320),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        
        self.fusion2 = nn.Sequential(
            nn.Linear(320, 320),
            nn.BatchNorm1d(320),
            nn.ReLU(),
            nn.Dropout(dropout * 0.9)
        )
        
        self.fusion3 = nn.Sequential(
            nn.Linear(320, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(dropout * 0.8)
        )
        
        self.fusion4 = nn.Sequential(
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(dropout * 0.7)
        )
        
        self.fusion5 = nn.Sequential(
            nn.Linear(128, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Dropout(dropout * 0.5)
        )
        
        self.head = nn.Linear(64, 1)

    def forward(self, seq4x23, cell_idx, phen_idx, chr_idx, strand_idx, screen_idx):
        x = seq4x23.unsqueeze(1)
        
        x = self.conv2d_1(x)
        x = self.conv2d_2(x)
        
        residual = x
        x = self.conv2d_3_1(x)
        x = self.conv2d_3_2(x)
        x = F.relu(x + residual)
        
        x = self.conv2d_4(x)
        
        x1 = self.conv_multi1(x)
        x2 = self.conv_multi2(x)
        x3 = self.conv_multi3(x)
        
        x1_avg = F.adaptive_avg_pool2d(x1, 1).squeeze(-1).squeeze(-1)
        x1_max = F.adaptive_max_pool2d(x1, 1).squeeze(-1).squeeze(-1)
        x2_avg = F.adaptive_avg_pool2d(x2, 1).squeeze(-1).squeeze(-1)
        x2_max = F.adaptive_max_pool2d(x2, 1).squeeze(-1).squeeze(-1)
        x3_avg = F.adaptive_avg_pool2d(x3, 1).squeeze(-1).squeeze(-1)
        x3_max = F.adaptive_max_pool2d(x3, 1).squeeze(-1).squeeze(-1)
        
        x_seq = torch.cat([x1_avg, x1_max, x2_avg, x2_max, x3_avg, x3_max], dim=1)
        
        x_seq_proj1 = self.seq_proj1(x_seq)
        x_seq_proj2 = self.seq_proj2(x_seq)
        x_seq_proj3 = self.seq_proj3(x_seq)
        x_seq_proj = torch.cat([x_seq_proj1, x_seq_proj2, x_seq_proj3], dim=1)
        
        x_cell = self.cell_emb(cell_idx)
        x_phen = self.phen_emb(phen_idx)
        x_chr = self.chr_emb(chr_idx)
        x_strand = self.strand_emb(strand_idx)
        x_screen = self.screen_emb(screen_idx)
        
        cell_phen_inter = self.cell_phen_interaction(torch.cat([x_cell, x_phen], dim=1))
        cell_chr_inter = self.cell_chr_interaction(torch.cat([x_cell, x_chr], dim=1))
        phen_chr_inter = self.phen_chr_interaction(torch.cat([x_phen, x_chr], dim=1))
        triple_inter = self.triple_interaction(torch.cat([x_cell, x_phen, x_chr], dim=1))
        screen_cell_inter = self.screen_cell_interaction(torch.cat([x_screen, x_cell], dim=1))
        screen_phen_inter = self.screen_phen_interaction(torch.cat([x_screen, x_phen], dim=1))
        screen_chr_inter = self.screen_chr_interaction(torch.cat([x_screen, x_chr], dim=1))
        
        x = torch.cat([
            x_seq_proj,
            x_cell, x_phen, x_chr, x_strand, x_screen,
            cell_phen_inter, cell_chr_inter, phen_chr_inter, triple_inter, screen_cell_inter, screen_phen_inter, screen_chr_inter
        ], dim=1)
        
        x = self.fusion1(x)
        x = self.fusion2(x) + x
        x = self.fusion3(x)
        x = self.fusion4(x)
        x = self.fusion5(x)
        
        return self.head(x).squeeze(-1)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model = EnhancedCrisprCNN(
    base_channels=64,
    n_cell=n_cell, n_phen=n_ph, n_chr=n_chr, n_strand=n_strand, n_screen=n_screen,
    emb_dim=36,
    dropout=0.35
).to(device)
print("params:", sum(p.numel() for p in model.parameters()))

Using device: cuda
params: 1636641


In [None]:
criterion = nn.MSELoss(reduction="mean")

LEARNING_RATE = 2.2e-4
WEIGHT_DECAY = 8e-5
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="min", factor=0.5, patience=3, min_lr=1e-6
)

EPOCHS = 5
BATCH  = 512

def iter_minibatches(indexes, batch_size=256, shuffle=True, epoch=None, use_normalized=True):
    """
    Generate minibatches. If use_normalized=True, returns normalized targets.
    """
    idx = np.asarray(indexes)
    if shuffle:
        if epoch is not None:
            rng = np.random.default_rng(42 + epoch)
        else:
            rng = np.random.default_rng()
        rng.shuffle(idx)
    
    target_array = y_norm if use_normalized else y
    
    for start in range(0, len(idx), batch_size):
        mb = idx[start:start+batch_size]
        yield (
            torch.from_numpy(X_seq[mb]).to(device),           # (B, 4, 23) float32
            torch.from_numpy(X_cell[mb]).long().to(device),   # (B,) int64
            torch.from_numpy(X_ph[mb]).long().to(device),     # (B,) int64
            torch.from_numpy(X_chr[mb]).long().to(device),    # (B,) int64
            torch.from_numpy(X_strand[mb]).long().to(device), # (B,) int64
            torch.from_numpy(X_screen[mb]).long().to(device), # (B,) int64
            torch.from_numpy(target_array[mb]).to(device),    # (B,) float32
        )

import time

for epoch in range(1, EPOCHS + 1):
    epoch_start = time.time()
    print(f"\n{'='*60}")
    print(f"Epoch {epoch:02d}/{EPOCHS}")
    print(f"{'='*60}")
    
    model.train()
    train_sum, n_train = 0.0, 0
    batch_count = 0
    total_train_batches = len(idx_train) // BATCH + (1 if len(idx_train) % BATCH != 0 else 0)
    last_update_time = time.time()
    
    for seq, cl, ph, ch, st, scr, tgt_norm in iter_minibatches(idx_train, batch_size=BATCH, shuffle=True, epoch=epoch, use_normalized=True):
        pred_norm = model(seq, cl, ph, ch, st, scr)
        loss = criterion(pred_norm, tgt_norm)
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        train_sum += loss.item() * tgt_norm.size(0)
        n_train   += tgt_norm.size(0)
        batch_count += 1
        
        current_time = time.time()
        if batch_count % 100 == 0 or (current_time - last_update_time) >= 30:
            progress = (batch_count / total_train_batches) * 100
            elapsed = current_time - epoch_start
            batches_per_sec = batch_count / elapsed if elapsed > 0 else 0
            eta_seconds = (total_train_batches - batch_count) / batches_per_sec if batches_per_sec > 0 else 0
            eta_min = int(eta_seconds // 60)
            eta_sec = int(eta_seconds % 60)
            current_loss = train_sum / n_train if n_train > 0 else 0.0
            
            print(f"  Train: [{batch_count:5d}/{total_train_batches}] ({progress:5.1f}%) | "
                  f"Loss: {current_loss:.4f} | "
                  f"Speed: {batches_per_sec:.1f} batches/s | "
                  f"ETA: {eta_min:02d}:{eta_sec:02d}", flush=True)
            last_update_time = current_time

    train_loss = train_sum / n_train if n_train > 0 else 0.0
    train_time = time.time() - epoch_start
    
    print("  Running validation...", flush=True)
    val_start = time.time()
    model.eval()
    val_sum, n_val = 0.0, 0
    val_batch_count = 0
    total_val_batches = len(idx_val) // BATCH + (1 if len(idx_val) % BATCH != 0 else 0)
    
    with torch.no_grad():
        for seq, cl, ph, ch, st, scr, tgt_norm in iter_minibatches(idx_val, batch_size=BATCH, shuffle=False, use_normalized=True):
            pred_norm = model(seq, cl, ph, ch, st, scr)
            loss = criterion(pred_norm, tgt_norm)
            val_sum += loss.item() * tgt_norm.size(0)
            n_val   += tgt_norm.size(0)
            val_batch_count += 1
            
            if val_batch_count % 50 == 0:
                val_progress = (val_batch_count / total_val_batches) * 100
                print(f"    Val: [{val_batch_count:4d}/{total_val_batches}] ({val_progress:5.1f}%)", flush=True)

    val_loss = val_sum / n_val if n_val > 0 else 0.0
    val_time = time.time() - val_start
    total_time = time.time() - epoch_start
    
    scheduler.step(val_loss)
    
    print(f"\n  Epoch {epoch:02d} Summary:")
    print(f"    Train MSE (norm): {train_loss:.4f} | Val MSE (norm): {val_loss:.4f}")
    print(f"    Time: Train={train_time/60:.1f}min, Val={val_time/60:.1f}min, Total={total_time/60:.1f}min")
    print(f"{'='*60}\n", flush=True)


Epoch 01/5
  Train: [  100/57525] (  0.2%) | Loss: 1.0259 | Speed: 27.5 batches/s | ETA: 34:49
  Train: [  200/57525] (  0.3%) | Loss: 1.0036 | Speed: 33.5 batches/s | ETA: 28:29
  Train: [  300/57525] (  0.5%) | Loss: 1.0017 | Speed: 36.2 batches/s | ETA: 26:20
  Train: [  400/57525] (  0.7%) | Loss: 0.9966 | Speed: 37.7 batches/s | ETA: 25:13
  Train: [  500/57525] (  0.9%) | Loss: 0.9978 | Speed: 38.8 batches/s | ETA: 24:28
  Train: [  600/57525] (  1.0%) | Loss: 0.9895 | Speed: 39.6 batches/s | ETA: 23:57
  Train: [  700/57525] (  1.2%) | Loss: 0.9822 | Speed: 40.2 batches/s | ETA: 23:31
  Train: [  800/57525] (  1.4%) | Loss: 0.9775 | Speed: 40.7 batches/s | ETA: 23:12
  Train: [  900/57525] (  1.6%) | Loss: 0.9731 | Speed: 41.2 batches/s | ETA: 22:55
  Train: [ 1000/57525] (  1.7%) | Loss: 0.9689 | Speed: 41.5 batches/s | ETA: 22:41
  Train: [ 1100/57525] (  1.9%) | Loss: 0.9664 | Speed: 41.8 batches/s | ETA: 22:29
  Train: [ 1200/57525] (  2.1%) | Loss: 0.9654 | Speed: 42.0 bat

In [None]:
def mse_doc(yhat, y):
    yhat = np.asarray(yhat, dtype=np.float64)
    y    = np.asarray(y,    dtype=np.float64)
    n = y.size
    return float(np.sum((y - yhat)**2) / n)

def pearson_doc(x, y):
    x = np.asarray(x, dtype=np.float64)
    y = np.asarray(y, dtype=np.float64)
    n      = x.size
    sum_x  = np.sum(x)
    sum_y  = np.sum(y)
    sum_xy = np.sum(x * y)
    sum_x2 = np.sum(x * x)
    sum_y2 = np.sum(y * y)
    denom = np.sqrt((n * sum_x2 - sum_x * sum_x) * (n * sum_y2 - sum_y * sum_y))
    return float((n * sum_xy - sum_x * sum_y) / denom) if denom != 0.0 else 0.0

def _ranks_avg(a):
    a = np.asarray(a, dtype=np.float64)
    order = np.argsort(a, kind="mergesort")
    ranks = np.empty_like(order, dtype=np.float64)
    sa = a[order]
    diff = np.concatenate(([True], sa[1:] != sa[:-1], [True]))
    idx = np.flatnonzero(diff)
    for s, e in zip(idx[:-1], idx[1:]):
        ranks[order[s:e]] = 0.5 * (s + e - 1) + 1.0
    return ranks

def spearman_doc(x, y):
    rx = _ranks_avg(x)
    ry = _ranks_avg(y)
    d  = rx - ry
    n  = rx.size
    denom = n * (n * n - 1.0)
    return float(1.0 - (6.0 * np.sum(d * d)) / denom) if denom != 0.0 else 0.0

def accuracy_direction(yhat, y):
    yhat = np.asarray(yhat, dtype=np.float64)
    y    = np.asarray(y,    dtype=np.float64)
    return float(np.mean((yhat >= 0) == (y >= 0)))

@torch.no_grad()
def preds_and_trues(indexes, batch_size=256):
    """
    Get predictions and true values. Predictions are de-normalized from normalized space.
    """
    model.eval()
    ps_norm, ys_norm = [], []
    for seq, cl, ph, ch, st, scr, tgt_norm in iter_minibatches(indexes, batch_size=batch_size, shuffle=False, use_normalized=True):
        out_norm = model(seq, cl, ph, ch, st, scr)
        ps_norm.append(out_norm.detach().cpu().numpy())
        ys_norm.append(tgt_norm.detach().cpu().numpy())
    
    yhat_norm = np.concatenate(ps_norm)
    y_norm = np.concatenate(ys_norm)

    yhat_real = yhat_norm * sigma + mu
    y_real = y_norm * sigma + mu
    
    return yhat_real, y_real

def eval_split(indexes):
    yhat, y = preds_and_trues(indexes, batch_size=256)
    return {
        "MSE": mse_doc(yhat, y),
        "Pearson": pearson_doc(yhat, y),
        "Spearman": spearman_doc(yhat, y),
        "Accuracy": accuracy_direction(yhat, y),
    }

print("Validation:", eval_split(idx_val))
print("Test:",       eval_split(idx_test))

Validation: {'MSE': 0.38142815814306374, 'Pearson': 0.6516409461316826, 'Spearman': 0.5244778973105289, 'Accuracy': 0.6940650766902328}
Test: {'MSE': 0.38301759073216407, 'Pearson': 0.6513514160595977, 'Spearman': 0.5244435260986505, 'Accuracy': 0.6938502223511529}
