In [None]:
!pip install pytorch-transformers

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import platform, sys, torch, numpy as np
print("Python:", sys.version)
print("Machine:", platform.machine())       # should be arm64 on Apple Silicon
print("Torch:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("MPS available:", torch.backends.mps.is_available())
print("NumPy:", np.__version__)

In [None]:
config = {}
def get_device():
    return (
        torch.device("cuda:0") if config.get("use_gpu", False) else torch.device("cpu")
    )
DEVICE = get_device()

In [None]:
import torch
import numpy as np
import json

In [None]:
train_human_path = '/content/drive/MyDrive/train_human.npy'
train_ai_path = '/content/drive/MyDrive/train_ai.npy'
validation_data_path = '/content/drive/MyDrive/validation.jsonl'
net_pt_path = '/content/drive/MyDrive/net.pt'

In [None]:
HWT_ref = np.load(train_human_path)
MGT_ref = np.load(train_ai_path)


In [None]:
HWT_ref.shape

In [None]:
def load_validation_data(jsonl_path):
    all_segments = []  # list of (n_segments, 100, 768)
    all_labels = []    # list of int (0 or 1)
    all_ids = []       # list of int or str

    with open(jsonl_path, 'r') as f:
        for line in f:
            entry = json.loads(line)

            features = np.array(entry['features'])  # shape: (n_segments, 100, 768)
            label = entry['label']
            sample_id = entry['id']

            all_segments.append(features)
            all_labels.append(label)
            all_ids.append(sample_id)

    return all_segments, all_labels, all_ids

X_val_segments, y_val_segments, val_ids = load_validation_data(validation_data_path)
# X_val_segments Shape of second sample: (20, 100, 768)

X_val_segments = [np.asarray(segment).astype('float32') for segment in X_val_segments]
y_val_segments = [np.float32(label) for label in y_val_segments]

In [None]:
#====================================================
# i don't know why it works. don't change. it is oldest version, but only it works
#====================================================

import torch
from torch import nn
from collections import namedtuple
import math
from pytorch_transformers.modeling_bert import (
    BertEncoder,
    BertPreTrainedModel,
    BertConfig,
)
import os, torch
from collections import namedtuple

class GeLU(nn.Module):
    """Implementation of the gelu activation function.
    For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
    0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
    Also see https://arxiv.org/abs/1606.08415
    """

    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))


class BertLayerNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-12):
        """Construct a layernorm module in the TF style (epsilon inside the square root)."""
        super(BertLayerNorm, self).__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.bias = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

    def forward(self, x):
        u = x.mean(-1, keepdim=True)
        s = (x - u).pow(2).mean(-1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.variance_epsilon)
        return self.weight * x + self.bias


class mlp_meta(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(config.hid_dim, config.hid_dim),
            GeLU(),
            BertLayerNorm(config.hid_dim, eps=1e-12),
            nn.Dropout(config.dropout),
        )

    def forward(self, x):
        return self.mlp(x)


class Bert_Transformer_Layer(BertPreTrainedModel):
    def __init__(self, fusion_config):
        super().__init__(BertConfig(**fusion_config))
        bertconfig_fusion = BertConfig(**fusion_config)
        self.encoder = BertEncoder(bertconfig_fusion)
        self.init_weights()

    def forward(self, input, mask=None):
        """
        input:(bs, 4, dim)
        """
        batch, feats, dim= input.size() # was batch, feats, dim
        if mask is not None:
            mask_ = torch.ones(size=(batch, feats), device=mask.device)
            mask_[:, 1:] = mask
            mask_ = torch.bmm(
                mask_.view(batch, 1, -1).transpose(1, 2), mask_.view(batch, 1, -1)
            )
            mask_ = mask_.unsqueeze(1)

        else:
            mask = torch.Tensor([1.0]).to(input.device)
            mask_ = mask.repeat(batch, 1, feats, feats)

        extend_mask = (1 - mask_) * -10000
        assert not extend_mask.requires_grad
        head_mask = [None] * self.config.num_hidden_layers

        enc_output = self.encoder(input, extend_mask, head_mask=head_mask)
        output = enc_output[0]
        all_attention = enc_output[1]

        return output, all_attention


