In [None]:
"""
Swin Transformer V1 Tiny for Coffee Bean Classification
Dataset: 54 Indonesian Coffee Bean Varieties
Platform: Kaggle with GPU T4 x2
"""
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import timm
from PIL import Image
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, classification_report, top_k_accuracy_score
from tqdm.auto import tqdm
import time

In [None]:
# Set random seeds for reproducibility
def set_seed(seed=42):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
set_seed(42)

In [None]:
# CONFIGURATION
class Config:
    # Data paths 
    DATA_DIR = '/kaggle/input/dataset' 
    TRAIN_DIR = os.path.join(DATA_DIR, 'train')
    VAL_DIR = os.path.join(DATA_DIR, 'valid')
    TEST_DIR = os.path.join(DATA_DIR, 'test')
    
    # Model configuration
    MODEL_NAME = 'swin_tiny_patch4_window7_224'
    NUM_CLASSES = 54
    IMG_SIZE = 224
    
    # Training configuration
    BATCH_SIZE = 32  # Adjust based on GPU memory
    NUM_EPOCHS = 50
    LEARNING_RATE = 1e-4
    WEIGHT_DECAY = 0.05
    NUM_WORKERS = 4
    
    # Device configuration - Use both GPUs
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    USE_MULTI_GPU = torch.cuda.device_count() > 1
    
    # Training settings
    EARLY_STOPPING_PATIENCE = 10
    SAVE_DIR = './output'
    CHECKPOINT_PATH = os.path.join(SAVE_DIR, 'best_swin_model.pth')
    
    # Fine-tuning settings
    FINETUNE_LAYERS = 1.0  # 1.0 = fine-tune all layers, 0.5 = fine-tune 50% of layers
# Create output directory
os.makedirs(Config.SAVE_DIR, exist_ok=True)


In [None]:
# DATASET
class CoffeeDataset(Dataset):
    """Custom Dataset for Coffee Bean Classification"""
    
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = sorted(os.listdir(root_dir))
        self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)}
        
        self.images = []
        self.labels = []
        
        # Load all image paths and labels
        for class_name in self.classes:
            class_dir = os.path.join(root_dir, class_name)
            if os.path.isdir(class_dir):
                for img_name in os.listdir(class_dir):
                    if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
                        self.images.append(os.path.join(class_dir, img_name))
                        self.labels.append(self.class_to_idx[class_name])
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]
        
        # Load image
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

In [None]:
# DATA AUGMENTATION 
def get_transforms(split='train'):
    """Get appropriate transforms for each data split"""
    
    if split == 'train':
        return transforms.Compose([
            transforms.Resize((Config.IMG_SIZE, Config.IMG_SIZE)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomRotation(degrees=10),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
            transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                               std=[0.229, 0.224, 0.225])
        ])
    else:  # validation and test
        return transforms.Compose([
            transforms.Resize((Config.IMG_SIZE, Config.IMG_SIZE)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                               std=[0.229, 0.224, 0.225])
        ])

In [None]:
# MODEL
def create_model(num_classes=54, pretrained=True):
    """Create Swin Transformer V1 Tiny model"""
    
    print(f"Creating {Config.MODEL_NAME} model...")
    model = timm.create_model(
        Config.MODEL_NAME,
        pretrained=pretrained,
        num_classes=num_classes
    )
    
    # Print model info
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    
    return model

