In [4]:
import torch.nn as nn
import math
import torch
from torchtext.vocab import GloVe, Vocab
from collections import Counter
from conceptnet import get_conceptnet_facts_for_image

## Common Sense Knowledge Retrieval (ConceptNet)

In this section, we define the helper functions used to fetch and process **common sense knowledge** from **ConceptNet**.

Given a set of object concepts (e.g. extracted from MS COCO annotations or an object detector), these functions:
- query ConceptNet for relevant relations,
- filter and rank the retrieved facts,
- and format them into a compact representation suitable for use as model input.

The resulting knowledge will later be combined with visual features to augment image caption generation with external commonsense information.

In [3]:
# get_conceptnet_facts_for_image(["dog", "park", "ball"])

## Load GloVe and Build a Vocab

Here we:

1. Load pretrained GloVe vectors using `torchtext.vocab.GloVe`.
2. Build a `Vocab` object from `glove.stoi` (string-to-index mapping).
3. Add special tokens:
   - `<unk>` for unknown words
   - `<pad>` for padding

The important idea:

> **GloVe is just a big lookup table.**

Each word has a fixed vector, and we wrap that into a PyTorch `nn.Embedding` layer later.

In [None]:
# Load GloVe
glove = GloVe(name="6B", dim=100)

# Define special tokens
specials = ["<unk>", "<pad>"]

# Build a Counter from GloVe vocab (all frequency=1)
counter = Counter(glove.stoi.keys())

# Create Vocab properly
my_vocab = Vocab(counter, specials=specials)

vocab_size = len(my_vocab)
embedding_dim = glove.dim  # same as glove.vectors.size(1)

print("Vocab size:", vocab_size)
print("Embedding dim:", embedding_dim)

## Create an Embedding Layer from Pretrained Vectors

We now wrap the GloVe tensor into an `nn.Embedding` using
`nn.Embedding.from_pretrained`.

- `glove.vectors` is a tensor of shape `[vocab_size_without_specials, embedding_dim]`.
- We need to **extend** it to include our `<unk>` and `<pad>` rows.
- `freeze=True` means we do **not** train the embeddings; they stay as GloVe.

This layer is still just a **lookup table**: it maps token IDs → word vectors.

In [None]:
# Build a weight matrix that matches our vocab (including specials)
num_specials = len(specials)
pad_vectors = torch.zeros(num_specials, embedding_dim)

# Order: specials first, then GloVe vectors
embedding_weights = torch.cat([pad_vectors, glove.vectors], dim=0)
assert embedding_weights.size(0) == vocab_size

embedding_layer = nn.Embedding.from_pretrained(
    embedding_weights,
    freeze=True  # set to False if you want to fine‑tune the embeddings
)

## Positional Encoding

Self‑attention by itself is **position‑agnostic**. It doesn't know which token
came first.

We add a standard sinusoidal positional encoding (as in the original Transformer paper):

