In [1]:
import argparse
from datetime import datetime
import os
import pandas as pd

from utils import seed_everything, get_device, check_max_len, clean_gpu
from utils_config import get_model_config
from trainer import Trainer
from utils import prepare_datasets

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
seed = 2
device = 0
target = 'mirna'
experiment_name = 'jupyter'
verbose=True
kmer=1

In [3]:
max_len = 1024
batch_size=32
d_model = 64
n_layer = 4
rc = False
is_trained = True
pooling_mode_target = 'mean'
is_convblock=False
is_cross_attention=True
rna_model = 'rnabert'

is_pretrained= False

In [4]:
model_name = 'tthymba'
load_pretrain_name='both__rnabert'


In [5]:
trainer = Trainer(
    seed=seed,
    device=device,
    experiment_name=experiment_name,
    verbose=verbose,
)

INFO:utils:Seeds set to 2.
✅ Logging setup complete.


In [39]:
max_len = check_max_len(max_len, model_name)
# df = pd.read_pickle(f'./data/df_final.pkl')
df = pd.read_pickle(f'./data/df_train_final.pkl')
df_test = pd.read_pickle(f'./data/df_test_final.pkl')

INFO:utils:Max length set to 1018 for model tthymba


In [41]:
# Load miRBase mature miRNA sequences
file_path = './data/mature.fa'
headers = []
sequences = []
with open(file_path, 'r') as file:
    sequence = ''
    for line in file:
        if line.startswith('>'):
            if sequence:
                sequences.append(sequence)
                sequence = ''
            headers.append(line.strip()[1:])
        else:
            sequence += line.strip()
    if sequence:
        sequences.append(sequence)

df_miRBase = pd.DataFrame({'Header': headers, 'miRNA': sequences})
df_miRBase['query_id'] = df_miRBase['Header'].str.split(' ', expand=True).iloc[:, 0]
df_miRBase['gene_id_MI'] = df_miRBase['Header'].str.split(' ', expand=True).iloc[:, 1]


In [44]:
df_miRBase['miRNA_ID'] = df_miRBase['query_id']

In [49]:
df = pd.merge(df, df_miRBase[['miRNA_ID','miRNA']], how='left')
df['target'] = df['miRNA']
df_test = pd.merge(df_test, df_miRBase[['miRNA_ID','miRNA']], how='left')
df_test['target'] = df_test['miRNA']
df.to_pickle(f'./data/df_train_final.pkl')
df_test.to_pickle(f'./data/df_test_final.pkl')

In [51]:
df['length'] = df['circRNA'].apply(len)
df_test['length'] = df_test['circRNA'].apply(len)
df = df[df['length'] <= max_len]
df_test = df_test[df_test['length'] <= max_len]
df['sum_sites'] = df['sites'].apply(sum)

In [53]:
train_dataset, valid_dataset, test_dataset, extra_dataset = prepare_datasets(
        df=df, 
        df_test=df_test,
        max_len=max_len + 2,  # 2 for special tokens (CLS and EOS)
        target=target, 
        seed=seed,
        kmer=1,
        # df_extra=df_test,
    )
trainer.set_dataloader(train_dataset, part=0, batch_size=batch_size)
trainer.set_dataloader(valid_dataset, part=1, batch_size=batch_size)
trainer.set_dataloader(test_dataset, part=2, batch_size=batch_size)

# Step 4. Configure Model
print('[Step 4] Configuring Model for training')
config = get_model_config(
    model_name=model_name,
    d_model=d_model,
    n_layer=n_layer,
    verbose=verbose,
    rc=rc,
    vocab_size=train_dataset.vocab_size
)

[Step 4] Configuring Model for training
- Model: tthymba
- d_model: 64
- n_layer: 4


In [54]:
trainer.define_model(
    config=config,
    model_name=model_name,
    pretrain=is_pretrained,
    pooling_mode_target=pooling_mode_target,
    is_convblock=is_convblock,
    is_cross_attention=is_cross_attention,
)

Model 'tthymba' initialized. Pretraining mode: False


In [55]:
trainer.set_pretrained_target(target=target, rna_model=rna_model)

Target model for mirna set with projection dimension 120


Target model for mirna set with projection dimension 120


In [56]:
trainer.model.embedding.word_embeddings.weight

