# JigSaw pretext task
Following [the original implementation](https://arxiv.org/pdf/1603.09246) <br>
Also useful to look at [the FAIR paper](https://arxiv.org/pdf/1905.01235) (page 12), for details on the implementation.<br>
Adapted to use ResNet18 instead of CFN.
TODO
- Organize into separate .py modules for JigSaw utils
- Add ViT
- Change training dataset to something w/ resolution of ~255x255 to avoid the need to upscate data
- Add the evaluations -> basically take the (pretrained) resnet module and plug it into another module w/ a clean classification head 
    - linear probing
    - full ft

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from torchvision import transforms, models
from PIL import Image
import numpy as np
import os
import random
from tqdm import tqdm
from datasets import load_dataset

In [2]:
# Set random seed for reproducibility -> maybe use pytorch lightning for reproducibility
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x14c6c03dedb0>

## Load Tiny-ImageNet from HF

In [3]:
tinyImageNet_dataset = load_dataset("zh-plus/tiny-imagenet")
# We can also download it from here http://cs231n.stanford.edu/tiny-imagenet-200.zip but i think HFs its easier

In [4]:
tinyImageNet_dataset

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 100000
    })
    valid: Dataset({
        features: ['image', 'label'],
        num_rows: 10000
    })
})

## Prepare dataset

In [5]:
def generate_permutations(n_permutations, n_tiles):
    """
    Generates a list of permutations, these will essentially be the 'gold truth' labels
    """
    permutations = []
    seen = set()
    while len(permutations) < n_permutations:
        perm = tuple(np.random.permutation(n_tiles))
        if perm not in seen:
            permutations.append(perm)
            seen.add(perm)
    return permutations

In [6]:
# Generate n_permutations permutations for 9 tiles, these values are related to the complexity of the task and should be updated as necessary
n_permutations = 1000 ##essentially the number of classes
n_tiles = 9
permutations = generate_permutations(n_permutations, n_tiles)

# Enumerate and store permutations
permutations_dict = {i: perm for i, perm in enumerate(permutations)} 

In [7]:
permutations[:3] 
print('These indices will basically be the possible "gold" labels :p')

These indices will basically be the possible "gold" labels :p


In [14]:
class JigsawPuzzleDataset(data.Dataset):
    def __init__(self, hf_dataset, permutations, transform=None):
        """
        Input:
            hf_dataset: HuggingFace Dataset object.
            permutations: List of permutations.
            transform: Optional transform to be applied on a sample.
        """
        self.dataset = hf_dataset
        self.permutations = permutations
        self.n_permutations = len(permutations)
        self.n_tiles = 9  # 3x3 grid -> this can be modified later
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        # Load image from the HuggingFace dataset and convert to RGB
        image = self.dataset[idx]['image'].convert('RGB')  # Ensure image is in RGB

        # We'll resize it to 255x255 since this is ResNet's input size
        image = image.resize((255, 255))

        # Divide the image into 3x3 grid of tiles (85x85 pixels each)
        tiles = []
        tile_size = 85 # 85 * 3 = 255

        ## Iterate over possible tiles and create the patches
        for i in range(3):
            for j in range(3):

                '''
                Original paper explanation: We randomly crop a 225 × 225 pixel window from an image (red dashed box), divide it into a 3 × 3 grid, and randomly pick a 64 × 64 pixel tiles from each 75 × 75 pixel cell.
                '''

                # Get boundaries and crop
                left = j * tile_size
                upper = i * tile_size
                right = left + tile_size
                lower = upper + tile_size
                tile = image.crop((left, upper, right, lower))
                
                # Random crop of 64x64 pixels with random shifts
                shift_max = tile_size - 64  # Max shift to introduce randomness
                left_shift = random.randint(0, shift_max)
                upper_shift = random.randint(0, shift_max)
                tile = tile.crop((left_shift, upper_shift, left_shift + 64, upper_shift + 64))
                
                # Apply any transform passed as argument
                if self.transform is not None:
                    tile = self.transform(tile)
                    
                tiles.append(tile)
                
        # Select a random permutation from the pre-computed permutations
        perm_idx = random.randint(0, self.n_permutations - 1)
        perm = self.permutations[perm_idx]

        # Shuffle the tiles according to the permutation
        shuffled_tiles = [tiles[p] for p in perm]

        # Stack tiles into a tensor
        tiles_tensor = torch.stack(shuffled_tiles, dim=0)  # Shape: [9, 3, 64, 64]

        # Return the shuffled tiles and the permutation index which is the gold label we aim the model to predict
        return tiles_tensor, perm_idx


