In [None]:
import torch
import torch.nn as nn
import numpy as np
from collections import defaultdict
import os
import matplotlib.pyplot as plt
import torch.nn as nn
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import torch.optim as optim
from sklearn.metrics import accuracy_score, classification_report

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


In [None]:

# Define model based on vit-base from the code.
model = MaskedAutoencoderViT(
    img_size=224, patch_size=16, in_chans=3,
    embed_dim=768, depth=12, num_heads=12,
    decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
    mlp_ratio=4.
).to(device)

checkpoint_path = "/content/drive/MyDrive/MAE_demo/MAE_demo/mae/mae_pretrain_vit_base.pth" # change path

ckpt = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(ckpt["model"], strict=False)
model.eval()



In [None]:


class AnimalCLEFDataset(Dataset):
    def __init__(self, root, split="database", transform=None):
        self.root = root.rstrip('/') # Remove trailing slash if present
        meta = pd.read_csv(f"/content/drive/MyDrive/animalclef2025/metadata.csv")
        sel = meta[meta['path'].str.contains(f"/{split}/")].reset_index(drop=True)
        if sel.empty:
            raise ValueError(f"No entries for split '{split}'")

        self.paths = sel['path'].tolist()
        self.image_ids = sel['image_id'].tolist()

        if split == 'database':
            #  Use individual identity,
            ids = sel['identity'].astype(str)

            #  Build mapping from identity string → label index
            self.id2idx = {iid: i for i, iid in enumerate(sorted(ids.unique()))}

            #  Map each sample's identity to its label
            self.labels = ids.map(self.id2idx).tolist()

            # Safety check
            num_classes = len(self.id2idx)
            assert all(0 <= label < num_classes for label in self.labels), "Invalid labels found"
        else:
            self.labels = [-1] * len(sel)

        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, i):
        # Fix: Remove .lstrip('/') to avoid removing necessary part of the path
        img_path = os.path.join(self.root, self.paths[i])
        img = Image.open(img_path).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img, self.labels[i]





In [None]:
def save_model_and_checkpoint(model, epoch, optimizer, loss, save_path):
    """Saves the model and checkpoint to the specified path."""
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }
    torch.save(checkpoint, save_path)
    print(f"Model and checkpoint saved to: {save_path}")


In [None]:
import torch.nn as nn

class ClassificationHead(nn.Module):
    def __init__(self, embed_dim=768, hidden_dim=256, num_classes=10, dropout_prob=0.5):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout_prob),
            nn.Linear(hidden_dim, num_classes)
        )
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        cls_token = x[:, 0]
        logits = self.net(cls_token)
        probs = self.softmax(logits)
        return logits, probs