In [None]:
# TRAINING
class Trainer:
    def __init__(self, model, train_loader, val_loader, criterion, optimizer, scheduler, device):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.criterion = criterion
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.device = device
        
        self.best_val_acc = 0.0
        self.best_val_f1 = 0.0
        self.epochs_without_improvement = 0
        self.history = {
            'train_loss': [], 'train_acc': [],
            'val_loss': [], 'val_acc': [], 'val_f1': []
        }
    
    def train_epoch(self):
        """Train for one epoch"""
        self.model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        pbar = tqdm(self.train_loader, desc='Training')
        for images, labels in pbar:
            images, labels = images.to(self.device), labels.to(self.device)
            
            # Forward pass
            self.optimizer.zero_grad()
            outputs = self.model(images)
            loss = self.criterion(outputs, labels)
            
            # Backward pass
            loss.backward()
            self.optimizer.step()
            
            # Statistics
            running_loss += loss.item() * images.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            # Update progress bar
            pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'acc': f'{100.*correct/total:.2f}%'
            })
        
        epoch_loss = running_loss / total
        epoch_acc = 100. * correct / total
        
        return epoch_loss, epoch_acc
    
    def validate(self):
        """Validate the model"""
        self.model.eval()
        running_loss = 0.0
        all_preds = []
        all_labels = []
        all_probs = []
        
        with torch.no_grad():
            pbar = tqdm(self.val_loader, desc='Validation')
            for images, labels in pbar:
                images, labels = images.to(self.device), labels.to(self.device)
                
                # Forward pass
                outputs = self.model(images)
                loss = self.criterion(outputs, labels)
                
                # Get probabilities for top-k accuracy
                probs = torch.softmax(outputs, dim=1)
                
                running_loss += loss.item() * images.size(0)
                all_probs.append(probs.cpu().numpy())
                all_preds.extend(outputs.argmax(1).cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
        
        # Calculate metrics
        val_loss = running_loss / len(all_labels)
        val_acc = 100. * accuracy_score(all_labels, all_preds)
        val_f1 = f1_score(all_labels, all_preds, average='macro')
        
        # Calculate Top-5 Accuracy
        all_probs_concat = np.vstack(all_probs)
        top5_acc = 100. * top_k_accuracy_score(
            all_labels, all_probs_concat, k=5, labels=range(Config.NUM_CLASSES)
        )
        
        print(f"\nValidation - Loss: {val_loss:.4f}, Top-1 Acc: {val_acc:.2f}%, "
              f"Top-5 Acc: {top5_acc:.2f}%, Macro F1: {val_f1:.4f}")
        
        return val_loss, val_acc, val_f1, top5_acc
    
    def fit(self, num_epochs):
        """Train the model for multiple epochs"""
        print(f"\nStarting training for {num_epochs} epochs...")
        print(f"Using device: {self.device}")
        if Config.USE_MULTI_GPU:
            print(f"Using {torch.cuda.device_count()} GPUs")
        
        for epoch in range(num_epochs):
            print(f"\n{'='*60}")
            print(f"Epoch {epoch+1}/{num_epochs}")
            print(f"{'='*60}")
            
            start_time = time.time()
            
            # Train
            train_loss, train_acc = self.train_epoch()
            
            # Validate
            val_loss, val_acc, val_f1, top5_acc = self.validate()
            
            # Update learning rate
            self.scheduler.step()
            
            # Save history
            self.history['train_loss'].append(train_loss)
            self.history['train_acc'].append(train_acc)
            self.history['val_loss'].append(val_loss)
            self.history['val_acc'].append(val_acc)
            self.history['val_f1'].append(val_f1)
            
            epoch_time = time.time() - start_time
            print(f"\nEpoch Summary:")
            print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
            print(f"  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
            print(f"  Val F1: {val_f1:.4f}, Top-5 Acc: {top5_acc:.2f}%")
            print(f"  Time: {epoch_time:.2f}s")
            print(f"  LR: {self.optimizer.param_groups[0]['lr']:.6f}")
            
            # Save best model
            if val_acc > self.best_val_acc:
                self.best_val_acc = val_acc
                self.best_val_f1 = val_f1
                self.epochs_without_improvement = 0
                
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': self.model.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'val_acc': val_acc,
                    'val_f1': val_f1,
                    'top5_acc': top5_acc
                }, Config.CHECKPOINT_PATH)
                print(f"  âœ“ Best model saved! (Val Acc: {val_acc:.2f}%)")
            else:
                self.epochs_without_improvement += 1
                print(f"  No improvement for {self.epochs_without_improvement} epoch(s)")
            
            # Early stopping
            if self.epochs_without_improvement >= Config.EARLY_STOPPING_PATIENCE:
                print(f"\nEarly stopping triggered after {epoch+1} epochs")
                break
        
        print(f"\n{'='*60}")
        print(f"Training completed!")
        print(f"Best Validation Accuracy: {self.best_val_acc:.2f}%")
        print(f"Best Validation F1-Score: {self.best_val_f1:.4f}")
        print(f"{'='*60}")
        
        return self.history

