# cov_vaccine_degradation

In [5]:
import os
import argparse
import torch

from data.dataloaders import get_multimodal_loaders
from utils.load_config import load_config
from utils.calculate_embeddings import calculate_embeddings
from models.multimodel import build_model
from trainer import RegressionTrainer


def _has_any_embeddings(emb_dir) -> bool:
    if not os.path.isdir(emb_dir):
        return False
    try:
        for fn in os.listdir(emb_dir):
            if fn.endswith(".pt"):
                return True
    except FileNotFoundError:
        return False
    return False


def main(name: str, dataset: str, max_len: int, batch_size: int, epochs: int):
    config = load_config(f"{name}.yml")
    config["Dataset"] = dataset
    config["device"] = "cuda" if torch.cuda.is_available() else "cpu"

    max_len = int(max_len)

    filtered_csv = f"data/datasets/{config['Dataset']}_multimodal_filtered_maxlen{max_len}.csv"

    dna_dir = f"embeddings/{config['Dataset']}/DNA/maxlen{max_len}"
    rna_dir = f"embeddings/{config['Dataset']}/RNA/maxlen{max_len}"
    prot_dir = f"embeddings/{config['Dataset']}/Protein/maxlen{max_len}"

    need_filtered_csv = not os.path.exists(filtered_csv)
    need_dna = not _has_any_embeddings(dna_dir)
    need_rna = not _has_any_embeddings(rna_dir)
    need_prot = not _has_any_embeddings(prot_dir)
    need_embeddings = need_dna or need_rna or need_prot

    if need_filtered_csv or need_embeddings:
        print("Embeddings and/or filtered CSV not found, calculating...")
        for modality in ("DNA", "RNA", "Protein"):
            calculate_embeddings(
                dataset=config["Dataset"],
                modality=modality,
                device=config["device"],
                max_len=max_len,
            )

    print("=" * 60)
    print("Training Configuration (Multimodal):")
    print(f"  Config name: {config.get('name', name)}")
    print(f"  Dataset: {config['Dataset']}")
    print(f"  Max Len (filter): {max_len}")
    print(f"  Fusion: {config.get('fusion_type', 'concat')}")
    print(f"  Batch size: {batch_size}")
    print(f"  Epochs: {epochs}")
    print(f"  Device: {config['device']}")
    print("=" * 60)

    print("\nLoading data...")
    train_loader, val_loader, test_loader = get_multimodal_loaders(
        config["Dataset"],
        batch_size=batch_size,
        max_len=max_len,
    )

    print("\nInitializing model...")
    model = build_model(config)

    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    non_trainable_params = total_params - trainable_params
    print(f"Total number of parameters: {total_params}")
    print(f"Trainable parameters: {trainable_params}")
    print(f"Non-trainable parameters: {non_trainable_params}")

    print("\nInitializing trainer...")
    trainer = RegressionTrainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
        device=config["device"],
        save_dir=f"./plots/{config.get('name', name)}/{config['Dataset']}",
    )

    # entropy reg for MIL (nested config with backwards-compatible fallback)
    lam_entropy = None
    if isinstance(config.get("trainer", None), dict):
        lam_entropy = config["trainer"].get("lam_entropy", None)
    if lam_entropy is None:
        lam_entropy = config.get("lam_entropy", None)

    if lam_entropy is not None:
        trainer.lam_entropy = float(lam_entropy)
        if trainer.lam_entropy > 0:
            print(f"Using MIL entropy regularization: lam_entropy={trainer.lam_entropy}")

    print("\n" + "=" * 60)
    print("Starting training...")
    print("=" * 60 + "\n")

    trainer.train(epochs=epochs)

    print("\n" + "=" * 60)
    print("Training completed!")
    print("=" * 60)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train multimodal model (DNA+RNA+Protein)")
    parser.add_argument(
        "--name",
        type=str,
        default="fusion_concat",
        help="Config file name (without .yml). Options: fusion_concat, fusion_mil, fusion_xattn",
    )
    parser.add_argument(
        "--dataset",
        type=str,
        default="cov_vaccine_degradation",
        help=(
            "Dataset to use. Default: fungal_expression. Options: "
            "'mrna_stability', 'ecoli_proteins', 'cov_vaccine_degradation', 'fungal_expression'"
        ),
    )
    parser.add_argument(
        "--max-len",
        type=int,
        default=1000,
        help="Filter threshold: keep only sequences with raw length <= max_len before embedding/training.",
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=32,
        help="Batch size.",
    )
    parser.add_argument(
        "--epochs",
        type=int,
        default=500,
        help="Number of training epochs.",
    )
    args, _ = parser.parse_known_args()

    main(
        name=args.name,
        dataset=args.dataset,
        max_len=args.max_len,
        batch_size=args.batch_size,
        epochs=args.epochs,
    )

Training Configuration (Multimodal):
  Config name: fusion_concat
  Dataset: cov_vaccine_degradation
  Max Len (filter): 1000
  Fusion: concat
  Batch size: 32
  Epochs: 500
  Device: cpu

Loading data...

Initializing model...
Total number of parameters: 2497753
Trainable parameters: 2497753
Non-trainable parameters: 0

Initializing trainer...

Starting training...

Starting training for 500 epochs...
Device: cpu
Save directory: plots/fusion_concat/cov_vaccine_degradation


Training: 100%|█████████████████████████████████| 50/50 [00:10<00:00,  4.97it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:01<00:00, 11.62it/s]


Epoch 1/500
  Train Loss: 1.1657, MSE: 1.1657, Spearman: 0.2603, LR: 0.000030
  Val Loss: 0.4941, MSE: 0.4941, Spearman: 0.6864


Training: 100%|█████████████████████████████████| 50/50 [00:09<00:00,  5.47it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:01<00:00, 11.31it/s]


Epoch 2/500
  Train Loss: 1.0804, MSE: 1.0804, Spearman: 0.3997, LR: 0.000030
  Val Loss: 0.4164, MSE: 0.4164, Spearman: 0.7122


Training: 100%|█████████████████████████████████| 50/50 [00:08<00:00,  6.24it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:01<00:00, 10.71it/s]


Epoch 3/500
  Train Loss: 0.9973, MSE: 0.9973, Spearman: 0.5152, LR: 0.000030
  Val Loss: 0.3611, MSE: 0.3611, Spearman: 0.7258


Training: 100%|█████████████████████████████████| 50/50 [00:09<00:00,  5.36it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:01<00:00, 11.57it/s]


Epoch 4/500
  Train Loss: 0.9349, MSE: 0.9349, Spearman: 0.5528, LR: 0.000030
  Val Loss: 0.3172, MSE: 0.3172, Spearman: 0.7520


Training: 100%|█████████████████████████████████| 50/50 [00:11<00:00,  4.31it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:01<00:00, 10.44it/s]


Epoch 5/500
  Train Loss: 0.8856, MSE: 0.8856, Spearman: 0.5803, LR: 0.000030
  Val Loss: 0.2973, MSE: 0.2973, Spearman: 0.7541


Training: 100%|█████████████████████████████████| 50/50 [00:11<00:00,  4.30it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:01<00:00, 11.10it/s]


Epoch 6/500
  Train Loss: 0.8257, MSE: 0.8257, Spearman: 0.6307, LR: 0.000030
  Val Loss: 0.2988, MSE: 0.2988, Spearman: 0.7600


Training: 100%|█████████████████████████████████| 50/50 [00:12<00:00,  4.00it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:01<00:00,  9.72it/s]


Epoch 7/500
  Train Loss: 0.8009, MSE: 0.8009, Spearman: 0.6322, LR: 0.000030
  Val Loss: 0.2406, MSE: 0.2406, Spearman: 0.7756


Training: 100%|█████████████████████████████████| 50/50 [00:13<00:00,  3.76it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:01<00:00,  9.97it/s]


Epoch 8/500
  Train Loss: 0.7565, MSE: 0.7565, Spearman: 0.6698, LR: 0.000030
  Val Loss: 0.2558, MSE: 0.2558, Spearman: 0.7767


Training: 100%|█████████████████████████████████| 50/50 [00:16<00:00,  2.96it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:01<00:00,  9.70it/s]


Epoch 9/500
  Train Loss: 0.7439, MSE: 0.7439, Spearman: 0.6725, LR: 0.000030
  Val Loss: 0.2391, MSE: 0.2391, Spearman: 0.7781


Training: 100%|█████████████████████████████████| 50/50 [00:10<00:00,  4.71it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:01<00:00,  9.98it/s]


Epoch 10/500
  Train Loss: 0.7225, MSE: 0.7225, Spearman: 0.6911, LR: 0.000030
  Val Loss: 0.2673, MSE: 0.2673, Spearman: 0.7855


Training: 100%|█████████████████████████████████| 50/50 [00:12<00:00,  3.97it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:01<00:00, 11.17it/s]


Epoch 11/500
  Train Loss: 0.6991, MSE: 0.6991, Spearman: 0.6948, LR: 0.000030
  Val Loss: 0.2204, MSE: 0.2204, Spearman: 0.7928


Training: 100%|█████████████████████████████████| 50/50 [00:12<00:00,  4.14it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:01<00:00, 11.16it/s]


Epoch 12/500
  Train Loss: 0.6925, MSE: 0.6925, Spearman: 0.6949, LR: 0.000030
  Val Loss: 0.2567, MSE: 0.2567, Spearman: 0.7950


Training: 100%|█████████████████████████████████| 50/50 [00:11<00:00,  4.17it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:01<00:00, 10.74it/s]


Epoch 13/500
  Train Loss: 0.6662, MSE: 0.6662, Spearman: 0.7138, LR: 0.000030
  Val Loss: 0.2370, MSE: 0.2370, Spearman: 0.7904


Training: 100%|█████████████████████████████████| 50/50 [00:11<00:00,  4.24it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:01<00:00, 11.46it/s]


Epoch 14/500
  Train Loss: 0.6492, MSE: 0.6492, Spearman: 0.7245, LR: 0.000030
  Val Loss: 0.2133, MSE: 0.2133, Spearman: 0.7912


Training: 100%|█████████████████████████████████| 50/50 [00:11<00:00,  4.20it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:01<00:00, 11.25it/s]


Epoch 15/500
  Train Loss: 0.6325, MSE: 0.6325, Spearman: 0.7363, LR: 0.000030
  Val Loss: 0.2109, MSE: 0.2109, Spearman: 0.7949


