In [1]:
import sys
import os

project_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))
sys.path.append(project_root)

print(project_root)

/data/data3/junibg-ego/Modelo_leo_coi


In [2]:
import torch
print(f"¿GPU disponible? {torch.cuda.is_available()}")
print(f"Número de GPUs: {torch.cuda.device_count()}")
print(f"Nombre GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'No GPU'}")


¿GPU disponible? True
Número de GPUs: 2
Nombre GPU: NVIDIA GeForce RTX 3090


In [3]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import pandas as pd
import json

In [4]:
from src.utils.load_fastaDataset import *
from torch.utils.data import DataLoader

In [5]:
def load_hierarchy_from_json(json_path):
    """Carga la jerarquía taxonómica desde JSON"""
    with open(json_path, 'r') as f:
        hierarchy_raw = json.load(f)
    
    # Convertir claves string a int
    hierarchy = {}
    for child_taxon, parent_dict in hierarchy_raw.items():
        hierarchy[child_taxon] = {}
        for parent_key, children_list in parent_dict.items():
            parent_int = int(float(parent_key))
            children_int = [int(c) for c in children_list]
            hierarchy[child_taxon][parent_int] = children_int
    
    print("✅ Jerarquía cargada desde JSON")
    for taxon, mapping in hierarchy.items():
        n_parents = len(mapping)
        n_children = sum(len(v) for v in mapping.values())
        print(f"  {taxon:10s}: {n_parents:4d} padres → {n_children:5d} hijos")
    
    return hierarchy

In [6]:
hierarchy_path = os.path.join(project_root, "src", "data", "taxonomy_hierarchy_fixed_with_class.json")
hierarchy = load_hierarchy_from_json(hierarchy_path)

✅ Jerarquía cargada desde JSON
  class     :   49 padres →   187 hijos
  order     :  173 padres →   831 hijos
  family    :  797 padres →  5446 hijos
  genus     : 5393 padres → 50568 hijos
  species   : 50510 padres → 205075 hijos


In [7]:
df = os.path.join(project_root, "src", "data", "all_taxa_numeric.csv")
df = pd.read_csv(df)

In [8]:
taxon_order = ['phylum', 'class','order', 'family', 'genus', 'species']
total_classes = {}
for taxon in taxon_order:
    n_classes = df[taxon].nunique()
    total_classes[taxon] = n_classes
    print(f"  {taxon:10s}: {n_classes:6d} clases")

  phylum    :     49 clases
  class     :    173 clases
  order     :    797 clases
  family    :   5393 clases
  genus     :  50510 clases
  species   : 205075 clases


In [9]:
from src.basic_classifier.build_classifier import *

number_of_classes = []
for class_name, num_classes in total_classes.items():
    number_of_classes.append(num_classes)

print(number_of_classes)

model_configuration = get_model_config(number_of_classes, project_root)
classifier = build_model(model_configuration)
print(classifier)

[49, 173, 797, 5393, 50510, 205075]


Some weights of BertModel were not initialized from the model checkpoint at /data/data3/junibg-ego/Modelo_leo_coi/src/data/archives and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


