In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

#######################################
# Basic Multi-Head Attention Module
#######################################
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"

        # Linear projections for Q, K, and V
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)

        # Project input embeddings and reshape for multi-head attention: (B, seq_len, embed_dim)
        q = self.q_proj(query).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(key).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(value).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)

        # Compute scaled dot-product attention
        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn = F.softmax(scores, dim=-1)

        # Multiply attention weights with value vectors
        context = torch.matmul(attn, v)  # shape: (B, num_heads, seq_len_q, head_dim)
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.embed_dim)
        output = self.out_proj(context)
        return output

#######################################
# Transformer Encoder Layer for Text Encoder
#######################################
class TransformerEncoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(embed_dim, num_heads)
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim)
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # Self-attention (text-only)
        attn_output = self.self_attn(x, x, x, mask)
        x = x + self.dropout(attn_output)
        x = self.norm1(x)
        # Feed-forward network
        ffn_output = self.ffn(x)
        x = x + self.dropout(ffn_output)
        x = self.norm2(x)
        return x

#######################################
# Text Encoder Module
#######################################
class TextEncoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_layers, num_heads, ff_dim, max_seq_len=128):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
        # Learnable positional embeddings
        self.pos_embed = nn.Parameter(torch.zeros(1, max_seq_len, embed_dim))
        self.layers = nn.ModuleList([
            TransformerEncoderLayer(embed_dim, num_heads, ff_dim)
            for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, text):
        # text: (batch, seq_len)
        x = self.embed(text)  # (B, seq_len, embed_dim)
        x = x + self.pos_embed[:, :x.size(1), :]
        for layer in self.layers:
            x = layer(x)
        x = self.norm(x)
        return x  # (B, seq_len, embed_dim)

#######################################
# Image Encoder Module
#######################################
class ImageEncoder(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        # A simple CNN to obtain patches from an image sized 256x256.
        # Using a kernel_size and stride of 16, we obtain 16x16 = 256 patches.
        self.conv = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=16, stride=16),  # from 256x256 -> 16x16 feature map (64 channels)
            nn.ReLU(),
            nn.Conv2d(64, embed_dim, kernel_size=1)
        )
        self.flatten = nn.Flatten(start_dim=2)  # flatten the spatial dimensions

    def forward(self, x):
        # x: (batch, 3, 256, 256)
        x = self.conv(x)  # (batch, embed_dim, 16, 16)
        x = self.flatten(x).transpose(1, 2)  # (batch, 256, embed_dim) with 256 patches per image
        return x

#######################################
# Fused Knowledge Module
#######################################
class FusedKnowledge(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        # We use multi-head attention to fuse the text encoder output (query)
        # with image encoder output (keys and values)
        self.cross_attn = MultiHeadAttention(embed_dim, num_heads)
        self.norm = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(0.1)

    def forward(self, text_repr, image_repr, mask=None):
        # text_repr: (batch, text_seq_len, embed_dim) used as query
        # image_repr: (batch, image_patches, embed_dim) used as key and value
        fused = self.cross_attn(text_repr, image_repr, image_repr, mask)
        # Add residual connection and layer normalization
        fused = self.norm(text_repr + self.dropout(fused))
        return fused  # (batch, text_seq_len, embed_dim)

#######################################
# Transformer Decoder for Captioning
#######################################
class Decoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_layers, num_heads, ff_dim, max_seq_len=128):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.pos_embed = nn.Parameter(torch.zeros(1, max_seq_len, embed_dim))
        self.layers = nn.ModuleList([
            TransformerDecoderLayer(embed_dim, num_heads, ff_dim)
            for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        # Tie the output weights with the embedding matrix
        self.output_proj = nn.Linear(embed_dim, vocab_size)

    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None):
        # tgt: (batch, tgt_seq_len)
        x = self.embed(tgt)  # (batch, tgt_seq_len, embed_dim)
        x = x + self.pos_embed[:, :x.size(1), :]
        for layer in self.layers:
            x = layer(x, memory, tgt_mask, memory_mask)
        x = self.norm(x)
        logits = self.output_proj(x)
        return logits

#######################################
# Transformer Decoder Layer (for decoder)
#######################################
class TransformerDecoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(embed_dim, num_heads)
        self.cross_attn = MultiHeadAttention(embed_dim, num_heads)
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim)
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.norm3 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None):
        # Self-attention on the target sequence
        self_attn_output = self.self_attn(tgt, tgt, tgt, tgt_mask)
        tgt = tgt + self.dropout(self_attn_output)
        tgt = self.norm1(tgt)

        # Cross-attention with memory from the fusion module
        cross_attn_output = self.cross_attn(tgt, memory, memory, memory_mask)
        tgt = tgt + self.dropout(cross_attn_output)
        tgt = self.norm2(tgt)

        # Feed-forward network
        ffn_output = self.ffn(tgt)
        tgt = tgt + self.dropout(ffn_output)
        tgt = self.norm3(tgt)
        return tgt