class mmdPreModel(nn.Module):
    def __init__(
        self,
        config,
        num_mlp=0,
        transformer_flag=False,
        num_hidden_layers=1,
        mlp_flag=True,
    ):
        super(mmdPreModel, self).__init__()
        self.num_mlp = num_mlp
        self.transformer_flag = transformer_flag
        self.mlp_flag = mlp_flag
        token_num = config.token_num
        self.mlp = nn.Sequential(
            nn.Linear(config.in_dim, config.hid_dim),
            GeLU(),
            BertLayerNorm(config.hid_dim, eps=1e-12),
            nn.Dropout(config.dropout),
            # nn.Linear(config.hid_dim, config.out_dim),
        )
        self.fusion_config = {
            "hidden_size": config.in_dim,
            "num_hidden_layers": num_hidden_layers,
            "num_attention_heads": 4,
            "output_attentions": True,
        }
        if self.num_mlp > 0:
            self.mlp2 = nn.ModuleList([mlp_meta(config) for _ in range(self.num_mlp)])
        if self.transformer_flag:
            self.transformer = Bert_Transformer_Layer(self.fusion_config)
        self.feature = nn.Linear(config.hid_dim * token_num, config.out_dim)

    def forward(self, features):
        """
        input: [batch, token_num, hidden_size], output: [batch, token_num * config.out_dim]
        """

        if self.transformer_flag:
            features, _ = self.transformer(features)
        if self.mlp_flag:
            features = self.mlp(features)

        if self.num_mlp > 0:
            # features = self.mlp2(features)
            for _ in range(1):
                for mlp in self.mlp2:
                    features = mlp(features)

        features = self.feature(features.view(features.shape[0], -1))
        return features  # features.view(features.shape[0], -1)


def get_device():
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_checkpoint(path, map_location):
    if not os.path.exists(path) or os.path.getsize(path) == 0:
        raise FileNotFoundError(f"Checkpoint missing or empty: {path}")
    with open(path, "rb") as f:
        return torch.load(f, map_location=map_location)

def save_checkpoint(path, net, sigma, sigma0_u, ep):
    payload = {
        "net": net.state_dict(),
        "sigma": sigma.detach().cpu(),
        "sigma0_u": sigma0_u.detach().cpu(),
        "ep": ep.detach().cpu(),
    }
    tmp = path + ".tmp"
    os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
    with open(tmp, "wb"):
        torch.save(payload, tmp)
    os.replace(tmp, path)

# TODO: replace with your real initial values / shapes if not scalars
def default_init_params(device):
    sigma    = torch.tensor(1.0, device=device)
    sigma0_u = torch.tensor(1.0, device=device)
    ep       = torch.tensor(1e-6, device=device)
    return sigma, sigma0_u, ep



In [None]:
# =========================
# MMD + NetLoader (drop-in)
# =========================
import math
import torch
import torch.nn as nn
from collections import namedtuple

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Your feature extractor class must exist/imported as mmdPreModel ---
# from your_module import mmdPreModel

