In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split, KFold
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error, r2_score
from hyperopt import hp, fmin, tpe, Trials, STATUS_OK
from joblib import Parallel, delayed
import os
from scipy.stats import pearsonr

# ----------------------------
# Proximal operator for L0.5
# ----------------------------
def prox_half(v, alpha):
    """
    Proximal operator for the L0.5 regularization.

    Args:
        v (torch.Tensor): The vector to apply the proximal operator to.
        alpha (float or torch.Tensor): The regularization parameter.

    Returns:
        torch.Tensor: The result of the proximal operation.
    """
    if not isinstance(alpha, torch.Tensor):
        alpha = torch.tensor(alpha, device=v.device, dtype=v.dtype)

    x = torch.abs(v)
    thr = (3/2 * alpha) ** (2/3)
    mask = x > thr
    w = torch.zeros_like(v)

    # Check if the mask is not empty before indexing
    if mask.any():
        v_masked = v[mask]
        x_masked = x[mask]

        # Corrected proximal operator for L0.5
        term = (alpha / 4) * (x_masked) ** (-1.5)
        w[mask] = (2/3) * v_masked * (1 + torch.cos(
            2/3 * np.pi - (2/3) * torch.acos(torch.clamp(term, -1, 1))
        ))

    return w

class ProxGEN(optim.Optimizer):
    """
    Adaptive proximal gradient optimizer (based on Adam) for various regularizers.
    """
    def __init__(self, params, lr=1e-3, lam=1e-6, betas=(0.1, 0.999), eps=1e-8):
        defaults = dict(lr=lr, lam=lam, betas=betas, eps=eps)
        super(ProxGEN, self).__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue
                grad = p.grad.data
                state = self.state[p]

                if len(state) == 0:
                    state["step"] = 0
                    state["exp_avg"] = torch.zeros_like(p)
                    state["exp_avg_sq"] = torch.zeros_like(p)

                exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
                beta1, beta2 = group["betas"]

                state["step"] += 1
                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

                denom = exp_avg_sq.sqrt().add_(group["eps"])
                step_size = group["lr"]

                theta_hat = p.data.addcdiv(exp_avg, denom, value=-step_size)

                # Proximal step for L_0.5
                p.data = prox_half(theta_hat, group["lam"] * step_size)

        return loss