Training: 100%|█████████████████████████████████| 50/50 [00:13<00:00,  3.73it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:01<00:00, 10.76it/s]


Epoch 16/500
  Train Loss: 0.6146, MSE: 0.6146, Spearman: 0.7503, LR: 0.000030
  Val Loss: 0.2237, MSE: 0.2237, Spearman: 0.7959


Training: 100%|█████████████████████████████████| 50/50 [00:11<00:00,  4.32it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:01<00:00, 10.40it/s]


Epoch 17/500
  Train Loss: 0.5959, MSE: 0.5959, Spearman: 0.7527, LR: 0.000030
  Val Loss: 0.2121, MSE: 0.2121, Spearman: 0.8019


Training: 100%|█████████████████████████████████| 50/50 [00:11<00:00,  4.29it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:01<00:00, 10.39it/s]


Epoch 18/500
  Train Loss: 0.5861, MSE: 0.5861, Spearman: 0.7604, LR: 0.000030
  Val Loss: 0.2053, MSE: 0.2053, Spearman: 0.8033


Training: 100%|█████████████████████████████████| 50/50 [00:11<00:00,  4.40it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:01<00:00, 11.45it/s]


Epoch 19/500
  Train Loss: 0.5920, MSE: 0.5920, Spearman: 0.7557, LR: 0.000030
  Val Loss: 0.2172, MSE: 0.2172, Spearman: 0.8072


Training: 100%|█████████████████████████████████| 50/50 [00:10<00:00,  4.61it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:01<00:00, 10.70it/s]


Epoch 20/500
  Train Loss: 0.5595, MSE: 0.5595, Spearman: 0.7759, LR: 0.000030
  Val Loss: 0.2243, MSE: 0.2243, Spearman: 0.7967


Training: 100%|█████████████████████████████████| 50/50 [00:09<00:00,  5.23it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:00<00:00, 13.81it/s]


Epoch 21/500
  Train Loss: 0.5346, MSE: 0.5346, Spearman: 0.7820, LR: 0.000030
  Val Loss: 0.2203, MSE: 0.2203, Spearman: 0.8093


Training: 100%|█████████████████████████████████| 50/50 [00:08<00:00,  5.92it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:00<00:00, 13.37it/s]


Epoch 22/500
  Train Loss: 0.5143, MSE: 0.5143, Spearman: 0.7931, LR: 0.000030
  Val Loss: 0.1960, MSE: 0.1960, Spearman: 0.8116


Training: 100%|█████████████████████████████████| 50/50 [00:07<00:00,  6.32it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:00<00:00, 13.76it/s]


Epoch 23/500
  Train Loss: 0.5051, MSE: 0.5051, Spearman: 0.7941, LR: 0.000030
  Val Loss: 0.1970, MSE: 0.1970, Spearman: 0.8148


Training: 100%|█████████████████████████████████| 50/50 [00:07<00:00,  6.33it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:00<00:00, 16.80it/s]


Epoch 24/500
  Train Loss: 0.4984, MSE: 0.4984, Spearman: 0.7966, LR: 0.000030
  Val Loss: 0.2067, MSE: 0.2067, Spearman: 0.8028


Training: 100%|█████████████████████████████████| 50/50 [00:07<00:00,  6.50it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:00<00:00, 17.30it/s]


Epoch 25/500
  Train Loss: 0.4749, MSE: 0.4749, Spearman: 0.8073, LR: 0.000030
  Val Loss: 0.2198, MSE: 0.2198, Spearman: 0.8079


Training: 100%|█████████████████████████████████| 50/50 [00:07<00:00,  6.43it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:00<00:00, 15.72it/s]


Epoch 26/500
  Train Loss: 0.4672, MSE: 0.4672, Spearman: 0.8090, LR: 0.000030
  Val Loss: 0.2061, MSE: 0.2061, Spearman: 0.8098


Training: 100%|█████████████████████████████████| 50/50 [00:07<00:00,  6.50it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:00<00:00, 15.95it/s]


Epoch 27/500
  Train Loss: 0.4480, MSE: 0.4480, Spearman: 0.8156, LR: 0.000030
  Val Loss: 0.2066, MSE: 0.2066, Spearman: 0.8115


Training: 100%|█████████████████████████████████| 50/50 [00:07<00:00,  6.41it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:00<00:00, 16.93it/s]


Epoch 28/500
  Train Loss: 0.4499, MSE: 0.4499, Spearman: 0.8166, LR: 0.000030
  Val Loss: 0.1942, MSE: 0.1942, Spearman: 0.8181


Training: 100%|█████████████████████████████████| 50/50 [00:07<00:00,  6.54it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:00<00:00, 16.86it/s]


Epoch 29/500
  Train Loss: 0.4487, MSE: 0.4487, Spearman: 0.8143, LR: 0.000030
  Val Loss: 0.2041, MSE: 0.2041, Spearman: 0.8066


Training: 100%|█████████████████████████████████| 50/50 [00:07<00:00,  6.64it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:00<00:00, 15.58it/s]


Epoch 30/500
  Train Loss: 0.4190, MSE: 0.4190, Spearman: 0.8296, LR: 0.000030
  Val Loss: 0.2275, MSE: 0.2275, Spearman: 0.8155


Training: 100%|█████████████████████████████████| 50/50 [00:07<00:00,  6.76it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:00<00:00, 17.84it/s]


Epoch 31/500
  Train Loss: 0.4252, MSE: 0.4252, Spearman: 0.8200, LR: 0.000030
  Val Loss: 0.2006, MSE: 0.2006, Spearman: 0.8068


Training: 100%|█████████████████████████████████| 50/50 [00:07<00:00,  7.12it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:00<00:00, 17.98it/s]


Epoch 32/500
  Train Loss: 0.4033, MSE: 0.4033, Spearman: 0.8377, LR: 0.000030
  Val Loss: 0.2058, MSE: 0.2058, Spearman: 0.8114


Training: 100%|█████████████████████████████████| 50/50 [00:07<00:00,  6.97it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:00<00:00, 18.74it/s]


Epoch 33/500
  Train Loss: 0.3726, MSE: 0.3726, Spearman: 0.8465, LR: 0.000030
  Val Loss: 0.2108, MSE: 0.2108, Spearman: 0.8115


Training: 100%|█████████████████████████████████| 50/50 [00:07<00:00,  7.04it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:00<00:00, 18.07it/s]


Epoch 34/500
  Train Loss: 0.3644, MSE: 0.3644, Spearman: 0.8447, LR: 0.000030
  Val Loss: 0.2040, MSE: 0.2040, Spearman: 0.8152


Training: 100%|█████████████████████████████████| 50/50 [00:07<00:00,  6.97it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:00<00:00, 18.44it/s]


Epoch 35/500
  Train Loss: 0.3658, MSE: 0.3658, Spearman: 0.8521, LR: 0.000030
  Val Loss: 0.1955, MSE: 0.1955, Spearman: 0.8200


Training: 100%|█████████████████████████████████| 50/50 [00:07<00:00,  6.94it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:00<00:00, 17.55it/s]


Epoch 36/500
  Train Loss: 0.3567, MSE: 0.3567, Spearman: 0.8535, LR: 0.000030
  Val Loss: 0.2593, MSE: 0.2593, Spearman: 0.8116


Training: 100%|█████████████████████████████████| 50/50 [00:07<00:00,  6.83it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:00<00:00, 17.09it/s]


Epoch 37/500
  Train Loss: 0.3472, MSE: 0.3472, Spearman: 0.8570, LR: 0.000030
  Val Loss: 0.2042, MSE: 0.2042, Spearman: 0.8112


Training: 100%|█████████████████████████████████| 50/50 [00:07<00:00,  7.00it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:00<00:00, 15.73it/s]


Epoch 38/500
  Train Loss: 0.3389, MSE: 0.3389, Spearman: 0.8594, LR: 0.000030
  Val Loss: 0.2076, MSE: 0.2076, Spearman: 0.8004


Training: 100%|█████████████████████████████████| 50/50 [00:07<00:00,  6.80it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:00<00:00, 18.48it/s]


Epoch 39/500
  Train Loss: 0.3316, MSE: 0.3316, Spearman: 0.8654, LR: 0.000030
  Val Loss: 0.1956, MSE: 0.1956, Spearman: 0.8185


Training: 100%|█████████████████████████████████| 50/50 [00:07<00:00,  6.51it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:00<00:00, 16.62it/s]


Epoch 40/500
  Train Loss: 0.3121, MSE: 0.3121, Spearman: 0.8770, LR: 0.000030
  Val Loss: 0.2167, MSE: 0.2167, Spearman: 0.8120


Training: 100%|█████████████████████████████████| 50/50 [00:07<00:00,  6.54it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:00<00:00, 17.25it/s]


Epoch 41/500
  Train Loss: 0.3101, MSE: 0.3101, Spearman: 0.8730, LR: 0.000030
  Val Loss: 0.2146, MSE: 0.2146, Spearman: 0.8063


Training: 100%|█████████████████████████████████| 50/50 [00:07<00:00,  7.06it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:00<00:00, 16.87it/s]


Epoch 42/500
  Train Loss: 0.3182, MSE: 0.3182, Spearman: 0.8714, LR: 0.000030
  Val Loss: 0.2517, MSE: 0.2517, Spearman: 0.8027


Training: 100%|█████████████████████████████████| 50/50 [00:07<00:00,  6.56it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:00<00:00, 17.23it/s]


Epoch 43/500
  Train Loss: 0.2869, MSE: 0.2869, Spearman: 0.8776, LR: 0.000030
  Val Loss: 0.2373, MSE: 0.2373, Spearman: 0.8088


Training: 100%|█████████████████████████████████| 50/50 [00:09<00:00,  5.32it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:01<00:00, 10.53it/s]


Epoch 44/500
  Train Loss: 0.2846, MSE: 0.2846, Spearman: 0.8836, LR: 0.000030
  Val Loss: 0.2138, MSE: 0.2138, Spearman: 0.8082


Training: 100%|█████████████████████████████████| 50/50 [00:06<00:00,  7.67it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:00<00:00, 14.42it/s]


Epoch 45/500
  Train Loss: 0.2767, MSE: 0.2767, Spearman: 0.8794, LR: 0.000030
  Val Loss: 0.2229, MSE: 0.2229, Spearman: 0.8106


Training: 100%|█████████████████████████████████| 50/50 [00:09<00:00,  5.38it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:00<00:00, 17.98it/s]


