# New 02 - Task 01 - VetCyto - MOBILENETv3L - Training Code (Automatically Labeled)

### alternative code 03 ~ the pytorch version of the original tensorflow code

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision import datasets, transforms, models
import matplotlib.pyplot as plt
import numpy as np
import os
import sys
import json
from datetime import datetime
import seaborn as sns
from sklearn.metrics import classification_report, balanced_accuracy_score, f1_score, confusion_matrix
from sklearn.utils.class_weight import compute_class_weight
import subprocess
from PIL import Image

def get_device_config():
    """
    Detect and configure device settings based on available hardware
    Returns tuple of (device, batch_size, image_size, device_info)
    """
    device_info = {}
    
    # Check if CUDA is available
    if not torch.cuda.is_available():
        device_info = {
            "device_name": "CPU",
            "memory": "N/A",
            "reason": "No GPU detected by PyTorch"
        }
        print("\n❌ ERROR: No GPU detected by PyTorch. This script requires a GPU to run.")
        print("📋 Debugging information:")
        print("  - PyTorch version:", torch.__version__)
        print("  - CUDA available:", torch.cuda.is_available())
        print("  - Python version:", sys.version)
        
        # Try to run nvidia-smi as a fallback detection method
        try:
            result = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
            print("\n📊 NVIDIA-SMI Output:")
            print(result.stdout)
            if result.stderr:
                print("NVIDIA-SMI Error:", result.stderr)
                
            print("\n💡 Possible issues:")
            print("  - NVIDIA drivers not installed or outdated")
            print("  - CUDA/cuDNN not installed or not compatible with PyTorch")
            print("  - PyTorch not built with GPU support")
            print("\n📝 Recommendation: Verify GPU installation with 'nvidia-smi' command and ensure PyTorch is properly installed with GPU support")
        except:
            print("\nFailed to run nvidia-smi. NVIDIA drivers may not be installed.")
        
        sys.exit(1)
    
    # Get GPU information
    gpu_name = torch.cuda.get_device_name(0)
    
    # Try to get GPU memory info
    try:
        result = subprocess.run(['nvidia-smi', '--query-gpu=name,memory.total', '--format=csv,noheader,nounits'], 
                               stdout=subprocess.PIPE, text=True, check=True)
        gpu_info = result.stdout.strip().split(',')
        gpu_name = gpu_info[0].strip()
        gpu_memory = float(gpu_info[1].strip()) / 1024  # Convert MB to GB
    except:
        # Fallback if nvidia-smi fails
        gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)  # Convert bytes to GB
        if not gpu_memory:
            gpu_memory = 6.0  # Assume 6GB for RTX 3060 based on previous info
    
    device_info = {
        "device_name": gpu_name,
        "gpu_memory": f"{gpu_memory:.2f} GB",
        "pytorch_version": torch.__version__,
        "cuda_version": torch.version.cuda if hasattr(torch.version, 'cuda') else "N/A"
    }
    
    # Set fixed image size of 112x112 as requested
    # Only batch size is adjusted based on GPU memory
    if gpu_memory >= 8:
        batch_size = 32
    elif gpu_memory >= 6:  # RTX 3060 with 6GB
        batch_size = 24
    elif gpu_memory >= 4:
        batch_size = 16
    else:
        batch_size = 8
    
    # Fixed image size of 112x112 as requested
    image_size = 112
    
    return "cuda", batch_size, image_size, device_info

