# Notebook 3: Pretext Tasks (Rotation, Jigsaw, Colorization)

Before the rise of contrastive learning, much of the progress in visual self-supervised learning came from designing pretext tasks – artificial problems that a neural network could solve using the structure of the data itself. These tasks encourage the model to learn high-level features as a by-product of trying to solve them. In this notebook, we will explore three canonical pretext tasks in computer vision:
- **Rotation Prediction** – Have the model predict how an image has been rotated.
- **Jigsaw Puzzle Solving** – Have the model arrange shuffled image patches into the correct order.
- **Colorization** – Have the model colorize a grayscale image.

These tasks were important stepping stones in SSL research. While they may not match the representation quality of contrastive methods on their own, they introduced key ideas and are still quite useful or combined with other methods.


## Predicting Image Rotations

Gidaris et al. (ICLR 2018) introduced a simple yet effective pretext task: randomly rotate an image by one of four angles {0°, 90°, 180°, 270°}, and train a CNN to classify which rotation was applied. Why does this help? To succeed, the model must understand the content of the image to some degree – for example, a dog upside down vs right-side-up will look very different. The network can’t just detect “sky at top” because not all images have skies, etc.; it has to learn more general features (like object shapes, orientation of canonical structures such as faces) to figure out the rotation.

**Task setup:** We define 4 classes corresponding to the 4 possible rotations. For each training image, we randomly rotate it by one of these angles and label it with that angle class (0,1,2,3 for 0°...270°). The network is a standard classifier (like a ResNet) with 4 outputs. It’s trained with a cross-entropy loss to predict the correct rotation.

Despite its simplicity, rotation prediction proved to be a powerful self-supervised signal:
- It forces the network to learn orientation-dependent features. For instance, many objects have a "correct" orientation (animals stand upright, cars on wheels, text is horizontal, etc.). To classify rotation, the network implicitly needs to recognize those objects or at least their asymmetry.
- The learned features transfer well to classification and detection tasks. In the original paper, a network pre-trained on rotation prediction significantly outperformed a random-initialized one when fine-tuned on ImageNet or Pascal VOC detection.

Training a rotation prediction model is straightforward. Let's implement a quick demonstration on CIFAR-10 (since it has objects where orientation can be meaningful, albeit low-resolution):


In [None]:
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF

# Simple CNN classifier (like earlier SmallCNN but ending in 4 logits for rotation)
class RotationNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.feat_extractor = SmallCNN(out_dim=64)  # reuse the SmallCNN conv trunk but smaller output
        self.classifier = nn.Linear(64, 4)  # 4 rotation classes
    def forward(self, x):
        feat = self.feat_extractor(x)
        logits = self.classifier(feat)
        return logits

# Create rotated versions of images with corresponding labels
def make_rotations(batch_images):
    images = []
    labels = []
    for img in batch_images:
        angle = torch.randint(0, 4, (1,)).item() * 90  # random choice of 0,90,180,270
        if angle == 0:
            rot_img = img
            label = 0
        else:
            rot_img = TF.rotate(img, angle)
            label = angle // 90  # 90->1, 180->2, 270->3
        images.append(rot_img)
        labels.append(label)
    return torch.stack(images), torch.tensor(labels)

# Training loop for rotation prediction
rot_model = RotationNet().to(device)
optimizer = torch.optim.Adam(rot_model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

# We'll use CIFAR-10 data as unlabeled
train_loader = DataLoader(train_subset, batch_size=128, shuffle=True)
for epoch in range(3):
    rot_model.train()
    total_loss = 0
    for imgs, _ in train_loader:
        imgs = imgs.to(device)
        # generate random rotations
        imgs_rot, labels_rot = make_rotations(imgs)
        imgs_rot, labels_rot = imgs_rot.to(device), labels_rot.to(device)
        # forward pass
        logits = rot_model(imgs_rot)
        loss = criterion(logits, labels_rot)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}: Rotation prediction loss = {total_loss/len(train_loader):.3f}")


