# JigSaw pretext task
Following [the original implementation](https://arxiv.org/pdf/1603.09246) <br>
Also useful to look at [the FAIR paper](https://arxiv.org/pdf/1905.01235) (page 12), for details on the implementation.<br>
Adapted to use ResNet18 instead of CFN.
TODO
- Organize into separate .py modules for JigSaw utils
- Create runner script that can use up to 4 GPUs for faster training.
- Add ViT
- Change training dataset to something w/ resolution of ~255x255 to avoid the need to upscate data
- Add the evaluations -> basically take the (pretrained) resnet module and plug it into another module w/ a clean classification head 
    - linear probing
    - full ft

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from torchvision import transforms, models
from PIL import Image
import numpy as np
import os
import random
from tqdm import tqdm
from datasets import load_dataset

In [2]:
# Set random seed for reproducibility -> maybe use pytorch lightning for reproducibility
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x150d94bcedb0>

## Load Tiny-ImageNet from HF

In [3]:
tinyImageNet_dataset = load_dataset("zh-plus/tiny-imagenet")
# We can also download it from here http://cs231n.stanford.edu/tiny-imagenet-200.zip but i think HFs its easier

In [4]:
tinyImageNet_dataset

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 100000
    })
    valid: Dataset({
        features: ['image', 'label'],
        num_rows: 10000
    })
})

## Prepare dataset

In [5]:
def generate_permutations(n_permutations, n_tiles):
    """
    Generates a list of permutations, these will essentially be the 'gold truth' labels
    """
    permutations = []
    seen = set()
    while len(permutations) < n_permutations:
        perm = tuple(np.random.permutation(n_tiles))
        if perm not in seen:
            permutations.append(perm)
            seen.add(perm)
    return permutations

In [6]:
# Generate n_permutations permutations for 9 tiles, these values are related to the complexity of the task and should be updated as necessary
n_permutations = 1000 ##essentially the number of classes
n_tiles = 9
permutations = generate_permutations(n_permutations, n_tiles)

# Enumerate and store permutations
permutations_dict = {i: perm for i, perm in enumerate(permutations)} 

In [7]:
permutations[:3] 
print('These indices will basically be the possible "gold" labels :p')

These indices will basically be the possible "gold" labels :p


In [8]:
class JigsawPuzzleDataset(data.Dataset):
    def __init__(self, hf_dataset, permutations, transform=None):
        """
        Input:
            hf_dataset: HuggingFace Dataset object.
            permutations: List of permutations.
            transform: Optional transform to be applied on a sample.
        """
        self.dataset = hf_dataset
        self.permutations = permutations
        self.n_permutations = len(permutations)
        self.n_tiles = 9  # 3x3 grid -> this can be modified later
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        # Load image from the HuggingFace dataset and convert to RGB
        image = self.dataset[idx]['image'].convert('RGB')  # Ensure image is in RGB

        # We'll resize it to 255x255 since this is ResNet's input size
        image = image.resize((255, 255))

        # Divide the image into 3x3 grid of tiles (85x85 pixels each)
        tiles = []
        tile_size = 85 # 85 * 3 = 255

        ## Iterate over possible tiles and create the patches
        for i in range(3):
            for j in range(3):

                '''
                Original paper explanation: We randomly crop a 225 × 225 pixel window from an image (red dashed box), divide it into a 3 × 3 grid, and randomly pick a 64 × 64 pixel tiles from each 75 × 75 pixel cell.
                '''

                # Get boundaries and crop
                left = j * tile_size
                upper = i * tile_size
                right = left + tile_size
                lower = upper + tile_size
                tile = image.crop((left, upper, right, lower))
                
                # Random crop of 64x64 pixels with random shifts
                shift_max = tile_size - 64  # Max shift to introduce randomness
                left_shift = random.randint(0, shift_max)
                upper_shift = random.randint(0, shift_max)
                tile = tile.crop((left_shift, upper_shift, left_shift + 64, upper_shift + 64))
                
                # Apply any transform passed as argument
                if self.transform is not None:
                    tile = self.transform(tile)
                    
                tiles.append(tile)
                
        # Select a random permutation from the pre-computed permutations
        perm_idx = random.randint(0, self.n_permutations - 1)
        perm = self.permutations[perm_idx]

        # Shuffle the tiles according to the permutation
        shuffled_tiles = [tiles[p] for p in perm]

        # Stack tiles into a tensor
        tiles_tensor = torch.stack(shuffled_tiles, dim=0)  # Shape: [9, 3, 64, 64]

        # Return the shuffled tiles and the permutation index which is the gold label we aim the model to predict
        return tiles_tensor, perm_idx


