In [60]:
# imports
import os
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm

import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torchvision.models as models
from torchvision import transforms


#Set GPU as device to use.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [61]:
df_train = pd.read_csv("../aml-2025-feathers-in-focus/train_images.csv")
df_train.head()

complete_bird_attributes = pd.read_csv("complete_bird_attributes.csv", index_col='class_key')

In [62]:
num_concepts = complete_bird_attributes.shape[1]

# test
example_label = df_train["label"].iloc[0]
example_concepts = complete_bird_attributes.loc[example_label]
print(example_concepts)


has_bill_shape::curved_(up_or_down)    0
has_bill_shape::dagger                 0
has_bill_shape::hooked                 0
has_bill_shape::needle                 0
has_bill_shape::hooked_seabird         1
                                      ..
has_crown_color::buff                  0
has_wing_pattern::solid                1
has_wing_pattern::spotted              0
has_wing_pattern::striped              0
has_wing_pattern::multi-colored        0
Name: 1, Length: 312, dtype: int64


In [63]:
class BirdConceptDataset(Dataset):
    def __init__(self, csv_df, attributes_df, images_root):
        self.df = csv_df
        self.attributes = attributes_df
        self.images_root = images_root
        
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # Load label
        label = self.df.iloc[idx]["label"]
        
        # Load concept vector
        concept_vec = torch.tensor(self.attributes.loc[label].values, dtype=torch.float32)
        
        # Build full image path
        img_rel_path = self.df.iloc[idx]["image_path"]
        img_path = os.path.join(self.images_root, os.path.basename(img_rel_path))

        # Load image using your method
        image = Image.open(img_path).convert("RGB")
        image = self.transform(image)

        return image, concept_vec


In [64]:
train_images_dir = "../aml-2025-feathers-in-focus/train_images/train_images/"
train_images = [
    os.path.join(train_images_dir, f)
    for f in os.listdir(train_images_dir)
    if f.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp'))
]


In [65]:
train_dataset = BirdConceptDataset(
    csv_df=df_train,
    attributes_df=complete_bird_attributes,
    images_root=train_images_dir
)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

In [None]:
class ConceptNet(nn.Module):
    def __init__(self, num_outputs):
        super().__init__()
        
        self.backbone = models.resnet50(pretrained=True)
        self.backbone.fc = nn.Identity()   # remove classification head
        
        self.fc = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(256, num_outputs)
            # nn.Sigmoid()   # multilabel output
        )
        
    def forward(self, x):
        x = self.backbone(x)
        x = self.fc(x)
        return x


In [67]:
model = ConceptNet(num_outputs=num_concepts)
model = model.to(device)
criterion = nn.BCEWithLogitsLoss()    # output is already sigmoid
criterion = criterion.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)



In [76]:
num_epochs = 50

model.train()
for epoch in range(num_epochs):
    total_loss = 0

    for images, targets in train_loader:
        optimizer.zero_grad()

        images = images.to(device)
        targets = targets.to(device)
        preds = model(images)         # predictions (batch_size Ã— num_concepts)
        loss = criterion(preds, targets)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss:.4f}")

    torch.save(model.state_dict(), "best_model.pth")


Epoch 1/50, Loss: 0.7226
Epoch 2/50, Loss: 0.7074
Epoch 3/50, Loss: 0.7053
Epoch 4/50, Loss: 0.7108
Epoch 5/50, Loss: 0.7141
Epoch 6/50, Loss: 0.7286
Epoch 7/50, Loss: 0.7242
Epoch 8/50, Loss: 0.7137
Epoch 9/50, Loss: 0.7208
Epoch 10/50, Loss: 0.7049
Epoch 11/50, Loss: 0.7330
Epoch 12/50, Loss: 0.7239
Epoch 13/50, Loss: 0.7182
Epoch 14/50, Loss: 0.7063
Epoch 15/50, Loss: 0.7215
Epoch 16/50, Loss: 0.7137
Epoch 17/50, Loss: 0.7229
Epoch 18/50, Loss: 0.7114
Epoch 19/50, Loss: 0.7124
Epoch 20/50, Loss: 0.7347
Epoch 21/50, Loss: 0.6951
Epoch 22/50, Loss: 0.7044
Epoch 23/50, Loss: 0.7053
Epoch 24/50, Loss: 0.6959
Epoch 25/50, Loss: 0.7139
Epoch 26/50, Loss: 0.7263
Epoch 27/50, Loss: 0.7262
Epoch 28/50, Loss: 0.7160
Epoch 29/50, Loss: 0.7255
Epoch 30/50, Loss: 0.7041
Epoch 31/50, Loss: 0.7018
Epoch 32/50, Loss: 0.7254
Epoch 33/50, Loss: 0.7114
Epoch 34/50, Loss: 0.7098
Epoch 35/50, Loss: 0.7132
Epoch 36/50, Loss: 0.7259
Epoch 37/50, Loss: 0.7087
Epoch 38/50, Loss: 0.7032
Epoch 39/50, Loss: 0.

