# Extract Predictions from PACS using various `timm` models

In [1]:
import blackhc.project.script

Appended /home/blackhc/PycharmProjects/gde_repro/src to paths
Switched to directory /home/blackhc/PycharmProjects/gde_repro
%load_ext autoreload
%autoreload 2


In [2]:
import timm

In [3]:
from timm.data.transforms_factory import transforms_imagenet_eval
from torchvision import transforms

In [4]:
from timm.data import create_dataset, create_loader, resolve_data_config
from timm.data.transforms_factory import create_transform

In [5]:
import torch
import torch.nn

In [6]:
torch.backends.cudnn.benchmark = True

In [7]:
#model = timm.create_model("beit_large_patch16_224", pretrained=True, scriptable=True)
#model = timm.create_model("deit3_large_patch16_224_in21ft1k", pretrained=True, scriptable=True)
#model = timm.create_model("vit_base_patch16_384", pretrained=True, scriptable=True)
#model = timm.create_model("convnext_large_in22ft1k", pretrained=True, scriptable=True)
#model = timm.create_model("resnet152d", pretrained=True, scriptable=True)

In [8]:
model.cuda()
model.eval();

In [9]:
model.pretrained_cfg

{'url': 'https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth',
 'num_classes': 1000,
 'input_size': (3, 224, 224),
 'pool_size': (7, 7),
 'crop_pct': 0.875,
 'interpolation': 'bicubic',
 'mean': (0.485, 0.456, 0.406),
 'std': (0.229, 0.224, 0.225),
 'first_conv': 'stem.0',
 'classifier': 'head.fc',
 'architecture': 'convnext_large_in22ft1k'}

In [10]:
data_config = resolve_data_config(dict(crop_pct=0.0), model=model)
eval_transform = create_transform(**data_config)

In [11]:
from timm.models import apply_test_time_pool

In [12]:
test_time_pool = True
if test_time_pool:
    model, test_time_pool = apply_test_time_pool(model, data_config)

In [13]:
# Monkey patch timm's forward_features
model.__class__.forward_head = torch.jit.export(model.__class__.forward_head)
model.__class__.forward_features = torch.jit.export(model.__class__.forward_features)

In [14]:
torch.jit.optimized_execution(True)
#jit_model = torch.jit.script(model)
jit_model = model

In [15]:
import os.path

In [16]:
sub_dataset_names = ["art_painting", "cartoon", "photo", "sketch"]
datasets = {sub_dataset_name: create_dataset("pacs", os.path.expanduser(f"~/datasets/pacs/{sub_dataset_name}")) for sub_dataset_name in sub_dataset_names}

In [17]:
from tqdm.auto import tqdm

In [18]:
import gc

import torch


def gc_cuda():
    """Gargage collect Torch (CUDA) memory."""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

In [19]:
def collect_features(jit_model, loader):
    labels = []
    features = []
    with torch.inference_mode():
        for batch_images, batch_labels in tqdm(loader):
            #print(batch_images.shape)
            batch_features = jit_model.forward_head(jit_model.forward_features(batch_images.cuda()), pre_logits=True)
            #print(batch_features.shape)
            labels += [batch_labels.cpu()]
            features += [batch_features.cpu()]

    gc_cuda()
    #print(batch_features.shape)

    labels = torch.cat(labels)
    features = torch.cat(features)
    return labels, features

In [20]:
crop_pct = 1.0 if test_time_pool else data_config["crop_pct"]

for dataset_name in sub_dataset_names:
    loader = create_loader(
        datasets[dataset_name],
        input_size=data_config["input_size"],
        batch_size=64,
        use_prefetcher=True,
        interpolation=data_config["interpolation"],
        mean=data_config["mean"],
        std=data_config["std"],
        num_workers=8,
        crop_pct=crop_pct,
        pin_memory=True,
        tf_preprocessing=False,
    )
    labels, features = collect_features(jit_model, loader)
    
    validation_features_info = dict(features=features, labels=labels, pretrained_cfg=model.pretrained_cfg, subdataset=dataset_name)
    torch.save(validation_features_info, f"pacs_{dataset_name}_features_labels_{model.pretrained_cfg['architecture']}.pt")

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/27 [00:00<?, ?it/s]

  0%|          | 0/62 [00:00<?, ?it/s]