In [1]:
import numpy as np
import pandas as pd
import torch

In [2]:
data = pd.read_csv('prepared_data.csv')

In [61]:
test = pd.read_csv('prepared_test.csv')

In [3]:
mapping = pd.read_csv('mapping_ids.csv')

In [66]:
TRAIN_IMG = 'train_images_archive/'
TEST_IMG = 'test_images_archive/'

In [5]:
from transformers import AutoFeatureExtractor, AutoModelForImageClassification

extractor = AutoFeatureExtractor.from_pretrained("therealcyberlord/stanford-car-vit-patch16")
model = AutoModelForImageClassification.from_pretrained("therealcyberlord/stanford-car-vit-patch16")



In [68]:
from PIL import Image

def get_img_embedding(id):
    embedding = np.zeros((1, 197, 768))
    count = 0
    img_ids = mapping[mapping['id'] == id]['image_id']
    for img_id in img_ids:
        try:
            image = Image.open(TRAIN_IMG + str(img_id) + '.jpg')
            inputs = extractor(images=image, return_tensors='pt')
            inputs = {name: tensor.to('cuda') for name, tensor in inputs.items()}
            outputs = model.base_model(**inputs)
            embedding += outputs.last_hidden_state.cpu().detach().numpy()
            count += 1
        except:
            pass
        
    return embedding.mean(axis=1).flatten() / (count if count > 0 else 1)

def get_test_img_embedding(id):
    embedding = np.zeros((1, 197, 768))
    count = 0
    img_ids = mapping[mapping['id'] == id]['image_id']
    for img_id in img_ids:
        try:
            image = Image.open(TEST_IMG + str(img_id) + '.jpg')
            inputs = extractor(images=image, return_tensors='pt')
            inputs = {name: tensor.to('cuda') for name, tensor in inputs.items()}
            outputs = model.base_model(**inputs)
            embedding += outputs.last_hidden_state.cpu().detach().numpy()
            count += 1
        except:
            pass

    return embedding.mean(axis=1).flatten() / (count if count > 0 else 1)

In [39]:
from tqdm import tqdm

In [40]:
model = model.to('cuda')
model.base_model.eval()
model.eval()

ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=7

In [58]:
with torch.no_grad():
    data_img_embeddings = data['id'].apply(get_img_embedding)

In [59]:
from sklearn.decomposition import PCA

embeddings_matrix = np.array(data_img_embeddings.tolist())

# Initialize PCA
pca = PCA(n_components=128)

# Fit and transform the data
reduced_embeddings = pca.fit_transform(embeddings_matrix)

In [60]:
np.save('img_embeddings.npy', reduced_embeddings)

In [62]:
del data_img_embeddings
del embeddings_matrix

In [69]:
with torch.no_grad():
    test_img_embeddings = test['id'].apply(get_test_img_embedding)

In [70]:
embeddings_matrix = np.array(test_img_embeddings.tolist())
reduced_embeddings = pca.transform(embeddings_matrix)

In [71]:
np.save('test_img_embeddings.npy', reduced_embeddings)

In [42]:
# get rows with not all zeros

print(data_img_embeddings[data_img_embeddings.apply(lambda x: np.sum(x) != 0)])

31      [-0.6503347507754548, 0.41770094290664583, 0.1...
168     [-0.20116988210982534, 0.2564373272659903, 0.3...
223     [-0.01812519282623269, 0.23617112564138168, -0...
325     [-0.015530212395718123, 0.8052608835690164, 0....
430     [0.009833476389471928, 0.41649819359830814, -0...
437     [-0.5263142807433884, 0.5468882173018395, -0.3...
447     [0.16715555636641843, 0.29856856315957286, -0....
471     [0.13455632621388067, 0.3916310768624672, -0.3...
502     [-0.5850567846782908, 0.45216989319586565, 0.0...
521     [-0.5947344552093691, 0.4608764386721859, 0.12...
600     [-0.01810001980494941, 0.3005775533261923, 0.2...
659     [-0.04033050550249763, 0.40188624662734507, -0...
717     [-0.02304378238760655, 0.28246198238952674, 0....
743     [-0.24775981457517948, 0.2860230819318402, 0.0...
859     [-0.04446852575481732, 0.6772535775876178, -0....
1000    [-0.08668337500872215, 0.21097299145680973, 0....
Name: id, dtype: object


In [52]:
? model