# CQKP (Contrastive Question-Knowledge Pretraining)
I needed a tool to automatically tell if a section of text was about the same subject as a question about the text. I have a bot that searches duckduckgo for results for keywords about a question, and I need to intelligently tell if the section of text that the duckduckgo api returns is related to the question.

Filter through top 10 or so results to find the highest scoring one, just use a normal gpt with some prompting to summarize the body of the text.

In [4]:
"""
    Adapted from https://towardsdatascience.com/simple-implementation-of-openai-clip-model-a-tutorial-ace6ff01d9f2
"""

'\n    Adapted from https://towardsdatascience.com/simple-implementation-of-openai-clip-model-a-tutorial-ace6ff01d9f2\n'

In [5]:
!pip install transformers 
from google.colab import output
output.clear()
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")   

In [6]:
class TextEncoder(nn.Module):
    def __init__(self, model_name="distilbert-base-uncased", pretrained=False, trainable=False):
        super().__init__()
        if pretrained:
            self.model = DistilBertModel.from_pretrained(model_name)
        else:
            self.model = DistilBertModel(config=DistilBertConfig())
        for p in self.model.parameters():
            p.requires_grad = trainable

        self.target_token_idx = 0

    def forward(self, input_ids, attention_mask):
        output = self.model(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = output.last_hidden_state
        return last_hidden_state[:, self.target_token_idx, :]
class ProjectionHead(nn.Module):
    def __init__(
        self,
        embedding_dim,
        projection_dim=256,
        dropout=0.1
    ):
        super().__init__()
        self.projection = nn.Linear(embedding_dim, projection_dim)
        self.gelu = nn.GELU()
        self.fc = nn.Linear(projection_dim, projection_dim)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(projection_dim)
    
    def forward(self, x):
        projected = self.projection(x)
        x = self.gelu(projected)
        x = self.fc(x)
        x = self.dropout(x)
        x = x + projected
        x = self.layer_norm(x)
        return x

class CKQP_Model(nn.Module):
    def __init__(
        self,   
        temperature=1.,
        image_embedding=768,
        text_embedding=768,
        max_length=192 # somewhat arbitrary, but i found on my own testing that it's about a good lengthy paragraph long
    ):
        super().__init__()
        self.text_encoder = TextEncoder()
        self.question_encoder = TextEncoder()

        self.question_projection = ProjectionHead(embedding_dim=text_embedding)
        self.text_projection = ProjectionHead(embedding_dim=text_embedding)
        self.temperature = temperature
        self.tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
        self.max_length = max_length
    def tokenize(self, texts):
        return self.tokenizer(
            list([str(text) for text in texts]), padding=True, truncation=True, max_length=self.max_length
        )
    def forward(self, text_features, question_features, mask1, mask2):
        # mask= torch.ones(16, 1, 1, 20).to(device)
        text_features = self.text_encoder(text_features,mask1)
        question_features = self.question_encoder(question_features,mask2)
        question_embeddings = self.question_projection(question_features)
        text_embeddings = self.text_projection(text_features)

        logits = (text_embeddings @ question_embeddings.T) / self.temperature
        questions_similarity = question_features @ question_features.T
        texts_similarity = text_embeddings @ text_embeddings.T
        targets = F.softmax(
            (questions_similarity + texts_similarity) / 2 * self.temperature, dim=-1
        )
        texts_loss = cross_entropy(logits, targets, reduction='none')
        questions_loss = cross_entropy(logits.T, targets.T, reduction='none')
        loss =  (questions_loss + texts_loss) / 2.0 # shape: (batch_size)
        return loss.mean()

def cross_entropy(preds, targets, reduction='none'):
    log_softmax = nn.LogSoftmax(dim=-1)
    loss = (-targets * log_softmax(preds)).sum(1)
    if reduction == "none":
        return loss
    elif reduction == "mean":
        return loss.mean()

### Training

In [7]:
# the gross part: loading data
import pandas as pd 
from torch.utils.data import DataLoader
from tqdm.auto import *
import itertools

csv = pd.read_csv("SQuAD_csv.csv")
contexts = csv['context']
questions = csv['question']
model = CKQP_Model().to(device)

dataset = [(k,q) for k,q in zip(contexts,questions)]
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

params = [
    {"params": model.text_encoder.parameters(), "lr": 1e-3},
    {"params": model.question_encoder.parameters(), "lr": 1e-3},
    {"params": itertools.chain(
        model.question_projection.parameters(), model.text_projection.parameters()
    ), "lr": 1e-3, "weight_decay": 1e-3}
]

optimizer = torch.optim.AdamW(params, weight_decay=0.)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="min", patience=2, factor=.5
)   

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/483 [00:00<?, ?B/s]

In [None]:
for step, batch in enumerate(tqdm(dataloader)):
    batch[0] = model.tokenizer(
        list(batch[0]), padding=True, truncation=True, max_length=model.max_length
    )
    batch[1] = model.tokenizer(
        list(batch[1]), padding=True, truncation=True, max_length=model.max_length
    )
    loss = model(torch.tensor(batch[0]['input_ids'],device=device),
                 torch.tensor(batch[1]['input_ids'],device=device),
                 torch.tensor(batch[0]['attention_mask'],device=device),
                 torch.tensor(batch[1]['attention_mask'],device=device))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    lr_scheduler.step(loss)

  0%|          | 0/10853 [00:00<?, ?it/s]

In [None]:
# save the model
torch.save(model, "SQuAD_CQKP.pt")