In [1]:
import os
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
from torchvision import models, transforms

from sklearn.cluster import AgglomerativeClustering
from sklearn.preprocessing import normalize

In [19]:
data_dir = r"data\challenge_three_bogatyrs\dataset"
output_csv = r"data\384_star_maps_2\submission.csv"

num_clusters = 983
device = "cuda" if torch.cuda.is_available() else "cpu"

In [5]:
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485,0.456,0.406],
        std=[0.229,0.224,0.225]
    )
])

In [7]:
model = models.resnet152(weights=models.ResNet152_Weights.IMAGENET1K_V1)
model.fc = nn.Identity()
model = model.to(device)
model.eval()

Downloading: "https://download.pytorch.org/models/resnet152-394f9c45.pth" to C:\Users\xz1v/.cache\torch\hub\checkpoints\resnet152-394f9c45.pth
100%|███████████████████████████████████████████████████████████████████████████████| 230M/230M [00:09<00:00, 25.4MB/s]


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [9]:
filenames = sorted([f for f in os.listdir(data_dir) if f.endswith(".png")])

embeddings = []

with torch.no_grad():
    for fname in tqdm(filenames):
        img_path = os.path.join(data_dir, fname)
        img = Image.open(img_path).convert("RGB")
        x = transform(img).unsqueeze(0).to(device)
        emb = model(x)
        embeddings.append(emb.cpu().numpy()[0])

embeddings = np.stack(embeddings)

100%|██████████████████████████████████████████████████████████████████████████████| 9605/9605 [05:10<00:00, 30.95it/s]


In [10]:
embeddings = normalize(embeddings, norm="l2")

In [11]:
clustering = AgglomerativeClustering(
    n_clusters=num_clusters,
    metric="cosine",
    linkage="average"
)

labels = clustering.fit_predict(embeddings)

In [21]:
df = pd.DataFrame({
    "filename": filenames,
    "label": labels
})

df.to_csv(output_csv, index=False)