In [None]:
# Install required packages
!pip install -q timm
!pip install -q pytorch-lightning
!pip install -q torch>=2.0.0 torchvision>=0.15.0
!pip install pytorch-lightning torchmetrics

import os
import json
import numpy as np
import pandas as pd
from pathlib import Path
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import transforms
import pytorch_lightning as pl
import timm
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report
from datetime import datetime
import zipfile
import torchmetrics

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Set up project directories
PROJECT_DIR = Path('/content/drive/MyDrive/SpeciesDetectionAI')
ZIP_PATH = PROJECT_DIR / 'dataset.zip'
DATA_DIR = PROJECT_DIR / 'dataset'
OUTPUT_DIR = PROJECT_DIR / 'outputs'
CHECKPOINT_DIR = OUTPUT_DIR / 'checkpoints'
LOGS_DIR = OUTPUT_DIR / 'logs'
ANALYSIS_DIR = OUTPUT_DIR / 'analysis'

# Create directories if they don't exist
for dir in [OUTPUT_DIR, CHECKPOINT_DIR, LOGS_DIR, ANALYSIS_DIR]:
    dir.mkdir(parents=True, exist_ok=True)

# Unzip dataset if needed
def setup_dataset():
    if not DATA_DIR.exists() and ZIP_PATH.exists():
        print("Unzipping dataset...")
        with zipfile.ZipFile(ZIP_PATH, 'r') as zip_ref:
            zip_ref.extractall(PROJECT_DIR)
        print("Dataset unzipped successfully")
    elif not DATA_DIR.exists():
        raise FileNotFoundError(f"Neither dataset directory nor zip file found at {DATA_DIR} or {ZIP_PATH}")

    # Verify the required files exist
    class_mappings_file = DATA_DIR / 'class_mappings.json'
    splits_dir = DATA_DIR / 'splits'

    if not class_mappings_file.exists():
        raise FileNotFoundError(f"class_mappings.json not found in {DATA_DIR}")
    if not splits_dir.exists():
        raise FileNotFoundError(f"splits directory not found in {DATA_DIR}")

    print("Dataset setup completed successfully")
    return class_mappings_file

class WildlifeDataset(Dataset):
    def __init__(self, root_dir, split='train', transform=None):
        self.root_dir = Path(root_dir)
        self.split = split
        self.transform = transform
        self.images = []
        self.labels = []

        # Load class mappings
        class_mappings_file = self.root_dir / 'class_mappings.json'
        print(f"Loading class mappings from: {class_mappings_file}")
        with open(class_mappings_file, 'r') as f:
            self.class_info = json.load(f)

        self.num_classes = len(self.class_info['species_list'])
        print(f"Number of classes: {self.num_classes}")

        # Look for splits directory
        splits_dir = self.root_dir / 'splits'
        if not splits_dir.is_dir():
            raise FileNotFoundError(f"Could not find splits directory in {self.root_dir}")

        split_dir = splits_dir / split
        if not split_dir.exists():
            if split == 'val':
                print("Validation directory not found, using a subset of training data for validation")
                split_dir = splits_dir / 'train'
                if not split_dir.exists():
                    raise FileNotFoundError("Could not find training split directory")
            else:
                raise FileNotFoundError(f"Could not find directory for {split} split")

        print(f"Using split directory: {split_dir}")

        # Load images for each species
        for species_name, class_id in self.class_info['species_to_idx'].items():
            species_dir = split_dir / species_name
            if not species_dir.exists():
                print(f"Warning: Directory not found for species {species_name}")
                continue

            # Try both 'images' subdirectory and direct image files
            image_paths = []
            image_dir = species_dir / 'images'
            if image_dir.exists():
                image_paths.extend(list(image_dir.glob('*.jpg')))
            image_paths.extend(list(species_dir.glob('*.jpg')))

            self.images.extend(image_paths)
            self.labels.extend([class_id] * len(image_paths))
            print(f"Found {len(image_paths)} images for species {species_name}")

        print(f"\nTotal images found for {split} split: {len(self.images)}")
        if len(self.images) == 0:
            raise ValueError(f"No valid images found in {split_dir}")

        # Calculate class weights for handling imbalance
        self.class_counts = np.bincount(self.labels)
        print("\nClass distribution:")
        for class_idx, count in enumerate(self.class_counts):
            species_name = next(name for name, idx in self.class_info['species_to_idx'].items() if idx == class_idx)
            print(f"{species_name}: {count} images")

        self.class_weights = 1. / torch.tensor(self.class_counts, dtype=torch.float)
        self.weights = self.class_weights[self.labels]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]

        # Load and transform image
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)

        return image, label

