In [19]:
import pandas as pd
import os
from dotenv import load_dotenv
from sqlalchemy import create_engine
from tqdm import tqdm

In [12]:
#Load credentials

db_username = os.getenv('POSTGRES_USERNAME')
db_password = os.getenv('POSTGRES_PASSWORD')
db_host = os.getenv('POSTGRES_HOST')
db_port = os.getenv('POSTGRES_PORT')
db_name = os.getenv('POSTGRES_DATABASE')



In [14]:
%load_ext sql

The sql extension is already loaded. To reload it, use:
  %reload_ext sql


In [15]:
%sqlcmd columns --table lse_doc

  columns = inspector.get_columns(name, schema) or []


name,type,nullable,default,autoincrement,comment
id,TEXT,False,,False,
doc_id,TEXT,True,,False,
chunk_id,TEXT,True,,False,
type,TEXT,True,,False,
url,TEXT,True,,False,
title,TEXT,True,,False,
content,TEXT,True,,False,
date_scraped,TIMESTAMP,True,,False,
embedding,,True,,False,


In [17]:
df = %sql SELECT * FROM lse_doc

In [31]:
df = df.DataFrame()

In [43]:
## Define SCCA and parameters

import torch
import torch.nn as nn

class ShiftedCrossChunkAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, shift_size):
        super(ShiftedCrossChunkAttention, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads)
        self.shift_size = shift_size

    def shift_key_value(self, embeddings, shift_size):
        # Shifting keys and values
        shifted_embeddings = torch.roll(embeddings, shifts=shift_size, dims=1)
        return shifted_embeddings

    def forward(self, chunk_embeddings):
        # Ensure chunk_embeddings is 3D
        if chunk_embeddings.dim() == 2:
            chunk_embeddings = chunk_embeddings.unsqueeze(1)  # Add chunk dimension
        
        # Reshape for multihead attention
        chunk_embeddings = chunk_embeddings.permute(1, 0, 2)  # (chunk_size, num_chunks, embed_dim)
        
        # Shift keys and values
        shifted_embeddings = self.shift_key_value(chunk_embeddings, self.shift_size)
        
        # Apply attention mechanism across shifted chunks
        attn_output, _ = self.attention(chunk_embeddings, shifted_embeddings, shifted_embeddings)
        
        # Reshape back to original format
        attn_output = attn_output.permute(1, 0, 2)  # (num_chunks, chunk_size, embed_dim)
        
        return attn_output

In [40]:
#Convert strings in df["embedding"] to list

import ast

def string_to_list(embedding_str):
    return ast.literal_eval(embedding_str)

df['embedding'] = df['embedding'].apply(string_to_list)

In [46]:
embed_dim = df['embedding'].iloc[0].size(0)
scca = ShiftedCrossChunkAttention(embed_dim=embed_dim, num_heads=16, shift_size=2)

# Function to apply the attention mechanism to each group of chunks with the same doc_id
def apply_attention(group):
    chunk_embeddings = torch.stack(group['embedding'].tolist())  # (num_chunks, chunk_size, embed_dim)
    attn_output = scca(chunk_embeddings)
    return attn_output.tolist()

# Apply the attention mechanism to each group of chunks with the same doc_id
new_embeddings = []
for doc_id, group in tqdm(df.groupby('doc_id')):
    new_embeddings.extend(apply_attention(group))

# Add the new embeddings to the dataframe
df['new_embedding_16heads_2shift'] = new_embeddings

# Verify the result
print(df.head())

100%|██████████| 4097/4097 [00:10<00:00, 407.35it/s]


                                   id                            doc_id  \
0  4512ceec4271e3dae865963cd56e4c43_0  4512ceec4271e3dae865963cd56e4c43   
1  4512ceec4271e3dae865963cd56e4c43_1  4512ceec4271e3dae865963cd56e4c43   
2  4512ceec4271e3dae865963cd56e4c43_2  4512ceec4271e3dae865963cd56e4c43   
3  4512ceec4271e3dae865963cd56e4c43_3  4512ceec4271e3dae865963cd56e4c43   
4  80c96c0909e63bd720109e2fe0d19306_0  80c96c0909e63bd720109e2fe0d19306   

  chunk_id     type                                                url  \
0        0      pdf  https://www.lse.ac.uk/africa/assets/Documents/...   
1        1      pdf  https://www.lse.ac.uk/africa/assets/Documents/...   
2        2      pdf  https://www.lse.ac.uk/africa/assets/Documents/...   
3        3      pdf  https://www.lse.ac.uk/africa/assets/Documents/...   
4        0  webpage  https://www.lse.ac.uk/News/Latest-news-from-LS...   

                                               title  \
0                        FLIA-Strategy-2023-2026

In [45]:
print(df["embedding"][0])
print(df["new_embedding"][0])

tensor([-0.0192, -0.0114, -0.0225,  ..., -0.0168, -0.0105,  0.0187])
[[0.02040272019803524, -0.009144347161054611, -0.015597495250403881, -0.0023359833285212517, 0.012673395685851574, 0.00904243066906929, 0.010066963732242584, -0.02606111951172352, 0.007092597428709269, -0.014990451745688915, -0.005179115105420351, 0.007910103537142277, -0.003317869734019041, -0.007480286527425051, 0.001792743569239974, 0.004515698179602623, 0.006688226014375687, 0.013258748687803745, 0.008924489840865135, 0.004127077758312225, -0.002436896786093712, 0.019145485013723373, 0.00438795005902648, 0.018516404554247856, 0.009216328151524067, -0.0017159337876364589, 0.008344014175236225, 0.0024057687260210514, -0.004002937115728855, 0.005214898847043514, 8.214887202484533e-05, -0.0027626007795333862, -0.0008861299720592797, -0.014175898395478725, 0.016800688579678535, 0.00086645083501935, 0.0033337674103677273, 0.005782128311693668, -0.029178712517023087, 0.00042248269892297685, 0.004733374807983637, -0.00655