Epoch 46/500
  Train Loss: 0.2764, MSE: 0.2764, Spearman: 0.8848, LR: 0.000030
  Val Loss: 0.2053, MSE: 0.2053, Spearman: 0.8052


Training: 100%|█████████████████████████████████| 50/50 [00:05<00:00,  8.40it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:00<00:00, 21.48it/s]


Epoch 47/500
  Train Loss: 0.2677, MSE: 0.2677, Spearman: 0.8838, LR: 0.000030
  Val Loss: 0.2325, MSE: 0.2325, Spearman: 0.8046


Training: 100%|█████████████████████████████████| 50/50 [00:05<00:00,  8.52it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:00<00:00, 20.35it/s]


Epoch 48/500
  Train Loss: 0.2589, MSE: 0.2589, Spearman: 0.8972, LR: 0.000030
  Val Loss: 0.2472, MSE: 0.2472, Spearman: 0.8022

Early stopping triggered after 48 epochs!
Best validation loss: 0.1942

Training completed! Best validation loss: 0.1942


Validating: 100%|███████████████████████████████| 13/13 [00:00<00:00, 19.87it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:01<00:00, 12.03it/s]



Test Results:
  Loss: 0.2283
  MSE: 0.2283
  Spearman: 0.7933

Training completed!


In [3]:
import os
import argparse
import torch

from data.dataloaders import get_multimodal_loaders
from utils.load_config import load_config
from utils.calculate_embeddings import calculate_embeddings
from models.multimodel import build_model
from trainer import RegressionTrainer


def _has_any_embeddings(emb_dir) -> bool:
    if not os.path.isdir(emb_dir):
        return False
    try:
        for fn in os.listdir(emb_dir):
            if fn.endswith(".pt"):
                return True
    except FileNotFoundError:
        return False
    return False


def main(name: str, dataset: str, max_len: int, batch_size: int, epochs: int):
    config = load_config(f"{name}.yml")
    config["Dataset"] = dataset
    config["device"] = "cuda" if torch.cuda.is_available() else "cpu"

    max_len = int(max_len)

    filtered_csv = f"data/datasets/{config['Dataset']}_multimodal_filtered_maxlen{max_len}.csv"

    dna_dir = f"embeddings/{config['Dataset']}/DNA/maxlen{max_len}"
    rna_dir = f"embeddings/{config['Dataset']}/RNA/maxlen{max_len}"
    prot_dir = f"embeddings/{config['Dataset']}/Protein/maxlen{max_len}"

    need_filtered_csv = not os.path.exists(filtered_csv)
    need_dna = not _has_any_embeddings(dna_dir)
    need_rna = not _has_any_embeddings(rna_dir)
    need_prot = not _has_any_embeddings(prot_dir)
    need_embeddings = need_dna or need_rna or need_prot

    if need_filtered_csv or need_embeddings:
        print("Embeddings and/or filtered CSV not found, calculating...")
        for modality in ("DNA", "RNA", "Protein"):
            calculate_embeddings(
                dataset=config["Dataset"],
                modality=modality,
                device=config["device"],
                max_len=max_len,
            )

    print("=" * 60)
    print("Training Configuration (Multimodal):")
    print(f"  Config name: {config.get('name', name)}")
    print(f"  Dataset: {config['Dataset']}")
    print(f"  Max Len (filter): {max_len}")
    print(f"  Fusion: {config.get('fusion_type', 'concat')}")
    print(f"  Batch size: {batch_size}")
    print(f"  Epochs: {epochs}")
    print(f"  Device: {config['device']}")
    print("=" * 60)

    print("\nLoading data...")
    train_loader, val_loader, test_loader = get_multimodal_loaders(
        config["Dataset"],
        batch_size=batch_size,
        max_len=max_len,
    )

    print("\nInitializing model...")
    model = build_model(config)

    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    non_trainable_params = total_params - trainable_params
    print(f"Total number of parameters: {total_params}")
    print(f"Trainable parameters: {trainable_params}")
    print(f"Non-trainable parameters: {non_trainable_params}")

    print("\nInitializing trainer...")
    trainer = RegressionTrainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
        device=config["device"],
        save_dir=f"./plots/{config.get('name', name)}/{config['Dataset']}",
    )

    # entropy reg for MIL (nested config with backwards-compatible fallback)
    lam_entropy = None
    if isinstance(config.get("trainer", None), dict):
        lam_entropy = config["trainer"].get("lam_entropy", None)
    if lam_entropy is None:
        lam_entropy = config.get("lam_entropy", None)

    if lam_entropy is not None:
        trainer.lam_entropy = float(lam_entropy)
        if trainer.lam_entropy > 0:
            print(f"Using MIL entropy regularization: lam_entropy={trainer.lam_entropy}")

    print("\n" + "=" * 60)
    print("Starting training...")
    print("=" * 60 + "\n")

    trainer.train(epochs=epochs)

    print("\n" + "=" * 60)
    print("Training completed!")
    print("=" * 60)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train multimodal model (DNA+RNA+Protein)")
    parser.add_argument(
        "--name",
        type=str,
        default="fusion_mil",
        help="Config file name (without .yml). Options: fusion_concat, fusion_mil, fusion_xattn",
    )
    parser.add_argument(
        "--dataset",
        type=str,
        default="cov_vaccine_degradation",
        help=(
            "Dataset to use. Default: fungal_expression. Options: "
            "'mrna_stability', 'ecoli_proteins', 'cov_vaccine_degradation', 'fungal_expression'"
        ),
    )
    parser.add_argument(
        "--max-len",
        type=int,
        default=1000,
        help="Filter threshold: keep only sequences with raw length <= max_len before embedding/training.",
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=32,
        help="Batch size.",
    )
    parser.add_argument(
        "--epochs",
        type=int,
        default=500,
        help="Number of training epochs.",
    )
    args, _ = parser.parse_known_args()

    main(
        name=args.name,
        dataset=args.dataset,
        max_len=args.max_len,
        batch_size=args.batch_size,
        epochs=args.epochs,
    )

Training Configuration (Multimodal):
  Config name: fusion_mil
  Dataset: cov_vaccine_degradation
  Max Len (filter): 1000
  Fusion: mil
  Batch size: 32
  Epochs: 500
  Device: cpu

Loading data...

Initializing model...
Total number of parameters: 3185374
Trainable parameters: 3185374
Non-trainable parameters: 0

Initializing trainer...
Using MIL entropy regularization: lam_entropy=0.01

Starting training...

Starting training for 500 epochs...
Device: cpu
Save directory: plots/fusion_mil/cov_vaccine_degradation


Training: 100%|█████████████████████████████████| 50/50 [00:10<00:00,  4.92it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:01<00:00, 10.86it/s]


Epoch 1/500
  Train Loss: 1.2053, MSE: 1.1944, Spearman: 0.1818, LR: 0.000030
  Val Loss: 0.5678, MSE: 0.5569, Spearman: 0.6082


Training: 100%|█████████████████████████████████| 50/50 [00:08<00:00,  5.96it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:00<00:00, 13.63it/s]


Epoch 2/500
  Train Loss: 1.1636, MSE: 1.1528, Spearman: 0.2765, LR: 0.000030
  Val Loss: 0.4996, MSE: 0.4890, Spearman: 0.6371


Training: 100%|█████████████████████████████████| 50/50 [00:07<00:00,  6.39it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:00<00:00, 13.52it/s]


Epoch 3/500
  Train Loss: 1.0946, MSE: 1.0843, Spearman: 0.4443, LR: 0.000030
  Val Loss: 0.4220, MSE: 0.4120, Spearman: 0.6461


Training: 100%|█████████████████████████████████| 50/50 [00:07<00:00,  6.35it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:00<00:00, 13.09it/s]


Epoch 4/500
  Train Loss: 1.0093, MSE: 0.9997, Spearman: 0.4852, LR: 0.000030
  Val Loss: 0.3441, MSE: 0.3351, Spearman: 0.6638


Training: 100%|█████████████████████████████████| 50/50 [00:07<00:00,  6.62it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:00<00:00, 15.42it/s]


Epoch 5/500
  Train Loss: 0.9348, MSE: 0.9260, Spearman: 0.5277, LR: 0.000030
  Val Loss: 0.3139, MSE: 0.3053, Spearman: 0.6751


Training: 100%|█████████████████████████████████| 50/50 [00:08<00:00,  6.00it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:01<00:00, 10.85it/s]


Epoch 6/500
  Train Loss: 0.8758, MSE: 0.8674, Spearman: 0.5793, LR: 0.000030
  Val Loss: 0.3239, MSE: 0.3157, Spearman: 0.6838


Training: 100%|█████████████████████████████████| 50/50 [00:07<00:00,  6.79it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:00<00:00, 15.65it/s]


Epoch 7/500
  Train Loss: 0.8452, MSE: 0.8372, Spearman: 0.6000, LR: 0.000030
  Val Loss: 0.2886, MSE: 0.2808, Spearman: 0.6907


Training: 100%|█████████████████████████████████| 50/50 [00:08<00:00,  6.23it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:00<00:00, 15.41it/s]


Epoch 8/500
  Train Loss: 0.8215, MSE: 0.8138, Spearman: 0.6221, LR: 0.000030
  Val Loss: 0.2879, MSE: 0.2804, Spearman: 0.7008


Training: 100%|█████████████████████████████████| 50/50 [00:07<00:00,  6.34it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:00<00:00, 15.35it/s]


Epoch 9/500
  Train Loss: 0.7716, MSE: 0.7642, Spearman: 0.6563, LR: 0.000030
  Val Loss: 0.2729, MSE: 0.2654, Spearman: 0.7092


Training: 100%|█████████████████████████████████| 50/50 [00:07<00:00,  6.35it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:00<00:00, 15.13it/s]


Epoch 10/500
  Train Loss: 0.7489, MSE: 0.7417, Spearman: 0.6681, LR: 0.000030
  Val Loss: 0.2807, MSE: 0.2739, Spearman: 0.7095


Training: 100%|█████████████████████████████████| 50/50 [00:07<00:00,  6.29it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:01<00:00, 10.64it/s]


Epoch 11/500
  Train Loss: 0.7288, MSE: 0.7220, Spearman: 0.6791, LR: 0.000030
  Val Loss: 0.2706, MSE: 0.2638, Spearman: 0.7132


Training: 100%|█████████████████████████████████| 50/50 [00:07<00:00,  6.41it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:00<00:00, 15.32it/s]