def compute_class_weights(directory):
    """Automatically compute class weights based on class distribution"""
    class_counts = {}
    class_dirs = [d for d in os.listdir(directory) if os.path.isdir(os.path.join(directory, d))]
    
    for class_name in class_dirs:
        class_dir = os.path.join(directory, class_name)
        class_counts[class_name] = len([f for f in os.listdir(class_dir) 
                                       if os.path.isfile(os.path.join(class_dir, f))])
    
    # Convert to format needed for compute_class_weight
    class_labels = []
    counts = []
    for i, (class_name, count) in enumerate(class_counts.items()):
        class_labels.extend([i] * count)
        counts.append(count)
    
    # Compute class weights using sklearn
    class_weights = compute_class_weight(
        class_weight='balanced',
        classes=np.unique(class_labels),
        y=class_labels
    )
    
    # Create dictionary mapping class indices to weights
    class_weight_dict = {i: weight for i, weight in enumerate(class_weights)}
    
    print("Class distribution:", class_counts)
    print("Computed class weights:", class_weight_dict)
    
    # Convert to PyTorch tensor for loss function
    weights_tensor = torch.FloatTensor(list(class_weight_dict.values()))
    
    return class_weight_dict, weights_tensor

def create_dataset(train_folder, val_folder, test_folder, image_size=(112, 112), batch_size=32):
    """
    Create PyTorch DataLoaders for training, validation, and test datasets
    Fixed image size of 112x112 pixels as requested
    """
    # Training transforms with augmentation
    train_transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.RandomRotation(40),
        transforms.RandomHorizontalFlip(),
        transforms.RandomAffine(degrees=0, translate=(0.2, 0.2)),
        transforms.RandomResizedCrop(image_size, scale=(0.8, 1.0)),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # Validation and test transforms (no augmentation)
    basic_transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # Create datasets
    train_dataset = datasets.ImageFolder(root=train_folder, transform=train_transform)
    val_dataset = datasets.ImageFolder(root=val_folder, transform=basic_transform)
    test_dataset = datasets.ImageFolder(root=test_folder, transform=basic_transform)
    
    # Create additional augmentation in model
    augmentation_layers = nn.Sequential(
        # These are handled by transforms above, kept for structure similarity
    )
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )
    
    return train_loader, val_loader, test_loader, train_dataset.classes, augmentation_layers

def build_model(input_shape=(112, 112, 3), num_classes=4):
    """
    Build MobileNetV3Large model with custom classifier for PyTorch
    Fixed input shape of 112x112x3 as requested
    """
    # Load pretrained model with weights parameter instead of deprecated pretrained
    model = models.mobilenet_v3_large(weights=models.MobileNet_V3_Large_Weights.IMAGENET1K_V1)
    
    # Freeze the pretrained weights
    for param in model.parameters():
        param.requires_grad = False
    
    # Rebuild top with BatchNorm and increased dropout
    model.classifier = nn.Sequential(
        nn.BatchNorm1d(960),
        nn.Dropout(p=0.5),  # Increased dropout rate
        nn.Linear(960, num_classes)
    )
    
    return model

def plot_hist(train_hist, val_hist, metric="accuracy", save_path=None):
    """
    Plot training and validation history, save to disk and display
    
    Parameters:
    train_hist, val_hist (dict): Training and validation history
    metric (str): Metric to plot (accuracy, loss, etc.)
    save_path (str): Directory to save the plot
    """
    # Skip creating a separate loss vs loss plot
    if metric == "loss":
        return
        
    plt.figure(figsize=(12, 5))

    # Plot loss (left side)
    plt.subplot(1, 2, 1)
    plt.plot(train_hist['loss'])
    plt.plot(val_hist['loss'])
    plt.title("Model Loss")
    plt.ylabel("Loss")
    plt.xlabel("Epoch")
    plt.legend(["Train", "Validation"], loc="upper left")

    # Plot other metric (right side)
    plt.subplot(1, 2, 2)
    plt.plot(train_hist[metric])
    plt.plot(val_hist[metric])
    plt.title(f"Model {metric.capitalize()}")
    plt.ylabel(metric.capitalize())
    plt.xlabel("Epoch")
    plt.legend(["Train", "Validation"], loc="upper left")

    plt.tight_layout()
    
    # Save the plot if save_path is provided
    if save_path:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = os.path.join(save_path, f"training_history_loss_vs_{metric}_{timestamp}.png")
        plt.savefig(filename, dpi=300, bbox_inches='tight')
        print(f"Saved training history plot to: {filename}")
    
    # Display the plot
    plt.show()

