In [1]:
# ========= 0. Setup device and threading =========

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from collections import Counter

# Device setup: use CUDA if available
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("Using device:", device)
if device.type == 'cuda':
    # enable cudnn autotuner (useful if input sizes are stable)
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.enabled = True

# Optionally control threads for CPU-bound parts
torch.set_num_threads(8)  # uncomment and tune if CPU is the bottleneck

Using device: cuda


In [2]:
# ========= 1. ËØªÂèñÂπ∂Ëß£ÊûêËæìÂÖ• =========
# df = pd.read_csv("../pymochi_catal_nostop_dataset_weighted.tsv", sep="\t")  # ‰Ω†ÁöÑËæìÂÖ•Ë°®
df = pd.read_csv("../pymochi_stab_dataset_weighted.tsv", sep="\t")
wt_seq = df.loc[df["WT"] == True, "aa_seq"].values[0]
mut_seqs = df["aa_seq"].values
# add column
df["fitness"]=1/(1+np.exp(df["ddG"].values))
# fitness = torch.tensor(df["fitness"].values, dtype=torch.float32)
fitness = torch.tensor(df["fitness"].values, dtype=torch.float32)
# sigma = torch.tensor(df["sigma"].values, dtype=torch.float32)
weight = torch.tensor(df["weight"].values, dtype=torch.float32)

In [None]:
# ========= 2. ÊâæÂá∫ÊâÄÊúâÁ™ÅÂèòÁ±ªÂûã =========

# find single mutations
def find_mutations(seq, wt):
    return [(i, wt[i], seq[i]) for i in range(len(wt)) if seq[i] != wt[i]]

all_mutations = sorted(list({
    (i, wt_seq[i], s[i]) 
    for s in mut_seqs 
    for i in range(len(wt_seq)) 
    if s[i] != wt_seq[i]
}))

mut_to_idx = {m: k for k, m in enumerate(all_mutations)}
M = len(all_mutations)  # ÁâπÂæÅÊï∞
N = len(df)

# find all mutation pairs for mutantation sequences with >1 mutation site
# e.g. for a sequence with 5 different sites, we have 10 pairs
def find_mutation_pairs(seq, wt):
    muts = find_mutations(seq, wt)
    pairs = []
    for i in range(len(muts)):
        for j in range(i + 1, len(muts)):
            pairs.append((muts[i], muts[j]))
    return pairs
all_mutation_pairs = sorted(list({
    pair
    for s in mut_seqs
    for pair in find_mutation_pairs(s, wt_seq)
}))

# remove duplicate pairs like ((0,'A','C'),(1,'G','T')) and ((1,'G','T'),(0,'A','C'))
all_mutation_pairs = [tuple(sorted(p)) for p in all_mutation_pairs]
all_mutation_pairs = sorted(list(set(all_mutation_pairs)))
pair_to_idx = {p: k for k, p in enumerate(all_mutation_pairs)}
P = len(all_mutation_pairs)  # ‰∫åÈò∂ÁâπÂæÅÊï∞
print(f"Number of features (M) = {M}, Number of total pairwise features (P) = {P}")



In [None]:
# ========= 3. ÊûÑÈÄ† one-hot ÁâπÂæÅÁü©Èòµ =========

# A. single mutation features
X = torch.zeros((N, M), dtype=torch.float32)
for n, seq in enumerate(mut_seqs):
    for m in find_mutations(seq, wt_seq):
        X[n, mut_to_idx[m]] = 1.0

# B. mutation pair features - build sparse representation to avoid huge dense allocation

# ---------- Step 1: ÁªüËÆ°ÊØè‰∏™pairÁöÑÂá∫Áé∞Ê¨°Êï∞ ----------
pair_counts = Counter()
for seq in mut_seqs:
    for p in find_mutation_pairs(seq, wt_seq):
        p = tuple(sorted(p))
        pair_counts[p] += 1

# ---------- Step 2: Á≠õÈÄâÂá∫Áé∞Ê¨°Êï∞ >= min_support ÁöÑpair ----------
min_support = 3  # ÂèØ‰ª•Ë∞ÉÊï¥ÈòàÂÄºÔºåÊØîÂ¶Ç3Êàñ5
filtered_pairs = [p for p, c in pair_counts.items() if c >= min_support]
pair_to_idx = {p: i for i, p in enumerate(filtered_pairs)}
P = len(filtered_pairs)

print(f"Number of filtered pairs kept: {P} / {len(pair_counts)} (min_support={min_support})")

# ---------- Step 3: ÊûÑÂª∫Á®ÄÁñèË°®Á§∫ ----------
row_idx, col_idx, vals = [], [], []
for n, seq in enumerate(mut_seqs):
    for p in find_mutation_pairs(seq, wt_seq):
        p = tuple(sorted(p))
        idx = pair_to_idx.get(p)
        if idx is None:
            continue  # Ë¢´ËøáÊª§ÊéâÁöÑpair‰∏ç‰ºöÂä†ÂÖ•Áü©Èòµ
        row_idx.append(n)
        col_idx.append(idx)
        vals.append(1.0)

# ---------- Step 4: ËΩ¨ÊàêÁ®ÄÁñèÂº†Èáè ----------
if len(row_idx) == 0:
    X_pair = torch.zeros((N, 0), dtype=torch.float32)
else:
    indices = torch.tensor([row_idx, col_idx], dtype=torch.long)
    values = torch.tensor(vals, dtype=torch.float32)
    X_pair = torch.sparse_coo_tensor(indices, values, size=(N, P))
    X_pair = X_pair.coalesce()
    print(f"Constructed sparse X_pair with nnz={X_pair._nnz()}")

# print(f"N={N}, P={P}")
# row_idx = []
# col_idx = []
# vals = []
# for n, seq in enumerate(mut_seqs):
#     for p in find_mutation_pairs(seq, wt_seq):
#         p = tuple(sorted(p))
#         idx = pair_to_idx.get(p)
#         if idx is None:
#             continue
#         row_idx.append(n)
#         col_idx.append(idx)
#         vals.append(1.0)

# # convert to tensors
# if len(row_idx) == 0:
#     # fallback to a small dense tensor if no pairs exist
#     X_pair = torch.zeros((N, 0), dtype=torch.float32)
# else:
#     indices = torch.tensor([row_idx, col_idx], dtype=torch.long)  # shape (2, nnz)
#     values = torch.tensor(vals, dtype=torch.float32)
#     X_pair = torch.sparse_coo_tensor(indices, values, size=(N, P))  # sparse tensor (N, P)
#     X_pair = X_pair.coalesce()  # ensure unique indices
#     print(f"Constructed sparse X_pair with nnz={X_pair._nnz()}")

# Move tensors to device (dense tensors to device; sparse move supported in newer torch)
X = X.to(device)
try:
    X_pair = X_pair.to(device)
except Exception as e:
    print("Warning: could not move sparse X_pair to device:", e)
    # keep X_pair on CPU; sparse CUDA support may depend on PyTorch build

fitness = fitness.to(device)
weight = weight.to(device)
# check X
print(X.shape)  # (N, M)
print(X.sum(dim=1))
print(X.sum(dim=0))  # check feature distribution
print(X[0])  # first row
print(X[:, 0])  # first column

# check X_pair
print(X_pair.shape)  # (N, P)

In [None]:
# ========= 4. ÂÆö‰πâÊ®°Âûã =========
class MoCHI_Core(nn.Module):
    def __init__(self, n_features):
        super().__init__()
        self.theta = nn.Parameter(torch.zeros(n_features))  # additive coefficients
        self.phi0 = nn.Parameter(torch.tensor(0.0))
        self.g = nn.Sequential(  # global epistasis (sum of sigmoids)
            nn.Linear(1, 20),
            nn.Sigmoid(),
            nn.Linear(20, 1)
        )
        self.a = nn.Parameter(torch.tensor(1.0))
        self.b = nn.Parameter(torch.tensor(0.0))

    def forward(self, X):
        phi = self.phi0 + X @ self.theta
        p = self.g(phi.unsqueeze(1)).squeeze(1)
        yhat = self.a * p + self.b
        return yhat, phi

class MoCHI_TwoState_order1(nn.Module):
    def __init__(self, M, R=8.314, T=303.0):
        super().__init__()
        self.theta = nn.Parameter(torch.zeros(M))
        self.phi0 = nn.Parameter(torch.tensor(0.0))
        self.a = nn.Parameter(torch.tensor(1.0))
        self.b = nn.Parameter(torch.tensor(0.0))
        self.R = R
        self.T = T

    def forward(self, X):
        phi = self.phi0 + X @ self.theta             # (N,)
        z = torch.clamp(phi / (self.R * self.T), -50, 50)
        p = 1.0 / (1.0 + torch.exp(z))
        yhat = self.a * p + self.b
        return yhat, phi

