# Implementation: Patchifying Latents

**Goal**: Prepare image for Transformer.

In [None]:
import torch
import torch.nn as nn

# 1. Mock Latent Input
# Batch 1, Channels 4, Height 32, Width 32
x = torch.randn(1, 4, 32, 32)

# 2. Parameters
patch_size = 2
dim = 4

# 3. Patchify Logic (Unfold)
# Look at 2x2 blocks
patches = x.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
# patches shape: [1, 4, 16, 16, 2, 2]

# Flatten into sequence
patches = patches.permute(0, 2, 3, 1, 4, 5).reshape(1, 16*16, -1)
# Result: [1, 256, 16]
# 256 tokens. Each token has size 16 (4 channels * 2 * 2 pixels).

print(f"Original Shape: {x.shape}")
print(f"Sequence Shape: {patches.shape}")

# Now we can feed 'patches' into a standard Transformer Encoder.

## Conclusion
DiT treats an image exactly like a sentence of words.