# ---------------------------
# NetLoader: make params learn
# ---------------------------
class NetLoader:
    def __init__(self, checkpoint_filename= net_pt_path):
        token_num, hidden_size = 100, 768
        Config = namedtuple("Config", ["in_dim", "hid_dim", "dropout", "out_dim", "token_num"])
        config = Config(in_dim=hidden_size, token_num=token_num, hid_dim=512, dropout=0.2, out_dim=300)
        self.config = config

        self.net = mmdPreModel(config=config, num_mlp=0, transformer_flag=True, num_hidden_layers=1)

        ckpt = torch.load(checkpoint_filename, map_location="cpu")
        self.net.load_state_dict(ckpt["net"])   # load network weights first
        self.net = self.net.to(DEVICE)

        # Ensure sigma, sigma0_u, ep are learnable Parameters ON self.net
        self._ensure_learnable_param("sigma",    ckpt.get("sigma",    torch.tensor(1.0)))
        self._ensure_learnable_param("sigma0_u", ckpt.get("sigma0_u", torch.tensor(1.0)))
        self._ensure_learnable_param("ep",       ckpt.get("ep",       torch.tensor(1e-6)))

        # Optional mirrors (so code that referenced net_loader.sigma still works)
        self.sigma    = self.net.sigma
        self.sigma0_u = self.net.sigma0_u
        self.ep       = self.net.ep

        # For inference feature extraction
        self.net.eval()

    def _ensure_learnable_param(self, name: str, value):
        t = torch.as_tensor(value, dtype=torch.float32, device=DEVICE).clone().detach()
        # If already a Parameter, just copy into .data
        if isinstance(self.net._parameters.get(name, None), nn.Parameter):
            with torch.no_grad():
                self.net._parameters[name].data.copy_(t)
            self.net._parameters[name].requires_grad = True
            return
        # If exists as a buffer, remove it so we can register a Parameter
        if name in self.net._buffers:
            self.net._buffers.pop(name)
        # Register as Parameter
        self.net.register_parameter(name, nn.Parameter(t, requires_grad=True))

# ------------------------
# Differentiable utilities
# ------------------------
def Pdist2(x, y=None):
    """Pairwise squared distances between rows of x and y."""
    if y is None:
        y = x
    x2 = (x * x).sum(dim=1, keepdim=True)                 # (m,1)
    y2 = (y * y).sum(dim=1, keepdim=True).transpose(0, 1) # (1,n)
    d2 = x2 + y2 - 2.0 * (x @ y.transpose(0, 1))
    return torch.clamp_min(d2, 0.0)                       # numeric floor, keeps grads

def flexible_kernel(X, Y, X_org, Y_org, sigma, sigma0=0.1, epsilon=1e-8):
    """
    Flexible kernel as in MMDu:
      K = (1-ε) * exp( - (Dxy/σ0)^L - Dxy_org/σ ) + ε * exp( - Dxy_org/σ )
    All ops tensorized & differentiable wrt sigma, sigma0, epsilon.
    """
    dtype, device = X.dtype, X.device
    sigma   = torch.as_tensor(sigma,   dtype=dtype, device=device)
    sigma0  = torch.as_tensor(sigma0,  dtype=dtype, device=device)
    epsilon = torch.as_tensor(epsilon, dtype=dtype, device=device)

    # keep positive & stable
    sigma   = torch.clamp_min(sigma,  1e-12)
    sigma0  = torch.clamp_min(sigma0, 1e-12)
    epsilon = torch.clamp(epsilon, 1e-12, 1.0 - 1e-12)

    Dxy     = Pdist2(X, Y)
    Dxy_org = Pdist2(X_org, Y_org)

    L = 1.0
    term_main = torch.exp(- (Dxy / sigma0) ** L - (Dxy_org / sigma))
    term_aux  = torch.exp(- (Dxy_org / sigma))
    return (1.0 - epsilon) * term_main + epsilon * term_aux