class MoCHI_TwoState_order2(nn.Module):
    def __init__(self, M, P, R=8.314, T=303.0):
        super().__init__()
        self.theta = nn.Parameter(torch.zeros(M))
        self.phi0 = nn.Parameter(torch.tensor(0.0))
        self.phi_pair = nn.Parameter(torch.zeros(P))
        self.a = nn.Parameter(torch.tensor(1.0))
        self.b = nn.Parameter(torch.tensor(0.0))
        self.R = R
        self.T = T

    def forward(self, X, X_pair):
        # Accept dense or sparse X_pair. If sparse, use efficient sparse mm
        phi = self.phi0 + X @ self.theta
        if getattr(X_pair, 'is_sparse', False):
            phi_pair_vec = self.phi_pair.unsqueeze(1)  # (P,1)
            # Sparse operations don't support autocast, so we need to disable it temporarily
            if torch.is_autocast_enabled():
                with torch.amp.autocast(device_type='cuda', enabled=False):
                    pair_term = torch.sparse.mm(X_pair, phi_pair_vec).squeeze(1)
            else:
                pair_term = torch.sparse.mm(X_pair, phi_pair_vec).squeeze(1)
        else:
            pair_term = X_pair @ self.phi_pair
        phi = phi + pair_term  # (N,)
        z = torch.clamp(phi / (self.R * self.T), -50, 50)
        p = 1.0 / (1.0 + torch.exp(z))
        yhat = self.a * p + self.b
        return yhat, phi


class Linear_order2(nn.Module):
    """
    Á∫øÊÄßÊ®°ÂûãÔºöyhat = phi0 + X @ theta + X_pair @ phi_pair
    """
    def __init__(self, M, P):
        super().__init__()
        self.phi0 = nn.Parameter(torch.zeros(1))
        self.theta = nn.Parameter(torch.zeros(M))
        self.phi_pair = nn.Parameter(torch.zeros(P))

    def forward(self, X, X_pair):
        # ÁÆÄÂçïÁöÑÁ∫øÊÄßÈ¢ÑÊµãÔºöyhat = phi0 + X @ theta + X_pair @ phi_pair
        yhat = self.phi0 + X @ self.theta
        
        if getattr(X_pair, 'is_sparse', False):
            phi_pair_vec = self.phi_pair.unsqueeze(1)
            if torch.is_autocast_enabled():
                with torch.amp.autocast(device_type='cuda', enabled=False):
                    pair_term = torch.sparse.mm(X_pair, phi_pair_vec).squeeze(1)
            else:
                pair_term = torch.sparse.mm(X_pair, phi_pair_vec).squeeze(1)
        else:
            pair_term = X_pair @ self.phi_pair
        
        yhat = yhat + pair_term
        return yhat


class Linear_order1(nn.Module):
    """
    ‰ªÖ‰ΩøÁî®ÂçïÁ™ÅÂèòÁöÑÁ∫øÊÄßÊ®°ÂûãÔºöyhat = phi0 + X @ theta
    Áî® LASSO Âú®ÂçïÁ™ÅÂèòÁâπÂæÅ‰∏äÊ±ÇËß£Á®ÄÁñè theta (‰∏çÂåÖÂê´ pairwise)
    """
    def __init__(self, M):
        super().__init__()
        self.phi0 = nn.Parameter(torch.zeros(1))
        self.theta = nn.Parameter(torch.zeros(M))

    def forward(self, X):
        yhat = self.phi0 + X @ self.theta
        return yhat


# instantiate and move models to device
model1 = MoCHI_TwoState_order1(M).to(device)
model2 = MoCHI_TwoState_order2(M, P).to(device)
model_linear = Linear_order2(M, P).to(device)
# new: linear model using only single-mutation features
model_linear1 = Linear_order1(M).to(device)


In [None]:
# ========= 4.5 Á∫øÊÄßÊ®°ÂûãÊ±ÇËß£ÔºàLASSOÔºâ=========
# ÂØπ‰∫é Linear_order2 Ê®°ÂûãÔºåÂèØ‰ª•Áî®LASSOÁõ¥Êé•Ê±ÇËß£ÔºåÊó†ÈúÄËø≠‰ª£‰ºòÂåñ

from sklearn.linear_model import Lasso
from scipy.sparse import csr_matrix, hstack
import numpy as np

print("\n" + "=" * 60)
print("Linear Model (Linear_order2) - LASSO Solution")
print("=" * 60)

# ËΩ¨Êç¢‰∏∫ sparse Ê†ºÂºèÔºà‰∏çËΩ¨ denseÔºÅÔºâ
X_np = X.cpu().numpy()  # (N, M)
X_np_sparse = csr_matrix(X_np)  # ËΩ¨‰∏∫ sparse format

# Â§ÑÁêÜÁ®ÄÁñèÁü©ÈòµÔºàÂèØËÉΩÂú®GPUÊàñCPU‰∏äÔºâ
# ‚ö†Ô∏è ÂÖ≥ÈîÆÔºö‰øùÊåÅ sparse formatÔºå‰∏çËΩ¨ denseÔºÅ
if getattr(X_pair, 'is_sparse', False):
    # ‰ªé PyTorch sparse tensor ÊèêÂèñÊï∞ÊçÆÔºåÊûÑÂª∫ scipy sparse
    X_pair_cpu = X_pair.cpu()
    indices = X_pair_cpu._indices().numpy()
    values = X_pair_cpu._values().numpy()
    shape = X_pair_cpu.shape
    X_pair_sparse = csr_matrix((values, (indices[0], indices[1])), shape=shape)
    print(f"  Built sparse X_pair with shape {shape}, nnz={len(values)}")
else:
    X_pair_sparse = csr_matrix(X_pair.cpu().numpy())

fitness_np = fitness.cpu().numpy()  # (N,)
weight_np = weight.cpu().numpy()  # (N,)

print(f"\nFeature dimensions:")
print(f"  X shape: {X_np.shape} (single mutations)")
print(f"  X_pair shape: {X_pair_sparse.shape} (pairwise interactions)")
print(f"  Combined features: {X_np.shape[1] + X_pair_sparse.shape[1] + 1} (including phi0)")

# ÁªÑÂêàÁâπÂæÅÁü©ÈòµÔºö[ones | X | X_pair]Ôºå‰øùÊåÅ sparse Ê†ºÂºè
ones = csr_matrix(np.ones((N, 1)))
X_combined = hstack([ones, X_np_sparse, X_pair_sparse])  # ‰øùÊåÅ sparse

# ‚ö†Ô∏è ÊîπËøõÔºö‰∏çÂú®fit‰∏≠Â∫îÁî®Âä†ÊùÉÔºåËÄåÊòØÂú®lossËÆ°ÁÆó‰∏≠Â∫îÁî®Âä†ÊùÉ
# ËøôÊ†∑ÂèØ‰ª•ÈÅøÂÖçËøáÂ∫¶Ê≠£ÂàôÂåñÂØºËá¥ÂèÇÊï∞ÂÖ®ÂèòÈõ∂
print(f"  Feature matrix (unweighted): {X_combined.shape} (sparse)")

# Â∞ùËØïÂ§ö‰∏™ alpha ÂÄº - ÈúÄË¶ÅÊõ¥Â∞èÁöÑalpha‰ª•ÈÅøÂÖçËøáÂ∫¶Ê≠£ÂàôÂåñ
alphas = [1e-8, 1e-7, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2]
lasso_results = {}

print(f"\nTraining LASSO with different alpha values...\n")

for alpha in alphas:
    lasso = Lasso(alpha=alpha, max_iter=10000, tol=1e-6, fit_intercept=False)
    # Âú®ÂéüÂßãÔºàÊú™Âä†ÊùÉÔºâÊï∞ÊçÆ‰∏äfit
    lasso.fit(X_combined, fitness_np)
    coef = lasso.coef_
    
    # ÊèêÂèñÂèÇÊï∞
    phi0_sol = coef[0]
    theta_sol = coef[1:1+M]
    phi_pair_sol = coef[1+M:1+M+P]
    
    # ËÆ°ÁÆóÈ¢ÑÊµãÂíåÂä†ÊùÉÊçüÂ§±
    yhat_linear = phi0_sol + X_np @ theta_sol + X_pair_sparse.dot(phi_pair_sol)
    residual = fitness_np - yhat_linear
    # Âä†ÊùÉÊçüÂ§±
    weighted_mae = np.mean(np.abs(residual * weight_np))
    
    # ËÆ°ÁÆóÁ®ÄÁñèÂ∫¶
    nnz_theta = np.count_nonzero(np.abs(theta_sol) > 1e-10)
    nnz_phi_pair = np.count_nonzero(np.abs(phi_pair_sol) > 1e-10)
    sparsity_phi_pair = 100 * (1 - nnz_phi_pair / len(phi_pair_sol))
    
    # ËÆ°ÁÆó L1 ËåÉÊï∞
    l1_theta = np.sum(np.abs(theta_sol))
    l1_phi_pair = np.sum(np.abs(phi_pair_sol))
    
    lasso_results[alpha] = {
        'phi0': phi0_sol,
        'theta': theta_sol,
        'phi_pair': phi_pair_sol,
        'loss': weighted_mae,
        'nnz_theta': nnz_theta,
        'nnz_phi_pair': nnz_phi_pair,
        'sparsity_phi_pair': sparsity_phi_pair,
        'l1_theta': l1_theta,
        'l1_phi_pair': l1_phi_pair
    }
    
    print(f"alpha = {alpha:8.2e} | Loss = {weighted_mae:.6f} | " +
          f"NNZ(theta)={nnz_theta:4d} | NNZ(phi_pair)={nnz_phi_pair:5d} | " +
          f"Sparsity={sparsity_phi_pair:5.1f}%")

