In [None]:
import os
import time
import json
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms, models
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp

class ContinuedEpochMetricsCallback:
    def __init__(self, working_path):
        self.working_path = working_path
        self.epoch_metrics = []
        self.epoch_metrics_file = os.path.join(working_path, 'epoch_metrics.json')
        
        if os.path.exists(self.epoch_metrics_file):
            with open(self.epoch_metrics_file, 'r') as f:
                self.epoch_metrics = json.load(f)
    
    def save_epoch_metrics(self):
        with open(self.epoch_metrics_file, 'w') as f:
            json.dump(self.epoch_metrics, f, indent=4)
    
    def plot_metrics(self):
        if not self.epoch_metrics:
            return
            
        epochs = list(range(1, len(self.epoch_metrics) + 1))
        metrics = {
            'Accuracy': ('accuracy', 'val_accuracy'),
            'Loss': ('loss', 'val_loss'),
            'Precision': ('precision', 'val_precision'),
            'Recall': ('recall', 'val_recall')
        }
        
        fig, axes = plt.subplots(2, 2, figsize=(15, 12))
        axes = axes.flatten()
        
        for idx, (title, (train_key, val_key)) in enumerate(metrics.items()):
            train_metric = [m[train_key] for m in self.epoch_metrics]
            val_metric = [m[val_key] for m in self.epoch_metrics]
            
            axes[idx].plot(epochs, train_metric, label=f'Training {title}')
            axes[idx].plot(epochs, val_metric, label=f'Validation {title}')
            axes[idx].set_title(f'Model {title}')
            axes[idx].set_xlabel('Epoch')
            axes[idx].set_ylabel(title)
            axes[idx].legend()
        
        plt.tight_layout()
        plt.savefig(os.path.join(self.working_path, f'training_metrics_epoch_{len(epochs)}.png'))
        plt.close()

