$$\textbf{Proyecto de Verano: PLN aplicado a la Bioinformática}$$
$$\textit{Y. Sarahi García Gozález}$$

In [3]:
import numpy as np
import pandas as pd
import numpy as np
import os
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
import numpy as np
import torch
import urllib.request
import sys
# sys.path.append('/kaggle/input/proyecto-archivos/')
# import convert_encodings
# from convert_encodings import m2
from transformers import BertModel, BertConfig, logging
from tqdm import tqdm
from sklearn.metrics import accuracy_score
#import wandb
from datetime import datetime
import yaml
import os
import shutil


In [2]:
pip install torchdrug -q

Note: you may need to restart the kernel to use updated packages.


In [3]:
import torchdrug
from torchdrug.datasets import Solubility

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}\n')

Device: cpu



Descargar los datos

In [5]:
class PeptideBERTDataset(torch.utils.data.Dataset):
    def __init__(self, input_ids, attention_masks, labels):
        self.input_ids = input_ids
        self.attention_masks = attention_masks
        self.labels = labels

        self.length = len(self.input_ids)

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        input_id = self.input_ids[idx]
        attention_mask = self.attention_masks[idx]
        label = self.labels[idx]

        return {
            'input_ids': torch.tensor(input_id, dtype=torch.long),
            'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
            'labels': torch.tensor(label, dtype=torch.float)
        }

In [6]:
logging.set_verbosity_error()

# Definimos la clase PeptideBERT, que hereda de torch.nn.Module (la clase base para todas las redes neuronales en PyTorch)
class PeptideBERT(torch.nn.Module):
    def __init__(self, bert_config):
        super(PeptideBERT, self).__init__()

        # Cargamos el modelo preentrenado
        self.protbert = BertModel.from_pretrained(
            'Rostlab/prot_bert_bfd', 
            config=bert_config,
            ignore_mismatched_sizes=True
        ) 
        #clasificacion
        self.head = torch.nn.Sequential( 
            torch.nn.Linear(bert_config.hidden_size, 1), #toma la salida de protVert y la convierte en un valor
            torch.nn.Sigmoid()
        )
    # Definimos el método forward que especifica cómo procesar los datos
    def forward(self, inputs, attention_mask):
        # Pasamos las entradas a través de ProtBert
        output = self.protbert(inputs, attention_mask=attention_mask)
        # Usamos la salida de ProtBert como entrada a la capa de regresión
        return self.head(output.pooler_output)


In [7]:
#criterio de pérdida,optimizador y el planificador de learning rate  para el entrenamiento del modelo

def cri_opt_sch(config, model):
    #criterio de loss (BinaryCrossEntropy)
    criterion = torch.nn.BCELoss()
    #optimizador AmadW
    optimizer = torch.optim.AdamW(model.parameters(), lr=config['optim']['lr'])
    #Scheduler
    if config['sch']['name'] == 'onecycle':
        ## Durante el entrenamiento, el learning-rate empieza en un valor inicial, aumenta hasta el valor máximo especificado (max_lr), y luego disminuye nuevamente hacia el final del entrenamiento.
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=config['optim']['lr'],
            epochs=config['epochs'],
            steps_per_epoch=config['sch']['steps']
        ) #Ajusta el learning-rate utilizando un ciclo de una sola pasada
    elif config['sch']['name'] == 'lronplateau':
        ## ajusta el learning-rate basándose en el rendimiento del modelo. Específicamente, reduce la tasa de aprendizaje cuando una métrica de rendimiento ha dejado de mejorar.
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='max',
            factor=config['sch']['factor'],
            patience=config['sch']['patience']
        )# Reduce lr cuando la métrica especificada ha dejado de mejorar.

    return criterion, optimizer, scheduler


