In [12]:
import pickle

datax = pickle.load(open("../data/tmp/datax", "rb"))
datay = pickle.load(open("../data/tmp/datay", "rb"))
datax

array([[-0.52269083, -0.22415945, -0.33585399, ...,  0.69791514,
         1.19434575,  0.81851722],
       [ 0.12289801,  0.23949905,  0.36805204, ...,  1.05240722,
         1.19434575,  0.81851722],
       [-0.61547699, -1.04043621, -0.2512704 , ...,  0.69791514,
         0.59526365,  0.25281025],
       ...,
       [ 1.52002204,  0.64506514,  0.12458575, ...,  1.05240722,
         1.02465893,  0.81851722],
       [ 1.19234002,  0.9515723 ,  1.34588799, ..., -0.7473218 ,
        -0.3527519 ,  0.81851722],
       [ 1.03877276,  1.59472997,  1.1641744 , ..., -0.7473218 ,
         0.57614401,  0.81851722]])

In [13]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
x_tensor = torch.tensor(datax, dtype=torch.float32)
y_tensor = torch.tensor(datay, dtype=torch.long)
y_tensor

tensor([[    8,    18,    78,   295,  1482,  7483],
        [    8,    18,    52,    41,   595,  3130],
        [    8,    18,    78,   295,   704,  3690],
        ...,
        [   17,     5,    86,   334,  3478, 16877],
        [   17,     5,    86,   136,  1749,  8798],
        [   17,     5,    86,   136,  2915, 14337]])

In [15]:
dataset = TensorDataset(x_tensor, y_tensor)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

In [16]:
class ChainModel(nn.Module):
    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        embed_dim: int,
        num_phylum: int,
        num_class: int,
        num_order: int,
        num_family: int,
        num_genus: int,
        num_species: int,
    ):
        super().__init__()
        # 公共特征提取层
        self.shared_layers = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )

        # 第 1 层级：phylum
        self.phylum_head = nn.Linear(hidden_dim, num_phylum)
        self.phylum_embed = nn.Embedding(num_phylum, embed_dim)

        # 第 2 层级：class
        self.class_head = nn.Linear(hidden_dim + embed_dim, num_class)
        self.class_embed = nn.Embedding(num_class, embed_dim)

        # 第 3 层级：order
        self.order_head = nn.Linear(hidden_dim + embed_dim, num_order)
        self.order_embed = nn.Embedding(num_order, embed_dim)

        # 第 4 层级：family
        self.family_head = nn.Linear(hidden_dim + embed_dim, num_family)
        self.family_embed = nn.Embedding(num_family, embed_dim)

        # 第 5 层级：genus
        self.genus_head = nn.Linear(hidden_dim + embed_dim, num_genus)
        self.genus_embed = nn.Embedding(num_genus, embed_dim)

        # 第 6 层级：species
        self.species_head = nn.Linear(hidden_dim + embed_dim, num_species)

    def forward(
        self,
        x: torch.Tensor,
        y=None,
        teacher_forcing: bool = True,
    ):
        """
        x: [batch_size, input_dim] 原始输入特征
        y: [batch_size, 6] 各层级的真值 (phylum, class, order, family, genus, species)
        teacher_forcing: 训练时是否使用真值 (True) 还是使用预测值 (False)

        返回：
          logits_phylum, logits_class, logits_order, logits_family, logits_genus, logits_species
        """
        # 先提取共享特征
        shared_features = self.shared_layers(x)

        # ====================
        # 第 1 层级: phylum
        # ====================
        logits_phylum = self.phylum_head(shared_features)
        if (y is not None) and teacher_forcing:
            # 训练时：使用真值做 embedding
            phylum_emb = self.phylum_embed(y[:, 0])
        else:
            # 推理或不采用 teacher forcing 时：使用预测做 embedding
            phylum_pred = logits_phylum.argmax(dim=1)
            phylum_emb = self.phylum_embed(phylum_pred)

        # ====================
        # 第 2 层级: class
        # ====================
        class_in = torch.cat([shared_features, phylum_emb], dim=1)
        logits_class = self.class_head(class_in)
        if (y is not None) and teacher_forcing:
            class_emb = self.class_embed(y[:, 1])
        else:
            class_pred = logits_class.argmax(dim=1)
            class_emb = self.class_embed(class_pred)

        # ====================
        # 第 3 层级: order
        # ====================
        order_in = torch.cat([shared_features, class_emb], dim=1)
        logits_order = self.order_head(order_in)
        if (y is not None) and teacher_forcing:
            order_emb = self.order_embed(y[:, 2])
        else:
            order_pred = logits_order.argmax(dim=1)
            order_emb = self.order_embed(order_pred)

        # ====================
        # 第 4 层级: family
        # ====================
        family_in = torch.cat([shared_features, order_emb], dim=1)
        logits_family = self.family_head(family_in)
        if (y is not None) and teacher_forcing:
            family_emb = self.family_embed(y[:, 3])
        else:
            family_pred = logits_family.argmax(dim=1)
            family_emb = self.family_embed(family_pred)

        # ====================
        # 第 5 层级: genus
        # ====================
        genus_in = torch.cat([shared_features, family_emb], dim=1)
        logits_genus = self.genus_head(genus_in)
        if (y is not None) and teacher_forcing:
            genus_emb = self.genus_embed(y[:, 4])
        else:
            genus_pred = logits_genus.argmax(dim=1)
            genus_emb = self.genus_embed(genus_pred)

        # ====================
        # 第 6 层级: species
        # ====================
        species_in = torch.cat([shared_features, genus_emb], dim=1)
        logits_species = self.species_head(species_in)

        return (
            logits_phylum,
            logits_class,
            logits_order,
            logits_family,
            logits_genus,
            logits_species
        )