# ÊâæÂà∞ÊçüÂ§±ÊúÄÂ∞èÁöÑÊ®°Âûã
best_alpha = min(lasso_results.keys(), key=lambda a: lasso_results[a]['loss'])
best_result = lasso_results[best_alpha]

print(f"\n{'='*60}")
print(f"Best LASSO solution: alpha = {best_alpha}")
print(f"{'='*60}")
print(f"Loss (Weighted MAE): {best_result['loss']:.6f}")
print(f"Parameters:")
print(f"  phi0: {best_result['phi0']:.6f}")
print(f"  theta: NNZ={best_result['nnz_theta']}, ||theta||_1={best_result['l1_theta']:.4f}")
print(f"  phi_pair: NNZ={best_result['nnz_phi_pair']}, Sparsity={best_result['sparsity_phi_pair']:.2f}%, " +
      f"||phi_pair||_1={best_result['l1_phi_pair']:.4f}")
print(f"{'='*60}")

# Â∞ÜÊúÄ‰Ω≥ÂèÇÊï∞Âä†ËΩΩÂà∞Ê®°Âûã
with torch.no_grad():
    model_linear.phi0.data = torch.tensor([best_result['phi0']], dtype=torch.float32, device=device)
    model_linear.theta.data = torch.tensor(best_result['theta'], dtype=torch.float32, device=device)
    model_linear.phi_pair.data = torch.tensor(best_result['phi_pair'], dtype=torch.float32, device=device)

print("\n‚úì LASSO parameters loaded to model_linear")

# ========= 4.6b Á∫øÊÄßÊ®°ÂûãÔºàÂçïÁ™ÅÂèòÔºâÊ±ÇËß£ - Linear_order1 (LASSO) =========
print('\n' + '='*60)
print('Linear Model (Linear_order1) - LASSO Solution (single mutations only)')
print('='*60)

X_single_sparse = X_np_sparse  # csr_matrix of shape (N, M)
X1_combined = hstack([ones, X_single_sparse])

lasso1_results = {}
print('\nTraining LASSO (single-mutation only) with different alpha values...\n')

for alpha in alphas:
    lasso1 = Lasso(alpha=alpha, max_iter=10000, tol=1e-6, fit_intercept=False)
    # Âú®ÂéüÂßãÊï∞ÊçÆ‰∏äfitÔºà‰∏çÂä†ÊùÉÔºâ
    lasso1.fit(X1_combined, fitness_np)
    coef1 = lasso1.coef_
    phi01 = coef1[0]
    theta1 = coef1[1:1+M]
    
    # ËÆ°ÁÆóÂä†ÊùÉÊçüÂ§±
    yhat1 = phi01 + X_np @ theta1
    residual1 = fitness_np - yhat1
    weighted_mae1 = np.mean(np.abs(residual1 * weight_np))
    
    nnz_theta1 = np.count_nonzero(np.abs(theta1) > 1e-10)
    l1_theta1 = np.sum(np.abs(theta1))
    lasso1_results[alpha] = { 'phi0': phi01, 'theta': theta1, 'loss': weighted_mae1, 'nnz_theta': nnz_theta1, 'l1_theta': l1_theta1 }
    print(f"alpha = {alpha:8.2e} | Loss = {weighted_mae1:.6f} | NNZ(theta)={nnz_theta1:4d}")

# choose best
best_alpha1 = min(lasso1_results.keys(), key=lambda a: lasso1_results[a]['loss'])
best_result1 = lasso1_results[best_alpha1]

print('\n' + '='*60)
print(f'Best LASSO (single) alpha = {best_alpha1}')
print(f"Loss (Weighted MAE): {best_result1['loss']:.6f}")
print('\nLoading parameters into model_linear1...')

with torch.no_grad():
    model_linear1.phi0.data = torch.tensor([best_result1['phi0']], dtype=torch.float32, device=device)
    model_linear1.theta.data = torch.tensor(best_result1['theta'], dtype=torch.float32, device=device)

print('\n‚úì LASSO (single) parameters loaded to model_linear1')


In [None]:
# ========= 5. ËÆ≠ÁªÉËÆæÁΩÆ =========

optimizer1 = optim.Adam(params=model1.parameters(), lr=0.05)
optimizer2 = optim.Adam(params=model2.parameters(), lr=0.05)

lambda_l1 = 1e-8  # L1 penalty factor for phi_pair
n_epochs = 30000

# Early stopping parameters
patience = 500
min_delta = 1e-5
best_loss = float('inf')
patience_counter = 0
stopped_early = False

# Setup mixed precision training if using CUDA
use_amp = (device.type == 'cuda')
scaler1 = torch.amp.GradScaler('cuda') if use_amp else None
scaler2 = torch.amp.GradScaler('cuda') if use_amp else None

# Initialize loss history tracking
loss_history1 = []
loss_history2 = []
data_loss_history = []
l1_loss_history = []
epoch_list = []

# ========= Order 1 Training =========
print("=" * 60)
print("Training Order 1 Model (Single Mutations Only)")
print("=" * 60)

best_loss1 = float('inf')
patience_counter1 = 0

for epoch in range(n_epochs):
    optimizer1.zero_grad()
    
    # Mixed precision forward pass
    if use_amp:
        with torch.amp.autocast('cuda'):
            yhat, phi = model1(X)
            loss = torch.mean(torch.abs((fitness - yhat) * weight))
        # Scaled backward pass
        scaler1.scale(loss).backward()
        scaler1.step(optimizer1)
        scaler1.update()
    else:
        yhat, phi = model1(X)
        loss = torch.mean(torch.abs((fitness - yhat) * weight))
        loss.backward()
        optimizer1.step()
    
    # Record loss every 10 epochs
    if epoch % 10 == 0:
        loss_history1.append(loss.item())
        if epoch == 0:
            epoch_list.append(epoch)
    
    # Early stopping check
    current_loss = loss.item()
    if current_loss < best_loss1 * (1 - min_delta):
        best_loss1 = current_loss
        patience_counter1 = 0
    else:
        patience_counter1 += 1
    
    if epoch % 500 == 0:
        phi_range = f"[{phi.min().item():.2f}, {phi.max().item():.2f}]"
        print(f"Epoch {epoch:4d} | Loss = {loss.item():.5f} | phi_range = {phi_range} | patience = {patience_counter1}/{patience}")
    
    if patience_counter1 >= patience:
        print(f"\n>>> Early stopping triggered at epoch {epoch}")
        print(f">>> Best loss: {best_loss1:.5f}, Current loss: {current_loss:.5f}")
        break

print("\n" + "=" * 60)
print("Training Order 2 Model (Single Mutations + Pairs)")
print(f"L1 regularization: lambda_l1={lambda_l1}")
print("=" * 60)

# ========= Order 2 Training =========
for epoch in range(n_epochs):
    optimizer2.zero_grad()
    
    # Mixed precision forward pass
    if use_amp:
        with torch.amp.autocast('cuda'):
            yhat, phi = model2(X, X_pair)
            data_loss = torch.mean(torch.abs((fitness - yhat) * weight))
            l1_penalty = lambda_l1 * torch.norm(model2.phi_pair, p=1)
            loss = data_loss + l1_penalty
        # Scaled backward pass
        scaler2.scale(loss).backward()
        scaler2.step(optimizer2)
        scaler2.update()
    else:
        yhat, phi = model2(X, X_pair)
        data_loss = torch.mean(torch.abs((fitness - yhat) * weight))
        l1_penalty = lambda_l1 * torch.norm(model2.phi_pair, p=1)
        loss = data_loss + l1_penalty
        loss.backward()
        optimizer2.step()
    
    # Record loss every 10 epochs
    if epoch % 10 == 0:
        loss_history2.append(loss.item())
        data_loss_history.append(data_loss.item())
        l1_loss_history.append(l1_penalty.item())
        if len(epoch_list) < len(loss_history2):
            epoch_list.append(epoch)
    
    # Early stopping based on data_loss only
    current_data_loss = data_loss.item()
    if current_data_loss < best_loss * (1 - min_delta):
        best_loss = current_data_loss
        patience_counter = 0
    else:
        patience_counter += 1
    
    if epoch % 500 == 0:
        phi_pair_norm = torch.norm(model2.phi_pair, p=1).item()
        nnz_phi_pair = int((torch.abs(model2.phi_pair) > 1e-10).sum().item())
        sparsity_pct = 100.0 * (1.0 - nnz_phi_pair / model2.phi_pair.numel())
        print(f"Epoch {epoch:4d} | Total={loss.item():.5f} | Data={data_loss.item():.5f} | L1={l1_penalty.item():.5f}")
        print(f"              ||phi_pair||_1={phi_pair_norm:.4f} | NNZ={nnz_phi_pair} | Sparsity={sparsity_pct:.1f}% | patience={patience_counter}/{patience}")
    
    if patience_counter >= patience:
        print(f"\n{'='*60}")
        print(f">>> Early stopping triggered at epoch {epoch}")
        print(f">>> Best data loss: {best_loss:.5f}, Current: {current_data_loss:.5f}")
        print(f">>> No improvement for {patience} consecutive epochs")
        print(f"{'='*60}")
        stopped_early = True
        break

