In [1]:
import os
import h5py
from PIL import Image
from imageio.v2 import imread
import numpy as np

import torch
from torch import nn
import torchvision
from torchvision import models

In [14]:
cnn = getattr(torchvision.models, 'resnet101')(weights='ResNet101_Weights.IMAGENET1K_V2')
layers = [
cnn.conv1,
cnn.bn1,
cnn.relu,
cnn.maxpool,
]
for i in range(3):
    name = 'layer%d' % (i + 1)
    layers.append(getattr(cnn, name))
model = torch.nn.Sequential(*layers)
model
# getattr(cnn, 'layer4')
cnn

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [2]:
def build_model(model_type, model_stage):
    if not hasattr(torchvision.models, model_type):
        raise ValueError('Invalid model "%s"' % model_type)
    if not 'resnet' in model_type:
        raise ValueError('Feature extraction only supports ResNets')
    cnn = getattr(torchvision.models, model_type)(weights='ResNet101_Weights.IMAGENET1K_V2')
    layers = [
    cnn.conv1,
    cnn.bn1,
    cnn.relu,
    cnn.maxpool,
    ]
    for i in range(model_stage):
        name = 'layer%d' % (i + 1)
        layers.append(getattr(cnn, name))
    model = torch.nn.Sequential(*layers)
    model.cuda()
    model.eval()
    return model

In [3]:
def get_input_paths(input_image_dir):
    input_paths = []
    idx_set = set()
    for fn in os.listdir(input_image_dir):
        if not fn.endswith('.png'): continue
        idx = int(os.path.splitext(fn)[0].split('_')[-1])
        input_paths.append((os.path.join(input_image_dir, fn), idx))
        idx_set.add(idx)
    input_paths.sort(key=lambda x: x[1])
    assert len(idx_set) == len(input_paths)
    assert min(idx_set) == 0 and max(idx_set) == len(idx_set) - 1
    if max_images is not None:
        input_paths = input_paths[:max_images]
    print(input_paths[0])
    print(input_paths[-1])
    return input_paths

In [4]:
def run_batch(cur_batch, model):
    mean = np.array([0.485, 0.456, 0.406]).reshape(1, 3, 1, 1)
    std = np.array([0.229, 0.224, 0.224]).reshape(1, 3, 1, 1)

    image_batch = np.concatenate(cur_batch, 0).astype(np.float32)
    image_batch = (image_batch / 255.0 - mean) / std
    image_batch = torch.FloatTensor(image_batch).cuda()

    feats = model(image_batch)
    feats = feats.data.cpu().clone().numpy()

    return feats

**======================================================= Args ==========================================================**

In [5]:
input_image_dirs = ['data/CLEVR_v1.0/images/train', 'data/CLEVR_v1.0/images/val', 'data/CLEVR_v1.0/images/test']
output_h5_files = ['data/train_features.h5', 'data/val_features.h5', 'data/test_features.h5']

max_images = None

image_height = 224
image_width = 224

model_type = 'resnet101'
model_stage = 3
batch_size = 128

**======================================================= Main ==========================================================**

In [6]:
for input_image_dir, output_h5_file in zip(input_image_dirs, output_h5_files):
    with h5py.File(output_h5_file, 'w') as f:
        feat_dset = None
        i0 = 0
        cur_batch = []
        input_paths = get_input_paths(input_image_dir)
        model = build_model(model_type, model_stage)
        for i, (path, idx) in enumerate(input_paths):
            img = Image.open(path)
            img = img.convert('RGB')
            img = img.resize((image_width, image_height), Image.Resampling.BICUBIC)
            img = np.array(img)
            img = img.transpose(2, 0, 1)[None]
            cur_batch.append(img)
            if len(cur_batch) == batch_size:
                feats = run_batch(cur_batch, model)
                if feat_dset is None:
                    N = len(input_paths)
                    _, C, H, W = feats.shape
                    feat_dset = f.create_dataset('features', (N, C, H, W), dtype=np.float32)
                i1 = i0 + len(cur_batch)
                feat_dset[i0:i1] = feats
                i0 = i1
                cur_batch = []
                
        if len(cur_batch) > 0:
            feats = run_batch(cur_batch, model)
            i1 = i0 + len(cur_batch)
            feat_dset[i0:i1] = feats
            
    print('Processed %d / %d images in %s' % (i1, len(input_paths), output_h5_file))

('data/CLEVR_v1.0/images/train\\CLEVR_train_000000.png', 0)
('data/CLEVR_v1.0/images/train\\CLEVR_train_069999.png', 69999)
Processed 70000 / 70000 images in data/train_features.h5
('data/CLEVR_v1.0/images/val\\CLEVR_val_000000.png', 0)
('data/CLEVR_v1.0/images/val\\CLEVR_val_014999.png', 14999)
Processed 15000 / 15000 images in data/val_features.h5
('data/CLEVR_v1.0/images/test\\CLEVR_test_000000.png', 0)
('data/CLEVR_v1.0/images/test\\CLEVR_test_014999.png', 14999)
Processed 15000 / 15000 images in data/test_features.h5
