In [1]:
import torch
import numpy as np
from transformers import AutoFeatureExtractor, ViTModel
from PIL import Image, ImageEnhance
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from tqdm import tqdm
from torchvision import transforms

In [2]:
# Load model and feature extractor, disabling automatic rescaling
model_name = "google/vit-base-patch32-384"
model = ViTModel.from_pretrained(model_name).eval()
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name, do_rescale=False)

config.json:   0%|          | 0.00/69.7k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/353M [00:00<?, ?B/s]

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch32-384 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]



In [3]:
# Move model to GPU if available for faster processing
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Custom Dataset Class with Centered Cropping Transformation
class ImageDataset(Dataset):
    def __init__(self, df, feature_extractor):
        self.df = df
        self.feature_extractor = feature_extractor
        self.transform = transforms.Compose([
            transforms.Resize((384, 384)),
            transforms.ToTensor()
        ])
#         self.transform = transforms.Compose([
#             transforms.RandomRotation(degrees=15),
#             transforms.Resize((384, 384)),
#             transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),
#             transforms.ToTensor()
#         ])
#         self.transform = transforms.Compose([
#             RandomResizedCrop(size=(384, 384), scale=(0.8, 1.0)),  # Randomly zoom in with 80%-100% crop scale
#             transforms.ToTensor()
#         ])
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        img_path = self.df.iloc[idx]["image_path"]
        image = Image.open(img_path).convert("RGB")
        image = self.transform(image)
        inputs = self.feature_extractor(images=image, return_tensors="pt")
        return inputs["pixel_values"].squeeze(0)

# Efficient DataLoader function
def create_dataloader(df, feature_extractor, batch_size=32, num_workers=4):
    dataset = ImageDataset(df, feature_extractor)
    return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=True)

# Processing function that returns numpy arrays with progress bar
def process_images(df, model, feature_extractor, batch_size=32):
    dataloader = create_dataloader(df, feature_extractor, batch_size=batch_size)
    features = []
    model.eval()  # Ensure model is in evaluation mode
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Processing Images"):
            batch = batch.to(device)  # Send batch to device (GPU if available)
            outputs = model(batch)
            cls_embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()  # Extract CLS token to numpy
            features.append(cls_embeddings)
    return np.vstack(features)  # Stack all arrays into a single numpy array

In [4]:
# Load train and test dataframes
train_df = pd.read_csv("/kaggle/input/visual-taxonomy/train.csv")
test_df = pd.read_csv("/kaggle/input/visual-taxonomy/test.csv")
train_images="/kaggle/input/visual-taxonomy/train_images/"
test_images="/kaggle/input/visual-taxonomy/test_images/"
train_df["image_path"]=train_df["id"].apply(lambda Id: train_images+str(Id).zfill(6)+".jpg")
test_df["image_path"]=test_df["id"].apply(lambda Id: test_images+str(Id).zfill(6)+".jpg")

In [5]:
# Process images and get output as numpy arrays
train_features = process_images(train_df, model, feature_extractor, batch_size=32)
test_features = process_images(test_df, model, feature_extractor, batch_size=32)

# Optionally save or use features as numpy arrays
np.save("train_features.npy", train_features)
np.save("test_features.npy", test_features)

  self.pid = os.fork()
  self.pid = os.fork()
Processing Images: 100%|██████████| 2195/2195 [12:16<00:00,  2.98it/s]
Processing Images: 100%|██████████| 944/944 [05:22<00:00,  2.92it/s]