After training for a few epochs (again, this is just illustrative; real training might use more epochs and a deeper network), the model should be able to predict rotations significantly better than random (which would be 25% accuracy with 4 classes). The learned convolutional filters in `rot_model.feat_extractor` will likely detect edges, corners, and other features that help orient the image.

**Feature learning:** What did the model learn? Possibly, it learned to detect gravity direction (sky vs ground), or for objects like animals, the typical orientation of legs/head. For text in images, detecting readable orientation vs upside-down would be key. All these cues are useful for general vision understanding. Indeed, the rotation task was found to complement other tasks: later works combined rotation with contrastive losses to further boost performance, indicating it learns something slightly different and complementary.


## Solving Jigsaw Puzzles

Another iconic SSL task was proposed by Noroozi & Favaro (ECCV 2016): Jigsaw puzzles. The idea is to take an image, split it into a grid of patches (e.g., 3x3 grid = 9 patches), shuffle these patches, and ask the network to predict the correct order (i.e., which patch goes to which position).  Solving jigsaw puzzles as self-supervision. (a) Shows the original image with grid; (b) the jumbled puzzle given to the network; (c) the solved puzzle in correct arrangement. The model must learn to recognize objects or scenes to some extent to reassemble them correctly.

**Formulation:** How do we get a label for a shuffled puzzle? There are many possible permutations (9! = 362,880 for 9 pieces), which is too many classes to classify directly. The authors simplified the problem: they did not allow all permutations, but sampled a subset of 1000 (pre-defined) permutations out of all possibilities and treated the task as 1000-class classification (the network predicts which permutation index is applied). Each permutation is a specific way the tiles could be shuffled, and the model must recognize which one by essentially learning to arrange them.

An alternative formulation could be multiple classification heads for each tile’s position, but the original method used the single-class approach.

**Network:** Interestingly, they designed a special network (called Context-Free Network) that processes each patch independently (siamese CNN towers with shared weights for each patch) and then a fully connected layer that combines the patch features to predict the permutation. By doing this, they forced the network to truly rely on the spatial configuration rather than trivial cues like continuity between patches (since each patch was processed separately until the final layers).

**Learning outcome:** The Jigsaw task is very challenging – the network has to learn the global context of the image. For example, if one patch has a bit of a dog’s face and another has a tail, the network should learn that the face patch likely belongs above the tail patch in the original image. It encourages learning about object parts and their configurations.

In practice:
- Features learned from jigsaw puzzles were useful for classification and detection. The paper showed improvements on PASCAL VOC detection after using jigsaw pretraining.
- The network likely learns to recognize salient objects and textures because aligning patches requires knowing that “this patch has part of a wheel, it should be bottom of a car” etc.

We can attempt a simplified jigsaw implementation, but doing a full 1000-way classification is complex for a quick demo. Instead, let's conceptually illustrate a simpler version:

We will do a 2x2 puzzle (4 pieces) for demonstration. That yields only 24 possible permutations (4! = 24), which is manageable to brute-force classify in a demo.

This is a toy puzzle: the network will take 4 patch embeddings and output a permutation prediction (0-23).


In [None]:
import itertools

# Prepare all permutations for 2x2 puzzle
permutations = list(itertools.permutations(range(4)))

# Network for 2x2 jigsaw (shared CNN for patches + permutation classifier)
class JigsawNet2x2(nn.Module):
    def __init__(self):
        super().__init__()
        self.patch_net = nn.Sequential(  # small conv for patch feature
            nn.Conv2d(3, 32, 3, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1), nn.ReLU(),
            nn.Flatten(),
            nn.Linear(64*8*8, 128), nn.ReLU()  # assuming input patch 16x16 (if original 32 and 2x2 grid)
        )
        self.classifier = nn.Linear(128*4, len(permutations))  # combine 4 patch features
    def forward(self, patches):
        # patches shape: (batch, 4, 3, 16, 16)
        B = patches.size(0)
        patches = patches.view(B*4, 3, 16, 16)
        feat = self.patch_net(patches)          # shape (B*4, 128)
        feat = feat.view(B, 4*128)             # concatenate features
        out = self.classifier(feat)           # shape (B, 24) logits for each permutation
        return out