Epoch 12/500
  Train Loss: 0.7373, MSE: 0.7306, Spearman: 0.6602, LR: 0.000030
  Val Loss: 0.3991, MSE: 0.3923, Spearman: 0.7072


Training: 100%|█████████████████████████████████| 50/50 [00:07<00:00,  6.29it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:00<00:00, 18.24it/s]


Epoch 13/500
  Train Loss: 0.6898, MSE: 0.6830, Spearman: 0.7031, LR: 0.000030
  Val Loss: 0.2619, MSE: 0.2554, Spearman: 0.7305


Training: 100%|█████████████████████████████████| 50/50 [00:07<00:00,  6.48it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:00<00:00, 15.46it/s]


Epoch 14/500
  Train Loss: 0.6772, MSE: 0.6709, Spearman: 0.7088, LR: 0.000030
  Val Loss: 0.2760, MSE: 0.2698, Spearman: 0.7178


Training: 100%|█████████████████████████████████| 50/50 [00:07<00:00,  6.54it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:00<00:00, 17.95it/s]


Epoch 15/500
  Train Loss: 0.6488, MSE: 0.6428, Spearman: 0.7240, LR: 0.000030
  Val Loss: 0.2865, MSE: 0.2805, Spearman: 0.7173


Training: 100%|█████████████████████████████████| 50/50 [00:08<00:00,  6.15it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:01<00:00, 11.05it/s]


Epoch 16/500
  Train Loss: 0.6141, MSE: 0.6083, Spearman: 0.7369, LR: 0.000030
  Val Loss: 0.3139, MSE: 0.3084, Spearman: 0.7067


Training: 100%|█████████████████████████████████| 50/50 [00:07<00:00,  6.72it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:00<00:00, 15.05it/s]


Epoch 17/500
  Train Loss: 0.5996, MSE: 0.5942, Spearman: 0.7462, LR: 0.000030
  Val Loss: 0.2874, MSE: 0.2821, Spearman: 0.7307


Training: 100%|█████████████████████████████████| 50/50 [00:07<00:00,  6.38it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:00<00:00, 15.38it/s]


Epoch 18/500
  Train Loss: 0.5861, MSE: 0.5808, Spearman: 0.7522, LR: 0.000030
  Val Loss: 0.3563, MSE: 0.3517, Spearman: 0.7083


Training: 100%|█████████████████████████████████| 50/50 [00:08<00:00,  6.13it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:00<00:00, 15.36it/s]


Epoch 19/500
  Train Loss: 0.5680, MSE: 0.5631, Spearman: 0.7589, LR: 0.000030
  Val Loss: 0.2976, MSE: 0.2928, Spearman: 0.7045


Training: 100%|█████████████████████████████████| 50/50 [00:07<00:00,  6.27it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:00<00:00, 14.84it/s]


Epoch 20/500
  Train Loss: 0.5358, MSE: 0.5310, Spearman: 0.7707, LR: 0.000030
  Val Loss: 0.3156, MSE: 0.3111, Spearman: 0.7098


Training: 100%|█████████████████████████████████| 50/50 [00:07<00:00,  6.38it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:01<00:00, 11.68it/s]


Epoch 21/500
  Train Loss: 0.5106, MSE: 0.5061, Spearman: 0.7889, LR: 0.000030
  Val Loss: 0.3208, MSE: 0.3165, Spearman: 0.7123


Training: 100%|█████████████████████████████████| 50/50 [00:08<00:00,  5.69it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:01<00:00,  7.92it/s]


Epoch 22/500
  Train Loss: 0.4955, MSE: 0.4912, Spearman: 0.7902, LR: 0.000030
  Val Loss: 0.3026, MSE: 0.2985, Spearman: 0.7087


Training: 100%|█████████████████████████████████| 50/50 [00:08<00:00,  5.85it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:00<00:00, 14.99it/s]


Epoch 23/500
  Train Loss: 0.4856, MSE: 0.4816, Spearman: 0.7900, LR: 0.000030
  Val Loss: 0.3348, MSE: 0.3309, Spearman: 0.6931


Training: 100%|█████████████████████████████████| 50/50 [00:07<00:00,  6.26it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:00<00:00, 15.38it/s]


Epoch 24/500
  Train Loss: 0.4617, MSE: 0.4579, Spearman: 0.7952, LR: 0.000030
  Val Loss: 0.3694, MSE: 0.3656, Spearman: 0.6948


Training: 100%|█████████████████████████████████| 50/50 [00:07<00:00,  6.35it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:00<00:00, 15.79it/s]


Epoch 25/500
  Train Loss: 0.4526, MSE: 0.4487, Spearman: 0.8006, LR: 0.000030
  Val Loss: 0.3357, MSE: 0.3319, Spearman: 0.6967


Training: 100%|█████████████████████████████████| 50/50 [00:07<00:00,  6.33it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:00<00:00, 15.19it/s]


Epoch 26/500
  Train Loss: 0.4154, MSE: 0.4118, Spearman: 0.8194, LR: 0.000030
  Val Loss: 0.3506, MSE: 0.3473, Spearman: 0.6898


Training: 100%|█████████████████████████████████| 50/50 [00:07<00:00,  6.37it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:00<00:00, 15.02it/s]


Epoch 27/500
  Train Loss: 0.4174, MSE: 0.4141, Spearman: 0.8240, LR: 0.000030
  Val Loss: 0.3467, MSE: 0.3434, Spearman: 0.6861


Training: 100%|█████████████████████████████████| 50/50 [00:07<00:00,  6.42it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:00<00:00, 14.38it/s]


Epoch 28/500
  Train Loss: 0.3789, MSE: 0.3756, Spearman: 0.8359, LR: 0.000030
  Val Loss: 0.3740, MSE: 0.3709, Spearman: 0.6783


Training: 100%|█████████████████████████████████| 50/50 [00:07<00:00,  6.39it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:00<00:00, 15.34it/s]


Epoch 29/500
  Train Loss: 0.3657, MSE: 0.3627, Spearman: 0.8333, LR: 0.000030
  Val Loss: 0.3574, MSE: 0.3545, Spearman: 0.6808


Training: 100%|█████████████████████████████████| 50/50 [00:08<00:00,  5.96it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:00<00:00, 18.39it/s]


Epoch 30/500
  Train Loss: 0.3535, MSE: 0.3506, Spearman: 0.8431, LR: 0.000030
  Val Loss: 0.3708, MSE: 0.3679, Spearman: 0.6815


Training: 100%|█████████████████████████████████| 50/50 [00:08<00:00,  6.15it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:00<00:00, 18.18it/s]


Epoch 31/500
  Train Loss: 0.3317, MSE: 0.3289, Spearman: 0.8519, LR: 0.000030
  Val Loss: 0.3931, MSE: 0.3904, Spearman: 0.6770


Training: 100%|█████████████████████████████████| 50/50 [00:07<00:00,  6.52it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:00<00:00, 15.37it/s]


Epoch 32/500
  Train Loss: 0.3168, MSE: 0.3142, Spearman: 0.8548, LR: 0.000030
  Val Loss: 0.3827, MSE: 0.3802, Spearman: 0.6590


Training: 100%|█████████████████████████████████| 50/50 [00:07<00:00,  6.49it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:00<00:00, 13.74it/s]


Epoch 33/500
  Train Loss: 0.3075, MSE: 0.3051, Spearman: 0.8625, LR: 0.000030
  Val Loss: 0.3672, MSE: 0.3647, Spearman: 0.6751

Early stopping triggered after 33 epochs!
Best validation loss: 0.2619

Training completed! Best validation loss: 0.2619


Validating: 100%|███████████████████████████████| 13/13 [00:00<00:00, 15.12it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:01<00:00, 11.38it/s]



Test Results:
  Loss: 0.3121
  MSE: 0.3055
  Spearman: 0.7007

Training completed!


In [4]:
import os
import argparse
import torch

from data.dataloaders import get_multimodal_loaders
from utils.load_config import load_config
from utils.calculate_embeddings import calculate_embeddings
from models.multimodel import build_model
from trainer import RegressionTrainer


def _has_any_embeddings(emb_dir) -> bool:
    if not os.path.isdir(emb_dir):
        return False
    try:
        for fn in os.listdir(emb_dir):
            if fn.endswith(".pt"):
                return True
    except FileNotFoundError:
        return False
    return False


