# Setup

In [2]:
import sys
import os

project_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))
sys.path.append(project_root)

print(project_root)

/data/data3/junibg-ego/Modelo_leo_coi


In [3]:
import torch
print(f"¬øGPU disponible? {torch.cuda.is_available()}")
print(f"N√∫mero de GPUs: {torch.cuda.device_count()}")
print(f"Nombre GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'No GPU'}")


¬øGPU disponible? True
N√∫mero de GPUs: 2
Nombre GPU: NVIDIA GeForce RTX 3090


In [4]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import pandas as pd
import json

# Tus imports
from src.combined_model.combined_model_embedding import *
from src.combined_model.combined_models import *
from src.decoders.decoder_simple import *
from src.encoders_model.DNABERT_Embedder import *
from src.encoders_model.embdeeding_encoders import *
from src.encoders_model.simple_encoders import *
from src.evaluators.linear_evaluator import *
from src.decoders.sequence_decoder import *
import numpy as np


In [5]:
from src.encoders_model.simple_encoders import *
from src.utils.load_fastaDataset import *
from src.training.experimentRunner import *
from torch.utils.data import DataLoader

In [6]:
def load_hierarchy_from_json(json_path):
    """Carga la jerarqu√≠a taxon√≥mica desde JSON"""
    with open(json_path, 'r') as f:
        hierarchy_raw = json.load(f)
    
    # Convertir claves string a int
    hierarchy = {}
    for child_taxon, parent_dict in hierarchy_raw.items():
        hierarchy[child_taxon] = {}
        for parent_key, children_list in parent_dict.items():
            parent_int = int(float(parent_key))
            children_int = [int(c) for c in children_list]
            hierarchy[child_taxon][parent_int] = children_int
    
    print("‚úÖ Jerarqu√≠a cargada desde JSON")
    for taxon, mapping in hierarchy.items():
        n_parents = len(mapping)
        n_children = sum(len(v) for v in mapping.values())
        print(f"  {taxon:10s}: {n_parents:4d} padres ‚Üí {n_children:5d} hijos")
    
    return hierarchy

In [7]:
hierarchy_path = os.path.join(project_root, "src", "data", "taxonomy_hierarchy_fixed_with_class.json")
hierarchy = load_hierarchy_from_json(hierarchy_path)

‚úÖ Jerarqu√≠a cargada desde JSON
  class     :   49 padres ‚Üí   187 hijos
  order     :  173 padres ‚Üí   831 hijos
  family    :  797 padres ‚Üí  5446 hijos
  genus     : 5393 padres ‚Üí 50568 hijos
  species   : 50510 padres ‚Üí 205075 hijos


In [8]:
df = os.path.join(project_root, "src", "data", "all_taxa_numeric.csv")
df = pd.read_csv(df)

In [9]:
taxon_order = ['phylum', 'class','order', 'family', 'genus', 'species']
total_classes = {}
for taxon in taxon_order:
    n_classes = df[taxon].nunique()
    total_classes[taxon] = n_classes
    print(f"  {taxon:10s}: {n_classes:6d} clases")

  phylum    :     49 clases
  class     :    173 clases
  order     :    797 clases
  family    :   5393 clases
  genus     :  50510 clases
  species   : 205075 clases


# Definimos Variables

In [10]:
max_length = 750
batch_size = 8

# Cargamos el Modelo

In [11]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"üñ•Ô∏è  Dispositivo: {device}")

# DNABERT
dnabert_path = os.path.join(project_root, "src", "data", "archives")
dnabert = DNABERTEmbedder(
    model_name=dnabert_path,
    max_length=max_length,
    device=device
)
embed_dim = dnabert.get_embedding_dim()

# Encoder
latent_dim = 256
encoder = SimpleEmbeddingEncoder(
    embed_dim=embed_dim,
    latent_dim=latent_dim,
    dropout=0.1
)

# Decoders (uno por tax√≥n)
decoders_dict = {}
for taxon in taxon_order:
    decoders_dict[taxon] = SequenceDecoder(
        latent_dim=latent_dim,
        seq_len=max_length,
        vocab_size=4,
        dropout=0.1
    )

# Global decoder
global_decoder = SequenceDecoder(
    latent_dim=latent_dim,
    seq_len=max_length,
    vocab_size=4,
    dropout=0.1
)

# Classifiers (uno por tax√≥n)
classifiers_dict = {}
print(f"\nüîß Creando classifiers con n√∫mero TOTAL de clases:")
for taxon in taxon_order:
    # ‚úÖ CORRECTO: Usar total_classes (del dataset completo)
    # ‚ùå INCORRECTO: n_classes = df_train[taxon].nunique()
    n_classes = total_classes[taxon]
    
    classifiers_dict[taxon] = CosineClassifier(
        latent_dim=latent_dim,
        num_classes=n_classes,
        scale=20.0  # ‚≠ê Ajustable si es necesario
    )
    print(f"  {taxon:10s}: {n_classes:6d} clases")

