In [44]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
import os
import wandb
import tqdm
import json

In [13]:
hyperparams = {
    "batch_size": 16,
    "learning_rate": 0.001,
    "num_epochs": 20,
    "num_classes": 2
}

In [14]:
wandb.init(project="tiger-flank-classifier", name="baseline")
config = wandb.config
config.update(hyperparams)

In [6]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


In [15]:
train_dataset = datasets.ImageFolder(root='/mnt/nas/WII-flanks/train_images', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)

val_dataset = datasets.ImageFolder(root='/mnt/nas/WII-flanks/val_images', transform=transform)
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False)

test_dataset = datasets.ImageFolder(root='/mnt/nas/WII-flanks/test_images', transform=transform)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False)

In [16]:
print(len(train_dataset), len(val_dataset), len(test_dataset))

2038 267 360


In [11]:
# Define the model
model = models.resnet18(pretrained=True)
model.fc = torch.nn.Linear(model.fc.in_features, config.num_classes)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


100%|██████████| 44.7M/44.7M [00:00<00:00, 59.9MB/s]


In [17]:
# Define loss function and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)

In [18]:
for epoch in tqdm.tqdm(range(config.num_epochs)):
    model.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * inputs.size(0)
    
    epoch_loss = running_loss / len(train_loader.dataset)
    print(f'Epoch [{epoch+1}/{config.num_epochs}], Loss: {epoch_loss:.4f}')
    
    # Log training loss to wandb
    wandb.log({"Training Loss": epoch_loss, "epoch": epoch + 1})
    
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            val_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    val_loss /= len(val_loader.dataset)
    val_accuracy = 100 * correct / total
    print(f'Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%')
    
    # Log validation loss and accuracy to wandb
    wandb.log({"Validation Loss": val_loss, "Validation Accuracy": val_accuracy, "epoch": epoch + 1})



  0%|          | 0/20 [00:00<?, ?it/s]

Epoch [1/20], Loss: 0.2498


  5%|▌         | 1/20 [01:05<20:38, 65.16s/it]

Validation Loss: 0.1866, Validation Accuracy: 94.01%
Epoch [2/20], Loss: 0.1632


 10%|█         | 2/20 [01:38<13:50, 46.16s/it]

Validation Loss: 0.2167, Validation Accuracy: 89.51%
Epoch [3/20], Loss: 0.1349


 15%|█▌        | 3/20 [02:08<11:06, 39.18s/it]

Validation Loss: 0.1456, Validation Accuracy: 94.01%
Epoch [4/20], Loss: 0.0870


 20%|██        | 4/20 [02:41<09:45, 36.60s/it]

Validation Loss: 0.1170, Validation Accuracy: 95.13%
Epoch [5/20], Loss: 0.0760


 25%|██▌       | 5/20 [03:12<08:40, 34.72s/it]

Validation Loss: 0.1372, Validation Accuracy: 94.38%
Epoch [6/20], Loss: 0.1489


 30%|███       | 6/20 [03:42<07:41, 32.95s/it]

Validation Loss: 0.5804, Validation Accuracy: 77.90%
Epoch [7/20], Loss: 0.1026


 35%|███▌      | 7/20 [04:11<06:53, 31.82s/it]

Validation Loss: 0.0865, Validation Accuracy: 96.63%
Epoch [8/20], Loss: 0.0411


 40%|████      | 8/20 [04:41<06:12, 31.03s/it]

Validation Loss: 0.1397, Validation Accuracy: 96.63%
Epoch [9/20], Loss: 0.0578


 45%|████▌     | 9/20 [05:12<05:40, 30.94s/it]

Validation Loss: 0.1404, Validation Accuracy: 96.25%
Epoch [10/20], Loss: 0.0600


 50%|█████     | 10/20 [05:44<05:12, 31.27s/it]

Validation Loss: 0.1792, Validation Accuracy: 91.76%
Epoch [11/20], Loss: 0.0667


 55%|█████▌    | 11/20 [06:16<04:44, 31.59s/it]

Validation Loss: 0.0785, Validation Accuracy: 96.25%
Epoch [12/20], Loss: 0.0367


 60%|██████    | 12/20 [06:49<04:15, 31.97s/it]

Validation Loss: 0.1233, Validation Accuracy: 96.63%
Epoch [13/20], Loss: 0.0268


 65%|██████▌   | 13/20 [07:21<03:45, 32.20s/it]

Validation Loss: 0.1755, Validation Accuracy: 95.88%
Epoch [14/20], Loss: 0.0313


 70%|███████   | 14/20 [07:56<03:16, 32.78s/it]

Validation Loss: 0.1167, Validation Accuracy: 95.88%
Epoch [15/20], Loss: 0.0259


 75%|███████▌  | 15/20 [08:29<02:45, 33.02s/it]

Validation Loss: 0.1669, Validation Accuracy: 95.13%
Epoch [16/20], Loss: 0.0211


 80%|████████  | 16/20 [09:03<02:12, 33.18s/it]

Validation Loss: 0.1276, Validation Accuracy: 97.00%
Epoch [17/20], Loss: 0.0515


 85%|████████▌ | 17/20 [09:36<01:39, 33.12s/it]

Validation Loss: 0.0935, Validation Accuracy: 97.75%
Epoch [18/20], Loss: 0.0331


 90%|█████████ | 18/20 [10:08<01:05, 32.94s/it]

Validation Loss: 0.1171, Validation Accuracy: 96.25%
Epoch [19/20], Loss: 0.0261


 95%|█████████▌| 19/20 [10:41<00:33, 33.06s/it]

Validation Loss: 0.1151, Validation Accuracy: 95.88%
Epoch [20/20], Loss: 0.0200


