In [None]:
# Import modules
from torchvision import models
from torchvision import transforms
from torchvision.models.feature_extraction import get_graph_node_names
from torchvision.models.feature_extraction import create_feature_extractor
import torch
import random
from torch.autograd import Variable as V
import os
import numpy as np
from PIL import Image
from tqdm import tqdm
import pickle
import argparse

init = True
images_dir = "/scratch/alexandel91/single_frames"
save_dir = "/scratch/agnek95/Unreal/CNN_activations_redone/2D_ResNet18/extracted/"

In [None]:

# STEP 1: LOAD RESNET2D MODEL #
# Set random seeds (especially important for random initialization)
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

# Set the architecture to use
arch = "resnet18"

model_file = "%s_places365.pth.tar" % arch

if not os.access(model_file, os.W_OK):
    weight_url = (
        "http://places2.csail.mit.edu/models_places365/" + model_file
    )
    os.system("wget " + weight_url)

# New syntax
model = models.resnet18(num_classes=365, weights=None)
if init:  # Initialize model with pre-trained weights
    save_dir = save_dir + "_pretrained"
    checkpoint = torch.load(
        model_file, map_location=lambda storage, loc: storage
    )
    state_dict = {
        str.replace(k, "module.", ""): v
        for k, v in checkpoint["state_dict"].items()
    }
    model.load_state_dict(state_dict)

model.eval()

# STEP 2: DEFINE DATA VARIABLES #
# number of images
num_videos = 1440

# load the image transformer to transform the image to the required format
centre_crop = transforms.Compose(
    [
        transforms.Resize((256, 256)),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

# --------------------------------------
# STEP 3: EXTRACT UNIT ACTIVATIONS #
# --------------------------------------
return_layers = ["layer1.0.relu_1", "layer1.1.relu_1", "layer2.0.relu_1", "layer2.1.relu_1", "layer3.0.relu_1", "layer3.1.relu_1", "layer4.0.relu_1", "layer4.1.relu_1"]

# First dimension refers to # of feature maps or channels
layer1_0_s = np.zeros((64, 56, 56))
layer1_1_s = np.zeros((64, 56, 56))
layer2_0_s = np.zeros((128, 28, 28))
layer2_1_s = np.zeros((128, 28, 28))
layer3_0_s = np.zeros((256, 14, 14))
layer3_1_s = np.zeros((256, 14, 14))
layer4_0_s = np.zeros((512, 7, 7))
layer4_1_s = np.zeros((512, 7, 7))

# Number of units per layer
num_col_1_0 = 64*56*56
num_col_1_1 = 64*56*56
num_col_2_0 = 128*28*28
num_col_2_1 = 128*28*28
num_col_3_0 = 256*14*14
num_col_3_1 = 256*14*14
num_col_4_0 = 512*7*7
num_col_4_1 = 512*7*7

layer1_0_features = np.zeros((1440, num_col_1_0))
layer1_1_features = np.zeros((1440, num_col_1_1))
layer2_0_features = np.zeros((1440, num_col_2_0))
layer2_1_features = np.zeros((1440, num_col_2_1))
layer3_0_features = np.zeros((1440, num_col_3_0))
layer3_1_features = np.zeros((1440, num_col_3_1))
layer4_0_features = np.zeros((1440, num_col_4_0))
layer4_1_features = np.zeros((1440, num_col_4_1))

# Extract features
train_nodes, eval_nodes = get_graph_node_names(model)

# checker whether nodes are same for training and evaluation mode
assert [t == e for t, e in zip(train_nodes, eval_nodes)]

feature_extractor = create_feature_extractor(
    model, return_nodes=return_layers
)

# Loop through all images
print("Extracting features...")
for img in tqdm(range(1, (num_videos + 1))):
    idx = img - 1

    image_dir = os.path.join(images_dir, (f"{img:04}_frame_20.jpg"))

    # Load image
    image = Image.open(image_dir)

    # Preprocess image
    batch_t = V(centre_crop(image).unsqueeze(0))

    # apply those features on image
    with torch.no_grad():
        out = feature_extractor(batch_t)

    for _, layer in enumerate(return_layers):
        # pick layer
        feat_maps = out[layer].numpy().squeeze(0)

        if layer == "layer1.0.relu_1":
            for fm in range(len(list(feat_maps))):
                flatten_fm = feat_maps[fm]
                layer1_0_s[fm, :, :] = flatten_fm
            layer1_0_features[idx, :] = layer1_0_s.flatten()
        elif layer == "layer1.1.relu_1":
            for fm in range(len(list(feat_maps))):
                flatten_fm = feat_maps[fm]
                layer1_1_s[fm, :, :] = flatten_fm
            layer1_1_features[idx, :] = layer1_1_s.flatten()
        elif layer == "layer2.0.relu_1":
            for fm in range(len(list(feat_maps))):
                flatten_fm = feat_maps[fm]
                layer2_0_s[fm, :, :] = flatten_fm
            layer2_0_features[idx, :] = layer2_0_s.flatten()
        elif layer == "layer2.1.relu_1":
            for fm in range(len(list(feat_maps))):
                flatten_fm = feat_maps[fm]
                layer2_1_s[fm, :, :] = flatten_fm
            layer2_1_features[idx, :] = layer2_1_s.flatten()
        elif layer == "layer3.0.relu_1":
            for fm in range(len(list(feat_maps))):
                flatten_fm = feat_maps[fm]
                layer3_0_s[fm, :, :] = flatten_fm
            layer3_0_features[idx, :] = layer3_0_s.flatten()
        elif layer == "layer3.1.relu_1":
            for fm in range(len(list(feat_maps))):
                flatten_fm = feat_maps[fm]
                layer3_1_s[fm, :, :] = flatten_fm
            layer3_1_features[idx, :] = layer3_1_s.flatten()
        elif layer == "layer4.0.relu_1":
            for fm in range(len(list(feat_maps))):
                flatten_fm = feat_maps[fm]
                layer4_0_s[fm, :, :] = flatten_fm
            layer4_0_features[idx, :] = layer4_0_s.flatten()
        elif layer == "layer4.1.relu_1":
            for fm in range(len(list(feat_maps))):
                flatten_fm = feat_maps[fm]
                layer4_1_s[fm, :, :] = flatten_fm
            layer4_1_features[idx, :] = layer4_1_s.flatten()

# --------------------------------------
# STEP 4: SAVE FEATURES WITHOUT PCA #
# --------------------------------------
# Save each layer separately
features = {
    "layer1.0.relu_1": layer1_0_features,
    "layer1.1.relu_1": layer1_1_features,
    "layer2.0.relu_1": layer2_0_features,
    "layer2.1.relu_1": layer2_1_features,
    "layer3.0.relu_1": layer3_0_features,
    "layer3.1.relu_1": layer3_1_features,
    "layer4.0.relu_1": layer4_0_features,
    "layer4.1.relu_1": layer4_1_features
}

print(type(features)) 
for layer in features.keys():
    print("Check")
    features_dir = save_dir + "/" + "new_features_resnet_" + layer + ".pkl"

    if not os.path.isdir(save_dir):
        os.makedirs(save_dir)

    with open(features_dir, "wb") as f:
        pickle.dump(features[layer], f)