DNAClassifier(
  (embedder): DNABERTEmbedder(
    (model): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(4096, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0-11): 12 x BertLayer(
            (attention): BertUnpadAttention(
              (self): BertUnpadSelfAttention(
                (dropout): Dropout(p=0.0, inplace=False)
                (Wqkv): Linear(in_features=768, out_features=2304, bias=True)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=True)
                (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
            )
            (mlp): BertGatedLinearUnitML

In [10]:
from sklearn.model_selection import train_test_split

# Primero separar test (20%)
df_temp, df_test = train_test_split(
    df, test_size=0.2, random_state=42, stratify=df['phylum']
)

# Luego separar train/val (80/20 del 80% restante = 64/16 del total)
df_train, df_val = train_test_split(
    df_temp, test_size=0.2, random_state=42, stratify=df_temp['phylum']
)

In [11]:
val_dataset = MultiTaxaFastaDataset(
    df_val.reset_index(drop=True),
    max_length=750,
    taxon_cols=taxon_order
)

training_dataset = MultiTaxaFastaDataset(
    df_train.reset_index(drop=True),
    max_length=750,
    taxon_cols=taxon_order
)

In [12]:
def collate_multitask(batch, taxon_cols=['phylum', 'class','order','family','genus','species'], max_length=900):
    sequences, labels_dict_list, recon_targets_list, true_tokens_list = zip(*batch)

    # Labels: dict de tensors
    labels_dict = {taxon: torch.stack([d[taxon] for d in labels_dict_list]) for taxon in taxon_cols}

    # Recon targets: dict de tensors
    recon_targets_dict = {taxon: torch.stack([d[taxon] for d in recon_targets_list]) for taxon in taxon_cols}

    # True tokens
    true_tokens = torch.stack(true_tokens_list)

    return sequences, labels_dict

In [13]:
from src.basic_classifier.train_classifier import train_basic_classifier

training_config = {
        "batch_size" : 16,
        "num_epochs" : 15,
        "lr" : 5e-4,
        "weight_decay": 1e-4,  # Add regularization
        "patience": 5,  # More patience for large class counts
        "label_smoothing": 0.08,  # Add for better generalization
        "gradient_clip": 1.0,  # Prevent gradient explosion
        "warmup_epochs": 2,  # Gradual LR warmup
        "lr_schedule": "cosine",  # Cosine annealing
        "accumulation_steps": 2,  # Gradient accumulation for effective batch_size=128
    }

# Optimizer config
optimizer_config = {
    "type": "AdamW",
    "lr": 5e-4,
    "betas": (0.9, 0.999),
    "weight_decay": 1e-4,
    "eps": 1e-8
}   


In [14]:
val_loader = DataLoader(
    val_dataset,
    batch_size=training_config["batch_size"],
    shuffle=False,
    drop_last=True,
    collate_fn=lambda b: collate_multitask(b, taxon_cols=val_dataset.taxon_cols, max_length=val_dataset.max_length),
    num_workers=6
)

training_loader = DataLoader(
    training_dataset,
    batch_size=training_config["batch_size"],
    shuffle=True,
    drop_last=True,
    collate_fn=lambda b: collate_multitask(b, taxon_cols=val_dataset.taxon_cols, max_length=val_dataset.max_length),
    num_workers=6
)

In [17]:
from datetime import datetime

os.environ["TOKENIZERS_PARALLELISM"] = "false"

classifier.to("cuda")

max_length = model_configuration["config_embedder"]["max_length"]
best_val_acc, best_model_state, history, last_improved = train_basic_classifier(classifier, training_loader, val_loader, training_config, optimizer_config)

date = datetime.now().strftime("%Y%m%d")
torch.save(best_model_state, f"best_model_ddp_{date}.pt")

best_model_train_metrics = history["train"][last_improved]
best_model_val_metrics = history["val"][last_improved]

print("\nBEST MODEL")
print(f"  Train - Loss: {best_model_train_metrics['loss_avg']:.4f}, Rank Loss: {best_model_train_metrics['loss_rank_avg']}")
print(f"          Top-1: {best_model_train_metrics['top1_acc']:.4f}, Top-5: {best_model_train_metrics['top5_acc']:.4f}")
print(f"          Rank Top-1: {best_model_train_metrics['top1_rank_acc']}, Rank Top-5: {best_model_train_metrics['top5_rank_acc']}")
print(f"  Val   - Loss: {best_model_val_metrics['loss_avg']:.4f}, Rank Loss: {best_model_val_metrics['loss_rank_avg']}")
print(f"          Top-1: {best_model_val_metrics['top1_acc']:.4f}, Top-5: {best_model_val_metrics['top5_acc']:.4f}")
print(f"          Rank Top-1: {best_model_val_metrics['top1_rank_acc']}, Rank Top-5: {best_model_val_metrics['top5_rank_acc']}")



Starting training...


Training:   0%|                                                                                                                                 | 0/61780 [00:00<?, ?it/s]
Evaluation:   0%|                                                                                                                               | 0/15445 [00:00<?, ?it/s]



Epoch [1/15]
  Train - Loss: 0.0007, Rank Loss: [1.46898676e-05 4.24142832e-05 7.06177153e-05 1.46783521e-04
 1.90667467e-04 2.12816267e-04]
          Top-1: 0.2917, Top-5: 0.4062
          Rank Top-1: [0.875 0.625 0.25  0.    0.    0.   ], Rank Top-5: [1.     0.75   0.625  0.0625 0.     0.    ]
  Val   - Loss: 0.0028, Rank Loss: [8.27700855e-05 2.19678925e-04 3.12766800e-04 5.84616064e-04
 7.54985396e-04 8.36123969e-04]
          Top-1: 0.2188, Top-5: 0.3021
          Rank Top-1: [0.875  0.375  0.0625 0.     0.     0.    ], Rank Top-5: [0.9375 0.4375 0.375  0.0625 0.     0.    ]


Training:   0%|                                                                                                                                 | 0/61780 [00:00<?, ?it/s]
Evaluation:   0%|                                                                                                                               | 0/15445 [00:00<?, ?it/s]



Epoch [2/15]
  Train - Loss: 0.0007, Rank Loss: [1.80021078e-05 4.56371019e-05 8.20634202e-05 1.41837057e-04
 1.87847042e-04 2.13746910e-04]
          Top-1: 0.2500, Top-5: 0.3750
          Rank Top-1: [0.875 0.5   0.125 0.    0.    0.   ], Rank Top-5: [1.   0.75 0.5  0.   0.   0.  ]
  Val   - Loss: 0.0028, Rank Loss: [0.000128   0.00019858 0.00030489 0.00057946 0.00075319 0.00083469]
          Top-1: 0.1667, Top-5: 0.3229
          Rank Top-1: [0.4375 0.375  0.1875 0.     0.     0.    ], Rank Top-5: [0.9375 0.5625 0.375  0.0625 0.     0.    ]


Training:   0%|                                                                                                                                 | 0/61780 [00:00<?, ?it/s]
Evaluation:   0%|                                                                                                                               | 0/15445 [00:00<?, ?it/s]



Epoch [3/15]
  Train - Loss: 0.0007, Rank Loss: [3.05461011e-05 3.83575667e-05 6.35747463e-05 1.37458121e-04
 1.90926987e-04 2.08955523e-04]
          Top-1: 0.2292, Top-5: 0.4167
          Rank Top-1: [0.375 0.75  0.25  0.    0.    0.   ], Rank Top-5: [0.9375 0.875  0.6875 0.     0.     0.    ]
  Val   - Loss: 0.0028, Rank Loss: [9.95775147e-05 1.94830970e-04 3.34486185e-04 5.74254943e-04
 7.47431949e-04 8.32547614e-04]
          Top-1: 0.1875, Top-5: 0.3125
          Rank Top-1: [0.5625 0.375  0.1875 0.     0.     0.    ], Rank Top-5: [0.9375 0.5    0.375  0.0625 0.     0.    ]


Training:   0%|                                                                                                                                 | 0/61780 [00:00<?, ?it/s]
Evaluation:   0%|                                                                                                                               | 0/15445 [00:00<?, ?it/s]



Epoch [4/15]
  Train - Loss: 0.0007, Rank Loss: [2.27386954e-05 4.31874069e-05 8.81233506e-05 1.42116723e-04
 1.83933566e-04 2.19875094e-04]
          Top-1: 0.2500, Top-5: 0.3542
          Rank Top-1: [0.8125 0.625  0.0625 0.     0.     0.    ], Rank Top-5: [0.9375 0.75   0.4375 0.     0.     0.    ]
  Val   - Loss: 0.0028, Rank Loss: [0.00014401 0.00019689 0.00030483 0.00057297 0.00073839 0.00083079]
          Top-1: 0.1875, Top-5: 0.3229
          Rank Top-1: [0.5625 0.375  0.1875 0.     0.     0.    ], Rank Top-5: [0.9375 0.625  0.375  0.     0.     0.    ]


Training:   0%|                                                                                                                                 | 0/61780 [00:00<?, ?it/s]
Evaluation:   0%|                                                                                                                               | 0/15445 [00:00<?, ?it/s]



Epoch [5/15]
  Train - Loss: 0.0007, Rank Loss: [1.61002642e-05 3.32229305e-05 6.90416211e-05 1.41825001e-04
 1.83904962e-04 2.13009256e-04]
          Top-1: 0.2917, Top-5: 0.4062
          Rank Top-1: [0.875 0.75  0.125 0.    0.    0.   ], Rank Top-5: [1.     0.8125 0.625  0.     0.     0.    ]
  Val   - Loss: 0.0028, Rank Loss: [0.0001544  0.00020954 0.00030752 0.00057121 0.00073391 0.00082723]
          Top-1: 0.1771, Top-5: 0.3229
          Rank Top-1: [0.5625 0.375  0.125  0.     0.     0.    ], Rank Top-5: [0.875  0.6875 0.375  0.     0.     0.    ]


Training:   0%|                                                                                                                                 | 0/61780 [00:00<?, ?it/s]
Evaluation:   0%|                                                                                                                               | 0/15445 [00:00<?, ?it/s]



Epoch [6/15]
  Train - Loss: 0.0007, Rank Loss: [3.08509898e-05 5.08057102e-05 7.64946754e-05 1.45457593e-04
 1.90362378e-04 2.16878676e-04]
          Top-1: 0.2188, Top-5: 0.3542
          Rank Top-1: [0.75   0.4375 0.125  0.     0.     0.    ], Rank Top-5: [0.8125 0.75   0.5625 0.     0.     0.    ]
  Val   - Loss: 0.0028, Rank Loss: [0.00013824 0.00020423 0.00031383 0.00056989 0.00073192 0.00082657]
          Top-1: 0.1771, Top-5: 0.3333
          Rank Top-1: [0.5625 0.375  0.125  0.     0.     0.    ], Rank Top-5: [0.9375 0.6875 0.375  0.     0.     0.    ]


Training:   0%|                                                                                                                                 | 0/61780 [00:00<?, ?it/s]
Evaluation:   0%|                                                                                                                               | 0/15445 [00:00<?, ?it/s]



Epoch [7/15]
  Train - Loss: 0.0007, Rank Loss: [3.99185562e-05 3.14477290e-05 9.03443404e-05 1.42007941e-04
 1.96540869e-04 2.12576043e-04]
          Top-1: 0.1979, Top-5: 0.3333
          Rank Top-1: [0.4375 0.6875 0.0625 0.     0.     0.    ], Rank Top-5: [0.875  0.875  0.1875 0.0625 0.     0.    ]
  Val   - Loss: 0.0028, Rank Loss: [0.00011401 0.00021612 0.00032238 0.00056853 0.00073126 0.00082994]
          Top-1: 0.1979, Top-5: 0.3229
          Rank Top-1: [0.625  0.4375 0.125  0.     0.     0.    ], Rank Top-5: [0.9375 0.6875 0.3125 0.     0.     0.    ]


Training:   0%|                                                                                                                                 | 0/61780 [00:00<?, ?it/s]
Evaluation:   0%|                                                                                                                               | 0/15445 [00:00<?, ?it/s]



Epoch [8/15]
  Train - Loss: 0.0007, Rank Loss: [3.86945867e-05 5.93720149e-05 9.03109742e-05 1.50916167e-04
 1.84989662e-04 2.13868875e-04]
          Top-1: 0.1562, Top-5: 0.2292
          Rank Top-1: [0.5    0.4375 0.     0.     0.     0.    ], Rank Top-5: [0.6875 0.5    0.1875 0.     0.     0.    ]
  Val   - Loss: 0.0028, Rank Loss: [0.00011245 0.00021746 0.00032657 0.00056858 0.00072849 0.00083243]
          Top-1: 0.1979, Top-5: 0.2917
          Rank Top-1: [0.8125 0.375  0.     0.     0.     0.    ], Rank Top-5: [0.9375 0.6875 0.125  0.     0.     0.    ]


Training:   0%|                                                                                                                                 | 0/61780 [00:00<?, ?it/s]
Evaluation:   0%|                                                                                                                               | 0/15445 [00:00<?, ?it/s]



Epoch [9/15]
  Train - Loss: 0.0007, Rank Loss: [2.22539238e-05 4.81149301e-05 8.86925759e-05 1.42252580e-04
 1.91674814e-04 2.12464575e-04]
          Top-1: 0.2292, Top-5: 0.3542
          Rank Top-1: [0.75   0.375  0.1875 0.0625 0.     0.    ], Rank Top-5: [1.     0.6875 0.375  0.0625 0.     0.    ]
  Val   - Loss: 0.0028, Rank Loss: [0.00010217 0.00020979 0.00032828 0.00056775 0.00072648 0.00083483]
          Top-1: 0.2188, Top-5: 0.3125
          Rank Top-1: [0.8125 0.5    0.     0.     0.     0.    ], Rank Top-5: [0.9375 0.625  0.3125 0.     0.     0.    ]


Training:   0%|                                                                                                                                 | 0/61780 [00:00<?, ?it/s]
Evaluation:   0%|                                                                                                                               | 0/15445 [00:00<?, ?it/s]



Epoch [10/15]
  Train - Loss: 0.0007, Rank Loss: [2.93261374e-05 3.66565903e-05 7.75575020e-05 1.42158293e-04
 1.72052479e-04 2.15045655e-04]
          Top-1: 0.2396, Top-5: 0.3750
          Rank Top-1: [0.75   0.6875 0.     0.     0.     0.    ], Rank Top-5: [0.875  0.8125 0.5    0.0625 0.     0.    ]
  Val   - Loss: 0.0028, Rank Loss: [9.21966035e-05 2.01976596e-04 3.26059952e-04 5.67122917e-04
 7.26307739e-04 8.36716674e-04]
          Top-1: 0.2083, Top-5: 0.3125
          Rank Top-1: [0.8125 0.4375 0.     0.     0.     0.    ], Rank Top-5: [0.9375 0.625  0.3125 0.     0.     0.    ]


Training:   0%|                                                                                                                                 | 0/61780 [00:00<?, ?it/s]
Evaluation:   0%|                                                                                                                               | 0/15445 [00:00<?, ?it/s]



Epoch [11/15]
  Train - Loss: 0.0007, Rank Loss: [2.59860892e-05 3.95495786e-05 7.95586976e-05 1.45046608e-04
 1.92561988e-04 2.14376416e-04]
          Top-1: 0.2500, Top-5: 0.3542
          Rank Top-1: [0.6875 0.6875 0.125  0.     0.     0.    ], Rank Top-5: [0.875 0.75  0.5   0.    0.    0.   ]
  Val   - Loss: 0.0027, Rank Loss: [8.45839114e-05 1.98974776e-04 3.22381776e-04 5.65657364e-04
 7.24915294e-04 8.38909970e-04]
          Top-1: 0.1979, Top-5: 0.3229
          Rank Top-1: [0.8125 0.375  0.     0.     0.     0.    ], Rank Top-5: [0.9375 0.625  0.375  0.     0.     0.    ]
Early stopping.


RuntimeError: Parent directory models does not exist.