In [1]:
import time
import torch
from PIL import Image
from pl_bolts.models.autoencoders import AE
from pl_bolts.models.autoencoders import VAE
from pl_bolts.models.self_supervised import SimCLR
from pl_bolts.models.self_supervised import SwAV
from torchvision import transforms

# Config

In [2]:
IMAGE_PATH = 'data/image_1.jpg'

# Utils

In [3]:
def image_loader(image_name):
    loader = transforms.Compose([transforms.ToTensor()])
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    image = Image.open(image_name).convert('RGB')
    image = loader(image).unsqueeze(0)
    return image.to(device, torch.float)

---

In [4]:
# Read image
image = image_loader(IMAGE_PATH)

# AE (cifar10)

In [5]:
print(f'AE: {AE.pretrained_weights_available()}')

# init
ae = AE(input_height=512)  # попробовать enc_type = 'resnet50'
ae = ae.from_pretrained('cifar10-resnet18')
ae_encoder = ae.encoder
ae_encoder.eval()

# inference
ae_time = time.time()
ae_features = ae_encoder(image)

# result
print(f'ae_features: {ae_features.shape}')
print(f"ae_time: {round((time.time() - ae_time), 3)}")

AE: ['cifar10-resnet18']


Downloading: "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/ae/ae-cifar10/checkpoints/epoch%3D96.ckpt" to /home/jovyan/.cache/torch/hub/checkpoints/epoch%3D96.ckpt


  0%|          | 0.00/228M [00:00<?, ?B/s]

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


ae_features: torch.Size([1, 512])
ae_time: 10.897


# VAE (cifar10)

In [6]:
print(f'VAE: {VAE.pretrained_weights_available()}')

# init
vae = VAE(input_height=512)
vae = vae.from_pretrained('cifar10-resnet18')  # попробовать enc_type = 'stl10-resnet18'
vae_encoder = vae.encoder
vae_encoder.eval()

# inference
vae_time = time.time()
vae_features = vae_encoder(image)

#result
print(f'vae_features: {vae_features.shape}')
print(f"vae_time: {round((time.time() - vae_time), 3)}")

VAE: ['cifar10-resnet18', 'stl10-resnet18']


Downloading: "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/vae/vae-cifar10/checkpoints/epoch%3D89.ckpt" to /home/jovyan/.cache/torch/hub/checkpoints/epoch%3D89.ckpt


  0%|          | 0.00/230M [00:00<?, ?B/s]

vae_features: torch.Size([1, 512])
vae_time: 14.146


# SimCLR (ImageNet)

In [7]:
weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt'  # IMAGENET

# init
simclr = SimCLR.load_from_checkpoint(weight_path, strict=False)
simclr_encoder = simclr.encoder
simclr_encoder.eval()

# inference
simclr_time = time.time()
simclr_features = simclr_encoder(image)
simclr_features = simclr_features[0]

#result
print(f'simclr_features: {simclr_features.shape}')
print(f"simclr_time: {round((time.time() - simclr_time), 3)}")

Downloading: "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt" to /home/jovyan/.cache/torch/hub/checkpoints/simclr_imagenet.ckpt


  0%|          | 0.00/229M [00:00<?, ?B/s]

simclr_features: torch.Size([1, 2048])
simclr_time: 4.602


# SwAV (ImageNet)

In [8]:
weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/swav/swav_imagenet/swav_imagenet.pth.tar'  # IMAGENET
# weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/swav/checkpoints/swav_stl10.pth.tar' # STL-10

# init
swav = SwAV.load_from_checkpoint(weight_path, strict=True)
swav.eval()

# inference
swav_time = time.time()
swav_features = swav(image)

#result
print(f'swav_features: {swav_features.shape}')
print(f"swav_time: {round((time.time() - swav_time), 3)}")

Downloading: "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/swav/swav_imagenet/swav_imagenet.pth.tar" to /home/jovyan/.cache/torch/hub/checkpoints/swav_imagenet.pth.tar


  0%|          | 0.00/322M [00:00<?, ?B/s]

swav_features: torch.Size([1, 2048])
swav_time: 9.044