- Same `d_model` as the embeddings
- Precomputed for a maximum sequence length
- Added to the embeddings before passing them to the Transformer

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.cos(position * div_term)
        pe[:, 1::2] = torch.sin(position * div_term)
        pe = pe.unsqueeze(0)  # shape: [1, max_len, d_model]
        self.register_buffer("pe", pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Add positional encodings to input.

        x shape: [batch_size, seq_len, d_model]
        """
        seq_len = x.size(1)
        x = x + self.pe[:, :seq_len]
        return self.dropout(x)

## Model Architecture: Dual-Encoder + Single Decoder (Image + Commonsense)

We use a **dual-encoder, single-decoder** architecture to generate captions conditioned on:
1) the **image content**, and  
2) a compact set of **commonsense facts** retrieved from ConceptNet (derived from the objects in the image).

### High-level idea
- The **vision encoder** converts the image into a sequence of **image tokens**.
- The **knowledge encoder** converts the retrieved commonsense facts into a sequence of **knowledge tokens**.
- A **caption decoder** generates the caption autoregressively while attending to **both** token sequences via cross-attention.

---

### Components

**1) Vision Encoder (ViT)**
- Input: an RGB image (e.g., `256×256×3`)
- The image is split into fixed-size patches (e.g., `16×16`), projected into an embedding space, and processed by a Transformer encoder.
- Output: `V_img ∈ R^{N_img × d}` (a sequence of image token embeddings)

**2) Knowledge Encoder (Text Transformer)**
- Input: a list of ConceptNet facts converted into short natural-language sentences (e.g., `"A frisbee is used for play."`)
- Facts are tokenized and embedded, then processed by a Transformer encoder.
- Output: `V_kg ∈ R^{N_kg × d}` (a sequence of knowledge token embeddings)

**3) Caption Decoder (Transformer)**
- Input: caption tokens generated so far (teacher forcing during training)
- The decoder performs:
  - self-attention over caption tokens
  - cross-attention to `V_img` (visual grounding)
  - cross-attention to `V_kg` (commonsense augmentation)
- Output: next-token distribution over the caption vocabulary

---

### Cross-attention fusion (two memories)
At each decoder layer, we attend to two separate memories:
- **Image memory:** keys/values from `V_img`
- **Knowledge memory:** keys/values from `V_kg`

A common implementation is **sequential cross-attention**:
1. cross-attend to image tokens (grounding)
2. cross-attend to knowledge tokens (augmentation)

This keeps image evidence primary while allowing knowledge to refine the generation.

--------
## ViT Encoder

In [None]:
class ViTImageEncoder(nn.Module):
    """
    ViT-style image encoder (patch embeddings + Transformer encoder).

    Key detail: we use a Conv2d as a *patch embedding* layer.
    - kernel_size = patch_size and stride = patch_size
    - so each convolution “step” looks at one non-overlapping patch (e.g. 16x16)
    - and produces a d_model-dimensional vector for that patch
    This is equivalent to: flatten(patch) -> Linear(patch_dim -> d_model),
    but implemented efficiently as a convolution.
    """
    def __init__(
        self,
        image_size: int = 256,
        patch_size: int = 16,
        in_channels: int = 3,
        d_model: int = 768,
        n_layers: int = 6,
        n_heads: int = 12,
        mlp_ratio: float = 4.0,
        dropout: float = 0.1,
        use_cls_token: bool = True,
    ):
        super().__init__()
        assert image_size % patch_size == 0, "image_size must be divisible by patch_size"

        self.image_size = image_size
        self.patch_size = patch_size
        self.d_model = d_model
        self.use_cls_token = use_cls_token

        grid = image_size // patch_size
        self.num_patches = grid * grid

        # Patch embedding via Conv2d:
        # Input:  (B, 3, 256, 256) 256x256 pixel RGB image
        # Output: (B, d_model, 16, 16)  when patch_size=16
        # Each (16,16) spatial location corresponds to one patch and has a d_model-dimensional embedding vector (these are convolution filter outputs)
        # The idea here is to project the image patches into d_model-dimensional tokens for the Transformer.
        self.patch_embed = nn.Conv2d(
            in_channels=in_channels,
            out_channels=d_model,
            kernel_size=patch_size,
            stride=patch_size,
            bias=True,
        )

        seq_len = self.num_patches + (1 if use_cls_token else 0)

        if use_cls_token:
            self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
        else:
            self.cls_token = None

        # Learned absolute positional embedding for each token in the sequence
        self.pos_embed = nn.Parameter(torch.zeros(1, seq_len, d_model))
        self.pos_drop = nn.Dropout(dropout)

        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=int(d_model * mlp_ratio),
            dropout=dropout,
            batch_first=True,   # (B, T, C)
            norm_first=True,    # pre-norm
            activation="gelu",
        )
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=n_layers)

        self._init_params()

    def _init_params(self):
        nn.init.normal_(self.pos_embed, std=0.02)
        if self.cls_token is not None:
            nn.init.normal_(self.cls_token, std=0.02)
        nn.init.kaiming_normal_(self.patch_embed.weight, mode="fan_out", nonlinearity="relu")
        if self.patch_embed.bias is not None:
            nn.init.zeros_(self.patch_embed.bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (B, 3, H, W) with H=W=image_size
        returns: (B, T, d_model) where T = num_patches (+1 if CLS)
        """
        B, C, H, W = x.shape
        assert H == self.image_size and W == self.image_size, \
            f"Expected {self.image_size}x{self.image_size}, got {H}x{W}"

        # 1) Patch embedding
        # (B, d_model, grid, grid)
        x = self.patch_embed(x)

        # 2) Flatten spatial grid -> sequence of tokens
        # flatten(2): (B, d_model, grid*grid) then transpose -> (B, grid*grid, d_model)
        x = x.flatten(2).transpose(1, 2)  # (B, num_patches, d_model)

        # 3) Optional CLS token
        if self.use_cls_token:
            cls = self.cls_token.expand(B, -1, -1)      # (B, 1, d_model)
            x = torch.cat([cls, x], dim=1)              # (B, 1+num_patches, d_model)

        # 4) Add positional embeddings -> element wise addition of learned weights
        x = x + self.pos_embed[:, :x.size(1), :]
        x = self.pos_drop(x)

        # 5) Transformer encoder
        x = self.encoder(x)  # (B, T, d_model)
        return x

## Knowledge Encoder

Encodes the facts from conceptnet.

In [None]:
class KnowledgeTextEncoder(nn.Module):
    """
    Encodes commonsense facts (already tokenized) into a sequence of embeddings.

    Inputs:
      - input_ids: (B, T_kg)
      - attention_mask: (B, T_kg) with 1 for keep, 0 for pad

    Output:
      - V_kg: (B, T_kg, d_model)
    """
    def __init__(
        self,
        embedding_layer: nn.Embedding,
        d_model: int = 768,
        n_layers: int = 4,
        n_heads: int = 12,
        mlp_ratio: float = 4.0,
        dropout: float = 0.1,
        pad_token_id: int = 0,
    ):
        super().__init__()
        self.pad_token_id = pad_token_id
        self.d_model = d_model

        self.token_embed = embedding_layer
        self.pos_encoder = PositionalEncoding(d_model=d_model, dropout=dropout)

        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=int(d_model * mlp_ratio),
            dropout=dropout,
            activation="gelu",
            batch_first=True,
            norm_first=True,  # pre-norm
        )
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=n_layers)

    def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
        # input_ids: (B, T)
        x = self.token_embed(input_ids)  # (B, T, d)

        # Add positional encodings
        x = self.pos_encoder(x)

        x = self.encoder(x)
        return x