print("\n" + "=" * 60)
if stopped_early:
    print(f"Training stopped early at epoch {epoch}/{n_epochs}")
else:
    print("Training completed all epochs!")
print(f"Final data loss: {data_loss.item():.5f}")
print(f"Final L1 penalty: {l1_penalty.item():.5f}")
print(f"Final total loss: {loss.item():.5f}")

# Final sparsity analysis
final_nnz = int((torch.abs(model2.phi_pair) > 1e-10).sum().item())
final_sparsity = 100.0 * (1.0 - final_nnz / model2.phi_pair.numel())
print(f"Final phi_pair non-zeros: {final_nnz} / {model2.phi_pair.numel()} ({final_sparsity:.2f}% sparse)")
print("=" * 60)

In [None]:
# ========= 5.5 ËÆ≠ÁªÉËøáÁ®ãÂèØËßÜÂåñ =========
import matplotlib.pyplot as plt

# Create epoch lists for each model based on actual recorded losses
epoch_list1 = list(range(0, len(loss_history1) * 10, 10)) if loss_history1 else []
epoch_list2 = list(range(0, len(loss_history2) * 10, 10))

fig, axes = plt.subplots(2, 2, figsize=(16, 12))

# Plot 1: Total loss curves comparison
ax1 = axes[0, 0]
if loss_history1:
    ax1.plot(epoch_list1, loss_history1, linewidth=2, color='steelblue', label='Order 1 Model', alpha=0.8)
ax1.plot(epoch_list2, loss_history2, linewidth=2, color='coral', label='Order 2 Model (Total)', alpha=0.8)
ax1.set_xlabel('Epoch', fontsize=13)
ax1.set_ylabel('Loss (Weighted L1)', fontsize=13)
ax1.set_title('Training Loss Comparison', fontsize=14, fontweight='bold')
ax1.legend(fontsize=12)
ax1.grid(alpha=0.3)

# Plot 2: Data loss vs L1 penalty decomposition
ax2 = axes[0, 1]
# ÂøÖÈ°ª‰ΩøÁî®ÂØπÂ∫îÁöÑÈïøÂ∫¶Ôºå‰∏çËÉΩÂÅáËÆæ epoch_list2 Âíå data_loss_history ÈïøÂ∫¶Áõ∏Âêå
if data_loss_history and l1_loss_history:
    ax2.plot(epoch_list2[:len(data_loss_history)], data_loss_history, linewidth=2, color='green', label='Data Loss', alpha=0.8)
    ax2_twin = ax2.twinx()
    ax2_twin.plot(epoch_list2[:len(l1_loss_history)], l1_loss_history, linewidth=2, color='red', label='L1 Penalty', alpha=0.8, linestyle='--')
    ax2.set_xlabel('Epoch', fontsize=13)
    ax2.set_ylabel('Data Loss', fontsize=13, color='green')
    ax2_twin.set_ylabel('L1 Penalty', fontsize=13, color='red')
    ax2.set_title('Loss Decomposition (Order 2 Model)', fontsize=14, fontweight='bold')
    ax2.tick_params(axis='y', labelcolor='green')
    ax2_twin.tick_params(axis='y', labelcolor='red')
    ax2.legend(loc='upper left', fontsize=12)
    ax2_twin.legend(loc='upper right', fontsize=12)
    ax2.grid(alpha=0.3)
else:
    ax2.text(0.5, 0.5, 'Data loss history not available\n(training cell needs to complete)', 
            ha='center', va='center', fontsize=12, transform=ax2.transAxes)
    ax2.set_title('Loss Decomposition (Order 2 Model)', fontsize=14, fontweight='bold')

# Plot 3: Log scale for total loss
ax3 = axes[1, 0]
if loss_history1:
    ax3.semilogy(epoch_list1, loss_history1, linewidth=2, color='steelblue', label='Order 1 Model', alpha=0.8)
ax3.semilogy(epoch_list2, loss_history2, linewidth=2, color='coral', label='Order 2 Model (Total)', alpha=0.8)
ax3.set_xlabel('Epoch', fontsize=13)
ax3.set_ylabel('Loss (Log Scale)', fontsize=13)
ax3.set_title('Training Loss Comparison (Log Scale)', fontsize=14, fontweight='bold')
ax3.legend(fontsize=12)
ax3.grid(alpha=0.3, which='both')

# Plot 4: Data loss in log scale
ax4 = axes[1, 1]
ax4.semilogy(epoch_list2[:len(data_loss_history)], data_loss_history, linewidth=2, color='green', label='Data Loss (Order 2)', alpha=0.8)
ax4.set_xlabel('Epoch', fontsize=13)
ax4.set_ylabel('Data Loss (Log Scale)', fontsize=13)
ax4.set_title('Data Loss Only (Log Scale)', fontsize=14, fontweight='bold')
ax4.legend(fontsize=12)
ax4.grid(alpha=0.3, which='both')

plt.tight_layout()
plt.show()

# Print training summary
print("=" * 60)
print("Training Summary:")
print("=" * 60)
if loss_history1:
    print(f"\nOrder 1 Model:")
    print(f"  Total epochs trained: {len(loss_history1) * 10}")
    print(f"  Initial loss: {loss_history1[0]:.5f}")
    print(f"  Final loss: {loss_history1[-1]:.5f}")
    print(f"  Loss reduction: {loss_history1[0] - loss_history1[-1]:.5f} ({(loss_history1[0] - loss_history1[-1])/loss_history1[0]*100:.2f}%)")
print(f"\nOrder 2 Model:")
print(f"  Total epochs trained: {len(loss_history2) * 10}")
print(f"  Initial total loss: {loss_history2[0]:.5f}")
print(f"  Final total loss: {loss_history2[-1]:.5f}")
print(f"  Total loss reduction: {loss_history2[0] - loss_history2[-1]:.5f} ({(loss_history2[0] - loss_history2[-1])/loss_history2[0]*100:.2f}%)")
print(f"\n  Initial data loss: {data_loss_history[0]:.5f}")
print(f"  Final data loss: {data_loss_history[-1]:.5f}")
print(f"  Data loss reduction: {data_loss_history[0] - data_loss_history[-1]:.5f} ({(data_loss_history[0] - data_loss_history[-1])/data_loss_history[0]*100:.2f}%)")
print(f"\n  Initial L1 penalty: {l1_loss_history[0]:.5f}")
print(f"  Final L1 penalty: {l1_loss_history[-1]:.5f}")
print(f"  L1 penalty change: {l1_loss_history[-1] - l1_loss_history[0]:.5f}")
if loss_history1:
    print(f"\nFinal loss comparison:")
    print(f"  Order 2 vs Order 1: {loss_history2[-1] - loss_history1[-1]:.5f} ({(loss_history2[-1] - loss_history1[-1])/loss_history1[-1]*100:.2f}%)")
print("=" * 60)

In [None]:
# ========= 7. Ê®°ÂûãÂèÇÊï∞ÂàÜÂ∏ÉÂàÜÊûê =========


# Extract parameters from model2
theta_values = model2.theta.detach().cpu().numpy()
phi_pair_values = model2.phi_pair.detach().cpu().numpy()
phi0_value = model2.phi0.item()
a_value = model2.a.item()
b_value = model2.b.item()

print("=" * 60)
print("Model Parameters Summary:")
print("=" * 60)
print(f"phi0 (baseline): {phi0_value:.4f}")
print(f"a (scaling): {a_value:.4f}")
print(f"b (offset): {b_value:.4f}")
print(f"\nSingle mutation parameters (theta):")
print(f"  - Number of features: {len(theta_values)}")
print(f"  - Range: [{theta_values.min():.4f}, {theta_values.max():.4f}]")
print(f"  - Mean: {theta_values.mean():.4f}, Std: {theta_values.std():.4f}")
print(f"  - Non-zero count: {np.count_nonzero(np.abs(theta_values) > 1e-6)}")
print(f"\nMutation pair parameters (phi_pair):")
print(f"  - Number of features: {len(phi_pair_values)}")
print(f"  - Range: [{phi_pair_values.min():.4f}, {phi_pair_values.max():.4f}]")
print(f"  - Mean: {phi_pair_values.mean():.4f}, Std: {phi_pair_values.std():.4f}")
print(f"  - Non-zero count: {np.count_nonzero(np.abs(phi_pair_values) > 1e-6)}")
print("=" * 60)

# Create visualization
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Plot 1: Theta distribution (histogram)
ax1 = axes[0, 0]
ax1.hist(theta_values, bins=100, edgecolor='black', alpha=0.7, color='steelblue')
ax1.axvline(0, color='red', linestyle='--', linewidth=2, label='Zero')
ax1.set_xlabel('Parameter value', fontsize=12)
ax1.set_ylabel('Frequency', fontsize=12)
ax1.set_title(f'Single Mutation Parameters (theta)\nN={len(theta_values)}, Mean={theta_values.mean():.4f}', fontsize=13)
ax1.legend()
ax1.grid(alpha=0.3)