In [15]:
transform = transforms.Compose([transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),transforms.ToTensor()])

## Instance the torch module

In [16]:
class JigsawNet(nn.Module):
    def __init__(self, 
                 n_permutations,
                 architecture = 'resnet', # 'resnet' or 'vit'
                ):
        
        super(JigsawNet, self).__init__()

        if architecture=='resnet':
            # Backbone ResNet model TODO: replace by ResNet 50
            # self.resnet = models.resnet18(pretrained=False) # I thnk this is deprecated
            self.resnet = models.resnet18() 
            self.resnet.fc = nn.Identity()  #Remove the classification layer
            
        elif architecture=='vit':
            pass ##TODO

        # Fully connected layers << to dispose after the PTT
        self.fc = nn.Sequential(
            nn.Linear(512 * 9, 4096), # each genertaes a 512-dimensional vector
            nn.ReLU(),
            nn.Linear(4096, n_permutations)
        )

    def forward(self, x):
        # x shape: [batch_size, 9, 3, 64, 64]
        batch_size = x.size(0)

        # Combine batch and tile dimensions (siamese network -> feed the same weights all the patches at once)
        x = x.view(batch_size * 9, 3, 64, 64)  
        features = self.resnet(x)  # Shape: [batch_size * 9, 512]

        # Concatenate the patches before the linear layers that learns the differences
        features = features.view(batch_size, 9 * 512)  # Shape: [batch_size, 9 * 512]

        #
        out = self.fc(features)  # Shape: [batch_size, n_permutations]
        return out

## Training
Hyperparameters should be adapted

In [18]:
model = JigsawNet(n_permutations=n_permutations)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

In [19]:
# Create the datasets and dataloaders
train_dataset = JigsawPuzzleDataset(tinyImageNet_dataset['train'], permutations, transform=transform)
valid_dataset = JigsawPuzzleDataset(tinyImageNet_dataset['valid'], permutations, transform=transform)

train_loader = data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)
valid_loader = data.DataLoader(valid_dataset, batch_size=128, shuffle=False, num_workers=4)



In [20]:
print('Shape of dataset output: {}'.format(next(iter(train_loader))[0].shape))

Shape of dataset output: torch.Size([128, 9, 3, 64, 64])


In [21]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)
num_epochs = 10  
log_each = 100

In [None]:
for epoch in range(num_epochs):
    model.train()
    avg_loss = 0.0
    for batch_idx, (tiles, perm_idx) in enumerate(tqdm(train_loader)):
        tiles = tiles.to(device)  # Shape: [batch_size, 9, 3, 64, 64]
        perm_idx = perm_idx.to(device)  # Shape: [batch_size]

        optimizer.zero_grad()
        outputs = model(tiles)  # Shape: [batch_size, n_permutations]
        loss = criterion(outputs, perm_idx)
        loss.backward()
        optimizer.step()

        avg_loss += loss.item()
        if batch_idx % log_each == log_each - 1:
            print(f'Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx+1}], Loss: {avg_loss / 100:.4f}')
            avg_loss = 0.0
            
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for tiles, perm_idx in tqdm(valid_loader):
            tiles = tiles.to(device)
            perm_idx = perm_idx.to(device)

            outputs = model(tiles)
            loss = criterion(outputs, perm_idx)
            val_loss += loss.item()

            _, predicted = torch.max(outputs.data, 1)
            total += perm_idx.size(0)
            correct += (predicted == perm_idx).sum().item()

    val_accuracy = 100 * correct / total
    avg_val_loss = val_loss / len(valid_loader)
    print(f'Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%')


 13%|█████▏                                   | 100/782 [01:06<04:47,  2.37it/s]

Epoch [1/10], Batch [100], Loss: 6.9271


 26%|██████████▍                              | 200/782 [02:10<04:05,  2.37it/s]

Epoch [1/10], Batch [200], Loss: 6.9242


 38%|███████████████▋                         | 300/782 [03:14<03:40,  2.19it/s]

Epoch [1/10], Batch [300], Loss: 6.9202


 51%|████████████████████▉                    | 400/782 [04:17<02:36,  2.44it/s]

Epoch [1/10], Batch [400], Loss: 6.9152


 64%|██████████████████████████▏              | 500/782 [05:20<01:56,  2.42it/s]

Epoch [1/10], Batch [500], Loss: 6.9097


 77%|███████████████████████████████▍         | 600/782 [06:23<01:19,  2.28it/s]

Epoch [1/10], Batch [600], Loss: 6.9106


 89%|████████████████████████████████████▋    | 699/782 [07:27<00:43,  1.90it/s]