class GoogLeNetBirdClassifier(nn.Module):
    def __init__(self, num_classes):
        super(GoogLeNetBirdClassifier, self).__init__()
        self.googlenet = models.googlenet(pretrained=True)
        
        # Freeze all parameters initially
        for param in self.googlenet.parameters():
            param.requires_grad = False
            
        # Unfreeze the last inception blocks (a3, b3)
        layers_to_unfreeze = ['inception5a', 'inception5b']
        for name, param in self.googlenet.named_parameters():
            if any(layer in name for layer in layers_to_unfreeze):
                param.requires_grad = True
        
        # Modified classifier head
        num_features = self.googlenet.fc.in_features
        self.googlenet.fc = nn.Sequential(
            nn.Linear(num_features, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(512, num_classes)
        )
        
    def forward(self, x):
        return self.googlenet(x)


def load_training_state(working_path):
    try:
        with open(os.path.join(working_path, 'training_state.json'), 'r') as f:
            state = json.load(f)
            return (state.get('global_epoch', 0),
                   state.get('best_val_accuracy', 0.0),
                   state.get('model_path', None))
    except FileNotFoundError:
        return 0, 0.0, None

def save_training_state(working_path, global_epoch, best_val_accuracy, model_path):
    state = {
        'global_epoch': global_epoch,
        'best_val_accuracy': best_val_accuracy,
        'model_path': model_path
    }
    with open(os.path.join(working_path, 'training_state.json'), 'w') as f:
        json.dump(state, f, indent=4)

def train_model(model, train_loader, val_loader, device, num_epochs, working_path, 
                global_epoch=0, best_val_accuracy=0.0):
    criterion = nn.CrossEntropyLoss()
    
    # Create parameter groups with different learning rates
    # Parameters of unfrozen inception layers
    inception_params = []
    # Parameters of the classifier (fc layer)
    classifier_params = []
    
    for name, param in model.named_parameters():
        if param.requires_grad:
            if 'fc' in name:
                classifier_params.append(param)
            else:
                inception_params.append(param)
    
    optimizer = optim.Adam([
        {'params': inception_params, 'lr': 1e-4},
        {'params': classifier_params, 'lr': 2e-4}
    ])
    
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', patience=5, factor=0.2, min_lr=1e-7
    )
    
    metrics_callback = ContinuedEpochMetricsCallback(working_path)
    best_model_path = os.path.join(working_path, 'best_model.pth')
    
    for epoch in range(global_epoch, num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        # Wrap data loader for TPU
        para_train_loader = pl.ParallelLoader(train_loader, [device]).per_device_loader(device)
        
        for images, labels in para_train_loader:
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            xm.optimizer_step(optimizer)
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
        
        train_loss = running_loss / len(train_loader)
        train_acc = correct / total
        
        # Validation phase
        model.eval()
        val_loss = 0
        correct = 0
        total = 0
        
        para_val_loader = pl.ParallelLoader(val_loader, [device]).per_device_loader(device)
        
        with torch.no_grad():
            for images, labels in para_val_loader:
                outputs = model(images)
                loss = criterion(outputs, labels)
                
                val_loss += loss.item()
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()
        
        val_loss = val_loss / len(val_loader)
        val_acc = correct / total
        
        # Update learning rate
        scheduler.step(val_loss)
        
        # Save best model
        if val_acc > best_val_accuracy:
            best_val_accuracy = val_acc
            xm.save(model.state_dict(), best_model_path)
        
        # Log metrics
        metrics_callback.epoch_metrics.append({
            'accuracy': train_acc,
            'val_accuracy': val_acc,
            'loss': train_loss,
            'val_loss': val_loss,
            'precision': 0,
            'val_precision': 0,
            'recall': 0,
            'val_recall': 0
        })
        
        metrics_callback.save_epoch_metrics()
        metrics_callback.plot_metrics()
        save_training_state(working_path, epoch + 1, best_val_accuracy, best_model_path)
        
        xm.master_print(f'\nEpoch {epoch + 1}/{num_epochs}')
        xm.master_print(f'Training Loss: {train_loss:.4f}')
        xm.master_print(f'Training Accuracy: {train_acc:.4f}')
        xm.master_print(f'Validation Loss: {val_loss:.4f}')
        xm.master_print(f'Validation Accuracy: {val_acc:.4f}')
    
    return best_val_accuracy

def evaluate_model(model, test_loader, class_names, device, working_path):
    model.eval()
    all_preds = []
    all_labels = []
    test_loss = 0
    correct = 0
    total = 0
    
    criterion = nn.CrossEntropyLoss()
    para_test_loader = pl.ParallelLoader(test_loader, [device]).per_device_loader(device)
    
    with torch.no_grad():
        for images, labels in para_test_loader:
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    accuracy = correct / total
    test_loss = test_loss / len(test_loader)
    
    # Generate and save confusion matrix
    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(20, 20))
    sns.heatmap(cm, xticklabels=class_names, yticklabels=class_names, 
                annot=True, fmt='d', cmap='Blues')
    plt.title('Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.tight_layout()
    plt.savefig(os.path.join(working_path, 'confusion_matrix.png'))
    plt.close()
    
    # Generate and save classification report
    report = classification_report(all_labels, all_preds, 
                                 target_names=class_names, 
                                 output_dict=True)
    
    with open(os.path.join(working_path, 'classification_report.json'), 'w') as f:
        json.dump(report, f, indent=4)
    
    return {
        'test_loss': test_loss,
        'test_accuracy': accuracy
    }

def main():
    BASE_PATH = '/kaggle/input/400birds/400BirdSpecies'
    WORKING_PATH = '/kaggle/working/'
    TRAIN_PATH = os.path.join(BASE_PATH, 'train')
    VALID_PATH = os.path.join(BASE_PATH, 'valid')
    TEST_PATH = os.path.join(BASE_PATH, 'test')
    
    # TPU device initialization
    device = xm.xla_device()
    
    # Data augmentation and normalization
    train_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomRotation(20),
        transforms.RandomHorizontalFlip(),
        transforms.RandomAffine(0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
        transforms.ColorJitter(brightness=0.15, contrast=0.15, saturation=0.15, hue=0.05),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    val_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # Load datasets
    train_dataset = ImageFolder(TRAIN_PATH, transform=train_transform)
    val_dataset = ImageFolder(VALID_PATH, transform=val_transform)
    test_dataset = ImageFolder(TEST_PATH, transform=val_transform)
    
    # Create data loaders optimized for TPU
    train_loader = DataLoader(
        train_dataset, 
        batch_size=32, 
        shuffle=True, 
        num_workers=4,
        drop_last=True
    )
    val_loader = DataLoader(
        val_dataset, 
        batch_size=32, 
        shuffle=False, 
        num_workers=4,
        drop_last=True
    )
    test_loader = DataLoader(
        test_dataset, 
        batch_size=32, 
        shuffle=False, 
        num_workers=4,
        drop_last=True
    )
    
    # Load previous state
    global_epoch, best_val_accuracy, _ = load_training_state(WORKING_PATH)
    
    xm.master_print(f"Starting training from epoch {global_epoch + 1}")
    xm.master_print(f"Previous best validation accuracy: {best_val_accuracy:.4f}")
    
    # Create and load model
    num_classes = len(train_dataset.classes)
    model = GoogLeNetBirdClassifier(num_classes)
    
    if os.path.exists(os.path.join(WORKING_PATH, 'best_model.pth')):
        model.load_state_dict(torch.load(os.path.join(WORKING_PATH, 'best_model.pth')))
    
    model = model.to(device)
    
    # Save class mapping
    class_names = train_dataset.classes
    class_mapping = {i: name for i, name in enumerate(class_names)}
    with open(os.path.join(WORKING_PATH, 'class_mapping.json'), 'w') as f:
        json.dump(class_mapping, f, indent=4)
    
    # Train model
    best_val_accuracy = train_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        device=device,
        num_epochs=100,
        working_path=WORKING_PATH,
        global_epoch=global_epoch,
        best_val_accuracy=best_val_accuracy
    )
    
    # Evaluate best model
    model.load_state_dict(torch.load(os.path.join(WORKING_PATH, 'best_model.pth')))
    test_metrics = evaluate_model(model, test_loader, class_names, device, WORKING_PATH)
    
    xm.master_print("\nTest Set Metrics:")
    xm.master_print(f"Test Loss: {test_metrics['test_loss']:.4f}")
    xm.master_print(f"Test Accuracy: {test_metrics['test_accuracy']:.4f}")

if __name__ == "__main__":
    def _mp_fn(rank, flags):
        main()
    xmp.spawn(_mp_fn, args=({},), nprocs=1, start_method='fork')