def main(name: str, dataset: str, max_len: int, batch_size: int, epochs: int):
    config = load_config(f"{name}.yml")
    config["Dataset"] = dataset
    config["device"] = "cuda" if torch.cuda.is_available() else "cpu"

    max_len = int(max_len)

    filtered_csv = f"data/datasets/{config['Dataset']}_multimodal_filtered_maxlen{max_len}.csv"

    dna_dir = f"embeddings/{config['Dataset']}/DNA/maxlen{max_len}"
    rna_dir = f"embeddings/{config['Dataset']}/RNA/maxlen{max_len}"
    prot_dir = f"embeddings/{config['Dataset']}/Protein/maxlen{max_len}"

    need_filtered_csv = not os.path.exists(filtered_csv)
    need_dna = not _has_any_embeddings(dna_dir)
    need_rna = not _has_any_embeddings(rna_dir)
    need_prot = not _has_any_embeddings(prot_dir)
    need_embeddings = need_dna or need_rna or need_prot

    if need_filtered_csv or need_embeddings:
        print("Embeddings and/or filtered CSV not found, calculating...")
        for modality in ("DNA", "RNA", "Protein"):
            calculate_embeddings(
                dataset=config["Dataset"],
                modality=modality,
                device=config["device"],
                max_len=max_len,
            )

    print("=" * 60)
    print("Training Configuration (Multimodal):")
    print(f"  Config name: {config.get('name', name)}")
    print(f"  Dataset: {config['Dataset']}")
    print(f"  Max Len (filter): {max_len}")
    print(f"  Fusion: {config.get('fusion_type', 'concat')}")
    print(f"  Batch size: {batch_size}")
    print(f"  Epochs: {epochs}")
    print(f"  Device: {config['device']}")
    print("=" * 60)

    print("\nLoading data...")
    train_loader, val_loader, test_loader = get_multimodal_loaders(
        config["Dataset"],
        batch_size=batch_size,
        max_len=max_len,
    )

    print("\nInitializing model...")
    model = build_model(config)

    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    non_trainable_params = total_params - trainable_params
    print(f"Total number of parameters: {total_params}")
    print(f"Trainable parameters: {trainable_params}")
    print(f"Non-trainable parameters: {non_trainable_params}")

    print("\nInitializing trainer...")
    trainer = RegressionTrainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
        device=config["device"],
        save_dir=f"./plots/{config.get('name', name)}/{config['Dataset']}",
    )

    # entropy reg for MIL (nested config with backwards-compatible fallback)
    lam_entropy = None
    if isinstance(config.get("trainer", None), dict):
        lam_entropy = config["trainer"].get("lam_entropy", None)
    if lam_entropy is None:
        lam_entropy = config.get("lam_entropy", None)

    if lam_entropy is not None:
        trainer.lam_entropy = float(lam_entropy)
        if trainer.lam_entropy > 0:
            print(f"Using MIL entropy regularization: lam_entropy={trainer.lam_entropy}")

    print("\n" + "=" * 60)
    print("Starting training...")
    print("=" * 60 + "\n")

    trainer.train(epochs=epochs)

    print("\n" + "=" * 60)
    print("Training completed!")
    print("=" * 60)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train multimodal model (DNA+RNA+Protein)")
    parser.add_argument(
        "--name",
        type=str,
        default="fusion_xattn",
        help="Config file name (without .yml). Options: fusion_concat, fusion_mil, fusion_xattn",
    )
    parser.add_argument(
        "--dataset",
        type=str,
        default="cov_vaccine_degradation",
        help=(
            "Dataset to use. Default: fungal_expression. Options: "
            "'mrna_stability', 'ecoli_proteins', 'cov_vaccine_degradation', 'fungal_expression'"
        ),
    )
    parser.add_argument(
        "--max-len",
        type=int,
        default=1000,
        help="Filter threshold: keep only sequences with raw length <= max_len before embedding/training.",
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=32,
        help="Batch size.",
    )
    parser.add_argument(
        "--epochs",
        type=int,
        default=500,
        help="Number of training epochs.",
    )
    args, _ = parser.parse_known_args()

    main(
        name=args.name,
        dataset=args.dataset,
        max_len=args.max_len,
        batch_size=args.batch_size,
        epochs=args.epochs,
    )

Training Configuration (Multimodal):
  Config name: fusion_xattn
  Dataset: cov_vaccine_degradation
  Max Len (filter): 1000
  Fusion: xattn
  Batch size: 32
  Epochs: 500
  Device: cpu

Loading data...

Initializing model...
Total number of parameters: 9315473
Trainable parameters: 9315473
Non-trainable parameters: 0

Initializing trainer...

Starting training...

Starting training for 500 epochs...
Device: cpu
Save directory: plots/fusion_xattn/cov_vaccine_degradation


Training: 100%|█████████████████████████████████| 50/50 [00:26<00:00,  1.86it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:02<00:00,  4.87it/s]


Epoch 1/500
  Train Loss: 1.2714, MSE: 1.2714, Spearman: 0.0256, LR: 0.000030
  Val Loss: 0.5419, MSE: 0.5419, Spearman: 0.5023


Training: 100%|█████████████████████████████████| 50/50 [00:30<00:00,  1.62it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:02<00:00,  4.86it/s]


Epoch 2/500
  Train Loss: 1.1629, MSE: 1.1629, Spearman: 0.2077, LR: 0.000030
  Val Loss: 0.4848, MSE: 0.4848, Spearman: 0.5456


Training: 100%|█████████████████████████████████| 50/50 [00:32<00:00,  1.53it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:02<00:00,  4.43it/s]


Epoch 3/500
  Train Loss: 1.1197, MSE: 1.1197, Spearman: 0.2856, LR: 0.000030
  Val Loss: 0.4641, MSE: 0.4641, Spearman: 0.5780


Training: 100%|█████████████████████████████████| 50/50 [00:33<00:00,  1.51it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:03<00:00,  3.97it/s]


Epoch 4/500
  Train Loss: 1.0659, MSE: 1.0659, Spearman: 0.3484, LR: 0.000030
  Val Loss: 0.4321, MSE: 0.4321, Spearman: 0.5884


Training: 100%|█████████████████████████████████| 50/50 [00:33<00:00,  1.50it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:03<00:00,  4.28it/s]


Epoch 5/500
  Train Loss: 0.9936, MSE: 0.9936, Spearman: 0.4462, LR: 0.000030
  Val Loss: 0.4379, MSE: 0.4379, Spearman: 0.6336


Training: 100%|█████████████████████████████████| 50/50 [00:32<00:00,  1.56it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:03<00:00,  4.21it/s]


Epoch 6/500
  Train Loss: 1.0044, MSE: 1.0044, Spearman: 0.4297, LR: 0.000030
  Val Loss: 0.3981, MSE: 0.3981, Spearman: 0.6477


Training: 100%|█████████████████████████████████| 50/50 [00:38<00:00,  1.30it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:02<00:00,  4.40it/s]


Epoch 7/500
  Train Loss: 0.9165, MSE: 0.9165, Spearman: 0.5104, LR: 0.000030
  Val Loss: 0.3625, MSE: 0.3625, Spearman: 0.6784


Training: 100%|█████████████████████████████████| 50/50 [00:33<00:00,  1.51it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:02<00:00,  4.56it/s]


Epoch 8/500
  Train Loss: 0.8990, MSE: 0.8990, Spearman: 0.5454, LR: 0.000030
  Val Loss: 0.3538, MSE: 0.3538, Spearman: 0.7130


Training: 100%|█████████████████████████████████| 50/50 [00:30<00:00,  1.65it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:02<00:00,  4.88it/s]


Epoch 9/500
  Train Loss: 0.8875, MSE: 0.8875, Spearman: 0.5503, LR: 0.000030
  Val Loss: 0.3274, MSE: 0.3274, Spearman: 0.6928


Training: 100%|█████████████████████████████████| 50/50 [00:30<00:00,  1.62it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:02<00:00,  4.65it/s]


Epoch 10/500
  Train Loss: 0.8793, MSE: 0.8793, Spearman: 0.5215, LR: 0.000030
  Val Loss: 0.3054, MSE: 0.3054, Spearman: 0.7275


Training: 100%|█████████████████████████████████| 50/50 [00:29<00:00,  1.71it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:02<00:00,  5.03it/s]


Epoch 11/500
  Train Loss: 0.8421, MSE: 0.8421, Spearman: 0.5857, LR: 0.000030
  Val Loss: 0.4045, MSE: 0.4045, Spearman: 0.7174


Training: 100%|█████████████████████████████████| 50/50 [00:28<00:00,  1.76it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:02<00:00,  4.98it/s]


Epoch 12/500
  Train Loss: 0.8146, MSE: 0.8146, Spearman: 0.5984, LR: 0.000030
  Val Loss: 0.3477, MSE: 0.3477, Spearman: 0.7345


Training: 100%|█████████████████████████████████| 50/50 [00:27<00:00,  1.83it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:02<00:00,  5.11it/s]


Epoch 13/500
  Train Loss: 0.7971, MSE: 0.7971, Spearman: 0.5983, LR: 0.000030
  Val Loss: 0.3774, MSE: 0.3774, Spearman: 0.7440


Training: 100%|█████████████████████████████████| 50/50 [00:25<00:00,  1.94it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:02<00:00,  5.31it/s]


Epoch 14/500
  Train Loss: 0.7690, MSE: 0.7690, Spearman: 0.5953, LR: 0.000030
  Val Loss: 0.3977, MSE: 0.3977, Spearman: 0.7057


Training: 100%|█████████████████████████████████| 50/50 [00:25<00:00,  1.92it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:02<00:00,  4.96it/s]


Epoch 15/500
  Train Loss: 0.8064, MSE: 0.8064, Spearman: 0.5799, LR: 0.000030
  Val Loss: 0.3469, MSE: 0.3469, Spearman: 0.7519


Training: 100%|█████████████████████████████████| 50/50 [00:25<00:00,  1.95it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:02<00:00,  5.39it/s]


Epoch 16/500
  Train Loss: 0.7409, MSE: 0.7409, Spearman: 0.6291, LR: 0.000030
  Val Loss: 0.3348, MSE: 0.3348, Spearman: 0.7444


Training: 100%|█████████████████████████████████| 50/50 [00:25<00:00,  1.95it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:02<00:00,  5.37it/s]


Epoch 17/500
  Train Loss: 0.7194, MSE: 0.7194, Spearman: 0.6231, LR: 0.000030
  Val Loss: 0.3489, MSE: 0.3489, Spearman: 0.7720


Training: 100%|█████████████████████████████████| 50/50 [00:25<00:00,  1.95it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:02<00:00,  5.29it/s]


Epoch 18/500
  Train Loss: 0.6879, MSE: 0.6879, Spearman: 0.6505, LR: 0.000030
  Val Loss: 0.2785, MSE: 0.2785, Spearman: 0.7668


Training: 100%|█████████████████████████████████| 50/50 [00:28<00:00,  1.75it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:02<00:00,  5.28it/s]


Epoch 19/500
  Train Loss: 0.6594, MSE: 0.6594, Spearman: 0.6438, LR: 0.000030
  Val Loss: 0.2958, MSE: 0.2958, Spearman: 0.7657


Training: 100%|█████████████████████████████████| 50/50 [00:26<00:00,  1.86it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:02<00:00,  5.01it/s]


Epoch 20/500
  Train Loss: 0.6580, MSE: 0.6580, Spearman: 0.6544, LR: 0.000030
  Val Loss: 0.3336, MSE: 0.3336, Spearman: 0.7377


Training: 100%|█████████████████████████████████| 50/50 [00:27<00:00,  1.83it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:02<00:00,  5.03it/s]


Epoch 21/500
  Train Loss: 0.6479, MSE: 0.6479, Spearman: 0.6577, LR: 0.000030
  Val Loss: 0.3093, MSE: 0.3093, Spearman: 0.7372


Training: 100%|█████████████████████████████████| 50/50 [00:27<00:00,  1.84it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:02<00:00,  4.95it/s]


Epoch 22/500
  Train Loss: 0.6046, MSE: 0.6046, Spearman: 0.6713, LR: 0.000030
  Val Loss: 0.3275, MSE: 0.3275, Spearman: 0.7397


Training: 100%|█████████████████████████████████| 50/50 [00:27<00:00,  1.84it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:02<00:00,  5.02it/s]


Epoch 23/500
  Train Loss: 0.5866, MSE: 0.5866, Spearman: 0.6627, LR: 0.000030
  Val Loss: 0.3604, MSE: 0.3604, Spearman: 0.7126


