In [18]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T
from torch.utils.data import DataLoader
import math

In [None]:
def create_2x2_chunk_order(num_rows=28, num_cols=28):
    """
    Create a custom ordering of patch indices for a (num_rows x num_cols) grid,
    chunked in 2x2, in row-major order inside each chunk.

    Returns:
        reorder_indices: a list (or np.array) of length (num_rows * num_cols)
                         with the new ordering (0-based).
    """
    # Create a 28x28 array of indices [0..783]
    indices = np.arange(num_rows * num_cols).reshape(num_rows, num_cols)
    
    reorder_list = []
    # We'll move in steps of 2 along rows and columns
    # to form 2x2 chunks
    for row_block in range(0, num_rows, 2):
        for col_block in range(0, num_cols, 2):
            # This chunk is 2x2
            # top-left
            reorder_list.append(indices[row_block, col_block])
            # top-right
            reorder_list.append(indices[row_block, col_block + 1])
            # bottom-left
            reorder_list.append(indices[row_block + 1, col_block])
            # bottom-right
            reorder_list.append(indices[row_block + 1, col_block + 1])
    
    reorder_array = np.array(reorder_list, dtype=int)
    return reorder_array


In [19]:
class PatchEmbed(nn.Module):
    """
    A generic patch embedding module that splits an image into patches,
    flattens each patch, and projects it to a desired embedding dimension.
    """
    def __init__(self, in_channels=3, embed_dim=768, patch_size=16, img_size=224):
        super().__init__()
        self.patch_size = patch_size
        self.img_size = img_size
        self.num_patches_h = img_size // patch_size
        self.num_patches_w = img_size // patch_size
        self.num_patches = self.num_patches_h * self.num_patches_w  # e.g. 196 when patch_size=16
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        """
        x: shape [B, 3, H, W]
        Return: shape [B, num_patches, embed_dim]
        """
        # Conv2d with kernel_size=patch_size => each patch is turned into a single token
        x = self.proj(x)  # shape [B, embed_dim, num_patches_h, num_patches_w]
        x = x.flatten(2)  # shape [B, embed_dim, num_patches]
        x = x.transpose(1, 2)  # shape [B, num_patches, embed_dim]
        return x