Parameter containing:
tensor([[ 1.1242,  0.5686,  0.0561,  ..., -0.0315,  0.2471, -2.5913],
        [-0.4575,  0.5583,  0.3818,  ..., -2.0080, -1.2144, -0.9142],
        [ 0.1006, -0.6440,  1.0175,  ..., -0.6536, -0.4584, -1.7906],
        ...,
        [-1.7067,  0.0418, -0.3866,  ...,  0.9492, -1.8889,  0.6507],
        [ 1.5072,  1.2297, -1.5661,  ...,  1.6412, -0.6571,  1.0435],
        [ 1.0192, -0.3668,  0.7323,  ...,  1.9918, -0.4954, -0.5716]],
       device='cuda:0', requires_grad=True)

In [57]:
trainer.load_model(pretrain=is_pretrained, load_pretrain_name=load_pretrain_name, verbose=True)

X Error while loading the model: Error(s) in loading state_dict for ModelWrapper:
	size mismatch for backbone.up1.proj.weight: copying a param with shape torch.Size([64, 64, 1]) from checkpoint, the shape in current model is torch.Size([64, 128, 1]).
	size mismatch for backbone.up2.proj.weight: copying a param with shape torch.Size([64, 64, 1]) from checkpoint, the shape in current model is torch.Size([64, 128, 1]).


In [58]:
trainer.model.embedding.word_embeddings.weight

Parameter containing:
tensor([[ 1.1030, -0.1945, -0.9553,  ...,  1.0378,  0.0552, -0.8059],
        [-0.1885, -1.2301,  0.2224,  ..., -1.0126, -0.5797,  0.6166],
        [ 0.6543,  1.1731, -1.3243,  ..., -0.1927,  0.4507,  1.3230],
        ...,
        [-0.2757,  1.2168, -0.7317,  ...,  0.1693, -0.2110,  0.6153],
        [-0.2423,  0.1703,  0.5274,  ..., -0.1308, -0.2674,  0.1774],
        [ 1.2349,  1.8327, -0.7270,  ...,  1.0371,  0.2363, -1.4482]],
       device='cuda:0', requires_grad=True)

In [59]:
trainer.rc = False
trainer.task = 'both'
trainer.verbose = True

In [60]:
results = trainer.inference(data_loader=trainer.test_loader)

In [61]:
# results.pop('lengths_sites')
df_results = pd.DataFrame(results)

ValueError: All arrays must be of the same length

trainer.evaluate(trainer.test_loader)

In [None]:
trainer.best_threshold_site

In [None]:
import torch
import pandas as pd

def flatten_result_dict_per_sample(result_dict, sequences=None):
    """
    Converts nested result dict to a per-sample DataFrame.
    
    Parameters:
        result_dict: dict with keys ['binding', 'sites', 'lengths']
        sequences: optional list of circRNA sequence strings
    """
    # Unpack
    binding_logits = result_dict['binding'][0]     # list of [2] tensors
    site_logits = result_dict['sites'][0]          # list of [L, 2] tensors
    lengths = result_dict['lengths'][0]            # list of [1] tensors

    data = []
    for i in range(len(binding_logits)):
        sample = {
            "sample_id": i,
            "binding_logits": binding_logits[i].detach().cpu().tolist(),
            "site_logits": site_logits[i].detach().cpu().tolist(),
            "length": int(lengths[i].item())
        }
        if sequences is not None:
            sample["circRNA"] = sequences[i]
        data.append(sample)

    return pd.DataFrame(data)


def get_iou(pred_mask, true_mask):
    intersection = np.logical_and(pred_mask, true_mask).sum()
    union = np.logical_or(pred_mask, true_mask).sum()
    return intersection / union if union > 0 else 0.0

In [None]:
df_samples = flatten_result_dict_per_sample(results, sequences=df_test['circRNA'].values)

In [None]:
df_results = pd.merge(df_samples, df_test, on='circRNA', how='outer')