In [17]:
import numpy as np
datay[:, 0]
unique_elements, num_phylum = np.unique(datay[:, 0], return_counts=True)
num_phylum

array([  20,  290,  651, 1757,    9,  331,   86,  814, 2815, 1624,  183,
        131,   21, 1541,  271,   16,   32, 4861,  614])

In [18]:
num_phylum=datay[:, 0].max()+1
print(num_phylum)
num_class=datay[:, 1].max()+1
print(num_class)
num_order=datay[:, 2].max()+1
print(num_order)
num_family=datay[:, 3].max()+1
print(num_family)
num_genus=datay[:, 4].max()+1
print(num_genus)
num_species=datay[:, 5].max()+1
print(num_species)


19
44
87
335
3692
17665


In [19]:


input_dim = x_tensor.shape[1]  # x_tensor.shape[1] 应该是特征数 D
hidden_dim = 1024              # 隐藏层维度
embed_dim = 128                 # 用于上一层级分类结果的 embedding 维度

model = ChainModel(
    input_dim=input_dim,
    hidden_dim=hidden_dim,
    embed_dim=embed_dim,
    num_phylum=num_phylum,
    num_class=num_class,
    num_order=num_order,
    num_family=num_family,
    num_genus=num_genus,
    num_species=num_species
)

# 优化器与损失函数
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

# 训练循环
model.train()
num_epochs = 20

