In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.nn as pyg_nn
from torch_geometric.data import DataLoader
from torch_geometric.data import Data
import numpy as np
from tqdm import tqdm
from GCNFrame import Biodata, GCNmodel
import pandas as pd
from sklearn.preprocessing import LabelEncoder, StandardScaler
import umap
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
from Bio import SeqIO
import pickle

In [None]:
# Assume that the phylum labels are already prepared
def prepare_phylum_labels(label_file):
    # read label file
    labels_df = pd.read_csv(label_file, sep='\t', header=None, 
                           names=['sequence_id', 'phylum'])
    
    # convert labels to numeric values
    label_encoder = LabelEncoder()
    numeric_labels = label_encoder.fit_transform(labels_df['phylum'])
    
    # same the label mapping
    label_mapping = dict(zip(label_encoder.classes_, 
                           range(len(label_encoder.classes_))))
    
    # save the numeric labels as a numpy array
    np.savetxt('/root/autodl-tmp/TaxoChallenge/phylum_numeric_labels.txt', numeric_labels, fmt='%d')
    
    # save the label mapping
    with open('/root/autodl-tmp/TaxoChallenge/phylum_label_mapping.txt', 'w') as f:
        for phylum, idx in label_mapping.items():
            f.write(f"{phylum}\t{idx}\n")
    
    return numeric_labels, label_mapping
prepare_phylum_labels("/root/autodl-tmp/TaxoChallenge/ICTV_TaxoChallenge_id_phylum.txt")

In [None]:

################## This is an example to train a two-classes model.#################
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

data = Biodata(fasta_file="/root/autodl-tmp/TaxoChallenge/ICVTTaxoChallenge_43587.fa", 
        label_file="/root/autodl-tmp/TaxoChallenge/phylum_numeric_labels.txt",
        feature_file=None)
dataset = data.encode(thread=20)
model = GCNmodel.model(label_num=20, other_feature_dim=0).to(device)
GCNmodel.train(dataset, model, weighted_sampling=True, batch_size=16, model_name="/root/autodl-tmp/TaxoChallenge/GCN_model_43587.pt")


In [3]:

# BipartiteData Class
class BipartiteData(Data):
    def _add_other_feature(self, other_feature):
        self.other_feature = other_feature

    def __inc__(self, key, value):
        if key == 'edge_index':
            return torch.tensor([[self.x_src.size(0)], [self.x_dst.size(0)]])
        else:
            return super(BipartiteData, self).__inc__(key, value)

In [None]:
###########################extract vector embedding##################################
def get_flattened_embeddings_from_model(model, data, device):
    """extract vector embedding from model and return the label if it exists"""
    model.eval()
    with torch.no_grad():
        # move data to device
        x_f = data.x_src.to(device)
        x_p = data.x_dst.to(device)
        edge_index_forward = data.edge_index[:, ::2].to(device)
        edge_index_backward = data.edge_index[[1, 0], :][:, 1::2].to(device)

        # primary node feature processing
        if model.pnode_nn:
            x_p = torch.reshape(x_p, (-1, model.pnode_num * model.pnode_dim))
            x_p = model.pnode_d(x_p)
            x_p = torch.reshape(x_p, (-1, model.node_hidden_dim))
        else:
            x_p = torch.reshape(x_p, (-1, model.pnode_dim))

        # feature node processing
        if model.fnode_nn:
            x_f = torch.reshape(x_f, (-1, model.fnode_num))
            x_f = model.fnode_d(x_f)
            x_f = torch.reshape(x_f, (-1, model.node_hidden_dim))
        else:
            x_f = torch.reshape(x_f, (-1, 1))

        # add label embedding if it exists
        if hasattr(model, 'label_embedding') and hasattr(data, 'y'):
            label_embedding = model.label_embedding(data.y)
            x_p = x_p + label_embedding.unsqueeze(1).expand(-1, x_p.size(1), -1)

        # GCN layers
        for i in range(model.gcn_layer_num):
            x_p = model.gconvs_1[i]((x_f, x_p), edge_index_forward)
            x_p = F.relu(x_p)
            x_f = model.gconvs_2[i]((x_p, x_f), edge_index_backward)
            x_f = F.relu(x_f)
            if not i == model.gcn_layer_num - 1:
                x_p = model.lns[i](x_p)
                x_f = model.lns[i](x_f)

        # convolutional layers
        x_p = torch.reshape(x_p, (-1, model.gcn_dim, model.pnode_num))
        for i in range(model.cnn_layer_num):
            x_p = model.convs[i](x_p)
            x_p = F.relu(x_p)

        # flatten to 2D tensor
        flattened_embedding = x_p.flatten(start_dim=1)

        # return flattened embedding and label if it exists
        return flattened_embedding.cpu(), data.y.cpu() if hasattr(data, 'y') else None