In [9]:
transform = transforms.Compose([transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),transforms.ToTensor()])

## Instance the torch module

In [10]:
class JigsawNet(nn.Module):
    def __init__(self, 
                 n_permutations,
                 architecture = 'resnet', # 'resnet' or 'vit'
                ):
        
        super(JigsawNet, self).__init__()

        if architecture=='resnet':
            # Backbone ResNet model TODO: replace by ResNet 50
            # self.resnet = models.resnet18(pretrained=False) # I thnk this is deprecated
            self.resnet = models.resnet18() 
            self.resnet.fc = nn.Identity()  #Remove the classification layer
            
        elif architecture=='vit':
            pass ##TODO

        # Fully connected layers << to dispose after the PTT
        self.fc = nn.Sequential(
            nn.Linear(512 * 9, 4096), # each genertaes a 512-dimensional vector
            nn.ReLU(),
            nn.Linear(4096, n_permutations)
        )

    def forward(self, x):
        # x shape: [batch_size, 9, 3, 64, 64]
        batch_size = x.size(0)

        # Combine batch and tile dimensions (siamese network -> feed the same weights all the patches at once)
        x = x.view(batch_size * 9, 3, 64, 64)  
        features = self.resnet(x)  # Shape: [batch_size * 9, 512]

        # Concatenate the patches before the linear layers that learns the differences
        features = features.view(batch_size, 9 * 512)  # Shape: [batch_size, 9 * 512]

        #
        out = self.fc(features)  # Shape: [batch_size, n_permutations]
        return out

## Training
Hyperparameters should be adapted

In [11]:
model = JigsawNet(n_permutations=n_permutations)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

In [12]:
# Create the datasets and dataloaders
train_dataset = JigsawPuzzleDataset(tinyImageNet_dataset['train'], permutations, transform=transform)
valid_dataset = JigsawPuzzleDataset(tinyImageNet_dataset['valid'], permutations, transform=transform)

train_loader = data.DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=4)
valid_loader = data.DataLoader(valid_dataset, batch_size=256, shuffle=False, num_workers=4)



In [13]:
print('Shape of dataset output: {}'.format(next(iter(train_loader))[0].shape))

Shape of dataset output: torch.Size([256, 9, 3, 64, 64])


In [17]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)
num_epochs = 30
log_each = 100

In [21]:
model_name = 'jigsaw_rn18_tinyimnt'

In [22]:
for epoch in range(num_epochs):
    model.train()
    avg_loss = 0.0
    for batch_idx, (tiles, perm_idx) in enumerate(tqdm(train_loader)):
        tiles = tiles.to(device)  # Shape: [batch_size, 9, 3, 64, 64]
        perm_idx = perm_idx.to(device)  # Shape: [batch_size]

        optimizer.zero_grad()
        outputs = model(tiles)  # Shape: [batch_size, n_permutations]
        loss = criterion(outputs, perm_idx)
        loss.backward()
        optimizer.step()

        avg_loss += loss.item()
        if batch_idx % log_each == log_each - 1:
            print(f'Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx+1}], Loss: {avg_loss / 100:.4f}')
            avg_loss = 0.0
            
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for tiles, perm_idx in tqdm(valid_loader):
            tiles = tiles.to(device)
            perm_idx = perm_idx.to(device)

            outputs = model(tiles)
            loss = criterion(outputs, perm_idx)
            val_loss += loss.item()

            _, predicted = torch.max(outputs.data, 1)
            total += perm_idx.size(0)
            correct += (predicted == perm_idx).sum().item()

    val_accuracy = 100 * correct / total
    avg_val_loss = val_loss / len(valid_loader)
    print(f'Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%')
    if val_accuracy > best_accuracy:
        print('Saving checkpoint')
        torch.save(model.state_dict(), f'{model_name}.pth')


 26%|███████████████████▋                                                         | 100/391 [02:32<04:49,  1.01it/s]

Epoch [18/30], Batch [100], Loss: 0.8026


 51%|███████████████████████████████████████▍                                     | 200/391 [05:02<03:11,  1.00s/it]

Epoch [18/30], Batch [200], Loss: 0.8180


 65%|██████████████████████████████████████████████████▍                          | 256/391 [06:30<03:25,  1.53s/it]


KeyboardInterrupt: 

In [23]:
best_accuracy

65.96