In [None]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import timm
from tqdm import tqdm
import glob


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


class ConvNeXtClassifier(nn.Module):
    def __init__(self, model_name='convnext_tiny.fb_in22k_ft_in1k', pretrained=False, dropout=0.3):
        super(ConvNeXtClassifier, self).__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained, num_classes=0)
        in_features = self.model.num_features
        
        self.head = nn.Sequential(
            nn.BatchNorm1d(in_features),
            nn.Dropout(dropout),
            nn.Linear(in_features, 512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(dropout / 2),
            nn.Linear(512, 1)
        )
        
    def forward(self, x):
        features = self.model(x)
        return self.head(features).squeeze()


class InferenceDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        self.image_paths = image_paths
        self.transform = transform
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        return image, img_path


def get_tta_transforms(img_size=256):
    """Get multiple augmented versions of the same image"""
    tta_transforms = [
        
        transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                               std=[0.229, 0.224, 0.225])
        ]),
        
        transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.RandomHorizontalFlip(p=1.0),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                               std=[0.229, 0.224, 0.225])
        ]),
        
        transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.RandomRotation(degrees=5),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                               std=[0.229, 0.224, 0.225])
        ]),
    ]
    return tta_transforms


def predict_single_model(model, loader, device, use_tta=False, tta_transforms=None):
    """Make predictions using a single model"""
    model.eval()
    predictions = []
    image_paths = []
    
    with torch.no_grad():
        for images, paths in tqdm(loader, desc='Predicting'):
            if use_tta and tta_transforms:
                
                batch_preds = []
                for tta_transform in tta_transforms:
                    tta_images = torch.stack([
                        tta_transform(Image.open(path).convert('RGB')) 
                        for path in paths
                    ]).to(device)
                    
                    outputs = model(tta_images)
                    preds = torch.sigmoid(outputs).cpu().numpy()
                    if preds.ndim == 0:
                        preds = np.array([float(preds)])
                    batch_preds.append(preds)
                
                
                preds = np.mean(batch_preds, axis=0)
            else:
                images = images.to(device)
                outputs = model(images)
                preds = torch.sigmoid(outputs).cpu().numpy()
                
                if preds.ndim == 0:
                    preds = np.array([float(preds)])
            
            predictions.extend(preds.tolist() if isinstance(preds, np.ndarray) else [preds])
            image_paths.extend(paths)
    
    return np.array(predictions), image_paths


def ensemble_predict(weights_dir, image_paths, img_size=256, batch_size=32, 
                     num_folds=5, use_tta=False, model_name='convnext_tiny.fb_in22k_ft_in1k'):
    """
    Make predictions using ensemble of all fold models
    
    Args:
        weights_dir: Directory containing model weights
        image_paths: List of image paths to predict
        img_size: Input image size
        batch_size: Batch size for inference
        num_folds: Number of CV folds
        use_tta: Whether to use test-time augmentation
        model_name: ConvNeXt model variant
    
    Returns:
        ensemble_preds: Averaged predictions from all folds
        fold_preds: Individual predictions from each fold
        image_paths: List of image paths
    """
    
    
    transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                           std=[0.229, 0.224, 0.225])
    ])
    
    
    tta_transforms = get_tta_transforms(img_size) if use_tta else None
    
    
    dataset = InferenceDataset(image_paths, transform if not use_tta else None)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, 
                       num_workers=2, pin_memory=True)
    
    all_predictions = []
    fold_info = []
    
    
    for fold in range(num_folds):
        weight_path = os.path.join(weights_dir, f'convnext_fold{fold+1}_best.pth')
        
        if not os.path.exists(weight_path):
            print(f"âš  Warning: {weight_path} not found, skipping...")
            continue
        
        print(f"\n{'='*60}")
        print(f"Loading Fold {fold + 1} weights...")
        print(f"{'='*60}")
        
        
        model = ConvNeXtClassifier(model_name=model_name, pretrained=False).to(device)
        checkpoint = torch.load(    weight_path,
    map_location=device,
    weights_only=False   
)
        model.load_state_dict(checkpoint['model_state_dict'])
        
        
        print(f"âœ“ Model loaded successfully")
        print(f"  - Epoch: {checkpoint.get('epoch', 'N/A')}")
        print(f"  - Val AUC: {checkpoint.get('val_auc', 'N/A'):.4f}")
        print(f"  - Val Acc: {checkpoint.get('val_acc', 'N/A'):.4f}")
        print(f"  - Val F1: {checkpoint.get('val_f1', 'N/A'):.4f}")
        
        
        predictions, paths = predict_single_model(
            model, loader, device, 
            use_tta=use_tta, 
            tta_transforms=tta_transforms
        )
        all_predictions.append(predictions)
        fold_info.append({
            'fold': fold + 1,
            'val_auc': checkpoint.get('val_auc', 0),
            'val_acc': checkpoint.get('val_acc', 0),
        })
        
        
        del model
        torch.cuda.empty_cache()
    
    if not all_predictions:
        raise ValueError("No model weights found! Check weights_dir path.")
    
    
    ensemble_preds = np.mean(all_predictions, axis=0)
    
    
    print(f"\n{'='*60}")
    print("ENSEMBLE SUMMARY")
    print(f"{'='*60}")
    print(f"Models used: {len(all_predictions)}/{num_folds}")
    for info in fold_info:
        print(f"  Fold {info['fold']}: AUC={info['val_auc']:.4f}, Acc={info['val_acc']:.4f}")
    print(f"{'='*60}\n")
    
    return ensemble_preds, all_predictions, paths


