# Image Feature Pair Extract - CLIP, ResNet18. 
conda activate clip


clip_image_features_list (118287, 512)
target_image_features_list (118287, 512)
clip_image_features_list (5000, 512)
target_image_features_list (5000, 512)

Feature extraction complete in 6m 16s

In [21]:
import numpy as np
import torch
import pickle
import time
print("Torch version:", torch.__version__)

assert torch.__version__.split(".") >= ["1", "7", "1"], "PyTorch 1.7.1 or later is required"

import os
import matplotlib.pyplot as plt
from collections import OrderedDict
import torch

%matplotlib inline
%config InlineBackend.figure_format = 'retina'

Torch version: 1.7.1


# Load CLIP

In [22]:
import clip

clip.available_models()

['RN50', 'RN101', 'RN50x4', 'RN50x16', 'ViT-B/32', 'ViT-B/16']

In [23]:
# ViT-B-32.json
# copied from https://github.com/mlfoundations/open_clip/blob/91f6cce16b7bee90b3b5d38ca305b5b3b67cc200/src/training/model_configs/ViT-B-32.json
model_info =  {
    "embed_dim": 512,
    "image_resolution": 224,
    "vision_layers": 12,
    "vision_width": 768,
    "vision_patch_size": 32,
    "context_length": 77,
    "vocab_size": 49408,
    "transformer_width": 512,
    "transformer_heads": 8,
    "transformer_layers": 12
} 

In [24]:
from torchvision import transforms
input_size = model_info['image_resolution']
preprocess = transforms.Compose([
        transforms.Resize(input_size),
        transforms.CenterCrop(input_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

In [25]:
type(preprocess)

torchvision.transforms.transforms.Compose

# Load Data

In [26]:
import torchvision
from torch.utils.data import DataLoader

def target_transform(caption_list):
    caption = caption_list[0] # only the first caption
    return clip.tokenize(caption)[0]

# coco_train_dataset = torchvision.datasets.CocoCaptions(
#                         root = '/home/ubuntu/data/coco/train2017',
#                         annFile = '/home/ubuntu/data/coco/annotations/captions_train2017.json',
#                         transform=preprocess,
#                         target_transform=target_transform,
#                         )

coco_val_dataset = torchvision.datasets.CocoCaptions(
                        root = '/home/ubuntu/data/coco/val2017',
                        annFile = '/home/ubuntu/data/coco/annotations/captions_val2017.json',
                        transform=preprocess,
                        target_transform=target_transform,
                        )

loading annotations into memory...
Done (t=0.20s)
creating index...
index created!


In [27]:
# coco_train_dataloader = DataLoader(coco_train_dataset, batch_size=64, shuffle=False, num_workers=8, pin_memory=True)
coco_val_dataloader = DataLoader(coco_val_dataset, batch_size=64, shuffle=False, num_workers=8, pin_memory=True)

# ResNet

In [28]:
import torch
import torch.nn as nn
import torchvision.models as models
from torch.autograd import Variable
import random

In [29]:
from clip.model import CLIP
def get_random_init_models(random_seed):

    random.seed(random_seed)
    torch.manual_seed(random_seed)
    np.random.seed(random_seed)

    clip_model = CLIP(**model_info)
    clip_model.cuda().eval()
    resnet18 = models.resnet18(pretrained=False) # resnet18 = models.resnet18(pretrained=True)
    modules=list(resnet18.children())[:-1]
    resnet18=nn.Sequential(*modules)
    for p in resnet18.parameters():
        p.requires_grad = False

    resnet18.cuda().eval()
    target_model = resnet18
    return clip_model, target_model


# Extractor loop


In [30]:
since = time.time()
dataloaders = {
    # 'train': coco_train_dataloader, 
    'val': coco_val_dataloader,
}
# Each epoch has a training and validation phase
for expriment_idx  in range(25):
    phase = 'val'
    clip_model, target_model = get_random_init_models(random_seed=expriment_idx)
    clip_model.eval()   # Set model to evaluate mode, for extraction
    ##################################
    # Fields to be stored for postprocessing 
    ##################################
    clip_image_features_list = []
    clip_text_features_list = []
    target_image_features_list = []

    data_seed = 42 # Always using the same data
    random.seed(data_seed)
    torch.manual_seed(data_seed)
    np.random.seed(data_seed)
    # Iterate over data.
    for inputs, captions in dataloaders[phase]:
        # image_input = inputs.cuda(non_blocking=True)
        # text_input = captions.cuda(non_blocking=True)

        batch_size = len(captions)
        image_input = torch.randn((batch_size, 3, 224, 224)).cuda(non_blocking=True)
        text_input = torch.randint(0, 49408, (batch_size, 77)).cuda(non_blocking=True)

        with torch.set_grad_enabled(False):
            clip_image_features = clip_model.encode_image(image_input).float()
            clip_text_features = clip_model.encode_text(text_input).float()
            target_image_features = target_model(image_input).squeeze() 
            ##################################
            # Evaluation book-keeping Field 
            ##################################
            clip_image_features_list.append( clip_image_features.cpu().numpy() )
            clip_text_features_list.append( clip_text_features.cpu().numpy() )
            target_image_features_list.append( target_image_features.cpu().numpy() )

    ##################################
    # Evaluation book-keeping Field 
    ##################################
    clip_image_features_list     = np.concatenate( clip_image_features_list, axis=0)
    clip_text_features_list      = np.concatenate( clip_text_features_list, axis=0)
    target_image_features_list   = np.concatenate( target_image_features_list, axis=0)
    print('clip_image_features_list', clip_image_features_list.shape)
    print('target_image_features_list', target_image_features_list.shape)

    dump_result_dict = {
        "clip_image_features_list":   clip_image_features_list, 
        "clip_text_features_list" :   clip_text_features_list,
        "target_image_features_list": target_image_features_list, 
        }
    with open(os.path.join('features', 'feature_dump_{}.pkl'.format(expriment_idx) ), "wb") as pkl_file:
        pickle.dump(
            dump_result_dict, 
            pkl_file, 
        )

    time_elapsed = time.time() - since
    print('expriment_idx', expriment_idx)
    print('Feature Extraction completed in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))

clip_image_features_list (5000, 512)
target_image_features_list (5000, 512)
expriment_idx 0
Feature Extraction completed in 1m 4s
clip_image_features_list (5000, 512)
target_image_features_list (5000, 512)
expriment_idx 1
Feature Extraction completed in 2m 10s
clip_image_features_list (5000, 512)
target_image_features_list (5000, 512)
expriment_idx 2
Feature Extraction completed in 3m 14s
clip_image_features_list (5000, 512)
target_image_features_list (5000, 512)
expriment_idx 3
Feature Extraction completed in 4m 20s
clip_image_features_list (5000, 512)
target_image_features_list (5000, 512)
expriment_idx 4
Feature Extraction completed in 5m 27s
clip_image_features_list (5000, 512)
target_image_features_list (5000, 512)
expriment_idx 5
Feature Extraction completed in 6m 33s
clip_image_features_list (5000, 512)
target_image_features_list (5000, 512)
expriment_idx 6
Feature Extraction completed in 7m 36s
clip_image_features_list (5000, 512)
target_image_features_list (5000, 512)
exprimen