In [1]:
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader, TensorDataset
from torch_geometric.data import DataLoader as GeometricDataLoader

import os
import sys
import nbimporter

project_root = os.path.join(os.getcwd(), '..')
sys.path.append(project_root)

from datapreparation.Process_1D_data import *
from datapreparation.Process_graph_2d_data import *
from datapreparation.Process_graph_3d_data import *
from datapreparation.Process_mlp_data import *
from cv_strategies.train_cv_strategy_123D import get_data_label

In [2]:
import numpy as np

def extract_features(model, loader, device, data_type):
    model.eval()  
    features = []
    all_labels = []
    
    with torch.no_grad():

        for data in loader:
            inputs, labels = get_data_label(data, data_type, device)
            if inputs is None or labels is None:
                raise ValueError(f"Unsupported data type: {data_type}")
            
            output_features = model(inputs, extract_features=True)
            
            features.append(output_features.cpu().numpy())
            all_labels.append(labels.cpu().numpy())

    features = np.concatenate(features, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)
    
    return features, all_labels


# Preprocess for feature extraction

In [3]:
def integrated_feature_extraction(data_type, model, model_path, smiles_df, device, batch_size, tokenizer, atom_numbers,mode='test',Preprocess=None,scale_path=None):
    
    model.load_state_dict(torch.load(model_path))
    model.to(device)
        
    if data_type == '1d':
        smiles_list = smiles_df['SMILES'].tolist()
        labels = smiles_df['Label'].apply(lambda x: 1 if x == 'Positive' else 0).tolist()
        all_dataset = SMILESDataset(smiles_list, labels, tokenizer)
        all_data_loader = DataLoader(all_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
    elif data_type == '2d':
        smiles_list = smiles_df["SMILES"].values
        labels = smiles_df['Label'].values
        graph_data_2d = preprocess_2d_graph_data(smiles_list, labels, atom_numbers)
        torch_graph_data_list = get_torch_graph_data_list(graph_data_2d,mode, Preprocess, scale_path)
        all_data_loader = GeometricDataLoader(torch_graph_data_list, batch_size=batch_size, shuffle=True)
    elif data_type == '3d':
        smiles_list = smiles_df["SMILES"].values
        labels = smiles_df['Label'].values
        graph_data_3d = preprocess_3d_graph(smiles_list, labels, atom_numbers)
        torch_graph_data_list = get_torch_graph_data_list(graph_data_3d,mode, Preprocess, scale_path)
        all_data_loader = GeometricDataLoader(torch_graph_data_list, batch_size=batch_size, shuffle=True)

    # Extract features
    all_features, all_labels = extract_features(model, all_data_loader, device, data_type)

    return all_features, all_labels
