In [1]:
import torch
import torchvision.transforms.v2 as transforms
from torchvision import models
from torchvision.models.feature_extraction import get_graph_node_names
from torchvision.models.feature_extraction import create_feature_extractor
import pickle
import numpy as np
import firebase_admin
from firebase_admin import credentials, firestore,storage
import requests
from PIL import Image, UnidentifiedImageError
from io import BytesIO
import tempfile
import os

In [2]:
cred = credentials.Certificate("imagequest-aab50-firebase-adminsdk-fbsvc-44dd473055.json")
firebase_admin.initialize_app(cred,{"storageBucket": "imagequest-aab50.firebasestorage.app"})
db = firestore.client()
bucket = storage.bucket()

In [3]:
model = models.inception_v3(weights=models.Inception_V3_Weights.DEFAULT)
model.eval()


Inception3(
  (Conv2d_1a_3x3): BasicConv2d(
    (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_2a_3x3): BasicConv2d(
    (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_2b_3x3): BasicConv2d(
    (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (maxpool1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (Conv2d_3b_1x1): BasicConv2d(
    (conv): Conv2d(64, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(80, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_4a_3x3): BasicConv2d(
    (conv): Conv2d(80, 192, kernel_size=(3, 3), stri

In [4]:
train_nodes, eval_nodes = get_graph_node_names(model)
print(train_nodes)
print(eval_nodes)

['x', 'getitem', 'unsqueeze', 'mul', 'add', 'getitem_1', 'unsqueeze_1', 'mul_1', 'add_1', 'getitem_2', 'unsqueeze_2', 'mul_2', 'add_2', 'cat', 'Conv2d_1a_3x3.conv', 'Conv2d_1a_3x3.bn', 'Conv2d_1a_3x3.relu', 'Conv2d_2a_3x3.conv', 'Conv2d_2a_3x3.bn', 'Conv2d_2a_3x3.relu', 'Conv2d_2b_3x3.conv', 'Conv2d_2b_3x3.bn', 'Conv2d_2b_3x3.relu', 'maxpool1', 'Conv2d_3b_1x1.conv', 'Conv2d_3b_1x1.bn', 'Conv2d_3b_1x1.relu', 'Conv2d_4a_3x3.conv', 'Conv2d_4a_3x3.bn', 'Conv2d_4a_3x3.relu', 'maxpool2', 'Mixed_5b.branch1x1.conv', 'Mixed_5b.branch1x1.bn', 'Mixed_5b.branch1x1.relu', 'Mixed_5b.branch5x5_1.conv', 'Mixed_5b.branch5x5_1.bn', 'Mixed_5b.branch5x5_1.relu', 'Mixed_5b.branch5x5_2.conv', 'Mixed_5b.branch5x5_2.bn', 'Mixed_5b.branch5x5_2.relu', 'Mixed_5b.branch3x3dbl_1.conv', 'Mixed_5b.branch3x3dbl_1.bn', 'Mixed_5b.branch3x3dbl_1.relu', 'Mixed_5b.branch3x3dbl_2.conv', 'Mixed_5b.branch3x3dbl_2.bn', 'Mixed_5b.branch3x3dbl_2.relu', 'Mixed_5b.branch3x3dbl_3.conv', 'Mixed_5b.branch3x3dbl_3.bn', 'Mixed_5b.bran



In [5]:
data_transforms = transforms.Compose([
    transforms.Resize(299),
    transforms.CenterCrop(299),    
    transforms.RandomHorizontalFlip(),
    transforms.ToImage(), 
    transforms.ToDtype(torch.float32, scale=True),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 
])

In [6]:
return_nodes = {
    'Mixed_7c': 'Mixed_7c',  
}
features = create_feature_extractor(model, return_nodes=return_nodes)
collection = db.collection('main2').stream(timeout=300)




In [7]:
extract = []
for x in collection:
    image = x.to_dict()
    image_url = image['url']
    image_id = image['image_id']
    try:
        get_image = requests.get(image_url, timeout=300)
        if get_image.status_code != 200:
            print("CANNOT GET IMAGE")
            continue

        read_image = Image.open(BytesIO(get_image.content)).convert('RGB')
        image_transform = data_transforms(read_image).unsqueeze(0)

        with torch.no_grad():
            output = features(image_transform)

        extraction = output['Mixed_7c'].cpu().numpy()
        reduce = extraction.reshape(1, -1)
        extract.append(reduce)

        print(f"FEATURES EXTRACTED: {image_id}: {reduce.shape}")
    except UnidentifiedImageError:
        print("UNIDENTIFIED IMAGE ERROR")
    except Exception as e:
        print("ERROR")
stack_list = np.vstack(extract)

FEATURES EXTRACTED: chess_dataset2/00000021 (2).jpg: (1, 131072)
FEATURES EXTRACTED: chess_dataset2/00000030 (2).jpg: (1, 131072)
FEATURES EXTRACTED: chess_dataset2/00000086 (2).jpg: (1, 131072)
FEATURES EXTRACTED: chess_dataset2/00000007 (5).jpg: (1, 131072)
FEATURES EXTRACTED: chess_dataset2/00000072 (3).jpg: (1, 131072)
FEATURES EXTRACTED: chess_dataset2/00000002.JPG: (1, 131072)
FEATURES EXTRACTED: chess_dataset2/00000052.jpg: (1, 131072)
FEATURES EXTRACTED: chess_dataset2/00000075 (2).jpg: (1, 131072)
FEATURES EXTRACTED: chess_dataset2/00000169.jpg: (1, 131072)
FEATURES EXTRACTED: chess_dataset2/00000112.jpg: (1, 131072)
FEATURES EXTRACTED: chess_dataset2/00000027.png: (1, 131072)
FEATURES EXTRACTED: chess_dataset2/00000071 (4).jpg: (1, 131072)
FEATURES EXTRACTED: chess_dataset2/00000073.jpg: (1, 131072)
FEATURES EXTRACTED: chess_dataset2/00000028 (2).jpg: (1, 131072)
FEATURES EXTRACTED: chess_dataset2/00000045.jpg: (1, 131072)
FEATURES EXTRACTED: chess_dataset2/00000194 (2).jpg: 

In [8]:
with tempfile.NamedTemporaryFile(delete=False, suffix='.pkl') as temp_file:
        pickle.dump(stack_list, temp_file)
        temp_filename = temp_file.name 
blob = bucket.blob("feature_vectors/inception_features.pkl")
blob.upload_from_filename(temp_filename)
os.remove(temp_filename)