def predict_best_model(weights_dir, image_paths, img_size=256, batch_size=32, 
                      use_tta=False, model_name='convnext_tiny.fb_in22k_ft_in1k'):
    """
    Make predictions using the single best performing fold model
    
    Args:
        weights_dir: Directory containing model weights
        image_paths: List of image paths to predict
        img_size: Input image size
        batch_size: Batch size for inference
        use_tta: Whether to use test-time augmentation
        model_name: ConvNeXt model variant
    
    Returns:
        predictions: Model predictions
        image_paths: List of image paths
        best_fold: Best fold number
    """
    
    
    weight_files = glob.glob(os.path.join(weights_dir, 'convnext_fold*_best.pth'))
    
    if not weight_files:
        raise ValueError(f"No model weights found in {weights_dir}")
    
    best_auc = 0
    best_weight_path = None
    best_fold = None
    
    for weight_path in weight_files:
        checkpoint = torch.load(weight_path, map_location='cpu',
    weights_only=False )
        auc = checkpoint.get('val_auc', 0)
        if auc > best_auc:
            best_auc = auc
            best_weight_path = weight_path
            fold_num = os.path.basename(weight_path).split('fold')[1].split('_')[0]
            best_fold = int(fold_num)
    
    print(f"\n{'='*60}")
    print(f"Using Best Model: Fold {best_fold}")
    print(f"Validation AUC: {best_auc:.4f}")
    print(f"{'='*60}\n")
    
    
    transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                           std=[0.229, 0.224, 0.225])
    ])
    
    
    tta_transforms = get_tta_transforms(img_size) if use_tta else None
    
    
    dataset = InferenceDataset(image_paths, transform if not use_tta else None)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, 
                       num_workers=2, pin_memory=True)
    
    
    model = ConvNeXtClassifier(model_name=model_name, pretrained=False).to(device)
    checkpoint = torch.load(best_weight_path, map_location=device,
    weights_only=False )
    model.load_state_dict(checkpoint['model_state_dict'])
    
    
    predictions, paths = predict_single_model(
        model, loader, device,
        use_tta=use_tta,
        tta_transforms=tta_transforms
    )
    
    del model
    torch.cuda.empty_cache()
    
    return predictions, paths, best_fold


def main():
    """
    Example usage of inference functions
    """
    
    
    IMAGE_DIR = '/kaggle/input/notebooks/kagglertw/shemagh-right-placev2/images'
    WEIGHTS_DIR = '/kaggle/input/notebooks/kagglertw/shemagh-right-place-binaryv2/weights_convnext/'
    OUTPUT_FILE = 'convnext_predictions.csv'
    
    
    image_extensions = ('*.jpg', '*.jpeg', '*.png', '*.JPG', '*.JPEG', '*.PNG')
    image_paths = []
    for ext in image_extensions:
        image_paths.extend(glob.glob(os.path.join(IMAGE_DIR, ext)))
    
    if not image_paths:
        print(f"âš  No images found in {IMAGE_DIR}")
        return
    
    print(f"Found {len(image_paths)} images")
    
    
    print("\nInference Options:")
    print("1. Ensemble (All folds) - Most accurate")
    print("2. Best single model - Faster")
    print("3. Ensemble with TTA - Highest accuracy (slower)")
    
    choice = "1"
    
    if choice == "1":
        
        print("\nðŸš€ Running Ensemble Prediction (All Folds)...")
        ensemble_preds, fold_preds, paths = ensemble_predict(
            WEIGHTS_DIR, 
            image_paths,
            img_size=256,
            batch_size=32,
            num_folds=5,
            use_tta=False
        )
        
        
        results_df = pd.DataFrame({
            'filename': [os.path.basename(p) for p in paths],
            'image_path': paths,
            'prediction_prob': ensemble_preds,
            'prediction_class': (ensemble_preds >= 0.5).astype(int),
            'confidence': np.maximum(ensemble_preds, 1 - ensemble_preds)
        })
        
        
        for i, preds in enumerate(fold_preds):
            results_df[f'fold_{i+1}_prob'] = preds
        
        
        results_df['prediction_variance'] = np.var(fold_preds, axis=0)
        
    elif choice == "2":
        
        print("\nðŸš€ Running Best Single Model Prediction...")
        predictions, paths, best_fold = predict_best_model(
            WEIGHTS_DIR,
            image_paths,
            img_size=256,
            batch_size=32,
            use_tta=False
        )
        
        results_df = pd.DataFrame({
            'filename': [os.path.basename(p) for p in paths],
            'image_path': paths,
            'prediction_prob': predictions,
            'prediction_class': (predictions >= 0.5).astype(int),
            'confidence': np.maximum(predictions, 1 - predictions),
            'model_fold': best_fold
        })
        
    else:  
        
        print("\nðŸš€ Running Ensemble Prediction with TTA...")
        print("âš  This will be slower but more accurate")
        ensemble_preds, fold_preds, paths = ensemble_predict(
            WEIGHTS_DIR,
            image_paths,
            img_size=256,
            batch_size=16,  
            num_folds=5,
            use_tta=True
        )
        
        results_df = pd.DataFrame({
            'filename': [os.path.basename(p) for p in paths],
            'image_path': paths,
            'prediction_prob': ensemble_preds,
            'prediction_class': (ensemble_preds >= 0.5).astype(int),
            'confidence': np.maximum(ensemble_preds, 1 - ensemble_preds)
        })
        
        for i, preds in enumerate(fold_preds):
            results_df[f'fold_{i+1}_prob'] = preds
        results_df['prediction_variance'] = np.var(fold_preds, axis=0)
    
    
    results_df = results_df.sort_values('confidence', ascending=False)
    
    
    results_df.to_csv(OUTPUT_FILE, index=False)
    
    
    print(f"\n{'='*60}")
    print("PREDICTION SUMMARY")
    print(f"{'='*60}")
    print(f"Total images: {len(results_df)}")
    print(f"\nClass distribution:")
    print(results_df['prediction_class'].value_counts().sort_index())
    print(f"\nConfidence statistics:")
    print(results_df['confidence'].describe())
    
    
    print(f"\n{'='*60}")
    print("Top 10 Most Confident Predictions:")
    print(f"{'='*60}")
    print(results_df[['filename', 'prediction_class', 'prediction_prob', 'confidence']].head(10).to_string(index=False))
    
    
    print(f"\n{'='*60}")
    print("Top 10 Least Confident Predictions (Review These):")
    print(f"{'='*60}")
    print(results_df[['filename', 'prediction_class', 'prediction_prob', 'confidence']].tail(10).to_string(index=False))
    
    print(f"\n{'='*60}")
    print(f"âœ“ Results saved to {OUTPUT_FILE}")
    print(f"{'='*60}\n")

