In [None]:
# ## Scratch
# import torch
# from torch import nn
# from torch.utils.data import DataLoader
# from datasets.topic_datasets import TopicDataset
# from models.deepflybrain import DeepFlyBrain
# import numpy as np

# device = torch.device(f'cuda:3')  # Use same device as training
# print(f"Using {device} device")

# model = DeepFlyBrain().to(device)

# checkpoint_path = './checkpoints/dfb_2025-05-28_14-46-48/model_epoch_100.pth'
# checkpoint = torch.load(checkpoint_path, map_location=device)
# model.load_state_dict(checkpoint)

# dataset = TopicDataset(
#     genome='data/resources/mm10.fa',
#     region_topic_bed='data/Furlanis_Topics_top_3k/regions_and_topics_sorted.bed',
#     transform=None,
#     target_transform=None
# )
# dataloader = DataLoader(dataset, batch_size=32, shuffle=False, num_workers=1)

# model.eval()

# all_probs = []
# all_targets = []
# all_predictions = []

# total_samples = 0
# exact_matches = 0

# with torch.no_grad():
#     for batch_idx, batch in enumerate(dataloader):
#         X, y = batch['sequence'], batch['label']
#         X, y = X.to(device), y.to(device)
#         y = y.float()
        
#         # Make predictions
#         pred_logits = model(X)
#         pred_probs = torch.sigmoid(pred_logits)
        
#         all_probs.append(pred_probs.cpu())
#         all_targets.append(y.cpu())
        
#         # Test multiple thresholds
#         for threshold in [0.1, 0.2, 0.3, 0.4, 0.5]:
#             pred_binary = (pred_probs > threshold).float()
            
#             if threshold == 0.5:  # Store predictions at default threshold
#                 all_predictions.append(pred_binary.cpu())
            
#             # Calculate exact match for this threshold
#             exact_match = ((pred_binary == y).sum(dim=1) == y.shape[1]).float().sum().item()
            
#             if batch_idx == 0:  # Print for first batch only
#                 print(f"Threshold {threshold}: {exact_match}/{y.shape[0]} exact matches")
        
#         total_samples += y.shape[0]
        
#         # Calculate exact matches at 0.5 threshold
#         pred_binary_05 = (pred_probs > 0.5).float()
#         exact_matches += ((pred_binary_05 == y).sum(dim=1) == y.shape[1]).float().sum().item()

# # Concatenate all results
# all_probs = torch.cat(all_probs, dim=0)
# all_targets = torch.cat(all_targets, dim=0)
# all_predictions = torch.cat(all_predictions, dim=0)

# # Calculate comprehensive metrics
# results = {}

# # Test different thresholds
# for threshold in [0.1, 0.2, 0.3, 0.4, 0.5]:
#     pred_binary = (all_probs > threshold).float()
    
#     # Exact match accuracy
#     exact_match = ((pred_binary == all_targets).sum(dim=1) == all_targets.shape[1]).float().mean()
    
#     # Hamming accuracy (per-label accuracy)
#     hamming_acc = (pred_binary == all_targets).float().mean()
    
#     # Jaccard accuracy (intersection over union)
#     intersection = (pred_binary * all_targets).sum(dim=1)
#     union = ((pred_binary + all_targets) > 0).float().sum(dim=1)
#     jaccard_acc = (intersection / (union + 1e-7)).mean()
    
#     # Count predictions vs targets
#     pred_count = pred_binary.sum().item()
#     target_count = all_targets.sum().item()
    
#     # Per-class F1 scores
#     tp = (pred_binary * all_targets).sum(dim=0)
#     fp = (pred_binary * (1 - all_targets)).sum(dim=0)
#     fn = ((1 - pred_binary) * all_targets).sum(dim=0)
    
#     precision = tp / (tp + fp + 1e-7)
#     recall = tp / (tp + fn + 1e-7)
#     f1_per_class = 2 * precision * recall / (precision + recall + 1e-7)
#     avg_f1 = f1_per_class.mean()
    
#     results[threshold] = {
#         'exact_match': exact_match.item(),
#         'hamming_acc': hamming_acc.item(),
#         'jaccard_acc': jaccard_acc.item(),
#         'avg_f1': avg_f1.item(),
#         'pred_count': pred_count,
#         'target_count': target_count,
#         'pred_ratio': pred_count / target_count if target_count > 0 else 0
#     }

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from datasets.topic_datasets import TopicDataset
from models.deepflybrain import DeepFlyBrain
import numpy as np

device = torch.device(f'cuda:3')  # Use same device as training
print(f"Using {device} device")

