## 1. Install Libraries

In [None]:
# Install libraries
%pip install --upgrade timm albumentations -q

## 2. Import Libraries

In [None]:
import os
import time
import random
import warnings
warnings.filterwarnings('ignore')

import timm
import torch
import albumentations as A
import pandas as pd
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
from albumentations.pytorch import ToTensorV2
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
from sklearn.model_selection import StratifiedKFold

print(f"PyTorch: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 3. Set Random Seed

In [None]:
# Seed
SEED = 42
os.environ['PYTHONHASHSEED'] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

print(f"Random seed set to {SEED}")

## 4. Configuration

In [None]:
# Configuration
CFG = {
    'IMG_SIZE': 384,
    'BATCH_SIZE': 8,
    'NUM_CLASSES': 17,
    'EPOCHS': 30,
    'LR': 5e-5,
    'WEIGHT_DECAY': 0.01,
    'MODEL_NAME': 'convnextv2_base.fcmae_ft_in22k_in1k_384',
    'DROPOUT': 0.6,
    'MIXUP_ALPHA': 0.3,
    'FOCAL_GAMMA': 2.0,
    'FOCAL_ALPHA': 1.0,
    'WARMUP_EPOCHS': 5,
    'USE_TTA': True,
    'N_FOLDS': 5,
    'BASE_DIR': '/home/realtheai/cv_competetion',
    'TRAIN_CSV': '/home/realtheai/cv_competetion/data/train.csv',
    'SAMPLE_SUBMISSION_CSV': '/home/realtheai/cv_competetion/data/sample_submission.csv',
    'TRAIN_IMG_DIR': '/home/realtheai/cv_competetion/data/train',
    'TEST_IMG_DIR': '/home/realtheai/cv_competetion/data/test',
    'DEVICE': 'cuda' if torch.cuda.is_available() else 'cpu',
}

print("\n" + "="*60)
print("Configuration - Strategy A")
print("="*60)
print(f"Model: {CFG['MODEL_NAME']}")
print(f"Image Size: {CFG['IMG_SIZE']} (↑ from 224)")
print(f"Batch Size: {CFG['BATCH_SIZE']} (↓ due to larger images)")
print(f"Epochs: {CFG['EPOCHS']}")
print(f"Learning Rate: {CFG['LR']} (↓ from 1e-4)")
print(f"Warmup Epochs: {CFG['WARMUP_EPOCHS']}")
print(f"Dropout: {CFG['DROPOUT']} (↑ from 0.5)")
print(f"Mixup Alpha: {CFG['MIXUP_ALPHA']}")
print(f"Focal Loss Gamma: {CFG['FOCAL_GAMMA']}")
print(f"TTA: {CFG['USE_TTA']}")
print(f"N-Folds: {CFG['N_FOLDS']}")
print(f"Device: {CFG['DEVICE']}")
print("="*60)

## 5. Dataset Class

In [5]:
# Dataset with Mixup support
class DocumentDataset(Dataset):
    def __init__(self, df, img_dir, transform=None, is_test=False, mixup_alpha=0.0):
        self.df = df
        self.img_dir = img_dir
        self.transform = transform
        self.is_test = is_test
        self.mixup_alpha = mixup_alpha

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_name = self.df.iloc[idx]['ID']
        img_path = os.path.join(self.img_dir, img_name)
        img = np.array(Image.open(img_path).convert('RGB'))
        
        if self.transform:
            img = self.transform(image=img)['image']
        
        if self.is_test:
            return img
        
        target = self.df.iloc[idx]['target']
        
        if self.mixup_alpha > 0 and random.random() < 0.5:
            mix_idx = random.randint(0, len(self.df) - 1)
            mix_img_name = self.df.iloc[mix_idx]['ID']
            mix_img_path = os.path.join(self.img_dir, mix_img_name)
            mix_img = np.array(Image.open(mix_img_path).convert('RGB'))
            
            if self.transform:
                mix_img = self.transform(image=mix_img)['image']
            
            lam = np.random.beta(self.mixup_alpha, self.mixup_alpha)
            img = lam * img + (1 - lam) * mix_img
            mix_target = self.df.iloc[mix_idx]['target']
            
            return img, target, mix_target, lam
        
        return img, target, target, 1.0

## 6. Data Augmentation

In [None]:
# Augmentation - Redesigned for Train/Test quality gap
trn_transform = A.Compose([
    A.Resize(CFG['IMG_SIZE'], CFG['IMG_SIZE']),
    A.OneOf([
        A.GaussianBlur(blur_limit=(3, 11), p=1.0),
        A.MotionBlur(blur_limit=9, p=1.0),
        A.MedianBlur(blur_limit=7, p=1.0),
    ], p=0.5),
    A.ImageCompression(quality_lower=50, quality_upper=100, p=0.5),
    A.Downscale(scale_min=0.6, scale_max=0.8, p=0.8),
    A.Sharpen(alpha=(0.2, 0.5), lightness=(0.5, 1.0), p=0.3),
    A.GaussNoise(var_limit=(10, 50), p=0.4),
    A.Rotate(limit=5, p=0.3),
    A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.1, rotate_limit=5, p=0.3),
    A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.5),
    A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=20, val_shift_limit=15, p=0.3),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

