In [1]:
!pip install transformers
!pip install torch
!pip install annoy

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.28.1-py3-none-any.whl (7.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.0/7.0 MB[0m [31m40.4 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.11.0
  Downloading huggingface_hub-0.14.1-py3-none-any.whl (224 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m224.5/224.5 kB[0m [31m12.5 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m24.8 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tokenizers, huggingface-hub, transformers
Successfully installed huggingface-hub-0.14.1 tokenizers-0.13.3 transformers-4.28.1
Looking in indexes: https://pypi.org/simple, https://

In [None]:
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModel
from annoy import AnnoyIndex

# Load pre-trained FinBERT model and tokenizer
model_name = 'ProsusAI/finbert'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

# Load financial documents and queries
documents = pd.read_csv('financial_documents.csv')
queries = pd.read_csv('queries.csv')

# Define the number of dimensions for the embeddings
embedding_size = 768

# Define the number of trees for the Annoy index
num_trees = 50

# Define the size of the concatenated embeddings
concatenated_size = 2 * embedding_size

# Define the query tower
class QueryTower(torch.nn.Module):
    def __init__(self):
        super(QueryTower, self).__init__()
        self.model = model
    def forward(self, input_ids, attention_mask):
        with torch.no_grad():
            output = self.model(input_ids=input_ids, attention_mask=attention_mask)
            query_embedding = output[0][:, 0, :]
            return query_embedding

# Define the document tower
class DocumentTower(torch.nn.Module):
    def __init__(self):
        super(DocumentTower, self).__init__()
        self.model = model
    def forward(self, input_ids, attention_mask):
        with torch.no_grad():
            output = self.model(input_ids=input_ids, attention_mask=attention_mask)
            document_embedding = output[0][:, 0, :]
            return document_embedding

# Create the two-tower model
class TwoTower(torch.nn.Module):
    def __init__(self):
        super(TwoTower, self).__init__()
        self.query_tower = QueryTower()
        self.document_tower = DocumentTower()
        self.fc1 = torch.nn.Linear(concatenated_size, 256)
        self.fc2 = torch.nn.Linear(256, 128)
        self.fc3 = torch.nn.Linear(128, 1)

    def forward(self, query_input_ids, query_attention_mask, document_input_ids, document_attention_mask):
        query_embedding = self.query_tower(query_input_ids, query_attention_mask)
        document_embedding = self.document_tower(document_input_ids, document_attention_mask)
        concatenated_embedding = torch.cat((query_embedding, document_embedding), dim=1)
        x = torch.relu(self.fc1(concatenated_embedding))
        x = torch.relu(self.fc2(x))
        similarity_score = torch.sigmoid(self.fc3(x))
        return similarity_score

# Define the training function
def train(model, optimizer, data_loader, device):
    model.train()
    for batch in data_loader:
        optimizer.zero_grad()
        query_input_ids = batch['query_input_ids'].to(device)
        query_attention_mask = batch['query_attention_mask'].to(device)
        document_input_ids = batch['document_input_ids'].to(device)
        document_attention_mask = batch['document_attention_mask'].to(device)
        labels = batch['label'].to(device)
        outputs = model(query_input_ids, query_attention_mask, document_input_ids, document_attention_mask)
        loss = torch.nn.BCELoss()(outputs, labels)
        loss.backward()
        optimizer.step()