def load_model(checkpoint_path, device):
    """Load model from checkpoint"""
    model = DeepFlyBrain().to(device)
    
    # Load the state dict
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint)
    
    # Set to evaluation mode
    model.eval()
    
    print(f"Model loaded from: {checkpoint_path}")
    return model

def test_single_sample(model, sample, device):
    """Test model on a single sample"""
    model.eval()
    
    with torch.no_grad():
        # Get sequence and label
        sequence = sample['sequence'].unsqueeze(0).to(device)  # Add batch dimension
        true_label = sample['label'].unsqueeze(0).to(device)
        
        # Make prediction
        pred_logits = model(sequence)
        pred_probs = torch.sigmoid(pred_logits)
        
        # Apply threshold
        pred_binary = (pred_probs > 0.5).float()
        
        return {
            'true_label': true_label.cpu().numpy()[0],
            'pred_probs': pred_probs.cpu().numpy()[0],
            'pred_binary': pred_binary.cpu().numpy()[0],
            'sequence_shape': sequence.shape
        }

def test_model_comprehensive(model, test_dataloader, device):
    """Comprehensive model testing"""
    model.eval()
    
    all_probs = []
    all_targets = []
    all_predictions = []
    
    total_samples = 0
    exact_matches = 0
    
    with torch.no_grad():
        for batch_idx, batch in enumerate(test_dataloader):
            X, y = batch['sequence'], batch['label']
            X, y = X.to(device), y.to(device)
            y = y.float()
            
            # Make predictions
            pred_logits = model(X)
            pred_probs = torch.sigmoid(pred_logits)
            
            all_probs.append(pred_probs.cpu())
            all_targets.append(y.cpu())
            
            # Test multiple thresholds
            for threshold in [0.1, 0.2, 0.3, 0.4, 0.5]:
                pred_binary = (pred_probs > threshold).float()
                
                if threshold == 0.5:  # Store predictions at default threshold
                    all_predictions.append(pred_binary.cpu())
                
                # Calculate exact match for this threshold
                exact_match = ((pred_binary == y).sum(dim=1) == y.shape[1]).float().sum().item()
                
                if batch_idx == 0:  # Print for first batch only
                    print(f"Threshold {threshold}: {exact_match}/{y.shape[0]} exact matches")
            
            total_samples += y.shape[0]
            
            # Calculate exact matches at 0.5 threshold
            pred_binary_05 = (pred_probs > 0.5).float()
            exact_matches += ((pred_binary_05 == y).sum(dim=1) == y.shape[1]).float().sum().item()
    
    # Concatenate all results
    all_probs = torch.cat(all_probs, dim=0)
    all_targets = torch.cat(all_targets, dim=0)
    all_predictions = torch.cat(all_predictions, dim=0)
    
    # Calculate comprehensive metrics
    results = {}
    
    # Test different thresholds
    for threshold in [0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5]:
        pred_binary = (all_probs > threshold).float()
        
        # Exact match accuracy
        exact_match = ((pred_binary == all_targets).sum(dim=1) == all_targets.shape[1]).float().mean()
        
        # Hamming accuracy (per-label accuracy)
        hamming_acc = (pred_binary == all_targets).float().mean()
        
        # Jaccard accuracy (intersection over union)
        intersection = (pred_binary * all_targets).sum(dim=1)
        union = ((pred_binary + all_targets) > 0).float().sum(dim=1)
        jaccard_acc = (intersection / (union + 1e-7)).mean()
        
        # Count predictions vs targets
        pred_count = pred_binary.sum().item()
        target_count = all_targets.sum().item()
        
        # Per-class F1 scores
        tp = (pred_binary * all_targets).sum(dim=0)
        fp = (pred_binary * (1 - all_targets)).sum(dim=0)
        fn = ((1 - pred_binary) * all_targets).sum(dim=0)
        
        precision = tp / (tp + fp + 1e-7)
        recall = tp / (tp + fn + 1e-7)
        f1_per_class = 2 * precision * recall / (precision + recall + 1e-7)
        avg_f1 = f1_per_class.mean()
        
        results[threshold] = {
            'exact_match': exact_match.item(),
            'hamming_acc': hamming_acc.item(),
            'jaccard_acc': jaccard_acc.item(),
            'avg_f1': avg_f1.item(),
            'pred_count': pred_count,
            'target_count': target_count,
            'pred_ratio': pred_count / target_count if target_count > 0 else 0
        }
    
    return results, all_probs, all_targets, all_predictions

