In [5]:
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F

from utils.DataLoad import load_embedding, get_embedding_tensor

In [2]:
embedding_vecs = load_embedding("data/model/embedding.txt")

In [4]:
embedding_tensor = get_embedding_tensor(embedding_vecs, 13308, 128)

Embedding tensor shape: torch.Size([13308, 128])


In [None]:
# 定义生成器
class Generator(nn.Module):
    """
    创建生成器
    :param input_size: 接收的随机噪声的维度大小
    :param phylogenetic_embedding_size: 系统发育树节点的嵌入向量的维度大小
    :param hidden_size: 隐藏层大小
    :output_size: 输出向量维度
    """
    def __init__(self, input_size: int, phylogenetic_embedding_size: int, hidden_size: int, output_size: int):
        super(Generator, self).__init__()
        self.fc1_1 = nn.Linear(input_size, hidden_size)
        self.fc1_1_bn = nn.BatchNorm1d(hidden_size)
        self.fc1_2 = nn.Linear(phylogenetic_embedding_size, hidden_size)
        self.fc1_2_bn = nn.BatchNorm1d(hidden_size)
        self.fc2 = nn.Linear(hidden_size * 2, hidden_size)
        self.fc2_bn = nn.BatchNorm1d(hidden_size)
        self.fc3 = nn.Linear(hidden_size, output_size)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x, phylogenetic_embedding):
        """
        生成器向前传播
        :param x: 随机噪声
        :param phylogentic_embedding: 系统发育树嵌入向量
        return: 生成器输出
        """
        x = F.relu(self.fc1_1_bn(self.fc1_1(input)))
        y = F.relu(self.fc1_2_bn(self.fc1_2(phylogenetic_embedding)))
        z = torch.cat([x, y], 1)
        z = F.relu(self.fc2_bn(self.fc2(z)))
        z = self.sigmoid(self.fc3(z))
        return z


In [None]:
class Generator(nn.Module):
    def __init__(self, input_size, phylogenetic_embedding_size, hidden_size, output_size):
        super(Generator, self).__init__()
        self.fc1_1 = nn.Linear(input_size, hidden_size)
        self.fc1_1_bn = nn.BatchNorm1d(hidden_size)
        self.fc1_2 = nn.Linear(phylogenetic_embedding_size, hidden_size)
        self.fc1_2_bn = nn.BatchNorm1d(hidden_size)
        self.fc2 = nn.Linear(hidden_size * 2, hidden_size)
        self.fc2_bn = nn.BatchNorm1d(hidden_size)
        self.fc3 = nn.Linear(hidden_size, output_size)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, input, phylogenetic_embedding):
        x = F.relu(self.fc1_1_bn(self.fc1_1(input)))
        y = F.relu(self.fc1_2_bn(self.fc1_2(phylogenetic_embedding)))
        z = torch.cat([x, y], 1)
        z = F.relu(self.fc2_bn(self.fc2(z)))
        z = self.sigmoid(self.fc3(z))
        return z

In [None]:

# 设置超参数
input_size = 100  # 输入随机噪声的大小
phylogenetic_embedding_size = 128  # 系统发育树节点的嵌入向量大小
hidden_size = 256  # 隐藏层的大小
output_size = 98  # 生成器输出的维度
lr = 0.0002  # 学习率

# 初始化生成器、判别器和注意力机制模块
generator = Generator(input_size, hidden_size, output_size)
discriminator = Discriminator(output_size, hidden_size)
attention = AttentionMechanism(phylogenetic_embedding_size, hidden_size)  # 自行实现注意力机制模块

# 定义优化器和损失函数
optimizer_G = optim.Adam(generator.parameters(), lr=lr)
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr)
criterion = nn.BCELoss()

# 训练过程
for epoch in range(num_epochs):
    for i, (real_data, phylogenetic_embedding) in enumerate(data_loader):
        # 训练判别器
        optimizer_D.zero_grad()
        # 生成假样本
        noise = torch.randn(batch_size, input_size)
        fake_data = generator(noise, phylogenetic_embedding)
        # 计算假样本的损失
        pred_fake = discriminator(fake_data, phylogenetic_embedding)
        loss_fake = criterion(pred_fake, torch.zeros_like(pred_fake))
        # 计算真实样本的损失
        pred_real = discriminator(real_data, phylogenetic_embedding)
        loss_real = criterion(pred_real, torch.ones_like(pred_real))
        # 更新判别器参数
        loss_D = loss_fake + loss_real
        loss_D.backward()
        optimizer_D.step()

        # 训练生成器
        optimizer_G.zero_grad()
        # 生成假样本
        noise = torch.randn(batch_size, input_size)
        fake_data = generator(noise, phylogenetic_embedding)
        # 计算假样本在判别器上的输出
        pred_fake = discriminator(fake_data, phylogenetic_embedding)
        # 计算生成器的损失
        loss_G = criterion(pred_fake, torch.ones_like(pred_fake))
        # 更新生成器参数
        loss_G.backward()
        optimizer_G.step()