def MMD_Diff_Var(Kyy, Kzz, Kxy, Kxz, epsilon=1e-8):
    """
    Variance of the difference statistic MMD(X,Y) - MMD(X,Z).
    Differentiable (uses clamp_min instead of branching + .item()).
    """
    device, dtype = Kyy.device, Kyy.dtype
    epsilon = torch.as_tensor(epsilon, dtype=dtype, device=device)
    epsilon = torch.clamp_min(epsilon, 1e-12)

    m = Kxy.shape[0]
    n = Kyy.shape[0]
    r = Kzz.shape[0]

    Kyynd = Kyy - torch.diag(torch.diag(Kyy))
    Kzznd = Kzz - torch.diag(torch.diag(Kzz))

    u_yy = Kyynd.sum() / (n * (n - 1))
    u_zz = Kzznd.sum() / (r * (r - 1))
    u_xy = Kxy.sum()   / (m * n)
    u_xz = Kxz.sum()   / (m * r)

    t1 = (Kyynd.t() @ Kyynd).sum() / (n**3)      - u_yy**2
    t2 = (Kxy.t()   @ Kxy   ).sum() / (n**2 * m) - u_xy**2
    t3 = (Kxy       @ Kxy.t()).sum() / (n * m**2) - u_xy**2
    t4 = (Kzznd.t() @ Kzznd).sum() / (r**3)      - u_zz**2
    t5 = (Kxz       @ Kxz.t()).sum() / (r * m**2) - u_xz**2
    t6 = (Kxz.t()   @ Kxz   ).sum() / (r**2 * m) - u_xz**2
    t7 = (Kyynd     @ Kxy.t()).sum() / (n**2 * m) - u_yy * u_xy
    t8 = (Kxy.t()   @ Kxz   ).sum() / (n * m * r) - u_xz * u_xy
    t9 = (Kzznd     @ Kxz.t()).sum() / (r**2 * m) - u_zz * u_xz

    zeta1 = torch.clamp_min(t1 + t2 + t3 + t4 + t5 + t6 - 2 * (t7 + t8 + t9), epsilon)
    zeta2 = torch.clamp_min(
        (1.0 / (m * (m - 1))) * ((Kyynd - Kzznd - Kxy.t() - Kxy + Kxz + Kxz.t()) ** 2).sum()
        - (u_yy - 2.0 * u_xy - (u_zz - 2.0 * u_xz)) ** 2,
        epsilon
    )

    Var    = (4.0 * (m - 2) / (m * (m - 1))) * zeta1
    Var_z2 = Var + (2.0 / (m * (m - 1))) * zeta2

    # lightweight debug (safe for printing only)
    data = {
        "t1": t1.detach().cpu().item(),
        "t2": t2.detach().cpu().item(),
        "t3": t3.detach().cpu().item(),
        "t4": t4.detach().cpu().item(),
        "t5": t5.detach().cpu().item(),
        "t6": t6.detach().cpu().item(),
        "t7": t7.detach().cpu().item(),
        "t8": t8.detach().cpu().item(),
        "t9": t9.detach().cpu().item(),
        "zeta1": zeta1.detach().cpu().item(),
        "zeta2": zeta2.detach().cpu().item(),
    }
    return Var, Var_z2, data

# --------------------------------
# 3-sample test (differentiable t)
# --------------------------------
def MMD_3_Sample_Test(
    ref_fea, fea_y, fea_z,
    ref_fea_org, fea_y_org, fea_z_org,
    sigma, sigma0, epsilon, alpha,
):
    """
    Returns: h (int), p_value (tensor), t_std (tensor), t_raw (tensor), Diff_Var (tensor)
    You can still call exactly as:
      h, p_value, t, *rest = MMD_3_Sample_Test(..., net.sigma, net.sigma0_u, net.ep, 0.05)
    """
    # Use features as-is; gradients will only flow into sigma/sigma0_u/epsilon
    X, Y, Z = ref_fea, fea_y, fea_z
    X_org, Y_org, Z_org = ref_fea_org, fea_y_org, fea_z_org

    Kyy = flexible_kernel(Y, Y, Y_org, Y_org, sigma, sigma0, epsilon)
    Kzz = flexible_kernel(Z, Z, Z_org, Z_org, sigma, sigma0, epsilon)
    Kxy = flexible_kernel(X, Y, X_org, Y_org, sigma, sigma0, epsilon)
    Kxz = flexible_kernel(X, Z, X_org, Z_org, sigma, sigma0, epsilon)

    Kyynd = Kyy - torch.diag(torch.diag(Kyy))
    Kzznd = Kzz - torch.diag(torch.diag(Kzz))

    u_yy = Kyynd.sum() / (Y.shape[0] * (Y.shape[0] - 1))
    u_zz = Kzznd.sum() / (Z.shape[0] * (Z.shape[0] - 1))
    u_xy = Kxy.sum()   / (X.shape[0] * Y.shape[0])
    u_xz = Kxz.sum()   / (X.shape[0] * Z.shape[0])

    # Directional difference (good loss target)
    t_raw = u_yy - 2.0 * u_xy - (u_zz - 2.0 * u_xz)

    # Var for standardization; keep differentiable w.r.t epsilon
    Diff_Var, _, _ = MMD_Diff_Var(Kyy, Kzz, Kxy, Kxz, epsilon)
    Diff_Var = torch.clamp_min(Diff_Var, torch.as_tensor(epsilon, dtype=Diff_Var.dtype, device=Diff_Var.device))

    t_std = t_raw / torch.sqrt(Diff_Var + 1e-12)

    # Decision/p-value (no grad needed)
    with torch.no_grad():
        from torch.distributions.normal import Normal
        # Two-sided by default; switch to one-sided if needed.
        p_value = 2.0 * (1.0 - Normal(0, 1).cdf(torch.abs(t_std)))
        h = (p_value <= alpha).to(torch.int32)
        ##############################################
        # if p_value <= alpha:
        #     h = 1
        # else:
        #     h = 0
        ###############################################
    return h, p_value, t_std, t_raw, Diff_Var

