In [1]:
import sys
import os

project_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))
sys.path.append(project_root)

print(project_root)

/data/data3/junibg-ego/Modelo_leo_coi


In [2]:
import torch
print(f"¿GPU disponible? {torch.cuda.is_available()}")
print(f"Número de GPUs: {torch.cuda.device_count()}")
print(f"Nombre GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'No GPU'}")


¿GPU disponible? True
Número de GPUs: 2
Nombre GPU: NVIDIA GeForce RTX 3090


In [3]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import pandas as pd
import json

In [4]:
from src.utils.load_fastaDataset import *
from torch.utils.data import DataLoader

In [5]:
def load_hierarchy_from_json(json_path):
    """Carga la jerarquía taxonómica desde JSON"""
    with open(json_path, 'r') as f:
        hierarchy_raw = json.load(f)
    
    # Convertir claves string a int
    hierarchy = {}
    for child_taxon, parent_dict in hierarchy_raw.items():
        hierarchy[child_taxon] = {}
        for parent_key, children_list in parent_dict.items():
            parent_int = int(float(parent_key))
            children_int = [int(c) for c in children_list]
            hierarchy[child_taxon][parent_int] = children_int
    
    print("✅ Jerarquía cargada desde JSON")
    for taxon, mapping in hierarchy.items():
        n_parents = len(mapping)
        n_children = sum(len(v) for v in mapping.values())
        print(f"  {taxon:10s}: {n_parents:4d} padres → {n_children:5d} hijos")
    
    return hierarchy

In [6]:
hierarchy_path = os.path.join(project_root, "src", "data", "taxonomy_hierarchy_fixed_with_class.json")
hierarchy = load_hierarchy_from_json(hierarchy_path)

✅ Jerarquía cargada desde JSON
  class     :   49 padres →   187 hijos
  order     :  173 padres →   831 hijos
  family    :  797 padres →  5446 hijos
  genus     : 5393 padres → 50568 hijos
  species   : 50510 padres → 205075 hijos


In [7]:
df = os.path.join(project_root, "src", "data", "all_taxa_numeric.csv")
df = pd.read_csv(df)

In [8]:
taxon_order = ['phylum', 'class','order', 'family', 'genus', 'species']
total_classes = {}
for taxon in taxon_order:
    n_classes = df[taxon].nunique()
    total_classes[taxon] = n_classes
    print(f"  {taxon:10s}: {n_classes:6d} clases")

  phylum    :     49 clases
  class     :    173 clases
  order     :    797 clases
  family    :   5393 clases
  genus     :  50510 clases
  species   : 205075 clases


In [9]:
from src.basic_classifier.build_classifier import *

number_of_classes = []
for class_name, num_classes in total_classes.items():
    number_of_classes.append(num_classes)

print(number_of_classes)

model_configuration = get_model_config(number_of_classes, project_root)
classifier = build_model(model_configuration)
print(classifier)

[49, 173, 797, 5393, 50510, 205075]


Some weights of BertModel were not initialized from the model checkpoint at /data/data3/junibg-ego/Modelo_leo_coi/src/data/archives and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