tst_transform = A.Compose([
    A.Resize(CFG['IMG_SIZE'], CFG['IMG_SIZE']),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

print("Augmentation setup complete - Redesigned for quality gap")
print("  - Blur: 0.5 (reduced from 0.8)")
print("  - Downscale: 0.8 with 0.6-0.8 range (stronger)")
print("  - Sharpen: 0.3 (new)")
print("  - Other aug probabilities reduced")

## 7. Model & Loss Function

In [None]:
# Focal Loss for hard samples
class FocalLoss(nn.Module):
    def __init__(self, alpha=1.0, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        
    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
        return focal_loss.mean()

# ConvNeXt V2 Model
class ConvNeXtClassifier(nn.Module):
    def __init__(self, model_name='convnextv2_base.fcmae_ft_in22k_in1k_384', num_classes=17, dropout=0.6):
        super().__init__()
        self.backbone = timm.create_model(model_name, pretrained=True, num_classes=0)
        in_features = self.backbone.num_features
        
        self.classifier = nn.Sequential(
            nn.Linear(in_features, 512),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256, num_classes)
        )
    
    def forward(self, x):
        return self.classifier(self.backbone(x))

print(f"Model: {CFG['MODEL_NAME']}")
print(f"Focal Loss: alpha={CFG['FOCAL_ALPHA']}, gamma={CFG['FOCAL_GAMMA']}")

## 8. Load Data

In [None]:
# Load data
train_df = pd.read_csv(CFG['TRAIN_CSV'])
test_df = pd.read_csv(CFG['SAMPLE_SUBMISSION_CSV'])
meta_df = pd.read_csv(f"{CFG['BASE_DIR']}/data/meta.csv")

print(f"\nData loaded:")
print(f"  Train: {len(train_df)} samples")
print(f"  Test: {len(test_df)} samples")
print(f"  Classes: {CFG['NUM_CLASSES']}")

## 9. Training Functions

In [9]:
# Training functions with Mixup support
def train_epoch(model, loader, optimizer, criterion, scheduler, device, current_epoch, warmup_epochs):
    model.train()
    total_loss = 0
    
    for imgs, targets, mix_targets, lam in tqdm(loader, desc="Training"):
        imgs = imgs.to(device)
        targets = targets.to(device)
        mix_targets = mix_targets.to(device)
        
        optimizer.zero_grad()
        outputs = model(imgs)
        
        loss = lam.mean() * criterion(outputs, targets) + (1 - lam.mean()) * criterion(outputs, mix_targets)
        
        loss.backward()
        optimizer.step()
        
        if current_epoch >= warmup_epochs:
            scheduler.step()
        
        total_loss += loss.item()
    
    return total_loss / len(loader)

def validate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds, all_labels = [], []
    
    with torch.no_grad():
        for imgs, targets, _, _ in tqdm(loader, desc="Validation"):
            imgs, targets = imgs.to(device), targets.to(device)
            outputs = model(imgs)
            
            total_loss += criterion(outputs, targets).item()
            all_preds.extend(outputs.argmax(dim=1).cpu().numpy())
            all_labels.extend(targets.cpu().numpy())
    
    avg_loss = total_loss / len(loader)
    accuracy = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average='macro')
    
    return avg_loss, accuracy, f1, all_preds, all_labels

## 10. K-Fold Cross Validation

In [None]:
# K-Fold Cross Validation
skf = StratifiedKFold(n_splits=CFG['N_FOLDS'], shuffle=True, random_state=SEED)
fold_results = []
oof_preds = np.zeros(len(train_df))
test_preds_proba = np.zeros((len(test_df), CFG['NUM_CLASSES'], CFG['N_FOLDS']))

print("\n" + "="*80)
print(f"Starting {CFG['N_FOLDS']}-Fold Cross Validation - Strategy A")
print("="*80)

for fold, (train_idx, val_idx) in enumerate(skf.split(train_df, train_df['target'])):
    print(f"\n{'='*80}")
    print(f"Fold {fold + 1}/{CFG['N_FOLDS']}")
    print(f"{'='*80}")
    
    trn_df = train_df.iloc[train_idx].reset_index(drop=True)
    val_df = train_df.iloc[val_idx].reset_index(drop=True)
    print(f"Train: {len(trn_df)} samples | Val: {len(val_df)} samples")
    
    trn_dataset = DocumentDataset(trn_df, CFG['TRAIN_IMG_DIR'], trn_transform, mixup_alpha=CFG['MIXUP_ALPHA'])
    val_dataset = DocumentDataset(val_df, CFG['TRAIN_IMG_DIR'], tst_transform, mixup_alpha=0.0)
    
    trn_loader = DataLoader(trn_dataset, batch_size=CFG['BATCH_SIZE'], shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=CFG['BATCH_SIZE'], shuffle=False, num_workers=4, pin_memory=True)
    
    model = ConvNeXtClassifier(CFG['MODEL_NAME'], CFG['NUM_CLASSES'], CFG['DROPOUT']).to(CFG['DEVICE'])
    criterion = FocalLoss(alpha=CFG['FOCAL_ALPHA'], gamma=CFG['FOCAL_GAMMA'])
    optimizer = AdamW(model.parameters(), lr=CFG['LR'], weight_decay=CFG['WEIGHT_DECAY'])
    scheduler = CosineAnnealingLR(optimizer, T_max=len(trn_loader) * (CFG['EPOCHS'] - CFG['WARMUP_EPOCHS']), eta_min=1e-7)
    
    best_f1, best_epoch, patience = 0, 0, 0
    
    for epoch in range(1, CFG['EPOCHS'] + 1):
        train_loss = train_epoch(model, trn_loader, optimizer, criterion, scheduler, CFG['DEVICE'], epoch, CFG['WARMUP_EPOCHS'])
        val_loss, val_acc, val_f1, val_preds, _ = validate(model, val_loader, criterion, CFG['DEVICE'])
        
        if val_f1 > best_f1:
            best_f1, best_epoch, patience = val_f1, epoch, 0
            torch.save(model.state_dict(), f"{CFG['BASE_DIR']}/fold{fold+1}_best.pth")
            oof_preds[val_idx] = val_preds
            warmup_status = "[Warmup]" if epoch <= CFG['WARMUP_EPOCHS'] else ""
            print(f"Epoch {epoch:2d} {warmup_status}: Loss {train_loss:.4f} | Val Loss {val_loss:.4f} Acc {val_acc:.4f} F1 {val_f1:.4f} | Best {best_f1:.4f} (SAVED)")
        else:
            patience += 1
            warmup_status = "[Warmup]" if epoch <= CFG['WARMUP_EPOCHS'] else ""
            print(f"Epoch {epoch:2d} {warmup_status}: Loss {train_loss:.4f} | Val Loss {val_loss:.4f} Acc {val_acc:.4f} F1 {val_f1:.4f} | Best {best_f1:.4f} (Patience {patience}/7)")
        
        if patience >= 7:
            print(f"Early stopping at epoch {epoch}")
            break
    
    fold_results.append({'fold': fold + 1, 'best_f1': best_f1, 'best_epoch': best_epoch})
    print(f"\nFold {fold + 1} complete: Best F1 = {best_f1:.4f} at epoch {best_epoch}")
    
    # Test inference
    print(f"\nRunning test inference...")
    model.load_state_dict(torch.load(f"{CFG['BASE_DIR']}/fold{fold+1}_best.pth"))
    model.eval()
    
    test_dataset = DocumentDataset(test_df, CFG['TEST_IMG_DIR'], tst_transform, is_test=True, mixup_alpha=0.0)
    test_loader = DataLoader(test_dataset, batch_size=CFG['BATCH_SIZE'], shuffle=False, num_workers=4)
    
    fold_test_proba = []
    with torch.no_grad():
        for imgs in tqdm(test_loader, desc=f"Fold {fold+1} Test"):
            proba = F.softmax(model(imgs.to(CFG['DEVICE'])), dim=1)
            fold_test_proba.append(proba.cpu().numpy())
    
    test_preds_proba[:, :, fold] = np.concatenate(fold_test_proba, axis=0)
    print(f"Fold {fold + 1} test predictions complete")

print("\n" + "="*80)
print("All folds complete!")
print("="*80)

## 11. Results & Submission

In [None]:
# Results
print("\n" + "="*80)
print("Final Results - Strategy A")
print("="*80)

print("\nFold Results:")
for result in fold_results:
    print(f"  Fold {result['fold']}: F1 = {result['best_f1']:.4f} (Epoch {result['best_epoch']})")

mean_f1 = np.mean([r['best_f1'] for r in fold_results])
std_f1 = np.std([r['best_f1'] for r in fold_results])
oof_f1 = f1_score(train_df['target'], oof_preds, average='macro')
oof_acc = accuracy_score(train_df['target'], oof_preds)

print(f"\nCross-Validation:")
print(f"  Mean Fold F1: {mean_f1:.4f} ± {std_f1:.4f}")
print(f"  OOF F1: {oof_f1:.4f}")
print(f"  OOF Accuracy: {oof_acc:.4f}")


# Submission
final_preds = test_preds_proba.mean(axis=2).argmax(axis=1)
submission = test_df.copy()
submission['target'] = final_preds
submission.to_csv(f"{CFG['BASE_DIR']}/submission.csv", index=False)

print(f"\nSubmission file saved: submission.csv")
print(f"Total predictions: {len(submission)}")

# Prediction distribution
print(f"\nPrediction Distribution:")
pred_counts = pd.Series(final_preds).value_counts().sort_index()
for class_id, count in pred_counts.items():
    class_name = meta_df[meta_df['target'] == class_id]['class_name'].values[0]
    print(f"  Class {class_id:2d} ({class_name[:25]:25s}): {count:4d}")

print("\n" + "="*80)
print("Complete!")
print("="*80)

## 12. Visualization: K-Fold Results

In [None]:
# Visualization: K-Fold Results
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

fold_nums = [r['fold'] for r in fold_results]
fold_f1s = [r['best_f1'] for r in fold_results]
fold_epochs = [r['best_epoch'] for r in fold_results]

# Fold F1 Scores
axes[0, 0].bar(fold_nums, fold_f1s, color=['#3498db', '#e74c3c', '#2ecc71', '#f39c12', '#9b59b6'])
axes[0, 0].axhline(y=np.mean(fold_f1s), color='red', linestyle='--', linewidth=2, label=f'Mean: {np.mean(fold_f1s):.4f}')
axes[0, 0].axhline(y=0.8827, color='orange', linestyle=':', linewidth=2, label='09_code OOF: 0.8827')
axes[0, 0].set_title('Best F1 Score per Fold (Strategy A)', fontsize=14, fontweight='bold')
axes[0, 0].set_xlabel('Fold')
axes[0, 0].set_ylabel('F1 Score')
axes[0, 0].set_ylim([min(fold_f1s) - 0.05, max(max(fold_f1s), 0.89) + 0.02])
axes[0, 0].legend()
axes[0, 0].grid(alpha=0.3, axis='y')
for i, (f1, epoch) in enumerate(zip(fold_f1s, fold_epochs)):
    axes[0, 0].text(fold_nums[i], f1 + 0.01, f'{f1:.4f}\n(E{epoch})', ha='center', fontsize=10, fontweight='bold')

# Best Epochs
axes[0, 1].bar(fold_nums, fold_epochs, color=['#3498db', '#e74c3c', '#2ecc71', '#f39c12', '#9b59b6'])
axes[0, 1].axhline(y=np.mean(fold_epochs), color='red', linestyle='--', linewidth=2, label=f'Mean: {np.mean(fold_epochs):.1f}')
axes[0, 1].axhline(y=CFG['WARMUP_EPOCHS'], color='green', linestyle=':', linewidth=2, label=f'Warmup: {CFG["WARMUP_EPOCHS"]}')
axes[0, 1].set_title('Best Epoch per Fold', fontsize=14, fontweight='bold')
axes[0, 1].set_xlabel('Fold')
axes[0, 1].set_ylabel('Epoch')
axes[0, 1].legend()
axes[0, 1].grid(alpha=0.3, axis='y')
for i, epoch in enumerate(fold_epochs):
    axes[0, 1].text(fold_nums[i], epoch + 0.3, f'{epoch}', ha='center', fontsize=11, fontweight='bold')

# F1 Distribution
axes[1, 0].boxplot([fold_f1s], labels=['All Folds'], widths=0.5)
axes[1, 0].scatter([1]*len(fold_f1s), fold_f1s, c=['#3498db', '#e74c3c', '#2ecc71', '#f39c12', '#9b59b6'], s=100, alpha=0.6, edgecolors='black', linewidth=1.5)
axes[1, 0].set_title('F1 Score Distribution', fontsize=14, fontweight='bold')
axes[1, 0].set_ylabel('F1 Score')
axes[1, 0].grid(alpha=0.3, axis='y')
axes[1, 0].text(1.3, np.mean(fold_f1s), f'Mean: {np.mean(fold_f1s):.4f}\nStd: {np.std(fold_f1s):.4f}\nOOF: {oof_f1:.4f}', fontsize=11, bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

# Summary Table
axes[1, 1].axis('off')
summary_data = [
    ['Metric', 'Value'],
    ['─' * 20, '─' * 20],
    ['Mean F1', f'{np.mean(fold_f1s):.4f}'],
    ['Std F1', f'{np.std(fold_f1s):.4f}'],
    ['Min F1', f'{min(fold_f1s):.4f}'],
    ['Max F1', f'{max(fold_f1s):.4f}'],
    ['─' * 20, '─' * 20],
    ['OOF F1', f'{oof_f1:.4f}'],
    ['OOF Accuracy', f'{oof_acc:.4f}'],
    ['─' * 20, '─' * 20],
    ['09_code OOF', '0.8827'],
    ['09_code LB', '0.6669'],
]

table = axes[1, 1].table(cellText=summary_data, cellLoc='left', loc='center', colWidths=[0.6, 0.4])
table.auto_set_font_size(False)
table.set_fontsize(11)
table.scale(1, 2.2)

for i in range(2):
    table[(0, i)].set_facecolor('#3498db')
    table[(0, i)].set_text_props(weight='bold', color='white')
    table[(7, i)].set_facecolor('#2ecc71')
    table[(7, i)].set_text_props(weight='bold')
    table[(8, i)].set_facecolor('#2ecc71')
    table[(8, i)].set_text_props(weight='bold')

axes[1, 1].set_title('Performance Summary', fontsize=14, fontweight='bold', pad=20)

plt.tight_layout()
plt.savefig(f"{CFG['BASE_DIR']}/kfold_results.png", dpi=150, bbox_inches='tight')
plt.show()

print(f"\nVisualization saved: kfold_results.png")

## 13. Wrong Predictions Analysis

In [None]:
# Wrong Predictions Analysis
print("\n" + "="*80)
print("Wrong Predictions Analysis")
print("="*80)

oof_preds_int = oof_preds.astype(int)
actual_labels = train_df['target'].values.astype(int)

wrong_mask = (oof_preds_int != actual_labels)
wrong_df = train_df[wrong_mask].copy()
wrong_df['predicted'] = oof_preds_int[wrong_mask]
wrong_df['actual'] = actual_labels[wrong_mask]

print(f"\nTotal wrong: {len(wrong_df)}/{len(train_df)} (Accuracy: {(1 - len(wrong_df)/len(train_df))*100:.2f}%)")

print(f"\nWrong by class:")
wrong_by_class = wrong_df.groupby('actual').size().sort_values(ascending=False)
for class_id, count in wrong_by_class.items():
    class_id = int(class_id)
    class_name = meta_df[meta_df['target'] == class_id]['class_name'].values[0]
    total_in_class = (train_df['target'] == class_id).sum()
    error_rate = (count / total_in_class) * 100
    print(f"  Class {class_id:2d} ({class_name[:25]:25s}): {count:2d}/{total_in_class:3d} ({error_rate:5.1f}% error)")

print(f"\nMost confused pairs:")
confusion_pairs = wrong_df.groupby(['actual', 'predicted']).size().sort_values(ascending=False).head(10)
for idx, ((actual, pred), count) in enumerate(confusion_pairs.items(), 1):
    actual = int(actual)
    pred = int(pred)
    actual_name = meta_df[meta_df['target'] == actual]['class_name'].values[0]
    pred_name = meta_df[meta_df['target'] == pred]['class_name'].values[0]
    print(f"  {idx:2d}. Class {actual:2d} ({actual_name[:20]:20s}) -> Class {pred:2d} ({pred_name[:20]:20s}): {count}")

# Visualize wrong predictions
n_samples = min(12, len(wrong_df))
sample_indices = np.random.choice(len(wrong_df), n_samples, replace=False)
samples = wrong_df.iloc[sample_indices]

fig, axes = plt.subplots(3, 4, figsize=(16, 12))
axes = axes.flatten()

for idx, (ax, (_, row)) in enumerate(zip(axes, samples.iterrows())):
    img_path = os.path.join(CFG['TRAIN_IMG_DIR'], row['ID'])
    img = Image.open(img_path).convert('RGB')
    
    actual = int(row['actual'])
    predicted = int(row['predicted'])
    actual_name = meta_df[meta_df['target'] == actual]['class_name'].values[0]
    pred_name = meta_df[meta_df['target'] == predicted]['class_name'].values[0]
    
    ax.imshow(img)
    ax.set_title(f"Actual: {actual} ({actual_name[:15]})\nPredicted: {predicted} ({pred_name[:15]})", fontsize=9, color='red', fontweight='bold')
    ax.axis('off')

plt.tight_layout()
plt.savefig(f"{CFG['BASE_DIR']}/wrong_predictions.png", dpi=150, bbox_inches='tight')
plt.show()

print(f"\nWrong predictions saved: wrong_predictions.png")

## 14. Confusion Matrix

In [None]:
# Confusion Matrix
print("\n" + "="*80)
print("Confusion Matrix")
print("="*80)

cm = confusion_matrix(actual_labels, oof_preds_int)

fig, ax = plt.subplots(figsize=(16, 14))
im = ax.imshow(cm, cmap='Blues', interpolation='nearest')
ax.set_title('Confusion Matrix (OOF - 5-Fold) - Strategy A', fontsize=18, fontweight='bold', pad=20)

cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
cbar.set_label('Count', fontsize=12)

tick_labels = []
for i in range(CFG['NUM_CLASSES']):
    class_name = meta_df[meta_df['target']==i]['class_name'].values[0]
    if len(class_name) > 15:
        class_name = class_name[:12] + '...'
    tick_labels.append(f"{i}\n{class_name}")

ax.set_xticks(range(CFG['NUM_CLASSES']))
ax.set_yticks(range(CFG['NUM_CLASSES']))
ax.set_xticklabels(tick_labels, rotation=45, ha='right', fontsize=9)
ax.set_yticklabels(tick_labels, fontsize=9)
ax.set_xlabel('Predicted Label', fontsize=14, fontweight='bold')
ax.set_ylabel('True Label', fontsize=14, fontweight='bold')

threshold = cm.max() / 2
for i in range(cm.shape[0]):
    for j in range(cm.shape[1]):
        if i == j:
            color = 'darkgreen' if cm[i, j] > threshold else 'green'
            weight = 'bold'
        else:
            color = 'white' if cm[i, j] > threshold else 'black'
            weight = 'normal'
        
        if cm[i, j] > 0:
            ax.text(j, i, str(cm[i, j]), ha='center', va='center', color=color, fontsize=9, fontweight=weight)

plt.tight_layout()
plt.savefig(f"{CFG['BASE_DIR']}/confusion_matrix.png", dpi=150, bbox_inches='tight')
plt.show()

print(f"\nConfusion Matrix saved: confusion_matrix.png")

print(f"\nStatistics:")
print(f"  Total samples: {cm.sum():.0f}")
print(f"  Correct: {np.trace(cm):.0f}")
print(f"  Wrong: {cm.sum() - np.trace(cm):.0f}")
print(f"  Accuracy: {np.trace(cm) / cm.sum() * 100:.2f}%")

print(f"\nClass Recall:")
for i in range(CFG['NUM_CLASSES']):
    class_name = meta_df[meta_df['target']==i]['class_name'].values[0]
    total = cm[i, :].sum()
    correct = cm[i, i]
    recall = (correct / total * 100) if total > 0 else 0
    print(f"  Class {i:2d} ({class_name[:25]:25s}): {correct:3.0f}/{total:3.0f} ({recall:5.1f}%)")

print(f"\nMost confused pairs (top 5):")
confusion_list = []
for i in range(cm.shape[0]):
    for j in range(cm.shape[1]):
        if i != j and cm[i, j] > 0:
            confusion_list.append((i, j, cm[i, j]))

confusion_list.sort(key=lambda x: x[2], reverse=True)
for idx, (true_class, pred_class, count) in enumerate(confusion_list[:5], 1):
    true_name = meta_df[meta_df['target']==true_class]['class_name'].values[0]
    pred_name = meta_df[meta_df['target']==pred_class]['class_name'].values[0]
    print(f"  {idx}. Class {true_class:2d} ({true_name[:20]:20s}) -> Class {pred_class:2d} ({pred_name[:20]:20s}): {count:.0f}")