# Plot 2: Phi_pair distribution (histogram)
ax2 = axes[0, 1]
ax2.hist(phi_pair_values, bins=100, edgecolor='black', alpha=0.7, color='coral')
ax2.axvline(0, color='red', linestyle='--', linewidth=2, label='Zero')
ax2.set_xlabel('Parameter value', fontsize=12)
ax2.set_ylabel('Frequency', fontsize=12)
ax2.set_title(f'Mutation Pair Parameters (phi_pair)\nN={len(phi_pair_values)}, Mean={phi_pair_values.mean():.4f}', fontsize=13)
ax2.legend()
ax2.grid(alpha=0.3)

# Plot 3: Sorted theta values (identify important features)
ax3 = axes[1, 0]
sorted_theta = np.sort(theta_values)
ax3.plot(sorted_theta, linewidth=1.5, color='steelblue', label='Sorted parameters')
ax3.fill_between(range(len(sorted_theta)), sorted_theta, 0, alpha=0.3, color='steelblue')
ax3.axhline(0, color='red', linestyle='--', linewidth=1.5, alpha=0.7)
ax3.set_xlabel('Feature index (sorted)', fontsize=12)
ax3.set_ylabel('Parameter value', fontsize=12)
ax3.set_title(f'Sorted Single Mutation Parameters\nRange: [{sorted_theta.min():.4f}, {sorted_theta.max():.4f}]', fontsize=13)
ax3.grid(alpha=0.3)
ax3.legend(fontsize=10)

# Plot 4: Sorted phi_pair values
ax4 = axes[1, 1]
sorted_phi_pair = np.sort(phi_pair_values)
ax4.plot(sorted_phi_pair, linewidth=1.5, color='coral', label='Sorted parameters')
ax4.fill_between(range(len(sorted_phi_pair)), sorted_phi_pair, 0, alpha=0.3, color='coral')
ax4.axhline(0, color='red', linestyle='--', linewidth=1.5, alpha=0.7)
ax4.set_xlabel('Feature index (sorted)', fontsize=12)
ax4.set_ylabel('Parameter value', fontsize=12)
ax4.set_title(f'Sorted Mutation Pair Parameters\nRange: [{sorted_phi_pair.min():.4f}, {sorted_phi_pair.max():.4f}]', fontsize=13)
ax4.grid(alpha=0.3)
ax4.legend(fontsize=10)

plt.tight_layout()
plt.show()

# Show top influential parameters
print("\nTop 10 most positive single mutations (theta):")
top_positive_theta_idx = np.argsort(theta_values)[-10:][::-1]
for i, idx in enumerate(top_positive_theta_idx, 1):
    mut = all_mutations[idx]
    print(f"  {i}. Position {mut[0]}: {mut[1]}->{mut[2]}, theta={theta_values[idx]:.4f}")

print("\nTop 10 most negative single mutations (theta):")
top_negative_theta_idx = np.argsort(theta_values)[:10]
for i, idx in enumerate(top_negative_theta_idx, 1):
    mut = all_mutations[idx]
    print(f"  {i}. Position {mut[0]}: {mut[1]}->{mut[2]}, theta={theta_values[idx]:.4f}")

print("\nTop 10 most positive mutation pairs (phi_pair):")
top_positive_pair_idx = np.argsort(phi_pair_values)[-10:][::-1]
for i, idx in enumerate(top_positive_pair_idx, 1):
    pair = all_mutation_pairs[idx]
    print(f"  {i}. {pair[0]} & {pair[1]}, phi_pair={phi_pair_values[idx]:.4f}")

print("\nTop 10 most negative mutation pairs (phi_pair):")
top_negative_pair_idx = np.argsort(phi_pair_values)[:10]
for i, idx in enumerate(top_negative_pair_idx, 1):
    pair = all_mutation_pairs[idx]
    print(f"  {i}. {pair[0]} & {pair[1]}, phi_pair={phi_pair_values[idx]:.4f}")


In [None]:
# Figure 4: ÂõõÊ®°ÂûãÂØπÊØîÔºàOrder 1, Order 2, Linear pairs, Linear singleÔºâ
try:
    fig4, axes = plt.subplots(2, 4, figsize=(24, 12))

    with torch.no_grad():
        y_p1, _ = model1(X)
        y_p1 = y_p1.cpu().numpy()
        _, phi2 = model2(X, X_pair)
        y_p2 = (1.0 / (1.0 + torch.exp(torch.clamp(phi2 / (8.314 * 303.0), -50, 50)))).cpu().numpy()
        y_p2 = model2.a.item() * y_p2 + model2.b.item()

    r1_tmp, p1_tmp = pearsonr(fitness_np, y_p1)
    r2_tmp, p2_tmp = pearsonr(fitness_np, y_p2)

    models_data = [
        ("Order 1\n(Single muts)", y_p1, r1_tmp),
        ("Order 2\n(With pairs)", y_p2, r2_tmp),
        ("Linear\n(Pairs)", y_pred_linear, r_linear_calc),
        ("Linear\n(Single)", y_pred_linear1_val, r_linear1_calc)
    ]

    # Compute y-axis limits SEPARATELY for each model to avoid compression
    # Store y-limits for each of the 4 models
    model_ylims_4 = []
    for col, (name, y_pred, r) in enumerate(models_data):
        y_min = y_pred.min()
        y_max = y_pred.max()
        y_margin = (y_max - y_min) * 0.1 if (y_max - y_min) > 0 else 0.1
        model_ylims_4.append((y_min - y_margin, y_max + y_margin))

    # Top row: hexbin density plots
    for col, (name, y_pred, r) in enumerate(models_data):
        ax = axes[0, col]
        hb = ax.hexbin(fitness_np, y_pred, gridsize=100, cmap='viridis', mincnt=1, edgecolors='none')
        ax.plot([fitness_np.min(), fitness_np.max()], [fitness_np.min(), fitness_np.max()], 'r--', linewidth=2, alpha=0.7)
        ax.set_xlabel("Measured fitness", fontsize=11)
        ax.set_ylabel("Predicted fitness", fontsize=11)
        ax.set_title(f"{name}\nr={r:.4f}", fontsize=12, fontweight='bold')
        ax.set_xlim(fitness_np.min(), fitness_np.max())
        # Use individual model's y-limits for clearer visualization
        ax.set_ylim(model_ylims_4[col])
        plt.colorbar(hb, ax=ax, label='Density')

    # Bottom row: weight-colored scatter plots
    for col, (name, y_pred, r) in enumerate(models_data):
        ax = axes[1, col]
        scatter = ax.scatter(fitness_np, y_pred, c=weight_np, s=2, cmap='plasma', alpha=0.6, edgecolors='none')
        ax.plot([fitness_np.min(), fitness_np.max()], [fitness_np.min(), fitness_np.max()], 'r--', linewidth=2, alpha=0.7)
        ax.set_xlabel("Measured fitness", fontsize=11)
        ax.set_ylabel("Predicted fitness", fontsize=11)
        ax.set_title(f"Weight distribution", fontsize=11)
        ax.set_xlim(fitness_np.min(), fitness_np.max())
        # Use individual model's y-limits
        ax.set_ylim(model_ylims_4[col])
        plt.colorbar(scatter, ax=ax, label='Weight')

    plt.tight_layout()
    fig4.savefig(f"{results_dir}/04_four_models_comparison.png", dpi=300, bbox_inches='tight')
    plt.close()
    print(f"  ‚úì 04_four_models_comparison.png")
except Exception as e:
    print(f"  ‚ö† Could not save four models comparison figure: {e}")

In [None]:
# ========= 8. ‰øùÂ≠òÊâÄÊúâÁªìÊûú =========
import os
import json
from datetime import datetime
import shutil

# ÂàõÂª∫Â∏¶Êó∂Èó¥Êà≥ÁöÑÁªìÊûúÊñá‰ª∂Â§π
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
results_dir = f"results_{timestamp}"

# Â¶ÇÊûúÊñá‰ª∂Â§πÂ∑≤Â≠òÂú®ÔºåÂà†Èô§ÂêéÈáçÂª∫
if os.path.exists(results_dir):
    shutil.rmtree(results_dir)
os.makedirs(results_dir, exist_ok=True)

print(f"\n{'='*70}")
print(f"üìÅ Saving all results to: {results_dir}")
print(f"{'='*70}\n")

# ========= Recalculate performance metrics if needed =========
print("üìà Calculating performance metrics...")
with torch.no_grad():
    y_pred_linear1_val = model_linear1(X).cpu().numpy()
    y_pred_linear_val = model_linear(X, X_pair).cpu().numpy()

# Compute metrics for Linear_order1
residual1 = fitness_np - y_pred_linear1_val
mae_linear1 = np.mean(np.abs(residual1 * weight_np))
r_linear1_calc, p_linear1_calc = pearsonr(fitness_np, y_pred_linear1_val)

# Compute metrics for Linear (with pairs)
residual_linear = fitness_np - y_pred_linear_val
mae_linear = np.mean(np.abs(residual_linear * weight_np))
r_linear_calc, p_linear_calc = pearsonr(fitness_np, y_pred_linear_val)