def get_flattened_dataset_embeddings(dataset, model_path, batch_size=8, device=None):
    """get flattened embeddings for the dataset and return the labels if they exist"""
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print(f"Loading model from {model_path}...")
    model = torch.load(model_path, map_location=device)
    model.to(device)
    model.eval()

    print(f"Processing dataset with {len(dataset)} samples...")
    loader = DataLoader(dataset,
                        batch_size=batch_size,
                        shuffle=False,
                        follow_batch=['x_src', 'x_dst'])

    embeddings = []
    labels = []

    with torch.no_grad():
        for batch in tqdm(loader, desc="Processing batches"):
            try:
                # extract flattened embeddings and labels
                embedding, label = get_flattened_embeddings_from_model(model, batch, device)
                embeddings.append(embedding)

                if label is not None:
                    labels.append(label)

                torch.cuda.empty_cache()

            except RuntimeError as e:
                print(f"Error processing batch: {e}")
                if batch_size > 1:
                    print("Reducing batch size and retrying...")
                    return get_flattened_dataset_embeddings(dataset, model_path, batch_size=batch_size // 2, device=device)
                else:
                    raise e

    print("Concatenating flattened embeddings...")
    embeddings = torch.cat(embeddings, dim=0)

    if labels:
        labels = torch.cat(labels, dim=0)
        return embeddings, labels

    return embeddings, None


In [4]:
# your own path
fasta_file = '/root/autodl-tmp/1000-ICTV/virus_new_sorted_1000.fasta'  # fasta path
phylum_file = '/root/autodl-tmp/workspace/ICTV/1000/id-phylum-1000.txt'     # phylum lable path
model_path = '/root/autodl-tmp/1000-ICTV/GCN_model.pt'  # model path
save_path = '/root/autodl-tmp/1000-ICTV/embeddings_with_phylum_1000_lowdimension_labels'


In [None]:
############################get vector embedding############################
flattened_embeddings, labels= get_flattened_dataset_embeddings(dataset, model_path, batch_size=32)
print(f"Flattened embeddings shape: {flattened_embeddings.shape}")
print(f"Labels shape: {labels.shape}")

In [None]:
############################umap Dimensionality Reduction and Visualization################################
def visualize_embeddings_with_umap(flattened_embeddings, labels=labels, n_neighbors=15, min_dist=0.1, n_components=2, metric='euclidean'):

    print("Starting UMAP dimensionality reduction...")
    reducer = umap.UMAP(n_neighbors=n_neighbors, min_dist=min_dist, n_components=n_components, metric=metric)
    embeddings_2d = reducer.fit_transform(flattened_embeddings)

    print("Visualizing embeddings...")
    plt.figure(figsize=(10, 8))

    if labels is not None:
        scatter = plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], c=labels, cmap='Spectral', s=5, alpha=0.7)
        plt.colorbar(scatter, label="Labels")
    else:
        plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], s=5, alpha=0.7)

    plt.title("UMAP Visualization of Embeddings")
    plt.xlabel("UMAP Dimension 1")
    plt.ylabel("UMAP Dimension 2")
    plt.show()

# visualization
visualize_embeddings_with_umap(flattened_embeddings, labels)

