In [90]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import pandas as pd
import pickle
from utilities import CustomDataset, trainLoop, testLoop, save_stats

# Hyperparameters

In [91]:
BATCH_SIZE = 64
LEARNING_RATE = 0.0001
NUM_EPOCHS = 100

# Siamese MLP

In [92]:
class SiameseMLPV1(nn.Module):
    def __init__(self, input_size = 768, hidden_size1 = 256, hidden_size2 = 128, output_size = 1):
        super(SiameseMLPV1, self).__init__()
        self.dropout = nn.Dropout(0.1)
        self.fc1 = nn.Linear(input_size, hidden_size1)
        self.fc2 = nn.Linear(hidden_size1, hidden_size2)
        self.fc3 = nn.Linear(hidden_size2, output_size)

    def forward_once(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return x

    def forward(self, x1, x2):
        output_A = self.forward_once(x1)
        output_B = self.forward_once(x2)
        return torch.sigmoid(self.fc3(torch.abs(output_A - output_B))).squeeze()

In [93]:
SAVES_FOLDER = "saves/"
df = pd.read_csv(SAVES_FOLDER + "dataset.csv")

with open(SAVES_FOLDER + 'id2embedding.pkl', 'rb') as f:
    id2embedding = pickle.load(f)

#model = SiameseMLPV1(1024, 512, 256)
model = SiameseMLPV1()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.BCELoss()

features_cols = ["left_spec_id", "right_spec_id"]
target_col = "label"
dataset = CustomDataset(df, features_cols, target_col, id2embedding)

train_size = int(0.75 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [94]:
stats = dict()

for epoch in range(NUM_EPOCHS):
    loss = trainLoop(model, optimizer, criterion, train_loader)
    print(f"Epoch {epoch + 1}, Loss: {loss}")
    pred_function = lambda x: x > 0.2
    testLoop(model, criterion, test_loader, pred_function, stats)

Epoch 1, Loss: 0.6731143421248386
Test Loss: 0.6371, Test Accuracy: 0.3350
Precision: 0.3350, Recall: 1.0000, F1-score: 0.5019
TP: 268, FP: 532, TN: 0, FN: 0
Epoch 2, Loss: 0.6237532135687376
Test Loss: 0.5868, Test Accuracy: 0.3538
Precision: 0.3414, Recall: 1.0000, F1-score: 0.5090
TP: 268, FP: 517, TN: 15, FN: 0
Epoch 3, Loss: 0.589394831343701
Test Loss: 0.5647, Test Accuracy: 0.4350
Precision: 0.3711, Recall: 0.9888, F1-score: 0.5397
TP: 265, FP: 449, TN: 83, FN: 3
Epoch 4, Loss: 0.5737387914406625
Test Loss: 0.5555, Test Accuracy: 0.4975
Precision: 0.3972, Recall: 0.9664, F1-score: 0.5630
TP: 259, FP: 393, TN: 139, FN: 9
Epoch 5, Loss: 0.5624395389305917
Test Loss: 0.5483, Test Accuracy: 0.5200
Precision: 0.4073, Recall: 0.9515, F1-score: 0.5705
TP: 255, FP: 371, TN: 161, FN: 13
Epoch 6, Loss: 0.5504742641198007
Test Loss: 0.5430, Test Accuracy: 0.5262
Precision: 0.4115, Recall: 0.9627, F1-score: 0.5765
TP: 258, FP: 369, TN: 163, FN: 10
Epoch 7, Loss: 0.5416155614350971
Test Loss

In [95]:
save_stats("MLP_V1_bert_base", stats)