Training: 100%|█████████████████████████████████| 50/50 [00:27<00:00,  1.85it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:02<00:00,  5.10it/s]


Epoch 24/500
  Train Loss: 0.6617, MSE: 0.6617, Spearman: 0.6255, LR: 0.000030
  Val Loss: 0.3151, MSE: 0.3151, Spearman: 0.7633


Training: 100%|█████████████████████████████████| 50/50 [00:27<00:00,  1.85it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:02<00:00,  5.11it/s]


Epoch 25/500
  Train Loss: 0.6092, MSE: 0.6092, Spearman: 0.6751, LR: 0.000030
  Val Loss: 0.3216, MSE: 0.3216, Spearman: 0.7432


Training: 100%|█████████████████████████████████| 50/50 [00:27<00:00,  1.85it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:02<00:00,  5.03it/s]


Epoch 26/500
  Train Loss: 0.6374, MSE: 0.6374, Spearman: 0.6303, LR: 0.000030
  Val Loss: 0.3556, MSE: 0.3556, Spearman: 0.7088


Training: 100%|█████████████████████████████████| 50/50 [00:27<00:00,  1.85it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:02<00:00,  4.95it/s]


Epoch 27/500
  Train Loss: 0.6053, MSE: 0.6053, Spearman: 0.6526, LR: 0.000030
  Val Loss: 0.3349, MSE: 0.3349, Spearman: 0.7524


Training: 100%|█████████████████████████████████| 50/50 [00:26<00:00,  1.86it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:02<00:00,  5.08it/s]


Epoch 28/500
  Train Loss: 0.5138, MSE: 0.5138, Spearman: 0.6933, LR: 0.000030
  Val Loss: 0.2866, MSE: 0.2866, Spearman: 0.7653


Training: 100%|█████████████████████████████████| 50/50 [00:26<00:00,  1.86it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:02<00:00,  4.83it/s]


Epoch 29/500
  Train Loss: 0.5280, MSE: 0.5280, Spearman: 0.7057, LR: 0.000030
  Val Loss: 0.3214, MSE: 0.3214, Spearman: 0.7570


Training: 100%|█████████████████████████████████| 50/50 [00:26<00:00,  1.85it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:02<00:00,  5.06it/s]


Epoch 30/500
  Train Loss: 0.5379, MSE: 0.5379, Spearman: 0.6817, LR: 0.000030
  Val Loss: 0.3162, MSE: 0.3162, Spearman: 0.7336


Training: 100%|█████████████████████████████████| 50/50 [00:30<00:00,  1.67it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:02<00:00,  4.82it/s]


Epoch 31/500
  Train Loss: 0.5243, MSE: 0.5243, Spearman: 0.7038, LR: 0.000030
  Val Loss: 0.3158, MSE: 0.3158, Spearman: 0.7483


Training: 100%|█████████████████████████████████| 50/50 [00:27<00:00,  1.83it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:02<00:00,  5.04it/s]


Epoch 32/500
  Train Loss: 0.4812, MSE: 0.4812, Spearman: 0.7101, LR: 0.000030
  Val Loss: 0.3183, MSE: 0.3183, Spearman: 0.7561


Training: 100%|█████████████████████████████████| 50/50 [00:27<00:00,  1.83it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:02<00:00,  5.05it/s]


Epoch 33/500
  Train Loss: 0.5028, MSE: 0.5028, Spearman: 0.7162, LR: 0.000030
  Val Loss: 0.3198, MSE: 0.3198, Spearman: 0.7365


Training: 100%|█████████████████████████████████| 50/50 [00:27<00:00,  1.83it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:02<00:00,  5.03it/s]


Epoch 34/500
  Train Loss: 0.5158, MSE: 0.5158, Spearman: 0.7060, LR: 0.000030
  Val Loss: 0.2982, MSE: 0.2982, Spearman: 0.7744


Training: 100%|█████████████████████████████████| 50/50 [00:27<00:00,  1.79it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:02<00:00,  5.01it/s]


Epoch 35/500
  Train Loss: 0.4750, MSE: 0.4750, Spearman: 0.7290, LR: 0.000030
  Val Loss: 0.3271, MSE: 0.3271, Spearman: 0.7111


Training: 100%|█████████████████████████████████| 50/50 [00:27<00:00,  1.84it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:02<00:00,  5.04it/s]


Epoch 36/500
  Train Loss: 0.4534, MSE: 0.4534, Spearman: 0.7342, LR: 0.000030
  Val Loss: 0.3666, MSE: 0.3666, Spearman: 0.7587


Training: 100%|█████████████████████████████████| 50/50 [00:27<00:00,  1.82it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:02<00:00,  5.02it/s]


Epoch 37/500
  Train Loss: 0.5716, MSE: 0.5716, Spearman: 0.6922, LR: 0.000030
  Val Loss: 0.2879, MSE: 0.2879, Spearman: 0.7518


Training: 100%|█████████████████████████████████| 50/50 [00:27<00:00,  1.81it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:02<00:00,  4.99it/s]


Epoch 38/500
  Train Loss: 0.4182, MSE: 0.4182, Spearman: 0.7463, LR: 0.000030
  Val Loss: 0.3069, MSE: 0.3069, Spearman: 0.7615

Early stopping triggered after 38 epochs!
Best validation loss: 0.2785

Training completed! Best validation loss: 0.2785


Validating: 100%|███████████████████████████████| 13/13 [00:02<00:00,  5.85it/s]
Validating: 100%|███████████████████████████████| 13/13 [00:02<00:00,  5.05it/s]



Test Results:
  Loss: 0.3011
  MSE: 0.3011
  Spearman: 0.7285

Training completed!


---

# ecoli_proteins

In [None]:
import os
import argparse
import torch

from data.dataloaders import get_multimodal_loaders
from utils.load_config import load_config
from utils.calculate_embeddings import calculate_embeddings
from models.multimodel import build_model
from trainer import RegressionTrainer


def _has_any_embeddings(emb_dir) -> bool:
    if not os.path.isdir(emb_dir):
        return False
    try:
        for fn in os.listdir(emb_dir):
            if fn.endswith(".pt"):
                return True
    except FileNotFoundError:
        return False
    return False

def main(name: str, dataset: str, max_len: int, batch_size: int, epochs: int):
    config = load_config(f"{name}.yml")
    config["Dataset"] = dataset
    config["device"] = "cuda" if torch.cuda.is_available() else "cpu"

    max_len = int(max_len)

    filtered_csv = f"data/datasets/{config['Dataset']}_multimodal_filtered_maxlen{max_len}.csv"

    dna_dir = f"embeddings/{config['Dataset']}/DNA/maxlen{max_len}"
    rna_dir = f"embeddings/{config['Dataset']}/RNA/maxlen{max_len}"
    prot_dir = f"embeddings/{config['Dataset']}/Protein/maxlen{max_len}"

    need_filtered_csv = not os.path.exists(filtered_csv)
    need_dna = not _has_any_embeddings(dna_dir)
    need_rna = not _has_any_embeddings(rna_dir)
    need_prot = not _has_any_embeddings(prot_dir)
    need_embeddings = need_dna or need_rna or need_prot

    if need_filtered_csv or need_embeddings:
        print("Embeddings and/or filtered CSV not found, calculating...")
        for modality in ("DNA", "RNA", "Protein"):
            calculate_embeddings(
                dataset=config["Dataset"],
                modality=modality,
                device=config["device"],
                max_len=max_len,
            )

    print("=" * 60)
    print("Training Configuration (Multimodal):")
    print(f"  Config name: {config.get('name', name)}")
    print(f"  Dataset: {config['Dataset']}")
    print(f"  Max Len (filter): {max_len}")
    print(f"  Fusion: {config.get('fusion_type', 'concat')}")
    print(f"  Batch size: {batch_size}")
    print(f"  Epochs: {epochs}")
    print(f"  Device: {config['device']}")
    print("=" * 60)

    print("\nLoading data...")
    train_loader, val_loader, test_loader = get_multimodal_loaders(
        config["Dataset"],
        batch_size=batch_size,
        max_len=max_len,
    )

    print("\nInitializing model...")
    model = build_model(config)

    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    non_trainable_params = total_params - trainable_params
    print(f"Total number of parameters: {total_params}")
    print(f"Trainable parameters: {trainable_params}")
    print(f"Non-trainable parameters: {non_trainable_params}")

    print("\nInitializing trainer...")
    trainer = ClassificationTrainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
        device=config["device"],
        save_dir=f"./plots/{config.get('name', name)}/{config['Dataset']}",
    )

    # entropy reg for MIL (nested config with backwards-compatible fallback)
    lam_entropy = None
    if isinstance(config.get("trainer", None), dict):
        lam_entropy = config["trainer"].get("lam_entropy", None)
    if lam_entropy is None:
        lam_entropy = config.get("lam_entropy", None)

    if lam_entropy is not None:
        trainer.lam_entropy = float(lam_entropy)
        if trainer.lam_entropy > 0:
            print(f"Using MIL entropy regularization: lam_entropy={trainer.lam_entropy}")

    print("\n" + "=" * 60)
    print("Starting training...")
    print("=" * 60 + "\n")

    trainer.train(epochs=epochs)

    print("\n" + "=" * 60)
    print("Training completed!")
    print("=" * 60)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train multimodal model (DNA+RNA+Protein)")
    parser.add_argument(
        "--name",
        type=str,
        default="fusion_concat",
        help="Config file name (without .yml). Options: fusion_concat, fusion_mil, fusion_xattn",
    )
    parser.add_argument(
        "--dataset",
        type=str,
        default="ecoli_proteins",
        help=(
            "Dataset to use. Default: fungal_expression. Options: "
            "'mrna_stability', 'ecoli_proteins', 'cov_vaccine_degradation', 'fungal_expression'"
        ),
    )
    parser.add_argument(
        "--max-len",
        type=int,
        default=1000,
        help="Filter threshold: keep only sequences with raw length <= max_len before embedding/training.",
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=32,
        help="Batch size.",
    )
    parser.add_argument(
        "--epochs",
        type=int,
        default=500,
        help="Number of training epochs.",
    )
    args, _ = parser.parse_known_args()

    main(
        name=args.name,
        dataset=args.dataset,
        max_len=args.max_len,
        batch_size=args.batch_size,
        epochs=args.epochs,
    )