#######################################
# Full Multimodal Fusion Model for Image Captioning
#######################################
class MultimodalFusionModel(nn.Module):
    def __init__(self, vocab_size,
                 embed_dim=512,
                 num_layers_enc=2,
                 num_layers_dec=4,
                 num_heads=8,
                 ff_dim=2048,
                 max_seq_len=128):
        super().__init__()
        # Text encoder for any prompt or auxiliary text input
        self.text_encoder = TextEncoder(vocab_size, embed_dim, num_layers_enc, num_heads, ff_dim, max_seq_len)
        # Image encoder to extract patch features from 256x256 images
        self.image_encoder = ImageEncoder(embed_dim)
        # Fusion module to fuse text (query) with image (key,value)
        self.fusion = FusedKnowledge(embed_dim, num_heads)
        # Decoder to generate captions from fused memory representation
        self.decoder = Decoder(vocab_size, embed_dim, num_layers_dec, num_heads, ff_dim, max_seq_len)
        self.vocab_size = vocab_size

    def forward(self, image, tgt, text_input):
        """
        image: tensor of shape (batch, 3, 256, 256)
        tgt: caption token sequence (batch, tgt_seq_len) for teacher forcing
        text_input: text tokens for the text encoder (batch, text_seq_len)
                    This could be a fixed prompt or additional text description.
        """
        # Obtain text and image representations
        text_repr = self.text_encoder(text_input)         # (B, text_seq_len, embed_dim)
        image_repr = self.image_encoder(image)              # (B, num_patches, embed_dim)

        # Fuse text and image features; here text representation serves as queries
        fused_memory = self.fusion(text_repr, image_repr)   # (B, text_seq_len, embed_dim)

        # Create the causal mask for the decoder (to ensure auto-regressive decoding)
        batch_size, tgt_seq_len = tgt.size(0), tgt.size(1)
        causal_mask = torch.triu(torch.ones(tgt_seq_len, tgt_seq_len, device=tgt.device), diagonal=1).bool().logical_not()
        tgt_padding_mask = (tgt != 0).unsqueeze(1).unsqueeze(2)
        tgt_mask = (causal_mask[None, None, :, :] & tgt_padding_mask).to(torch.float)

        # Decode captions using the fused memory as encoder output
        logits = self.decoder(tgt, fused_memory, tgt_mask=tgt_mask)
        return logits

    def generate(self, image, text_input, max_len=50):
        """
        Auto-regressively generate captions.
        image: (batch, 3, 256, 256)
        text_input: (batch, text_seq_len)
        """
        self.eval()
        batch_size = image.size(0)
        # Compute fused memory from text and image representations
        text_repr = self.text_encoder(text_input)
        image_repr = self.image_encoder(image)
        fused_memory = self.fusion(text_repr, image_repr)

        # Initialize generated caption with the <start> token (assumed id=1)
        tgt = torch.ones(batch_size, 1, device=image.device, dtype=torch.long) * 1
        generated_tokens = []

        for _ in range(max_len):
            logits = self.decoder(tgt, fused_memory)
            next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)
            generated_tokens.append(next_token)
            tgt = torch.cat([tgt, next_token], dim=1)
            # Check if all sequences predicted the <end> token (assumed id=2)
            if (next_token == 2).all():
                break

        # Concatenate and return generated tokens (without the initial <start>)
        return torch.cat(generated_tokens, dim=1)