# ==========================================
# Example: one-step MMD update of kernel pars
# ==========================================
@torch.no_grad()
def _flatten(x):  # helper to build X_org/Y_org/Z_org
    return x.view(x.size(0), -1)

def mmd_update_step(net_loader,
                    feature_for_sents_sample,     # (B, T, D)
                    feature_mgt_ref_sample,       # (B, T, D)
                    feature_hwt_ref_sample,       # (B, T, D)
                    lr=0.2, two_sided=False, alpha=0.05):
    """
    Runs one gradient step that updates net.sigma, net.sigma0_u, net.ep using the MMD statistic.
    Your call pattern stays the same inside the MMD function.
    """
    net = net_loader  # so you can refer to net.sigma etc.
    # Build a tiny optimizer on the three kernel params
    opt_kernel = torch.optim.Adam([net.sigma, net.sigma0_u, net.ep], lr=lr)

    # 1) Extract learned features (freeze backbone)
    with torch.no_grad():
        Xl = net.net(feature_for_sents_sample.to(DEVICE))
        Yl = net.net(feature_mgt_ref_sample.to(DEVICE))
        Zl = net.net(feature_hwt_ref_sample.to(DEVICE))

    # 2) Originals flattened (no grad path to backbone needed)
    Xf = _flatten(feature_for_sents_sample.to(DEVICE))
    Yf = _flatten(feature_mgt_ref_sample.to(DEVICE))
    Zf = _flatten(feature_hwt_ref_sample.to(DEVICE))

    # 3) Compute MMD test and backprop through t_raw (or t_std)
    opt_kernel.zero_grad()
    h_u, p_value, t_std, t_raw, diff_var = MMD_3_Sample_Test(
        Xl, Yl, Zl,
        Xf, Yf, Zf,
        net.sigma, net.sigma0_u, net.ep,
        alpha
    )
    # Choose objective: directional or magnitude
    loss = -t_raw if not two_sided else -t_raw.abs()
    loss.backward()
    # Optional: inspect gradients once
    # print("grads:", net.sigma.grad, net.sigma0_u.grad, net.ep.grad)

    opt_kernel.step()

    # Keep params in sane ranges (optional but recommended)
    with torch.no_grad():
        net.sigma.clamp_(1e-8, 1e6)
        net.sigma0_u.clamp_(1e-8, 1e6)
        net.ep.clamp_(1e-12, 1e-2)

    # Return tensors for logging or further use
    return {
        "h": int(h_u.item()),
        "p_value": float(p_value.item()),
        "t_std": float(t_std.detach().cpu().item()),
        "t_raw": float(t_raw.detach().cpu().item()),
        "diff_var": float(diff_var.detach().cpu().item()),
        "sigma": float(net.sigma.detach().cpu().item()),
        "sigma0_u": float(net.sigma0_u.detach().cpu().item()),
        "ep": float(net.ep.detach().cpu().item()),
    }

# ==========================
# (Optional) quick sanity IO
# ==========================
def save_updated_checkpoint(net_loader, path="net.pt"):
    """Persist updated net + kernel params back to disk."""
    with torch.no_grad():
        torch.save({
            "net": net_loader.net.state_dict(),
            "sigma":    net_loader.sigma.detach().cpu(),
            "sigma0_u": net_loader.sigma0_u.detach().cpu(),
            "ep":       net_loader.ep.detach().cpu(),
        }, path)


