In [None]:
pip install transformers

In [None]:
pip install sentence_transformers

In [None]:
!pip install llama-index==0.9.38

# hej

In [None]:
import numpy as np
import torch
import torch.nn as nn
from sentence_transformers import SentenceTransformer
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.schema import Document

# Step 1: Use LlamaIndex chunkers for better text splitting
def split_text_with_sentence_splitter(text, chunk_size=512, chunk_overlap=50):
    """
    Splits text using LlamaIndex SentenceSplitter which respects sentence boundaries.
    This creates more natural, semantically coherent chunks than simple word-based splitting.
    """
    # Create a Document object (LlamaIndex's container for text)
    document = Document(text=text)
    
    # Create a SentenceSplitter - this splits by sentences and then combines them into chunks
    sentence_splitter = SentenceSplitter(
        chunk_size=chunk_size,
        chunk_overlap=chunk_overlap,
        paragraph_separator="\n\n",  # Adjust as needed for your text
        separator=" "                # Separator used when combining sentences
    )
    
    # Parse the document into nodes
    nodes = sentence_splitter.get_nodes_from_documents([document])
    
    # Extract text from each node
    chunks = [node.text for node in nodes]

    for i, chunk in enumerate(chunks):
        print(f"Chunk {i}: {chunk}")  # Print first 50 characters of each chunk for debugging
    
    return chunks

# Step 2: Embed each chunk using SentenceTransformer
class TextEmbedder:
    def __init__(self, model_name="paraphrase-MiniLM-L6-v2"):  # Using a more compatible model
        self.model = SentenceTransformer(model_name)

    def embed_text(self, chunks):
        """
        Embeds a list of text chunks using SentenceTransformer.
        """
        return self.model.encode(chunks)

# Step 3: Group embeddings by text
def process_texts(texts, chunk_size=20, chunk_overlap=20):
    """
    Processes multiple texts, splits them into chunks, embeds them, and groups embeddings by text.
    """
    embedder = TextEmbedder()
    grouped_embeddings = {}

    for idx, text in enumerate(texts):
        chunks = split_text_with_sentence_splitter(text, chunk_size, chunk_overlap)
        if chunks:  # Make sure we have chunks before proceeding
            embeddings = embedder.embed_text(chunks)
            grouped_embeddings[f"text_{idx}"] = np.array(embeddings)

    return grouped_embeddings

# The rest of your GNN code remains the same
# ...

# Example usage
if __name__ == "__main__":
    texts = [
        "This is the first example text. It is quite long and needs to be split into chunks . It is quite long and needs to be split into chunks . It is quite long and needs to be split into chunks . It is quite long and needs to be split into chunks . It is quite long and needs to be split into chunks . It is quite long and needs to be split into chunks . It is quite long and needs to be split into chunks .",
        "Here is another text that will also be split and embedded.Here is another text that will also be split and embedded.Here is another text that will also be split and embedded.Here is another text that will also be split and embedded.Here is another text that will also be split and embedded.",
         "Here is another text that will also be split and embedded.Here is another text that will also be split and embedded."
    ]

    # Process texts with LlamaIndex chunker
    grouped_embeddings = process_texts(texts)
    
    # Print information about the chunks
    for text_id, embeddings in grouped_embeddings.items():
        print(f"{text_id}: {len(embeddings)} chunks with embedding dimension {embeddings.shape[1]}")

In [None]:
grouped_embeddings

In [None]:
pip install torch-geometric

In [None]:
master_embedding.size(0)  # Replace 0 with the dimension you want to check

In [None]:
import torch
import torch.nn as nn
import numpy as np
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool

# Step 1: Define the GNN models
class ChunkGNN(nn.Module):
    """GNN for creating document-level embeddings from chunks"""
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(ChunkGNN, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, output_dim)
        
    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index)
        x = torch.relu(x)
        x = self.conv2(x, edge_index)
        x = torch.relu(x)
        x = global_mean_pool(x, batch)  # Pooling to get document-level embedding
        x = self.fc(x)
        return x