Embeddings and/or filtered CSV not found, calculating...
Calculating embeddings... for dataset: ecoli_proteins and modality: DNA
Filtering rule for CSV: RNA length <= 1000
Created filtered CSV (RNA-length based).
Original rows : 6348
Kept rows     : 4450 (RNA length <= 1000)
Saved to      : data/datasets/ecoli_proteins_multimodal_filtered_maxlen1000.csv


Embedding ecoli_proteins from DNA:   0%|       | 1/4450 [00:00<24:04,  3.08it/s]

First embedding shape: torch.Size([133, 512])


Embedding ecoli_proteins from DNA:  58%|██▎ | 2589/4450 [10:23<05:33,  5.58it/s]

In [None]:
import os
import argparse
import torch

from data.dataloaders import get_multimodal_loaders
from utils.load_config import load_config
from utils.calculate_embeddings import calculate_embeddings
from models.multimodel import build_model
from trainer import RegressionTrainer


def _has_any_embeddings(emb_dir) -> bool:
    if not os.path.isdir(emb_dir):
        return False
    try:
        for fn in os.listdir(emb_dir):
            if fn.endswith(".pt"):
                return True
    except FileNotFoundError:
        return False
    return False

def main(name: str, dataset: str, max_len: int, batch_size: int, epochs: int):
    config = load_config(f"{name}.yml")
    config["Dataset"] = dataset
    config["device"] = "cuda" if torch.cuda.is_available() else "cpu"

    max_len = int(max_len)

    filtered_csv = f"data/datasets/{config['Dataset']}_multimodal_filtered_maxlen{max_len}.csv"

    dna_dir = f"embeddings/{config['Dataset']}/DNA/maxlen{max_len}"
    rna_dir = f"embeddings/{config['Dataset']}/RNA/maxlen{max_len}"
    prot_dir = f"embeddings/{config['Dataset']}/Protein/maxlen{max_len}"

    need_filtered_csv = not os.path.exists(filtered_csv)
    need_dna = not _has_any_embeddings(dna_dir)
    need_rna = not _has_any_embeddings(rna_dir)
    need_prot = not _has_any_embeddings(prot_dir)
    need_embeddings = need_dna or need_rna or need_prot

    if need_filtered_csv or need_embeddings:
        print("Embeddings and/or filtered CSV not found, calculating...")
        for modality in ("DNA", "RNA", "Protein"):
            calculate_embeddings(
                dataset=config["Dataset"],
                modality=modality,
                device=config["device"],
                max_len=max_len,
            )

    print("=" * 60)
    print("Training Configuration (Multimodal):")
    print(f"  Config name: {config.get('name', name)}")
    print(f"  Dataset: {config['Dataset']}")
    print(f"  Max Len (filter): {max_len}")
    print(f"  Fusion: {config.get('fusion_type', 'concat')}")
    print(f"  Batch size: {batch_size}")
    print(f"  Epochs: {epochs}")
    print(f"  Device: {config['device']}")
    print("=" * 60)

    print("\nLoading data...")
    train_loader, val_loader, test_loader = get_multimodal_loaders(
        config["Dataset"],
        batch_size=batch_size,
        max_len=max_len,
    )

    print("\nInitializing model...")
    model = build_model(config)

    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    non_trainable_params = total_params - trainable_params
    print(f"Total number of parameters: {total_params}")
    print(f"Trainable parameters: {trainable_params}")
    print(f"Non-trainable parameters: {non_trainable_params}")

    print("\nInitializing trainer...")
    trainer = ClassificationTrainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
        device=config["device"],
        save_dir=f"./plots/{config.get('name', name)}/{config['Dataset']}",
    )

    # entropy reg for MIL (nested config with backwards-compatible fallback)
    lam_entropy = None
    if isinstance(config.get("trainer", None), dict):
        lam_entropy = config["trainer"].get("lam_entropy", None)
    if lam_entropy is None:
        lam_entropy = config.get("lam_entropy", None)

    if lam_entropy is not None:
        trainer.lam_entropy = float(lam_entropy)
        if trainer.lam_entropy > 0:
            print(f"Using MIL entropy regularization: lam_entropy={trainer.lam_entropy}")

    print("\n" + "=" * 60)
    print("Starting training...")
    print("=" * 60 + "\n")

    trainer.train(epochs=epochs)

    print("\n" + "=" * 60)
    print("Training completed!")
    print("=" * 60)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train multimodal model (DNA+RNA+Protein)")
    parser.add_argument(
        "--name",
        type=str,
        default="fusion_mil",
        help="Config file name (without .yml). Options: fusion_concat, fusion_mil, fusion_xattn",
    )
    parser.add_argument(
        "--dataset",
        type=str,
        default="ecoli_proteins",
        help=(
            "Dataset to use. Default: fungal_expression. Options: "
            "'mrna_stability', 'ecoli_proteins', 'cov_vaccine_degradation', 'fungal_expression'"
        ),
    )
    parser.add_argument(
        "--max-len",
        type=int,
        default=1000,
        help="Filter threshold: keep only sequences with raw length <= max_len before embedding/training.",
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=32,
        help="Batch size.",
    )
    parser.add_argument(
        "--epochs",
        type=int,
        default=500,
        help="Number of training epochs.",
    )
    args, _ = parser.parse_known_args()

    main(
        name=args.name,
        dataset=args.dataset,
        max_len=args.max_len,
        batch_size=args.batch_size,
        epochs=args.epochs,
    )

In [None]:
import os
import argparse
import torch

from data.dataloaders import get_multimodal_loaders
from utils.load_config import load_config
from utils.calculate_embeddings import calculate_embeddings
from models.multimodel import build_model
from trainer import RegressionTrainer


def _has_any_embeddings(emb_dir) -> bool:
    if not os.path.isdir(emb_dir):
        return False
    try:
        for fn in os.listdir(emb_dir):
            if fn.endswith(".pt"):
                return True
    except FileNotFoundError:
        return False
    return False

def main(name: str, dataset: str, max_len: int, batch_size: int, epochs: int):
    config = load_config(f"{name}.yml")
    config["Dataset"] = dataset
    config["device"] = "cuda" if torch.cuda.is_available() else "cpu"

    max_len = int(max_len)

    filtered_csv = f"data/datasets/{config['Dataset']}_multimodal_filtered_maxlen{max_len}.csv"

    dna_dir = f"embeddings/{config['Dataset']}/DNA/maxlen{max_len}"
    rna_dir = f"embeddings/{config['Dataset']}/RNA/maxlen{max_len}"
    prot_dir = f"embeddings/{config['Dataset']}/Protein/maxlen{max_len}"

    need_filtered_csv = not os.path.exists(filtered_csv)
    need_dna = not _has_any_embeddings(dna_dir)
    need_rna = not _has_any_embeddings(rna_dir)
    need_prot = not _has_any_embeddings(prot_dir)
    need_embeddings = need_dna or need_rna or need_prot

    if need_filtered_csv or need_embeddings:
        print("Embeddings and/or filtered CSV not found, calculating...")
        for modality in ("DNA", "RNA", "Protein"):
            calculate_embeddings(
                dataset=config["Dataset"],
                modality=modality,
                device=config["device"],
                max_len=max_len,
            )

    print("=" * 60)
    print("Training Configuration (Multimodal):")
    print(f"  Config name: {config.get('name', name)}")
    print(f"  Dataset: {config['Dataset']}")
    print(f"  Max Len (filter): {max_len}")
    print(f"  Fusion: {config.get('fusion_type', 'concat')}")
    print(f"  Batch size: {batch_size}")
    print(f"  Epochs: {epochs}")
    print(f"  Device: {config['device']}")
    print("=" * 60)

    print("\nLoading data...")
    train_loader, val_loader, test_loader = get_multimodal_loaders(
        config["Dataset"],
        batch_size=batch_size,
        max_len=max_len,
    )

    print("\nInitializing model...")
    model = build_model(config)

    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    non_trainable_params = total_params - trainable_params
    print(f"Total number of parameters: {total_params}")
    print(f"Trainable parameters: {trainable_params}")
    print(f"Non-trainable parameters: {non_trainable_params}")

    print("\nInitializing trainer...")
    trainer = ClassificationTrainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
        device=config["device"],
        save_dir=f"./plots/{config.get('name', name)}/{config['Dataset']}",
    )

    # entropy reg for MIL (nested config with backwards-compatible fallback)
    lam_entropy = None
    if isinstance(config.get("trainer", None), dict):
        lam_entropy = config["trainer"].get("lam_entropy", None)
    if lam_entropy is None:
        lam_entropy = config.get("lam_entropy", None)

    if lam_entropy is not None:
        trainer.lam_entropy = float(lam_entropy)
        if trainer.lam_entropy > 0:
            print(f"Using MIL entropy regularization: lam_entropy={trainer.lam_entropy}")

    print("\n" + "=" * 60)
    print("Starting training...")
    print("=" * 60 + "\n")

    trainer.train(epochs=epochs)

    print("\n" + "=" * 60)
    print("Training completed!")
    print("=" * 60)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train multimodal model (DNA+RNA+Protein)")
    parser.add_argument(
        "--name",
        type=str,
        default="fusion_xattn",
        help="Config file name (without .yml). Options: fusion_concat, fusion_mil, fusion_xattn",
    )
    parser.add_argument(
        "--dataset",
        type=str,
        default="ecoli_proteins",
        help=(
            "Dataset to use. Default: fungal_expression. Options: "
            "'mrna_stability', 'ecoli_proteins', 'cov_vaccine_degradation', 'fungal_expression'"
        ),
    )
    parser.add_argument(
        "--max-len",
        type=int,
        default=1000,
        help="Filter threshold: keep only sequences with raw length <= max_len before embedding/training.",
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=32,
        help="Batch size.",
    )
    parser.add_argument(
        "--epochs",
        type=int,
        default=500,
        help="Number of training epochs.",
    )
    args, _ = parser.parse_known_args()

    main(
        name=args.name,
        dataset=args.dataset,
        max_len=args.max_len,
        batch_size=args.batch_size,
        epochs=args.epochs,
    )

# Uni Models:

In [1]:
import torch
import argparse
import os
import pandas as pd
import shutil

from data.dataloaders import get_loaders
from data.subsampler.subsample import subsample_loader
from utils.load_config import load_config
from utils.calculate_embeddings import calculate_embeddings
from models.unimodel import build_model
from trainer import RegressionTrainer, ClassificationTrainer


