In [1]:
import os
import time
import json
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import timm
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns

class ContinuedEpochMetricsCallback:
    def __init__(self, working_path):
        self.working_path = working_path
        self.epoch_metrics = []
        self.epoch_metrics_file = os.path.join(working_path, 'epoch_metrics.json')
        
    def save_epoch_metrics(self):
        with open(self.epoch_metrics_file, 'w') as f:
            json.dump(self.epoch_metrics, f, indent=4)
    
    def on_epoch_start(self):
        self.epoch_start_time = time.time()
    
    def on_epoch_end(self, epoch, logs):
        epoch_duration = time.time() - self.epoch_start_time
        actual_epoch = epoch + 1
        
        epoch_metrics = {
            'global_epoch': actual_epoch,
            'local_epoch': epoch + 1,
            'duration_minutes': epoch_duration / 60,
            'accuracy': logs['accuracy'],
            'val_accuracy': logs['val_accuracy'],
            'loss': logs['loss'],
            'val_loss': logs['val_loss'],
            'precision': logs.get('precision', 0),
            'val_precision': logs.get('val_precision', 0),
            'recall': logs.get('recall', 0),
            'val_recall': logs.get('val_recall', 0)
        }
        
        self.epoch_metrics.append(epoch_metrics)
        self.save_epoch_metrics()
        
        print(f"\nGlobal Epoch {actual_epoch} Metrics:")
        print(f"Time taken: {epoch_metrics['duration_minutes']:.2f} minutes")
        print(f"Training Accuracy: {epoch_metrics['accuracy']:.4f}")
        print(f"Validation Accuracy: {epoch_metrics['val_accuracy']:.4f}")
        print(f"Training Loss: {epoch_metrics['loss']:.4f}")
        print(f"Validation Loss: {epoch_metrics['val_loss']:.4f}")
        
        self.plot_metrics()

    def plot_metrics(self):
        if not self.epoch_metrics:
            return
            
        epochs = [m['global_epoch'] for m in self.epoch_metrics]
        accuracy = [m['accuracy'] for m in self.epoch_metrics]
        val_accuracy = [m['val_accuracy'] for m in self.epoch_metrics]
        loss = [m['loss'] for m in self.epoch_metrics]
        val_loss = [m['val_loss'] for m in self.epoch_metrics]
        precision = [m['precision'] for m in self.epoch_metrics]
        val_precision = [m['val_precision'] for m in self.epoch_metrics]
        recall = [m['recall'] for m in self.epoch_metrics]
        val_recall = [m['val_recall'] for m in self.epoch_metrics]

        fig, axs = plt.subplots(2, 2, figsize=(15, 10))

        axs[0, 0].plot(epochs, accuracy, label='Training Accuracy')
        axs[0, 0].plot(epochs, val_accuracy, label='Validation Accuracy')
        axs[0, 0].set_title('Model Accuracy')
        axs[0, 0].set_xlabel('Epoch')
        axs[0, 0].set_ylabel('Accuracy')
        axs[0, 0].legend()

        axs[0, 1].plot(epochs, loss, label='Training Loss')
        axs[0, 1].plot(epochs, val_loss, label='Validation Loss')
        axs[0, 1].set_title('Model Loss')
        axs[0, 1].set_xlabel('Epoch')
        axs[0, 1].set_ylabel('Loss')
        axs[0, 1].legend()

        axs[1, 0].plot(epochs, precision, label='Training Precision')
        axs[1, 0].plot(epochs, val_precision, label='Validation Precision')
        axs[1, 0].set_title('Model Precision')
        axs[1, 0].set_xlabel('Epoch')
        axs[1, 0].set_ylabel('Precision')
        axs[1, 0].legend()

        axs[1, 1].plot(epochs, recall, label='Training Recall')
        axs[1, 1].plot(epochs, val_recall, label='Validation Recall')
        axs[1, 1].set_title('Model Recall')
        axs[1, 1].set_xlabel('Epoch')
        axs[1, 1].set_ylabel('Recall')
        axs[1, 1].legend()

        plt.tight_layout()
        plt.savefig(os.path.join(self.working_path, f'training_metrics_epoch_{len(epochs)}.png'))
        plt.close()