Epoch [1/10], Batch [700], Loss: 6.9027


100%|█████████████████████████████████████████| 782/782 [08:17<00:00,  1.57it/s]
100%|███████████████████████████████████████████| 79/79 [00:50<00:00,  1.57it/s]


Validation Loss: 6.8944, Validation Accuracy: 0.27%


 13%|█████▏                                   | 100/782 [01:05<04:30,  2.52it/s]

Epoch [2/10], Batch [100], Loss: 6.8868


 26%|██████████▍                              | 200/782 [02:08<03:51,  2.51it/s]

Epoch [2/10], Batch [200], Loss: 6.8753


 38%|███████████████▋                         | 300/782 [03:11<03:29,  2.30it/s]

Epoch [2/10], Batch [300], Loss: 6.8540


 51%|████████████████████▉                    | 400/782 [04:13<02:32,  2.51it/s]

Epoch [2/10], Batch [400], Loss: 6.8387


 64%|██████████████████████████▏              | 499/782 [05:15<02:20,  2.02it/s]

Epoch [2/10], Batch [500], Loss: 6.8103


 77%|███████████████████████████████▍         | 599/782 [06:18<01:34,  1.94it/s]

Epoch [2/10], Batch [600], Loss: 6.7797


 90%|████████████████████████████████████▋    | 700/782 [07:20<00:32,  2.49it/s]

Epoch [2/10], Batch [700], Loss: 6.7403


100%|█████████████████████████████████████████| 782/782 [08:09<00:00,  1.60it/s]
100%|███████████████████████████████████████████| 79/79 [00:49<00:00,  1.58it/s]


Validation Loss: 6.6849, Validation Accuracy: 1.14%


 13%|█████▏                                   | 100/782 [01:05<04:55,  2.31it/s]

Epoch [3/10], Batch [100], Loss: 6.6693


 26%|██████████▍                              | 200/782 [02:08<04:06,  2.36it/s]

Epoch [3/10], Batch [200], Loss: 6.6247


 38%|███████████████▋                         | 300/782 [03:10<03:32,  2.27it/s]

Epoch [3/10], Batch [300], Loss: 6.5805


 51%|████████████████████▉                    | 400/782 [04:13<02:44,  2.32it/s]

Epoch [3/10], Batch [400], Loss: 6.5319


 64%|██████████████████████████▏              | 499/782 [05:16<02:34,  1.83it/s]

Epoch [3/10], Batch [500], Loss: 6.4804


 77%|███████████████████████████████▍         | 600/782 [06:19<01:16,  2.39it/s]

Epoch [3/10], Batch [600], Loss: 6.4112


 90%|████████████████████████████████████▋    | 700/782 [07:21<00:34,  2.41it/s]

Epoch [3/10], Batch [700], Loss: 6.3760


100%|█████████████████████████████████████████| 782/782 [08:10<00:00,  1.59it/s]
100%|███████████████████████████████████████████| 79/79 [00:50<00:00,  1.57it/s]


Validation Loss: 6.3848, Validation Accuracy: 1.20%


 13%|█████▏                                   | 100/782 [01:04<04:33,  2.50it/s]

Epoch [4/10], Batch [100], Loss: 6.2868


 26%|██████████▍                              | 200/782 [02:08<04:14,  2.28it/s]

Epoch [4/10], Batch [200], Loss: 6.2685


 38%|███████████████▋                         | 299/782 [03:10<04:03,  1.98it/s]

Epoch [4/10], Batch [300], Loss: 6.2200


 51%|████████████████████▉                    | 400/782 [04:13<02:36,  2.43it/s]

Epoch [4/10], Batch [400], Loss: 6.1812


 64%|██████████████████████████▏              | 500/782 [05:15<01:47,  2.62it/s]

Epoch [4/10], Batch [500], Loss: 6.1706


 77%|███████████████████████████████▍         | 600/782 [06:18<01:13,  2.48it/s]

Epoch [4/10], Batch [600], Loss: 6.1220


 90%|████████████████████████████████████▋    | 700/782 [07:20<00:34,  2.35it/s]

Epoch [4/10], Batch [700], Loss: 6.0830


100%|█████████████████████████████████████████| 782/782 [08:10<00:00,  1.59it/s]
100%|███████████████████████████████████████████| 79/79 [00:50<00:00,  1.57it/s]


Validation Loss: 6.1697, Validation Accuracy: 1.93%


 13%|█████▏                                   | 100/782 [01:05<04:24,  2.58it/s]