# Function to create 2x2 jigsaw puzzle from image
def make_puzzle_2x2(img):
    # img expected size 32x32
    # Split into four 16x16 patches
    patches = []
    # top-left
    patches.append(TF.crop(img, 0, 0, 16, 16))
    # top-right
    patches.append(TF.crop(img, 0, 16, 16, 16))
    # bottom-left
    patches.append(TF.crop(img, 16, 0, 16, 16))
    # bottom-right
    patches.append(TF.crop(img, 16, 16, 16, 16))
    # Choose a random permutation
    perm_idx = torch.randint(len(permutations), (1,)).item()
    perm = permutations[perm_idx]
    shuffled_patches = [patches[i] for i in perm]
    return torch.stack(shuffled_patches), perm_idx

# Example usage:
img, _ = train_dataset[0]
puzzle, perm_label = make_puzzle_2x2(img)
print("True permutation:", permutations[perm_label])
# Visualize puzzle
import matplotlib.pyplot as plt
plt.figure(figsize=(2,2))
for i, patch in enumerate(puzzle):
    plt.subplot(2,2,i+1)
    plt.imshow(patch.permute(1,2,0))
    plt.axis('off')
plt.suptitle("Shuffled 2x2 Puzzle")
plt.show()


We could train `JigsawNet2x2` similarly to how we trained `RotationNet`, feeding in puzzles and labels and using cross-entropy. Given the limited scope, we'll not run a full training here. But if we did, the network would learn to predict the correct arrangement out of 24 possibilities.

Scaling to 3x3 puzzles: The real jigsaw task uses 9 pieces and a subset of 1000 permutations. That task is quite complex – the network must learn higher-level assembly. By examining intermediate feature activations, one finds the network learns to detect patterns that span multiple patches (e.g., continuation of a fence or a stripe from one patch to the adjacent patch). This encourages a form of contextual understanding.


## Image Colorization

Colorization as a self-supervised task was explored by Larsson et al. 2016 and Zhang et al. 2016. Here, the pretext task is: take a grayscale image and predict the color (the output could be the chromatic channels or full color image). The network is trained on millions of color photos, but during training it only receives the grayscale version as input; the target is the original color image. Essentially, the model learns to add color to black-and-white images.

Why is this meaningful? Consider what it takes to colorize an image:
- The model needs to infer what objects or materials are present to color them realistically (grass is likely green, sky blue, oranges are orange, etc.).
- It also must capture contextual cues: e.g., a football field vs a dry field might have different shades of green/brown.
  
Some ambiguity is involved (a shirt could be any color), but models address this by predicting a distribution of colors or using classification in color space.

Zhang et al. ("Colorful Image Colorization") treated colorization as a classification problem in Lab color space:
- They use the L (lightness) channel as input, and predict a and b color channels.
- Instead of regressing raw color values (which can lead to dull averages), they quantize the ab color space into bins (313 bins) and have the network output a probability distribution over colors for each pixel. A special loss (mix of multinomial cross-entropy with re-balancing) is used to encourage vibrant colors.
- At test time, they pick the most likely color for each pixel (or use an annealed mean to avoid desaturation).

The colorization network often has an encoder-decoder architecture (to produce an output map the same size as input). By learning to colorize, the encoder part of the network learns rich features about objects – because to color them properly, it must recognize what they are:
- If the input is a gray photo of a banana, the network likely outputs yellow tints for that banana region; to do that it must have some representation of "banana-ness".
- If the input is a gray sky, it knows skies are usually blue or gray; a lawn is green, etc.

Colorization was shown to be a strong supervisory signal. Zhang et al. reported that their colorization-pretrained model, when used as a feature extractor, performed well on classification and detection tasks (not as well as supervised pre-training, but impressively close for an unsupervised method of that time).

**Visualization:** Colorization networks produce impressive visual results – turning old black & white photos to color. But keep in mind, for self-supervision, we don't actually need perfect photo-realistic colorization; we only care that the network learned good internal representations. Sometimes the output might be nonsensical (because multiple color assignments are plausible), but as long as the process forced the network to learn structure (e.g., outline of objects, texture details), it succeeded as a pretext.