class WildlifeClassifier(pl.LightningModule):
    def __init__(self, num_classes, learning_rate=1e-4):
        super().__init__()
        self.save_hyperparameters()

        # Load EfficientNet-B2 model
        self.model = timm.create_model('tf_efficientnet_b2',
                                     pretrained=True,
                                     num_classes=num_classes)

        # Loss function with class weights
        self.criterion = nn.CrossEntropyLoss()

        # Updated metrics initialization using torchmetrics
        self.train_acc = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes)
        self.val_acc = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)

        # Calculate accuracy
        preds = torch.argmax(logits, dim=1)
        acc = self.train_acc(preds, y)

        # Log metrics
        self.log('train_loss', loss, prog_bar=True)
        self.log('train_acc', acc, prog_bar=True)

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)

        # Calculate accuracy
        preds = torch.argmax(logits, dim=1)
        acc = self.val_acc(preds, y)

        # Log metrics
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)

        return {'val_loss': loss, 'val_acc': acc}

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.hparams.learning_rate,
            weight_decay=0.0001
        )

        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='min',
            factor=0.1,
            patience=3,
            verbose=True
        )

        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'monitor': 'val_loss'
            }
        }

def main():
    # Setup dataset first
    try:
        setup_dataset()
    except Exception as e:
        print(f"Error setting up dataset: {str(e)}")
        return

    # Get current timestamp
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')

    # Data transforms
    train_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ColorJitter(brightness=0.2, contrast=0.2),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    val_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # Create datasets
    train_dataset = WildlifeDataset(
        root_dir=DATA_DIR,
        split='train',
        transform=train_transform
    )

    # Setup weighted sampler for handling class imbalance
    sampler = WeightedRandomSampler(
        weights=train_dataset.weights,
        num_samples=len(train_dataset),
        replacement=True
    )

    # Create train loader with sampler
    train_loader = DataLoader(
        train_dataset,
        batch_size=32,
        sampler=sampler,
        num_workers=2,
        pin_memory=True
    )

    # Try to create validation dataset
    try:
        val_dataset = WildlifeDataset(
            root_dir=DATA_DIR,
            split='val',
            transform=val_transform
        )
    except FileNotFoundError:
        print("Using 20% of training data for validation")
        train_size = int(0.8 * len(train_dataset))
        val_size = len(train_dataset) - train_size
        train_dataset, val_dataset = torch.utils.data.random_split(
            train_dataset,
            [train_size, val_size]
        )

    val_loader = DataLoader(
        val_dataset,
        batch_size=32,
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )

    # Initialize model
    model = WildlifeClassifier(
        num_classes=len(train_dataset.class_info['species_list']),
        learning_rate=1e-4
    )

    # Initialize trainer
    trainer = pl.Trainer(
        max_epochs=15,
        accelerator='gpu',
        devices=1,
        callbacks=[
            pl.callbacks.ModelCheckpoint(
                dirpath=str(CHECKPOINT_DIR),
                filename=f'efficientnet_b2_{timestamp}' + '-{epoch:02d}-{val_loss:.2f}',
                save_top_k=3,
                monitor='val_loss'
            ),
            pl.callbacks.EarlyStopping(
                monitor='val_loss',
                patience=5,
                mode='min'
            )
        ],
        logger=pl.loggers.CSVLogger(save_dir=str(LOGS_DIR), name='wildlife_classification')
    )

    # Train model
    trainer.fit(model, train_loader, val_loader)

    # Save final model
    final_model_path = OUTPUT_DIR / f'final_model_{timestamp}.pth'
    torch.save(model.state_dict(), str(final_model_path))

    # Save model configuration
    config = {
        'num_classes': len(train_dataset.class_info['species_list']),
        'class_mapping': train_dataset.class_info['species_to_idx'],
        'image_size': (224, 224),
        'timestamp': timestamp
    }

    with open(OUTPUT_DIR / f'model_config_{timestamp}.json', 'w') as f:
        json.dump(config, f, indent=4)