In [None]:
import torch
import torch.nn as nn

# If you already define DEVICE elsewhere, you can remove this.
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---- helpers ---------------------------------------------------------------

def to_tensor(x, device=None, dtype=torch.float32):
    if isinstance(x, torch.Tensor):
        return x.to(device=device, dtype=dtype, non_blocking=True)
    return torch.as_tensor(x, dtype=dtype, device=device)

def feature_ref_loader(load_ref_data, num_ref=5000):
    """
    Torch-only sampler (no NumPy). Keeps gradients/layouts intact.
    """
    t = load_ref_data
    if num_ref is not None and num_ref > 0 and t.shape[0] > num_ref:
        idx = torch.randperm(t.shape[0], device=t.device)[:num_ref]
        t = t[idx]
    return t

def _ensure_param(obj, name, value, device):
    """
    Make sure obj.<name> is an nn.Parameter on the correct device/dtype,
    and copy 'value' into it without losing Parameter-ness.
    """
    val = torch.as_tensor(value, dtype=torch.float32, device=device).clone().detach()
    p = getattr(obj, name, None)
    if isinstance(p, nn.Parameter):
        with torch.no_grad():
            p.data.copy_(val)
        p.requires_grad_(True)
    else:
        setattr(obj, name, nn.Parameter(val, requires_grad=True))
    return getattr(obj, name)

# ---- your class ------------------------------------------------------------

