In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
class ChainModel(nn.Module):
    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        embed_dim: int,
        num_phylum: int,
        num_class: int,
        num_order: int,
        num_family: int,
        num_genus: int,
        num_species: int,
    ):
        super().__init__()
        # 公共特征提取层
        self.shared_layers = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )

        # 第 1 层级：phylum
        self.phylum_head = nn.Linear(hidden_dim, num_phylum)
        self.phylum_embed = nn.Embedding(num_phylum, embed_dim)

        # 第 2 层级：class
        self.class_head = nn.Linear(hidden_dim + embed_dim, num_class)
        self.class_embed = nn.Embedding(num_class, embed_dim)

        # 第 3 层级：order
        self.order_head = nn.Linear(hidden_dim + embed_dim, num_order)
        self.order_embed = nn.Embedding(num_order, embed_dim)

        # 第 4 层级：family
        self.family_head = nn.Linear(hidden_dim + embed_dim, num_family)
        self.family_embed = nn.Embedding(num_family, embed_dim)

        # 第 5 层级：genus
        self.genus_head = nn.Linear(hidden_dim + embed_dim, num_genus)
        self.genus_embed = nn.Embedding(num_genus, embed_dim)

        # 第 6 层级：species
        self.species_head = nn.Linear(hidden_dim + embed_dim, num_species)

    def forward(
        self,
        x: torch.Tensor,
        y=None,
        teacher_forcing: bool = True,
    ):
        """
        x: [batch_size, input_dim] 原始输入特征
        y: [batch_size, 6] 各层级的真值 (phylum, class, order, family, genus, species)
        teacher_forcing: 训练时是否使用真值 (True) 还是使用预测值 (False)

        返回：
          logits_phylum, logits_class, logits_order, logits_family, logits_genus, logits_species
        """
        # 先提取共享特征
        shared_features = self.shared_layers(x)

        # ====================
        # 第 1 层级: phylum
        # ====================
        logits_phylum = self.phylum_head(shared_features)
        if (y is not None) and teacher_forcing:
            # 训练时：使用真值做 embedding
            phylum_emb = self.phylum_embed(y[:, 0])
        else:
            # 推理或不采用 teacher forcing 时：使用预测做 embedding
            phylum_pred = logits_phylum.argmax(dim=1)
            phylum_emb = self.phylum_embed(phylum_pred)

        # ====================
        # 第 2 层级: class
        # ====================
        class_in = torch.cat([shared_features, phylum_emb], dim=1)
        logits_class = self.class_head(class_in)
        if (y is not None) and teacher_forcing:
            class_emb = self.class_embed(y[:, 1])
        else:
            class_pred = logits_class.argmax(dim=1)
            class_emb = self.class_embed(class_pred)

        # ====================
        # 第 3 层级: order
        # ====================
        order_in = torch.cat([shared_features, class_emb], dim=1)
        logits_order = self.order_head(order_in)
        if (y is not None) and teacher_forcing:
            order_emb = self.order_embed(y[:, 2])
        else:
            order_pred = logits_order.argmax(dim=1)
            order_emb = self.order_embed(order_pred)

        # ====================
        # 第 4 层级: family
        # ====================
        family_in = torch.cat([shared_features, order_emb], dim=1)
        logits_family = self.family_head(family_in)
        if (y is not None) and teacher_forcing:
            family_emb = self.family_embed(y[:, 3])
        else:
            family_pred = logits_family.argmax(dim=1)
            family_emb = self.family_embed(family_pred)

        # ====================
        # 第 5 层级: genus
        # ====================
        genus_in = torch.cat([shared_features, family_emb], dim=1)
        logits_genus = self.genus_head(genus_in)
        if (y is not None) and teacher_forcing:
            genus_emb = self.genus_embed(y[:, 4])
        else:
            genus_pred = logits_genus.argmax(dim=1)
            genus_emb = self.genus_embed(genus_pred)

        # ====================
        # 第 6 层级: species
        # ====================
        species_in = torch.cat([shared_features, genus_emb], dim=1)
        logits_species = self.species_head(species_in)

        return (
            logits_phylum,
            logits_class,
            logits_order,
            logits_family,
            logits_genus,
            logits_species
        )

In [None]:
model = torch.load('tax_model.pth')

In [None]:
import pickle
with open('/root/tax/datax_pre.pkl', 'rb') as f:
    datax_pre = pickle.load(f)
datax_pre

In [None]:
x_tensor = torch.tensor(datax_pre, dtype=torch.float32)

model.eval()
with torch.no_grad():
    logits_phylum, logits_class, logits_order, logits_family, logits_genus, logits_species = model(x_tensor,teacher_forcing=False)
    pred_phylum = torch.argmax(logits_phylum, dim=1)
    pred_class = torch.argmax(logits_class, dim=1)
    pred_order = torch.argmax(logits_order, dim=1)
    pred_family = torch.argmax(logits_family, dim=1)
    pred_genus = torch.argmax(logits_genus, dim=1)
    pred_species = torch.argmax(logits_species, dim=1)


In [None]:
pred_data = torch.stack([pred_phylum, pred_class, pred_order, pred_family, pred_genus, pred_species], dim=1)
pred_data

In [None]:
with open('/root/tax/pred_data.pkl', 'wb') as f:
    pickle.dump(pred_data, f)