In [8]:
###Función que se encarga del proceso de entrenamiento
def train(model, dataloader, optimizer, criterion, scheduler, device):
    model.train()  # Pone el modelo en modo de entrenamiento
    total_loss = 0.0

    for batch in tqdm(dataloader):  # Itera sobre los lotes de datos en el dataloader
        inputs = batch['input_ids'].to(device)  # Mueve las entradas al dispositivo (CPU o GPU)
        attention_mask = batch['attention_mask'].to(device)  # Mueve la máscara de atención al dispositivo
        labels = batch['labels'].to(device)  # Mueve las etiquetas al dispositivo

        optimizer.zero_grad()  # Resetea los gradientes del optimizador

        logits = model(inputs, attention_mask).squeeze(1)  # Pasa las entradas a través del modelo y ajusta las dimensiones
        loss = criterion(logits, labels)  # Calcula la pérdida

        loss.backward()  # Calcula los gradientes
        optimizer.step()  # Actualiza los parámetros del modelo
        # scheduler.step()  # Si el scheduler es OneCycleLR, ajusta la tasa de aprendizaje en cada paso

        total_loss += loss.item()  # Acumula la pérdida total

    return total_loss / len(dataloader)  # Retorna la pérdida promedio por lote

In [9]:
def validate(model, dataloader, criterion, device):
    model.eval()  # Pone el modelo en modo de evaluación
    total_loss = 0.0

    ground_truth = []
    predictions = []

    for batch in tqdm(dataloader):  # Itera sobre los lotes de datos en el dataloader
        inputs = batch['input_ids'].to(device)  # Mueve las entradas al dispositivo
        attention_mask = batch['attention_mask'].to(device)  # Mueve la máscara de atención al dispositivo
        labels = batch['labels'].to(device)  # Mueve las etiquetas al dispositivo

        with torch.inference_mode():  # Desactiva el cálculo de gradientes
            logits = model(inputs, attention_mask).squeeze(1)  # Pasa las entradas a través del modelo
            loss = criterion(logits, labels)  # Calcula la pérdida

        total_loss += loss.item()  # Acumula la pérdida total
        # Genera predicciones binarias
        preds = torch.where(logits > 0.5, 1, 0)  
        predictions.extend(preds.cpu().tolist())  # Añade las predicciones a la lista
        ground_truth.extend(labels.cpu().tolist())  # Añade las etiquetas reales a la lista

    total_loss = total_loss / len(dataloader)  # Calcula la pérdida promedio
    accuracy = 100 * accuracy_score(ground_truth, predictions)  # Calcula la precisión

    return total_loss, accuracy  # Retorna la pérdida promedio y la precisión

In [10]:
def test(model, dataloader, device):
    model.eval()  # Pone el modelo en modo de evaluación

    ground_truth = []
    predictions = []

    for batch in tqdm(dataloader):  # Itera sobre los lotes de datos en el dataloader
        inputs = batch['input_ids'].to(device)  # Mueve las entradas al dispositivo
        attention_mask = batch['attention_mask'].to(device)  # Mueve la máscara de atención al dispositivo
        labels = batch['labels']  # Las etiquetas permanecen en la CPU

        with torch.inference_mode():  # Desactiva el cálculo de gradientes
            logits = model(inputs, attention_mask).squeeze(1)  # Pasa las entradas a través del modelo

        preds = torch.where(logits > 0.5, 1, 0)  # Genera predicciones binarias
        predictions.extend(preds.cpu().tolist())  # Añade las predicciones a la lista
        ground_truth.extend(labels.tolist())  # Añade las etiquetas reales a la lista

    accuracy = 100 * accuracy_score(ground_truth, predictions)  # Calcula la precisión

    return accuracy  # Retorna la precisión


In [11]:
def train_model(model):
    print(f'{"="*30}{"TRAINING":^20}{"="*30}')

    best_acc = 0 #inicializa la mejor precisión en cero

    #iteramos cada época
    for epoch in range(config['epochs']):

        #llamamos a la funcion de entrenamiento
        train_loss = train(model, train_data_loader, optimizer, criterion, scheduler, device)
        #obtenemos learning rate
        curr_lr = optimizer.param_groups[0]['lr']
        #imprimimos loss de entrenamiento y learning rate
        print(f'Epoch {epoch+1}/{config["epochs"]} - Train Loss: {train_loss}\tLR: {curr_lr}')
        #imprimimos loss y accuracy de validacion
        val_loss, val_acc = validate(model, val_data_loader, criterion, device)
        print(f'Epoch {epoch+1}/{config["epochs"]} - Validation Loss: {val_loss}\tValidation Accuracy: {val_acc}\n')
        #Actualizar el Scheduler:
        scheduler.step(val_acc)

        #Registrar Métricas con wandb
        if not config['debug']:
            wandb.log({
                'train_loss': train_loss, 
                'val_loss': val_loss, 
                'val_accuracy': val_acc, 
                'lr': curr_lr
            })
        #Guardamos mejor modelo
        if val_acc >= best_acc and not config['debug']:
            best_acc = val_acc
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'train_loss': train_loss,
                'val_loss': val_loss,
                'acc': val_acc, 
                'lr': curr_lr
            }, f'{save_dir}/model.pt')
            print('Model Saved\n')
    wandb.finish()