In [None]:
# EVALUATION
def evaluate_model(model, test_loader, device):
    """Comprehensive evaluation on test set"""
    model.eval()
    all_preds = []
    all_labels = []
    all_probs = []
    
    print("\nEvaluating on test set...")
    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc='Testing'):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            probs = torch.softmax(outputs, dim=1)
            
            all_probs.append(probs.cpu().numpy())
            all_preds.extend(outputs.argmax(1).cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    # Calculate metrics
    all_probs_concat = np.vstack(all_probs)
    
    top1_acc = 100. * accuracy_score(all_labels, all_preds)
    top5_acc = 100. * top_k_accuracy_score(
        all_labels, all_probs_concat, k=5, labels=range(Config.NUM_CLASSES)
    )
    macro_f1 = f1_score(all_labels, all_preds, average='macro')
    
    print(f"\n{'='*60}")
    print(f"TEST RESULTS:")
    print(f"{'='*60}")
    print(f"Top-1 Accuracy: {top1_acc:.2f}%")
    print(f"Top-5 Accuracy: {top5_acc:.2f}%")
    print(f"Macro F1-Score: {macro_f1:.4f}")
    print(f"{'='*60}\n")
    
    return {
        'predictions': all_preds,
        'labels': all_labels,
        'probabilities': all_probs_concat,
        'top1_accuracy': top1_acc,
        'top5_accuracy': top5_acc,
        'macro_f1': macro_f1
    }

In [None]:
# VISUALIZATION
def plot_training_history(history):
    """Plot training history"""
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    # Loss
    axes[0].plot(history['train_loss'], label='Train Loss', marker='o')
    axes[0].plot(history['val_loss'], label='Val Loss', marker='s')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Training and Validation Loss')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Accuracy
    axes[1].plot(history['train_acc'], label='Train Acc', marker='o')
    axes[1].plot(history['val_acc'], label='Val Acc', marker='s')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy (%)')
    axes[1].set_title('Training and Validation Accuracy')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    # F1-Score
    axes[2].plot(history['val_f1'], label='Val F1', marker='s', color='green')
    axes[2].set_xlabel('Epoch')
    axes[2].set_ylabel('F1-Score')
    axes[2].set_title('Validation F1-Score')
    axes[2].legend()
    axes[2].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(os.path.join(Config.SAVE_DIR, 'training_history.png'), dpi=300, bbox_inches='tight')
    plt.show()

def plot_confusion_matrix(labels, predictions, class_names):
    """Plot confusion matrix"""
    cm = confusion_matrix(labels, predictions)
    
    plt.figure(figsize=(20, 18))
    sns.heatmap(cm, annot=False, fmt='d', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names,
                cbar_kws={'label': 'Count'})
    plt.xlabel('Predicted Label', fontsize=12)
    plt.ylabel('True Label', fontsize=12)
    plt.title('Confusion Matrix - Swin Transformer V1 Tiny', fontsize=14, fontweight='bold')
    plt.xticks(rotation=90, fontsize=8)
    plt.yticks(rotation=0, fontsize=8)
    plt.tight_layout()
    plt.savefig(os.path.join(Config.SAVE_DIR, 'confusion_matrix.png'), dpi=300, bbox_inches='tight')
    plt.show()
    
    return cm

def analyze_predictions(results, class_names):
    """Analyze prediction results in detail"""
    # Per-class accuracy
    cm = confusion_matrix(results['labels'], results['predictions'])
    
    # Ensure we only use classes that exist in the test set
    unique_labels = np.unique(results['labels'])
    per_class_acc = cm.diagonal() / cm.sum(axis=1)
    
    # Handle cases where some classes might not be in test set
    if len(per_class_acc) < len(class_names):
        # Get only the class names that actually appear in test set
        actual_class_names = [class_names[i] for i in unique_labels]
    else:
        actual_class_names = class_names
    
    # Ensure arrays have same length
    min_len = min(len(actual_class_names), len(per_class_acc))
    actual_class_names = actual_class_names[:min_len]
    per_class_acc = per_class_acc[:min_len]
    samples_per_class = cm.sum(axis=1)[:min_len]
    
    # Create DataFrame
    class_analysis = pd.DataFrame({
        'Class': actual_class_names,
        'Accuracy': per_class_acc * 100,
        'Samples': samples_per_class
    })
    class_analysis = class_analysis.sort_values('Accuracy')
    
    print("\nPer-Class Performance:")
    print("="*60)
    print(class_analysis.to_string(index=False))
    print("="*60)
    
    # Top 5 best and worst
    print("\nðŸ“Š TOP 5 BEST PERFORMING CLASSES:")
    print(class_analysis.tail(5).to_string(index=False))
    
    print("\nðŸ“‰ TOP 5 WORST PERFORMING CLASSES:")
    print(class_analysis.head(5).to_string(index=False))
    
    # Save to CSV
    class_analysis.to_csv(os.path.join(Config.SAVE_DIR, 'per_class_analysis.csv'), index=False)
    
    return class_analysis

In [None]:
def main():
    print("="*60)
    print("Swin Transformer V1 Tiny - Coffee Bean Classification")
    print("54 Indonesian Coffee Varieties")
    print("="*60)
    
    # Check GPU availability
    print(f"\nGPU Available: {torch.cuda.is_available()}")
    if torch.cuda.is_available():
        print(f"GPU Count: {torch.cuda.device_count()}")
        for i in range(torch.cuda.device_count()):
            print(f"  GPU {i}: {torch.cuda.get_device_name(i)}")
    
    # Create datasets
    print("\n" + "="*60)
    print("Loading datasets...")
    print("="*60)
    
    train_dataset = CoffeeDataset(Config.TRAIN_DIR, transform=get_transforms('train'))
    val_dataset = CoffeeDataset(Config.VAL_DIR, transform=get_transforms('val'))
    test_dataset = CoffeeDataset(Config.TEST_DIR, transform=get_transforms('test'))
    
    print(f"Train samples: {len(train_dataset)}")
    print(f"Validation samples: {len(val_dataset)}")
    print(f"Test samples: {len(test_dataset)}")
    print(f"Number of classes: {len(train_dataset.classes)}")
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset, 
        batch_size=Config.BATCH_SIZE,
        shuffle=True,
        num_workers=Config.NUM_WORKERS,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=Config.BATCH_SIZE,
        shuffle=False,
        num_workers=Config.NUM_WORKERS,
        pin_memory=True
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=Config.BATCH_SIZE,
        shuffle=False,
        num_workers=Config.NUM_WORKERS,
        pin_memory=True
    )
    
    # Create model
    print("\n" + "="*60)
    print("Creating model...")
    print("="*60)
    
    model = create_model(num_classes=Config.NUM_CLASSES, pretrained=True)
    
    # Multi-GPU support
    if Config.USE_MULTI_GPU:
        print(f"\nWrapping model with DataParallel for {torch.cuda.device_count()} GPUs")
        model = nn.DataParallel(model)
    
    model = model.to(Config.DEVICE)
    
    # Loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(
        model.parameters(),
        lr=Config.LEARNING_RATE,
        weight_decay=Config.WEIGHT_DECAY
    )
    
    # Learning rate scheduler
    scheduler = optim.lr_scheduler.CosineAnnealingLR(
        optimizer,
        T_max=Config.NUM_EPOCHS,
        eta_min=1e-6
    )
    
    # Create trainer
    trainer = Trainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        device=Config.DEVICE
    )
    
    # Train model
    history = trainer.fit(Config.NUM_EPOCHS)
    
    # Plot training history
    plot_training_history(history)
    
    # Load best model for testing
    print("\n" + "="*60)
    print("Loading best model for testing...")
    print("="*60)
    
    checkpoint = torch.load(Config.CHECKPOINT_PATH, weights_only=False)
    # Load state dict - handle DataParallel wrapper
    model.load_state_dict(checkpoint['model_state_dict'])
    
    # Evaluate on test set
    results = evaluate_model(model, test_loader, Config.DEVICE)
    
    # Visualizations
    print("\nGenerating visualizations...")
    plot_confusion_matrix(results['labels'], results['predictions'], train_dataset.classes)
    analyze_predictions(results, train_dataset.classes)
    
    # Save final results
    final_results = {
        'model': Config.MODEL_NAME,
        'num_classes': Config.NUM_CLASSES,
        'top1_accuracy': results['top1_accuracy'],
        'top5_accuracy': results['top5_accuracy'],
        'macro_f1': results['macro_f1'],
        'train_samples': len(train_dataset),
        'val_samples': len(val_dataset),
        'test_samples': len(test_dataset)
    }
    
    # Save to file
    import json
    with open(os.path.join(Config.SAVE_DIR, 'final_results.json'), 'w') as f:
        json.dump(final_results, f, indent=4)
    
    print("\n" + "="*60)
    print("Training and evaluation completed successfully!")
    print(f"Results saved to: {Config.SAVE_DIR}")
    print("="*60)
    
if __name__ == "__main__":
    main()