In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score

In [10]:
# 定义数据集类
class CustomDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.long)
        self.y = torch.tensor(y, dtype=torch.long)

    def __len__(self):
        return len(self.X)
#返回数据集中的一个样本
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]


In [11]:
# 定义关系抽取模型
class RelationExtractionModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size, output_size):
        super(RelationExtractionModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.bilstm = nn.LSTM(embedding_dim, hidden_size, batch_first=True, bidirectional=True)
        self.attention = nn.Linear(hidden_size * 2, 1)
        self.fc = nn.Linear(hidden_size * 2, output_size)

    def forward(self, x):
        embedded = self.embedding(x)
        ##只输出lstm_out
        lstm_out, _ = self.bilstm(embedded)
        attention_weights = torch.softmax(self.attention(lstm_out).squeeze(-1), dim=1).unsqueeze(-1)
        context_vector = torch.sum(attention_weights * lstm_out, dim=1)
        output = self.fc(context_vector)
        return torch.log_softmax(output, dim=1)

In [16]:
# 示例新的训练集和测试集数据
train_data = [
    ("苹果", "手机", "属于", "苹果手机是苹果公司的产品。"),
    ("玫瑰", "花", "是", "玫瑰是一种美丽的花。"),
    ("太阳", "星", "围绕", "太阳是我们的星球。"),
    ("狗", "动物", "属于", "狗是一种忠诚的动物。")
]

test_data = [
    ("手机", "设备", "是", "手机是现代人们常用的通讯工具。"),
    ("花", "植物", "属于", "花是植物的一部分。"),
    ("星", "天体", "是", "星是夜空中的光点。"),
    ("猫", "动物", "属于", "猫是一种常见的宠物动物。")
]

# 构建词汇表
word_to_index = {"苹果": 1, "手机": 2, "属于": 3, "玫瑰": 4, "花": 5, "是": 6, "太阳": 7, "星": 8, "围绕": 9, "狗": 10, "动物": 11, "设备": 12, "植物": 13, "天体": 14, "猫": 15}
relation_to_id = {"属于": 0, "是": 1, "围绕": 2}

In [17]:
# 转换文本数据为数字序列，X取句子Y取关系
X_train = [[word_to_index.get(word, 0) for word in sentence.split()] for _, _, _, sentence in train_data]
y_train = [relation_to_id.get(relation, 0) for _, _, relation, _ in train_data]

X_test = [[word_to_index.get(word, 0) for word in sentence.split()] for _, _, _, sentence in test_data]
y_test = [relation_to_id.get(relation, 0) for _, _, relation, _ in test_data]

# 超参数
vocab_size = len(word_to_index) + 1  # 词汇表大小（加1是因为索引从1开始）
embedding_dim = 100  # 嵌入维度
hidden_size = 64  # LSTM隐藏单元数
output_size = len(relation_to_id)  # 输出类别数

# 创建数据集和数据加载器
train_dataset = CustomDataset(X_train, y_train)
test_dataset = CustomDataset(X_test, y_test)
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)



In [19]:
# 初始化模型、损失函数和优化器
model = RelationExtractionModel(vocab_size, embedding_dim, hidden_size, output_size)
criterion = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练模型
num_epochs = 10  # 设置为100表示训练100次

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for batch_X, batch_y in train_loader:
        optimizer.zero_grad()
        output = model(batch_X)
        loss = criterion(output, batch_y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    average_loss = total_loss / len(train_loader)
    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {average_loss:.4f}')
    
    # 在每个周期结束后进行评估
    model.eval()
    all_predictions = []
    all_targets = []
    with torch.no_grad():
        for batch_X, batch_y in test_loader:
            output = model(batch_X)
            _, predicted = torch.max(output, 1)
            all_predictions.extend(predicted.tolist())
            all_targets.extend(batch_y.tolist())

    # 计算评估指标
    f1 = f1_score(all_targets, all_predictions, average='weighted')
    acc = accuracy_score(all_targets, all_predictions)
    precision = precision_score(all_targets, all_predictions, average='weighted')
    recall = recall_score(all_targets, all_predictions, average='weighted')
    
    print(f'F1 Score: {f1:.4f}, Accuracy: {acc:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}')

Epoch [1/10], Loss: 1.1566
F1 Score: 0.3333, Accuracy: 0.5000, Precision: 0.2500, Recall: 0.5000
Epoch [2/10], Loss: 1.0896
F1 Score: 0.3333, Accuracy: 0.5000, Precision: 0.2500, Recall: 0.5000
Epoch [3/10], Loss: 1.0576
F1 Score: 0.3333, Accuracy: 0.5000, Precision: 0.2500, Recall: 0.5000
Epoch [4/10], Loss: 1.0709
F1 Score: 0.3333, Accuracy: 0.5000, Precision: 0.2500, Recall: 0.5000
Epoch [5/10], Loss: 1.0480
F1 Score: 0.3333, Accuracy: 0.5000, Precision: 0.2500, Recall: 0.5000
Epoch [6/10], Loss: 1.0496
F1 Score: 0.3333, Accuracy: 0.5000, Precision: 0.2500, Recall: 0.5000
Epoch [7/10], Loss: 1.0510
F1 Score: 0.3333, Accuracy: 0.5000, Precision: 0.2500, Recall: 0.5000
Epoch [8/10], Loss: 1.0569
F1 Score: 0.3333, Accuracy: 0.5000, Precision: 0.2500, Recall: 0.5000
Epoch [9/10], Loss: 1.0496
F1 Score: 0.3333, Accuracy: 0.5000, Precision: 0.2500, Recall: 0.5000


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch [10/10], Loss: 1.0460
F1 Score: 0.3333, Accuracy: 0.5000, Precision: 0.2500, Recall: 0.5000


  _warn_prf(average, modifier, msg_start, len(result))