In [None]:
def linear_binding_site_plot_with_overlap(df, sample_idx=0, threshold=0.5, min_span_len=20):
    row = df.iloc[sample_idx]

    # --- logits → probs ---
    logits = torch.tensor(row["site_logits"])
    if logits.dim() == 3:
        logits = logits.squeeze(0)
    elif logits.dim() == 1:
        logits = logits.unsqueeze(-1)

    if logits.size(-1) == 1:
        probs = torch.sigmoid(logits).squeeze(-1).cpu().numpy()
    elif logits.size(-1) == 2:
        probs = F.softmax(logits, dim=-1)[:, 1].cpu().numpy()
    else:
        raise ValueError(f"Unexpected logits shape: {logits.shape}")

    # --- labels ---
    sites = row["sites"]
    if isinstance(sites, list) and isinstance(sites[0], list):
        sites = sites[0]
    sites = np.array(sites)

    min_len = min(len(probs), len(sites))
    probs = probs[:min_len]
    sites = sites[:min_len]

    preds_binary = (probs >= threshold).astype(int)

    # --- Span-based mask ---
    pred_mask = np.zeros_like(preds_binary)
    for start, end in extract_positive_spans(preds_binary, min_span_len):
        pred_mask[start:end] = 1

    true_mask = np.zeros_like(sites)
    for start, end in extract_positive_spans(sites, min_span_len):
        true_mask[start:end] = 1

    overlap_mask = np.logical_and(pred_mask, true_mask)

    # --- IoU 계산 ---
    iou = get_iou(pred_mask, true_mask)

    # --- Plot ---
    plt.figure(figsize=(12, 4))
    ax = plt.gca()
    ax.set_xlim(0, len(probs))

    # Overlap zone
    for i in range(len(overlap_mask)):
        if overlap_mask[i]:
            ax.axvspan(i, i+1, color='limegreen', alpha=0.5, label='Overlap' if i == np.where(overlap_mask)[0][0] else "")

    # Predicted spans
    for start, end in extract_positive_spans(preds_binary, min_span_len):
        ax.axvspan(start, end, color='skyblue', alpha=0.4, label="Predicted" if start == extract_positive_spans(preds_binary, min_span_len)[0][0] else "")

    # True spans
    for start, end in extract_positive_spans(sites, min_span_len):
        ax.axvspan(start, end, color='orange', alpha=0.3, label="True" if start == extract_positive_spans(sites, min_span_len)[0][0] else "")

    # --- Line plot 제거됨 ---
    ax.plot(probs, label="Predicted probability", color='black', linewidth=1)

    plt.title(f"circRNA: {row['isoform_ID']} | miRNA: {row['miRNA_ID']} | IoU = {iou:.2f}")
    plt.xlabel("Sequence position")
    plt.ylabel("Binding probability")
    plt.legend(loc="lower right")
    plt.grid(True)
    plt.tight_layout()
    plt.show()
    return iou


In [None]:

import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np

def compute_iou_for_sample(row, threshold=0.5, min_span_len=20):
    # --- logits → probs ---
    logits = torch.tensor(row["site_logits"])
    if logits.dim() == 3:
        logits = logits.squeeze(0)
    elif logits.dim() == 1:
        logits = logits.unsqueeze(-1)

    if logits.size(-1) == 1:
        probs = torch.sigmoid(logits).squeeze(-1).cpu().numpy()
    elif logits.size(-1) == 2:
        probs = F.softmax(logits, dim=-1)[:, 1].cpu().numpy()
    else:
        raise ValueError(f"Unexpected logits shape: {logits.shape}")

    # --- labels ---
    sites = row["sites"]
    if isinstance(sites, list) and isinstance(sites[0], list):
        sites = sites[0]
    sites = np.array(sites)

    min_len = min(len(probs), len(sites))
    probs = probs[:min_len]
    sites = sites[:min_len]

    preds_binary = (probs >= threshold).astype(int)

    # --- Span-based mask ---
    pred_mask = np.zeros_like(preds_binary)
    for start, end in extract_positive_spans(preds_binary, min_span_len):
        pred_mask[start:end] = 1

    true_mask = np.zeros_like(sites)
    for start, end in extract_positive_spans(sites, min_span_len):
        true_mask[start:end] = 1

    # --- IoU 계산 ---
    intersection = np.logical_and(pred_mask, true_mask).sum()
    union = np.logical_or(pred_mask, true_mask).sum()
    iou = intersection / union if union > 0 else 0.0

    return iou

