In [1]:
from PIL import Image
import os
from huggingface_hub import hf_hub_download
import numpy as np
from pycocotools.coco import COCO
import torch
import torch.nn as nn
import torchvision
from tqdm import tqdm

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
class EfficientNetEncoder(nn.Module):
    def __init__(self, c_latent=16, effnet="efficientnet_v2_s"):
        super().__init__()
        if effnet == "efficientnet_v2_s":
            self.backbone = torchvision.models.efficientnet_v2_s(weights='DEFAULT').features.eval()
        else:
            print("Using EffNet L.")
            self.backbone = torchvision.models.efficientnet_v2_l(weights='DEFAULT').features.eval()
        self.mapper = nn.Sequential(
            nn.Conv2d(1280, c_latent, kernel_size=1, bias=False),
            nn.BatchNorm2d(c_latent),  # then normalize them to have mean 0 and std 1
        )

    def forward(self, x):
        return self.mapper(self.backbone(x)).add(1.).div(42.)

In [4]:
image_dir = "../../Data/coco_train2017"
caption_file = "../../Data/coco_captions/captions_train2017.json"

In [5]:
coco_captions = COCO(caption_file)

loading annotations into memory...
Done (t=0.56s)
creating index...
index created!


In [6]:
coco_captions.getImgIds()[0]

391895

In [7]:
init_preprocess = torchvision.transforms.Compose([
    torchvision.transforms.Resize(512),
    torchvision.transforms.CenterCrop(512),
    torchvision.transforms.ToTensor()
])

effnet_preprocess = torchvision.transforms.Compose([
    torchvision.transforms.Resize(384, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True),
    torchvision.transforms.CenterCrop(384),
    torchvision.transforms.Normalize(
        mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
    )
])

effnet_preprocess_768 = torchvision.transforms.Compose([
    torchvision.transforms.Resize(768, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True),
    torchvision.transforms.CenterCrop(768),
    torchvision.transforms.Normalize(
        mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
    )
])

pretrained_checkpoint = torch.load("./models/model_v2_stage_b.pt", map_location=device)

effnet = EfficientNetEncoder(effnet="efficientnet_v2_s").to(device)
effnet.load_state_dict(pretrained_checkpoint['effnet_state_dict'])
effnet.eval().requires_grad_(False)

EfficientNetEncoder(
  (backbone): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      (2): SiLU(inplace=True)
    )
    (1): Sequential(
      (0): FusedMBConv(
        (block): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (2): SiLU(inplace=True)
          )
        )
        (stochastic_depth): StochasticDepth(p=0.0, mode=row)
      )
      (1): FusedMBConv(
        (block): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=T

In [8]:
effnet_features = effnet(effnet_preprocess(init_preprocess(Image.open(os.path.join(image_dir, "000000000009.jpg")).convert('RGB'))).unsqueeze(0).to(device))
print(effnet_features.shape)

torch.Size([1, 16, 12, 12])


In [9]:
effnet_features = effnet(effnet_preprocess_768(init_preprocess(Image.open(os.path.join(image_dir, "000000000009.jpg")).convert('RGB'))).unsqueeze(0).to(device))
print(effnet_features.shape)

torch.Size([1, 16, 24, 24])


In [10]:
def batch_encode(image_dir, output_filename, preprocess=effnet_preprocess, feats_per_file=200000, batch_size=64):
    latent_tensors = []
    file_count = 0
    for i in tqdm(range(0, len(os.listdir(image_dir)), batch_size)):
        images = []

        with torch.no_grad():
            for filename in os.listdir(image_dir)[i:i+batch_size]:
                images.append(init_preprocess(Image.open(os.path.join(image_dir, filename)).convert('RGB')).unsqueeze(0))
            
            latent_tensors.append(effnet(preprocess(torch.cat(images, dim=0)).to(device)).detach().cpu())
        
        # if i != 0 and len(latent_tensors) >= feats_per_file:
        #     np.save(f"./stage_b_latents/coco_train2017_latents_v2_{file_count}.npy", torch.cat(latent_tensors[:feats_per_file], dim=0).numpy(), allow_pickle=True)
        #     latent_tensors = latent_tensors[feats_per_file:]
        #     file_count += 1
    
    np.save(f"./stage_b_latents/{output_filename}", torch.cat(latent_tensors, dim=0).numpy(), allow_pickle=True)
    
    return

In [12]:
batch_encode(image_dir, output_filename="coco_train2017_latents.npy", preprocess=effnet_preprocess)

100%|██████████| 1849/1849 [26:45<00:00,  1.15it/s]


In [13]:
batch_encode(image_dir, output_filename="coco_train2017_latents_v2.npy", preprocess=effnet_preprocess_768)

100%|██████████| 1849/1849 [39:49<00:00,  1.29s/it]
