In [2]:
import pandas as pd
import os
from dotenv import load_dotenv
from sqlalchemy import create_engine
from tqdm import tqdm

In [3]:
#Load credentials

db_username = os.getenv('POSTGRES_USERNAME')
db_password = os.getenv('POSTGRES_PASSWORD')
db_host = os.getenv('POSTGRES_HOST')
db_port = os.getenv('POSTGRES_PORT')
db_name = os.getenv('POSTGRES_DATABASE')



In [4]:
%load_ext sql

In [5]:
%sqlcmd columns --table lse_doc

  columns = inspector.get_columns(name, schema) or []


name,type,nullable,default,autoincrement,comment
id,TEXT,False,,False,
doc_id,TEXT,True,,False,
chunk_id,TEXT,True,,False,
type,TEXT,True,,False,
url,TEXT,True,,False,
title,TEXT,True,,False,
content,TEXT,True,,False,
date_scraped,TIMESTAMP,True,,False,
embedding,,True,,False,


In [6]:
df = %sql SELECT * FROM lse_doc

In [7]:
df = df.DataFrame()

In [8]:
print(df.columns)

Index(['id', 'doc_id', 'chunk_id', 'type', 'url', 'title', 'content',
       'date_scraped', 'embedding'],
      dtype='object')


In [9]:
import torch
import torch.nn as nn

class ShiftedCrossChunkAttention(nn.Module):
    def __init__(self, embed_dim, num_heads=8, shift_size=1):
        super(ShiftedCrossChunkAttention, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads)
        self.shift_size = shift_size

    def shift_key_value(self, embeddings, shift_size):
        # embeddings: Tensor of shape (chunk_size, num_chunks, embed_dim)
        #chunk_size, num_chunks, embed_dim = embeddings.size()
        
        # Shifting keys and values
        shifted_embeddings = torch.roll(embeddings, shifts=shift_size, dims=1)
        return shifted_embeddings

    def forward(self, chunk_embeddings):
        # chunk_embeddings: Tensor of shape (num_chunks, chunk_size, embed_dim)
        #num_chunks, chunk_size, embed_dim = chunk_embeddings.size()
        
        # Reshape for multihead attention
        chunk_embeddings = chunk_embeddings.permute(1, 0, 2)  # (chunk_size, num_chunks, embed_dim)
        
        # Shift keys and values
        shifted_embeddings = self.shift_key_value(chunk_embeddings, self.shift_size)
        
        # Apply attention mechanism across shifted chunks
        attn_output, _ = self.attention(chunk_embeddings, shifted_embeddings, shifted_embeddings)
        
        # Reshape back to original format
        attn_output = attn_output.permute(1, 0, 2)  # (num_chunks, chunk_size, embed_dim)
        
        return attn_output

In [10]:
#Convert strings in df["embedding"] to list

import ast

def string_to_list(embedding_str):
    return ast.literal_eval(embedding_str)

df['embedding'] = df['embedding'].apply(string_to_list)

In [16]:
scca = ShiftedCrossChunkAttention(embed_dim=1024, num_heads=32, shift_size=1)

# Function to apply the attention mechanism to each group of chunks with the same doc_id
def apply_attention(group):
    chunk_embeddings = torch.tensor(group['embedding'].tolist(), dtype=torch.float32)  # (num_chunks, embed_dim)
    num_chunks, embed_dim = chunk_embeddings.size()
    
    # Reshape to (num_chunks, chunk_size=1, embed_dim)
    chunk_embeddings = chunk_embeddings.view(num_chunks, 1, embed_dim)
    
    # Apply the attention model
    attended_embeddings = scca(chunk_embeddings)
    
    # Reshape back to (num_chunks, embed_dim) and convert to list
    return attended_embeddings.view(num_chunks, embed_dim).tolist()

# Apply the attention mechanism to each group of chunks with the same doc_id
new_embeddings = []
for doc_id, group in tqdm(df.groupby('doc_id')):
    new_embeddings.extend(apply_attention(group))

# Add the new embeddings to the dataframe
df['new_embedding_32heads_1shift'] = new_embeddings

# Verify the result
print(df.columns)

100%|██████████| 4097/4097 [00:12<00:00, 329.56it/s]

Index(['id', 'doc_id', 'chunk_id', 'type', 'url', 'title', 'content',
       'date_scraped', 'embedding', 'new_embedding_16heads_2shift',
       'new_embedding_16heads_1shift', 'new_embedding_32heads_2shift',
       'new_embedding_32heads_1shift'],
      dtype='object')





In [17]:
df.to_csv("scca_test_data.csv")