In [1]:
import torch
import torchvision.transforms.v2 as transforms
from torchvision import models
from torchvision.models.feature_extraction import get_graph_node_names
from torchvision.models.feature_extraction import create_feature_extractor
import pickle
import numpy as np
import firebase_admin
from firebase_admin import credentials, firestore,storage
import requests
from PIL import Image, UnidentifiedImageError
from io import BytesIO
import tempfile
import os

In [2]:
if not firebase_admin._apps:
    cred = credentials.Certificate("imagequest-aab50-firebase-adminsdk-fbsvc-44dd473055.json")
    firebase_admin.initialize_app(cred, {"storageBucket": "imagequest-aab50.firebasestorage.app"})
db = firestore.client()
bucket = storage.bucket()

In [4]:
model = models.vgg16(weights=models.VGG16_Weights.DEFAULT)
model.eval()

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

In [5]:
train_nodes, eval_nodes = get_graph_node_names(model)
print(train_nodes)
print(eval_nodes)

['x', 'features.0', 'features.1', 'features.2', 'features.3', 'features.4', 'features.5', 'features.6', 'features.7', 'features.8', 'features.9', 'features.10', 'features.11', 'features.12', 'features.13', 'features.14', 'features.15', 'features.16', 'features.17', 'features.18', 'features.19', 'features.20', 'features.21', 'features.22', 'features.23', 'features.24', 'features.25', 'features.26', 'features.27', 'features.28', 'features.29', 'features.30', 'avgpool', 'flatten', 'classifier.0', 'classifier.1', 'classifier.2', 'classifier.3', 'classifier.4', 'classifier.5', 'classifier.6']
['x', 'features.0', 'features.1', 'features.2', 'features.3', 'features.4', 'features.5', 'features.6', 'features.7', 'features.8', 'features.9', 'features.10', 'features.11', 'features.12', 'features.13', 'features.14', 'features.15', 'features.16', 'features.17', 'features.18', 'features.19', 'features.20', 'features.21', 'features.22', 'features.23', 'features.24', 'features.25', 'features.26', 'fea

In [6]:
data_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToImage(), 
    transforms.ToDtype(torch.float32, scale=True),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

In [7]:
return_nodes = {
    'features.28': 'features.28', 
}
features = create_feature_extractor(model, return_nodes=return_nodes)
collection = db.collection('main2').stream(timeout=300)

In [None]:
extract = []
for x in collection:
    image = x.to_dict()
    image_url = image['url']
    image_id = image['image_id']
    try:
        get_image = requests.get(image_url, timeout=300)
        if get_image.status_code != 200:
            print("CANNOT GET IMAGE")
            continue

        read_image = Image.open(BytesIO(get_image.content)).convert('RGB')
        image_transform = data_transforms(read_image).unsqueeze(0)

        with torch.no_grad():
            output = features(image_transform)

        extraction = output['features.28'].cpu().numpy()
        reduce = extraction.reshape(1, -1)
        extract.append({
        'image_id': image_id,
        'url': image_url,
        'features': reduce.tolist()  
})

        print(f"FEATURES EXTRACTED: {image_id}: {reduce.shape}")
    except UnidentifiedImageError:
        print("UNIDENTIFIED IMAGE ERROR")
    except Exception as e:
        print("ERROR")


FEATURES EXTRACTED: chess_dataset2/00000021 (2).jpg: (1, 100352)
FEATURES EXTRACTED: chess_dataset2/00000030 (2).jpg: (1, 100352)
FEATURES EXTRACTED: chess_dataset2/00000086 (2).jpg: (1, 100352)
FEATURES EXTRACTED: chess_dataset2/00000007 (5).jpg: (1, 100352)
FEATURES EXTRACTED: chess_dataset2/00000072 (3).jpg: (1, 100352)
FEATURES EXTRACTED: chess_dataset2/00000002.JPG: (1, 100352)
FEATURES EXTRACTED: chess_dataset2/00000052.jpg: (1, 100352)
FEATURES EXTRACTED: chess_dataset2/00000075 (2).jpg: (1, 100352)
FEATURES EXTRACTED: chess_dataset2/00000169.jpg: (1, 100352)
FEATURES EXTRACTED: chess_dataset2/00000112.jpg: (1, 100352)
FEATURES EXTRACTED: chess_dataset2/00000027.png: (1, 100352)
FEATURES EXTRACTED: chess_dataset2/00000071 (4).jpg: (1, 100352)
FEATURES EXTRACTED: chess_dataset2/00000073.jpg: (1, 100352)
FEATURES EXTRACTED: chess_dataset2/00000028 (2).jpg: (1, 100352)
FEATURES EXTRACTED: chess_dataset2/00000045.jpg: (1, 100352)
FEATURES EXTRACTED: chess_dataset2/00000194 (2).jpg: 

In [None]:
with tempfile.NamedTemporaryFile(delete=False, suffix='.pkl') as temp_file:
    pickle.dump(extract, temp_file)
    temp_filename = temp_file.name 
blob = bucket.blob("feature_vectors/vgg_features.pkl")
blob.upload_from_filename(temp_filename)
os.remove(temp_filename)