# Sparsity
nnz_phi_pair = (np.abs(model_linear.phi_pair.detach().cpu().numpy()) > 1e-10).sum()
sparsity_pct = 100 * (1 - nnz_phi_pair / len(model_linear.phi_pair.detach().cpu().numpy()))

print(f"  Linear (single): r={r_linear1_calc:.4f}, MAE={mae_linear1:.6f}")
print(f"  Linear (pairs): r={r_linear_calc:.4f}, MAE={mae_linear:.6f}, Sparsity={sparsity_pct:.2f}%\n")

# ========= 1. ‰øùÂ≠òÂèØËßÜÂåñÂõæË°® =========
print("üìä Saving visualization figures...")

# Figure 1: ËÆ≠ÁªÉÊçüÂ§±Êõ≤Á∫øÔºàÂ¶ÇÊûúËÆ≠ÁªÉÂéÜÂè≤ÂèØÁî®Ôºâ
try:
    fig1, axes = plt.subplots(2, 2, figsize=(16, 12))
    
    if 'loss_history1' in dir() and loss_history1:
        epoch_list1 = list(range(0, len(loss_history1) * 10, 10))
    else:
        epoch_list1 = []
    
    if 'loss_history2' in dir() and loss_history2:
        epoch_list2 = list(range(0, len(loss_history2) * 10, 10))
    else:
        epoch_list2 = []

    ax = axes[0, 0]
    if epoch_list1 and 'loss_history1' in dir():
        ax.plot(epoch_list1, loss_history1, linewidth=2, color='steelblue', label='Order 1', alpha=0.8)
    if epoch_list2 and 'loss_history2' in dir():
        ax.plot(epoch_list2, loss_history2, linewidth=2, color='coral', label='Order 2 (Total)', alpha=0.8)
    ax.set_xlabel('Epoch', fontsize=12)
    ax.set_ylabel('Loss', fontsize=12)
    ax.set_title('Training Loss Comparison', fontsize=13, fontweight='bold')
    ax.legend(fontsize=11)
    ax.grid(alpha=0.3)

    ax = axes[0, 1]
    if 'data_loss_history' in dir() and 'l1_loss_history' in dir() and data_loss_history and l1_loss_history:
        ax.plot(epoch_list2[:len(data_loss_history)], data_loss_history, linewidth=2, color='green', label='Data Loss', alpha=0.8)
        ax2 = ax.twinx()
        ax2.plot(epoch_list2[:len(l1_loss_history)], l1_loss_history, linewidth=2, color='red', label='L1 Penalty', alpha=0.8, linestyle='--')
        ax.set_ylabel('Data Loss', fontsize=12, color='green')
        ax2.set_ylabel('L1 Penalty', fontsize=12, color='red')
        ax.tick_params(axis='y', labelcolor='green')
        ax2.tick_params(axis='y', labelcolor='red')
        ax.legend(loc='upper left', fontsize=11)
        ax2.legend(loc='upper right', fontsize=11)
    ax.set_xlabel('Epoch', fontsize=12)
    ax.set_title('Loss Decomposition (Order 2)', fontsize=13, fontweight='bold')
    ax.grid(alpha=0.3)

    ax = axes[1, 0]
    if epoch_list1 and 'loss_history1' in dir():
        ax.semilogy(epoch_list1, loss_history1, linewidth=2, color='steelblue', label='Order 1', alpha=0.8)
    if epoch_list2 and 'loss_history2' in dir():
        ax.semilogy(epoch_list2, loss_history2, linewidth=2, color='coral', label='Order 2', alpha=0.8)
    ax.set_xlabel('Epoch', fontsize=12)
    ax.set_ylabel('Loss (log scale)', fontsize=12)
    ax.set_title('Training Loss (Log Scale)', fontsize=13, fontweight='bold')
    ax.legend(fontsize=11)
    ax.grid(alpha=0.3, which='both')

    ax = axes[1, 1]
    if 'data_loss_history' in dir() and data_loss_history:
        ax.semilogy(epoch_list2[:len(data_loss_history)], data_loss_history, linewidth=2, color='green', alpha=0.8)
    ax.set_xlabel('Epoch', fontsize=12)
    ax.set_ylabel('Data Loss (log scale)', fontsize=12)
    ax.set_title('Data Loss Only (Log Scale)', fontsize=13, fontweight='bold')
    ax.grid(alpha=0.3, which='both')

    plt.tight_layout()
    fig1.savefig(f"{results_dir}/01_training_loss.png", dpi=300, bbox_inches='tight')
    plt.close()
    print(f"  ‚úì 01_training_loss.png")
except Exception as e:
    print(f"  ‚ö† Could not save training loss figure: {e}")

# Figure 2: Order 1 vs Order 2 ÂØπÊØî
try:
    fig2, axes = plt.subplots(2, 2, figsize=(16, 14))
    ax = axes[0, 0]
    hb = ax.hexbin(df["fitness"], df["predicted_fitness_order1"], gridsize=120, cmap='viridis', mincnt=1, edgecolors='none')
    ax.set_xlabel("Measured fitness", fontsize=12)
    ax.set_ylabel("Predicted fitness", fontsize=12)
    ax.set_title(f"Order 1 Model - Density\nPearson r={r1:.4f}" if 'r1' in dir() else "Order 1 Model - Density", fontsize=13, fontweight='bold')
    ax.axis('equal')
    ax.set_xlim(df["fitness"].min(), df["fitness"].max())
    ax.set_ylim(df["fitness"].min(), df["fitness"].max())
    plt.colorbar(hb, ax=ax, label='Density')

    ax = axes[0, 1]
    ax.scatter(df["fitness"], df["predicted_fitness_order1"], c=df["weight"], s=5, cmap='plasma', alpha=0.6, edgecolors='none')
    ax.set_xlabel("Measured fitness", fontsize=12)
    ax.set_ylabel("Predicted fitness", fontsize=12)
    ax.set_title(f"Order 1 Model - Weight\nPearson r={r1:.4f}" if 'r1' in dir() else "Order 1 Model - Weight", fontsize=13, fontweight='bold')
    ax.axis('equal')
    ax.set_xlim(df["fitness"].min(), df["fitness"].max())
    ax.set_ylim(df["fitness"].min(), df["fitness"].max())

    ax = axes[1, 0]
    hb = ax.hexbin(df["fitness"], df["predicted_fitness_order2"], gridsize=120, cmap='viridis', mincnt=1, edgecolors='none')
    ax.set_xlabel("Measured fitness", fontsize=12)
    ax.set_ylabel("Predicted fitness", fontsize=12)
    ax.set_title(f"Order 2 Model - Density\nPearson r={r2:.4f}" if 'r2' in dir() else "Order 2 Model - Density", fontsize=13, fontweight='bold')
    ax.axis('equal')
    ax.set_xlim(df["fitness"].min(), df["fitness"].max())
    ax.set_ylim(df["fitness"].min(), df["fitness"].max())
    plt.colorbar(hb, ax=ax, label='Density')

    ax = axes[1, 1]
    ax.scatter(df["fitness"], df["predicted_fitness_order2"], c=df["weight"], s=5, cmap='plasma', alpha=0.6, edgecolors='none')
    ax.set_xlabel("Measured fitness", fontsize=12)
    ax.set_ylabel("Predicted fitness", fontsize=12)
    ax.set_title(f"Order 2 Model - Weight\nPearson r={r2:.4f}" if 'r2' in dir() else "Order 2 Model - Weight", fontsize=13, fontweight='bold')
    ax.axis('equal')
    ax.set_xlim(df["fitness"].min(), df["fitness"].max())
    ax.set_ylim(df["fitness"].min(), df["fitness"].max())

    plt.tight_layout()
    fig2.savefig(f"{results_dir}/02_order1_vs_order2.png", dpi=300, bbox_inches='tight')
    plt.close()
    print(f"  ‚úì 02_order1_vs_order2.png")
except Exception as e:
    print(f"  ‚ö† Could not save Order 1 vs 2 figure: {e}")

