In [1]:
# SPRINT VERSION: EfficientNet-B1
import torch, timm, pandas as pd, numpy as np, os
from PIL import Image
from torchvision import transforms
from tqdm import tqdm

# --- SPRINT CONFIGURATION ---
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_NAME = "efficientnet_b1"  # <-- FASTER MODEL
IMAGE_SIZE = 240                # <-- Smaller image size
BATCH_SIZE = 32                # <-- Bigger batch size for speed
print(f"Using device: {DEVICE}")

model = timm.create_model(MODEL_NAME, pretrained=True, num_classes=0).to(DEVICE).eval()
transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Function is the same
def get_image_embeddings(image_paths):
    all_embeddings = []
    for i in tqdm(range(0, len(image_paths), BATCH_SIZE), desc="Processing Image Batches (SPRINT)"):
        batch_paths = image_paths[i:i + BATCH_SIZE]; batch_images = []
        for path in batch_paths:
            try:
                img = Image.open(path).convert("RGB"); img_tensor = transform(img); batch_images.append(img_tensor)
            except Exception as e:
                batch_images.append(torch.zeros((3, IMAGE_SIZE, IMAGE_SIZE)))
        with torch.no_grad():
            embeddings = model(torch.stack(batch_images).to(DEVICE))
        all_embeddings.append(embeddings.cpu().numpy())
    return np.vstack(all_embeddings)

# Data loading is the same
train_df = pd.read_csv("../data/raw/dataset/train.csv"); test_df = pd.read_csv("../data/raw/dataset/test.csv")
full_df = pd.concat([train_df.assign(is_train=1), test_df.assign(is_train=0)], ignore_index=True)
full_df['image_filename'] = full_df['image_link'].apply(lambda url: os.path.basename(url))
train_paths = '../data/images/train/' + full_df[full_df['is_train'] == 1]['image_filename']
test_paths = '../data/images/test/' + full_df[full_df['is_train'] == 0]['image_filename']
all_image_paths = pd.concat([train_paths, test_paths]).tolist()

print("Starting SPRINT image embedding generation...")
image_embeddings = get_image_embeddings(all_image_paths)
embedding_cols = [f"img_b1_{i}" for i in range(image_embeddings.shape[1])]
image_features_df = pd.DataFrame(image_embeddings, columns=embedding_cols)

# --- SAVE TO A NEW FILE ---
PROCESSED_DIR = "../data/processed"; os.makedirs(PROCESSED_DIR, exist_ok=True)
SAVE_PATH = os.path.join(PROCESSED_DIR, "image_features_b1_sprint.parquet")
image_features_df.to_parquet(SAVE_PATH, index=False)
print(f"\n SPRINT Image features saved to: {SAVE_PATH}")


Using device: cpu
Starting SPRINT image embedding generation...


Processing Image Batches (SPRINT): 100%|██████████| 4688/4688 [5:13:22<00:00,  4.01s/it]   



 SPRINT Image features saved to: ../data/processed\image_features_b1_sprint.parquet