#######################################
# Example instantiation and usage
#######################################
if __name__ == '__main__':
    # Assume vocabulary size of 10000 tokens
    vocab_size = 10000
    model = MultimodalFusionModel(vocab_size)
    # Dummy inputs: a batch of 2 images (3,256,256) and text prompt tokens (batch, text_seq_len)
    dummy_images = torch.randn(2, 3, 256, 256)
    dummy_text = torch.randint(0, vocab_size, (2, 10))  # e.g., a prompt with 10 tokens
    dummy_tgt = torch.randint(0, vocab_size, (2, 12))   # target captions for teacher forcing
    # Forward pass
    logits = model(dummy_images, dummy_tgt, dummy_text)
    print("Logits shape:", logits.shape)
    # Generation (auto-regressive decoding)
    generated = model.generate(dummy_images, dummy_text)
    print("Generated shape:", generated.shape)


Logits shape: torch.Size([2, 12, 10000])
Generated shape: torch.Size([2, 50])


In [None]:
import json
from collections import Counter
import os
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torch.cuda.amp import GradScaler, autocast
from PIL import Image
import torchvision.transforms as transforms

# Import the updated multimodal fusion model with two encoders, fusion, and decoder
# from ModelFlow import MultimodalFusionModel

#############################################
# Dataset for Image Captioning
#############################################
class CaptioningDataset(Dataset):
    def __init__(self, json_file, transform=None):
        with open(json_file, 'r') as f:
            self.data = json.load(f)
        all_captions = [caption.lower().split() for caption in self.data.values()]
        word_freq = Counter(word for caption in all_captions for word in caption)
        # Reserve special tokens: <pad>, <start>, <end>, <unk>
        self.vocab = {'<pad>': 0, '<start>': 1, '<end>': 2, '<unk>': 3}
        # Allow additional tokens (here up to 29996 tokens) from the frequency list
        self.vocab.update({word: idx for idx, (word, _) in enumerate(word_freq.most_common(29996), start=4)})
        self.idx_to_word = {idx: word for word, idx in self.vocab.items()}
        self.image_paths = list(self.data.keys())
        self.captions = []
        for caption in self.data.values():
            words = caption.lower().split()
            indices = [self.vocab.get(word, self.vocab['<unk>']) for word in words]
            # Surround each caption with start and end tokens
            self.captions.append([self.vocab['<start>']] + indices + [self.vocab['<end>']])
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        caption = self.captions[idx]
        return image, caption

#############################################
# Collate function to batch data
#############################################
def collate_fn(batch):
    images, captions = zip(*batch)
    images = torch.stack(images)
    max_len = max(len(c) for c in captions)
    padded_captions = [c + [0] * (max_len - len(c)) for c in captions]
    return images, torch.tensor(padded_captions)

#############################################
# Utility: Denormalize image for visualization
#############################################
def denormalize(image):
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    image = image.cpu() * std + mean
    image = image.clamp(0, 1).permute(1, 2, 0).numpy()
    return image