Solubilidad

In [12]:
ds=Solubility("f'/kaggle/working/",lazy=True)

19:55:32   Extracting f'/kaggle/working/solubility.tar.gz to f'/kaggle/working


Constructing proteins from sequences: 100%|██████████| 71419/71419 [00:00<00:00, 118807.44it/s]


In [13]:
sequences=ds.sequences
targets=ds.targets['solubility']

In [14]:
aminoacidos = ['G', 'A', 'S', 'P', 'V', 'T', 'C', 'I', 'L', 'N', 'D', 'Q', 'K', 'E', 'M', 'H', 'F', 'R', 'Y', 'W']
#diccionario de mapeo
letter_to_number = {letter: index  for index, letter in enumerate(aminoacidos)}
# Función para convertir secuencias de letras a secuencias de números
def convert_sequences_to_numbers(sequences, mapping):
    return [[mapping[letter] for letter in seq] for seq in sequences]

sequences_number=convert_sequences_to_numbers(sequences, letter_to_number)

In [15]:
#longitud del string más larga
max_length = max(len(s) for s in sequences)

print("La longitud del string más largo es:", max_length)

La longitud del string más largo es: 1200


In [16]:
ds.num_samples #train,valid,test

[62478, 6942, 1999]

In [17]:
from torch.utils.data import DataLoader
import numpy as np
from keras.preprocessing.sequence import pad_sequences

def atention_mask(array_sequences, max_length):
    m = len(array_sequences)
    atention_mask_sequence = np.zeros((m, max_length), dtype=np.float64)

    for i, seq in enumerate(array_sequences):
        seq_len = min(len(seq), max_length)
        atention_mask_sequence[i, :seq_len] = 1

    return atention_mask_sequence


def load_data_torchdrug(sequences, targets, ds, max_length,truncate=True):
    print(f'{"="*30}{"DATA":^20}{"="*30}')
    
    n0=ds.num_samples[0] #lista que contiene el número de muestras por set:train,val,test
    n1=ds.num_samples[0]+ds.num_samples[1]
   
    
    train_sequences=[seq for seq in sequences[0:n0] if len(seq)<500]
    train_targets=np.array([target for seq,target in zip(sequences[0:n0],targets[0:n0]) if len(seq)<500])
  
    val_sequences=[seq for seq in sequences[n0:n1] if len(seq)<500]
    val_targets=np.array([target for seq,target in zip(sequences[n0:n1],targets[n0:n1]) if len(seq)<500])

    test_sequences=[seq for seq in sequences[n1:] if len(seq)<500]
    test_targets=np.array([target for seq,target in zip(sequences[n1:],targets[n1:]) if len(seq)<500])

    #cnvertir a array
    # Padear las secuencias para que todas tengan la misma longitud
    max_len = 500
    train_sequences = pad_sequences(train_sequences, maxlen=max_len, padding='post')
    val_sequences= pad_sequences(val_sequences, maxlen=max_len, padding='post')
    test_sequences = pad_sequences(test_sequences, maxlen=max_len, padding='post')
    
    
    # Crear las máscaras de atención
    attention_mask_train = (train_sequences > 0).astype(np.float64)
    attention_mask_val = (val_sequences > 0).astype(np.float64)
    attention_mask_test = (test_sequences > 0).astype(np.float64)
    
    
    
    train_dataset = PeptideBERTDataset(input_ids=train_sequences, attention_masks=attention_mask_train, labels=train_targets)
    val_dataset = PeptideBERTDataset(input_ids=val_sequences, attention_masks=attention_mask_val, labels=val_targets)
    test_dataset = PeptideBERTDataset(input_ids=test_sequences, attention_masks=attention_mask_test, labels=test_targets)

    train_data_loader = DataLoader(
        train_dataset,
        batch_size=16,
        shuffle=True
    )

    val_data_loader = DataLoader(
        val_dataset,
        batch_size=16,
        shuffle=False
    )

    test_data_loader = DataLoader(
        test_dataset,
        batch_size=16,
        shuffle=False
    )

    print('Batch size: ', 8)

    print('Train dataset samples: ', len(train_dataset))
    print('Validation dataset samples: ', len(val_dataset))
    print('Test dataset samples: ', len(test_dataset))

    print('Train dataset batches: ', len(train_data_loader))
    print('Validation dataset batches: ', len(val_data_loader))
    print('Test dataset batches: ', len(test_data_loader))

    print()

    return train_data_loader, val_data_loader, test_data_loader