Epoch [5/10], Batch [100], Loss: 6.0083


 26%|██████████▍                              | 200/782 [02:08<04:18,  2.25it/s]

Epoch [5/10], Batch [200], Loss: 6.0161


 38%|███████████████▋                         | 300/782 [03:11<03:26,  2.33it/s]

Epoch [5/10], Batch [300], Loss: 5.9420


 51%|████████████████████▉                    | 400/782 [04:14<02:52,  2.21it/s]

Epoch [5/10], Batch [400], Loss: 5.9242


 64%|██████████████████████████▏              | 500/782 [05:16<01:52,  2.50it/s]

Epoch [5/10], Batch [500], Loss: 5.8773


 77%|███████████████████████████████▍         | 600/782 [06:19<01:15,  2.41it/s]

Epoch [5/10], Batch [600], Loss: 5.8557


 90%|████████████████████████████████████▋    | 700/782 [07:21<00:32,  2.53it/s]

Epoch [5/10], Batch [700], Loss: 5.7754


100%|█████████████████████████████████████████| 782/782 [08:10<00:00,  1.59it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:50<00:00,  1.56it/s]


Validation Loss: 5.6291, Validation Accuracy: 3.43%


 13%|███████████▉                                                                                 | 100/782 [01:06<04:37,  2.45it/s]

Epoch [6/10], Batch [100], Loss: 5.4171


 26%|███████████████████████▊                                                                     | 200/782 [02:09<04:15,  2.28it/s]

Epoch [6/10], Batch [200], Loss: 5.0826


 38%|███████████████████████████████████▌                                                         | 299/782 [03:11<04:03,  1.98it/s]

Epoch [6/10], Batch [300], Loss: 4.5067


 51%|███████████████████████████████████████████████▍                                             | 399/782 [04:13<03:13,  1.98it/s]

Epoch [6/10], Batch [400], Loss: 4.1271


 64%|███████████████████████████████████████████████████████████▍                                 | 500/782 [05:16<02:03,  2.28it/s]

Epoch [6/10], Batch [500], Loss: 3.8600


 77%|███████████████████████████████████████████████████████████████████████▎                     | 600/782 [06:18<01:16,  2.39it/s]

Epoch [6/10], Batch [600], Loss: 3.6631


 89%|███████████████████████████████████████████████████████████████████████████████████▏         | 699/782 [07:21<00:43,  1.90it/s]

Epoch [6/10], Batch [700], Loss: 3.4043


100%|█████████████████████████████████████████████████████████████████████████████████████████████| 782/782 [08:10<00:00,  1.59it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:50<00:00,  1.55it/s]


Validation Loss: 3.7824, Validation Accuracy: 18.58%


 13%|███████████▉                                                                                 | 100/782 [01:05<05:00,  2.27it/s]

Epoch [7/10], Batch [100], Loss: 2.8949


 26%|███████████████████████▊                                                                     | 200/782 [02:08<04:02,  2.40it/s]

Epoch [7/10], Batch [200], Loss: 2.6062


 38%|███████████████████████████████████▋                                                         | 300/782 [03:11<03:34,  2.24it/s]

Epoch [7/10], Batch [300], Loss: 2.3179


 51%|███████████████████████████████████████████████▍                                             | 399/782 [04:14<03:18,  1.93it/s]

Epoch [7/10], Batch [400], Loss: 2.1039


 64%|███████████████████████████████████████████████████████████▍                                 | 500/782 [05:16<01:49,  2.58it/s]

Epoch [7/10], Batch [500], Loss: 1.9098


 77%|███████████████████████████████████████████████████████████████████████▎                     | 600/782 [06:19<01:20,  2.26it/s]

Epoch [7/10], Batch [600], Loss: 1.7420


 89%|███████████████████████████████████████████████████████████████████████████████████▏         | 699/782 [07:21<00:45,  1.84it/s]

Epoch [7/10], Batch [700], Loss: 1.6367


100%|█████████████████████████████████████████████████████████████████████████████████████████████| 782/782 [08:11<00:00,  1.59it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:50<00:00,  1.56it/s]


Validation Loss: 1.7197, Validation Accuracy: 53.48%


 13%|███████████▉                                                                                  | 99/782 [01:05<06:18,  1.80it/s]

Epoch [8/10], Batch [100], Loss: 1.4394


 25%|███████████████████████▋                                                                     | 199/782 [02:08<04:50,  2.01it/s]

Epoch [8/10], Batch [200], Loss: 1.3619


 33%|██████████████████████████████▎                                                              | 255/782 [02:43<04:30,  1.95it/s]