# JigSaw pretext task
Following [the original implementation](https://arxiv.org/pdf/1603.09246) <br>
Also useful to look at [the FAIR paper](https://arxiv.org/pdf/1905.01235) (page 12), for details on the implementation.<br>
Adapted to use ResNet18 instead of CFN.
TODO
- Organize into separate .py modules for JigSaw utils
- Create runner script that can use up to 4 GPUs for faster training.
- Add ViT
- Change training dataset to something w/ resolution of ~255x255 to avoid the need to upscate data
- Add the evaluations -> basically take the (pretrained) resnet module and plug it into another module w/ a clean classification head 
    - linear probing
    - full ft

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from torchvision import transforms, models
from PIL import Image
import numpy as np
import os
import random
from tqdm import tqdm
from datasets import load_dataset

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [6]:
# Set random seed for reproducibility -> maybe use pytorch lightning for reproducibility
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x152750083750>

## Load Tiny-ImageNet from HF

In [7]:
tinyImageNet_dataset = load_dataset("zh-plus/tiny-imagenet")
# We can also download it from here http://cs231n.stanford.edu/tiny-imagenet-200.zip but i think HFs its easier

In [8]:
tinyImageNet_dataset

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 100000
    })
    valid: Dataset({
        features: ['image', 'label'],
        num_rows: 10000
    })
})

In [78]:
# REDUCE TRAINING DATA TO K labels
K = 10

train_labels = list(set(tinyImageNet_dataset["train"]["label"]))
random.seed(42)  # For reproducibility
selected_labels = random.sample(train_labels, K)
filtered_train = tinyImageNet_dataset["train"].filter(lambda example: example['label'] in selected_labels)
filtered_valid = tinyImageNet_dataset["valid"].filter(lambda example: example['label'] in selected_labels)
filtered_dataset = {
    "train": filtered_train,
    "valid": filtered_valid
}

In [79]:
filtered_dataset

{'train': Dataset({
     features: ['image', 'label'],
     num_rows: 5000
 }),
 'valid': Dataset({
     features: ['image', 'label'],
     num_rows: 500
 })}

### Loading pretrained model and extracting features


In [11]:
## aux -> model class
class JigsawNet(nn.Module):
    def __init__(self, 
                 n_permutations,
                 architecture = 'resnet', # 'resnet' or 'vit'
                ):
        
        super(JigsawNet, self).__init__()

        if architecture=='resnet':
            # Backbone ResNet model TODO: replace by ResNet 50
            # self.resnet = models.resnet18(pretrained=False) # I thnk this is deprecated
            self.resnet = models.resnet18() 
            self.resnet.fc = nn.Identity()  #Remove the classification layer
            
        elif architecture=='vit':
            pass ##TODO

        # Fully connected layers << to dispose after the PTT
        self.fc = nn.Sequential(
            nn.Linear(512 * 9, 4096), # each genertaes a 512-dimensional vector
            nn.ReLU(),
            nn.Linear(4096, n_permutations)
        )

    def forward(self, x):
        # x shape: [batch_size, 9, 3, 64, 64]
        batch_size = x.size(0)

        # Combine batch and tile dimensions (siamese network -> feed the same weights all the patches at once)
        x = x.view(batch_size * 9, 3, 64, 64)  
        features = self.resnet(x)  # Shape: [batch_size * 9, 512]

        # Concatenate the patches before the linear layers that learns the differences
        features = features.view(batch_size, 9 * 512)  # Shape: [batch_size, 9 * 512]

        #
        out = self.fc(features)  # Shape: [batch_size, n_permutations]
        return out

In [12]:
model = JigsawNet(n_permutations=1).resnet
pretrained_path = '/home/emilio.villa/nlp_local/cv_ptt/jigsaw_rn18_tinyimnt_resnet.pth'
model.load_state_dict(torch.load(pretrained_path, weights_only=True))

<All keys matched successfully>

In [102]:
class ResNetFeatureExtractor(nn.Module):
    '''
    Given a pretrained ResNet model, extract features up to layer N [0,4]
    
    TODO -> update for resnet50
    '''
    def __init__(self,
                 N,
                 pretrained_model = None
                ):
        
        super(ResNetFeatureExtractor, self).__init__()
        self.N = N

        if pretrained_model is None:
            ## instance resnet from scratch
            print('resnet not provided, using random initialization')
            pretrained_model = models.resnet18(weights=None) ##
        else:
            ## instance using previously trained resnet
            print('resnet model provided, using it to initialize features')
            # self.pretrained_model = pretrained_model

        layers = [
            pretrained_model.conv1,
            pretrained_model.bn1,
            pretrained_model.relu,
            pretrained_model.maxpool
        ]

        if N >= 1:
            layers.append(pretrained_model.layer1)
        if N >= 2:
            layers.append(pretrained_model.layer2)
        if N >= 3:
            layers.append(pretrained_model.layer3)
        if N >= 4:
            layers.append(pretrained_model.layer4)

        self.features = nn.Sequential(*layers)
        self.avgpool = pretrained_model.avgpool

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        return x

### Extracting train/test features

In [103]:
#Define the transforms
classification_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
])

# Define custom dataset class
class ClassificationDataset(data.Dataset):
    def __init__(self, hf_dataset, transform=None):
        self.dataset = hf_dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image = self.dataset[idx]['image'].convert('RGB')
        label = self.dataset[idx]['label']
        if self.transform:
            image = self.transform(image)
        return image, label