def main(name, dataset, max_len):
    config = load_config(f"{name}.yml")
    config["task"] = "classification"
    config["num_classes"] = 3
    config["Dataset"] = dataset
    config["device"] = "cuda" if torch.cuda.is_available() else "cpu"

    max_len = int(max_len)

    filtered_csv = f"data/datasets/{config['Dataset']}_multimodal_filtered_maxlen{max_len}.csv"
    emb_dir = f"embeddings/{config['Dataset']}/{config['modality']}/maxlen{max_len}"

    need_embeddings = not os.path.exists(os.path.join(emb_dir, "seq1.pt"))
    need_filtered_csv = not os.path.exists(filtered_csv)

    # Paths
    filtered_csv = f"data/datasets/{config['Dataset']}_multimodal_filtered_maxlen{max_len}.csv"
    emb_dir = f"embeddings/{config['Dataset']}/{config['modality']}/maxlen{max_len}"

    need_filtered_csv = not os.path.exists(filtered_csv)
    need_embeddings = not os.path.exists(emb_dir)

    # 🔥 If CSV needs to be regenerated, embeddings are INVALID
    if need_filtered_csv and os.path.exists(emb_dir):
        print("Filtered CSV missing → removing stale embeddings")
        shutil.rmtree(emb_dir)

    # 🔥 If embeddings exist but CSV was recreated earlier → force rebuild
    if not need_filtered_csv and os.path.exists(emb_dir):
        df = pd.read_csv(filtered_csv)
        expected_ids = set(df["id"].astype(str))
        existing_ids = {
            f.replace(".pt", "") for f in os.listdir(emb_dir)
            if f.endswith(".pt")
        }

        if not expected_ids.issubset(existing_ids):
            print("Embedding mismatch detected → removing stale embeddings")
            shutil.rmtree(emb_dir)
            need_embeddings = True

    if need_embeddings or need_filtered_csv:
        print("Recomputing embeddings...")
        calculate_embeddings(
            dataset=config["Dataset"],
            modality=config["modality"],
            device=config["device"],
            max_len=max_len,
        )


    print("=" * 60)
    print("Training Configuration:")
    print(f"  Name: {config['name']}")
    print(f"  Dataset: {config['Dataset']}")
    print(f"  Modality: {config['modality']}")
    print(f"  Max Len (filter): {max_len}")
    print("=" * 60)

    print("\nLoading data...")
    train_loader, val_loader, test_loader = get_loaders(
        config["Dataset"],
        32,
        modality=config["modality"],
        max_len=max_len,
    )
    
    print(
        len(train_loader.dataset),
        len(val_loader.dataset),
        len(test_loader.dataset)
    )

    train_loader = subsample_loader(train_loader, fraction=1)
    val_loader   = subsample_loader(val_loader, fraction=1)
    test_loader  = subsample_loader(test_loader, fraction=1)
    
    print(
        len(train_loader.dataset),
        len(val_loader.dataset),
        len(test_loader.dataset)
    )
    
    print("\nInitializing model...")
    model = build_model(config)
    
    print(model)

    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    non_trainable_params = total_params - trainable_params
    print(f"Total number of parameters: {total_params}")
    print(f"Trainable parameters: {trainable_params}")
    print(f"Non-trainable parameters: {non_trainable_params}")

    print("\nInitializing trainer...")
    
    trainer = ClassificationTrainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
        device=config["device"],
        save_dir=f"./plots/{config['name']}/{config['Dataset']}",
    )

    print("\n" + "=" * 60)
    print("Starting training...")
    print("=" * 60 + "\n")

    trainer.train(epochs=500)

    print("\n" + "=" * 60)
    print("Training completed!")
    print("=" * 60)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train unimodal model")
    parser.add_argument(
        "--name",
        type=str,
        default="uni_rna",
        help="Name of config file (without .yml extension) to use. Default: uni_rna",
    )
    parser.add_argument(
        "--dataset",
        type=str,
        default="ecoli_proteins",
        help=(
            "Dataset to use. Default: fungal_expression. Options: "
            "'mrna_stability', 'ecoli_proteins', 'cov_vaccine_degradation', 'fungal_expression'"
        ),
    )
    parser.add_argument(
        "--max-len",
        type=int,
        default=1024,
        help="Filter threshold: keep only sequences with raw length <= max_len before embedding/training.",
    )
    args, _ = parser.parse_known_args()
    main(args.name, args.dataset, args.max_len)

Training Configuration:
  Name: uni_rna
  Dataset: ecoli_proteins
  Modality: RNA
  Max Len (filter): 1024

Loading data...
3073 742 721
3073 742 721

Initializing model...
UnimodalClassificationModel(
  (net): TextCNNHead(
    (project): Linear(in_features=640, out_features=640, bias=True)
    (convs): ModuleList(
      (0): Conv1d(640, 100, kernel_size=(3,), stride=(1,), padding=(1,))
      (1): Conv1d(640, 100, kernel_size=(4,), stride=(1,), padding=(2,))
      (2): Conv1d(640, 100, kernel_size=(5,), stride=(1,), padding=(2,))
    )
    (dropout): Dropout(p=0.2, inplace=False)
    (fc): Linear(in_features=300, out_features=3, bias=True)
  )
)
Total number of parameters: 1179443
Trainable parameters: 1179443
Non-trainable parameters: 0

Initializing trainer...

Starting training...

Starting classification training for 500 epochs
Device: cpu
Classes: 3


Training:  10%|███▍                             | 10/97 [00:15<02:12,  1.52s/it]


KeyboardInterrupt: 

In [1]:
import torch
import argparse
import os
import pandas as pd
import shutil

from data.dataloaders import get_loaders
from data.subsampler.subsample import subsample_loader
from utils.load_config import load_config
from utils.calculate_embeddings import calculate_embeddings
from models.unimodel import build_model
from trainer import RegressionTrainer, ClassificationTrainer


def main(name, dataset, max_len):
    config = load_config(f"{name}.yml")
    config["task"] = "regression"
    #config["num_classes"] = 3
    config["Dataset"] = dataset
    config["device"] = "cuda" if torch.cuda.is_available() else "cpu"

    max_len = int(max_len)

    filtered_csv = f"data/datasets/{config['Dataset']}_multimodal_filtered_maxlen{max_len}.csv"
    emb_dir = f"embeddings/{config['Dataset']}/{config['modality']}/maxlen{max_len}"

    need_embeddings = not os.path.exists(os.path.join(emb_dir, "seq1.pt"))
    need_filtered_csv = not os.path.exists(filtered_csv)

    # Paths
    filtered_csv = f"data/datasets/{config['Dataset']}_multimodal_filtered_maxlen{max_len}.csv"
    emb_dir = f"embeddings/{config['Dataset']}/{config['modality']}/maxlen{max_len}"

    need_filtered_csv = not os.path.exists(filtered_csv)
    need_embeddings = not os.path.exists(emb_dir)

    # 🔥 If CSV needs to be regenerated, embeddings are INVALID
    if need_filtered_csv and os.path.exists(emb_dir):
        print("Filtered CSV missing → removing stale embeddings")
        shutil.rmtree(emb_dir)

    # 🔥 If embeddings exist but CSV was recreated earlier → force rebuild
    if not need_filtered_csv and os.path.exists(emb_dir):
        df = pd.read_csv(filtered_csv)
        expected_ids = set(df["id"].astype(str))
        existing_ids = {
            f.replace(".pt", "") for f in os.listdir(emb_dir)
            if f.endswith(".pt")
        }

        if not expected_ids.issubset(existing_ids):
            print("Embedding mismatch detected → removing stale embeddings")
            shutil.rmtree(emb_dir)
            need_embeddings = True

    if need_embeddings or need_filtered_csv:
        print("Recomputing embeddings...")
        calculate_embeddings(
            dataset=config["Dataset"],
            modality=config["modality"],
            device=config["device"],
            max_len=max_len,
        )


    print("=" * 60)
    print("Training Configuration:")
    print(f"  Name: {config['name']}")
    print(f"  Dataset: {config['Dataset']}")
    print(f"  Modality: {config['modality']}")
    print(f"  Max Len (filter): {max_len}")
    print("=" * 60)

    print("\nLoading data...")
    train_loader, val_loader, test_loader = get_loaders(
        config["Dataset"],
        32,
        modality=config["modality"],
        max_len=max_len,
    )
    
    print(
        len(train_loader.dataset),
        len(val_loader.dataset),
        len(test_loader.dataset)
    )

    train_loader = subsample_loader(train_loader, fraction=1)
    val_loader   = subsample_loader(val_loader, fraction=1)
    test_loader  = subsample_loader(test_loader, fraction=1)
    
    print(
        len(train_loader.dataset),
        len(val_loader.dataset),
        len(test_loader.dataset)
    )
    
    print("\nInitializing model...")
    model = build_model(config)

    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    non_trainable_params = total_params - trainable_params
    print(f"Total number of parameters: {total_params}")
    print(f"Trainable parameters: {trainable_params}")
    print(f"Non-trainable parameters: {non_trainable_params}")

    print("\nInitializing trainer...")
    trainer = RegressionTrainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
        device=config["device"],
        save_dir=f"./plots/{config['name']}/{config['Dataset']}",
    )

    print("\n" + "=" * 60)
    print("Starting training...")
    print("=" * 60 + "\n")

    trainer.train(epochs=500)

    print("\n" + "=" * 60)
    print("Training completed!")
    print("=" * 60)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train unimodal model")
    parser.add_argument(
        "--name",
        type=str,
        default="uni_rna",
        help="Name of config file (without .yml extension) to use. Default: uni_rna",
    )
    parser.add_argument(
        "--dataset",
        type=str,
        default="fungal_expression",
        help=(
            "Dataset to use. Default: fungal_expression. Options: "
            "'mrna_stability', 'ecoli_proteins', 'cov_vaccine_degradation', 'fungal_expression'"
        ),
    )
    parser.add_argument(
        "--max-len",
        type=int,
        default=1024,
        help="Filter threshold: keep only sequences with raw length <= max_len before embedding/training.",
    )
    args, _ = parser.parse_known_args()
    main(args.name, args.dataset, args.max_len)

Training Configuration:
  Name: uni_rna
  Dataset: fungal_expression
  Modality: RNA
  Max Len (filter): 1024

Loading data...
2232 566 445
2232 566 445

Initializing model...


TypeError: __init__() got an unexpected keyword argument 'num_classes'