In [20]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.0):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        
        self.attn_dropout = nn.Dropout(dropout)
        self.proj_dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        """
        x: [B, N, D] where N is number of tokens, D is embed_dim
        """
        B, N, D = x.shape
        assert D == self.embed_dim

        qkv = self.qkv(x)  # [B, N, 3D]
        qkv = qkv.reshape(B, N, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # [3, B, num_heads, N, head_dim]
        q, k, v = qkv[0], qkv[1], qkv[2]  # each: [B, num_heads, N, head_dim]
        
        # Compute attention scores
        attn_logits = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        # attn_logits shape: [B, num_heads, N, N]
        if mask is not None:
            attn_logits = attn_logits + mask  # mask should be broadcastable

        attn = F.softmax(attn_logits, dim=-1)
        attn = self.attn_dropout(attn)

        out = torch.matmul(attn, v)  # [B, num_heads, N, head_dim]
        out = out.permute(0, 2, 1, 3).reshape(B, N, D)
        out = self.out_proj(out)
        out = self.proj_dropout(out)
        return out


class TransformerBlock(nn.Module):
    """
    A standard ViT Transformer block:
        - LayerNorm
        - MSA
        - Add/Skip
        - LayerNorm
        - MLP
        - Add/Skip
    """
    def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, dropout=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadSelfAttention(embed_dim, num_heads, dropout=dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        
        hidden_dim = int(embed_dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, embed_dim),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        # Self-attention
        x_norm = self.norm1(x)
        attn_out = self.attn(x_norm)
        x = x + attn_out  # skip connection

        # MLP
        x_norm = self.norm2(x)
        mlp_out = self.mlp(x_norm)
        x = x + mlp_out
        return x


In [21]:
class MergeBlock(nn.Module):
    """
    Custom block for the 9th layer that merges the two branches:
        logits = Q_A K_A^T + Q_B K_B^T
        attention = softmax(logits)
        out = attention * (V_A + V_B)

    Then do skip connections, MLP, etc.
    """
    def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, dropout=0.0):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        # We'll create linear layers for Q, K, V for branch A and B
        self.qkv_A = nn.Linear(embed_dim, 3 * embed_dim)
        self.qkv_B = nn.Linear(embed_dim, 3 * embed_dim)

        self.out_proj = nn.Linear(embed_dim, embed_dim)
        
        self.dropout_attn = nn.Dropout(dropout)
        self.dropout_proj = nn.Dropout(dropout)

        self.norm1_A = nn.LayerNorm(embed_dim)
        self.norm1_B = nn.LayerNorm(embed_dim)

        # For the MLP part (shared after merge)
        self.norm2 = nn.LayerNorm(embed_dim)
        hidden_dim = int(embed_dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, embed_dim),
            nn.Dropout(dropout),
        )

    def forward(self, xA, xB):
        """
        xA: [B, N, D] (branch A tokens after 8 blocks)
        xB: [B, N, D] (branch B tokens after 8 blocks)
        Returns: [B, N, D] (merged)
        """
        B, N, D = xA.shape
        # 1) Norm
        xA_norm = self.norm1_A(xA)  # shape [B, N, D]
        xB_norm = self.norm1_B(xB)  # shape [B, N, D]

        # 2) Project to Q,K,V for each branch
        qkv_A = self.qkv_A(xA_norm).view(B, N, 3, self.num_heads, self.head_dim)
        qkv_B = self.qkv_B(xB_norm).view(B, N, 3, self.num_heads, self.head_dim)

        # qkv_A => [B, N, 3, num_heads, head_dim]
        qkv_A = qkv_A.permute(2, 0, 3, 1, 4)  # => [3, B, num_heads, N, head_dim]
        qA, kA, vA = qkv_A[0], qkv_A[1], qkv_A[2]

        qkv_B = qkv_B.permute(2, 0, 3, 1, 4)  # => [3, B, num_heads, N, head_dim]
        qB, kB, vB = qkv_B[0], qkv_B[1], qkv_B[2]

        # 3) Compute combined attention logits
        #    shape of qA, kA => [B, num_heads, N, head_dim]
        logits_A = torch.matmul(qA, kA.transpose(-2, -1)) / math.sqrt(self.head_dim)
        logits_B = torch.matmul(qB, kB.transpose(-2, -1)) / math.sqrt(self.head_dim)

        logits = logits_A + logits_B  # [B, num_heads, N, N]

        attn = F.softmax(logits, dim=-1)
        attn = self.dropout_attn(attn)

        # 4) Multiply by (vA + vB)
        # shape vA, vB => [B, num_heads, N, head_dim]
        vSum = vA + vB
        out = torch.matmul(attn, vSum)  # [B, num_heads, N, head_dim]

        # 5) Reshape & project
        out = out.permute(0, 2, 1, 3).reshape(B, N, D)
        out = self.out_proj(out)
        out = self.dropout_proj(out)

        # 6) Skip connection
        # We return a single route, so let's pick xA as the base route 
        # or you could average xA and xB. 
        # Or, to be consistent, let's do (xA + xB) for skip. 
        # It's your design choice. 
        x_merged = (xA + xB) + out  

        # 7) MLP with skip
        x_norm = self.norm2(x_merged)
        mlp_out = self.mlp(x_norm)
        x_merged = x_merged + mlp_out

        return x_merged


