In [1]:
import ast
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoModel
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.metrics import roc_curve

tokenizer = AutoTokenizer.from_pretrained("SIKU-BERT/sikubert")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


Using device: cuda


In [2]:
# Import data for SNN training and testing

snn_data = pd.read_csv('couplets_snn_v2.csv').iloc[:, 1:]

np.random.seed(6)
shuffled_index_data = np.random.permutation(snn_data.index)
data_shuffled = snn_data.iloc[shuffled_index_data]
train, test = train_test_split(data_shuffled, test_size=0.2, random_state=6, stratify=data_shuffled['label'])
# train.to_csv('couplets_snn.csv', index=False)
train = train.reset_index(drop=True)
test = test.reset_index(drop=True)

display(train)
display(test)

Unnamed: 0,label,text1,text2,word_seg1,word_seg2,POS1,POS2
0,1,岁通盛世家家富,人遇年华个个欢,"[1, 1, 1, 0, 1, 0, 1]","[1, 1, 1, 0, 1, 0, 1]","[11, 11, 7, 11, 7, 7]","[7, 15, 15, 7, 2, 2]"
1,1,森林繁茂生梁栋,大海深沉蕴宝珍,"[1, 0, 1, 0, 1, 1, 0]","[1, 0, 1, 0, 1, 1, 0]","[7, 7, 0, 0, 15, 11]","[7, 7, 0, 0, 15, 0]"
2,1,换个角度看世界,留些空间还中国,"[1, 1, 1, 0, 1, 1, 0]","[1, 1, 1, 0, 1, 1, 0]","[15, 7, 7, 7, 15, 7]","[15, 0, 7, 7, 15, 11]"
3,1,崛起新村莺燕舞,腾飞伟业凤龙来,"[1, 0, 1, 0, 1, 0, 1]","[1, 0, 1, 0, 1, 0, 1]","[15, 15, 7, 7, 7, 7]","[11, 15, 7, 7, 11, 11]"
4,1,一湖芳树洇诗绿,万簇鲜花沁画红,"[1, 1, 1, 0, 1, 1, 1]","[1, 1, 1, 0, 1, 1, 1]","[8, 7, 7, 7, 11, 11]","[8, 7, 7, 7, 15, 15]"
...,...,...,...,...,...,...,...
159995,1,胸中块垒最添堵,眼下风光正入时,"[1, 0, 1, 0, 1, 1, 1]","[1, 0, 1, 0, 1, 1, 1]","[7, 7, 7, 7, 2, 0]","[7, 7, 7, 7, 2, 15]"
159996,0,不愿折腰求寸进,杨柳屯前劳燕分,"[1, 1, 1, 0, 0, 0, 0]","[1, 0, 0, 1, 0, 0, 1]","[3, 3, 15, 15, 15, 15]","[11, 11, 9, 7, 15, 11]"
159997,1,丑角登台添笑料,时髦抢眼挺风流,"[1, 0, 1, 0, 1, 1, 0]","[1, 0, 1, 0, 1, 1, 0]","[7, 7, 15, 15, 15, 7]","[0, 0, 0, 0, 2, 0]"
159998,1,雨打芳林花溅泪,风吹古柏树参天,"[1, 1, 1, 0, 0, 1, 0]","[1, 1, 1, 0, 0, 1, 0]","[15, 15, 7, 7, 7, 15]","[7, 15, 7, 7, 7, 15]"


Unnamed: 0,label,text1,text2,word_seg1,word_seg2,POS1,POS2
0,0,个个爱弹廉政曲,昂首挺胸不求人,"[1, 0, 1, 1, 1, 0, 0]","[1, 0, 0, 0, 1, 1, 1]","[7, 7, 3, 15, 7, 7]","[15, 15, 15, 15, 2, 15]"
1,1,春堤榆柳烟笼翠,夏苑牡丹影泛红,"[1, 1, 1, 1, 1, 0, 1]","[1, 0, 1, 0, 1, 1, 0]","[7, 7, 7, 7, 7, 15]","[11, 11, 7, 7, 7, 15]"
2,1,三溪水笑联花艳,双凤山清诗意浓,"[1, 0, 1, 1, 1, 1, 1]","[1, 0, 0, 1, 1, 0, 1]","[11, 11, 9, 15, 15, 7]","[11, 11, 9, 0, 7, 7]"
3,1,竹篙桂楫飞如箭,鹤性松心合在山,"[1, 0, 1, 1, 1, 1, 1]","[1, 0, 1, 1, 1, 1, 1]","[7, 7, 7, 7, 15, 15]","[11, 11, 11, 11, 15, 15]"
4,1,风添红袖三分瘦,梦枕秋肩一晚凉,"[1, 1, 1, 0, 1, 1, 1]","[1, 1, 1, 0, 1, 1, 1]","[7, 15, 7, 7, 8, 7]","[15, 15, 7, 7, 8, 7]"
...,...,...,...,...,...,...,...
39995,0,扭扭捏捏当公主,远景熙和岛上香,"[1, 0, 0, 0, 1, 1, 0]","[1, 0, 1, 0, 0, 1, 1]","[0, 0, 0, 2, 15, 7]","[11, 11, 11, 11, 7, 7]"
39996,0,千山景色常入眼,光明心事月当天,"[1, 1, 1, 0, 1, 1, 0]","[1, 0, 1, 0, 1, 1, 0]","[8, 7, 7, 7, 2, 15]","[11, 11, 7, 7, 9, 7]"
39997,1,一溪弹破空山寂,几树妆成野岭春,"[1, 1, 1, 1, 1, 1, 1]","[1, 1, 1, 1, 1, 1, 1]","[11, 11, 15, 15, 7, 7]","[8, 7, 15, 15, 11, 11]"
39998,0,苍凉是酒别开口,月色酬君满腹诗,"[1, 0, 1, 1, 1, 1, 0]","[1, 0, 1, 1, 1, 0, 1]","[7, 7, 3, 7, 2, 15]","[7, 7, 15, 15, 15, 15]"