2024-07-29 19:55:37.383398: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-07-29 19:55:37.383500: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-07-29 19:55:37.509124: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [18]:
train_data_loader, val_data_loader, test_data_loader = load_data_torchdrug(sequences_number,targets,ds,max_length)

Batch size:  8
Train dataset samples:  56845
Validation dataset samples:  6308
Test dataset samples:  1774
Train dataset batches:  3553
Validation dataset batches:  395
Test dataset batches:  111



In [19]:
train_data_loader.dataset[0]

{'input_ids': tensor([ 0, 14,  7,  8, 12,  5,  9,  8, 16,  0, 15,  5, 18, 11, 16, 12,  2,  7,
          5, 10,  4,  8,  1, 12,  1,  9, 13, 13, 12,  2,  0, 10, 17,  8,  1,  0,
          4,  1,  1, 13,  2,  1, 13, 13, 17,  4,  1,  1, 12,  4,  4,  8,  2, 12,
         14,  5,  8,  0, 10,  8, 17,  9,  9,  3,  4,  4,  3, 18, 13,  5, 10, 13,
          4,  5, 17,  7,  7, 11, 10, 11,  4,  9, 10, 17,  7, 15, 10,  2,  7, 12,
          9, 19,  5,  4, 13, 13,  8, 17, 13, 19,  7,  8, 10, 15, 12,  5,  5, 10,
          1, 10,  7, 12, 17,  4,  1, 17,  0,  8,  5,  2, 13,  7,  7,  1,  1,  4,
          5, 12,  8, 14,  2,  9,  8, 10,  8,  7, 18,  0,  1, 12, 12,  7, 17,  4,
          7,  1, 15,  1,  9,  5,  5,  7,  0,  8,  3,  0,  5, 16,  2,  1, 17,  8,
         11,  3,  9, 15,  3,  5, 10, 10,  3, 10,  0,  7,  8,  1,  2,  8, 14, 13,
          0,  8,  5, 18,  0,  7,  0, 10,  1,  4,  7,  0,  8,  9,  3,  4, 10, 10,
          2,  5, 10,  2,  4,  4, 17,  8,  8,  9, 12, 16, 13, 13, 16, 17,  2, 12,
         19, 10

In [20]:
train_data_loader.dataset[0]['input_ids']

tensor([ 0, 14,  7,  8, 12,  5,  9,  8, 16,  0, 15,  5, 18, 11, 16, 12,  2,  7,
         5, 10,  4,  8,  1, 12,  1,  9, 13, 13, 12,  2,  0, 10, 17,  8,  1,  0,
         4,  1,  1, 13,  2,  1, 13, 13, 17,  4,  1,  1, 12,  4,  4,  8,  2, 12,
        14,  5,  8,  0, 10,  8, 17,  9,  9,  3,  4,  4,  3, 18, 13,  5, 10, 13,
         4,  5, 17,  7,  7, 11, 10, 11,  4,  9, 10, 17,  7, 15, 10,  2,  7, 12,
         9, 19,  5,  4, 13, 13,  8, 17, 13, 19,  7,  8, 10, 15, 12,  5,  5, 10,
         1, 10,  7, 12, 17,  4,  1, 17,  0,  8,  5,  2, 13,  7,  7,  1,  1,  4,
         5, 12,  8, 14,  2,  9,  8, 10,  8,  7, 18,  0,  1, 12, 12,  7, 17,  4,
         7,  1, 15,  1,  9,  5,  5,  7,  0,  8,  3,  0,  5, 16,  2,  1, 17,  8,
        11,  3,  9, 15,  3,  5, 10, 10,  3, 10,  0,  7,  8,  1,  2,  8, 14, 13,
         0,  8,  5, 18,  0,  7,  0, 10,  1,  4,  7,  0,  8,  9,  3,  4, 10, 10,
         2,  5, 10,  2,  4,  4, 17,  8,  8,  9, 12, 16, 13, 13, 16, 17,  2, 12,
        19, 10,  4,  3,  5, 11,  5,  6, 

In [21]:
train_data_loader.dataset[7410]['input_ids'].size()

torch.Size([500])

In [29]:
##Configuración y Preparación###
#llamamos al archivo donde se guarda la config del modelo peptidebert
config = yaml.load(open('/kaggle/input/proyecto-archivos/config.yaml', 'r'), Loader=yaml.FullLoader)
config['task'] = 'fluor'
config['batch_size'] = 16
config['epochs'] = 8
config['optim']['lr'] = 1.0e-5
config['sch']['steps'] = len(train_data_loader)

In [30]:
def create_model_torchdrug(config):
    bert_config = BertConfig(
        vocab_size=25,
        hidden_size=480,
        num_hidden_layers=12,
        num_attention_heads=12,
        hidden_dropout_prob=0.15,
        max_position_embeddings= 512 #maximo len de preentrenamiento HF
    )
    #creamos una istancia de PeptideBERT utilizando la configuración de BERT definida
    model = PeptideBERT(bert_config).to(device)
    #regresamos el modelo
    return model

In [31]:

#creamos el modelo
model_torchdrug = create_model_torchdrug(config)

#configuramos criterio de pérdida, optimizador y scheduler
criterion, optimizer, scheduler = cri_opt_sch(config, model_torchdrug)


#Configuración de Weights & Biases (WandB)
if not config['debug']:
    run_name = f'{config["task"]}-{datetime.now().strftime("%m%d_%H%M")}'
    wandb.init(project='PeptideBERT', name=run_name)

    save_dir = f'/kaggle/working/checkpoints/{run_name}'
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    else:
        print('ya existe')
    shutil.copy('/kaggle/input/proyecto-archivos/config.yaml', f'{save_dir}/config.yaml')
    shutil.copy('/kaggle/input/model-peptidos/network.py', f'{save_dir}/network.py')

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
lr,▁▁
train_loss,█▁
val_accuracy,▁▁
val_loss,█▁

0,1
lr,0.0001
train_loss,0.68198
val_accuracy,58.33862
val_loss,0.67963


In [32]:
#Entrenamiento del Modelo
train_model(model_torchdrug)
if not config['debug']:
    model_torchdrug.load_state_dict(torch.load(f'{save_dir}/model.pt')['model_state_dict'], strict=False)




100%|██████████| 3553/3553 [31:51<00:00,  1.86it/s]


Epoch 1/8 - Train Loss: 0.6342966494226737	LR: 1e-05


100%|██████████| 395/395 [01:12<00:00,  5.45it/s]


Epoch 1/8 - Validation Loss: 0.5753310979921606	Validation Accuracy: 65.91629676601141

Model Saved



100%|██████████| 3553/3553 [31:51<00:00,  1.86it/s]


Epoch 2/8 - Train Loss: 0.5640089373051064	LR: 1e-05


100%|██████████| 395/395 [01:12<00:00,  5.46it/s]


Epoch 2/8 - Validation Loss: 0.5345281376114375	Validation Accuracy: 70.62460367786937

Model Saved



100%|██████████| 3553/3553 [31:50<00:00,  1.86it/s]


Epoch 3/8 - Train Loss: 0.5432118680438223	LR: 1e-05


100%|██████████| 395/395 [01:12<00:00,  5.47it/s]


Epoch 3/8 - Validation Loss: 0.5378148098535176	Validation Accuracy: 69.57831325301204



100%|██████████| 3553/3553 [31:50<00:00,  1.86it/s]


Epoch 4/8 - Train Loss: 0.5347400006444697	LR: 1e-05


100%|██████████| 395/395 [01:12<00:00,  5.46it/s]


Epoch 4/8 - Validation Loss: 0.5305518093742901	Validation Accuracy: 71.06848446417247

Model Saved



100%|██████████| 3553/3553 [31:50<00:00,  1.86it/s]


Epoch 5/8 - Train Loss: 0.5297306505703705	LR: 1e-05


100%|██████████| 395/395 [01:12<00:00,  5.46it/s]


Epoch 5/8 - Validation Loss: 0.5223847663100761	Validation Accuracy: 71.16360177552315

Model Saved



100%|██████████| 3553/3553 [31:50<00:00,  1.86it/s]


Epoch 6/8 - Train Loss: 0.524560482635184	LR: 1e-05


100%|██████████| 395/395 [01:12<00:00,  5.46it/s]


Epoch 6/8 - Validation Loss: 0.5186560472355614	Validation Accuracy: 71.8294229549778

Model Saved



100%|██████████| 3553/3553 [31:50<00:00,  1.86it/s]


Epoch 7/8 - Train Loss: 0.5208642824523858	LR: 1e-05


100%|██████████| 395/395 [01:12<00:00,  5.46it/s]


Epoch 7/8 - Validation Loss: 0.5182853369773188	Validation Accuracy: 72.28915662650603

Model Saved



100%|██████████| 3553/3553 [31:49<00:00,  1.86it/s]


Epoch 8/8 - Train Loss: 0.5171788369991932	LR: 1e-05


100%|██████████| 395/395 [01:12<00:00,  5.47it/s]


Epoch 8/8 - Validation Loss: 0.5202548097960557	Validation Accuracy: 71.70259987317692



VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
lr,▁▁▁▁▁▁▁▁
train_loss,█▄▃▂▂▁▁▁
val_accuracy,▁▆▅▇▇▇█▇
val_loss,█▃▃▃▂▁▁▁

0,1
lr,1e-05
train_loss,0.51718
val_accuracy,71.7026
val_loss,0.52025


In [34]:
#test
test_acc = test(model_torchdrug, test_data_loader, device)
print(f'Test Accuracy: {test_acc}%')

100%|██████████| 111/111 [00:20<00:00,  5.48it/s]

Test Accuracy: 67.98196166854565%





In [38]:
# Función para cargar el modelo guardado
def load_trained_model(save_path, config):
    model = create_model_torchdrug(config)
    checkpoint = torch.load(save_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    return model

# Función para realizar la inferencia
def infer(model, dataloader, device):
    model.eval()  # Pone el modelo en modo de evaluación
    predictions = []
    ground_truth = []

    with torch.no_grad():  # Desactiva el cálculo de gradientes
        for batch in dataloader:
            inputs = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            logits = model(inputs, attention_mask).squeeze(1)
            preds = torch.where(logits > 0.5, 1, 0)
            predictions.extend(preds.cpu().tolist())
            ground_truth.extend(labels.cpu().tolist())

    return predictions, ground_truth


save_path = f'{save_dir}/model.pt'  # Ruta del modelo guardado

# Cargar el modelo entrenado
model = load_trained_model(save_path, config)



# Realizar inferencia en los datos de prueba
predictions, ground_truth = infer(model, test_data_loader, device)

# Imprimir resultados
from sklearn.metrics import accuracy_score, classification_report

accuracy = accuracy_score(ground_truth, predictions)
report = classification_report(ground_truth, predictions)

print(f'Test Accuracy: {accuracy * 100:.2f}%')
print('Classification Report:')
print(report)

Test Accuracy: 67.98%
Classification Report:
              precision    recall  f1-score   support

         0.0       0.63      0.86      0.73       879
         1.0       0.79      0.50      0.61       895

    accuracy                           0.68      1774
   macro avg       0.71      0.68      0.67      1774
weighted avg       0.71      0.68      0.67      1774