Let's do a quick code sketch using a pre-trained colorization model from OpenCV to see how colorization works (OpenCV has a deep learning colorizer example trained by Richard Zhang): (We assume we have an OpenCV dnn model or similar. If not, we'll skip actual code execution and just conceptually outline.)


In [None]:
import cv2
import numpy as np

# Suppose we have a pre-trained colorization model loaded (this is pseudo-code)
net = cv2.dnn.readNetFromCaffe("colorization_deploy_v2.prototxt", "colorization_release_v2.caffemodel")
# The network expects input in Lab space with L normalized a certain way.
# We prepare a grayscale image:
gray = cv2.imread('grayscale_image.jpg', cv2.IMREAD_GRAYSCALE)
gray = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)  # replicate to 3 channels for input
blob = cv2.dnn.blobFromImage(gray, scalefactor=1.0/255.0)
net.setInput(blob)
output = net.forward()  # output ab channels
# Post-process the output with the original L channel


## Summary and Usage

Each of these tasks – rotation, jigsaw, colorization – can be used to pre-train a model on unlabeled images. Typically, one would:
- Set up the chosen pretext task and train the network until it solves that task well (convergence of loss).
- Then either directly use the learned features or fine-tune the network on the actual task of interest.

In practice, these tasks can also be combined. For example, one could train a network to simultaneously predict rotation and colorize an image (multi-task self-supervision) hoping to learn an even stronger representation. Some works added rotation prediction heads on contrastive learning models (like an auxiliary loss) to boost performance. While these pretext tasks have somewhat been overtaken by contrastive and transformer-based methods in recent years (which achieve higher results), they are still very relevant:
- They are simple and lightweight to implement (no need for large batches or tricky losses).
- They provide intuition about what features are learned (e.g., one can inspect a rotation model and see neurons firing for upright vs inverted patterns).
- They might be useful in scenarios where contrastive learning is hard to apply (e.g., small datasets), or can serve as a quick initialization.

In the next notebook, we'll shift focus to masked modeling, which is a pretext task paradigm that has revolutionized NLP (with BERT) and is now making waves in vision (with Masked Autoencoders). This will connect the ideas from this notebook (predicting missing information) with powerful modern architectures.

**Bonus Exercise:** Design your own jigsaw puzzle variant: for instance, a 5x5 puzzle but have the network predict the relative location of each patch (a series of pairwise position classifications). How might that be set up? Alternatively, think about video jigsaw – shuffling frames in a video and predicting the correct order. What would a model have to learn in order to solve a video jigsaw puzzle?

## References:
- Gidaris, S., Singh, P., Komodakis, N. (2018). "Unsupervised Representation Learning by Predicting Image Rotations." ICLR. – Rotation prediction task introduced.
- Noroozi, M., Favaro, P. (2016). "Unsupervised Learning of Visual Representations by Solving Jigsaw Puzzles." ECCV. – Jigsaw puzzle task for SSL.
- Zhang, R., Isola, P., Efros, A.A. (2016). "Colorful Image Colorization." ECCV. – Treats colorization as self-supervision, with a class-balanced loss for vibrant outputs.
- Larsson, G. et al. (2016). "Learning Representations for Automatic Colorization." ECCV. – Another approach to colorization-based SSL.
- Doersch, C., Gupta, A., Efros, A.A. (2015). "Unsupervised Visual Representation Learning by Context Prediction." ICCV. – Early work predicting relative patch positions (the precursor to jigsaw idea).
- Pathak, D. et al. (2016). "Context Encoders: Feature Learning by Inpainting." CVPR. – Another pretext task: image inpainting (predict missing central patch), which is akin to colorization in spirit.
- Reed, J. et al. (2019). "Self-supervised Learning for Video: Sequence Ordering." – Extends the puzzle idea to video frames (temporal order).
- Jenni, S., Favaro, P. (2018). "Self-Supervised Feature Learning by Learning to Spot Artifacts." – A creative task where artifacts (like inpainting with noise) are introduced and the network must detect them; another form of self-supervision.