class MasterGNN(nn.Module):
    """Graph neural network for aggregating document embeddings into a fixed-size master embedding"""
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(MasterGNN, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, output_dim)
        self.output_dim = output_dim
        
    def forward(self, document_embeddings):
        """
        Create a fixed-size master embedding from document embeddings using a graph approach
        
        Args:
            document_embeddings: Tensor of shape [num_documents, embedding_dim]
        
        Returns:
            Fixed-size master embedding regardless of number of input documents
        """
        # Create a single batch for all documents (all in one graph)
        batch = torch.zeros(document_embeddings.size(0), dtype=torch.long)
        
        # Create a fully connected graph between documents
        if document_embeddings.size(0) > 1:
            # Generate all pairs of indices for a fully connected graph
            edge_index = torch.combinations(torch.arange(document_embeddings.size(0)), r=2).t()
            # Make edges bidirectional
            edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1)
        else:
            # Self-loop for a single document
            edge_index = torch.tensor([[0], [0]], dtype=torch.long)
        
        # Process through GNN
        x = self.conv1(document_embeddings, edge_index)
        x = torch.relu(x)
        x = self.conv2(x, edge_index)
        x = torch.relu(x)
        
        # Global pooling to get fixed-size representation regardless of document count
        x = global_mean_pool(x, batch)
        x = self.fc(x)
        
        return x

# Step 2: Prepare data for chunk-level GNN
def prepare_chunk_data(grouped_embeddings):
    """Convert grouped embeddings into PyTorch Geometric Data objects"""
    data_list = []
    for text_id, embeddings in grouped_embeddings.items():
        # Create node features
        x = torch.tensor(embeddings, dtype=torch.float)
        
        # Create a fully connected graph between chunks
        num_nodes = x.size(0)
        if num_nodes > 1:
            # Create bidirectional edges for better message passing
            edge_index = torch.combinations(torch.arange(num_nodes), r=2).t()
            edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1)
        else:
            # Handle single node case with self-loop
            edge_index = torch.tensor([[0], [0]], dtype=torch.long)
            
        # Create a Data object
        data = Data(x=x, edge_index=edge_index)
        data_list.append(data)
    
    return data_list

# Step 3: Process grouped embeddings through the GNNs
def process_with_gnns(grouped_embeddings, chunk_gnn, master_gnn):
    """Process embeddings through GNNs to get a single fixed-size master embedding"""
    # Prepare chunk-level data
    data_list = prepare_chunk_data(grouped_embeddings)
    loader = DataLoader(data_list, batch_size=1, shuffle=False)
    
    # Process each document through chunk-level GNN
    document_embeddings = []
    for data in loader:
        document_embedding = chunk_gnn(data.x, data.edge_index, data.batch)
        document_embeddings.append(document_embedding.squeeze(0))  # Remove batch dimension
    
    # Stack document embeddings
    document_embeddings = torch.stack(document_embeddings)
    
    # Process through master GNN to get fixed-size embedding
    master_embedding = master_gnn(document_embeddings)
    
    return master_embedding

# Example usage
if __name__ == "__main__":
    # For testing with different numbers of documents
    test_cases = [
        {"text_0": np.random.rand(5, 384)},  # 1 document
        {"text_0": np.random.rand(5, 384), "text_1": np.random.rand(3, 384)},  # 2 documents
        {"text_0": np.random.rand(5, 384), "text_1": np.random.rand(3, 384), 
         "text_2": np.random.rand(4, 384), "text_3": np.random.rand(4, 384)}  # 4 documents
    ]
    
    # Define fixed embedding dimensions
    embedding_dim = 384
    output_dim = 64
    
    # Initialize models
    chunk_gnn = ChunkGNN(input_dim=embedding_dim, hidden_dim=128, output_dim=output_dim)
    master_gnn = MasterGNN(input_dim=output_dim, hidden_dim=32, output_dim=16)
    
    # Test with different numbers of documents
    for i, grouped_embeddings in enumerate(test_cases):
        print(f"\nTest case {i+1}: {len(grouped_embeddings)} documents")
        
        # Process embeddings
        master_embedding = process_with_gnns(grouped_embeddings, chunk_gnn, master_gnn)
        
        # Verify shape is consistent
        print(f"Master embedding shape: {master_embedding.shape}")