class XceptionBirdClassifier(nn.Module):
    def __init__(self, num_classes):
        super(XceptionBirdClassifier, self).__init__()
        self.xception = timm.create_model('xception', pretrained=True)
        
        # Freeze all parameters initially
        for param in self.xception.parameters():
            param.requires_grad = False
            
        # Unfreeze the last three blocks for more capacity
        blocks_to_unfreeze = ['block4', 'block3', 'block2']
        for name, param in self.xception.named_parameters():
            if any(block in name for block in blocks_to_unfreeze):
                param.requires_grad = True
        
        # Get the number of features from the last layer
        num_features = self.xception.num_features
        
        # Replace the classification head
        self.xception.fc = nn.Sequential(
            nn.Linear(num_features, 1536),
            nn.BatchNorm1d(1536),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(1536, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(1024, num_classes)
        )
        
    def forward(self, x):
        return self.xception(x)

def load_training_state_and_metrics(working_path):
    try:
        with open(os.path.join(working_path, 'training_state.json'), 'r') as f:
            state = json.load(f)
            global_epoch = state.get('global_epoch', 0)
            best_val_accuracy = state.get('best_val_accuracy', 0.0)
            previous_model_path = state.get('model_path', None)
    except FileNotFoundError:
        global_epoch = 0
        best_val_accuracy = 0.0
        previous_model_path = None
    
    try:
        with open(os.path.join(working_path, 'epoch_metrics.json'), 'r') as f:
            epoch_metrics = json.load(f)
    except FileNotFoundError:
        epoch_metrics = []
    
    return global_epoch, best_val_accuracy, previous_model_path, epoch_metrics

def save_training_state(working_path, global_epoch, best_val_accuracy, model_path):
    state = {
        'global_epoch': global_epoch,
        'best_val_accuracy': best_val_accuracy,
        'model_path': model_path
    }
    with open(os.path.join(working_path, 'training_state.json'), 'w') as f:
        json.dump(state, f, indent=4)

def validate_model(model, val_loader, criterion, device):
    model.eval()
    val_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            val_loss += loss.item()
    
    return val_loss / len(val_loader), correct / total

def evaluate_model(model, test_loader, class_names, device, working_path):
    model.eval()
    all_preds = []
    all_labels = []
    test_loss = 0
    correct = 0
    total = 0
    
    criterion = nn.CrossEntropyLoss()
    
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    accuracy = correct / total
    test_loss /= len(test_loader)
    
    # Calculate confusion matrix
    cm = confusion_matrix(all_labels, all_preds)
    plot_confusion_matrix(cm, class_names, working_path)
    
    # Generate classification report
    report = classification_report(all_labels, all_preds, 
                                 target_names=class_names, 
                                 output_dict=True)
    
    # Save classification report
    with open(os.path.join(working_path, 'classification_report.json'), 'w') as f:
        json.dump(report, f, indent=4)
    
    # Calculate precision and recall
    precision = report['weighted avg']['precision']
    recall = report['weighted avg']['recall']
    
    test_metrics = {
        'test_loss': test_loss,
        'test_accuracy': accuracy,
        'test_precision': precision,
        'test_recall': recall
    }
    
    # Save test metrics
    with open(os.path.join(working_path, 'test_metrics.json'), 'w') as f:
        json.dump(test_metrics, f, indent=4)
    
    return test_metrics

def plot_confusion_matrix(cm, classes, save_path):
    plt.figure(figsize=(20, 20))
    sns.heatmap(cm, xticklabels=classes, yticklabels=classes, 
                annot=True, fmt='d', cmap='Blues')
    plt.title('Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.xticks(rotation=90)
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.savefig(os.path.join(save_path, 'confusion_matrix.png'))
    plt.close()

def train_model(model, train_loader, val_loader, device, num_epochs, working_path, 
                global_epoch=0, best_val_accuracy=0.0):
    criterion = nn.CrossEntropyLoss()
    
    # Xception-specific optimizer configuration with different learning rates for different blocks
    optimizer = optim.AdamW([
        {'params': (p for n, p in model.named_parameters() if 'block2' in n), 'lr': 5e-5},
        {'params': (p for n, p in model.named_parameters() if 'block3' in n), 'lr': 1e-4},
        {'params': (p for n, p in model.named_parameters() if 'block4' in n), 'lr': 1e-4},
        {'params': model.xception.fc.parameters(), 'lr': 3e-4}
    ], weight_decay=0.01)
    
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 
        mode='min',
        patience=5,
        factor=0.2,
        min_lr=1e-7
    )
    
    metrics_callback = ContinuedEpochMetricsCallback(working_path)
    best_model_path = os.path.join(working_path, 'best_model.pth')
    
    patience = 10
    patience_counter = 0
    min_delta = 0.0005
    
    for epoch in range(global_epoch, num_epochs):
        metrics_callback.on_epoch_start()
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        # Training phase
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            
            loss = criterion(outputs, labels)
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
            optimizer.step()
            
            running_loss += loss.item()
            
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
        
        train_loss = running_loss / len(train_loader)
        train_acc = correct / total
        
        # Validation phase
        val_loss, val_acc = validate_model(model, val_loader, criterion, device)
        
        scheduler.step(val_loss)
        
        if val_acc > best_val_accuracy + min_delta:
            best_val_accuracy = val_acc
            torch.save(model.state_dict(), best_model_path)
            patience_counter = 0
        else:
            patience_counter += 1
        
        if patience_counter >= patience:
            print(f"\nEarly stopping triggered after epoch {epoch + 1}")
            break
        
        # Log metrics
        logs = {
            'accuracy': train_acc,
            'val_accuracy': val_acc,
            'loss': train_loss,
            'val_loss': val_loss,
            'precision': 0,
            'val_precision': 0,
            'recall': 0,
            'val_recall': 0
        }
        
        metrics_callback.on_epoch_end(epoch, logs)
        save_training_state(working_path, epoch + 1, best_val_accuracy, best_model_path)
    
    return best_val_accuracy

def main():
    BASE_PATH = '/kaggle/input/400birds/400BirdSpecies'
    WORKING_PATH = '/kaggle/working/'
    TRAIN_PATH = os.path.join(BASE_PATH, 'train')
    VALID_PATH = os.path.join(BASE_PATH, 'valid')
    TEST_PATH = os.path.join(BASE_PATH, 'test')
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Xception-specific data augmentation
    train_transform = transforms.Compose([
        transforms.Resize((299, 299)),  # Xception preferred input size
        transforms.RandomRotation(20),
        transforms.RandomHorizontalFlip(),
        transforms.RandomAffine(0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
        transforms.ColorJitter(brightness=0.15, contrast=0.15, saturation=0.15, hue=0.05),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    val_transform = transforms.Compose([
        transforms.Resize((299, 299)),  # Xception preferred input size
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # Load datasets
    train_dataset = ImageFolder(TRAIN_PATH, transform=train_transform)
    val_dataset = ImageFolder(VALID_PATH, transform=val_transform)
    test_dataset = ImageFolder(TEST_PATH, transform=val_transform)
    
    # Create data loaders with Xception-optimized batch size
    train_loader = DataLoader(train_dataset, batch_size=24, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=24, shuffle=False, num_workers=4)
    test_loader = DataLoader(test_dataset, batch_size=24, shuffle=False, num_workers=4)
    
    # Load previous state
    global_epoch, best_val_accuracy, previous_model_path, epoch_metrics = load_training_state_and_metrics(WORKING_PATH)
    
    print(f"Starting training from epoch {global_epoch + 1}")
    print(f"Previous best validation accuracy: {best_val_accuracy:.4f}")
    
    # Create model
    num_classes = len(train_dataset.classes)
    model = XceptionBirdClassifier(num_classes)
    
    if os.path.exists(os.path.join(WORKING_PATH, 'best_model.pth')):
        model.load_state_dict(torch.load(os.path.join(WORKING_PATH, 'best_model.pth')))
    
    model = model.to(device)
    
    # Store class names
    class_names = train_dataset.classes
    class_mapping = {i: class_name for i, class_name in enumerate(class_names)}
    with open(os.path.join(WORKING_PATH, 'class_mapping.json'), 'w') as f:
        json.dump(class_mapping, f, indent=4)
    
    # Training
    total_epochs = 100
    best_val_accuracy = train_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        device=device,
        num_epochs=total_epochs,
        working_path=WORKING_PATH,
        global_epoch=global_epoch,
        best_val_accuracy=best_val_accuracy
    )
    
    # Evaluate best model
    model.load_state_dict(torch.load(os.path.join(WORKING_PATH, 'best_model.pth')))
    test_metrics = evaluate_model(model, test_loader, class_names, device, WORKING_PATH)
    
    print("\nTest Set Metrics:")
    print(f"Test Loss: {test_metrics['test_loss']:.4f}")
    print(f"Test Accuracy: {test_metrics['test_accuracy']:.4f}")
    print(f"Test Precision: {test_metrics['test_precision']:.4f}")
    print(f"Test Recall: {test_metrics['test_recall']:.4f}")

if __name__ == "__main__":
    main()



Starting training from epoch 1
Previous best validation accuracy: 0.0000


  model = create_fn(
Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/xception-43020ad28.pth" to /root/.cache/torch/hub/checkpoints/xception-43020ad28.pth



Global Epoch 1 Metrics:
Time taken: 10.63 minutes
Training Accuracy: 0.6523
Validation Accuracy: 0.9325
Training Loss: 1.5668
Validation Loss: 0.2719

Global Epoch 2 Metrics:
Time taken: 10.62 minutes
Training Accuracy: 0.8432
Validation Accuracy: 0.9550
Training Loss: 0.5757
Validation Loss: 0.1611

Global Epoch 3 Metrics:
Time taken: 10.63 minutes
Training Accuracy: 0.8781
Validation Accuracy: 0.9555
Training Loss: 0.4385
Validation Loss: 0.1397

Global Epoch 4 Metrics:
Time taken: 10.62 minutes
Training Accuracy: 0.8965
Validation Accuracy: 0.9665
Training Loss: 0.3678
Validation Loss: 0.1122

Global Epoch 5 Metrics:
Time taken: 10.62 minutes
Training Accuracy: 0.9081
Validation Accuracy: 0.9590
Training Loss: 0.3212
Validation Loss: 0.1136

Global Epoch 6 Metrics:
Time taken: 10.63 minutes
Training Accuracy: 0.9192
Validation Accuracy: 0.9750
Training Loss: 0.2824
Validation Loss: 0.0985

Global Epoch 7 Metrics:
Time taken: 10.62 minutes
Training Accuracy: 0.9250
Validation Accura

  model.load_state_dict(torch.load(os.path.join(WORKING_PATH, 'best_model.pth')))



Test Set Metrics:
Test Loss: 0.0260
Test Accuracy: 0.9940
Test Precision: 0.9948
Test Recall: 0.9940