# Figure 3: Order 2 ÂèÇÊï∞ÂàÜÂ∏É
try:
    theta_vals = model2.theta.detach().cpu().numpy()
    phi_pair_vals = model2.phi_pair.detach().cpu().numpy()
    fig3, axes = plt.subplots(2, 2, figsize=(15, 10))

    ax = axes[0, 0]
    ax.hist(theta_vals, bins=100, edgecolor='black', alpha=0.7, color='steelblue')
    ax.axvline(0, color='red', linestyle='--', linewidth=2)
    ax.set_xlabel('Parameter value', fontsize=12)
    ax.set_ylabel('Frequency', fontsize=12)
    ax.set_title(f'Single Mutations (theta)\nN={len(theta_vals)}, Mean={theta_vals.mean():.4f}', fontsize=13)
    ax.grid(alpha=0.3)

    ax = axes[0, 1]
    ax.hist(phi_pair_vals, bins=100, edgecolor='black', alpha=0.7, color='coral')
    ax.axvline(0, color='red', linestyle='--', linewidth=2)
    ax.set_xlabel('Parameter value', fontsize=12)
    ax.set_ylabel('Frequency', fontsize=12)
    ax.set_title(f'Pairwise Interactions (phi_pair)\nN={len(phi_pair_vals)}, Mean={phi_pair_vals.mean():.4f}', fontsize=13)
    ax.grid(alpha=0.3)

    ax = axes[1, 0]
    sorted_theta = np.sort(theta_vals)
    ax.plot(sorted_theta, linewidth=1.5, color='steelblue')
    ax.axhline(0, color='red', linestyle='--', linewidth=1.5, alpha=0.7)
    ax.set_xlabel('Feature index (sorted)', fontsize=12)
    ax.set_ylabel('Parameter value', fontsize=12)
    ax.set_title('Sorted Single Mutation Parameters', fontsize=13)
    ax.grid(alpha=0.3)

    ax = axes[1, 1]
    sorted_phi = np.sort(phi_pair_vals)
    ax.plot(sorted_phi, linewidth=1.5, color='coral')
    ax.axhline(0, color='red', linestyle='--', linewidth=1.5, alpha=0.7)
    ax.set_xlabel('Feature index (sorted)', fontsize=12)
    ax.set_ylabel('Parameter value', fontsize=12)
    ax.set_title('Sorted Pairwise Interaction Parameters', fontsize=13)
    ax.grid(alpha=0.3)

    plt.tight_layout()
    fig3.savefig(f"{results_dir}/03_order2_parameters.png", dpi=300, bbox_inches='tight')
    plt.close()
    print(f"  ‚úì 03_order2_parameters.png")
except Exception as e:
    print(f"  ‚ö† Could not save Order 2 parameters figure: {e}")

# Figure 4: ÂõõÊ®°ÂûãÂØπÊØîÔºàOrder 1, Order 2, Linear pairs, Linear singleÔºâ
try:
    fig4, axes = plt.subplots(2, 4, figsize=(24, 12))

    with torch.no_grad():
        y_p1, _ = model1(X)
        y_p1 = y_p1.cpu().numpy()
        _, phi2 = model2(X, X_pair)
        y_p2 = (1.0 / (1.0 + torch.exp(torch.clamp(phi2 / (8.314 * 303.0), -50, 50)))).cpu().numpy()
        y_p2 = model2.a.item() * y_p2 + model2.b.item()

    r1_tmp, p1_tmp = pearsonr(fitness_np, y_p1)
    r2_tmp, p2_tmp = pearsonr(fitness_np, y_p2)

    models_data = [
        ("Order 1\n(Single muts)", y_p1, r1_tmp),
        ("Order 2\n(With pairs)", y_p2, r2_tmp),
        ("Linear\n(Pairs)", y_pred_linear, r_linear_calc),
        ("Linear\n(Single)", y_pred_linear1_val, r_linear1_calc)
    ]

    # Compute y-axis limits SEPARATELY for each model to avoid compression
    model_ylims_4 = []
    for col, (name, y_pred, r) in enumerate(models_data):
        y_min = y_pred.min()
        y_max = y_pred.max()
        y_margin = (y_max - y_min) * 0.1 if (y_max - y_min) > 0 else 0.1
        model_ylims_4.append((y_min - y_margin, y_max + y_margin))

    # Top row: hexbin density plots
    for col, (name, y_pred, r) in enumerate(models_data):
        ax = axes[0, col]
        hb = ax.hexbin(fitness_np, y_pred, gridsize=100, cmap='viridis', mincnt=1, edgecolors='none')
        ax.plot([fitness_np.min(), fitness_np.max()], [fitness_np.min(), fitness_np.max()], 'r--', linewidth=2, alpha=0.7)
        ax.set_xlabel("Measured fitness", fontsize=11)
        ax.set_ylabel("Predicted fitness", fontsize=11)
        ax.set_title(f"{name}\nr={r:.4f}", fontsize=12, fontweight='bold')
        ax.set_xlim(fitness_np.min(), fitness_np.max())
        # Use individual model's y-limits for clearer visualization
        ax.set_ylim(model_ylims_4[col])
        plt.colorbar(hb, ax=ax, label='Density')

    # Bottom row: weight-colored scatter plots
    for col, (name, y_pred, r) in enumerate(models_data):
        ax = axes[1, col]
        scatter = ax.scatter(fitness_np, y_pred, c=weight_np, s=2, cmap='plasma', alpha=0.6, edgecolors='none')
        ax.plot([fitness_np.min(), fitness_np.max()], [fitness_np.min(), fitness_np.max()], 'r--', linewidth=2, alpha=0.7)
        ax.set_xlabel("Measured fitness", fontsize=11)
        ax.set_ylabel("Predicted fitness", fontsize=11)
        ax.set_title(f"Weight distribution", fontsize=11)
        ax.set_xlim(fitness_np.min(), fitness_np.max())
        # Use individual model's y-limits
        ax.set_ylim(model_ylims_4[col])
        plt.colorbar(scatter, ax=ax, label='Weight')

    plt.tight_layout()
    fig4.savefig(f"{results_dir}/04_four_models_comparison.png", dpi=300, bbox_inches='tight')
    plt.close()
    print(f"  ‚úì 04_four_models_comparison.png")
except Exception as e:
    print(f"  ‚ö† Could not save four models comparison figure: {e}")

# Figure 5: Á∫øÊÄßÊ®°ÂûãÂèÇÊï∞ÂàÜÂ∏ÉÔºàÂåÖÊã¨ Linear_order1Ôºâ
try:
    fig5, axes = plt.subplots(2, 2, figsize=(15, 10))

    theta_linear1 = model_linear1.theta.detach().cpu().numpy()
    ax = axes[0, 0]
    ax.hist(theta_linear1, bins=80, edgecolor='black', alpha=0.7, color='steelblue')
    ax.axvline(0, color='red', linestyle='--', linewidth=2)
    ax.set_xlabel('Parameter value', fontsize=12)
    ax.set_ylabel('Frequency', fontsize=12)
    ax.set_title(f'Linear_order1: Single Mutations\nN={len(theta_linear1)}', fontsize=13)
    ax.grid(alpha=0.3)

    theta_linear = model_linear.theta.detach().cpu().numpy()
    phi_pair_linear = model_linear.phi_pair.detach().cpu().numpy()
    ax = axes[0, 1]
    ax.hist(phi_pair_linear, bins=80, edgecolor='black', alpha=0.7, color='coral')
    ax.axvline(0, color='red', linestyle='--', linewidth=2)
    ax.set_xlabel('Parameter value', fontsize=12)
    ax.set_ylabel('Frequency', fontsize=12)
    ax.set_title(f'Linear (pairs): Pairwise Interactions\nN={len(phi_pair_linear)}', fontsize=13)
    ax.grid(alpha=0.3)

    ax = axes[1, 0]
    sorted_t1 = np.sort(theta_linear1)
    ax.plot(sorted_t1, linewidth=1.5, color='steelblue', label='Linear_order1')
    sorted_t2 = np.sort(theta_linear)
    ax.plot(sorted_t2, linewidth=1.5, color='orange', label='Linear (pairs)', alpha=0.7)
    ax.axhline(0, color='red', linestyle='--', linewidth=1.5, alpha=0.7)
    ax.set_xlabel('Feature index (sorted)', fontsize=12)
    ax.set_ylabel('Parameter value', fontsize=12)
    ax.set_title('Sorted Single Mutation Parameters Comparison', fontsize=13)
    ax.legend(fontsize=11)
    ax.grid(alpha=0.3)

    ax = axes[1, 1]
    sorted_p = np.sort(phi_pair_linear)
    ax.plot(sorted_p, linewidth=1.5, color='coral')
    ax.axhline(0, color='red', linestyle='--', linewidth=1.5, alpha=0.7)
    ax.fill_between(range(len(sorted_p)), sorted_p, 0, alpha=0.3, color='coral')
    ax.set_xlabel('Feature index (sorted)', fontsize=12)
    ax.set_ylabel('Parameter value', fontsize=12)
    ax.set_title('Sorted Pairwise Interaction Parameters (Linear)', fontsize=13)
    ax.grid(alpha=0.3)

    plt.tight_layout()
    fig5.savefig(f"{results_dir}/05_linear_model_parameters.png", dpi=300, bbox_inches='tight')
    plt.close()
    print(f"  ‚úì 05_linear_model_parameters.png")
except Exception as e:
    print(f"  ‚ö† Could not save linear model parameters figure: {e}")

# ========= 2. ‰øùÂ≠òÊ®°ÂûãÂèÇÊï∞CSV =========
print("\nüìã Saving parameter CSV files...")

# Save Linear_order1 parameters (single mutations only)
linear1_theta = pd.DataFrame({'mutation': all_mutations, 'theta_linear1': model_linear1.theta.detach().cpu().numpy()})
linear1_theta.to_csv(f"{results_dir}/linear1_theta.csv", index=False)
print(f"  ‚úì linear1_theta.csv")

# Save Linear model (with pairs) parameters if available
try:
    linear_theta_df = pd.DataFrame({'mutation': all_mutations, 'theta': model_linear.theta.detach().cpu().numpy()})
    linear_phi_pair_df = pd.DataFrame({'pair_idx': range(len(filtered_pairs)), 'phi_pair': model_linear.phi_pair.detach().cpu().numpy()})
    linear_theta_df.to_csv(f"{results_dir}/linear_theta.csv", index=False)
    linear_phi_pair_df.to_csv(f"{results_dir}/linear_phi_pair.csv", index=False)
    print(f"  ‚úì linear_theta.csv")
    print(f"  ‚úì linear_phi_pair.csv")