for epoch in range(num_epochs):
    total_loss = 0.0
    for batch_x, batch_y in dataloader:
        logits_phylum, logits_class, logits_order, logits_family, logits_genus, logits_species = model(
            batch_x, 
            y=batch_y, 
            teacher_forcing=True
        )

        # 计算各层级的交叉熵损失
        loss_phylum = criterion(logits_phylum, batch_y[:, 0])
        loss_class = criterion(logits_class, batch_y[:, 1])
        loss_order = criterion(logits_order, batch_y[:, 2])
        loss_family = criterion(logits_family, batch_y[:, 3])
        loss_genus = criterion(logits_genus, batch_y[:, 4])
        loss_species = criterion(logits_species, batch_y[:, 5])

        # 汇总各层级损失，或者可以加权
        loss = (
            loss_phylum + loss_class + loss_order +
            loss_family + loss_genus + loss_species
        )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(dataloader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")

Epoch [1/20], Loss: 16.5900
Epoch [2/20], Loss: 8.7436
Epoch [3/20], Loss: 5.2702
Epoch [4/20], Loss: 3.7572
Epoch [5/20], Loss: 2.9236
Epoch [6/20], Loss: 2.3807
Epoch [7/20], Loss: 2.0191
Epoch [8/20], Loss: 1.7803
Epoch [9/20], Loss: 1.5732
Epoch [10/20], Loss: 1.4327
Epoch [11/20], Loss: 1.3338
Epoch [12/20], Loss: 1.2526
Epoch [13/20], Loss: 1.1625
Epoch [14/20], Loss: 1.1167
Epoch [15/20], Loss: 1.0588
Epoch [16/20], Loss: 1.0108
Epoch [17/20], Loss: 0.9638
Epoch [18/20], Loss: 0.9379
Epoch [19/20], Loss: 0.8875
Epoch [20/20], Loss: 0.8763


In [20]:
# 将模型设置为评估模式
model.eval()

# 从数据加载器中获取第一个批次的数据
first_batch_x, first_batch_y = next(iter(dataloader))

# 使用模型进行预测
with torch.no_grad():  # 确保不计算梯度
    logits_phylum, logits_class, logits_order, logits_family, logits_genus, logits_species = model(
        first_batch_x, 
        y=first_batch_y, 
        teacher_forcing=False  # 在预测时通常不使用teacher forcing
    )

# 转换logits为标签索引
pred_phylum = torch.argmax(logits_phylum, dim=1)
pred_class = torch.argmax(logits_class, dim=1)
pred_order = torch.argmax(logits_order, dim=1)
pred_family = torch.argmax(logits_family, dim=1)
pred_genus = torch.argmax(logits_genus, dim=1)
pred_species = torch.argmax(logits_species, dim=1)

# 打印第一个数据的预测和实际标签
print("Predictions:")
print(f"Phylum: {pred_phylum[0]}, Class: {pred_class[0]}, Order: {pred_order[0]}, Family: {pred_family[0]}, Genus: {pred_genus[0]}, Species: {pred_species[0]}")
print("Actual Labels:")
print(f"Phylum: {first_batch_y[0, 0]}, Class: {first_batch_y[0, 1]}, Order: {first_batch_y[0, 2]}, Family: {first_batch_y[0, 3]}, Genus: {first_batch_y[0, 4]}, Species: {first_batch_y[0, 5]}")

Predictions:
Phylum: 3, Class: 33, Order: 19, Family: 116, Genus: 1295, Species: 6494
Actual Labels:
Phylum: 3, Class: 33, Order: 19, Family: 116, Genus: 1295, Species: 6494


In [47]:
import torch

# 设置模型为评估模式
model.eval()

# 初始化用来存储准确率计算结果的变量
accuracies = {
    'phylum': 0,
    'class': 0,
    'order': 0,
    'family': 0,
    'genus': 0,
    'species': 0
}
total_batches = len(dataloader)

# 遍历数据加载器
with torch.no_grad():
    for batch_x, batch_y in dataloader:
        logits_phylum, logits_class, logits_order, logits_family, logits_genus, logits_species = model(
            batch_x, 
            y=batch_y, 
            teacher_forcing=False
        )

        # 计算每层的准确率
        accuracies['phylum'] += (torch.argmax(logits_phylum, dim=1) == batch_y[:, 0]).float().mean()
        accuracies['class'] += (torch.argmax(logits_class, dim=1) == batch_y[:, 1]).float().mean()
        accuracies['order'] += (torch.argmax(logits_order, dim=1) == batch_y[:, 2]).float().mean()
        accuracies['family'] += (torch.argmax(logits_family, dim=1) == batch_y[:, 3]).float().mean()
        accuracies['genus'] += (torch.argmax(logits_genus, dim=1) == batch_y[:, 4]).float().mean()
        accuracies['species'] += (torch.argmax(logits_species, dim=1) == batch_y[:, 5]).float().mean()

# 计算平均准确率
for key in accuracies:
    accuracies[key] /= total_batches

# 打印结果
print("Training Set Accuracy:")
for level, acc in accuracies.items():
    print(f"{level.capitalize()}: {acc:.4f}")

Training Set Accuracy:
Phylum: 0.9939
Class: 0.9923
Order: 0.9912
Family: 0.9878
Genus: 0.9721
Species: 0.8213


In [48]:
torch.save(model, 'tax_model.pth')