In [82]:
from sentence_transformers import SentenceTransformer
import random
import torch
from torch import nn
import numpy as np
from torch.utils.data import Dataset, DataLoader
from d2l import torch as d2l
import PyPDF2
import re
import os
from IPython.display import clear_output

Dataset Loading

In [None]:
trainChance = 0.6
validationChance = 0.1
testChance = 0.3

file = open("SICK_annotated.txt", "r")

trainSet = open("TrainDataset.txt", "x")
validationSet = open("ValidationDataset.txt", "x")
testSet = open("TestDataset.txt", "x")

for i, line in enumerate(file):
    if i == 0:
        continue

    arr = line.split('\t')

    num = random.uniform(0, 1)

    if num <= trainChance:
        trainSet.write(str(arr[2] + "\t" + arr[4] + "\t" + arr[6] + "\n"))
    elif num <= trainChance + validationChance:
        validationSet.write(str(arr[2] + "\t" + arr[4] + "\t" + arr[6] + "\n"))
    else:
        testSet.write(str(arr[2] + "\t" + arr[4] + "\t" + arr[6] + "\n"))

    #print(f"{arr[2]} {arr[4]} {arr[6]}")

trainSet.close()
validationSet.close()
testSet.close()

In [None]:
trainFile = open("TrainDataset.txt", "r")
validationFile = open("ValidationDataset.txt", "r")
testFile = open("TestDataset.txt", "r")

trainSet = []
validationSet = []
testSet = []

for i, line in enumerate(trainFile):
    arr = line.split('\t')
    arr[2] = torch.tensor(float(arr[2]))
    trainSet.append(arr)

for i, line in enumerate(validationFile):
    arr = line.split('\t')
    arr[2] = torch.tensor(float(arr[2]))
    validationSet.append(arr)

for i, line in enumerate(testFile):
    arr = line.split('\t')
    arr[2] = torch.tensor(float(arr[2]))
    testSet.append(arr)

print(f"{len(trainSet)} {len(validationSet)} {len(testSet)}")

In [None]:
class SentenceDataset(Dataset):
    def __init__(self, arr):
        super().__init__()
        self.data = arr

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]

Model Creation

In [None]:
def clamp_tanh(x, clamp=15):
    return x.clamp(-clamp, clamp).tanh()

def expmap0(u, c = 1, min_norm = 1e-15):
    sqrt_c = c ** 0.5
    u_norm = torch.clamp_min(u.norm(dim=-1, p=2, keepdim=True), min_norm)
    gamma_1 = clamp_tanh(sqrt_c * u_norm) * u / (sqrt_c * u_norm)
    return gamma_1

def sqdist(p1, p2, c = 1):
        sqrt_c = c ** 0.5
        dist_c = torch.atanh(
            sqrt_c * mobius_add(-p1, p2, c, dim=-1).norm(dim=-1, p=2, keepdim=False)
        )
        dist = dist_c * 2 / sqrt_c
        return dist ** 2

def mobius_add(x, y, c = 1, dim=-1, min_norm = 1e-15):
        x2 = x.pow(2).sum(dim=dim, keepdim=True)
        y2 = y.pow(2).sum(dim=dim, keepdim=True)
        xy = (x * y).sum(dim=dim, keepdim=True)
        num = (1 + 2 * c * xy + c * y2) * x + (1 - c * x2) * y
        denom = 1 + 2 * c * xy + c ** 2 * x2 * y2
        return num / denom.clamp_min(min_norm)

In [None]:
class SearchNetwork(nn.Module):
    def __init__(self, hyperbolic, explicit, device = None) -> None:
        super().__init__()

        self.embedding = SentenceTransformer('multi-qa-mpnet-base-dot-v1')
        self.fineTune = nn.Sequential(nn.Linear(768, 768), nn.Linear(768, 768)).to(device)
        self.explicit = explicit
        self.hyperbolic = hyperbolic
        self.device = device

        if not explicit:
            self.regression = nn.Sequential(nn.Linear(1536, 768), nn.Linear(768, 1)).to(device)

    def forward(self, x1, x2):
        y1 = self.fineTune(torch.tensor(self.embedding.encode(x1)).to(self.device))
        y2 = self.fineTune(torch.tensor(self.embedding.encode(x2)).to(self.device))

        if self.hyperbolic:
            y1 = expmap0(y1)
            y2 = expmap0(y2)

        if not self.explicit:
            return self.regression(torch.cat((y1, y2), dim=-1))

        if self.hyperbolic:
            return sqdist(y1, y2).to(self.device)

        return (y1 - y2).pow(2).sum(-1).sqrt()

Training

In [None]:
def mse(prediction, target) -> float:
    return (prediction - target).pow(2)