def plot_confusion_matrix(y_true, y_pred, class_names, save_path=None):
    """
    Plot confusion matrix using seaborn, save to disk and display
    
    Parameters:
    y_true: Ground truth labels
    y_pred: Predicted labels
    class_names: List of class names
    save_path (str): Directory to save the plot
    """
    # Compute confusion matrix
    cm = confusion_matrix(y_true, y_pred)
    
    # Plot
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    
    # Save the plot if save_path is provided
    if save_path:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = os.path.join(save_path, f"confusion_matrix_{timestamp}.png")
        plt.savefig(filename, dpi=300, bbox_inches='tight')
        print(f"Saved confusion matrix to: {filename}")
    
    # Display the plot
    plt.show()

def unfreeze_model(model):
    """
    Unfreeze the last 20 layers of the model for fine-tuning
    """
    # Get all feature modules
    features = list(model.features.named_children())
    
    # Unfreeze the last 20 layers
    for i, (name, layer) in enumerate(features):
        if i >= len(features) - 20:
            if not isinstance(layer, nn.BatchNorm2d):
                for param in layer.parameters():
                    param.requires_grad = True
    
    # Classifier is already unfrozen

def calculate_f1(precision, recall):
    """Calculate F1 score from precision and recall"""
    return 2 * (precision * recall) / (precision + recall + 1e-10)

def train_epoch(model, train_loader, criterion, optimizer, device, class_weights=None, steps_per_epoch=None):
    """
    Train for one epoch
    """
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    all_targets = []
    all_predictions = []
    
    # Limit batches if steps_per_epoch is specified
    batch_count = 0
    
    for inputs, targets in train_loader:
        if steps_per_epoch is not None and batch_count >= steps_per_epoch:
            break
            
        inputs, targets = inputs.to(device), targets.to(device)
        
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
        
        # Statistics
        running_loss += loss.item() * inputs.size(0)
        _, predicted = torch.max(outputs, 1)
        total += targets.size(0)
        correct += (predicted == targets).sum().item()
        
        all_targets.extend(targets.cpu().numpy())
        all_predictions.extend(predicted.cpu().numpy())
        
        batch_count += 1
    
    # Calculate statistics
    if total == 0:  # Avoid division by zero
        return {
            'loss': 0,
            'accuracy': 0,
            'precision': 0,
            'recall': 0,
            'f1': 0
        }
        
    epoch_loss = running_loss / total
    epoch_accuracy = correct / total
    
    # Compute precision and recall
    from sklearn.metrics import precision_score, recall_score
    epoch_precision = precision_score(all_targets, all_predictions, average='weighted', zero_division=0)
    epoch_recall = recall_score(all_targets, all_predictions, average='weighted', zero_division=0)
    
    # Calculate F1 score
    epoch_f1 = calculate_f1(epoch_precision, epoch_recall)
    
    return {
        'loss': epoch_loss,
        'accuracy': epoch_accuracy,
        'precision': epoch_precision,
        'recall': epoch_recall,
        'f1': epoch_f1
    }

def validate(model, val_loader, criterion, device, steps=None):
    """
    Validate the model
    """
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_targets = []
    all_predictions = []
    
    # Limit batches if steps is specified
    batch_count = 0
    
    with torch.no_grad():
        for inputs, targets in val_loader:
            if steps is not None and batch_count >= steps:
                break
                
            inputs, targets = inputs.to(device), targets.to(device)
            
            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            # Statistics
            running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()
            
            all_targets.extend(targets.cpu().numpy())
            all_predictions.extend(predicted.cpu().numpy())
            
            batch_count += 1
    
    # Calculate statistics
    if total == 0:  # Avoid division by zero
        return {
            'loss': 0,
            'accuracy': 0,
            'precision': 0,
            'recall': 0,
            'f1': 0
        }
        
    epoch_loss = running_loss / total
    epoch_accuracy = correct / total
    
    # Compute precision and recall
    from sklearn.metrics import precision_score, recall_score
    epoch_precision = precision_score(all_targets, all_predictions, average='weighted', zero_division=0)
    epoch_recall = recall_score(all_targets, all_predictions, average='weighted', zero_division=0)
    
    # Calculate F1 score
    epoch_f1 = calculate_f1(epoch_precision, epoch_recall)
    
    return {
        'loss': epoch_loss,
        'accuracy': epoch_accuracy,
        'precision': epoch_precision,
        'recall': epoch_recall,
        'f1': epoch_f1,
        'predictions': all_predictions,
        'targets': all_targets
    }

