In [None]:
import pickle
import random
import time
from pathlib import Path

import tensorflow as tf
import tensorboard as tb

import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn as nn 
import torch.nn.functional as F
import torchvision.models as models
from torch.utils.tensorboard import SummaryWriter 

import cv2
import numpy as np
from tqdm import tqdm
from PIL import Image

import torchreid

tf.io.gfile = tb.compat.tensorflow_stub.io.gfile

In [None]:
class BackboneFeatureExtractor(nn.Module):
    def __init__(self, backbone):
        super(BackboneFeatureExtractor, self).__init__()
        self.backbone = backbone        
        self.backbone.fc = nn.Identity()
            
    def forward(self, x):
        return self.backbone(x)

In [None]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

preprocess_image = transforms.Compose([
    transforms.Resize((256, 128)),
    transforms.ToTensor(),
    normalize,
])

In [None]:
# backbone = torchvision.models.vgg19(pretrained=True)
# model = BackboneFeatureExtractor(backbone).cuda(1)
# _ = model.eval()

In [None]:
model = torchreid.models.build_model(
    name='resnet50',
    num_classes=4101,
    loss='softmax',
    pretrained=True
)

model = model.cuda(1)
checkpoint = torch.load('../resnet50_msmt17.pth')
model.load_state_dict(checkpoint)
_ = model.eval()

In [None]:
# Generate all embedings for images
images_path = Path('/media/svakhreev/fast/generated_images/crops/body/')

for image_filename in tqdm(images_path.iterdir()):
    if image_filename.suffix.lower() not in ('.png', '.jpg'):
        continue
    image = Image.open(str(image_filename))
    input_data = preprocess_image(image)[None,].cuda(1)
    with torch.no_grad():
        features = model(input_data).cpu().numpy()
    np.save(str(images_path / f'{image_filename.stem}.npy'), features) # сохраняем эмбеддинги .npy в указанный путь
