In [None]:
import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer, BertConfig
import torchvision.models as models
import torchvision.transforms as transforms
from torch_geometric.nn import GCNConv, GATConv
import torch_geometric.data
import spacy
from torch.utils.data import Dataset, DataLoader
from rouge_score import rouge_scorer

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vocab_size = 30522  # BERT base vocab size

# 1. Textual Factual Information Representation
class TextEncoder(nn.Module):
    def __init__(self, bert_model_name="bert-base-uncased"):
        super().__init__()
        self.tokenizer = BertTokenizer.from_pretrained(bert_model_name)
        self.bert = BertModel.from_pretrained(bert_model_name).to(device)
        self.nlp = spacy.load("en_core_web_sm")
        self.gcn = GCNConv(768, 768).to(device)
        self.gat = GATConv(768, 768, heads=8, dropout=0.6).to(device)

    def encode_text(self, text):
        encodings = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
        encodings = {k: v.to(device) for k, v in encodings.items()}
        with torch.no_grad():
            outputs = self.bert(**encodings)
        return outputs.last_hidden_state, encodings["input_ids"], encodings["attention_mask"]

    def build_text_graph(self, embeddings, text):
        doc = self.nlp(text)
        entities = [(ent.text, ent.start, ent.end) for ent in doc.ents]
        num_nodes = len(entities) + embeddings.size(1)
        x = torch.zeros(num_nodes, 768).to(device)
        edge_index = []

        x[:embeddings.size(1)] = embeddings[0]
        for i, (ent_text, start, end) in enumerate(entities, embeddings.size(1)):
            x[i] = embeddings[0, start:end].mean(dim=0)
            for j in range(start, end):
                edge_index.append([i, j])
                edge_index.append([j, i])

        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous().to(device)
        return torch_geometric.data.Data(x=x, edge_index=edge_index)

    def process_graph(self, graph_data):
        x = self.gcn(graph_data.x, graph_data.edge_index)
        x = torch.relu(x)
        x = self.gat(x, graph_data.edge_index)
        return x

    def forward(self, text):
        embeddings, input_ids, attention_mask = self.encode_text(text)
        graph_data = self.build_text_graph(embeddings, text)
        graph_embeddings = self.process_graph(graph_data)
        return graph_embeddings, graph_data, input_ids, attention_mask

# 2. Image Knowledge Representation
class ImageProcessor(nn.Module):
    def __init__(self):
        super().__init__()
        self.cnn = models.resnet101(pretrained=True).to(device)
        self.cnn.fc = nn.Identity()
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        self.gru = nn.GRU(input_size=2048, hidden_size=768, batch_first=True).to(device)

    def preprocess_image(self, image):
        if isinstance(image, torch.Tensor) and image.dim() == 3:
            image = self.transform(image).unsqueeze(0).to(device)
        return image

    def extract_features(self, image):
        with torch.no_grad():
            features = self.cnn(image)
        return features

    def build_erg(self, features):
        features = features.unsqueeze(1)
        gru_out, _ = self.gru(features)
        x = gru_out.squeeze(1)
        edge_index = torch.tensor([[0], [0]], dtype=torch.long).to(device)
        return torch_geometric.data.Data(x=x, edge_index=edge_index)

    def forward(self, image):
        image = self.preprocess_image(image)
        features = self.extract_features(image)
        erg = self.build_erg(features)
        return erg

# 3. Multimodal Knowledge Graph Construction
class MKGConstructor(nn.Module):
    def __init__(self, entity_dim=768):
        super().__init__()
        self.entity_dim = entity_dim
        self.entity_projection = nn.Linear(entity_dim, entity_dim).to(device)
        self.relation_scorer = nn.Linear(2 * entity_dim, 1).to(device)

    def combine_graphs(self, text_graph, image_erg):
        num_text_nodes = text_graph.x.size(0)
        x = torch.cat([text_graph.x, image_erg.x], dim=0)
        edge_index = text_graph.edge_index
        cross_edge = torch.tensor([[num_text_nodes - 1], [num_text_nodes]], dtype=torch.long).to(device)
        edge_index = torch.cat([edge_index, cross_edge], dim=1)
        return torch_geometric.data.Data(x=x, edge_index=edge_index)

    def learn_representations(self, graph_data):
        entities = self.entity_projection(graph_data.x)
        edge_pairs = torch.cat([entities[graph_data.edge_index[0]], entities[graph_data.edge_index[1]]], dim=1)
        scores = self.relation_scorer(edge_pairs)
        return entities, scores

    def forward(self, text_graph, image_erg):
        mkg_data = self.combine_graphs(text_graph, image_erg)
        mkg_embeddings, relation_scores = self.learn_representations(mkg_data)
        return mkg_embeddings, relation_scores

# 4. Entity Memory Embedding
class EntityMemory(nn.Module):
    def __init__(self, entity_dim=768, memory_size=1000):
        super().__init__()
        self.memory = nn.Parameter(torch.randn(memory_size, entity_dim)).to(device)
        self.gat = GATConv(entity_dim, entity_dim, heads=8, dropout=0.6).to(device)
        self.update_rate = 0.1

    def update_memory(self, entities, graph_data):
        x = torch.cat([entities, self.memory[:entities.size(0)]], dim=0)
        edge_index = graph_data.edge_index
        updated = self.gat(x, edge_index)[:entities.size(0)]
        with torch.no_grad():
            self.memory[:entities.size(0)] = (1 - self.update_rate) * self.memory[:entities.size(0)] + self.update_rate * updated
        return self.memory[:entities.size(0)]

    def get_entity_embedding(self, entity_ids):
        return self.memory[entity_ids]

    def forward(self, entities, graph_data):
        return self.update_memory(entities, graph_data)