#############################################
# Training Loop
#############################################
if __name__ == '__main__':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Transformation for images
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # Create dataset and compute vocabulary size
    dataset = CaptioningDataset('assets/processed_captions.json', transform=transform)
    vocab_size = len(dataset.vocab)

    # Instantiate the multimodal model.
    # Note: MultimodalFusionModel requires three inputs: image, tgt, and text_input.
    model = MultimodalFusionModel(vocab_size=vocab_size,
                                  embed_dim=512,
                                  num_layers_enc=2,  # Number of layers in the text encoder
                                  num_layers_dec=4,  # Number of layers in the decoder
                                  num_heads=8,
                                  ff_dim=2048,
                                  max_seq_len=128).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    scaler = GradScaler()

    # Split dataset into training and validation sets.
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

    train_loader = DataLoader(train_dataset,
                              batch_size=32,
                              shuffle=True,
                              collate_fn=collate_fn,
                              num_workers=4,
                              pin_memory=True)
    val_loader = DataLoader(val_dataset,
                            batch_size=32,
                            shuffle=False,
                            collate_fn=collate_fn,
                            num_workers=4,
                            pin_memory=True)

    os.makedirs('checkpoints', exist_ok=True)
    os.makedirs('visualizations', exist_ok=True)

    num_epochs = 100

    # For this example, we use a fixed text prompt as input to the text encoder.
    # You can modify prompt_length as needed.
    prompt_length = 10  # e.g. 10 tokens prompt. Here we fill it with the <start> token (assumed id=1).
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0

        for images, captions in train_loader:
            images, captions = images.to(device), captions.to(device)
            # Teacher forcing: use caption tokens except the last one as target input.
            tgt = captions[:, :-1]
            target = captions[:, 1:]

            # Create a fixed text prompt. Here we use a vector filled with <start> token (id=1).
            batch_size = images.size(0)
            text_prompt = torch.ones(batch_size, prompt_length, device=device, dtype=torch.long) * 1

            optimizer.zero_grad()
            with autocast():
                # Forward pass uses three inputs: image, tgt, and text_input (the text prompt)
                logits = model(image=images, tgt=tgt, text_input=text_prompt)
                loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)),
                                       target.reshape(-1),
                                       ignore_index=0)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}')

        # Save checkpoint and visualize predictions every 10 epochs (or on the final epoch)
        if (epoch + 1) % 10 == 0 or epoch == num_epochs - 1:
            checkpoint_path = f'checkpoints/model_epoch_{epoch+1}.pth'
            torch.save(model.state_dict(), checkpoint_path)
            print(f'Model saved to {checkpoint_path}')

            model.eval()
            with torch.no_grad():
                for val_images, val_captions in val_loader:
                    val_images = val_images.to(device)
                    # Use the same fixed prompt for generation.
                    batch_size = val_images.size(0)
                    text_prompt = torch.ones(batch_size, prompt_length, device=device, dtype=torch.long) * 1

                    # Generate captions auto-regressively
                    generated = model.generate(val_images, text_input=text_prompt, max_len=50)

                    samples = []
                    # Collect sample outputs to visualize (first 3 samples)
                    for i in range(min(3, len(generated))):
                        gen_indices = generated[i].tolist()
                        gen_words = [dataset.idx_to_word.get(idx, '<unk>')
                                     for idx in gen_indices if idx not in [0, 1, 2]]
                        gt_indices = val_captions[i].tolist()
                        gt_words = [dataset.idx_to_word.get(idx, '<unk>')
                                    for idx in gt_indices if idx not in [0, 1, 2]]
                        samples.append({
                            'image': val_images[i],
                            'gen_caption': ' '.join(gen_words),
                            'gt_caption': ' '.join(gt_words)
                        })
                    break

                # Plot and save the visualizations
                fig, axes = plt.subplots(3, 1, figsize=(8, 24))
                for i, sample in enumerate(samples):
                    img = denormalize(sample['image'])
                    axes[i].imshow(img)
                    axes[i].set_title(
                        f"Generated: {sample['gen_caption']}\nGround Truth: {sample['gt_caption']}",
                        fontsize=10
                    )
                    axes[i].axis('off')
                plt.tight_layout()
                plt.savefig(f'visualizations/epoch_{epoch+1}.png')
                plt.close()

    print("Training completed. Model is ready for finetuning on logical reasoning.")
print()


Using device: cuda