# ----------------------------
# ResNet-10 Definition (1D)
# ----------------------------
class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, in_channels, out_channels, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm1d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv1d(in_channels, out_channels, kernel_size=1,
                          stride=stride, bias=False),
                nn.BatchNorm1d(out_channels),
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_tasks=7):
        super(ResNet, self).__init__()
        self.in_channels = 64
        self.conv1 = nn.Conv1d(1, 64, kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.bn1 = nn.BatchNorm1d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)

        self.linear = nn.Linear(256 * block.expansion, num_tasks)

    def _make_layer(self, block, out_channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for s in strides:
            layers.append(block(self.in_channels, out_channels, s))
            self.in_channels = out_channels * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = F.adaptive_avg_pool1d(out, 1)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

def ResNet10(num_tasks=7):
    # Custom ResNet-10 configuration with 1, 1, and 1 blocks.
    return ResNet(BasicBlock, [1, 1, 1], num_tasks)

# ----------------------------
# Metrics Calculation
# ----------------------------
def calculate_sparsity(model):
    """Calculates the sparsity of the model's weights."""
    total_params = 0
    non_zero_params = 0
    for param in model.parameters():
        total_params += param.numel()
        non_zero_params += torch.count_nonzero(param.data).item()

    if total_params == 0:
        return 0.0

    sparsity = (1 - (non_zero_params / total_params)) * 100
    return sparsity

def calculate_correlations(labels, predictions):
    """Calculates the Pearson correlation coefficient for each trait."""
    num_tasks = labels.shape[1]
    correlations = []
    for i in range(num_tasks):
        corr, _ = pearsonr(labels[:, i], predictions[:, i])
        correlations.append(corr)
    return correlations

# ----------------------------
# Load Data
# ----------------------------
data = pd.read_csv("/content/drive/My Drive/Adaptive gradient method/Adaptive gradient method from L_0 to L infnity/Final_pine_data.csv")
Y = data.iloc[:, :7].values  # regression tasks
X = data.iloc[:, 7:].values  # SNPs/features

scaler = StandardScaler()
X = scaler.fit_transform(X)

# Split into train+val and test
X_trainval, X_test, Y_trainval, Y_test = train_test_split(
    X, Y, test_size=0.2, random_state=42
)

# ----------------------------
# Training + CV Evaluation
# ----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs for parallel processing.")
else:
    print("Using a single GPU or CPU.")

def train_and_eval(X_train, Y_train, X_val, Y_val, params):
    train_ds = torch.utils.data.TensorDataset(
        torch.tensor(X_train, dtype=torch.float32).unsqueeze(1),
        torch.tensor(Y_train, dtype=torch.float32),
    )
    val_ds = torch.utils.data.TensorDataset(
        torch.tensor(X_val, dtype=torch.float32).unsqueeze(1),
        torch.tensor(Y_val, dtype=torch.float32),
    )
    train_loader = torch.utils.data.DataLoader(train_ds, batch_size=int(params["batch_size"]), shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_ds, batch_size=128, shuffle=False)

    model = ResNet10(num_tasks=7) # Changed to ResNet10
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    model.to(device)

    optimizer = ProxGEN(model.parameters(), lr=params["lr"], lam=params["lam"])
    criterion = nn.MSELoss()

    for epoch in range(30):
        model.train()
        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb.to(device)
            pred = model(xb)
            loss = criterion(pred, yb)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    model.eval()
    preds, labels = [], []
    with torch.no_grad():
        for xb, yb in val_loader:
            xb, yb = xb.to(device), yb.to(device)
            pred = model(xb)
            preds.append(pred.cpu().numpy())
            labels.append(yb.cpu().numpy())
    preds, labels = np.vstack(preds), np.vstack(labels)
    mse = mean_squared_error(labels, preds)
    return mse

def objective(params):
    kf = KFold(n_splits=5, shuffle=True, random_state=42)
    mses = Parallel(n_jobs=-1)(delayed(train_and_eval)(
        X_trainval[train_idx], Y_trainval[train_idx],
        X_trainval[val_idx], Y_trainval[val_idx],
        params
    ) for train_idx, val_idx in kf.split(X_trainval))
    return {"loss": np.mean(mses), "status": STATUS_OK}

# ----------------------------
# Bayesian Optimization
# ----------------------------
search_space = {
    "lr": hp.loguniform("lr", np.log(1e-4), np.log(1e-2)),
    "lam": hp.loguniform("lam", np.log(1e-3), np.log(1e+2)),
    "batch_size": hp.choice("batch_size", [32, 64, 128]),
}

trials = Trials()
best = fmin(fn=objective, space=search_space, algo=tpe.suggest,
            max_evals=50, trials=trials, rstate=np.random.default_rng(42))

print("Best hyperparameters:", best)
print("Best batch size:", [32, 64, 128][best["batch_size"]])

# ----------------------------
# Final Training on train+val
# ----------------------------
final_model = ResNet10(num_tasks=7) # Changed to ResNet10
if torch.cuda.device_count() > 1:
    final_model = nn.DataParallel(final_model)
final_model.to(device)

final_optimizer = ProxGEN(final_model.parameters(),
                          lr=best["lr"], lam=best["lam"])
criterion = nn.MSELoss()

trainval_ds = torch.utils.data.TensorDataset(
    torch.tensor(X_trainval, dtype=torch.float32).unsqueeze(1),
    torch.tensor(Y_trainval, dtype=torch.float32),
)
trainval_loader = torch.utils.data.DataLoader(trainval_ds,
                                              batch_size=[32,64,128][best["batch_size"]],
                                              shuffle=True)

for epoch in range(50):
    final_model.train()
    for xb, yb in trainval_loader:
        xb, yb = xb.to(device), yb.to(device)
        pred = final_model(xb)
        loss = criterion(pred, yb)
        final_optimizer.zero_grad()
        loss.backward()
        final_optimizer.step()

# ----------------------------
# Test Evaluation and Metrics Calculation
# ----------------------------
test_ds = torch.utils.data.TensorDataset(
    torch.tensor(X_test, dtype=torch.float32).unsqueeze(1),
    torch.tensor(Y_test, dtype=torch.float32),
)
test_loader = torch.utils.data.DataLoader(test_ds, batch_size=128, shuffle=False)

final_model.eval()
preds, labels = [], []
with torch.no_grad():
    for xb, yb in test_loader:
        xb, yb = xb.to(device), yb.to(device)
        pred = final_model(xb)
        preds.append(pred.cpu().numpy())
        labels.append(yb.cpu().numpy())
preds, labels = np.vstack(preds), np.vstack(labels)

# Calculate and print total MSE and R2
total_mse = mean_squared_error(labels, preds)
r2_per_trait = r2_score(labels, preds, multioutput="raw_values")
print("\n--- Final Test Evaluation ---")
print(f"Total Test MSE: {total_mse:.4f}")
print("R2 per trait:", r2_per_trait)

# Calculate and print MSE per trait
mse_per_trait = mean_squared_error(labels, preds, multioutput='raw_values')
print("Mean Test MSE for each trait:", mse_per_trait)

# Calculate and print Pearson correlation coefficient
correlations = calculate_correlations(labels, preds)
print("Pearson correlation coefficient per trait:", correlations)

# Calculate and print sparsity
if isinstance(final_model, nn.DataParallel):
    sparsity = calculate_sparsity(final_model.module)
else:
    sparsity = calculate_sparsity(final_model)
print(f"Model sparsity: {sparsity:.2f}%")