def analyze_predictions(all_probs, all_targets, threshold=0.5):
    """Analyze prediction patterns"""
    pred_binary = (all_probs > threshold).float()
    
    print(f"\n=== PREDICTION ANALYSIS (threshold={threshold}) ===")
    
    # Overall statistics
    print(f"Total samples: {all_targets.shape[0]}")
    print(f"Total classes: {all_targets.shape[1]}")
    print(f"Total target positives: {all_targets.sum().item():.0f}")
    print(f"Total predicted positives: {pred_binary.sum().item():.0f}")
    print(f"Prediction ratio: {pred_binary.sum().item() / all_targets.sum().item():.2f}")
    
    # Per-class analysis
    print(f"\nPer-class analysis:")
    for class_idx in range(all_targets.shape[1]):
        target_class = all_targets[:, class_idx]
        pred_class = pred_binary[:, class_idx]
        
        true_positives = (pred_class * target_class).sum().item()
        false_positives = (pred_class * (1 - target_class)).sum().item()
        false_negatives = ((1 - pred_class) * target_class).sum().item()
        true_negatives = ((1 - pred_class) * (1 - target_class)).sum().item()
        
        precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
        recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
        
        print(f"  Class {class_idx:2d}: P={precision:.3f}, R={recall:.3f}, F1={f1:.3f}, "
              f"TP={true_positives:3.0f}, FP={false_positives:3.0f}, FN={false_negatives:3.0f}")

if __name__ == "__main__":
    
    # 1. Load the trained model
    checkpoint_path = './checkpoints/dfb_2025-05-28_13-27-01/final_model.pth'  # Update this path
    model = load_model(checkpoint_path, device)
    
    # 2. Load test dataset
    dataset = TopicDataset(
        genome='data/resources/mm10.fa',
        region_topic_bed='data/Furlanis_Topics_top_3k/regions_and_topics_sorted.bed',
        transform=None,
        target_transform=None
    )
    
    # 3. Create test split (same as training)
    torch.manual_seed(42)  # Same seed as training for consistent split
    train_size = int(0.7 * len(dataset))
    val_size = int(0.15 * len(dataset))
    test_size = len(dataset) - val_size - train_size
    train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])
    
    test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=1)
    
    print(f"Test dataset size: {len(test_dataset)}")
    
    # 4. Test single sample
    print("\n=== SINGLE SAMPLE TEST ===")
    sample = test_dataset[0]
    result = test_single_sample(model, sample, device)
    
    print(f"Sample sequence shape: {result['sequence_shape']}")
    print(f"True labels: {result['true_label']}")
    print(f"Predicted probabilities: {result['pred_probs']}")
    print(f"Predicted binary (>0.5): {result['pred_binary']}")
    print(f"Number of true positives: {result['true_label'].sum()}")
    print(f"Number of predicted positives: {result['pred_binary'].sum()}")
    
    # 5. Comprehensive testing
    print("\n=== COMPREHENSIVE TESTING ===")
    results, all_probs, all_targets, all_predictions = test_model_comprehensive(model, test_dataloader, device)
    
    # 6. Print results for all thresholds
    print("\nResults across different thresholds:")
    print("Thresh | Exact  | Hamming| Jaccard|  F1   | Pred/Target")
    print("-------|--------|--------|--------|-------|------------")
    for threshold, metrics in results.items():
        print(f"{threshold:6.2f} | {metrics['exact_match']:6.4f} | {metrics['hamming_acc']:6.4f} | "
              f"{metrics['jaccard_acc']:6.4f} | {metrics['avg_f1']:5.3f} | "
              f"{metrics['pred_ratio']:6.2f}")
    
    # 7. Find best threshold
    best_threshold = max(results.keys(), key=lambda x: results[x]['exact_match'])
    print(f"\nBest threshold for exact match: {best_threshold} "
          f"(Exact Match: {results[best_threshold]['exact_match']:.4f})")
    
    # 8. Detailed analysis at best threshold
    analyze_predictions(all_probs, all_targets, threshold=best_threshold)
    
    # 9. Test on specific examples
    print("\n=== SPECIFIC EXAMPLE ANALYSIS ===")
    for i in range(min(5, len(test_dataset))):
        sample = test_dataset[i]
        result = test_single_sample(model, sample, device)
        print(f"\nSample {i}:")
        print(f"  True positives: {result['true_label'].sum():.0f}")
        print(f"  Pred positives: {result['pred_binary'].sum():.0f}")
        print(f"  Max prob: {result['pred_probs'].max():.3f}")
        print(f"  Min prob: {result['pred_probs'].min():.3f}")
        print(f"  Exact match: {(result['pred_binary'] == result['true_label']).all()}")