def plot_sequence_coloring_from_df(df, sample_idx=0, threshold=0.5, min_span_len=20, row_width=100, fontsize=10):
    row = df.iloc[sample_idx]

    # --- sequence ---
    seq = row["circRNA"]
    if isinstance(seq, list):
        seq = ''.join(seq)
    sequence = list(seq)

    # --- logits → probs ---
    logits = torch.tensor(row["site_logits"])
    if logits.dim() == 3:
        logits = logits.squeeze(0)
    elif logits.dim() == 1:
        logits = logits.unsqueeze(-1)

    if logits.size(-1) == 1:
        probs = torch.sigmoid(logits).squeeze(-1).cpu().numpy()
    elif logits.size(-1) == 2:
        probs = F.softmax(logits, dim=-1)[:, 1].cpu().numpy()
    else:
        raise ValueError(f"Unexpected logits shape: {logits.shape}")

    # --- labels ---
    sites = row["sites"]
    if isinstance(sites, list) and isinstance(sites[0], list):
        sites = sites[0]
    sites = np.array(sites)

    min_len = min(len(probs), len(sites), len(sequence))
    probs = probs[:min_len]
    sites = sites[:min_len]
    sequence = sequence[:min_len]

    preds_binary = (probs >= threshold).astype(int)

    # --- Span-based mask ---
    pred_mask = np.zeros_like(preds_binary)
    for start, end in extract_positive_spans(preds_binary, min_span_len):
        pred_mask[start:end] = 1

    true_mask = np.zeros_like(sites)
    for start, end in extract_positive_spans(sites, min_span_len):
        true_mask[start:end] = 1

    # --- 시각화 ---
    length = len(sequence)
    n_rows = (length + row_width - 1) // row_width

    fig, ax = plt.subplots(figsize=(row_width * 0.1, n_rows * 0.6))
    ax.set_xlim(0, row_width)
    ax.set_ylim(-n_rows, 1)
    ax.axis("off")

    def get_color(i):
        if true_mask[i] and pred_mask[i]:
            return "limegreen"
        elif true_mask[i]:
            return "orange"
        elif pred_mask[i]:
            return "skyblue"
        else:
            return None

    for i in range(length):
        row_idx = -(i // row_width)
        col = i % row_width
        base = sequence[i]
        color = get_color(i)

        if color:
            rect = patches.Rectangle((col, row_idx), 1, 1, color=color, alpha=0.6)
            ax.add_patch(rect)

        ax.text(col + 0.5, row_idx + 0.5, base, ha='center', va='center', fontsize=fontsize, family='monospace')

    plt.tight_layout()
    plt.show()


In [None]:
threshold = 0.6
min_span_len = 25

best = {
    'pair': [],
    'iou': [],
    'idx': [],
    'test':[]
}

ious = []
best_iou = 0.0
best_idx = -1  # 초기 인덱스 설정 (예: -1이면 아무 것도 아직 선택되지 않았다는 뜻)

for i in range(len(df_results)):
    iou = compute_iou_for_sample(df_results.iloc[i], threshold=threshold, min_span_len=min_span_len)
    ious.append(iou)
    print(f'\r {i}', end='\r')

    if iou > 0.6 and iou < 0.8:
        best['test'].append([i, iou])
        print(f"test match found at index {i} with IOU: {iou:.4f}")
    if iou > best_iou:
        best_iou = iou
        best_idx = i
        print(f"Best match found at index {best_idx} with IOU: {best_iou:.4f}")

# 최고 IOU가 존재할 경우 best 딕셔너리에 저장
if best_idx != -1:
    best['pair'].append([best_idx, best_iou])
    best['iou'].append(best_iou)
    best['idx'].append(best_idx)

    print(f"Best match found at index {best_idx} with IOU: {best_iou:.4f}")
else:
    print("No IOU above threshold was found.")


In [None]:

import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np

def linear_binding_site_plot_with_overlap(df, sample_idx=0, threshold=0.5, min_span_len=20):
    row = df.iloc[sample_idx]

    # --- logits → probs ---
    logits = torch.tensor(row["site_logits"])
    if logits.dim() == 3:
        logits = logits.squeeze(0)
    elif logits.dim() == 1:
        logits = logits.unsqueeze(-1)

    if logits.size(-1) == 1:
        probs = torch.sigmoid(logits).squeeze(-1).cpu().numpy()
    elif logits.size(-1) == 2:
        probs = F.softmax(logits, dim=-1)[:, 1].cpu().numpy()
    else:
        raise ValueError(f"Unexpected logits shape: {logits.shape}")

    # --- labels ---
    sites = row["sites"]
    if isinstance(sites, list) and isinstance(sites[0], list):
        sites = sites[0]
    sites = np.array(sites)

    min_len = min(len(probs), len(sites))
    probs = probs[:min_len]
    sites = sites[:min_len]

    preds_binary = (probs >= threshold).astype(int)

    # --- Span-based mask ---
    pred_mask = np.zeros_like(preds_binary)
    for start, end in extract_positive_spans(preds_binary, min_span_len):
        pred_mask[start:end] = 1

    true_mask = np.zeros_like(sites)
    for start, end in extract_positive_spans(sites, min_span_len):
        true_mask[start:end] = 1

    overlap_mask = np.logical_and(pred_mask, true_mask)

    # --- IoU 계산 ---
    iou = get_iou(pred_mask, true_mask)

    # --- Plot ---
    plt.figure(figsize=(8, 4))
    ax = plt.gca()
    ax.set_xlim(0, len(probs))

    # Overlap zone
    for i in range(len(overlap_mask)):
        if overlap_mask[i]:
            ax.axvspan(i, i+1, color='limegreen', alpha=0.5, label='Overlap' if i == np.where(overlap_mask)[0][0] else "")

    # Predicted spans
    for start, end in extract_positive_spans(preds_binary, min_span_len):
        ax.axvspan(start, end, color='skyblue', alpha=0.4, label="Predicted" if start == extract_positive_spans(preds_binary, min_span_len)[0][0] else "")

    # True spans
    for start, end in extract_positive_spans(sites, min_span_len):
        ax.axvspan(start, end, color='orange', alpha=0.3, label="True" if start == extract_positive_spans(sites, min_span_len)[0][0] else "")

    # --- Line plot 제거됨 ---
    ax.plot(probs, label="Predicted probability", color='black', linewidth=1)

    plt.title(f"circRNA: {row['isoform_ID']} | miRNA: {row['miRNA_ID']} ")
    plt.xlabel("Sequence position")
    plt.ylabel("Binding probability")
    plt.legend(loc="lower right")
    plt.grid(True)
    plt.tight_layout()
    plt.show()
    return iou


def plot_sequence_coloring_from_df(df, sample_idx=0, threshold=0.5, min_span_len=20, row_width=100, fontsize=10):
    row = df.iloc[sample_idx]

    # --- sequence ---
    seq = row["circRNA"]
    if isinstance(seq, list):
        seq = ''.join(seq)
    sequence = list(seq)

    # --- logits → probs ---
    logits = torch.tensor(row["site_logits"])
    if logits.dim() == 3:
        logits = logits.squeeze(0)
    elif logits.dim() == 1:
        logits = logits.unsqueeze(-1)

    if logits.size(-1) == 1:
        probs = torch.sigmoid(logits).squeeze(-1).cpu().numpy()
    elif logits.size(-1) == 2:
        probs = F.softmax(logits, dim=-1)[:, 1].cpu().numpy()
    else:
        raise ValueError(f"Unexpected logits shape: {logits.shape}")

    # --- labels ---
    sites = row["sites"]
    if isinstance(sites, list) and isinstance(sites[0], list):
        sites = sites[0]
    sites = np.array(sites)

    min_len = min(len(probs), len(sites), len(sequence))
    probs = probs[:min_len]
    sites = sites[:min_len]
    sequence = sequence[:min_len]

    preds_binary = (probs >= threshold).astype(int)

    # --- Span-based mask ---
    pred_mask = np.zeros_like(preds_binary)
    for start, end in extract_positive_spans(preds_binary, min_span_len):
        pred_mask[start:end] = 1

    true_mask = np.zeros_like(sites)
    for start, end in extract_positive_spans(sites, min_span_len):
        true_mask[start:end] = 1

    # --- 시각화 ---
    length = len(sequence)
    n_rows = (length + row_width - 1) // row_width

    fig, ax = plt.subplots(figsize=(row_width * 0.1, n_rows * 0.5))
    ax.set_xlim(0, row_width)
    ax.set_ylim(-n_rows, 1)
    ax.axis("off")

    def get_color(i):
        if true_mask[i] and pred_mask[i]:
            return "limegreen"
        elif true_mask[i]:
            return "orange"
        elif pred_mask[i]:
            return "skyblue"
        else:
            return None

    for i in range(length):
        row_idx = -(i // row_width)
        col = i % row_width
        base = sequence[i]
        color = get_color(i)

        if color:
            rect = patches.Rectangle((col, row_idx), 1, 1, color=color, alpha=0.6)
            ax.add_patch(rect)

        ax.text(col + 0.5, row_idx + 0.5, base, ha='center', va='center', fontsize=fontsize, family='monospace')

    plt.tight_layout()
    plt.show()


In [None]:
# idx = 19135      
idx = 83393                       
threshold=0.5

min_span_len = 1

iou = linear_binding_site_plot_with_overlap(df_results, sample_idx=idx, threshold=threshold, min_span_len=min_span_len)
plot_sequence_coloring_from_df(df_results, sample_idx=idx, threshold=threshold, min_span_len=min_span_len)
                      
min_span_len = len(df_results['miRNA'].iloc[idx])

iou = linear_binding_site_plot_with_overlap(df_results, sample_idx=idx, threshold=threshold, min_span_len=min_span_len)
plot_sequence_coloring_from_df(df_results, sample_idx=idx, threshold=threshold, min_span_len=min_span_len)



In [None]:
import torch
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

def extract_positive_spans(binary_array, min_span_len=20):
    spans = []
    start = None
    for i, val in enumerate(binary_array):
        if val == 1 and start is None:
            start = i
        elif val == 0 and start is not None:
            if i - start >= min_span_len:
                spans.append((start, i))
            start = None
    if start is not None and len(binary_array) - start >= min_span_len:
        spans.append((start, len(binary_array)))
    return spans

def span_filter_mask(binary_array, min_span_len=20):
    mask = np.zeros_like(binary_array)
    for start, end in extract_positive_spans(binary_array, min_span_len):
        mask[start:end] = 1
    return mask

def evaluate_span_filtered(preds_binary: torch.Tensor, sites_logits: torch.Tensor, min_span_len=20):
    """
    preds_binary: torch.Tensor of shape (B, L)
    sites_logits: torch.Tensor of shape (B, L, 2)
    """
    preds_binary = preds_binary.detach().cpu().numpy()
    probs = torch.softmax(sites_logits, dim=-1)
    true_binary = torch.argmax(probs, dim=-1).cpu().numpy()  # shape (B, L)

    all_y_true, all_y_pred = [], []

    for b in range(preds_binary.shape[0]):
        pred_mask = span_filter_mask(preds_binary[b], min_span_len)
        true_mask = span_filter_mask(true_binary[b], min_span_len)

        all_y_true.append(true_mask)
        all_y_pred.append(pred_mask)

    y_true = np.concatenate(all_y_true)
    y_pred = np.concatenate(all_y_pred)

    result = {
        'accuracy': accuracy_score(y_true, y_pred),
        'precision': precision_score(y_true, y_pred, zero_division=0),
        'recall': recall_score(y_true, y_pred, zero_division=0),
        'f1': f1_score(y_true, y_pred, zero_division=0)
    }
    return result


In [None]:
all_y_true = []
all_y_pred = []

for site_logits, site_labels in zip(df_results['site_logits'], df_results['sites']):
    # logits → probs
    logits = torch.tensor(site_logits)
    if logits.dim() == 1:
        logits = logits.unsqueeze(-1)
    if logits.size(-1) == 1:
        probs = torch.sigmoid(logits).squeeze(-1).numpy()
    elif logits.size(-1) == 2:
        probs = torch.softmax(logits, dim=-1)[:, 1].numpy()
    else:
        raise ValueError(f"Unexpected logits shape: {logits.shape}")

    # 예측값
    preds_binary = (probs >= 0.5).astype(int)

    # 정답값
    labels = site_labels
    if isinstance(labels, torch.Tensor):
        labels = labels.numpy()
    if labels.ndim == 2 and labels.shape[1] == 2:
        labels = np.argmax(labels, axis=-1)
    elif labels.ndim == 1:
        pass
    else:
        raise ValueError(f"Unexpected label shape: {labels.shape}")

    # 최소 길이 맞추기
    min_len = min(len(preds_binary), len(labels))
    preds_binary = preds_binary[:min_len]
    labels = labels[:min_len]

    # span filtering
    pred_mask = span_filter_mask(preds_binary, min_span_len=22)
    true_mask = span_filter_mask(labels, min_span_len=22)

    all_y_pred.append(pred_mask)
    all_y_true.append(true_mask)

# concat 후 평가
y_pred = np.concatenate(all_y_pred)
y_true = np.concatenate(all_y_true)

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

metrics = {
    "accuracy": accuracy_score(y_true, y_pred),
    "precision": precision_score(y_true, y_pred, zero_division=0),
    "recall": recall_score(y_true, y_pred, zero_division=0),
    "f1": f1_score(y_true, y_pred, zero_division=0)
}
for k, v in metrics.items():
    print(f"{k}: {v:.4f}")