except:
    print(f"  ‚ö† Could not save linear (pairs) parameters")

# ========= 3. ‰øùÂ≠òÈ¢ÑÊµãÁªìÊûú =========
print("\nüìä Saving predictions...")
try:
    predictions_df_save = df[['aa_seq', 'fitness', 'predicted_fitness_linear', 'predicted_fitness_linear1', 'weight']].copy()
    predictions_df_save.to_csv(f"{results_dir}/predictions_linear_models.csv", index=False)
    print(f"  ‚úì predictions_linear_models.csv ({len(predictions_df_save)} sequences)")
except Exception as e:
    print(f"  ‚ö† Could not save predictions: {e}")

# ========= 4. ‰øùÂ≠òÊÄßËÉΩÊåáÊ†á =========
print("\nüìà Saving performance metrics...")
try:
    perf = {
        "Dataset": {
            "N_sequences": int(N),
            "N_mutations": int(M),
            "N_pairs": int(P),
            "Filtered_pairs_min_support_3": int(len(filtered_pairs))
        },
        "Linear_single": {
            "Model": "Linear, single-mutation only (LASSO)",
            "Description": "Uses LASSO to fit only single-mutation features (no pairwise interactions)",
            "Pearson_r": float(r_linear1_calc),
            "p_value": float(p_linear1_calc),
            "MAE_weighted": float(mae_linear1),
            "NNZ_theta": int((np.abs(model_linear1.theta.detach().cpu().numpy()) > 1e-10).sum()),
            "N_theta": int(len(model_linear1.theta.detach().cpu().numpy()))
        },
        "Linear_pairs": {
            "Model": "Linear, with pairwise interactions (LASSO)",
            "Description": "Uses LASSO to fit single-mutation and pairwise epistasis features",
            "Pearson_r": float(r_linear_calc),
            "p_value": float(p_linear_calc),
            "MAE_weighted": float(mae_linear),
            "Sparsity_percent": float(sparsity_pct),
            "NNZ_theta": int((np.abs(model_linear.theta.detach().cpu().numpy()) > 1e-10).sum()),
            "NNZ_phi_pair": int(nnz_phi_pair)
        }
    }
    
    with open(f"{results_dir}/performance_metrics.json", 'w') as f:
        json.dump(perf, f, indent=2)
    print(f"  ‚úì performance_metrics.json")
except Exception as e:
    print(f"  ‚ö† Could not save performance metrics: {e}")

# ========= 5. ‰øùÂ≠òÊ®°ÂûãÊ£ÄÊü•ÁÇπ =========
print("\nü§ñ Saving model checkpoints...")
try:
    torch.save(model_linear1.state_dict(), f"{results_dir}/model_linear1_checkpoint.pth")
    print(f"  ‚úì model_linear1_checkpoint.pth (Linear single-mutation model)")
except Exception as e:
    print(f"  ‚ö† Could not save model_linear1: {e}")

try:
    torch.save(model_linear.state_dict(), f"{results_dir}/model_linear_checkpoint.pth")
    print(f"  ‚úì model_linear_checkpoint.pth (Linear with-pairs model)")
except Exception as e:
    print(f"  ‚ö† Could not save model_linear: {e}")

# ========= 6. ‰øùÂ≠òREADME =========
print("\nüìù Saving README...")
try:
    readme = f"""# MoCHI Model Training Results - Linear & MoCHI Models
Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}

## Summary
This results folder contains comprehensive training results including:
- **MoCHI Models**: Order 1 (non-linear, single muts), Order 2 (non-linear, with pairs)
- **Linear Models**: Linear_order1 (single muts only), Linear (with pairs) - both trained via LASSO

## Dataset Summary
- Total sequences: {N}
- Single mutations (M): {M}
- Pairwise mutations (P): {P}
- Filtered pairs (min_support=3): {len(filtered_pairs)}

## Files

### Visualizations (5 PNG files)
- **01_training_loss.png**: Training dynamics for Order 1 & 2 models (loss curves)
- **02_order1_vs_order2.png**: MoCHI model comparison with density and weight plots
- **03_order2_parameters.png**: Parameter distributions for MoCHI Order 2
- **04_four_models_comparison.png**: All 4 models side-by-side (Order 1, Order 2, Linear pairs, Linear single)
- **05_linear_model_parameters.png**: Parameter distributions for Linear models

### Parameters (3 CSV files)
- **linear1_theta.csv**: Linear_order1 single-mutation parameters ({len(linear1_theta)} mutations)
- **linear_theta.csv**: Linear (with pairs) single-mutation parameters
- **linear_phi_pair.csv**: Linear (with pairs) pairwise interaction parameters

### Results
- **predictions_linear_models.csv**: Full predictions for all {N} sequences
- **performance_metrics.json**: Quantitative performance comparison

### Models
- **model_linear1_checkpoint.pth**: Linear_order1 weights (single-mutation only)
- **model_linear_checkpoint.pth**: Linear model weights (with pairwise interactions)

## Performance Comparison

| Model | Type | Pearson r | MAE | Sparsity |
|-------|------|-----------|-----|----------|
| Linear (single) | LASSO, M only | {r_linear1_calc:.4f} | {mae_linear1:.6f} | N/A |
| Linear (pairs) | LASSO, M+P | {r_linear_calc:.4f} | {mae_linear:.6f} | {sparsity_pct:.2f}% |

## Key Results

### Linear Models
- **Linear_order1** (r={r_linear1_calc:.4f}): Uses only single mutations without epistasis
- **Linear with pairs** (r={r_linear_calc:.4f}): Adds pairwise interactions
  - Correlation improvement: +{(r_linear_calc/r_linear1_calc-1)*100:.2f}%
  - MAE reduction: {(1-mae_linear/mae_linear1)*100:.1f}% (from {mae_linear1:.4f} ‚Üí {mae_linear:.4f})
  - Sparsity: {sparsity_pct:.2f}% sparse in phi_pair (via LASSO L1)

## Model Architectures

### Linear_order1 (Single-mutation only)
```
yhat = phi0 + X @ theta
```

### Linear (with Pairwise)
```
yhat = phi0 + X @ theta + X_pair @ phi_pair
```

## Notes
- All linear models use weighted least squares fitting
- LASSO regularization (L1 penalty) automatically induces sparsity
- Grid search over alpha values to find optimal regularization strength
"""

    with open(f"{results_dir}/README.md", 'w') as f:
        f.write(readme)
    print(f"  ‚úì README.md")
except Exception as e:
    print(f"  ‚ö† Could not save README: {e}")

# ========= 7. ÊúÄÁªàÁªüËÆ° =========
print(f"\n{'='*70}")
print(f"‚úÖ All results saved successfully!")
print(f"{'='*70}")
print(f"\nüìÅ Location: {os.path.abspath(results_dir)}")
print(f"\nüìä Files generated ({len(os.listdir(results_dir))} total):")

file_stats = {}
for fname in sorted(os.listdir(results_dir)):
    fpath = os.path.join(results_dir, fname)
    size = os.path.getsize(fpath)
    file_stats[fname] = size
    size_mb = size / (1024**2)
    if size < 1024:
        print(f"  ‚Ä¢ {fname:<40} {size:>10} B")
    elif size < 1024**2:
        print(f"  ‚Ä¢ {fname:<40} {size/1024:>10.2f} KB")
    else:
        print(f"  ‚Ä¢ {fname:<40} {size_mb:>10.2f} MB")

total_size = sum(file_stats.values()) / (1024**2)
print(f"\nüì¶ Total size: {total_size:.2f} MB")
print(f"{'='*70}\n")

# ========= 8. Ê®°ÂûãÊÄßËÉΩÊÄªÁªì =========
print("\n" + "="*70)
print("Model Performance Summary")
print("="*70)
print(f"Linear_order1 (single mutations only):")
print(f"  Pearson r: {r_linear1_calc:.6f}")
print(f"  MAE (weighted): {mae_linear1:.6f}")
print(f"  Non-zero parameters: {int((np.abs(model_linear1.theta.detach().cpu().numpy()) > 1e-10).sum())}/{M}")
print(f"\nLinear (with pairwise interactions):")
print(f"  Pearson r: {r_linear_calc:.6f}")
print(f"  MAE (weighted): {mae_linear:.6f}")
print(f"  Sparsity (phi_pair): {sparsity_pct:.2f}%")
print(f"  Non-zero single mutations: {int((np.abs(model_linear.theta.detach().cpu().numpy()) > 1e-10).sum())}/{M}")
print(f"  Non-zero pairwise: {int(nnz_phi_pair)}/{len(filtered_pairs)}")
print(f"\nImprovement (single ‚Üí pairs):")
print(f"  Pearson r improvement: +{(r_linear_calc - r_linear1_calc):.6f} ({(r_linear_calc/r_linear1_calc - 1)*100:.2f}%)")
print(f"  MAE reduction: {(mae_linear1 - mae_linear):.6f} ({(1 - mae_linear/mae_linear1)*100:.2f}% better)")
print("="*70)