def train_and_evaluate(train_folder, val_folder, test_folder, save_path=None, batch_size=32, 
                      image_size=(112, 112), nbs=None):
    """
    Train and evaluate the model
    """
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Create datasets and loaders
    train_loader, val_loader, test_loader, class_names, augmentation_layers = create_dataset(
        train_folder, val_folder, test_folder, image_size=image_size, batch_size=batch_size
    )
    
    # Calculate validation steps if nbs is provided
    steps_per_epoch = nbs
    validation_steps = nbs // 2 if nbs is not None else None
    
    # Compute class weights for balanced training
    class_weight_dict, class_weights_tensor = compute_class_weights(train_folder)
    class_weights_tensor = class_weights_tensor.to(device)
    
    # Build model
    model = build_model(input_shape=(*image_size, 3), num_classes=len(class_names))
    model = model.to(device)
    
    # Define loss function with class weights
    criterion = nn.CrossEntropyLoss(weight=class_weights_tensor)
    
    # Define optimizer with weight decay (AdamW equivalent)
    optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
    
    # Learning rate scheduler (equivalent to ReduceLROnPlateau)
    # Removed verbose parameter to avoid deprecation warning
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5, min_lr=1e-6
    )
    
    # Initialize history dictionaries
    train_hist = {'loss': [], 'accuracy': [], 'precision': [], 'recall': [], 'f1': []}
    val_hist = {'loss': [], 'accuracy': [], 'precision': [], 'recall': [], 'f1': []}
    
    # Set up for early stopping and checkpointing
    checkpoint_path_temp = os.path.join(save_path, "temporary_checkpoint.pt")
    best_val_accuracy = 0
    patience = 10
    patience_counter = 0
    best_val_loss = float('inf')
    
    # First training phase
    print("\n=== Initial Training Phase ===")
    epochs = 30
    
    for epoch in range(epochs):
        # Training phase
        train_metrics = train_epoch(
            model, train_loader, criterion, optimizer, device, 
            class_weights=class_weights_tensor, 
            steps_per_epoch=steps_per_epoch
        )
        
        # Validation phase
        val_metrics = validate(model, val_loader, criterion, device, steps=validation_steps)
        
        # Update learning rate scheduler
        scheduler.step(val_metrics['loss'])
        
        # Update history
        for metric in ['loss', 'accuracy', 'precision', 'recall', 'f1']:
            train_hist[metric].append(train_metrics[metric])
            val_hist[metric].append(val_metrics[metric])
        
        print(f"Epoch {epoch+1}/{epochs}, "
              f"Train Loss: {train_metrics['loss']:.4f}, Train Acc: {train_metrics['accuracy']:.4f}, "
              f"Val Loss: {val_metrics['loss']:.4f}, Val Acc: {val_metrics['accuracy']:.4f}")
        
        # Save best model
        if val_metrics['accuracy'] > best_val_accuracy:
            best_val_accuracy = val_metrics['accuracy']
            torch.save(model.state_dict(), checkpoint_path_temp)
            print(f"    Saving model with validation accuracy: {best_val_accuracy:.4f}")
        
        # Early stopping check
        if val_metrics['loss'] < best_val_loss:
            best_val_loss = val_metrics['loss']
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping triggered at epoch {epoch+1}")
                break
    
    print("Phase 1 Training completed")
    print("Best val accuracy:", best_val_accuracy)
    
    # Save initial histories
    frozen_train_hist = train_hist.copy()
    frozen_val_hist = val_hist.copy()
    
    # Load best model from phase 1 (using weights_only=True to avoid warning)
    model.load_state_dict(torch.load(checkpoint_path_temp, weights_only=True))
    
    # Reset histories for phase 2
    train_hist = {'loss': [], 'accuracy': [], 'precision': [], 'recall': [], 'f1': []}
    val_hist = {'loss': [], 'accuracy': [], 'precision': [], 'recall': [], 'f1': []}
    
    # Unfreeze some layers and continue training
    print("\n=== Fine-tuning Phase ===")
    unfreeze_model(model)
    
    # New optimizer with lower learning rate
    optimizer = optim.AdamW(model.parameters(), lr=1e-5, weight_decay=1e-4)
    # Removed verbose parameter to avoid deprecation warning
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5, min_lr=1e-6
    )
    
    # Reset early stopping variables
    best_val_accuracy_unfrozen = best_val_accuracy
    patience_counter = 0
    best_val_loss = float('inf')
    
    # Second training phase
    epochs_unfreeze = 30
    
    for epoch in range(epochs_unfreeze):
        # Training phase
        train_metrics = train_epoch(
            model, train_loader, criterion, optimizer, device, 
            class_weights=class_weights_tensor, 
            steps_per_epoch=steps_per_epoch
        )
        
        # Validation phase
        val_metrics = validate(model, val_loader, criterion, device, steps=validation_steps)
        
        # Update learning rate scheduler
        scheduler.step(val_metrics['loss'])
        
        # Update history
        for metric in ['loss', 'accuracy', 'precision', 'recall']:
            train_hist[metric].append(train_metrics[metric])
            val_hist[metric].append(val_metrics[metric])
        
        print(f"Epoch {epoch+1}/{epochs_unfreeze}, "
              f"Train Loss: {train_metrics['loss']:.4f}, Train Acc: {train_metrics['accuracy']:.4f}, "
              f"Val Loss: {val_metrics['loss']:.4f}, Val Acc: {val_metrics['accuracy']:.4f}")
        
        # Save best model
        if val_metrics['accuracy'] > best_val_accuracy_unfrozen:
            best_val_accuracy_unfrozen = val_metrics['accuracy']
            torch.save(model.state_dict(), checkpoint_path_temp)
            print(f"    Saving model with validation accuracy: {best_val_accuracy_unfrozen:.4f}")
        
        # Early stopping check
        if val_metrics['loss'] < best_val_loss:
            best_val_loss = val_metrics['loss']
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping triggered at epoch {epoch+1}")
                break
    
    print("Phase 2 Training completed")
    print("Best val accuracy (unfrozen):", best_val_accuracy_unfrozen)
    
    # Find overall best model
    all_val_accuracies = frozen_val_hist['accuracy'] + val_hist['accuracy']
    max_val_accuracy = max(all_val_accuracies)
    epoch_with_max_val_accuracy = all_val_accuracies.index(max_val_accuracy) + 1
    
    # Rename model file with "best_model" in filename
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    checkpoint_name = f"best_model_MobileNet_epoch{epoch_with_max_val_accuracy}_acc{int(max_val_accuracy * 100)}_{timestamp}.pt"
    final_checkpoint_path = os.path.join(save_path, checkpoint_name)
    os.rename(checkpoint_path_temp, final_checkpoint_path)
    
    print(f"\nBest model: Epoch {epoch_with_max_val_accuracy} with validation accuracy {max_val_accuracy:.4f}")
    print(f"Saved best model to: {final_checkpoint_path}")
    
    # Load best model (using weights_only=True to avoid warning)
    model.load_state_dict(torch.load(final_checkpoint_path, weights_only=True))
    
    # Plot and save training history (combine phase 1 and phase 2)
    combined_train_hist = {
        'loss': frozen_train_hist['loss'] + train_hist['loss'],
        'accuracy': frozen_train_hist['accuracy'] + train_hist['accuracy'],
        'precision': frozen_train_hist['precision'] + train_hist['precision'],
        'recall': frozen_train_hist['recall'] + train_hist['recall'],
        'f1': frozen_train_hist['f1'] + train_hist['f1']
    }
    
    combined_val_hist = {
        'loss': frozen_val_hist['loss'] + val_hist['loss'],
        'accuracy': frozen_val_hist['accuracy'] + val_hist['accuracy'],
        'precision': frozen_val_hist['precision'] + val_hist['precision'],
        'recall': frozen_val_hist['recall'] + val_hist['recall'],
        'f1': frozen_val_hist['f1'] + val_hist['f1']
    }
    
    # Save training history plots
    plot_hist(combined_train_hist, combined_val_hist, metric="accuracy", save_path=save_path)
    plot_hist(combined_train_hist, combined_val_hist, metric="precision", save_path=save_path)
    plot_hist(combined_train_hist, combined_val_hist, metric="recall", save_path=save_path)
    plot_hist(combined_train_hist, combined_val_hist, metric="f1", save_path=save_path)
    # Not creating a loss vs loss plot as requested
    
    # Evaluate on test set
    print("\n=== Test Set Evaluation ===")
    test_metrics = validate(model, test_loader, criterion, device)
    
    test_loss = test_metrics['loss']
    test_accuracy = test_metrics['accuracy']
    test_precision = test_metrics['precision']
    test_recall = test_metrics['recall']
    test_f1 = 2 * (test_precision * test_recall) / (test_precision + test_recall + 1e-10)
    
    print(f"Test Loss: {test_loss:.4f}")
    print(f"Test Accuracy: {test_accuracy:.4f}")
    print(f"Test Precision: {test_precision:.4f}")
    print(f"Test Recall: {test_recall:.4f}")
    print(f"Test F1 Score: {test_f1:.4f}")
    
    # Generate class report and additional metrics
    y_true = test_metrics['targets']
    y_pred = test_metrics['predictions']
    
    # Calculate balanced accuracy
    balanced_acc = balanced_accuracy_score(y_true, y_pred)
    print(f"Balanced Accuracy: {balanced_acc:.4f}")
    
    # Per-class F1 scores
    f1_per_class = f1_score(y_true, y_pred, average=None)
    for i, class_name in enumerate(class_names):
        print(f"F1 Score for class {class_name}: {f1_per_class[i]:.4f}")
    
    # Full classification report
    print("\nClassification Report:")
    report = classification_report(y_true, y_pred, target_names=class_names)
    print(report)
    
    # Save classification report to file
    report_path = os.path.join(save_path, f"classification_report_{timestamp}.txt")
    with open(report_path, 'w') as f:
        f.write(f"Test Loss: {test_loss:.4f}\n")
        f.write(f"Test Accuracy: {test_accuracy:.4f}\n")
        f.write(f"Test Precision: {test_precision:.4f}\n")
        f.write(f"Test Recall: {test_recall:.4f}\n")
        f.write(f"Test F1 Score: {test_f1:.4f}\n")
        f.write(f"Balanced Accuracy: {balanced_acc:.4f}\n\n")
        f.write("Per-class F1 Scores:\n")
        for i, class_name in enumerate(class_names):
            f.write(f"F1 Score for class {class_name}: {f1_per_class[i]:.4f}\n")
        f.write("\nClassification Report:\n")
        f.write(report)
    
    print(f"Saved classification report to: {report_path}")
    
    # Plot and save confusion matrix
    plot_confusion_matrix(y_true, y_pred, class_names, save_path=save_path)
    
    # Save final (last) model with descriptive name
    final_model_path = os.path.join(save_path, f"last_model_acc{int(test_accuracy * 100)}_{timestamp}.pt")
    torch.save(model.state_dict(), final_model_path)
    print(f"Saved last model to: {final_model_path}")
    
    # Save model in torchscript format for deployment
    # Commented out as requested - not needed for current use case
    # try:
    #     model.eval()
    #     scripted_model = torch.jit.script(model)
    #     script_path = os.path.join(save_path, f"model_scripted_{timestamp}.pt")
    #     scripted_model.save(script_path)
    #     print(f"Saved TorchScript model to: {script_path}")
    # except Exception as e:
    #     print(f"Failed to save TorchScript model: {e}")
    
    return model

