In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as T
import math

# -------------------------
# 1. CIFAR-10 Datasets & Dataloaders
# -------------------------
def get_cifar10_loaders(batch_size=64, num_workers=2):
    """
    Returns train/val loaders for CIFAR-10 dataset.
    Images are resized from 32x32 to 224x224 and normalized.
    """
    # Common CIFAR-10 statistics
    mean = (0.4914, 0.4822, 0.4465)
    std  = (0.2470, 0.2435, 0.2616)

    # Train transforms: resize to 224, random crop, random flip
    train_transform = T.Compose([
        T.Resize(224),
        T.RandomCrop(224, padding=4),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.Normalize(mean, std),
    ])

    # Validation transforms: resize to 224, just center crop or direct resize
    val_transform = T.Compose([
        T.Resize(224),
        T.ToTensor(),
        T.Normalize(mean, std),
    ])

    # Download CIFAR-10
    train_dataset = torchvision.datasets.CIFAR10(
        root='./data',
        train=True,
        download=True,
        transform=train_transform
    )
    val_dataset = torchvision.datasets.CIFAR10(
        root='./data',
        train=False,
        download=True,
        transform=val_transform
    )

    # Create Data Loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size,
                              shuffle=True, num_workers=num_workers)
    val_loader   = DataLoader(val_dataset,   batch_size=batch_size,
                              shuffle=False, num_workers=num_workers)
    return train_loader, val_loader


# -------------------------
# 2. Progressive Transformer Model
# -------------------------

class PatchEmbedding(nn.Module):
    """
    Splits 224x224 images into (224/8=28) x (224/8=28)=784 patches,
    each of size 8x8x3 flattened -> projected to hidden_dim.
    Adds a [CLS] token and learns positional embeddings.
    """
    def __init__(self, img_size=224, patch_size=8, in_chans=3, hidden_dim=786):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = img_size // patch_size  # 224/8=28
        num_patches = self.grid_size * self.grid_size  # 784

        self.proj = nn.Conv2d(in_chans, hidden_dim,
                              kernel_size=patch_size, stride=patch_size)
        # This yields shape: (B, hidden_dim, 28, 28) => flatten => (B, 784, hidden_dim)

        self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches+1, hidden_dim))

        self.num_patches = num_patches

    def forward(self, x):
        B = x.shape[0]
        # x: (B, 3, 224, 224)
        x = self.proj(x)  # -> (B, hidden_dim, 28, 28)
        x = x.flatten(2)  # -> (B, hidden_dim, 784)
        x = x.transpose(1, 2)  # -> (B, 784, hidden_dim)

        # Concat CLS token
        cls_token = self.cls_token.expand(B, -1, -1)  # (B, 1, hidden_dim)
        x = torch.cat([cls_token, x], dim=1)  # (B, 785, hidden_dim)

        # Add positional embeddings
        x = x + self.pos_embed[:, : x.size(1), :]  # broadcast along batch

        return x  # (B, 785, hidden_dim)