class RelativeTester:
    def __init__(self, feature_hwt_ref, feature_mgt_ref, feature_test, *,
                 device=DEVICE, dtype=torch.float32, max_ref=8161, net_loader_cls=None): # was defult 5000 which is not correct; should be 8161
        """
        Minimal-change init:
        - convert once
        - sample refs with torch.randperm (no NumPy)
        - keep tensors contiguous
        - build NetLoader once (optionally injected)
        """
        print("Relative Tester init")
        self.device, self.dtype = device, dtype

        # Convert inputs once
        feature_hwt_ref = to_tensor(feature_hwt_ref, device, dtype)
        feature_mgt_ref = to_tensor(feature_mgt_ref, device, dtype)
        feature_test    = to_tensor(feature_test,    device, dtype)

        # Optional subsample to control memory (kept disabled by default)
        feature_hwt_ref = feature_ref_loader(feature_hwt_ref, num_ref=max_ref) # load all samples
        feature_mgt_ref = feature_ref_loader(feature_mgt_ref, num_ref=max_ref)

        self.feature_hwt_ref = feature_hwt_ref.contiguous()
        self.feature_mgt_ref = feature_mgt_ref.contiguous()
        self.feature_test    = feature_test.contiguous()

        # Build/load net once here (was inside test()) to avoid reloading each call
        # If you must keep the old behavior, you can move this back into test().
        NetLoaderCls = NetLoader if net_loader_cls is None else net_loader_cls
        self.net = NetLoaderCls()  # assumes your NetLoader already loads 'net.pt'
        self.net.net.eval()        # deterministic inference for features

        # --- Ensure kernel scalars are learnable Parameters on self.net (critical) ---
        # If NetLoader already registers them as nn.Parameter, these lines simply copy.
        self.net.sigma    = _ensure_param(self.net, "sigma",    getattr(self.net, "sigma",    1.0),   self.device)
        self.net.sigma0_u = _ensure_param(self.net, "sigma0_u", getattr(self.net, "sigma0_u", 1.0),   self.device)
        self.net.ep       = _ensure_param(self.net, "ep",       getattr(self.net, "ep",       1e-6),  self.device)

        # A tiny optimizer only for the three kernel params (built once)
        # kernel hyperparameters (σ,σ0​,ε) to maximize the three-sample MMD test statistic 𝑡

        self.opt_kernel = torch.optim.Adam([self.net.sigma, self.net.sigma0_u, self.net.ep], lr=0.2)

    def test(self, threshold=0.2, rounds=800, two_sided=False, update_kernel=True):  # two_sided = True 0.77
        """
        Runs rounds of the 3-sample test; optionally updates sigma/sigma0_u/ep
        via the MMD statistic each round (minimal change to your flow).

        Returns the same "probability" as your code: abs(power - 1.0).
        """
        # Safety checks
        min_len = min(len(self.feature_test), len(self.feature_hwt_ref), len(self.feature_mgt_ref))
        assert min_len > 0, "Empty inputs to RelativeTester.test()"

        h_u_list, p_value_list, t_list = [], [], []

        for _ in range(rounds):
            # Sample equal sized mini-batches
            ix = torch.randperm(len(self.feature_test),     device=self.device)[:min_len]
            iy = torch.randperm(len(self.feature_hwt_ref),  device=self.device)[:min_len]
            iz = torch.randperm(len(self.feature_mgt_ref),  device=self.device)[:min_len]

            feature_for_sents_sample = self.feature_test[ix]
            feature_hwt_ref_sample   = self.feature_hwt_ref[iy]
            feature_mgt_ref_sample   = self.feature_mgt_ref[iz]

            # Freeze backbone; just compute learned features
            with torch.no_grad():
                Xl = self.net.net(feature_for_sents_sample)  # [B, T*out_dim] (per your comment)
                Yl = self.net.net(feature_mgt_ref_sample)
                Zl = self.net.net(feature_hwt_ref_sample)    # (you noted swapped order)

            # Flatten originals once (no need for torch.no_grad() here)
            Xf = feature_for_sents_sample.view(feature_for_sents_sample.size(0), -1)
            Yf = feature_mgt_ref_sample.view(feature_mgt_ref_sample.size(0), -1)
            Zf = feature_hwt_ref_sample.view(feature_hwt_ref_sample.size(0), -1)

            # --- Call your MMD test exactly as before ---
            h_u, p_value, t, *rest = MMD_3_Sample_Test(
                Xl, Yl, Zl,   # learned features
                Xf, Yf, Zf,   # original/flattened features
                self.net.sigma, self.net.sigma0_u, self.net.ep,
                0.05, # was 0.05, AUROC converged to 0.96 after  600 rounds
            )

            # --- Optional kernel-parameter update via MMD ---
            # If your MMD returns tensors (recommended), we can backprop through 't' (or t_raw in rest)
            if update_kernel:
                # Prefer the raw directional statistic if you returned it in *rest*
                t_raw = None
                if rest:
                    # Heuristic: pick the first tensor scalar in *rest* as t_raw if present
                    for r in rest:
                        if torch.is_tensor(r) and r.numel() == 1:
                            t_raw = r
                            break
                target = t_raw if (t_raw is not None and torch.is_tensor(t_raw)) else (t if torch.is_tensor(t) else None)

                if target is not None:
                    self.opt_kernel.zero_grad()
                    loss = -target.abs() if two_sided else -target
                    loss.backward()
                    # grads should be finite; if you want, print once:
                    # print(self.net.sigma.grad, self.net.sigma0_u.grad, self.net.ep.grad)
                    self.opt_kernel.step()

                    # keep params positive and in sane ranges
                    with torch.no_grad():
                        self.net.sigma.clamp_(1e-8, 1e6)
                        self.net.sigma0_u.clamp_(1e-8, 1e6)
                        self.net.ep.clamp_(1e-12, 1e-2)
                # else: your MMD likely returned floats via .item(); then no gradient path exists.

            # Loggers (cast to Python numbers for storage)
            h_u_list.append(float(h_u))
            p_value_list.append(float(p_value) if not torch.is_tensor(p_value) else float(p_value.item()))
            t_list.append(float(t) if not torch.is_tensor(t) else float(t.item()))

        power = sum(h_u_list) / len(h_u_list)
        # Your original return: probability = abs(power - 1.0)
        return abs(power - 1.0)


In [None]:
lis = []

for i in X_val_segments:
    nmb = i.shape[0]
    lis.append(nmb)
lis.sort()
lis

In [None]:
"main"

