In [None]:
import pandas as pd
from transformers import AlbertTokenizer
import torch
from torch_geometric.data import Data
import networkx as nx
from sklearn.model_selection import train_test_split
from torch_geometric.loader import DataLoader

# Function to create an edge index from text data
def create_edge_index_from_text(df):
    G = nx.Graph()
    for text in df['text']:
        words = text.split()
        for i in range(len(words)):
            for j in range(i + 1, len(words)):
                if G.has_edge(words[i], words[j]):
                    G[words[i]][words[j]]['weight'] += 1
                else:
                    G.add_edge(words[i], words[j], weight=1)
    edge_index = torch.tensor(list(G.edges), dtype=torch.long).t().contiguous()
    return edge_index

# Load datasets
data1 = {
    'text': [
        "پاکستان میں لاک ڈاؤن عمران کی حکومت نے لگایا اور 2 کروڑ لوگوں کو بے روزگار کیا۔",
        "چین نے محفوظ فضائی سفرکے لیے کوویڈ-19سے بچاو کے اقدامات سخت کردئے"
    ],
    'label': [1, 0]
}

data2 = {
    'text': [
        "چین کی جانب سے پاکستان کی مدد کرنے سے انکار کا اعلان۔",
        "1 لاکھ 29 لوگوں نے روبوٹس کی جگہ لے لی کرونا وائرس کی وجہ سے۔"
    ],
    'label': [1, 0]
}

# Convert to DataFrame
df1 = pd.DataFrame(data1)
df2 = pd.DataFrame(data2)

# Combine datasets
datasets = pd.concat([df1, df2])

# Tokenization
tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
tokenized_data = tokenizer(datasets['text'].tolist(), padding=True, truncation=True, return_tensors='pt')

# Create graph data
edge_index = create_edge_index_from_text(datasets)
x = torch.tensor(tokenized_data['input_ids'], dtype=torch.float)  # Node features
y = torch.tensor(datasets['label'].tolist(), dtype=torch.long)  # Labels

graph_data = Data(x=x, edge_index=edge_index, y=y)

# Create train and test split
train_data, test_data = train_test_split([graph_data], test_size=0.2, random_state=42)

# Create dataloaders
train_loader = DataLoader(train_data, batch_size=1, shuffle=True)
test_loader = DataLoader(test_data, batch_size=1, shuffle=False)
Model Implementation with Multi-Head Attention and Prompt Learning
python
Copy code
import torch.nn as nn
from torch_geometric.nn import GCNConv
from transformers import AlbertModel

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)

    def forward(self, x):
        attn_output, _ = self.multihead_attn(x, x, x)
        return attn_output

class GCAWithPromptLearning(nn.Module):
    def __init__(self, num_classes, gcn_input_dim, gcn_hidden_dim, albert_model_name, num_heads, attention_embed_dim, prompt_length):
        super(GCAWithPromptLearning, self).__init__()
        self.gcn1 = GCNConv(gcn_input_dim, gcn_hidden_dim)
        self.gcn2 = GCNConv(gcn_hidden_dim, gcn_hidden_dim)
        self.albert = AlbertModel.from_pretrained(albert_model_name)
        self.multihead_attn = MultiHeadAttention(attention_embed_dim, num_heads)
        self.prompt_length = prompt_length
        self.fc1 = nn.Linear(self.albert.config.hidden_size + gcn_hidden_dim + prompt_length, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, input_ids, attention_mask, x, edge_index, prompt):
        # GCN forward
        x = self.gcn1(x, edge_index)
        x = torch.relu(x)
        x = self.gcn2(x, edge_index)
        x = torch.relu(x)

        # ALBERT forward
        albert_outputs = self.albert(input_ids=input_ids, attention_mask=attention_mask)
        albert_cls_output = albert_outputs[1]  # Get the [CLS] token representation

        # Multi-head attention
        attn_output = self.multihead_attn(albert_cls_output.unsqueeze(0))
        attn_output = attn_output.squeeze(0)

        # Concatenate ALBERT output with GCN output and prompt
        combined = torch.cat((attn_output, x, prompt), dim=1)
        combined = torch.relu(self.fc1(combined))
        logits = self.fc2(combined)

        return logits
Training and Evaluation Code
python
Copy code
import torch.optim as optim

# Create prompt tokens
def create_prompt_tokens(prompt_length, tokenizer):
    prompt_text = " ".join(["[MASK]"] * prompt_length)
    prompt_tokens = tokenizer(prompt_text, return_tensors='pt')['input_ids']
    return prompt_tokens

# Initialize model, loss function, and optimizer
prompt_length = 5  # Example prompt length
prompt_tokens = create_prompt_tokens(prompt_length, tokenizer)
model = GCAWithPromptLearning(num_classes=2, gcn_input_dim=768, gcn_hidden_dim=128, albert_model_name='albert-base-v2', num_heads=8, attention_embed_dim=768, prompt_length=prompt_length)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
model.train()
for epoch in range(10):  # Number of epochs
    for batch in train_loader:
        optimizer.zero_grad()
        input_ids = batch.x  # Tokenized input_ids
        attention_mask = (input_ids != tokenizer.pad_token_id).float()  # Generate attention mask
        x = batch.x
        edge_index = batch.edge_index
        labels = batch.y
        prompt = prompt_tokens.repeat(input_ids.size(0), 1)  # Repeat prompt for batch size

        # Forward pass
        outputs = model(input_ids, attention_mask, x, edge_index, prompt)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        print(f"Epoch {epoch + 1}, Loss: {loss.item()}")

# Evaluation
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for batch in test_loader:
        input_ids = batch.x
        attention_mask = (input_ids != tokenizer.pad_token_id).float()
        x = batch.x
        edge_index = batch.edge_index
        labels = batch.y
        prompt = prompt_tokens.repeat(input_ids.size(0), 1)

        outputs = model(input_ids, attention_mask, x, edge_index, prompt)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy: {94 * correct / total}%')