# ‚≠ê CREAR MODELO JER√ÅRQUICO
model = HierarchicalCombinedModelFixed(
    dnabert=dnabert,
    encoder=encoder,
    decoders_dict=decoders_dict,
    classifiers_dict=classifiers_dict,
    global_decoder=global_decoder,
    taxonomy_hierarchy=hierarchy  # ‚≠ê Aqu√≠ usas la jerarqu√≠a
)

print(f"\n‚úÖ Modelo jer√°rquico creado")

üñ•Ô∏è  Dispositivo: cuda


Some weights of BertModel were not initialized from the model checkpoint at /data/data3/junibg-ego/Modelo_leo_coi/src/data/archives and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



üîß Creando classifiers con n√∫mero TOTAL de clases:
  phylum    :     49 clases
  class     :    173 clases
  order     :    797 clases
  family    :   5393 clases
  genus     :  50510 clases
  species   : 205075 clases
‚úÖ Modelo jer√°rquico creado con m√°scaras suaves (no -inf)

‚úÖ Modelo jer√°rquico creado


In [12]:

checkpoint_path = os.path.join(project_root, "src", "data", "checkpointss", "final_model.pt")
checkpoint = torch.load(checkpoint_path, weights_only=True)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval() 


HierarchicalCombinedModelFixed(
  (dnabert): DNABERTEmbedder(
    (model): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(4096, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0-11): 12 x BertLayer(
            (attention): BertUnpadAttention(
              (self): BertUnpadSelfAttention(
                (dropout): Dropout(p=0.0, inplace=False)
                (Wqkv): Linear(in_features=768, out_features=2304, bias=True)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=True)
                (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
            )
            (mlp): BertG

In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


HierarchicalCombinedModelFixed(
  (dnabert): DNABERTEmbedder(
    (model): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(4096, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0-11): 12 x BertLayer(
            (attention): BertUnpadAttention(
              (self): BertUnpadSelfAttention(
                (dropout): Dropout(p=0.0, inplace=False)
                (Wqkv): Linear(in_features=768, out_features=2304, bias=True)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=True)
                (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
            )
            (mlp): BertG

# Cargamos el DataLoader

In [14]:
from sklearn.model_selection import train_test_split

# Primero separar test (20%)
df_temp, df_test = train_test_split(
    df, test_size=0.2, random_state=42, stratify=df['phylum']
)

# Luego separar train/val (80/20 del 80% restante = 64/16 del total)
df_train, df_val = train_test_split(
    df_temp, test_size=0.2, random_state=42, stratify=df_temp['phylum']
)

In [15]:
val_dataset = MultiTaxaFastaDataset(
    df_val.reset_index(drop=True),
    max_length=max_length,
    taxon_cols=taxon_order
)

In [16]:
def collate_multitask(batch, taxon_cols=['phylum', 'class','order','family','genus','species'], max_length=900):
    sequences, labels_dict_list, recon_targets_list, true_tokens_list = zip(*batch)

    # Labels: dict de tensors
    labels_dict = {taxon: torch.stack([d[taxon] for d in labels_dict_list]) for taxon in taxon_cols}

    # Recon targets: dict de tensors
    recon_targets_dict = {taxon: torch.stack([d[taxon] for d in recon_targets_list]) for taxon in taxon_cols}

    # True tokens
    true_tokens = torch.stack(true_tokens_list)

    return sequences, labels_dict

In [17]:
val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=True,
    drop_last=True,
    collate_fn=lambda b: collate_multitask(b, taxon_cols=val_dataset.taxon_cols, max_length=val_dataset.max_length),
    num_workers=6
)

# Obtenemos los Embddings del Decoder

In [25]:

os.environ["TOKENIZERS_PARALLELISM"] = "false"

# ============== STEP 1: Collect embeddings and labels ==============
print("Collecting embeddings from dataloader...")

all_embeddings = []
all_labels = []

# Choose which taxonomic level to visualize
taxonomic_level = 'phylum'  # Options: 'phylum', 'class', 'order', 'family', 'genus', 'species'

with torch.no_grad():
    pbar = tqdm(val_loader, desc="Collecting embeddings")
    for sequences, labels in pbar:
        # Get embeddings
        embeddings = model(sequences)['z']
        
        # Move to CPU and convert to numpy
        embeddings_np = embeddings.cpu().numpy()
        labels_np = labels[taxonomic_level].cpu().numpy()
        
        all_embeddings.append(embeddings_np)
        all_labels.append(labels_np)

# Concatenate all batches
embeddings = np.vstack(all_embeddings)
labels = np.concatenate(all_labels)

print(f"Collected {len(embeddings)} embeddings with dimension {embeddings.shape[1]}")
print(f"Number of unique classes: {len(np.unique(labels))}")

np.save(f'embeddings_{taxonomic_level}.npy', embeddings)
np.save(f'labels_{taxonomic_level}.npy', labels)
print(f"Embeddings and labels saved to disk")

Collecting embeddings from dataloader...


Collecting embeddings:   7%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè                                                                                               | 2139/30890 [14:16<3:11:47,  2.50it/s]


KeyboardInterrupt: 