In [1]:
#!/usr/bin/env python3

# This script pulls the feature maps from the specified layer of the CNN for each subject runs
# dimensionality reduction on them using incremental PCA. Can take a while and can be adapted

import os
# conda
# Limit the number of CPUs used to 2
# os.environ["OMP_NUM_THREADS"] = "1" # For layer 0 and 2 try to limit it to 1, so that there is no multi-threading issue

import sys
import numpy as np
import torch
import torch.nn as nn
import argparse
import joblib
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torchvision.models.feature_extraction import (
    create_feature_extractor,
    get_graph_node_names,
)
from tqdm import tqdm
from sklearn.decomposition import IncrementalPCA
from torchvision import models
from typing import Dict, Tuple, Union, Optional

os.chdir("/home/rfpred")
sys.path.append("/home/rfpred/")
sys.path.append("/home/rfpred/envs/rfenv/lib/python3.11/site-packages/")
sys.path.append("/home/rfpred/envs/rfenv/lib/python3.11/site-packages/nsdcode")

from classes.natspatpred import NatSpatPred
NSP = NatSpatPred()
NSP.initialise()


Naturalistic Spatial Prediction class: [97mInitialised[0m

Class contains the following attributes:
[34m .analyse[0m
[34m .attributes[0m
[34m .cortex[0m
[34m .datafetch[0m
[34m .explore[0m
[34m .hidden_methods[0m
[34m .initialise[0m
[34m .nsd_datapath[0m
[34m .own_datapath[0m
[34m .stimuli[0m
[34m .subjects[0m
[34m .utils[0m


In [5]:
# Argparse arguments
pca_fit_batch = 1000
n_comps = 1000
cnn_layer = 0

In [9]:
prf_region = "center_strict"

# Load the pretrained AlexNet model
# model = models.vgg16_bn(pretrained=True)
model = models.vgg16(pretrained=True)
modeltype = model._get_name()
model.eval()  # Set the model to evaluation mode




VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

In [6]:
class ImageDataset(Dataset):
    def __init__(self, image_ids, transform=None, crop: bool = True):
        self.image_ids = image_ids
        self.transform = transform
        self.crop = crop

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        img_id = self.image_ids[idx]
        if self.crop:
            imgnp = NSP.stimuli.show_stim(img_no=img_id, hide=True, small=True, crop=False)[0][
                163:263, 163:263
            ]  # I CROP THEM, YOU SEE
        else:
            imgnp = NSP.stimuli.show_stim(img_no=img_id, hide=True, small=True, crop=False)[0]

        imgPIL = Image.fromarray(imgnp)  # Convert into PIL from np

        if self.transform:
            imgPIL = self.transform(imgPIL)

        return imgPIL


preprocess = transforms.Compose(
    [
        transforms.Resize((224, 224)),  # resize the images to 224x24 pixels
        transforms.ToTensor(),  # convert the images to a PyTorch tensor
        transforms.Normalize(
            [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
        ),  # normalize the images color channels
    ]
)

In [22]:
dense_ices = [0, 3, 6,]
feature_extractor = create_feature_extractor(model, return_nodes=["classifier.6"]) # Here the layer is specified !!!!

feature_extractor

VGG(
  (features): Module(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ce

In [7]:

train_nodes, _ = get_graph_node_names(model)
print(train_nodes)

this_layer = train_nodes[cnn_layer + 1] #if cnn_layer != "norm" else "x"

# Which layer to extract the features from # Also add this as argparse thing.
# model_layer = "features.2" #@param ["features.2", "features.5", "features.7", "features.9", "features.12", "classifier.2", "classifier.5", "classifier.6"] {allow-input: true}

# if cnn_layer != "norm":
feature_extractor = create_feature_extractor(model, return_nodes=[this_layer]) # Here the layer is specified !!!!

train_batch = pca_fit_batch
apply_batch = 500  # The image batch over which the fitted PCA is applied later on.
fixed_n_comps = n_comps
crop_imgs = True #IMPORTANT!!!!!!!!!!

# image_ids = get_imgs_designmx()[subject][start:end] # This was for subject-specific image indices. Current line (below) is for all images.
image_ids = list(range(0, train_batch))
dataset = ImageDataset(image_ids, transform=preprocess, crop=False) # CHECK THIS CROP ARG
dataloader = DataLoader(dataset, batch_size=train_batch, shuffle=False)

['x', 'features.0', 'features.1', 'features.2', 'features.3', 'features.4', 'features.5', 'features.6', 'features.7', 'features.8', 'features.9', 'features.10', 'features.11', 'features.12', 'features.13', 'features.14', 'features.15', 'features.16', 'features.17', 'features.18', 'features.19', 'features.20', 'features.21', 'features.22', 'features.23', 'features.24', 'features.25', 'features.26', 'features.27', 'features.28', 'features.29', 'features.30', 'avgpool', 'flatten', 'classifier.0', 'classifier.1', 'classifier.2', 'classifier.3', 'classifier.4', 'classifier.5', 'classifier.6']


#### Functions for extracting features, and fitting the pca

In [10]:

def extract_features(feature_extractor, dataloader, pca, cnn_layer: int|str):
    while True:  # Keep trying until successful
        try:
            features = []
            for i, d in tqdm(enumerate(dataloader), total=len(dataloader)):

                ft = feature_extractor(d)
                # Flatten the features
                ft = torch.hstack([torch.flatten(l, start_dim=1) for l in ft.values()])

                # Print out some summary statistics of the features
                print(
                    f"AlexNet layer: {cnn_layer}, Mean: {ft.mean()}, Std: {ft.std()}, Min: {ft.min()}, Max: {ft.max()}"
                )

                # Check if the features contain NaN values
                if np.isnan(ft.detach().numpy()).any():
                    raise ValueError("NaN value detected")

                # Check for extreme outliers
                if (ft.detach().numpy() < -100000).any() or (
                    ft.detach().numpy() > 100000
                ).any():
                    raise ValueError("Extreme outlier detected before PCA fit")

                # Apply PCA transform
                ft = pca.transform(ft.cpu().detach().numpy())
                features.append(ft)
            return np.vstack(features)  # Return the features
        except ValueError as e:
            print(f"Error occurred: {e}")
            print("Restarting feature extraction...")



def extract_features_and_check(d, feature_extractor, cnn_layer):
    while True:  # Keep trying until successful
        try:
            
            # Extract features
            ft = feature_extractor(d)
            # Flatten the features
            ft = torch.hstack([torch.flatten(l, start_dim=1) for l in ft.values()])

            # Check for NaN values
            if np.isnan(ft.detach().numpy().any()):
                raise ValueError("NaN value detected before PCA fit")

            # Check for extreme outliers
            if (ft.detach().numpy() < -100000).any() or (ft.detach().numpy() > 100000).any():
                raise ValueError("Extreme outlier detected before PCA fit")

            return ft  # If everything is fine, return the features

        except ValueError as e:
            print(f"Error occurred: {e}")
            print("Restarting feature extraction...")


def fit_pca(
    feature_extractor,
    dataloader,
    pca_save_path=None,
    fixed_n_comps: Optional[int] = None,
    train_batch: int = None,
    cnn_layer: int|str = None,
):
    # Define PCA parameters
    pca = IncrementalPCA(n_components=None, batch_size=train_batch)

    try:
        if fixed_n_comps is None:
            # Fit PCA to batch to determine number of components
            print(
                "Determining the number of components to maintain 95% of the variance..."
            )
            for _, d in tqdm(enumerate(dataloader), total=len(dataloader)):
                ft = extract_features_and_check(d, feature_extractor, cnn_layer)
                # Fit PCA to batch
                pca.partial_fit(ft.detach().cpu().numpy())

            # Calculate cumulative explained variance ratio
            cumulative_var_ratio = np.cumsum(pca.explained_variance_ratio_)
            # Find the number of components to maintain 95% of the variance
            n_comps = np.argmax(cumulative_var_ratio >= 0.95) + 1
            print(f"Number of components to maintain 95% of the variance: {n_comps}")

        else:
            n_comps = fixed_n_comps
            print(f"Using fixed number of components: {n_comps}")

        # Set the number of components
        pca = IncrementalPCA(n_components=n_comps, batch_size=train_batch)

        # Fit PCA to the entire dataset
        print("Fitting PCA with determined number of PCs to batch...")
        for _, d in tqdm(enumerate(dataloader), total=len(dataloader)):
            ft = extract_features_and_check(d, feature_extractor, cnn_layer) # cnn_layer arg not used in function
            # Fit PCA to batch
            pca.partial_fit(ft.detach().cpu().numpy())

        # Save the fitted PCA object if specified
        if pca_save_path:
            print(f"Saving fitted PCA object to: {pca_save_path}")
            joblib.dump(pca, pca_save_path)

        # Return the fitted PCA object
        print("PCA fitting completed.")
        return pca

    except Exception as e:
        print(f"Error occurred: {e}")
        print("PCA fitting failed.")
        return None

### Boolean argument to include dense layers

In [11]:
dense = True

In [None]:
dense_str = "dense/" if dense else ""
os.makedirs(f"{NSP.own_datapath}/visfeats/cnn_featmaps/{modeltype}/{dense_str}", exist_ok=True)

smallpatch_str = "smallpatch_" if crop_imgs else ""

# Fit PCA and get the fitted PCA object
pca = fit_pca(
    feature_extractor,
    dataloader,
    # pca_save_path=f"/home/rfpred/data/custom_files/visfeats/cnn_featmaps/pca_{cnn_layer}_{fixed_n_comps}pcs.joblib",
    pca_save_path=f"{NSP.own_datapath}/visfeats/cnn_featmaps/{modeltype}/{dense_str}pca_{smallpatch_str}{cnn_layer}_{fixed_n_comps}pcs.joblib",
    fixed_n_comps=fixed_n_comps,
    train_batch=train_batch,
    cnn_layer=cnn_layer,
    )

del dataloader, dataset

In [None]:
# Redefine the dataset and dataloader with the entire image set to apply the fitted PCA to.
all_img_ids = list(range(0, 73000))  # All the NSD images
# all_img_ids = list(NSP.stimuli.imgs_designmx()["subj01"]) # If it still is too heavy
full_dataset = ImageDataset(all_img_ids, transform=preprocess, crop=False)
full_dataloader = DataLoader(full_dataset, batch_size=apply_batch, shuffle=False)

# Check if PCA fitting was successful
if pca is not None:
    # Apply the fitted PCA to the rest of the dataset
    features_algo = extract_features(
        feature_extractor, full_dataloader, pca, cnn_layer
    )
else:
    print("PCA fitting failed. Unable to apply PCA, fock.")

# np.savez(
#     # f"/home/rfpred/data/custom_files/visfeats/cnn_featmaps/featmaps/featmaps_lay{this_layer}.npz",
#     f"/home/rfpred/data/custom_files/visfeats/cnn_featmaps/featmaps/featmaps_smallpatch_lay{this_layer}.npz",
#     *features_algo,
# )

os.makedirs(f"{NSP.own_datapath}/visfeats/cnn_featmaps/{modeltype}/featmaps/", exist_ok=True)

np.savez(
    # f"/home/rfpred/data/custom_files/visfeats/cnn_featmaps/featmaps/featmaps_lay{this_layer}.npz",
    f"{NSP.own_datapath}/visfeats/cnn_featmaps/{modeltype}/featmaps/featmaps_{smallpatch_str}lay{this_layer}.npz",
    *features_algo,
)