prob_result_paragraph = []
for segments in X_val_segments[:20]:
    relative_tester = RelativeTester(HWT_ref,MGT_ref,segments) #feature_hwt_ref,feature_mgt_ref,feature_test
    prob = relative_tester.test()
    prob_result_paragraph.append(prob)


In [None]:
prob_result_paragraph

In [None]:
prob_result_paragraph # alpah = 0.027 0.096

In [None]:
prob_result_paragraph # alhap == 0.025 result 0.096  , 0.05 - > 0.96; 0.02 -> 0.96

In [None]:
prob_result_paragraph #round == 800 alhap == 0.015 result  0.95, which means there could be a summit  in the range of 0.02 to 0.01

In [None]:
prob_result_paragraph #round == 800 alhap == 0.01 result  0.94, which means there could be a summit  in the range of 0.02 to 0.01

In [None]:
prob_result_paragraph #round == 800 alhap == 0.009 result  0.93, which means there could be a summit  in the range of 0.02 to 0.009

In [None]:
prob_result_paragraph #round == 800 alhap == 0.02 result 0.96 go into large

In [None]:
prob_result_paragraph #round == 800 alhap == 0.09 result 0.94 go into small

In [None]:
prob_result_paragraph #round == 800 alhap == 0.5 result 0.96 which means the prob have cnoveraged

In [None]:
prob_result_paragraph # round == 500 alpha = 0.05 result 0.96

In [None]:
import torch

data = torch.load(net_pt_path, map_location=DEVICE)
print(type(data)) #dict_keys(['net', 'sigma', 'sigma0_u', 'ep'])
print(data.keys())
print(data['sigma']) #tensor(1.)
print(data['sigma0_u'])#tensor(1.)
print(data['ep'])#tensor(1.0000e-06)


In [None]:
y_val_segments[:20]

In [None]:
X_val_segments[17].shape

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, roc_auc_score

# 1) Replace with your data
# y_true: ground-truth labels (0/1)
# y_score: predicted scores/probabilities for the positive class
y_true  = y_val_segments          # <- true labels
y_score = prob_result_paragraph  # <- model scores

# 2) Basic checks
y_true = np.asarray(y_true, dtype=int)
y_score = np.asarray(y_score, dtype=float)
assert y_true.shape[0] == y_score.shape[0], "y_true and y_score must have same length"
assert set(np.unique(y_true)).issubset({0, 1}), "y_true must be binary (0/1)"

# 3) Compute ROC and AUC
fpr, tpr, thresholds = roc_curve(y_true, y_score)
auc = roc_auc_score(y_true, y_score)
print(f"AUC: {auc:.4f}")

# 4) Plot ROC curve
plt.figure(figsize=(6, 5))
plt.plot(fpr, tpr, label=f"ROC (AUC = {auc:.3f})", linewidth=2)
plt.plot([0, 1], [0, 1], linestyle="--", linewidth=1)  # diagonal chance line
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC Curve")
plt.legend(loc="lower right")
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()


In [None]:
# ------------------ Predict on test ------------------

In [None]:
def load_jsonl_test(path):
        """
        Returns:
          paragraphs: list of (n_segments,100,768) arrays (float32)
          ids:        list of str/int
        """
        paragraphs, ids = [], []
        with open(path, "r") as f:
            for line in f:
                row = json.loads(line)
                paragraphs.append(np.asarray(row["features"], dtype="float32"))
                ids.append(row["id"])
        return paragraphs, ids

test_path = '/content/drive/MyDrive/test_features.jsonl'
X_test_segments, test_ids = load_jsonl_test(test_path)


In [None]:


prob_result_paragraph = []
for segments in X_test_segments:
    relative_tester = RelativeTester(HWT_ref,MGT_ref,segments) #feature_hwt_ref,feature_mgt_ref,feature_test
    prob = relative_tester.test()
    prob_result_paragraph.append(prob)


In [None]:
import pandas as pd
# ------------------ Package submission ------------------
df_pred = pd.DataFrame({"id": test_ids, "y_prob": prob_result_paragraph})
out_csv = "submission_fucking_giveup.csv"
df_pred.to_csv(out_csv, index=False)
print(f"[done] Wrote {out_csv}")
print(df_pred.head())