if __name__ == '__main__':
    main()

Mounted at /content/drive
Unzipping dataset...
Dataset unzipped successfully
Dataset setup completed successfully
Loading class mappings from: /content/drive/MyDrive/SpeciesDetectionAI/dataset/class_mappings.json
Number of classes: 38
Using split directory: /content/drive/MyDrive/SpeciesDetectionAI/dataset/splits/train
Found 187 images for species banded civet
Found 2 images for species banded linsang
Found 5 images for species bay cat
Found 316 images for species bearded pig
Found 4 images for species binturong
Found 4 images for species bird sp
Found 102 images for species bulwer's pheasant
Found 343 images for species great argus pheasant
Found 3 images for species ground tufted squirrel
Found 204 images for species human
Found 105 images for species leopard cat
Found 293 images for species long-tailed macaque
Found 20 images for species long-tailed porcupine
Found 220 images for species malayan civet
Found 259 images for species malayan porcupine
Found 5 images for species marbled 

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model.safetensors:   0%|          | 0.00/36.8M [00:00<?, ?B/s]

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name      | Type               | Params | Mode 
---------------------------------------------------------
0 | model     | EfficientNet       | 7.8 M  | train
1 | criterion | CrossEntropyLoss   | 0      | train
2 | train_acc | MulticlassAccuracy | 0      | train
3 | val_acc   | MulticlassAccuracy | 0      | train
---------------------------------------------------------
7.8 M     Trainable params
0         Non-trainable params
7.8 M     Total params
31.018    Total estimated model params size (MB)
476       Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=15` reached.


In [None]:
!pip install -q timm seaborn scikit-learn tqdm

In [None]:
!pip install -q timm seaborn scikit-learn tqdm

import torch
import timm
import json
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from pathlib import Path
from PIL import Image
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import confusion_matrix, classification_report, precision_recall_curve, roc_curve, auc
from tqdm.notebook import tqdm
import zipfile
import os

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Set up project directories
PROJECT_DIR = Path('/content/drive/MyDrive/SpeciesDetectionAI')
MODEL_PATH = PROJECT_DIR / 'outputs' / 'final_model_20241125_041231.pth'
CONFIG_PATH = PROJECT_DIR / 'outputs' / 'model_config_20241125_041231.json'
ZIP_PATH = PROJECT_DIR / 'dataset.zip'
DATA_DIR = PROJECT_DIR / 'dataset'
OUTPUT_DIR = PROJECT_DIR / 'outputs'
ANALYSIS_DIR = OUTPUT_DIR / 'analysis'

# Create directories
print("Creating directories...")
OUTPUT_DIR.mkdir(exist_ok=True, parents=True)
ANALYSIS_DIR.mkdir(exist_ok=True, parents=True)

# Verify paths
print("\nVerifying paths...")
print(f"Model exists: {MODEL_PATH.exists()}")
print(f"Config exists: {CONFIG_PATH.exists()}")
print(f"Dataset zip exists: {ZIP_PATH.exists()}")

# Load model config first
print("\nLoading model configuration...")
with open(CONFIG_PATH, 'r') as f:
    model_config = json.load(f)
    print("Model config loaded successfully")

class WildlifeDataset(Dataset):
    def __init__(self, root_dir, split='train', transform=None):
        self.root_dir = Path(root_dir)
        self.split = split
        self.transform = transform
        self.images = []
        self.labels = []

        # Load class mappings
        with open(CONFIG_PATH, 'r') as f:
            self.class_info = json.load(f)

        self.num_classes = len(self.class_info['class_mapping'])
        print(f"Number of classes: {self.num_classes}")

        # Look for splits directory
        splits_dir = self.root_dir / 'splits'
        split_dir = splits_dir / split

        if not split_dir.exists():
            # Try looking in the root of the dataset directory
            split_dir = self.root_dir / split
            if not split_dir.exists():
                raise FileNotFoundError(f"Could not find directory for {split} split")

        print(f"Using split directory: {split_dir}")

        # Load images for each species
        for species_name, class_id in self.class_info['class_mapping'].items():
            species_dir = split_dir / species_name
            if not species_dir.exists():
                print(f"Warning: Directory not found for species {species_name}")
                continue

            # Try both 'images' subdirectory and direct image files
            image_paths = []
            image_dir = species_dir / 'images'
            if image_dir.exists():
                image_paths.extend(list(image_dir.glob('*.jpg')))
            image_paths.extend(list(species_dir.glob('*.jpg')))

            self.images.extend(image_paths)
            self.labels.extend([class_id] * len(image_paths))

        print(f"Total images found for {split} split: {len(self.images)}")
        if len(self.images) == 0:
            raise ValueError(f"No valid images found in {split_dir}")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]

        # Load and transform image
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)

        return image, label

def unzip_dataset():
    if not DATA_DIR.exists() and ZIP_PATH.exists():
        print("\nUnzipping dataset...")
        with zipfile.ZipFile(ZIP_PATH, 'r') as zip_ref:
            zip_ref.extractall(PROJECT_DIR)
        print("Dataset unzipped successfully")
    elif not DATA_DIR.exists():
        raise FileNotFoundError(f"Neither dataset directory nor zip file found")
    else:
        print("\nDataset directory already exists")

# Unzip dataset if needed
unzip_dataset()

# Initialize model
print("\nInitializing model...")
model = timm.create_model('tf_efficientnet_b2', pretrained=False, num_classes=model_config['num_classes'])
state_dict = torch.load(MODEL_PATH, map_location='cuda' if torch.cuda.is_available() else 'cpu')
# Remove 'model.' prefix if present
state_dict = {k.replace('model.', ''): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model.eval()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)
print(f"Model loaded and moved to {device}")

# Setup transforms
test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load test data
print("\nLoading test dataset...")
test_dataset = WildlifeDataset(DATA_DIR, split='test', transform=test_transform)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=2)
print(f"Test dataset loaded with {len(test_dataset)} images")

# Collect predictions and ground truth
print("\nRunning predictions...")
all_preds = []
all_labels = []
all_probs = []

with torch.no_grad():
    for images, labels in tqdm(test_loader, desc="Evaluating"):
        images = images.to(device)
        outputs = model(images)
        probabilities = torch.softmax(outputs, dim=1)
        _, predictions = torch.max(outputs, 1)

        all_probs.extend(probabilities.cpu().numpy())
        all_preds.extend(predictions.cpu().numpy())
        all_labels.extend(labels.numpy())

all_probs = np.array(all_probs)
all_preds = np.array(all_preds)
all_labels = np.array(all_labels)

# Get class names
idx_to_class = {v: k for k, v in model_config['class_mapping'].items()}
class_names = [idx_to_class[i] for i in range(len(idx_to_class))]

print("\nGenerating visualizations...")

# 1. Confusion Matrix
plt.figure(figsize=(15, 15))
cm = confusion_matrix(all_labels, all_preds)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=class_names, yticklabels=class_names)
plt.title('Confusion Matrix')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=45)
plt.tight_layout()
plt.savefig(ANALYSIS_DIR / 'confusion_matrix.png', dpi=300, bbox_inches='tight')
plt.close()

# 2. Per-class Accuracy Bar Plot
plt.figure(figsize=(15, 6))
class_accuracy = cm.diagonal() / cm.sum(axis=1)
sns.barplot(x=class_names, y=class_accuracy)
plt.title('Per-class Accuracy')
plt.xticks(rotation=45, ha='right')
plt.ylabel('Accuracy')
plt.tight_layout()
plt.savefig(ANALYSIS_DIR / 'per_class_accuracy.png', dpi=300, bbox_inches='tight')
plt.close()

# 3. ROC Curves
plt.figure(figsize=(12, 8))
for i in range(len(class_names)):
    fpr, tpr, _ = roc_curve(all_labels == i, all_probs[:, i])
    roc_auc = auc(fpr, tpr)
    plt.plot(fpr, tpr, label=f'{class_names[i][:10]}... (AUC = {roc_auc:.2f})')

plt.plot([0, 1], [0, 1], 'k--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curves')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize='small')
plt.tight_layout()
plt.savefig(ANALYSIS_DIR / 'roc_curves.png', dpi=300, bbox_inches='tight')
plt.close()

# 4. Precision-Recall Curves
plt.figure(figsize=(12, 8))
for i in range(len(class_names)):
    precision, recall, _ = precision_recall_curve(all_labels == i, all_probs[:, i])
    plt.plot(recall, precision, label=f'{class_names[i][:10]}...')

plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curves')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize='small')
plt.tight_layout()
plt.savefig(ANALYSIS_DIR / 'precision_recall_curves.png', dpi=300, bbox_inches='tight')
plt.close()

# 5. Generate Classification Report
report = classification_report(all_labels, all_preds, target_names=class_names)
print("\nClassification Report:")
print(report)

# Save classification report
with open(ANALYSIS_DIR / 'classification_report.txt', 'w') as f:
    f.write(report)

# 6. Top Misclassifications Analysis
misclassified = []
for pred, true, prob in zip(all_preds, all_labels, all_probs):
    if pred != true:
        misclassified.append({
            'true': idx_to_class[true],
            'predicted': idx_to_class[pred],
            'confidence': prob[pred]
        })

if misclassified:
    misclassified.sort(key=lambda x: x['confidence'], reverse=True)

    # Plot top misclassifications
    plt.figure(figsize=(12, 6))
    top_n = min(10, len(misclassified))
    top_misclassified = misclassified[:top_n]

    plt.bar(
        range(top_n),
        [x['confidence'] for x in top_misclassified],
        tick_label=[f"{x['true'][:10]}...\n→\n{x['predicted'][:10]}..." for x in top_misclassified]
    )
    plt.title('Top Misclassifications with Confidence')
    plt.ylabel('Confidence')
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.savefig(ANALYSIS_DIR / 'top_misclassifications.png', dpi=300, bbox_inches='tight')
plt.close()

# Calculate and save overall metrics
accuracy = (all_preds == all_labels).mean()
print(f"\nOverall Test Accuracy: {accuracy:.4f}")

test_metrics = {
    'overall_accuracy': float(accuracy),
    'per_class_accuracy': {class_names[i]: float(class_accuracy[i]) for i in range(len(class_names))},
    'number_of_test_samples': len(test_dataset),
    'number_of_classes': len(class_names)
}

with open(ANALYSIS_DIR / 'test_metrics.json', 'w') as f:
    json.dump(test_metrics, f, indent=4)

print(f"\nAll visualizations have been saved to {ANALYSIS_DIR}")

Mounted at /content/drive
Creating directories...

Verifying paths...
Model exists: True
Config exists: True
Dataset zip exists: True

Loading model configuration...
Model config loaded successfully

Dataset directory already exists

Initializing model...


  state_dict = torch.load(MODEL_PATH, map_location='cuda' if torch.cuda.is_available() else 'cpu')


Model loaded and moved to cuda

Loading test dataset...
Number of classes: 38
Using split directory: /content/drive/MyDrive/SpeciesDetectionAI/dataset/splits/test
Total images found for test split: 0


ValueError: No valid images found in /content/drive/MyDrive/SpeciesDetectionAI/dataset/splits/test

In [None]:
# First, let's check the actual structure of the test folder
import os

def print_directory_structure(startpath):
    for root, dirs, files in os.walk(startpath):
        level = root.replace(startpath, '').count(os.sep)
        indent = ' ' * 4 * level
        print(f'{indent}{os.path.basename(root)}/')
        subindent = ' ' * 4 * (level + 1)
        for f in files[:5]:  # Print first 5 files in each directory
            print(f'{subindent}{f}')
        if len(files) > 5:
            print(f'{subindent}... and {len(files)-5} more files')

print("Checking dataset structure:")
print("\nMain dataset directory structure:")
print_directory_structure(str(DATA_DIR))

# Now let's modify the WildlifeDataset class to be more flexible
class WildlifeDataset(Dataset):
    def __init__(self, root_dir, split='train', transform=None):
        self.root_dir = Path(root_dir)
        self.split = split
        self.transform = transform
        self.images = []
        self.labels = []

        # Load class mappings
        with open(CONFIG_PATH, 'r') as f:
            self.class_info = json.load(f)

        self.num_classes = len(self.class_info['class_mapping'])
        print(f"Number of classes: {self.num_classes}")

        # Try different possible test directory locations
        possible_paths = [
            self.root_dir / 'splits' / split,
            self.root_dir / split,
            self.root_dir / 'test',
            self.root_dir
        ]

        split_dir = None
        for path in possible_paths:
            if path.exists():
                print(f"Found potential split directory: {path}")
                # Verify it contains image data
                if any(path.glob('**/*.jpg')) or any(path.glob('**/*.jpeg')) or any(path.glob('**/*.png')):
                    split_dir = path
                    break

        if split_dir is None:
            raise FileNotFoundError(f"Could not find valid directory for {split} split")

        print(f"Using split directory: {split_dir}")

        # Load images using different possible directory structures
        for species_name, class_id in self.class_info['class_mapping'].items():
            # Try different possible paths for each species
            possible_species_paths = [
                split_dir / species_name,
                split_dir / 'images' / species_name,
                split_dir / species_name / 'images',
            ]

            images_found = False
            for species_dir in possible_species_paths:
                if species_dir.exists():
                    print(f"Found directory for species {species_name} at {species_dir}")
                    # Look for images in this directory and all subdirectories
                    image_paths = []
                    for ext in ['.jpg', '.jpeg', '.png']:
                        image_paths.extend(list(species_dir.glob(f'**/*{ext}')))

                    if image_paths:
                        self.images.extend(image_paths)
                        self.labels.extend([class_id] * len(image_paths))
                        print(f"Found {len(image_paths)} images for species {species_name}")
                        images_found = True
                        break

            if not images_found:
                print(f"Warning: No images found for species {species_name}")

        print(f"\nTotal images found for {split} split: {len(self.images)}")
        if len(self.images) == 0:
            raise ValueError(f"No valid images found in any of the searched directories")

        # Print some sample paths to verify
        print("\nSample image paths:")
        for path in self.images[:5]:
            print(path)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]

        # Load and transform image
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)

        return image, label

# Now let's try to load the test dataset
print("\nAttempting to load test dataset...")

try:
    test_dataset = WildlifeDataset(DATA_DIR, split='test', transform=test_transform)
    print("\nSuccess! Found test dataset.")
    print(f"Number of test images: {len(test_dataset)}")
except Exception as e:
    print(f"\nError loading test dataset: {str(e)}")
    print("\nPlease verify the correct location of your test images and their organization.")

Checking dataset structure:

Main dataset directory structure:
dataset/
    class_mappings.json
    split_statistics.json
    config/
    splits/
        test/
            banded civet/
                images/
                    banded_civet_104.JPG
                    banded_civet_105.JPG
                    banded_civet_106.JPG
                    banded_civet_119.JPG
                    banded_civet_12.JPG
                    ... and 21 more files
                labels/
            banded linsang/
                images/
                labels/
            bay cat/
                images/
                labels/
            bearded pig/
                images/
                    bearded_pig_100.JPG
                    bearded_pig_1005.JPG
                    bearded_pig_101.JPG
                    bearded_pig_1021.JPG
                    bearded_pig_1024.JPG
                    ... and 248 more files
                labels/
            binturong/
                images/
         

In [None]:
import torch
import timm
import json
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from pathlib import Path
from PIL import Image
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import confusion_matrix, classification_report, precision_recall_curve, roc_curve, auc
from tqdm.notebook import tqdm

# Modify dataset class to load validation data
class WildlifeDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = Path(root_dir)
        self.transform = transform
        self.images = []
        self.labels = []

        # Load class mappings
        with open(CONFIG_PATH, 'r') as f:
            self.class_info = json.load(f)

        self.num_classes = len(self.class_info['class_mapping'])
        print(f"Number of classes: {self.num_classes}")

        # Use validation split
        split_dir = self.root_dir / 'splits' / 'val'
        print(f"Using validation directory: {split_dir}")

        # Load images for each species
        for species_name, class_id in self.class_info['class_mapping'].items():
            species_dir = split_dir / species_name / 'images'
            if species_dir.exists():
                image_paths = list(species_dir.glob('*.jpg')) + list(species_dir.glob('*.jpeg'))
                if image_paths:
                    self.images.extend(image_paths)
                    self.labels.extend([class_id] * len(image_paths))
                    print(f"Found {len(image_paths)} images for species {species_name}")

        print(f"\nTotal images found: {len(self.images)}")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]

        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)

        return image, label

# Setup transforms
val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load model
print("\nInitializing model...")
model = timm.create_model('tf_efficientnet_b2', pretrained=False, num_classes=38)
state_dict = torch.load(MODEL_PATH, map_location='cuda' if torch.cuda.is_available() else 'cpu')
state_dict = {k.replace('model.', ''): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model.eval()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)
print(f"Model loaded and moved to {device}")

# Load validation dataset
print("\nLoading validation dataset...")
val_dataset = WildlifeDataset(DATA_DIR, transform=val_transform)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=2)

# Collect predictions
print("\nRunning predictions...")
all_preds = []
all_labels = []
all_probs = []

with torch.no_grad():
    for images, labels in tqdm(val_loader, desc="Evaluating"):
        images = images.to(device)
        outputs = model(images)
        probabilities = torch.softmax(outputs, dim=1)
        _, predictions = torch.max(outputs, 1)

        all_probs.extend(probabilities.cpu().numpy())
        all_preds.extend(predictions.cpu().numpy())
        all_labels.extend(labels.numpy())

all_probs = np.array(all_probs)
all_preds = np.array(all_preds)
all_labels = np.array(all_labels)

# Get class names
idx_to_class = {v: k for k, v in val_dataset.class_info['class_mapping'].items()}
class_names = [idx_to_class[i] for i in range(len(idx_to_class))]

print("\nGenerating visualizations...")

# Create analysis directory if it doesn't exist
ANALYSIS_DIR = PROJECT_DIR / 'outputs' / 'analysis'
ANALYSIS_DIR.mkdir(parents=True, exist_ok=True)

# 1. Confusion Matrix
plt.figure(figsize=(20, 20))
cm = confusion_matrix(all_labels, all_preds)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=class_names, yticklabels=class_names)
plt.title('Confusion Matrix')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=45)
plt.tight_layout()
plt.savefig(ANALYSIS_DIR / 'confusion_matrix.png', dpi=300, bbox_inches='tight')
plt.close()

# 2. Per-class Accuracy
plt.figure(figsize=(20, 8))
class_accuracy = cm.diagonal() / cm.sum(axis=1)
plt.bar(class_names, class_accuracy)
plt.title('Per-class Accuracy')
plt.xticks(rotation=45, ha='right')
plt.ylabel('Accuracy')
plt.tight_layout()
plt.savefig(ANALYSIS_DIR / 'per_class_accuracy.png', dpi=300, bbox_inches='tight')
plt.close()

# 3. ROC Curves
plt.figure(figsize=(15, 10))
for i in range(len(class_names)):
    if len(np.unique(all_labels == i)) > 1:  # Only plot if we have both positive and negative samples
        fpr, tpr, _ = roc_curve(all_labels == i, all_probs[:, i])
        roc_auc = auc(fpr, tpr)
        plt.plot(fpr, tpr, label=f'{class_names[i][:15]}... (AUC = {roc_auc:.2f})')

plt.plot([0, 1], [0, 1], 'k--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curves')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize='small')
plt.tight_layout()
plt.savefig(ANALYSIS_DIR / 'roc_curves.png', dpi=300, bbox_inches='tight')
plt.close()

# 4. Classification Report
report = classification_report(all_labels, all_preds, target_names=class_names)
print("\nClassification Report:")
print(report)

with open(ANALYSIS_DIR / 'classification_report.txt', 'w') as f:
    f.write(report)

# 5. Overall Metrics
accuracy = (all_preds == all_labels).mean()
print(f"\nOverall Validation Accuracy: {accuracy:.4f}")

# Save metrics
metrics = {
    'overall_accuracy': float(accuracy),
    'per_class_accuracy': {class_names[i]: float(class_accuracy[i]) for i in range(len(class_names))},
    'number_of_validation_samples': len(val_dataset),
    'number_of_classes': len(class_names)
}

with open(ANALYSIS_DIR / 'validation_metrics.json', 'w') as f:
    json.dump(metrics, f, indent=4)

print(f"\nAll visualizations and metrics have been saved to {ANALYSIS_DIR}")


Initializing model...


  state_dict = torch.load(MODEL_PATH, map_location='cuda' if torch.cuda.is_available() else 'cpu')


Model loaded and moved to cuda

Loading validation dataset...
Number of classes: 38
Using validation directory: /content/drive/MyDrive/SpeciesDetectionAI/dataset/splits/val
Found 37 images for species banded civet
Found 1 images for species banded linsang
Found 1 images for species bay cat
Found 63 images for species bearded pig
Found 1 images for species binturong
Found 1 images for species bird sp
Found 20 images for species bulwer's pheasant
Found 68 images for species great argus pheasant
Found 1 images for species ground tufted squirrel
Found 40 images for species human
Found 21 images for species leopard cat
Found 58 images for species long-tailed macaque
Found 4 images for species long-tailed porcupine
Found 44 images for species malayan civet
Found 51 images for species malayan porcupine
Found 1 images for species marbled cat
Found 1 images for species masked palm civet
Found 25 images for species mongoose sp
Found 2 images for species monitor lizard
Found 9 images for species 

Evaluating:   0%|          | 0/28 [00:00<?, ?it/s]


Generating visualizations...

Classification Report:
                          precision    recall  f1-score   support

            banded civet       0.97      1.00      0.99        37
          banded linsang       1.00      1.00      1.00         1
                 bay cat       1.00      1.00      1.00         1
             bearded pig       1.00      0.98      0.99        63
               binturong       1.00      1.00      1.00         1
                 bird sp       1.00      1.00      1.00         1
       bulwer's pheasant       1.00      1.00      1.00        20
    great argus pheasant       1.00      0.99      0.99        68
  ground tufted squirrel       1.00      1.00      1.00         1
                   human       1.00      1.00      1.00        40
             leopard cat       1.00      1.00      1.00        21
     long-tailed macaque       0.97      1.00      0.98        58
   long-tailed porcupine       0.80      1.00      0.89         4
           malayan ci