In [175]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchtext.vocab import build_vocab_from_iterator
from torch.nn.utils.rnn import pad_sequence
import torch.optim as optim
from collections import Counter
from torchtext.vocab import Vocab

In [None]:
import Tokenizer
from tqdm import tqdm
trainset_path = '../corpus/sst2_train_500_zh_bo.txt'
labels = []
texts = []
maxLen = 30
for line in tqdm(open(trainset_path, encoding='utf-8')):
    tokedText = []
    label, text = line.rstrip('\n').split('\t')
    labels.append(label)
    tok = Tokenizer.sentence_tokenize(text.rstrip('\n'))
    if len(tok) > maxLen:
        tok = tok[:maxLen]
    texts.append(' '.join(tok))
    print(tok)

print(texts)


In [None]:
labels = list(map(int, labels))
origin_labels = labels.copy()
origin_labels

In [179]:
# 建立词汇表
def yield_tokens(data_iter):
    for text in data_iter:
        yield text.split()

# # 使用Counter计算每个单词的频率
# counter = Counter()
# for text in texts:
#     counter.update(text.lower().split())


# print(counter)

# # 创建词汇表，你可以限制词汇表的大小，或者使用min_freq来指定最小频次
# # vocab = Vocab(counter)
# vocab = build_vocab_from_iterator(yield_tokens(texts), specials=["<unk>"])

# # 文本转换为整数序列
# def text_pipeline(x):
#     return [vocab[token] for token in x.split()]

# # 将文本数据集转换为整数序列
# sequences = [torch.tensor(text_pipeline(x)) for x in texts]

# print(len(sequences))
# print(len(labels))
# pSequences = pad_sequence(sequences, batch_first=True)
# print(pSequences.size())
# s = 0
# for i in range(len(sequences)):
#     s = s + len(sequences[i])
# print(s / len(sequences))

In [180]:
# 定义数据集类
class TextDataset(Dataset):
    def __init__(self, texts, labels, vocab):
        self.texts = [torch.tensor([vocab[token] for token in text.split()]) for text in texts]
        # self.texts = pad_sequence(self.texts, batch_first=True)
        # print(self.texts)
        self.labels = torch.tensor(labels)
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        return self.texts[idx], self.labels[idx]

In [None]:
# 创建词汇表
vocab = build_vocab_from_iterator(yield_tokens(texts), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])

def collate_batch(batch):
    # 按序列长度降序排序批次
    batch.sort(key=lambda x: len(x[0]), reverse=True)
    
    sequences, labels = zip(*batch)

    # 填充序列以使它们具有相同的长度，并堆叠标签
    sequences_padded = pad_sequence(sequences, batch_first=True)
    labels = torch.stack(labels)
    # print(sequences_padded)

    # 计算每个填充序列的长度
    lengths = torch.tensor([len(seq) for seq in sequences_padded])

    return sequences_padded, lengths, labels

# 创建数据集
trainSize = int(len(texts) * 0.8)
train_dataset = TextDataset(texts[:trainSize], labels[:trainSize], vocab)
train_dataloader = DataLoader(train_dataset, batch_size=30, collate_fn=collate_batch, drop_last=True)
test_dataset = TextDataset(texts[trainSize:], labels[trainSize:], vocab)
test_dataloader = DataLoader(test_dataset, batch_size=30, collate_fn=collate_batch, drop_last=True)

for item in train_dataloader:
    print(item)

In [202]:
# 定义模型
class BiLSTMClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes):
        super(BiLSTMClassifier, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.bilstm1 = nn.LSTM(embed_dim, hidden_dim, bidirectional=True)
        self.dropout1 = nn.Dropout(p=0.5)
        self.bilstm2 = nn.LSTM(hidden_dim*2, hidden_dim//2, bidirectional=True)
        self.dropout2 = nn.Dropout(p=0.5)
        self.fc = nn.Linear(hidden_dim, num_classes)
        self.relu = nn.ReLU()
        self.softmax = nn.LogSoftmax(dim=1)
    
    def forward(self, text, text_lengths):
        embedded = self.embedding(text)
        packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, text_lengths, enforce_sorted=True)
        packed_output1, (hidden1, cell1) = self.bilstm1(packed_embedded)
        output1, output_lengths1 = nn.utils.rnn.pad_packed_sequence(packed_output1)
        dropped1 = self.dropout1(output1)
        
        packed_dropped1 = nn.utils.rnn.pack_padded_sequence(dropped1, output_lengths1, enforce_sorted=True)
        packed_output2, (hidden2, cell2) = self.bilstm2(packed_dropped1)
        output2, output_lengths2 = nn.utils.rnn.pad_packed_sequence(packed_output2)
        dropped2 = self.dropout2(output2)
        
        # Concat the final forward (hidden2[-2,:,:]) and backward (hidden2[-1,:,:]) hidden layers
        hidden_final = torch.cat((hidden2[-2,:,:], hidden2[-1,:,:]), dim=1)
        
        relued = self.relu(self.fc(hidden_final))
        return self.softmax(relued)

In [None]:
# 设置超参数
vocab_size = len(vocab)
embed_dim = 128
hidden_dim = 64
num_classes = 2

# 初始化模型
model = BiLSTMClassifier(vocab_size, embed_dim, hidden_dim, num_classes)

# 定义损失函数和优化器
criterion = nn.NLLLoss()
optimizer = optim.Adam(model.parameters())

# 训练模型
num_epochs = 200
for epoch in range(num_epochs):
    for text, length, label in train_dataloader:
        # print(text, length, label)
        # 清除之前的梯度
        optimizer.zero_grad()
        # print(batch)
        
        # 前向传播
        outputs = model(text, length)
        loss = criterion(outputs, label)
        
        # 后向传播和优化
        loss.backward()
        optimizer.step()
        
    # 每个epoch结束时打印状态
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}')
    
    # 这里可以加入早停逻辑，例如检查验证集上的性能，
    # 如果在一定epoch内性能没有提升，则停止训练
    # ...

In [None]:
# 开启模型评估模式
model.eval()

# 初始化变量以跟踪准确性和损失
total_loss = 0
total_correct = 0
total_samples = 0

with torch.no_grad():  # 关闭梯度计算
    for text, lengths, labels in test_dataloader:
        # 前向传播
        outputs = model(text, lengths)

        # 计算损失
        loss = criterion(outputs, labels)
        total_loss += loss.item()

        # 计算准确性
        _, predicted = torch.max(outputs, 1)
        total_correct += (predicted == labels).sum().item()
        total_samples += labels.size(0)

# 计算平均损失和总体准确性
avg_loss = total_loss / len(test_dataloader)
accuracy = total_correct / total_samples
print(total_correct, total_samples)

print(f'Test Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}')