In [11]:
#######################extract matrix embedding##################################
def get_matrix_embeddings_from_model(model, data, device):
    model.eval()
    with torch.no_grad():
        # 将move data to device
        x_f = data.x_src.to(device)
        x_p = data.x_dst.to(device)
        edge_index_forward = data.edge_index[:, ::2].to(device)
        edge_index_backward = data.edge_index[[1, 0], :][:, 1::2].to(device)

        # primary node feature processing
        if model.pnode_nn:
            x_p = torch.reshape(x_p, (-1, model.pnode_num * model.pnode_dim))
            x_p = model.pnode_d(x_p)
            x_p = torch.reshape(x_p, (-1, model.node_hidden_dim))
        else:
            x_p = torch.reshape(x_p, (-1, model.pnode_dim))

        # feature node processing
        if model.fnode_nn:
            x_f = torch.reshape(x_f, (-1, model.fnode_num))
            x_f = model.fnode_d(x_f)
            x_f = torch.reshape(x_f, (-1, model.node_hidden_dim))
        else:
            x_f = torch.reshape(x_f, (-1, 1))

        # add label embedding if it exists
        if hasattr(model, 'label_embedding') and hasattr(data, 'y'):
            label_embedding = model.label_embedding(data.y)
            x_p = x_p + label_embedding.unsqueeze(1).expand(-1, x_p.size(1), -1)

        # GCN layers
        for i in range(model.gcn_layer_num):
            x_p = model.gconvs_1[i]((x_f, x_p), edge_index_forward)
            x_p = F.relu(x_p)
            x_f = model.gconvs_2[i]((x_p, x_f), edge_index_backward)
            x_f = F.relu(x_f)
            if not i == model.gcn_layer_num - 1:
                x_p = model.lns[i](x_p)
                x_f = model.lns[i](x_f)

        # convolutional layers
        x_p = torch.reshape(x_p, (-1, model.gcn_dim, model.pnode_num))
        for i in range(model.cnn_layer_num):
            x_p = model.convs[i](x_p)
            x_p = F.relu(x_p)

        # return matrix embedding and label if it exists
        return x_p.cpu(), data.y.cpu() if hasattr(data, 'y') else None


def get_matrix_dataset_embeddings(dataset, model_path, fasta_ids, batch_size=8, device=None):
    """Get matrix-form embeddings for the entire dataset and store them as a dictionary including fasta sequence IDs"""
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print(f"Loading model from {model_path}...")
    model = torch.load(model_path, map_location=device)
    model.to(device)
    model.eval()

    print(f"Processing dataset with {len(dataset)} samples...")
    loader = DataLoader(dataset,
                        batch_size=batch_size,
                        shuffle=False,
                        follow_batch=['x_src', 'x_dst'])

    embeddings_dict = {}

    with torch.no_grad():
        fasta_id_index = 0  # for tracking the current fasta ID

        for batch in tqdm(loader, desc="Processing batches"):
            try:
                # extract matrix embeddings and labels
                embedding, label = get_matrix_embeddings_from_model(model, batch, device)

                # use fasta IDS as keys to store embeddings in the dictionary
                for idx in range(embedding.shape[0]):
                    contig_id = fasta_ids[fasta_id_index]
                    embeddings_dict[contig_id] = embedding[idx].numpy()
                    fasta_id_index += 1

                torch.cuda.empty_cache()

            except RuntimeError as e:
                print(f"Error processing batch: {e}")
                if batch_size > 1:
                    print("Reducing batch size and retrying...")
                    return get_matrix_dataset_embeddings(dataset, model_path, fasta_ids, batch_size=batch_size // 2, device=device)
                else:
                    raise e

    return embeddings_dict

In [None]:
# set path
fasta_file = '/root/autodl-tmp/TaxoChallenge/ICVTTaxoChallenge_43587.fa'  # fasta path
phylum_file = '/root/autodl-tmp/TaxoChallenge/ICTV_TaxoChallenge_id_phylum.txt'     # phylum label path
model_path = '/root/autodl-tmp/TaxoChallenge/GCN_model_43587.pt'  # model path
save_path = '/root/autodl-tmp/TaxoChallenge/embeddings_with_phylum_dict_matrix.pt'

In [None]:
#########################extract fasta id##################################
fasta_file = "/root/autodl-tmp/TaxoChallenge/ICVTTaxoChallenge_43587.fa"
fasta_ids = [record.id for record in SeqIO.parse(fasta_file, "fasta")]
print(f"Extracted {len(fasta_ids)} IDs from {fasta_file}")
print(f"First 5 IDs: {fasta_ids[:5]}")

In [None]:
#########################extreact matrix embedding result##################################
matrix_embeddings= get_matrix_dataset_embeddings(dataset, model_path, fasta_ids, batch_size=32)
print(f"Flattened embeddings shape: {matrix_embeddings}.keys()")

In [13]:
#########################save as.pkl##################################
save_path = '/root/autodl-tmp/TaxoChallenge/embeddings_with_phylum_43587_matrix.pkl'
with open(save_path, 'wb') as f:
    pickle.dump(matrix_embeddings, f)