In [1]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from transformers import GPT2LMHeadModel, GPT2Tokenizer, BertModel, BertTokenizer
from datasets import load_dataset
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.model_selection import train_test_split
from torch_geometric.data import Data

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load a custom dataset (replace with real authorship data)
# This is a placeholder; you will need to replace this with your own dataset (texts, authorship labels)
dataset = load_dataset('ag_news', split='train[:5%]')  # Small sample for testing
texts = dataset['text']
labels = dataset['label']  # You should replace this with actual authorship labels

In [3]:
# Initialize GPT-2 (or BERT for a different approach)
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')  # Change to 'bert-base-uncased' for BERT
model = GPT2LMHeadModel.from_pretrained('gpt2')  # Change to BertModel.from_pretrained for BERT

In [4]:
def get_llm_embeddings(texts, model, tokenizer):
    embeddings = []

    # Ensure that the GPT-2 tokenizer has a pad_token
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token  # Use eos_token as pad_token for GPT-2

    for text in texts:
        # Tokenize text with padding, truncation, and attention masks
        inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=512)

        # Make sure model is set to output hidden states
        model.config.output_hidden_states = True

        with torch.no_grad():
            outputs = model(**inputs)

        # The hidden states are returned in a tuple of tensors
        # outputs.hidden_states[-1] is the last layer hidden states (shape: [batch_size, seq_len, hidden_size])
        last_hidden_state = outputs.hidden_states[-1]  # Take the last layer's hidden states
        embeddings.append(last_hidden_state.mean(dim=1).squeeze().numpy())  # Mean pooling over the sequence length

    return np.array(embeddings)

# Extract LLM embeddings for the texts
llm_embeddings = get_llm_embeddings(texts, model, tokenizer)

In [5]:
def construct_graph_from_embeddings(embeddings, threshold=0.8):
    # Compute cosine similarity matrix
    similarity_matrix = cosine_similarity(embeddings)

    # Create edges for similarity above threshold
    edge_index = np.array(np.nonzero(similarity_matrix > threshold))

    return torch.tensor(edge_index, dtype=torch.long)

# Construct the graph based on LLM embeddings
edge_index = construct_graph_from_embeddings(llm_embeddings)

In [6]:
class AuthorshipGNN(torch.nn.Module):
    def __init__(self, num_features, hidden_channels, num_classes):
        super(AuthorshipGNN, self).__init__()
        self.conv1 = GCNConv(num_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, num_classes)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

In [7]:
# Convert LLM embeddings to a PyTorch tensor
x = torch.tensor(llm_embeddings, dtype=torch.float)

# Prepare graph data in PyTorch Geometric format
data = Data(x=x, edge_index=edge_index)

In [8]:
# Initialize the GNN model
gnn_model = AuthorshipGNN(num_features=llm_embeddings.shape[1], hidden_channels=64, num_classes=4)  # For 4 classes in AG News

# Define optimizer
optimizer = torch.optim.Adam(gnn_model.parameters(), lr=0.01)

# Train function
def train(model, data, optimizer, labels):
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = F.nll_loss(out, torch.tensor(labels[:len(out)], dtype=torch.long))  # Match labels to prediction size
    loss.backward()
    optimizer.step()
    return loss.item()

# Training loop
for epoch in range(10):  # Adjust number of epochs for your dataset
    loss = train(gnn_model, data, optimizer, labels)
    print(f"Epoch {epoch+1}, Loss: {loss:.4f}")

Epoch 1, Loss: 10.0325
Epoch 2, Loss: 7.7568
Epoch 3, Loss: 11.2341
Epoch 4, Loss: 16.9262
Epoch 5, Loss: 13.7347
Epoch 6, Loss: 7.1787
Epoch 7, Loss: 8.7896
Epoch 8, Loss: 9.5433
Epoch 9, Loss: 7.6898
Epoch 10, Loss: 5.6032


In [9]:
# Evaluation function
def evaluate(model, data, labels):
    model.eval()
    with torch.no_grad():
        out = model(data.x, data.edge_index)
    pred = out.argmax(dim=1)
    accuracy = (pred == torch.tensor(labels[:len(pred)], dtype=torch.long)).sum().item() / len(pred)
    return accuracy

accuracy = evaluate(gnn_model, data, labels)
print(f"Model Accuracy: {accuracy:.4f}")

Model Accuracy: 0.2578


In [None]:
# Visualize the graph (Optional)
def visualize_graph(similarity_matrix, threshold=0.8):
    G = nx.Graph()
    for i in range(similarity_matrix.shape[0]):
        G.add_node(i, label=str(i))
    for i, j in zip(*np.where(similarity_matrix > threshold)):
        G.add_edge(i, j)

    plt.figure(figsize=(10, 10))
    nx.draw(G, with_labels=True, node_color='skyblue', font_size=10)
    plt.show()
    
visualize_graph(cosine_similarity(llm_embeddings))

In [None]:
from fastapi import FastAPI
from pydantic import BaseModel
import uvicorn

# Define FastAPI app
app = FastAPI()

# Define input format
class AuthorshipRequest(BaseModel):
    text: str

@app.post("/verify_authorship/")
async def verify_authorship(request: AuthorshipRequest):
    # Process the text and get LLM embeddings
    embedding = get_llm_embeddings([request.text], model, tokenizer)

    # Make prediction using the trained GNN model
    model.eval()
    with torch.no_grad():
        out = gnn_model(torch.tensor(embedding, dtype=torch.float), data.edge_index)
    pred = out.argmax(dim=1).item()

    return {"predicted_label": pred}

# Run the FastAPI app (in production, use an ASGI server like Uvicorn)
if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)

In [None]:
# Example: Testing the API
import requests

response = requests.post("http://127.0.0.1:8000/verify_authorship/", json={"text": "Sample text to verify authorship."})
print(response.json())