In [22]:
class CustomTwoBranchViT(nn.Module):
    def __init__(
        self,
        img_size=224,
        embed_dim=256,
        num_heads=4,
        mlp_ratio=4.0,
        num_classes=10,
        dropout=0.0,
    ):
        super().__init__()
        self.embed_dim = embed_dim
        
        # -----------------------
        # Branch A: 8×8 patches
        # -----------------------
        self.patch_embed_A = PatchEmbed(
            in_channels=3, embed_dim=embed_dim, patch_size=8, img_size=img_size
        )
        self.num_patches_A = (img_size // 8) ** 2  # 28×28=784

        # -----------------------
        # Branch B: 16×16 patches
        # -----------------------
        self.patch_embed_B = PatchEmbed(
            in_channels=3, embed_dim=embed_dim, patch_size=16, img_size=img_size
        )
        self.num_patches_B = (img_size // 16) ** 2  # 14×14=196 => repeated 4× => 784

        # -----------------------
        # Positional Embeddings
        # Each route has its own
        # -----------------------
        self.pos_embed_A = nn.Parameter(torch.zeros(1, self.num_patches_A, embed_dim))
        self.pos_embed_B = nn.Parameter(torch.zeros(1, self.num_patches_A, embed_dim))
        # note: we store B's pos_embed in shape [1, 784, D], even though we only have 196 unique patches
        # but we'll just replicate them. Alternatively, you can only store [1,196,D] and tile 4x at forward().

        # -----------------------
        # Transformer Blocks 1–8
        # (separate for A & B)
        # -----------------------
        self.blocks_A = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout=dropout)
            for _ in range(8)
        ])
        self.blocks_B = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout=dropout)
            for _ in range(8)
        ])

        # -----------------------
        # Merge Block (9th)
        # -----------------------
        self.merge_block = MergeBlock(embed_dim, num_heads, mlp_ratio, dropout=dropout)

        # -----------------------
        # Blocks 10–12 (single route)
        # -----------------------
        self.blocks_merged = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout=dropout)
            for _ in range(3)
        ])

        # -----------------------
        # Classification Head
        # -----------------------
        self.norm_final = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

        # Init params
        self._init_weights()

    def _init_weights(self):
        # Simple initialization
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.trunc_normal_(m.weight, std=0.02)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.LayerNorm):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, x):
        # ---------------------------------------
        # 1) Branch A (8×8)
        # ---------------------------------------
         # 1) Extract patches for Branch A
        xA = self.patch_embed_A(x)   # shape [B, 784, embed_dim]

        # 2) Reorder the patches using your custom indices
        #    (make sure `reorder_indices_2x2` is defined beforehand)
        xA = xA[:, reorder_indices_2x2, :]

        # 3) Add positional embeddings
        xA = xA + self.pos_embed_A

        # Pass through 8 blocks
        for blk in self.blocks_A:
            xA = blk(xA)

        # ---------------------------------------
        # 2) Branch B (16×16)
        # ---------------------------------------
        xB = self.patch_embed_B(x)  # [B, 196, embed_dim]

        # Repeat each patch 4× => shape 
        
        # One approach: repeat the sequence dimension
        # xB has shape [B, 196, D]
        # We replicate along dim=1 => repeat_interleave(4, dim=1) => [B, 784, D]
        xB = xB.repeat_interleave(4, dim=1)

        # Add route B's position embedding
        xB = xB + self.pos_embed_B  # shape [B, 784, D]

        # Pass through 8 blocks
        for blk in self.blocks_B:
            xB = blk(xB)

        # ---------------------------------------
        # 3) Merge at block 9
        # ---------------------------------------
        x_merged = self.merge_block(xA, xB)  # shape [B, 784, D]

        # ---------------------------------------
        # 4) Blocks 10–12 (single route)
        # ---------------------------------------
        for blk in self.blocks_merged:
            x_merged = blk(x_merged)

        # ---------------------------------------
        # 5) Classification Head
        # ---------------------------------------
        # For ViT, common to do e.g. global average pooling
        # or use a [CLS] token. We'll do mean-pool here:
        x_merged = self.norm_final(x_merged)
        # x_merged => [B, 784, D]
        x_pooled = x_merged.mean(dim=1)  # [B, D]

        logits = self.head(x_pooled)  # [B, num_classes]
        return logits


In [23]:
def train_on_cifar10():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Hyperparameters
    num_epochs = 2
    batch_size = 16
    learning_rate = 1e-4

    # Data transforms
    transform_train = T.Compose([
        T.Resize(224),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.Normalize((0.4914, 0.4822, 0.4465),
                    (0.2023, 0.1994, 0.2010)),
    ])
    transform_test = T.Compose([
        T.Resize(224),
        T.ToTensor(),
        T.Normalize((0.4914, 0.4822, 0.4465),
                    (0.2023, 0.1994, 0.2010)),
    ])

    # Datasets
    train_dataset = torchvision.datasets.CIFAR10(
        root='./data',
        train=True,
        download=True,
        transform=transform_train
    )
    test_dataset = torchvision.datasets.CIFAR10(
        root='./data',
        train=False,
        download=True,
        transform=transform_test
    )

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

    # Create model
    model = CustomTwoBranchViT(
        img_size=224,
        embed_dim=256,     # smaller than typical ViT-Base for demonstration
        num_heads=4,
        mlp_ratio=4.0,
        num_classes=10,
        dropout=0.1
    ).to(device)

    # Optimizer / Loss
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss()

    # Training loop
    best_acc = 0.0
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        epoch_loss = running_loss / len(train_loader)
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")

        # Evaluate
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                _, preds = torch.max(outputs, 1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)

        acc = 100.0 * correct / total
        print(f"Test Accuracy: {acc:.2f}%")

        if acc > best_acc:
            best_acc = acc
            torch.save(model.state_dict(), "best_two_branch_vit_cifar10.pth")
            print("  [*] Best model saved.")

    print(f"Training completed. Best accuracy: {best_acc:.2f}%")

train_on_cifar10()


Using device: cuda
Files already downloaded and verified
Files already downloaded and verified


KeyboardInterrupt: 