def recommended_nbs(dataset_size, batch_size):
    """Recommend number of batches per epoch based on dataset size and batch size"""
    # Calculate total number of batches in dataset
    total_batches = dataset_size // batch_size
    
    # For small datasets, use all batches
    if total_batches < 100:
        return None  # Use all batches
    
    # For larger datasets, limit to avoid memory issues
    return min(total_batches, 200)  # Cap at 200 batches per epoch

# Main execution block - moved closer to the end as requested
if __name__ == "__main__":
    # Configuration flag to enable/disable user input prompts (default: False)
    ENABLE_USER_INPUT = False
    
    # notes : automated annotations ~ 100% ;
    # example 02 A : 'C:\\Users\\praam\\Desktop\\havetai+vetcyto\\04th-setup_task-04_new-work-03\\vet_images_sliced_split\\TrainingStepSet_automated-labels_T_640_100-pc\\' ~ train/val/test ; 
    # example 02 B : 'C:\\Users\\praam\\Desktop\\havetai+vetcyto\\04th-setup_task-04_new-work-03\\outputMOBILENETv3L_AutomatedAnnotation_100-pc' ;
    # notes : automated annotations ~ 125% ;
    # example 04 A : 'C:\\Users\\praam\\Desktop\\havetai+vetcyto\\04th-setup_task-04_new-work-03\\vet_images_sliced_split\\TrainingStepSet_automated-labels_T_640_125-pc\\' ~ train/val/test ; 
    # example 04 B : 'C:\\Users\\praam\\Desktop\\havetai+vetcyto\\04th-setup_task-04_new-work-03\\outputMOBILENETv3L_AutomatedAnnotation_125-pc' ;
    # notes : automated annotations ~ 150% ;
    # example 06 A : 'C:\\Users\\praam\\Desktop\\havetai+vetcyto\\04th-setup_task-04_new-work-03\\vet_images_sliced_split\\TrainingStepSet_automated-labels_T_640_150-pc\\' ~ train/val/test ; 
    # example 06 B : 'C:\\Users\\praam\\Desktop\\havetai+vetcyto\\04th-setup_task-04_new-work-03\\outputMOBILENETv3L_AutomatedAnnotation_150-pc' ;
    
    # default : 'C:/Users/karli/Desktop/vet_images_sliced_split(training)/train' ; 
    # alt : 'E:\\-_EDI_-\\notes\\havetai+vetcyto\\vet_images_sliced_split\\train' ;
    train_folder = 'C:\\Users\\praam\\Desktop\\havetai+vetcyto\\04th-setup_task-04_new-work-03\\vet_images_sliced_split\\TrainingStepSet_automated-labels_T_640_150-pc\\train' 
    # default : 'C:/Users/karli/Desktop/vet_images_sliced_split(training)/val' ; 
    # alt : 'E:\\-_EDI_-\\notes\\havetai+vetcyto\\vet_images_sliced_split\\val' ;
    val_folder = 'C:\\Users\\praam\\Desktop\\havetai+vetcyto\\04th-setup_task-04_new-work-03\\vet_images_sliced_split\\TrainingStepSet_automated-labels_T_640_150-pc\\val'      
    # default : 'C:/Users/karli/Desktop/vet_images_sliced_split(training)/test' ; 
    # alt : 'E:\\-_EDI_-\\notes\\havetai+vetcyto\\vet_images_sliced_split\\test' ;
    test_folder = 'C:\\Users\\praam\\Desktop\\havetai+vetcyto\\04th-setup_task-04_new-work-03\\vet_images_sliced_split\\TrainingStepSet_automated-labels_T_640_150-pc\\test'    
    # default : 'C:/Users/karli/Desktop/outputMobile' ; 
    # alt : 'E:\\-_EDI_-\\notes\\havetai+vetcyto\\outputMOBILENETv3L' ;
    save_model_path = 'C:\\Users\\praam\\Desktop\\havetai+vetcyto\\04th-setup_task-04_new-work-03\\outputMOBILENETv3L_AutomatedAnnotation_150-pc'
    
    # Check if save path exists, if not create it
    if not os.path.exists(save_model_path):
        os.makedirs(save_model_path)
        print(f"Created output directory: {save_model_path}")
    
    # Detect device and get optimal configuration
    device, auto_batch_size, auto_image_size, device_info = get_device_config()
    
    print("\n=== Hardware Configuration ===")
    for key, value in device_info.items():
        print(f"{key}: {value}")
    
    # Count total dataset size
    dataset_size = 0
    for class_name in os.listdir(train_folder):
        class_dir = os.path.join(train_folder, class_name)
        if os.path.isdir(class_dir):
            dataset_size += len([f for f in os.listdir(class_dir) if os.path.isfile(os.path.join(class_dir, f))])
    
    # Calculate optimal parameters
    nbs = recommended_nbs(dataset_size, auto_batch_size)
    
    print(f"\n=== Dataset Information ===")
    print(f"Dataset size: {dataset_size} images")
    print(f"Optimal configuration for your GPU:")
    print(f"  - Batch size: {auto_batch_size}")
    print(f"  - Image size: 112x112 (fixed)")
    print(f"  - Batches per epoch: {nbs if nbs else 'All'}")
    
    # Allow user to override batch size and number of batches if enabled
    if ENABLE_USER_INPUT:
        user_batch = input(f"\nEnter batch size (optimal: {auto_batch_size}, press Enter to use optimal): ")
        batch_size = int(user_batch) if user_batch.strip() else auto_batch_size
        
        user_nbs = input(f"Enter number of batches per epoch (optimal: {nbs if nbs else 'All'}, press Enter to use optimal): ")
        if user_nbs.strip() and user_nbs.lower() != "all":
            nbs = int(user_nbs)
        elif user_nbs.lower() == "all":
            nbs = None
    else:
        # Use optimal values without prompting
        batch_size = auto_batch_size
    
    # Fixed image size of 112x112 as requested (not adjustable)
    image_size = 112
    
    print(f"\n=== Training Configuration ===")
    print(f"Using batch size: {batch_size}")
    print(f"Using image size: 112x112 (fixed)")
    print(f"Using batches per epoch: {nbs if nbs else 'All'}")
    
    # Save configuration for reproducibility
    config = {
        "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
        "device_info": device_info,
        "dataset_size": dataset_size,
        "batch_size": batch_size,
        "image_size": image_size,
        "batches_per_epoch": nbs if nbs else "All",
        "train_folder": train_folder,
        "val_folder": val_folder,
        "test_folder": test_folder,
        "save_path": save_model_path
    }
    
    with open(os.path.join(save_model_path, "training_config.json"), "w") as f:
        json.dump(config, f, indent=4)
    
    print("\n=== Starting Training ===")
    # Clear GPU memory before training
    if device == "cuda":
        torch.cuda.empty_cache()
    
    model = train_and_evaluate(
        train_folder, 
        val_folder, 
        test_folder, 
        save_path=save_model_path, 
        batch_size=batch_size, 
        image_size=(image_size, image_size),
        nbs=nbs
    )
    
    print("\n=== Training Complete ===")
    print(f"Model and training logs saved to: {save_model_path}")