In [77]:
torch.save(model.state_dict(), "best_model.pth")

In [91]:
test_images_dir = "../aml-2025-feathers-in-focus/test_images/test_images/"

test_images = [
    os.path.join(test_images_dir, f)
    for f in os.listdir(test_images_dir)
    if f.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp'))
]

print(f"Found {len(test_images)} test images")


img_size = 224
test_transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor()
])


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# make sure model structure matches training
model = ConceptNet(num_outputs=complete_bird_attributes.shape[1])
model.load_state_dict(torch.load("best_model.pth", map_location=device))
model = model.to(device)
model.eval()


#num_concepts = len(complete_bird_attributes.columns)
#topk = 1  # number of concepts to keep per image

threshold = 0.5
pred_list = []

with torch.no_grad():
    for img_path in tqdm(test_images, desc="Predicting test images"):
        img = Image.open(img_path).convert("RGB")
        x = test_transform(img).unsqueeze(0).to(device)

        logits = model(x)
        probs = torch.sigmoid(logits).cpu().numpy().squeeze()  # (num_concepts,)

        # ðŸ”¥ Every predicted concept above threshold becomes 1
        pred_vec = (probs >= threshold).astype(int)

        pred_list.append(pred_vec)


pred_array = np.stack(pred_list, axis=0)  # shape: (4000, 312)

# Use column names from complete_bird_attributes
pred_df = pd.DataFrame(pred_array, columns=complete_bird_attributes.columns)
pred_df.insert(0, "image_path", [os.path.basename(f) for f in test_images])

print(pred_df.shape)  # should be (4000, 312)

pred_df.head()


pred_df.to_csv("test_predictions.csv", index=False)


Found 4000 test images


Predicting test images: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 4000/4000 [00:41<00:00, 97.50it/s] 


(4000, 313)


In [94]:
#KNN with K=1 on the predicted concept vectors to find the nearest training image
from sklearn.neighbors import NearestNeighbors
import pandas as pd
import numpy as np

# Drop image_path column for NN, only keep concept vectors
test_features = pred_df.drop(columns=["image_path"]).values  # shape (4000, 312)

# Make sure training attribute matrix is aligned (columns in same order)
train_features = complete_bird_attributes.values  # shape (num_birds, 312)
train_index = complete_bird_attributes.index.values  # class_key values

# Fit Nearest Neighbors model
k = 1
nn_model = NearestNeighbors(n_neighbors=k, metric='euclidean')  # can use 'cosine' or 'euclidean'
nn_model.fit(train_features)


distances, indices = nn_model.kneighbors(test_features)  # indices shape (4000, k)

nearest_class_keys = train_index[indices[:, 0]]  # shape (4000,)
output_df = pd.DataFrame({
    "image_path": pred_df["image_path"].values,
    "class_key": nearest_class_keys
})

output_df.head()

output_df.to_csv("test_nearest_neighbors.csv", index=False)



In [95]:
output_df.class_key.unique()

array([153, 105,  68,  20,  24, 195, 167, 137,  17,  21,  26,   5,  55,
       186,   1,  54, 115,  51,  48, 125,  43, 187,  36,  73,  86, 182,
       176,  88, 143,   8, 133, 111, 188,  53,   7,  45,  56,  62, 138,
       135,  22, 119,  13,  75,  59, 148, 183,  74,  23,  16,  94, 146,
       190, 121, 166,  58,  14,   2,  31,  29, 158, 129,  35,  65, 100,
        57,  50, 109,  96,  47,  25,  39, 120,  79, 160,  83, 178, 155,
        12,  69, 159,   3,  85, 128, 144,  52, 139, 163, 141, 161, 165,
       110, 112, 168,  95, 145,  82,  40, 124,  90, 152, 104,  63,  46,
       116, 140,  60,  67,  76,  15, 136,  81, 193, 123,  87,  99, 131,
       191, 149, 132, 127,  18, 103,  37,  91,  30,  27,   9,  42,  89,
        77,  98, 130, 200,  66,  80,  44,  61,  10, 108, 134,  70, 114,
        97, 122, 175, 151,  78, 192,  38, 196, 194,   4, 147,  93, 185,
       199, 106,  33,  32,  41, 171, 164,   6,  92, 189, 101,  64,  28,
       107,  19, 142,  84, 181, 113, 154,  34, 126, 117, 184,  1

In [82]:
output_df.class_key.unique()

array([ 7,  2,  4, 15, 12,  8, 14, 10, 25, 21,  3, 51, 22, 49, 26, 16, 23,
        5, 29, 63, 17, 62, 13])