## Dual-memory decoder (self-attn + cross-attn to image + cross-attn to KG)

A single decoder layer (custom, because nn.TransformerDecoderLayer only supports one memory)

In [None]:
class DualCrossAttnDecoderLayer(nn.Module):
    """
    One decoder layer with:
      1) causal self-attention over caption tokens
      2) cross-attention to image tokens
      3) cross-attention to knowledge tokens
      4) FFN

    Pre-norm design for stability.
    """
    def __init__(
        self,
        d_model: int = 768,
        n_heads: int = 12,
        mlp_ratio: float = 4.0,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        self.cross_attn_img = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        self.cross_attn_kg = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.norm4 = nn.LayerNorm(d_model)

        self.drop = nn.Dropout(dropout)

        hidden = int(d_model * mlp_ratio)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, hidden),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, d_model),
        )

    def forward(
        self,
        x: torch.Tensor,                         # (B, T_txt, d)
        img_mem: torch.Tensor,                   # (B, T_img, d)
        kg_mem: torch.Tensor,                    # (B, T_kg, d)
        *,
        txt_key_padding: torch.Tensor = None,    # (B, T_txt) True=pad
        img_key_padding: torch.Tensor = None,    # (B, T_img) True=pad (often None)
        kg_key_padding: torch.Tensor = None,     # (B, T_kg) True=pad
        causal_mask: torch.Tensor = None,        # (T_txt, T_txt) True=block
    ) -> torch.Tensor:
        # 1) Masked causal self-attention -> only look at previous tokens
        h = self.norm1(x)
        attn, _ = self.self_attn(
            h, h, h,
            attn_mask=causal_mask,
            key_padding_mask=txt_key_padding,
            need_weights=False,
        )
        x = x + self.drop(attn)

        # 2) cross-attention to image memory
        h = self.norm2(x)
        attn, _ = self.cross_attn_img(
            h, img_mem, img_mem,
            key_padding_mask=img_key_padding,
            need_weights=False,
        )

        # Add & Norm
        x = x + self.drop(attn)
        h = self.norm3(x)

        # 3) cross-attention to knowledge memory
        attn, _ = self.cross_attn_kg(
            h, kg_mem, kg_mem,
            key_padding_mask=kg_key_padding,
            need_weights=False,
        )

        # Add & Norm
        x = x + self.drop(attn)
        h = self.norm4(x)

        # 4) FFN
        x = x + self.drop(self.ffn(h))
        return x


