# Extract Predictions from ImageNet using various `timm` models

In [None]:
import blackhc.project.script

In [None]:
import timm

In [None]:
import os

In [None]:
os.getcwd()

In [1]:
import os

In [4]:
os.getcwd()

'/home/blackhc/PycharmProjects/gde_repro/notebooks'

In [3]:
from timm.data.transforms_factory import transforms_imagenet_eval
from torchvision import transforms

In [4]:
from timm.data import create_dataset, create_loader, resolve_data_config
from timm.data.transforms_factory import create_transform

In [5]:
import torch
import torch.nn

In [6]:
torch.backends.cudnn.benchmark = True

In [7]:
#model = timm.create_model("beit_large_patch16_224", pretrained=True, scriptable=True)
#model = timm.create_model("deit3_large_patch16_224_in21ft1k", pretrained=True, scriptable=True)
#model = timm.create_model("vit_base_patch16_384", pretrained=True, scriptable=True)
#model = timm.create_model("convnext_large_in22ft1k", pretrained=True, scriptable=True)
model = timm.create_model("resnet152d", pretrained=True, scriptable=True)

Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet152d_ra2-5cac0439.pth" to /home/blackhc/.cache/torch/hub/checkpoints/resnet152d_ra2-5cac0439.pth


In [8]:
model.cuda()
model.eval();

In [9]:
model.pretrained_cfg

{'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet152d_ra2-5cac0439.pth',
 'num_classes': 1000,
 'input_size': (3, 256, 256),
 'pool_size': (8, 8),
 'crop_pct': 1.0,
 'interpolation': 'bicubic',
 'mean': (0.485, 0.456, 0.406),
 'std': (0.229, 0.224, 0.225),
 'first_conv': 'conv1.0',
 'classifier': 'fc',
 'test_input_size': (3, 320, 320),
 'architecture': 'resnet152d'}

In [10]:
data_config = resolve_data_config(dict(crop_pct=0.0), model=model)
eval_transform = create_transform(**data_config)

In [11]:
from timm.models import apply_test_time_pool

In [12]:
test_time_pool = True
if test_time_pool:
    model, test_time_pool = apply_test_time_pool(model, data_config)

In [13]:
# Monkey patch timm's forward_features
model.__class__.forward_head = torch.jit.export(model.__class__.forward_head)
model.__class__.forward_features = torch.jit.export(model.__class__.forward_features)

In [14]:
torch.jit.optimized_execution(True)
#jit_model = torch.jit.script(model)
jit_model = model

In [15]:
import os.path

In [16]:
dataset = create_dataset("imagenet", os.path.expanduser("~/imagenet"), "validation")

In [17]:
crop_pct = 1.0 if test_time_pool else data_config["crop_pct"]
loader = create_loader(
    dataset,
    input_size=data_config["input_size"],
    batch_size=96,
    use_prefetcher=True,
    interpolation=data_config["interpolation"],
    mean=data_config["mean"],
    std=data_config["std"],
    num_workers=8,
    crop_pct=crop_pct,
    pin_memory=True,
    tf_preprocessing=False,
)

In [18]:
from tqdm.auto import tqdm

In [19]:
labels = []
predictions = []
with torch.inference_mode():
    for batch_images, batch_labels in tqdm(loader):
        batch_logits = jit_model(batch_images.cuda())
        batch_probs = torch.nn.functional.softmax(batch_logits.cpu(), dim=-1)
        labels += [batch_labels.cpu()]
        predictions += [batch_probs]

labels = torch.cat(labels)
predictions = torch.cat(predictions)

  0%|          | 0/521 [00:00<?, ?it/s]

In [20]:
validation_info = dict(predictions=predictions, labels=labels, pretrained_cfg=model.pretrained_cfg)

torch.save(validation_info, f"imagenet_val_probs_labels_{model.pretrained_cfg['architecture']}.pt")

In [21]:
!ls *.pt

imagenet_val_features_labels_beit_large_patch16_224.pt
imagenet_val_features_labels_convnext_large_in22ft1k.pt
imagenet_val_features_labels_deit3_large_patch16_224_in21ft1k.pt
imagenet_val_features_labels_vit_base_patch16_384.pt
imagenet_val_probs_labels_beit_large_patch16_224.pt
imagenet_val_probs_labels_convnext_large_in22ft1k.pt
imagenet_val_probs_labels_deit3_large_patch16_224_in21ft1k.pt
imagenet_val_probs_labels_resnet152d.pt
imagenet_val_probs_labels_vit_base_patch16_384.pt


In [22]:
import gc

import torch


def gc_cuda():
    """Gargage collect Torch (CUDA) memory."""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

In [23]:
def collect_features(jit_model, loader):
    labels = []
    features = []
    with torch.inference_mode():
        for batch_images, batch_labels in tqdm(loader):
            #print(batch_images.shape)
            batch_features = jit_model.forward_head(jit_model.forward_features(batch_images.cuda()), pre_logits=True)
            #print(batch_features.shape)
            labels += [batch_labels.cpu()]
            features += [batch_features.cpu()]

    gc_cuda()
    #print(batch_features.shape)

    labels = torch.cat(labels)
    features = torch.cat(features)
    return labels, features


labels, features = collect_features(jit_model, loader)

  0%|          | 0/521 [00:00<?, ?it/s]

In [24]:
validation_features_info = dict(features=features, labels=labels, pretrained_cfg=model.pretrained_cfg)

torch.save(validation_features_info, f"imagenet_val_features_labels_{model.pretrained_cfg['architecture']}.pt")

In [87]:
1+1

2

In [11]:
cfg = model.pretrained_cfg
eval_transform = transforms_imagenet_eval(
    img_size=cfg["input_size"][1],
    crop_pct=1,
    interpolation=cfg["interpolation"],
    mean=cfg["mean"],
    std=cfg["std"],
    use_prefetcher=False,
)
print(eval_transform)
# data_config = resolve_data_config(dict(crop_pct=0.), model=model)
# eval_transform=create_transform(**data_config)
eval_transform.transforms = [
    transforms.ToPILImage(),
    transforms.Lambda(lambda img: img.convert("RGB")),
] + eval_transform.transforms

Compose(
    Resize(size=224, interpolation=bicubic, max_size=None, antialias=None)
    CenterCrop(size=(224, 224))
    ToTensor()
    Normalize(mean=tensor([0.5000, 0.5000, 0.5000]), std=tensor([0.5000, 0.5000, 0.5000]))
)


In [12]:
eval_dataloader = ds.pytorch(
    batch_size=128,
    transform=dict(images=eval_transform, labels=None),
    tensors=["images", "labels"],
    shuffle=False,
    pin_memory=True,
    num_workers=8,
    use_progress_bar=False,
    use_local_cache=True,
)

In [15]:
from torchvision.datasets import ImageNet

In [18]:
ImageNet("~/imagenet", split="val")

Dataset ImageNet
    Number of datapoints: 50000
    Root location: /home/blackhc/imagenet
    Split: val

In [17]:
!pip install scipy

Collecting scipy
  Downloading scipy-1.9.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (43.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.9/43.9 MB[0m [31m26.1 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Installing collected packages: scipy
Successfully installed scipy-1.9.1