def absError(prediction, target) -> float:
    return torch.abs(prediction - target)

In [None]:
def evaluate(network: nn.Module, dataset, lossFunc = mse, evalFunc = absError, device = None):
    network.eval()

    #Loss, evaluation metric
    metric = d2l.Accumulator(2)

    with torch.no_grad():
        for i, (x1, x2, y) in enumerate(dataset):
            y = y.to(device)

            yhat = network(x1, x2).to(device)

            loss = torch.mean(lossFunc(yhat, y))
            eval = torch.mean(evalFunc(yhat, y))

            metric.add(loss, eval)

    return metric[0] / len(dataset), metric[1] / len(dataset)

In [None]:
def train(network: nn.Module, trainDataset, validationDataset, testDataset, learnRate, epochs, modelFileName, epochsToSave=10, lossFunc = mse, evalFunc = absError, device = None):
    network.train()
    optimizer = torch.optim.Adam(network.parameters(), lr=learnRate)

    bestLoss = float('inf')

    for epoch in range(epochs):
        network.train()

        #Loss, evaluation metric
        metric = d2l.Accumulator(2)

        for i, (x1, x2, y) in enumerate(trainDataset):
            y = y.to(device)
            yhat = network(x1, x2)

            loss = torch.mean(lossFunc(yhat, y))
            eval = torch.mean(evalFunc(yhat, y))

            loss.backward()
            optimizer.step()

            metric.add(torch.mean(loss), eval)

        validationLoss, validationEval = evaluate(network, validationDataset, device=device)
        testLoss, testEval = evaluate(network, testDataset, device=device)

        if validationLoss < bestLoss:
            bestLoss = validationLoss
            torch.save(network.state_dict(), modelFileName + "BestLoss")

        if (epoch + 1) % epochsToSave == 0:
            torch.save(network.state_dict(), modelFileName + "Epoch" + str(epoch))

        print(f"Epoch: {epoch}\n\t Train Loss: {metric[0] / len(trainDataset)}\t\tTrain Eval: {metric[1] / len(trainDataset)}\n\t \
Validation Loss: {validationLoss}\t\tValidation Eval: {validationEval}\n \
\tTest Loss: {testLoss}\t\tTest Eval: {testEval}")

Running

In [None]:
batchSize = 10
learnRate = 0.0000001
epochs = 25
modelFileName = "EuclideanImplicit5"

In [None]:
trainDataset = SentenceDataset(trainSet)
trainIter = DataLoader(trainDataset, shuffle=True, batch_size=batchSize)

validationDataset = SentenceDataset(validationSet)
validationIter = DataLoader(validationDataset, batch_size=batchSize)

testDataset = SentenceDataset(testSet)
testIter = DataLoader(testDataset, batch_size=batchSize)

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
#device = torch.device("cpu")

net = SearchNetwork(True, True, device=device)
net = net.to(device)

In [None]:
train(net, trainIter, validationIter, testIter, learnRate, epochs, modelFileName, device=device)

Demo

In [83]:
documentPath = "Documents/"
fileNames = os.listdir(documentPath)

maxSentences = 250

punctuation = '\.|;|\?|!'

#1: average, 2: lowest
mode = 1

In [84]:
userInput = input("Please input your query (type quit to exit): ")

while userInput.lower() != "quit":
    clear_output()
    vals = []

    for name in fileNames:
        print(f"Reading {name}")
        if mode == 1:
            vals.append(0)
        elif mode == 2:
            vals.append(float('inf'))

        file = open(documentPath + name, 'rb')
        reader = PyPDF2.PdfReader(file)

        numSentences = 0

        for i, page in enumerate(reader.pages):
            if numSentences >= maxSentences:
                break

            arr = re.split(punctuation, page.extract_text())[:maxSentences - numSentences]
            numSentences += len(arr)

            for sentence in arr:
                dist = net(sentence, userInput)

                if mode == 1:
                    vals[-1] += dist
                elif mode == 2:
                    vals[-1] = min(dist, vals[-1])

        if mode == 1:
            vals[-1] /= numSentences

        file.close()

    newFileNames = [x for _,x in sorted(zip(vals, fileNames))]

    clear_output()
    print(f"Query: {userInput}")
    for i, name in enumerate(newFileNames):
        print(f"{i}: {name}")

    userInput = input("Please input your query (type quit to exit): ")

Query: different geometric space for information extraction
0: ReinforcementLearningHumanPreferences.pdf
1: HyperbolicRelevanceMatching.pdf
2: ChainOfThoughtPrompting.pdf
3: FactualAssociations.pdf
4: ContextAwareModeling.pdf
5: FewShotRelationExtraction.pdf
6: PromptTransferTextGeneration.pdf
7: FeedForwardLayersKeyValue.pdf