In [None]:
class DualEncoderCaptionDecoder(nn.Module):
    """
    Autoregressive caption decoder that attends to two encoder memories (image + knowledge).
    """
    def __init__(
        self,
        vocab_size: int,
        embedding_layer: nn.Embedding,
        d_model: int = 768,
        n_layers: int = 6,
        n_heads: int = 12,
        mlp_ratio: float = 4.0,
        dropout: float = 0.1,
        pad_token_id: int = 0,
    ):
        super().__init__()
        self.pad_token_id = pad_token_id
        self.d_model = d_model

        self.token_embed = embedding_layer
        self.pos_encoder = PositionalEncoding(d_model=d_model, dropout=dropout)
        self.drop = nn.Dropout(dropout)

        self.layers = nn.ModuleList([
            DualCrossAttnDecoderLayer(d_model, n_heads, mlp_ratio, dropout)
            for _ in range(n_layers)
        ])

        self.final_norm = nn.LayerNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)

    @staticmethod
    def _causal_mask(T: int, device) -> torch.Tensor:
        # MultiheadAttention expects True where positions are NOT allowed to attend
        return torch.triu(torch.ones(T, T, device=device, dtype=torch.bool), diagonal=1)

    def forward(
        self,
        input_ids: torch.Tensor,            # (B, T_txt)
        attention_mask: torch.Tensor,       # (B, T_txt) 1=keep,0=pad
        img_mem: torch.Tensor,              # (B, T_img, d)
        kg_mem: torch.Tensor,               # (B, T_kg, d)
        kg_attention_mask: torch.Tensor = None,  # (B, T_kg) 1=keep,0=pad
        img_attention_mask: torch.Tensor = None, # (B, T_img) (often None/all ones)
    ) -> torch.Tensor:
        B, T = input_ids.shape
        x = self.token_embed(input_ids)  # (B, T, d)
        x = self.pos_encoder(x)
        x = self.drop(x)

        causal = self._causal_mask(T, input_ids.device)
        txt_key_padding = (attention_mask == 0)

        kg_key_padding = (kg_attention_mask == 0) if kg_attention_mask is not None else None
        img_key_padding = (img_attention_mask == 0) if img_attention_mask is not None else None

        for layer in self.layers:
            x = layer(
                x,
                img_mem=img_mem,
                kg_mem=kg_mem,
                txt_key_padding=txt_key_padding,
                img_key_padding=img_key_padding,
                kg_key_padding=kg_key_padding,
                causal_mask=causal,
            )

        x = self.final_norm(x)
        logits = self.lm_head(x)  # (B, T, vocab_size)
        return logits


## Full transformer model

In [None]:
class KnowledgeEnhancedTransformer(nn.Module):
    def __init__(
        self,
        vision_encoder: ViTImageEncoder,
        knowledge_encoder: KnowledgeTextEncoder,
        caption_decoder: DualEncoderCaptionDecoder,
    ):
        super().__init__()
        self.vision_encoder = vision_encoder
        self.knowledge_encoder = knowledge_encoder
        self.caption_decoder = caption_decoder

    def forward(
        self, 
        pixel_values,        # Raw image pixels (B, 3, H, W)
        knowledge_ids,       # ConceptNet fact IDs (B, T_kg)
        caption_ids,         # Target caption IDs for teacher forcing (B, T_txt)
        caption_mask,        # Mask for captions
        knowledge_mask=None  # Mask for knowledge padding
    ) -> torch.Tensor:
        # 1. Run the Encoders
        # Output: (B, T_img, d_model)
        img_mem = self.vision_encoder(pixel_values)
        
        # Output: (B, T_kg, d_model)
        kg_mem = self.knowledge_encoder(knowledge_ids)

        # 2. Run the Decoder
        # The decoder now receives the actual computed features
        logits = self.caption_decoder(
            input_ids=caption_ids,
            attention_mask=caption_mask,
            img_mem=img_mem,
            kg_mem=kg_mem,
            kg_attention_mask=knowledge_mask
        )
        
        return logits