# 5. Modified BERT with MKG Embeddings
class CKGMBert(BertModel):
    def __init__(self, entity_memory):
        config = BertConfig.from_pretrained("bert-base-uncased")
        super().__init__(config)
        self.entity_memory = entity_memory
        self.entity_embedding_layer = nn.Linear(768, 768).to(device)
        self.to(device)

    def forward(self, input_ids, attention_mask, entity_ids):
        token_embeddings = self.embeddings(input_ids)
        entity_embeddings = self.entity_embedding_layer(self.entity_memory.get_entity_embedding(entity_ids))
        if entity_embeddings.size(0) == token_embeddings.size(0):
            token_embeddings += entity_embeddings.unsqueeze(1).expand_as(token_embeddings)
        outputs = super().forward(inputs_embeds=token_embeddings, attention_mask=attention_mask)
        return outputs.last_hidden_state

# 6. Transformer Decoder for Summarization
class TransformerDecoder(nn.Module):
    def __init__(self, hidden_dim=768, num_layers=6, num_heads=8):
        super().__init__()
        decoder_layer = nn.TransformerDecoderLayer(d_model=hidden_dim, nhead=num_heads)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers).to(device)
        self.fc_out = nn.Linear(hidden_dim, vocab_size).to(device)
        self.embedding = nn.Embedding(vocab_size, hidden_dim).to(device)

    def forward(self, encoder_output, target_ids):
        target_emb = self.embedding(target_ids).transpose(0, 1)
        decoder_output = self.decoder(target_emb, encoder_output.transpose(0, 1))
        return self.fc_out(decoder_output.transpose(0, 1))

# 7. Full CKGM Model
class CKGM(nn.Module):
    def __init__(self):
        super().__init__()
        self.text_encoder = TextEncoder()
        self.image_processor = ImageProcessor()
        self.mkg_constructor = MKGConstructor()
        self.entity_memory = EntityMemory()
        self.bert = CKGMBert(self.entity_memory)
        self.decoder = TransformerDecoder()

    def forward(self, text, image, target_text):
        text_embeddings, text_graph, input_ids, attention_mask = self.text_encoder(text)
        image_erg = self.image_processor(image)
        mkg_embeddings, _ = self.mkg_constructor(text_graph, image_erg)
        entity_ids = torch.arange(min(mkg_embeddings.size(0), 1000)).to(device)  # Limit to memory size
        updated_embeddings = self.entity_memory(mkg_embeddings, text_graph)
        encoder_output = self.bert(input_ids, attention_mask, entity_ids)
        target_ids = self.text_encoder.tokenizer(target_text, return_tensors="pt", padding=True, truncation=True, max_length=512)["input_ids"].to(device)
        summary_logits = self.decoder(encoder_output, target_ids)
        return summary_logits

# 8. Dataset
class SummaryDataset(Dataset):
    def __init__(self, texts, images, targets):
        self.texts = texts
        self.images = images
        self.targets = targets

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        return self.texts[idx], self.images[idx], self.targets[idx]

# 9. Training and Evaluation
def compute_loss(output, target_ids, pad_token_id=0):
    output = output.view(-1, vocab_size)
    target = target_ids.view(-1)
    return nn.CrossEntropyLoss(ignore_index=pad_token_id)(output, target)

def train(model, loader, optimizer, epochs=1):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch in loader:
            text, image, target_text = batch
            optimizer.zero_grad()
            output = model(text[0], image[0], target_text[0])
            target_ids = model.text_encoder.tokenizer(target_text[0], return_tensors="pt", padding=True, truncation=True, max_length=512)["input_ids"].to(device)
            loss = compute_loss(output, target_ids)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}, Loss: {total_loss / len(loader):.4f}")

def evaluate(model, loader):
    model.eval()
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    total_scores = {'rouge1': 0, 'rouge2': 0, 'rougeL': 0}
    
    with torch.no_grad():
        for batch in loader:
            text, image, target_text = batch
            output = model(text[0], image[0], target_text[0])
            pred_ids = torch.argmax(output, dim=-1)
            pred_summary = model.text_encoder.tokenizer.decode(pred_ids[0], skip_special_tokens=True)
            scores = scorer.score(target_text[0], pred_summary)
            for key in total_scores:
                total_scores[key] += scores[key].fmeasure
    
    avg_scores = {key: total_scores[key] / len(loader) for key in total_scores}
    print(f"ROUGE-1: {avg_scores['rouge1']:.4f}, ROUGE-2: {avg_scores['rouge2']:.4f}, ROUGE-L: {avg_scores['rougeL']:.4f}")

# 10. Main Execution
if __name__ == "__main__":
    # Dummy data
    texts = ["This is a sample text about a dog and a cat.", "Birds fly in the sky."]
    images = [torch.rand(3, 224, 224), torch.rand(3, 224, 224)]
    targets = ["Dog and cat are friends.", "Birds soar high."]
    dataset = SummaryDataset(texts, images, targets)
    loader = DataLoader(dataset, batch_size=1, shuffle=True)

    # Initialize and train model
    model = CKGM().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    train(model, loader, optimizer, epochs=1)
    evaluate(model, loader)