In [None]:
import os 
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.nn as pyg_nn
from torch_geometric.data import DataLoader
from torch_geometric.data import Data
import numpy as np
from tqdm import tqdm
from GCNFrame import Biodata
import pandas as pd
from sklearn.preprocessing import LabelEncoder, StandardScaler
import umap
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm



In [None]:
# This is an example to train a two-classes model.
from GCNFrame import Biodata, GCNmodel
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

data = Biodata(fasta_file="/root/autodl-tmp/TaxoChallenge/ICVTTaxoChallenge_43587.fa", 
        label_file="/root/autodl-tmp/TaxoChallenge/phylum_numeric_labels.txt",
        feature_file=None)
dataset = data.encode(thread=20)
model = GCNmodel.model(label_num=20, other_feature_dim=0).to(device)
GCNmodel.train(dataset, model, weighted_sampling=True, batch_size=16, model_name="/root/autodl-tmp/TaxoChallenge/GCN_model_43587.pt")


In [None]:
# Assuming the format of your phylum label file is one label per line
def prepare_phylum_labels(label_file):
    # Read the label file
    labels_df = pd.read_csv(label_file, sep='\t', header=None, 
                           names=['sequence_id', 'phylum'])
    
    # Convert text labels to numeric labels
    label_encoder = LabelEncoder()
    numeric_labels = label_encoder.fit_transform(labels_df['phylum'])
    
    # Save the label mapping for interpreting results later
    label_mapping = dict(zip(label_encoder.classes_, 
                           range(len(label_encoder.classes_))))
    
    # Save the numeric labels in numpy array format
    np.savetxt('/root/autodl-tmp/TaxoChallenge/phylum_numeric_labels.txt', numeric_labels, fmt='%d')
    
    # Save the label mapping relationship
    with open('/root/autodl-tmp/TaxoChallenge/phylum_label_mapping.txt', 'w') as f:
        for phylum, idx in label_mapping.items():
            f.write(f"{phylum}\t{idx}\n")
    
    return numeric_labels, label_mapping

# Call the function
prepare_phylum_labels("/root/autodl-tmp/TaxoChallenge/ICTV_TaxoChallenge_id_phylum.txt")


In [None]:

# BipartiteData class
class BipartiteData(Data):
    def _add_other_feature(self, other_feature):
        self.other_feature = other_feature

    def __inc__(self, key, value):
        if key == 'edge_index':
            return torch.tensor([[self.x_src.size(0)], [self.x_dst.size(0)]])
        else:
            return super(BipartiteData, self).__inc__(key, value)

In [None]:
def get_matrix_embeddings_from_model(model, data, device):
    """Extract matrix embeddings from the model, and return labels as well"""
    model.eval()
    with torch.no_grad():
        # Move data to the specified device
        x_f = data.x_src.to(device)
        x_p = data.x_dst.to(device)
        edge_index_forward = data.edge_index[:, ::2].to(device)
        edge_index_backward = data.edge_index[[1, 0], :][:, 1::2].to(device)

        # Process primary node features
        if model.pnode_nn:
            x_p = torch.reshape(x_p, (-1, model.pnode_num * model.pnode_dim))
            x_p = model.pnode_d(x_p)
            x_p = torch.reshape(x_p, (-1, model.node_hidden_dim))
        else:
            x_p = torch.reshape(x_p, (-1, model.pnode_dim))

        # Process feature node features
        if model.fnode_nn:
            x_f = torch.reshape(x_f, (-1, model.fnode_num))
            x_f = model.fnode_d(x_f)
            x_f = torch.reshape(x_f, (-1, model.node_hidden_dim))
        else:
            x_f = torch.reshape(x_f, (-1, 1))

        # Add label embeddings (if supported)
        if hasattr(model, 'label_embedding') and hasattr(data, 'y'):
            label_embedding = model.label_embedding(data.y)
            x_p = x_p + label_embedding.unsqueeze(1).expand(-1, x_p.size(1), -1)

        # Graph convolution layers
        for i in range(model.gcn_layer_num):
            x_p = model.gconvs_1[i]((x_f, x_p), edge_index_forward)
            x_p = F.relu(x_p)
            x_f = model.gconvs_2[i]((x_p, x_f), edge_index_backward)
            x_f = F.relu(x_f)
            if not i == model.gcn_layer_num - 1:
                x_p = model.lns[i](x_p)
                x_f = model.lns[i](x_f)

        # Convert to matrix form
        x_p = torch.reshape(x_p, (-1, model.gcn_dim, model.pnode_num))
        for i in range(model.cnn_layer_num):
            x_p = model.convs[i](x_p)
            x_p = F.relu(x_p)

        # Return matrix form and labels (if available)
        return x_p.cpu(), data.y.cpu() if hasattr(data, 'y') else None


def get_matrix_dataset_embeddings(dataset, model_path, fasta_ids, batch_size=8, device=None):
    """Get matrix-form embeddings of the entire dataset, store them as a dictionary with fasta sequence IDs"""
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print(f"Loading model from {model_path}...")
    model = torch.load(model_path, map_location=device)
    model.to(device)
    model.eval()

    print(f"Processing dataset with {len(dataset)} samples...")
    loader = DataLoader(dataset,
                        batch_size=batch_size,
                        shuffle=False,
                        follow_batch=['x_src', 'x_dst'])

    embeddings_dict = {}

    with torch.no_grad():
        fasta_id_index = 0  # Track the index of fasta_ids

        for batch in tqdm(loader, desc="Processing batches"):
            try:
                # Extract matrix embeddings and labels
                embedding, label = get_matrix_embeddings_from_model(model, batch, device)

                # Use fasta_ids as keys to store embeddings in the dictionary
                for idx in range(embedding.shape[0]):
                    contig_id = fasta_ids[fasta_id_index]
                    embeddings_dict[contig_id] = embedding[idx].numpy()
                    fasta_id_index += 1

                torch.cuda.empty_cache()

            except RuntimeError as e:
                print(f"Error processing batch: {e}")
                if batch_size > 1:
                    print("Reducing batch size and retrying...")
                    return get_matrix_dataset_embeddings(dataset, model_path, fasta_ids, batch_size=batch_size // 2, device=device)
                else:
                    raise e

    return embeddings_dict





In [None]:
#################save as .pkl file#################
import pickle
save_path = '/root/autodl-tmp/TaxoChallenge/embeddings_with_phylum_43587_matrix.pkl'
with open(save_path, 'wb') as f:
    pickle.dump(matrix_embeddings, f)