In [9]:

# Define the class for dataloader
class CoupletsDataSNN(Dataset):
    def __init__(self, data, max_len, tokenizer):
        self.data = data
        self.max_len = max_len
        self.tokenizer = tokenizer
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        label = self.data.iloc[index, 0]
        text = "".join(self.data.iloc[index, 1:3])

        word_seg1 = torch.tensor(ast.literal_eval(self.data.iloc[index, 3]))
        word_seg2 = torch.tensor(ast.literal_eval(self.data.iloc[index, 4]))
        pos1 = torch.tensor(ast.literal_eval(self.data.iloc[index, 5]))
        pos2 = torch.tensor(ast.literal_eval(self.data.iloc[index, 6]))
        
        encoding = self.tokenizer(
            text, 
            max_length=self.max_len,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )
        input_ids = encoding["input_ids"].squeeze(0)
        input1 = torch.concat([input_ids[:(self.max_len) // 2].float(), word_seg1.float(), pos1.float()])
        input2 = torch.concat([input_ids[(self.max_len) // 2:].float(), word_seg2.float(), pos2.float()])

        return input1, input2, label

# Define the shared subnetwork
class SharedSubnetwork(nn.Module):
    def __init__(self, input_dim=20, hidden_dim=120, output_dim=10):
        super(SharedSubnetwork, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, hidden_dim // 4),
            nn.ReLU(),
            nn.Linear(hidden_dim // 4, output_dim),
            nn.ReLU()
        )
    
    def forward(self, x):
        return self.network(x)


# Define the Siamese Neural Network
class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        self.shared_network = SharedSubnetwork()
    
    def forward(self, input1, input2):
        # Pass both inputs through the shared subnetwork
        output1 = self.shared_network(input1)
        output2 = self.shared_network(input2)
        # Compute the cosine similarity
        similarity = F.cosine_similarity(output1, output2, dim=1)
        # print(f"distance: {distance}")
        return similarity

# Define the contrastive loss
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=0.5):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin
    
    def forward(self, similarity, label):
        # Convert similarity to distance for loss computation
        loss = (1 - label) * torch.pow(similarity, 2) + \
               label * torch.pow(torch.clamp(self.margin - similarity, min=0.0), 2)
        return torch.mean(loss)


# Training function
def train_siamese_network(model, dataloader, optimizer, loss_fn, epochs=20):
    for epoch in range(epochs):
        model.train()
        total_loss = 0.0
        for input1, input2, label in dataloader:
            input1, input2, label = input1.to(device), input2.to(device), label.to(device)
            optimizer.zero_grad()
            similarity = model(input1, input2)
            loss = loss_fn(similarity, label)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {total_loss/len(dataloader):.4f}")


In [14]:
# Training start
if __name__ == "__main__":
    dataloader = DataLoader(CoupletsDataSNN(train, 14, tokenizer), batch_size=128, shuffle=True)
    
    # Initialize the model, loss function, and optimizer
    model = SiameseNetwork().to(device)
    loss_fn = ContrastiveLoss(margin=0.5)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    # Train the model
    train_siamese_network(model, dataloader, optimizer, loss_fn, epochs=20)

    torch.save(model, "siamese_model_complete.pth")
    print("Entire model saved successfully!")

Epoch [1/20], Loss: 0.0676
Epoch [2/20], Loss: 0.0625
Epoch [3/20], Loss: 0.0621
Epoch [4/20], Loss: 0.0618
Epoch [5/20], Loss: 0.0616
Epoch [6/20], Loss: 0.0615
Epoch [7/20], Loss: 0.0613
Epoch [8/20], Loss: 0.0612
Epoch [9/20], Loss: 0.0612
Epoch [10/20], Loss: 0.0610
Epoch [11/20], Loss: 0.0609
Epoch [12/20], Loss: 0.0610
Epoch [13/20], Loss: 0.0608
Epoch [14/20], Loss: 0.0608
Epoch [15/20], Loss: 0.0607
Epoch [16/20], Loss: 0.0607
Epoch [17/20], Loss: 0.0606
Epoch [18/20], Loss: 0.0606
Epoch [19/20], Loss: 0.0606
Epoch [20/20], Loss: 0.0606
Entire model saved successfully!


In [13]:

# test = test.drop(columns=['label'])
class TestDataSNN(Dataset):
    def __init__(self, data, tokenizer, max_len=14):
        self.data = data
        self.max_len = max_len
        self.tokenizer = tokenizer
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        text = "".join(self.data.iloc[index, 0:2])
        word_seg1 = torch.tensor(ast.literal_eval(self.data.iloc[index, 2]))
        word_seg2 = torch.tensor(ast.literal_eval(self.data.iloc[index, 3]))
        pos1 = torch.tensor(ast.literal_eval(self.data.iloc[index, 4]))
        pos2 = torch.tensor(ast.literal_eval(self.data.iloc[index, 5]))
        
        encoding = self.tokenizer(
            text, 
            max_length=self.max_len,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )
        input_ids = encoding["input_ids"].squeeze(0)
        input1 = torch.concat([input_ids[:(self.max_len) // 2].float(), word_seg1.float(), pos1.float()])
        input2 = torch.concat([input_ids[(self.max_len) // 2:].float(), word_seg2.float(), pos2.float()])
        return input1, input2

# dataloader = DataLoader(TestDataSNN(test.iloc[:,1:], tokenizer), batch_size=1, shuffle=True)
# for input1, input2 in dataloader:
#     input1, input2 = input1.to(device), input2.to(device)

#     with torch.no_grad():
#         similarity = model(input1, input2)
#         print(similarity.item())



In [5]:
def find_optimal_threshold(model, val_loader):
    model.eval()
    all_labels = []
    all_scores = []

    with torch.no_grad():
        for input1, input2, labels in val_loader:
            input1, input2, labels = input1.to(device), input2.to(device), labels.to(device)
            scores = model(input1, input2)  # Cosine similarity
            all_scores.extend(scores.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    fpr, tpr, thresholds = roc_curve(all_labels, all_scores)
    optimal_idx = (tpr - fpr).argmax()
    optimal_threshold = thresholds[optimal_idx]

    print(f"Optimal Threshold: {optimal_threshold:.4f}")
    return optimal_threshold

def test_snn_model(model, test_data, max_len, tokenizer, batch_size=32, threshold=0.3):
    """
    Test the trained Siamese Neural Network model on test data.
    
    Args:
        model (nn.Module): Trained SNN model.
        test_data (pd.DataFrame): Test dataset as a DataFrame.
        max_len (int): Maximum sequence length for tokenization.
        tokenizer: Tokenizer to preprocess the text.
        batch_size (int): Batch size for testing.
        threshold (float): Threshold to classify similarity.
    
    Returns:
        None: Prints test metrics (accuracy, precision, recall, F1-score).
    """
    # Prepare the test dataset and dataloader
    test_dataset = CoupletsDataSNN(test_data, max_len, tokenizer)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    # Set the model to evaluation mode
    model.eval()
    model.to(device)

    all_labels = []
    all_predictions = []

    with torch.no_grad():
        for input1, input2, labels in test_loader:
            # Move data to the appropriate device
            input1, input2, labels = input1.to(device), input2.to(device), labels.to(device)
            
            # Forward pass
            similarity_scores = model(input1, input2)  # Cosine similarity values
            
            # Convert similarity scores to binary predictions
            predictions = (similarity_scores >= threshold).float()
            
            # Store labels and predictions
            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predictions.cpu().numpy())
    
    # Compute metrics
    accuracy = accuracy_score(all_labels, all_predictions)
    precision = precision_score(all_labels, all_predictions, zero_division=1)
    recall = recall_score(all_labels, all_predictions, zero_division=1)
    f1 = f1_score(all_labels, all_predictions, zero_division=1)

    # Print metrics
    print(f"Test Accuracy: {accuracy:.4f}")
    print(f"Test Precision: {precision:.4f}")
    print(f"Test Recall: {recall:.4f}")
    print(f"Test F1-Score: {f1:.4f}")



In [15]:

test_dataset = CoupletsDataSNN(test, 14, tokenizer)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
threshold = find_optimal_threshold(model, test_loader)
threshold

Optimal Threshold: 0.2538


0.2538063

In [16]:
test_snn_model(model, test, 14, tokenizer, batch_size=32, threshold=threshold)

Test Accuracy: 0.5680
Test Precision: 0.5753
Test Recall: 0.5191
Test F1-Score: 0.5458