In [104]:
def extract_features(
    image_dataset,
    feature_extraction_model,
    batch_size = 512,
    device = 'cuda'):
    '''
    Generates numpy array with features given a pretrained resnet model.
    Parameters:
    - image_dataset : Dataset object with the images to extract features from 
        NOTE: this is for the moment teh <ClassificationDataset> which generates images and labels, 
        should use other that only returns images
    - feature_extraction_model : Module to extract features
    - batch_size
    Returns
    - features : Numpy array with extracted features
    '''
    loader = data.DataLoader(
            image_dataset, batch_size=batch_size, shuffle=False, num_workers=0
        )
    feat_list = []
    y_list = [] #lazy way to also extract labels
    # counter = 0 ## delete
    feature_extraction_model.eval()
    with torch.no_grad():
        for images, labels in tqdm(loader): # update this if dataloader is replaced
            images = images.to(device)
            outputs = feature_extraction_model(images)
            feat_list.append(outputs)
            y_list.append(labels)
            
    return torch.cat(feat_list).cpu().numpy(), torch.cat(y_list).cpu().numpy()

In [117]:
# feature_extractor = ResNetFeatureExtractor(N=1, pretrained_model = model).to(device)
feature_extractor = ResNetFeatureExtractor(N=4, pretrained_model = None).to(device)

resnet not provided, using random initialization


In [118]:
feature_extractor.features[0]

Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

In [119]:
# train_classification_dataset = ClassificationDataset(tinyImageNet_dataset['train'], transform=classification_transform)
# val_classification_dataset = ClassificationDataset(tinyImageNet_dataset['valid'], transform=classification_transform)

train_classification_dataset = ClassificationDataset(filtered_dataset['train'], transform=classification_transform)
val_classification_dataset = ClassificationDataset(filtered_dataset['valid'], transform=classification_transform)

In [120]:
X_train, y_train = extract_features(
    image_dataset = train_classification_dataset,
    feature_extraction_model = feature_extractor,
    batch_size = 1024,
    device = 'cuda'
)

X_test, y_test = extract_features(
    image_dataset = val_classification_dataset,
    feature_extraction_model = feature_extractor,
    batch_size = 1024,
    device = 'cuda'
)

100%|██████████| 5/5 [00:06<00:00,  1.39s/it]
100%|██████████| 1/1 [00:00<00:00,  1.49it/s]


In [121]:
print(X_train.shape, y_train.shape)
print(X_test.shape, y_test.shape)

(5000, 512) (5000,)
(500, 512) (500,)


### Fitting SVM classifier op top of this

In [122]:
import numpy as np
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import GridSearchCV
from tqdm import tqdm


# Optional: Scale the features for better SVM performance
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

# Define the parameter grid for GridSearchCV
param_grid = {
    'C': [0.1, 1, 10],           # Regularization parameter
    'kernel': ['linear'],#, 'rbf'],  # Kernel type
    'gamma': ['scale']#, 'auto']    # Kernel coefficient for 'rbf'
}

# Initialize the SVM classifier
svm_classifier = SVC(random_state=42)

# Initialize GridSearchCV with a progress bar
grid_search = GridSearchCV(
    estimator=svm_classifier,
    param_grid=param_grid,
    scoring='accuracy',
    cv=3,                # 3-fold cross-validation
    n_jobs=-1,           # Use all available CPU cores
    verbose=3            # Verbose output to monitor progress
)

# Fit the model on the training data with a progress bar
print("Starting Grid Search...")
grid_search.fit(X_train_scaled, y_train)
print("Grid Search Completed.")

# Get the best estimator
best_svm = grid_search.best_estimator_
print(f"Best Parameters: {grid_search.best_params_}")

# Predict on the test data
y_pred = best_svm.predict(X_test_scaled)

# Compute accuracy
accuracy = accuracy_score(y_test, y_pred)
print(f'Accuracy: {accuracy:.4f}')

# Generate a classification report
report = classification_report(y_test, y_pred)
print('Classification Report:')
print(report)

# Compute the confusion matrix
conf_matrix = confusion_matrix(y_test, y_pred)
print('Confusion Matrix:')
print(conf_matrix)

Starting Grid Search...
Fitting 3 folds for each of 3 candidates, totalling 9 fits
Grid Search Completed.
Best Parameters: {'C': 0.1, 'gamma': 'scale', 'kernel': 'linear'}
Accuracy: 0.4480
Classification Report:
              precision    recall  f1-score   support

           6       0.46      0.52      0.49        50
          26       0.30      0.46      0.37        50
          28       0.44      0.54      0.48        50
          35       0.42      0.38      0.40        50
          57       0.37      0.40      0.38        50
          62       0.59      0.44      0.51        50
          70       0.33      0.30      0.32        50
         163       0.55      0.46      0.50        50
         188       0.58      0.56      0.57        50
         189       0.60      0.42      0.49        50

    accuracy                           0.45       500
   macro avg       0.47      0.45      0.45       500
weighted avg       0.47      0.45      0.45       500

Confusion Matrix:
[[26  5  3 

In [None]:
###
class ClassificationModel(nn.Module):
    def __init__(
        self, 
        num_classes,
        architecture = 'resnet', #'resnet 
        pretrained_model = None,
        ):
        super(ClassificationModel, self).__init__()
        # Use the pretrained ResNet model from the PTT task
        
        if architecture=='resnet':
            # Backbone ResNet model TODO: replace by ResNet 50
            self.features = models.resnet18() 
            # self.resnet.fc = nn.Identity()
            self.features.fc = nn.Identity()  #Remove the classification layer ### IS this necessary??
            
        elif architecture=='vit':
            pass ##TODO
            
        # Classification layer
        self.linear_proj = nn.Linear(512, num_classes)

    def forward(self, x):
        # x shape: [batch_size, 3, H, W]
        features = self.features(x)  # Shape: [batch_size, 512]
        out = self.linear_proj(features)  # Shape: [batch_size, num_classes]
        return out