100%|██████████| 20/20 [11:14<00:00, 33.74s/it]

Validation Loss: 0.1284, Validation Accuracy: 96.63%





In [None]:
# Save the model
model_path = 'tiger_flank_classifier.pt'
torch.save(model.state_dict(), model_path)
wandb.save(model_path)

# Finish the wandb run
wandb.finish()

In [21]:
# Compute test accuracy

model.eval()
test_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        test_loss += loss.item() * inputs.size(0)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

test_loss /= len(test_loader.dataset)
test_accuracy = 100 * correct / total
print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')


Test Loss: 0.1343, Test Accuracy: 96.67%


In [67]:
from typing import Any, Tuple


class FlankDataset(datasets.ImageFolder):
    def __init__(self, root, transform=None):
        super().__init__(root, transform=transform)
        self.class_to_idx = {'left': 0, 'right': 1}
        self.idx_to_class = {v: k for k, v in self.class_to_idx.items()}

    # def __getitem__(self, idx):
    #     image, label = super().__getitem__(idx)

    #     # get the path of the image
    #     image_path, _ = self.samples[idx]
    #     return image, self.idx_to_class[label], image_path


    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """
        Args:
            index (int): Index

        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        """
        path, target = self.samples[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return sample, target, path

In [73]:
inference_train_set = FlankDataset(root='datasets/WII/wii.coco/image_by_identity/train', transform=transform)
inference_train_loader = DataLoader(inference_train_set, batch_size=config.batch_size, shuffle=False)

inference_test_set = FlankDataset(root='datasets/WII/wii.coco/image_by_identity/test', transform=transform)
inference_test_loader = DataLoader(inference_test_set, batch_size=config.batch_size, shuffle=False)

print(len(inference_train_set), len(inference_test_set))


3730 4000


In [64]:
inference_train_set.class_to_idx

{'left': 0, 'right': 1}

In [75]:
train_map = {}
test_map = {}

# get the first item in the dataloadre
labels[0].item()

0

In [77]:
# Infer the labels on the infernce sets and save them
model.eval()

with torch.no_grad():
    for inputs, identities, img_paths in inference_train_loader:
        inputs = inputs.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1)


        # check if labels is in map
        for i, (flank_label, identity, path) in enumerate(zip(predicted.cpu().numpy(), identities, img_paths)):
            identity = int(inference_train_set.classes[identity.item()])
            # get the image name
            image_name = os.path.basename(path)
    
            if identity not in train_map:
                train_map[identity] = {"left_flank": [], "right_flank": []}

            if int(flank_label) == 0:
                train_map[identity]["left_flank"].append(image_name)
            elif int(flank_label) == 1:
                train_map[identity]["right_flank"].append(image_name)


with torch.no_grad():
    for inputs, identities, img_paths in inference_test_loader:
        inputs = inputs.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1)

        # check if labels is in map
        for i, (flank_label, identity, path) in enumerate(zip(predicted.cpu().numpy(), identities, img_paths)):
            identity = int(inference_test_set.classes[identity.item()])
            # get the image name
            image_name = os.path.basename(path)
    
            if identity not in test_map:
                test_map[identity] = {"left_flank": [], "right_flank": []}

            if int(flank_label) == 0:
                test_map[identity]["left_flank"].append(image_name)
            elif int(flank_label) == 1:
                test_map[identity]["right_flank"].append(image_name)


In [94]:
len(train_map), len(test_map)

(253, 191)

In [96]:
# keep only those items that have both left and right flanks in the training set
train_map_filter = {k: v for k, v in train_map.items() if len(v["left_flank"]) > 5 and len(v["right_flank"]) > 5}
test_map_filter = {k: v for k, v in test_map.items() if len(v["left_flank"]) > 5 and len(v["right_flank"]) > 5}

In [116]:
len(train_map_filter), len(test_map_filter)

(77, 78)

In [110]:
# for each item, choose 5 random images from each flank for gallery and the remaining for probe
import random

def split_flanks(map_filter):
    random.seed(42)

    gallery = {}
    query = {}

    for identity, flanks in map_filter.items():
        left_flank = flanks["left_flank"]
        right_flank = flanks["right_flank"]

        random.shuffle(left_flank)
        random.shuffle(right_flank)

        k_left = 6 if len(left_flank) > 6 else 5
        k_right = 6 if len(right_flank) > 6 else 5

        gallery[identity] = {
            "left_flank": left_flank[:k_left],
            "right_flank": right_flank[:k_right]
        }

        query[identity] = {
            "left_flank": left_flank[k_left:],
            "right_flank": right_flank[k_right:]
        }

    return gallery, query

In [111]:
train_gallery, train_query = split_flanks(train_map_filter)
test_gallery, test_query = split_flanks(test_map_filter)

In [117]:
def get_json(gallery, query):
    final_json = []
    for (identity, flanks), (identity_query, flanks_query) in zip(gallery.items(), query.items()):
        assert identity == identity_query
        final_json.append({
            "tiger_id": identity,
            "gallery": {
                "left_flank": flanks["left_flank"],
                "right_flank": flanks["right_flank"]
            },
            "query": {
                "left_flank": flanks_query["left_flank"],
                "right_flank": flanks_query["right_flank"]
            }
        })

    return final_json

In [118]:
final_train_json = get_json(train_gallery, train_query)
final_test_json = get_json(test_gallery, test_query)

In [120]:
# save json 
with open('datasets/WII/wii.coco/gallery_metadata_train.json', 'w') as f:
    json.dump(final_train_json, f, indent=4)

with open('datasets/WII/wii.coco/gallery_metadata_test.json', 'w') as f:
    json.dump(final_test_json, f, indent=4)