class TransformerEncoderBlock(nn.Module):
    """
    Standard Transformer block: MHSA + MLP (with skip & layernorm).
    """
    def __init__(self, hidden_dim=786, num_heads=6, mlp_ratio=4.0, dropout=0.0):
        super().__init__()
        self.ln1 = nn.LayerNorm(hidden_dim)
        self.attn = nn.MultiheadAttention(embed_dim=hidden_dim,
                                          num_heads=num_heads,
                                          dropout=dropout,
                                          batch_first=True)  # batch_first => (B, N, C)

        self.ln2 = nn.LayerNorm(hidden_dim)

        self.mlp = nn.Sequential(
            nn.Linear(hidden_dim, int(hidden_dim * mlp_ratio)),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(int(hidden_dim * mlp_ratio), hidden_dim),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        # x shape: (B, N, hidden_dim)
        x_norm = self.ln1(x)
        attn_out, _ = self.attn(x_norm, x_norm, x_norm)  # shape (B, N, hidden_dim)
        x = x + attn_out

        x_norm = self.ln2(x)
        x_mlp = self.mlp(x_norm)
        x = x + x_mlp

        return x


def pairwise_mean_downsample(x):
    """
    x shape: (B, N, C)
    We'll keep the first token (CLS) separate,
    then take pairs of patch tokens, averaging them.

    Example:
      - CLS index = 0 => remains alone
      - Patch tokens = x[:, 1:, :] => shape (B, N-1, C)

      We'll pair (0,1), (2,3), ...
      So the new patch token count = (N-1)//2

      Then we re-concat CLS in front => total tokens = (N-1)//2 + 1
    """
    B, N, C = x.shape
    assert N > 1, "We need patch tokens + CLS at least."

    cls_token = x[:, 0:1, :]    # (B, 1, C)
    patch_tokens = x[:, 1:, :]  # (B, N-1, C)

    # Pairwise reshape: (B, (N-1)//2, 2, C) and take mean along dim=2
    # But we must ensure (N-1) is even, as described in your scenario
    # e.g., 784 -> 392, 392 -> 196, etc.
    assert (patch_tokens.shape[1] % 2) == 0, \
        f"Number of patch tokens must be even, got {patch_tokens.shape[1]}"

    patch_tokens = patch_tokens.reshape(B, patch_tokens.shape[1] // 2, 2, C)
    patch_tokens = patch_tokens.mean(dim=2)  # (B, new_N, C)

    # Concat CLS
    x_down = torch.cat([cls_token, patch_tokens], dim=1)
    return x_down


class ProgressiveTransformer(nn.Module):
    """
    A 12-layer Transformer with hidden_dim=786, that:
      - uses patch size=8 => 785 tokens initially
      - after block 3: downsample 784 patch -> 392 patch => total 393
      - after block 6: downsample 392 -> 196 => total 197
      - after block 9: downsample 196 -> 98 => total 99
      - blocks 10, 11, 12 keep it at 99 tokens
      - final linear for 10-class CIFAR output
    """
    def __init__(self, 
                 img_size=224,
                 patch_size=8,
                 in_chans=3,
                 hidden_dim=786,
                 num_heads=6,       # 6 * 131 = 786
                 mlp_ratio=4.0,
                 num_layers=12,
                 num_classes=10,
                 dropout=0.0):
        super().__init__()

        # Patch embedding
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_chans, hidden_dim)

        # Create 12 Transformer blocks
        self.blocks = nn.ModuleList([
            TransformerEncoderBlock(hidden_dim, num_heads, mlp_ratio, dropout)
            for _ in range(num_layers)
        ])

        # Final norm + classifier
        self.norm = nn.LayerNorm(hidden_dim)
        self.head = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        # 1) Patchify + Embeddings
        x = self.patch_embed(x)  # (B, 785, 786) initially

        # 2) Pass through 12 blocks, downsampling after 3,6,9
        for i, block in enumerate(self.blocks, start=1):
            x = block(x)  # (B, N, 786)
            
            if i in [3, 6, 9]:  # downsample after these blocks
                x = pairwise_mean_downsample(x) 
                # e.g. after block 3: (B, 393, 786)
                #     after block 6: (B, 197, 786)
                #     after block 9: (B, 99, 786)

        # 3) Final norm + CLS for classification
        x = self.norm(x)  # (B, N, 786)
        cls_token = x[:, 0]  # (B, 786)
        out = self.head(cls_token)  # (B, 10)
        return out


# -------------------------
# 3. Training / Evaluation Helpers
# -------------------------

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def accuracy_topk(logits, targets, topk=(1,5)):
    """
    Computes top-k accuracy for specified k values.
    Returns list of accuracies in percentage.
    """
    max_k = max(topk)
    batch_size = targets.size(0)

    # Get top-k predictions
    _, pred = logits.topk(max_k, dim=1, largest=True, sorted=True)
    pred = pred.t()  # shape (max_k, B)
    correct = pred.eq(targets.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        # For each k, compute how many of the batch are correct in top-k
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
        acc_k = correct_k * 100.0 / batch_size
        res.append(acc_k.item())
    return res


def train_one_epoch(model, loader, optimizer, device):
    model.train()
    total_loss = 0.0
    total_samples = 0
    top1_sum = 0.0
    top5_sum = 0.0

    for images, labels in loader:
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        logits = model(images)
        loss = F.cross_entropy(logits, labels)
        loss.backward()
        optimizer.step()

        # stats
        batch_size = images.size(0)
        total_loss += loss.item() * batch_size
        total_samples += batch_size

        acc1, acc5 = accuracy_topk(logits, labels, topk=(1,5))
        top1_sum += acc1 * batch_size
        top5_sum += acc5 * batch_size

    avg_loss = total_loss / total_samples
    avg_top1 = top1_sum / total_samples
    avg_top5 = top5_sum / total_samples

    return avg_loss, avg_top1, avg_top5


@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    total_loss = 0.0
    total_samples = 0
    top1_sum = 0.0
    top5_sum = 0.0

    for images, labels in loader:
        images = images.to(device)
        labels = labels.to(device)

        logits = model(images)
        loss = F.cross_entropy(logits, labels)

        batch_size = images.size(0)
        total_loss += loss.item() * batch_size
        total_samples += batch_size

        acc1, acc5 = accuracy_topk(logits, labels, topk=(1,5))
        top1_sum += acc1 * batch_size
        top5_sum += acc5 * batch_size

    avg_loss = total_loss / total_samples
    avg_top1 = top1_sum / total_samples
    avg_top5 = top5_sum / total_samples

    return avg_loss, avg_top1, avg_top5


# -------------------------
# 4. Putting It All Together (Example Training Script)
# -------------------------
if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    train_loader, val_loader = get_cifar10_loaders(batch_size=32)

    model = ProgressiveTransformer(
        img_size=224,
        patch_size=8,
        hidden_dim=786,   # user-specified
        num_heads=6,      # 6 heads x 131 dim each = 786
        mlp_ratio=4.0,
        num_layers=12,
        num_classes=10,
        dropout=0.1
    ).to(device)

    # Number of parameters
    num_params = count_parameters(model)
    print(f"Model has {num_params/1e6:.2f} M learnable parameters.")

    # Simple AdamW optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)

    # Train for a few epochs (example)
    epochs = 5
    for epoch in range(1, epochs+1):
        train_loss, train_top1, train_top5 = train_one_epoch(model, train_loader, optimizer, device)
        val_loss, val_top1, val_top5 = evaluate(model, val_loader, device)

        print(f"Epoch [{epoch}/{epochs}]")
        print(f"  Train Loss: {train_loss:.4f}, Top1: {train_top1:.2f}%, Top5: {train_top5:.2f}%")
        print(f"  Val   Loss: {val_loss:.4f},   Top1: {val_top1:.2f}%,   Top5: {val_top5:.2f}%")

    print("Training complete.")


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data\cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [02:12<00:00, 1286529.23it/s]


Extracting ./data\cifar-10-python.tar.gz to ./data
Files already downloaded and verified
Model has 89.86 M learnable parameters.


KeyboardInterrupt: 

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as T
import math

# -------------------------
# 1. CIFAR-10 Datasets & Dataloaders
# -------------------------
def get_cifar10_loaders(batch_size=64, num_workers=2):
    """
    Returns train/val loaders for CIFAR-10 dataset.
    Images are resized from 32x32 to 224x224 and normalized.
    """
    # CIFAR-10 mean/std
    mean = (0.4914, 0.4822, 0.4465)
    std  = (0.2470, 0.2435, 0.2616)

    # Train transforms: resize to 224, random augmentations
    train_transform = T.Compose([
        T.Resize(224),
        T.RandomCrop(224, padding=4),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.Normalize(mean, std),
    ])

    # Validation transforms: resize to 224
    val_transform = T.Compose([
        T.Resize(224),
        T.ToTensor(),
        T.Normalize(mean, std),
    ])

    train_dataset = torchvision.datasets.CIFAR10(
        root='./data',
        train=True,
        download=True,
        transform=train_transform
    )
    val_dataset = torchvision.datasets.CIFAR10(
        root='./data',
        train=False,
        download=True,
        transform=val_transform
    )

    train_loader = DataLoader(train_dataset, batch_size=batch_size,
                              shuffle=True, num_workers=num_workers)
    val_loader   = DataLoader(val_dataset,   batch_size=batch_size,
                              shuffle=False, num_workers=num_workers)
    return train_loader, val_loader


# -------------------------
# 2. Model Components
# -------------------------

class PatchEmbedding(nn.Module):
    """
    Splits 224x224 images into (224/8=28) x (224/8=28)=784 patches,
    each of size 8x8x3 flattened -> projected to hidden_dim=768.
    Adds a [CLS] token and learns positional embeddings.
    """
    def __init__(self, img_size=224, patch_size=8, in_chans=3, hidden_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = img_size // patch_size  # 224/8=28
        num_patches = self.grid_size * self.grid_size  # 784

        # Project each patch to hidden_dim
        self.proj = nn.Conv2d(in_chans, hidden_dim,
                              kernel_size=patch_size, stride=patch_size)
        # -> shape: (B, 768, 28, 28) => flatten => (B, 784, 768)

        self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, hidden_dim))
        self.num_patches = num_patches

    def forward(self, x):
        # x: (B, 3, 224, 224)
        B = x.shape[0]
        x = self.proj(x)  # -> (B, 768, 28, 28)
        x = x.flatten(2)  # -> (B, 768, 784)
        x = x.transpose(1, 2)  # -> (B, 784, 768)

        # Concat CLS token
        cls_token = self.cls_token.expand(B, -1, -1)  # (B, 1, 768)
        x = torch.cat([cls_token, x], dim=1)  # (B, 785, 768)

        # Add positional embeddings
        x = x + self.pos_embed[:, : x.size(1), :]
        return x  # (B, 785, 768)


class TransformerEncoderBlock(nn.Module):
    """
    Standard Transformer block: MHSA + MLP (with skip connections & layernorm).
    """
    def __init__(self, hidden_dim=768, num_heads=12, mlp_ratio=4.0, dropout=0.0):
        super().__init__()
        self.ln1 = nn.LayerNorm(hidden_dim)
        self.attn = nn.MultiheadAttention(embed_dim=hidden_dim,
                                          num_heads=num_heads,
                                          dropout=dropout,
                                          batch_first=True)  # (B, N, C)
        self.ln2 = nn.LayerNorm(hidden_dim)

        self.mlp = nn.Sequential(
            nn.Linear(hidden_dim, int(hidden_dim * mlp_ratio)),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(int(hidden_dim * mlp_ratio), hidden_dim),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        # x: (B, N, 768)
        x_norm = self.ln1(x)
        attn_out, _ = self.attn(x_norm, x_norm, x_norm)  # (B, N, 768)
        x = x + attn_out

        x_norm = self.ln2(x)
        mlp_out = self.mlp(x_norm)
        x = x + mlp_out
        return x


def pairwise_mean_downsample(x):
    """
    Downsamples the patch tokens by pairs:
      - Keep CLS token separate (index 0).
      - Pair up the patch tokens (indices 1...N-1) => average each pair.
      - Ensure patch_tokens.shape[1] is even.
    Output has roughly half the patch tokens + 1 CLS.
    """
    B, N, C = x.shape
    cls_token = x[:, :1, :]   # (B, 1, 768)
    patch_tokens = x[:, 1:, :]  # (B, N-1, 768)

    # Check even number of patch tokens
    assert (patch_tokens.shape[1] % 2) == 0, \
        f"Number of patch tokens must be even, got {patch_tokens.shape[1]}"

    # Reshape and average each pair
    patch_tokens = patch_tokens.reshape(B, patch_tokens.shape[1] // 2, 2, C)
    patch_tokens = patch_tokens.mean(dim=2)  # (B, new_N, 768)

    # Concat CLS on the front
    x_down = torch.cat([cls_token, patch_tokens], dim=1)  # (B, new_N+1, 768)
    return x_down


class ProgressiveTransformer(nn.Module):
    """
    A 12-layer Transformer with hidden_dim=768.
    - Patch size=8 -> 784 patches + 1 CLS = 785 tokens
    - After blocks #3, #6, #9: pairwise-mean downsampling of patch tokens.
    - Finally output a 10-class prediction for CIFAR-10.
    """
    def __init__(self,
                 img_size=224,
                 patch_size=8,
                 in_chans=3,
                 hidden_dim=768,
                 num_heads=12,
                 mlp_ratio=4.0,
                 num_layers=12,
                 num_classes=10,
                 dropout=0.0):
        super().__init__()

        self.patch_embed = PatchEmbedding(img_size, patch_size, in_chans, hidden_dim)

        self.blocks = nn.ModuleList([
            TransformerEncoderBlock(hidden_dim, num_heads, mlp_ratio, dropout)
            for _ in range(num_layers)
        ])

        self.norm = nn.LayerNorm(hidden_dim)
        self.head = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        # 1) Embed patches + CLS
        x = self.patch_embed(x)  # (B, 785, 768) initially

        # 2) Pass through 12 blocks, downsampling after blocks 3, 6, 9
        for i, block in enumerate(self.blocks, start=1):
            x = block(x)
            if i in [3, 6, 9]:
                x = pairwise_mean_downsample(x)
                # e.g. after block 3: from (B, 785, 768) -> (B, 393, 768)
                #     after block 6: (B, 393, 768) -> (B, 197, 768)
                #     after block 9: (B, 197, 768) -> (B, 99, 768)

        # 3) Final norm + classifier
        x = self.norm(x)         # (B, N, 768)
        cls_token = x[:, 0]      # (B, 768)
        out = self.head(cls_token)  # (B, 10)
        return out


# -------------------------
# 3. Training / Evaluation Helpers
# -------------------------
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def accuracy_topk(logits, targets, topk=(1,5)):
    """
    Computes top-k accuracy for specified k values.
    Returns list of accuracies in percentage.
    """
    max_k = max(topk)
    batch_size = targets.size(0)

    # Get top-k predictions
    _, pred = logits.topk(max_k, dim=1, largest=True, sorted=True)
    pred = pred.t()  # (max_k, B)
    correct = pred.eq(targets.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
        acc_k = correct_k * 100.0 / batch_size
        res.append(acc_k.item())
    return res

def train_one_epoch(model, loader, optimizer, device):
    model.train()
    total_loss = 0.0
    total_samples = 0
    top1_sum = 0.0
    top5_sum = 0.0

    for images, labels in loader:
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        logits = model(images)
        loss = F.cross_entropy(logits, labels)
        loss.backward()
        optimizer.step()

        batch_size = images.size(0)
        total_loss += loss.item() * batch_size
        total_samples += batch_size

        acc1, acc5 = accuracy_topk(logits, labels, topk=(1,5))
        top1_sum += acc1 * batch_size
        top5_sum += acc5 * batch_size

    avg_loss = total_loss / total_samples
    avg_top1 = top1_sum / total_samples
    avg_top5 = top5_sum / total_samples
    return avg_loss, avg_top1, avg_top5


@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    total_loss = 0.0
    total_samples = 0
    top1_sum = 0.0
    top5_sum = 0.0

    for images, labels in loader:
        images = images.to(device)
        labels = labels.to(device)

        logits = model(images)
        loss = F.cross_entropy(logits, labels)

        batch_size = images.size(0)
        total_loss += loss.item() * batch_size
        total_samples += batch_size

        acc1, acc5 = accuracy_topk(logits, labels, topk=(1,5))
        top1_sum += acc1 * batch_size
        top5_sum += acc5 * batch_size

    avg_loss = total_loss / total_samples
    avg_top1 = top1_sum / total_samples
    avg_top5 = top5_sum / total_samples
    return avg_loss, avg_top1, avg_top5


# -------------------------
# 4. Putting It All Together (Example Training Script)
# -------------------------
if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Get CIFAR-10 data
    train_loader, val_loader = get_cifar10_loaders(batch_size=32)

    # Build Progressive Transformer (ViT-Base dimension=768, 12 heads, 12 layers)
    model = ProgressiveTransformer(
        img_size=224,
        patch_size=8,
        hidden_dim=768,
        num_heads=12,      # 12 x 64 = 768
        mlp_ratio=4.0,
        num_layers=12,
        num_classes=10,
        dropout=0.1
    ).to(device)

    # Count parameters
    num_params = count_parameters(model)
    print(f"Model has {num_params/1e6:.2f} M learnable parameters.")

    # Simple AdamW optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)

    # Train for a few epochs as an example
    epochs = 5
    for epoch in range(1, epochs+1):
        train_loss, train_top1, train_top5 = train_one_epoch(model, train_loader, optimizer, device)
        val_loss, val_top1, val_top5 = evaluate(model, val_loader, device)

        print(f"Epoch [{epoch}/{epochs}]")
        print(f"  Train Loss: {train_loss:.4f}, Top1: {train_top1:.2f}%, Top5: {train_top5:.2f}%")
        print(f"  Val   Loss: {val_loss:.4f},   Top1: {val_top1:.2f}%,   Top5: {val_top5:.2f}%")

    print("Training complete.")