DNAClassifier(
  (embedder): DNABERTEmbedder(
    (model): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(4096, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0-11): 12 x BertLayer(
            (attention): BertUnpadAttention(
              (self): BertUnpadSelfAttention(
                (dropout): Dropout(p=0.0, inplace=False)
                (Wqkv): Linear(in_features=768, out_features=2304, bias=True)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=True)
                (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
            )
            (mlp): BertGatedLinearUnitML

In [10]:
from sklearn.model_selection import train_test_split

# Primero separar test (20%)
df_temp, df_test = train_test_split(
    df, test_size=0.2, random_state=42, stratify=df['phylum']
)

# Luego separar train/val (80/20 del 80% restante = 64/16 del total)
df_train, df_val = train_test_split(
    df_temp, test_size=0.2, random_state=42, stratify=df_temp['phylum']
)

In [11]:
val_dataset = MultiTaxaFastaDataset(
    df_val.reset_index(drop=True),
    max_length=750,
    taxon_cols=taxon_order
)

training_dataset = MultiTaxaFastaDataset(
    df_train.reset_index(drop=True),
    max_length=750,
    taxon_cols=taxon_order
)

In [12]:
def collate_multitask(batch, taxon_cols=['phylum', 'class','order','family','genus','species'], max_length=900):
    sequences, labels_dict_list, recon_targets_list, true_tokens_list = zip(*batch)

    # Labels: dict de tensors
    labels_dict = {taxon: torch.stack([d[taxon] for d in labels_dict_list]) for taxon in taxon_cols}

    # Recon targets: dict de tensors
    recon_targets_dict = {taxon: torch.stack([d[taxon] for d in recon_targets_list]) for taxon in taxon_cols}

    # True tokens
    true_tokens = torch.stack(true_tokens_list)

    return sequences, labels_dict

In [13]:
from src.basic_classifier.train_classifier import train_basic_classifier

training_config = {
        "batch_size" : 16,
        "num_epochs" : 15,
        "lr" : 5e-4,
        "weight_decay": 1e-4,  # Add regularization
        "patience": 5,  # More patience for large class counts
        "label_smoothing": 0.08,  # Add for better generalization
        "gradient_clip": 1.0,  # Prevent gradient explosion
        "warmup_epochs": 2,  # Gradual LR warmup
        "lr_schedule": "cosine",  # Cosine annealing
        "accumulation_steps": 2,  # Gradient accumulation for effective batch_size=128
    }

# Optimizer config
optimizer_config = {
    "type": "AdamW",
    "lr": 5e-4,
    "betas": (0.9, 0.999),
    "weight_decay": 1e-4,
    "eps": 1e-8
}   


In [14]:
val_loader = DataLoader(
    val_dataset,
    batch_size=training_config["batch_size"],
    shuffle=False,
    drop_last=True,
    collate_fn=lambda b: collate_multitask(b, taxon_cols=val_dataset.taxon_cols, max_length=val_dataset.max_length),
    num_workers=6
)

training_loader = DataLoader(
    training_dataset,
    batch_size=training_config["batch_size"],
    shuffle=True,
    drop_last=True,
    collate_fn=lambda b: collate_multitask(b, taxon_cols=val_dataset.taxon_cols, max_length=val_dataset.max_length),
    num_workers=6
)

In [18]:
from datetime import datetime

os.environ["TOKENIZERS_PARALLELISM"] = "false"

classifier.to("cuda")

max_length = model_configuration["config_embedder"]["max_length"]
best_val_acc, best_model_state, history, last_improved = train_basic_classifier(classifier, training_loader, val_loader, training_config, optimizer_config)

date = datetime.now().strftime("%Y%m%d")
torch.save(best_model_state, f"best_model_ddp_{date}.pt")

best_model_train_metrics = history["train"][last_improved]
best_model_val_metrics = history["val"][last_improved]

print("\nBEST MODEL")
print(f"  Train - Loss: {best_model_train_metrics['loss_avg']:.4f}, Rank Loss: {best_model_train_metrics['loss_rank_avg']}")
print(f"          Top-1: {best_model_train_metrics['top1_acc']:.4f}, Top-5: {best_model_train_metrics['top5_acc']:.4f}")
print(f"          Rank Top-1: {best_model_train_metrics['top1_rank_acc']}, Rank Top-5: {best_model_train_metrics['top5_rank_acc']}")
print(f"  Val   - Loss: {best_model_val_metrics['loss_avg']:.4f}, Rank Loss: {best_model_val_metrics['loss_rank_avg']}")
print(f"          Top-1: {best_model_val_metrics['top1_acc']:.4f}, Top-5: {best_model_val_metrics['top5_acc']:.4f}")
print(f"          Rank Top-1: {best_model_val_metrics['top1_rank_acc']}, Rank Top-5: {best_model_val_metrics['top5_rank_acc']}")



Starting training...


Training:   0%|                                                                                                                                 | 0/61780 [00:00<?, ?it/s]
Evaluation:   0%|                                                                                                                               | 0/15445 [00:00<?, ?it/s]



Epoch [1/15]
  Train - Loss: 0.0007, Rank Loss: [1.79589007e-05 3.83539854e-05 8.83349634e-05 1.50031834e-04
 1.79388933e-04 2.09834762e-04]
          Top-1: 0.2604, Top-5: 0.3542
          Rank Top-1: [0.875  0.625  0.0625 0.     0.     0.    ], Rank Top-5: [1.     0.8125 0.3125 0.     0.     0.    ]
  Val   - Loss: 0.0027, Rank Loss: [7.91587271e-05 1.98339328e-04 3.18713974e-04 5.65789625e-04
 7.25041936e-04 8.38722261e-04]
          Top-1: 0.2083, Top-5: 0.3125
          Rank Top-1: [0.875 0.375 0.    0.    0.    0.   ], Rank Top-5: [0.9375 0.5625 0.375  0.     0.     0.    ]


Training:   0%|                                                                                                                                 | 0/61780 [00:00<?, ?it/s]
Evaluation:   0%|                                                                                                                               | 0/15445 [00:00<?, ?it/s]



Epoch [2/15]
  Train - Loss: 0.0007, Rank Loss: [3.05824331e-05 4.86370737e-05 8.49233240e-05 1.39326245e-04
 1.77836611e-04 2.01813725e-04]
          Top-1: 0.1979, Top-5: 0.3125
          Rank Top-1: [0.5625 0.4375 0.125  0.0625 0.     0.    ], Rank Top-5: [0.875  0.5625 0.375  0.0625 0.     0.    ]
  Val   - Loss: 0.0028, Rank Loss: [8.21138748e-05 2.13151149e-04 3.21504297e-04 5.68163283e-04
 7.27218993e-04 8.38535725e-04]
          Top-1: 0.1771, Top-5: 0.3125
          Rank Top-1: [0.75   0.25   0.0625 0.     0.     0.    ], Rank Top-5: [0.9375 0.5625 0.375  0.     0.     0.    ]


Training:   0%|                                                                                                                                 | 0/61780 [00:00<?, ?it/s]
Evaluation:   0%|                                                                                                                               | 0/15445 [00:00<?, ?it/s]



Epoch [3/15]
  Train - Loss: 0.0007, Rank Loss: [2.49456977e-05 3.92645530e-05 7.61268362e-05 1.41119424e-04
 1.90813405e-04 2.08163655e-04]
          Top-1: 0.2396, Top-5: 0.3750
          Rank Top-1: [0.8125 0.625  0.     0.     0.     0.    ], Rank Top-5: [0.875  0.8125 0.5625 0.     0.     0.    ]
  Val   - Loss: 0.0027, Rank Loss: [8.45942693e-05 1.97569997e-04 3.14870348e-04 5.73680516e-04
 7.32125492e-04 8.40759030e-04]
          Top-1: 0.2083, Top-5: 0.3333
          Rank Top-1: [0.75  0.375 0.125 0.    0.    0.   ], Rank Top-5: [0.9375 0.625  0.375  0.0625 0.     0.    ]


Training:   0%|                                                                                                                                 | 0/61780 [00:00<?, ?it/s]
Evaluation:   0%|                                                                                                                               | 0/15445 [00:00<?, ?it/s]



Epoch [4/15]
  Train - Loss: 0.0007, Rank Loss: [1.42007082e-05 2.66235212e-05 7.42678736e-05 1.39052214e-04
 1.87379822e-04 2.17331633e-04]
          Top-1: 0.2917, Top-5: 0.4375
          Rank Top-1: [0.875 0.625 0.25  0.    0.    0.   ], Rank Top-5: [1.     1.     0.5625 0.0625 0.     0.    ]
  Val   - Loss: 0.0028, Rank Loss: [8.79814108e-05 2.04779794e-04 3.29153327e-04 5.76729866e-04
 7.32518323e-04 8.40266417e-04]
          Top-1: 0.2188, Top-5: 0.3229
          Rank Top-1: [0.75   0.375  0.1875 0.     0.     0.    ], Rank Top-5: [0.9375 0.625  0.375  0.     0.     0.    ]


Training:   0%|                                                                                                                                 | 0/61780 [00:00<?, ?it/s]
Evaluation:   0%|                                                                                                                               | 0/15445 [00:00<?, ?it/s]



Epoch [5/15]
  Train - Loss: 0.0007, Rank Loss: [2.74613283e-05 2.90454727e-05 8.33029730e-05 1.43700690e-04
 1.88531517e-04 2.09965695e-04]
          Top-1: 0.2708, Top-5: 0.3542
          Rank Top-1: [0.75  0.75  0.125 0.    0.    0.   ], Rank Top-5: [0.875 0.875 0.375 0.    0.    0.   ]
  Val   - Loss: 0.0028, Rank Loss: [8.47112249e-05 2.14257029e-04 3.26659510e-04 5.80677873e-04
 7.34613257e-04 8.39722616e-04]
          Top-1: 0.2292, Top-5: 0.3333
          Rank Top-1: [0.8125 0.375  0.1875 0.     0.     0.    ], Rank Top-5: [0.9375 0.6875 0.375  0.     0.     0.    ]


Training:   0%|                                                                                                                                 | 0/61780 [00:00<?, ?it/s]
Evaluation:   0%|                                                                                                                               | 0/15445 [00:00<?, ?it/s]



Epoch [6/15]
  Train - Loss: 0.0007, Rank Loss: [1.55205208e-05 3.80071478e-05 7.20571028e-05 1.41211149e-04
 1.97285886e-04 2.10512074e-04]
          Top-1: 0.2500, Top-5: 0.3958
          Rank Top-1: [0.8125 0.5625 0.125  0.     0.     0.    ], Rank Top-5: [1.     0.9375 0.4375 0.     0.     0.    ]
  Val   - Loss: 0.0027, Rank Loss: [7.28844902e-05 1.86223172e-04 3.17338355e-04 5.85013958e-04
 7.36033056e-04 8.38990488e-04]
          Top-1: 0.2500, Top-5: 0.3438
          Rank Top-1: [0.875  0.4375 0.1875 0.     0.     0.    ], Rank Top-5: [1.     0.6875 0.375  0.     0.     0.    ]


Training:   0%|                                                                                                                                 | 0/61780 [00:00<?, ?it/s]
Evaluation:   0%|                                                                                                                               | 0/15445 [00:00<?, ?it/s]



Epoch [7/15]
  Train - Loss: 0.0006, Rank Loss: [1.93255393e-05 3.09778557e-05 6.92960320e-05 1.34740705e-04
 1.91935246e-04 2.03231115e-04]
          Top-1: 0.2500, Top-5: 0.4167
          Rank Top-1: [0.75   0.6875 0.0625 0.     0.     0.    ], Rank Top-5: [1.     0.8125 0.6875 0.     0.     0.    ]
  Val   - Loss: 0.0027, Rank Loss: [6.80433272e-05 1.86582351e-04 3.23060710e-04 5.87913387e-04
 7.36026202e-04 8.38929482e-04]
          Top-1: 0.2708, Top-5: 0.3542
          Rank Top-1: [0.875  0.5625 0.1875 0.     0.     0.    ], Rank Top-5: [1.    0.75  0.375 0.    0.    0.   ]


Training:   0%|                                                                                                                                 | 0/61780 [00:00<?, ?it/s]
Evaluation:   0%|                                                                                                                               | 0/15445 [00:00<?, ?it/s]



Epoch [8/15]
  Train - Loss: 0.0007, Rank Loss: [1.75532341e-05 4.27429133e-05 7.60178691e-05 1.32682926e-04
 1.88760149e-04 2.17791706e-04]
          Top-1: 0.2708, Top-5: 0.3854
          Rank Top-1: [0.8125 0.625  0.1875 0.     0.     0.    ], Rank Top-5: [1.     0.75   0.5625 0.     0.     0.    ]
  Val   - Loss: 0.0027, Rank Loss: [6.65162442e-05 1.85500429e-04 3.25317636e-04 5.88584139e-04
 7.36173961e-04 8.37967225e-04]
          Top-1: 0.2708, Top-5: 0.3542
          Rank Top-1: [0.9375 0.5625 0.125  0.     0.     0.    ], Rank Top-5: [1.    0.75  0.375 0.    0.    0.   ]


Training:   0%|                                                                                                                                 | 0/61780 [00:00<?, ?it/s]
Evaluation:   0%|                                                                                                                               | 0/15445 [00:00<?, ?it/s]



Epoch [9/15]
  Train - Loss: 0.0007, Rank Loss: [1.46225920e-05 2.67854455e-05 6.56646291e-05 1.42890282e-04
 1.91066365e-04 2.09815574e-04]
          Top-1: 0.3438, Top-5: 0.4375
          Rank Top-1: [0.9375 0.8125 0.3125 0.     0.     0.    ], Rank Top-5: [1.     0.9375 0.6875 0.     0.     0.    ]
  Val   - Loss: 0.0027, Rank Loss: [6.27414005e-05 1.80701462e-04 3.27962639e-04 5.88425080e-04
 7.36072512e-04 8.37133895e-04]
          Top-1: 0.2708, Top-5: 0.3438
          Rank Top-1: [0.875  0.5625 0.1875 0.     0.     0.    ], Rank Top-5: [1.     0.6875 0.375  0.     0.     0.    ]


Training:   0%|                                                                                                                                 | 0/61780 [00:00<?, ?it/s]
Evaluation:   0%|                                                                                                                               | 0/15445 [00:00<?, ?it/s]



Epoch [10/15]
  Train - Loss: 0.0006, Rank Loss: [1.24921919e-05 3.15612942e-05 6.90698006e-05 1.41899992e-04
 1.83815306e-04 2.08376711e-04]
          Top-1: 0.3438, Top-5: 0.4167
          Rank Top-1: [1.     0.8125 0.25   0.     0.     0.    ], Rank Top-5: [1.     0.9375 0.5625 0.     0.     0.    ]
  Val   - Loss: 0.0027, Rank Loss: [6.34402625e-05 1.79976312e-04 3.28779853e-04 5.87065855e-04
 7.36351915e-04 8.37634535e-04]
          Top-1: 0.2292, Top-5: 0.3438
          Rank Top-1: [0.8125 0.5    0.0625 0.     0.     0.    ], Rank Top-5: [1.     0.6875 0.375  0.     0.     0.    ]


Training:   0%|                                                                                                                                 | 0/61780 [00:00<?, ?it/s]
Evaluation:   0%|                                                                                                                               | 0/15445 [00:00<?, ?it/s]



Epoch [11/15]
  Train - Loss: 0.0007, Rank Loss: [1.45231801e-05 3.11304003e-05 7.86815272e-05 1.37024645e-04
 1.88049617e-04 2.19487341e-04]
          Top-1: 0.2708, Top-5: 0.3854
          Rank Top-1: [0.875 0.625 0.125 0.    0.    0.   ], Rank Top-5: [1.     0.875  0.4375 0.     0.     0.    ]
  Val   - Loss: 0.0027, Rank Loss: [6.32240688e-05 1.75668631e-04 3.29746093e-04 5.86895620e-04
 7.37180676e-04 8.37857563e-04]
          Top-1: 0.2292, Top-5: 0.3438
          Rank Top-1: [0.875  0.4375 0.0625 0.     0.     0.    ], Rank Top-5: [1.     0.6875 0.375  0.     0.     0.    ]


Training:   0%|                                                                                                                                 | 0/61780 [00:00<?, ?it/s]
Evaluation:   0%|                                                                                                                               | 0/15445 [00:00<?, ?it/s]



Epoch [12/15]
  Train - Loss: 0.0007, Rank Loss: [1.44774771e-05 3.56658449e-05 7.16520768e-05 1.38074999e-04
 1.81426242e-04 2.19491185e-04]
          Top-1: 0.2917, Top-5: 0.3854
          Rank Top-1: [0.875  0.6875 0.125  0.0625 0.     0.    ], Rank Top-5: [1.     0.8125 0.4375 0.0625 0.     0.    ]
  Val   - Loss: 0.0027, Rank Loss: [6.32612055e-05 1.76654166e-04 3.29262711e-04 5.86870551e-04
 7.37366039e-04 8.38729547e-04]
          Top-1: 0.2292, Top-5: 0.3438
          Rank Top-1: [0.875  0.4375 0.0625 0.     0.     0.    ], Rank Top-5: [1.     0.6875 0.375  0.     0.     0.    ]
Early stopping.

BEST MODEL
  Train - Loss: 0.0006, Rank Loss: [1.93255393e-05 3.09778557e-05 6.92960320e-05 1.34740705e-04
 1.91935246e-04 2.03231115e-04]
          Top-1: 0.2500, Top-5: 0.4167
          Rank Top-1: [0.75   0.6875 0.0625 0.     0.     0.    ], Rank Top-5: [1.     0.8125 0.6875 0.     0.     0.    ]
  Val   - Loss: 0.0027, Rank Loss: [6.80433272e-05 1.86582351e-04 3.23060710e-04 5.8791