class MAE_Classifier(nn.Module):
    def __init__(self, mae_model, num_classes=10, freeze_encoder=True):
        super().__init__()
        # load mae model
        self.encoder = mae_model.to(device)
        if freeze_encoder:
            for p in self.encoder.parameters():
                p.requires_grad = False

        # get embedding dim dynamically
        embed_dim = self.encoder.patch_embed.proj.out_channels
        # classification head
        self.head = ClassificationHead(embed_dim=embed_dim, num_classes=num_classes).to(device)
        # Store class-wise probabilities
        self.class_probs = defaultdict(list)
        self.class_thresholds = {}

    def forward(self, img):
        img = img.to(device)
        enc_out = self.encoder.forward_encoder(img, mask_ratio=0.0)[0]
        logits, probs = self.head(enc_out)
        return logits, probs

    def collect_train_probs(self, train_loader):
        """Collect probabilities for each class during training"""
        self.eval()
        with torch.no_grad():
            for images, labels in train_loader:
                _, probs = self.forward(images)
                probs = probs.cpu().numpy()
                for i in range(len(labels)):
                    self.class_probs[labels[i].item()].append(probs[i])

        # Calculate mean probabilities and thresholds for each class
        for class_id in self.class_probs:
            class_probs = np.array(self.class_probs[class_id])
            mean_probs = np.mean(class_probs, axis=0)
            self.class_probs[class_id] = mean_probs
            # Threshold is the highest probability value for this class
            self.class_thresholds[class_id] = np.max(mean_probs)

    def train_classifier(self, train_loader, val_loader, epochs=40, lr=1e-3):
        """Train the classification head with train/val split and log loss/acc"""
        self.train()
        criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
        optimizer = optim.AdamW(self.head.parameters(), lr=lr,  weight_decay=5e-4)

        train_losses, val_losses = [], []
        train_accuracies, val_accuracies = [], []

        for epoch in range(epochs):
            self.train()
            running_loss = 0.0
            correct, total = 0, 0

            for images, labels in train_loader:
                images, labels = images.to(device), labels.to(device)
                optimizer.zero_grad()
                logits, _ = self.forward(images)
                loss = criterion(logits, labels)
                loss.backward()
                optimizer.step()

                running_loss += loss.item()
                _, predicted = torch.max(logits, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

            train_loss = running_loss / len(train_loader)
            train_acc = 100 * correct / total
            train_losses.append(train_loss)
            train_accuracies.append(train_acc)

            # Validation
            self.eval()
            val_loss = 0.0
            val_correct, val_total = 0, 0

            with torch.no_grad():
                for images, labels in val_loader:
                    images, labels = images.to(device), labels.to(device)
                    logits, _ = self.forward(images)
                    loss = criterion(logits, labels)
                    val_loss += loss.item()
                    _, predicted = torch.max(logits, 1)
                    val_total += labels.size(0)
                    val_correct += (predicted == labels).sum().item()

            val_loss /= len(val_loader)
            val_acc = 100 * val_correct / val_total
            val_losses.append(val_loss)
            val_accuracies.append(val_acc)

            print(f"Epoch {epoch+1}/{epochs} "
                  f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% "
                  f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")

        # Plot loss curves
        plt.figure(figsize=(10, 5))
        plt.plot(train_losses, label='Train Loss')
        plt.plot(val_losses, label='Validation Loss')
        plt.xlabel("Epochs")
        plt.ylabel("Loss")
        plt.title("Training vs Validation Loss")
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        plt.show()


    def predict_with_openmax(self, test_loader):
        """Predict using OpenMax approach with softmax probabilities"""
        self.eval()
        predictions = []
        all_probs = []

        with torch.no_grad():
            for images, _ in test_loader:
                _, probs = self.forward(images)
                batch_probs = probs.cpu().numpy()
                all_probs.extend(batch_probs)

                for prob in batch_probs:
                    min_dist = float('inf')
                    predicted_class = None

                    # Compare with each known class
                    for class_id, class_mean_probs in self.class_probs.items():
                        dist = np.linalg.norm(prob - class_mean_probs)
                        if dist < min_dist:
                            min_dist = dist
                            predicted_class = class_id

                    # Check if any probability exceeds the max seen in training
                    max_train_prob = max(self.class_thresholds.values())
                    margin = 0.05  # Reduced from 0.1 to 0.05
                    if np.max(prob) > max_train_prob + margin:
                        predicted_class = "new_individual"

                    predictions.append(predicted_class)

        return predictions, np.array(all_probs)

In [None]:
def create_sample_submission(dataset_query, predictions, id2identity, file_name='/content/drive/MyDrive/MAE_demo/sample_submission.csv'):
    # Map numeric predictions back to original identity strings
    mapped_predictions = []
    for pred in predictions:
        if pred == "new_individual":
            mapped_predictions.append("new_individual")
        elif pred in id2identity:  # Check if pred (label index) is in id2identity
            mapped_predictions.append(id2identity[pred])
        else:
            # Handle unseen labels (e.g., assign them as 'new_individual')
            mapped_predictions.append("new_individual")

    df = pd.DataFrame({
        'image_id': dataset_query.image_ids,
        'identity': mapped_predictions
    })
    df.to_csv(file_name, index=False)

In [None]:


# --- Main Execution ---
if __name__ == "__main__":
    # Define transforms
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.2, contrast=0.2),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # Create datasets using AnimalCLEFDataset
    #data_path = '/content/drive/MyDrive/animalclef2025'
    images_path = '/content/drive/MyDrive/animalclef2025'

    try:
        train_dataset = AnimalCLEFDataset(images_path, split="database", transform=transform)
        test_dataset = AnimalCLEFDataset(images_path, split="query", transform=transform)
    except Exception as e:
        print(f"Error creating datasets: {e}")
        raise

    # Print dataset statistics
    print(f"Found {len(train_dataset)} training samples across {len(train_dataset.id2idx)} classes")
    print(f"Found {len(test_dataset)} query images")

    # Create data loaders
    # Split train dataset into train and validation
    train_len = int(0.8 * len(train_dataset))
    val_len = len(train_dataset) - train_len
    train_set, val_set = torch.utils.data.random_split(train_dataset, [train_len, val_len])

    train_loader = DataLoader(train_set, batch_size=32, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_set, batch_size=32, shuffle=False, num_workers=4)

    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

    # Initialize classifier
    num_classes = len(train_dataset.id2idx)
    print(f"Number of classes: {num_classes}")
    classifier = MAE_Classifier(model, num_classes=num_classes, freeze_encoder=True)

    # Train the classifier head
    print("Training classifier head...")
    classifier.train_classifier(train_loader, val_loader, epochs=20, lr=1e-2)


    # Collect training probabilities (after softmax)
    print("Collecting training probabilities...")
    classifier.collect_train_probs(train_loader)

    # Predict on test set
    print("Predicting on query images...")
    predictions, all_probs = classifier.predict_with_openmax(test_loader)

    # Reverse the id2idx mapping
    idx2id = {v: k for k, v in train_dataset.id2idx.items()}

    # Create submission file with proper identity strings
    print("Creating submission file...")
    create_sample_submission(test_dataset, predictions, idx2id)  # Pass idx2id

    print("\nSubmission file created successfully!")
    print("Sample predictions:")

    # Convert predictions in sample_df to identity strings
    sample_df = pd.DataFrame({
        'image_id': test_dataset.image_ids[:5],
        'identity': [idx2id.get(p, "new_individual") for p in predictions[:5]]  # Use idx2id for mapping
    })
    print(sample_df)