if __name__ == '__main__':
    main()



Using device: cuda
Found 842 images

Inference Options:
1. Ensemble (All folds) - Most accurate
2. Best single model - Faster
3. Ensemble with TTA - Highest accuracy (slower)

ðŸš€ Running Ensemble Prediction (All Folds)...

Loading Fold 1 weights...
âœ“ Model loaded successfully
  - Epoch: 3
  - Val AUC: 0.8414
  - Val Acc: 0.6946
  - Val F1: 0.5854


Predicting: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 27/27 [00:10<00:00,  2.61it/s]



Loading Fold 2 weights...
âœ“ Model loaded successfully
  - Epoch: 6
  - Val AUC: 0.8816
  - Val Acc: 0.7725
  - Val F1: 0.6607


Predicting: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 27/27 [00:04<00:00,  6.13it/s]



Loading Fold 3 weights...
âœ“ Model loaded successfully
  - Epoch: 12
  - Val AUC: 0.8604
  - Val Acc: 0.8084
  - Val F1: 0.6863


Predicting: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 27/27 [00:04<00:00,  6.19it/s]



Loading Fold 4 weights...
âœ“ Model loaded successfully
  - Epoch: 5
  - Val AUC: 0.8876
  - Val Acc: 0.7771
  - Val F1: 0.6542


Predicting: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 27/27 [00:04<00:00,  6.40it/s]



Loading Fold 5 weights...
âœ“ Model loaded successfully
  - Epoch: 4
  - Val AUC: 0.7998
  - Val Acc: 0.6807
  - Val F1: 0.5760


Predicting: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 27/27 [00:04<00:00,  6.17it/s]



ENSEMBLE SUMMARY
Models used: 5/5
  Fold 1: AUC=0.8414, Acc=0.6946
  Fold 2: AUC=0.8816, Acc=0.7725
  Fold 3: AUC=0.8604, Acc=0.8084
  Fold 4: AUC=0.8876, Acc=0.7771
  Fold 5: AUC=0.7998, Acc=0.6807


PREDICTION SUMMARY
Total images: 842

Class distribution:
prediction_class
0    625
1    217
Name: count, dtype: int64

Confidence statistics:
count    842.000000
mean       0.698573
std        0.084724
min        0.500010
25%        0.638446
50%        0.708972
75%        0.767214
max        0.876097
Name: confidence, dtype: float64

Top 10 Most Confident Predictions:
filename  prediction_class  prediction_prob  confidence
  41.jpg                 1         0.876097    0.876097
 319.jpg                 1         0.867351    0.867351
 693.jpg                 1         0.864327    0.864327
 712.jpg                 1         0.861979    0.861979
 117.jpg                 1         0.855168    0.855168
 267.jpg                 1         0.853722    0.853722
 558.jpg                 1        