# Clasificacion de especies de hongos

## Preparacion de entorno


### Uso en Kaggle

In [None]:
path = '/kaggle/input/mushroom1' 

### Uso en Colab u otro lado

In [None]:
import kagglehub

print("Downloading dataset (11.3GB - this will take a few minutes)...")
path = kagglehub.dataset_download("zlatan599/mushroom1")
print(f"✓ Dataset downloaded to: {path}")

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import timm
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report, accuracy_score, confusion_matrix
from tqdm.auto import tqdm
import warnings
import os
import time

In [None]:
torch.manual_seed(42)
np.random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"\n{'='*70}")
print(f"SYSTEM CONFIGURATION")
print(f"{'='*70}")
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
print(f"PyTorch version: {torch.__version__}")
print(f"{'='*70}\n")

## Carga de datasets

In [None]:
train_csv = os.path.join(path, 'train.csv')
val_csv = os.path.join(path, 'val.csv')
test_csv = os.path.join(path, 'test.csv')

train_df = pd.read_csv(train_csv)
val_df = pd.read_csv(val_csv)
test_df = pd.read_csv(test_csv)

img_col = train_df.columns[0]
label_col = train_df.columns[1]

TRAIN_SUBSET_FRACTION = 0.05
VAL_SUBSET_FRACTION = 0.2

train_df = train_df.sample(frac=TRAIN_SUBSET_FRACTION, random_state=42).reset_index(drop=True)
val_df = val_df.sample(frac=VAL_SUBSET_FRACTION, random_state=42).reset_index(drop=True)


In [None]:
print(f"\n{'='*70}")
print(f"Estadisticas del dataset")
print(f"{'='*70}")
print(f"Train: {len(train_df):,}")
print(f"Val: {len(val_df):,}")
print(f"Test: {len(test_df):,}")
print(f"Total: {len(train_df) + len(val_df) + len(test_df):,}")
print(f"Columna de imagen: '{img_col}'")
print(f"Columna de etiqueta: '{label_col}'")
print(f"{'='*70}\n")

## Preparar Etiquetas

In [None]:
# Combinar todas las etiquetas
all_labels = pd.concat([
    train_df[label_col],
    val_df[label_col],
    test_df[label_col]
])

# Encoder
label_encoder = LabelEncoder()
label_encoder.fit(all_labels)

num_classes = len(label_encoder.classes_)

print(f"\n{'='*70}")
print(f"Informacion de especies")
print(f"{'='*70}")
print(f"Especies totales: {num_classes}")
print(f"\Primeras 20 especies:")
for i, species in enumerate(label_encoder.classes_[:20], 1):
    count = (train_df[label_col] == species).sum()
    print(f"{i:2d}. {species:45s} | {count:,} samples")
if num_classes > 20:
    print(f"... and {num_classes - 20} more species")
print(f"{'='*70}\n")

## Clase para manejar el dataset

In [None]:
class MushroomDataset(Dataset):
    """
    Dataset class with robust error handling for corrupted images
    """
    def __init__(self, dataframe, root_dir, img_col, label_col,
                 label_encoder, transform=None):
        self.dataframe = dataframe.reset_index(drop=True)
        self.root_dir = root_dir
        self.img_col = img_col
        self.label_col = label_col
        self.label_encoder = label_encoder
        self.transform = transform

        # Pre-compute encoded labels for speed
        self.labels = label_encoder.transform(self.dataframe[label_col])

        # Keep track of failed loads
        self.failed_loads = 0

    def _find_and_load_image(self, img_path_from_csv, species_name):
        """
        Try to find and load image with multiple fallback strategies
        """
        filename = os.path.basename(img_path_from_csv)

        # Try multiple path strategies
        possible_paths = [
            os.path.join(self.root_dir, 'merged_dataset', species_name, filename),
            os.path.join(self.root_dir, species_name, filename),
        ]

        for path in possible_paths:
            if os.path.exists(path):
                try:
                    image = Image.open(path)
                    image = image.convert('RGB')
                    # Verify the image is valid
                    image.load()
                    return image
                except Exception as e:
                    continue

        return None

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        try:
            # Get image path and species from CSV
            img_path_from_csv = self.dataframe.iloc[idx][self.img_col]
            species_name = self.dataframe.iloc[idx][self.label_col]

            # Try to load the image
            image = self._find_and_load_image(img_path_from_csv, species_name)

            # If loading failed, create a placeholder
            if image is None:
                self.failed_loads += 1
                if self.failed_loads <= 10:  # Only print first 10 warnings
                    print(f"Warning: Could not load image for {species_name}")
                image = Image.new('RGB', (224, 224), color=(128, 128, 128))

            # Get label
            label = self.labels[idx]

            # Apply transforms
            if self.transform:
                image = self.transform(image)

            return image, label

        except Exception as e:
            # Catch-all error handler
            print(f"Error in __getitem__ at index {idx}: {e}")
            # Return a placeholder
            placeholder = Image.new('RGB', (224, 224), color=(128, 128, 128))
            if self.transform:
                placeholder = self.transform(placeholder)
            return placeholder, 0  # Return class 0 as fallback

## Transformaciones y aumentaciones

In [None]:
IMG_SIZE = 224
# Transformaciones y aumentaciones para el training
train_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                       std=[0.229, 0.224, 0.225])
])

# Transformaciones para test y validation
val_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                       std=[0.229, 0.224, 0.225])
])

## Datasets y Dataloaders

In [None]:
train_dataset = MushroomDataset(
    train_df, path, img_col, label_col, label_encoder, train_transforms
)
val_dataset = MushroomDataset(
    val_df, path, img_col, label_col, label_encoder, val_transforms
)
test_dataset = MushroomDataset(
    test_df, path, img_col, label_col, label_encoder, val_transforms
)

BATCH_SIZE =150 # Usar 200 si el modelo a entrenar es baselineb

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0,
    pin_memory=False
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    pin_memory=False
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    pin_memory=False
)

## Visualizar imagenes ejemplo

In [None]:
def show_batch(loader, label_encoder, n_images=8, title="Sample Images"):
    dataiter = iter(loader)
    images, labels = next(dataiter)

    # Denormalizar imagenes
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    images = images * std + mean
    images = torch.clamp(images, 0, 1)

    # Plot
    fig, axes = plt.subplots(2, 4, figsize=(16, 8))
    fig.suptitle(title, fontsize=16, fontweight='bold')

    for idx, ax in enumerate(axes.flat):
        if idx < min(n_images, len(images)):
            ax.imshow(images[idx].permute(1, 2, 0))
            species = label_encoder.inverse_transform([labels[idx].item()])[0]
            ax.set_title(species, fontsize=9)
            ax.axis('off')

    plt.tight_layout()
    plt.show()

print("Visualizing training samples...")
show_batch(train_loader, label_encoder, title="Training Samples")

## Definir modelo

### Definir baseline model

In [None]:
class BaselineCNN(nn.Module):

    def __init__(self, num_classes):
        super(BaselineCNN, self).__init__()

        # Bloques convulocionales
        self.features = nn.Sequential(
            # Bloque 1: 224 -> 112
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),

            # Bloque 2: 112 -> 56
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),

            # Bloque 3: 56 -> 28
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),

            # Bloque 4: 28 -> 14
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),

            # Bloque 5: 14 -> 7
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
        )

        # pooling y classifier
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(512, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

In [None]:
# Inicializar modelo
baseline_model = BaselineCNN(num_classes=num_classes)
baseline_model = baseline_model.to(device)
baseline_params = sum(p.numel() for p in baseline_model.parameters())

### Clase para definir el modelo clasificador

In [None]:
class MushroomClassifier(nn.Module):
  def __init__(self, num_classes, pretrained=True):
    super(MushroomClassifier, self).__init__()

    # Cargar EfficientNet-B0
    self.backbone = timm.create_model(
        'efficientnet_b0',
        pretrained=pretrained,
        num_classes=num_classes
    )

  def forward(self,x):
    return self.backbone(x)

In [None]:
# Inicializar el modelo
model = MushroomClassifier(num_classes=num_classes, pretrained=True)
model = model.to(device)
total_params = sum(p.numel() for p in model.parameters())

## Configuracion de training

### Configuracion para baseline

In [None]:
# Loss function
baseline_criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

# Optimizador
baseline_optimizer = optim.AdamW(
    baseline_model.parameters(),
    lr=1e-3,
    weight_decay=1e-4
)

# LR Scheduler
baseline_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    baseline_optimizer,
    mode='max',
    factor=0.5,
    patience=2,
)

# Epocas
NUM_EPOCHS_BASELINE = 15

### Configuracion para EfficientNet

In [None]:
# Loss function
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

# Optimizador
optimizer = optim.AdamW(
    model.parameters(),
    lr=3e-4,
    weight_decay=1e-4
)

# LR scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    factor=0.5,
    patience=2,
)

# Epocas
NUM_EPOCHS = 15

## Funciones de entrenamiento

In [None]:
def training_epoch(model, loader, criterion, optimizer, device, epoch):
  model.train()
  running_loss = 0.0
  correct = 0
  total = 0

  pbar = tqdm(loader, desc=f'Epoch {epoch+1:02d}/{NUM_EPOCHS} [TRAIN]')

  for batch_idx, (images, labels) in enumerate(pbar):
    images,labels = images.to(device), labels.to(device)

    # Forward pass
    optimizer.zero_grad()
    outputs = model(images)
    loss = criterion(outputs,labels)

    # Backward pass
    loss.backward()
    optimizer.step()

    # Estadisticas
    running_loss += loss.item() * images.size(0)
    _, predicted = outputs.max(1)
    total += labels.size(0)
    correct += predicted.eq(labels).sum().item()

    # Actualizar barra de progreso
    if batch_idx % 10 == 0:
      pbar.set_postfix({
          'loss': f'{running_loss/total:.4f}',
          'acc': f'{100.*correct/total:.2f}'
      })
  epoch_loss = running_loss / total
  epoch_acc = 100. * correct / total

  return epoch_loss, epoch_acc

In [None]:
def validate(model, loader, criterion, device, phase='VAL'):
  model.eval()
  running_loss = 0.0
  correct = 0
  total = 0

  pbar = tqdm(loader, desc=f'{phase:5s}')

  with torch.no_grad():
    for images, labels, in pbar:
      images, labels = images.to(device), labels.to(device)
      outputs = model(images)
      loss = criterion(outputs, labels)

      running_loss += loss.item() * images.size(0)
      _, predicted = outputs.max(1)
      total += labels.size(0)
      correct += predicted.eq(labels).sum().item()

      pbar.set_postfix({
          'loss': f'{running_loss/total:.4f}',
          'acc': f'{100.*correct/total:.2f}%'
      })

      val_loss = running_loss / total
      val_acc = 100. * correct / total
  return val_loss, val_acc

## Entrenamiento

### Entrenamiento de baseline

In [None]:
baseline_history = {
    'train_loss': [], 'train_acc': [],
    'val_loss': [], 'val_acc': [],
    'epoch_times': []
}

best_baseline_acc = 0
best_baseline_epoch = 0

print(f"\n{'='*70}")
print(f"Empezando entrenamiento")
print(f"{'='*70}")
print(f"Total epochs: {NUM_EPOCHS}")
print(f"Training samples: {len(train_dataset):,}")
print(f"Validation samples: {len(val_dataset):,}")
print(f"{'='*70}\n")

baseline_start_time = time.time()

for epoch in range(NUM_EPOCHS_BASELINE):
    epoch_start = time.time()

    # Train
    train_loss, train_acc = training_epoch(
        baseline_model, train_loader, baseline_criterion,
        baseline_optimizer, device, epoch
    )

    # Validate
    val_loss, val_acc = validate(
        baseline_model, val_loader, baseline_criterion, device, phase='VAL'
    )

    # Update scheduler
    baseline_scheduler.step(val_acc)

    epoch_time = time.time() - epoch_start

    # Save history
    baseline_history['train_loss'].append(train_loss)
    baseline_history['train_acc'].append(train_acc)
    baseline_history['val_loss'].append(val_loss)
    baseline_history['val_acc'].append(val_acc)
    baseline_history['epoch_times'].append(epoch_time)

    # Print summary
    print(f"Epoch {epoch+1:02d}/{NUM_EPOCHS_BASELINE} | "
          f"Train: {train_acc:.2f}% | "
          f"Val: {val_acc:.2f}% | "
          f"Time: {epoch_time:.1f}s | "
          f"LR: {baseline_optimizer.param_groups[0]['lr']:.6f}")
    
    print(f"\n{'='*70}")
    print(f"EPOCH {epoch+1}/{NUM_EPOCHS} SUMMARY")
    print(f"{'='*70}")
    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
    print(f"Val Loss:   {val_loss:.4f} | Val Acc:   {val_acc:.2f}%")
    print(f"Epoch Time: {epoch_time:.1f}s | LR: {baseline_optimizer.param_groups[0]['lr']:.6f}")


    # Save best
    if val_acc > best_baseline_acc:
        best_baseline_acc = val_acc
        best_baseline_epoch = epoch + 1
        torch.save({
            'epoch': epoch,
            'model_state_dict': baseline_model.state_dict(),
            'val_acc': val_acc,
            'num_classes': num_classes,      
            'label_encoder': label_encoder,  
            'img_size': IMG_SIZE             
        }, 'baseline_best.pth')
        print(f"Modelo baseline guardado")

baseline_total_time = time.time() - baseline_start_time

print(f"\n{'='*70}")
print(f"TRAINING DE BASELINE COMPLETADO")
print(f"{'='*70}")
print(f"Timepo total de training: {baseline_total_time/60:.1f} min")
print(f"AVG tiempo por epoch: {np.mean(baseline_history['epoch_times']):.1f}s")
print(f"Mejor val accuracy: {best_baseline_acc:.2f}% (Epoch {best_baseline_epoch})")
print(f"{'='*70}\n")

### Entrenamiento de EfficientNet

In [None]:
history = {
    'train_loss': [], 'train_acc': [],
    'val_loss': [], 'val_acc': [],
    'epoch_times': []
}

best_val_acc = 0
best_epoch = 0

print(f"\n{'='*70}")
print(f"Empezando entrenamiento")
print(f"{'='*70}")
print(f"Total epochs: {NUM_EPOCHS}")
print(f"Training samples: {len(train_dataset):,}")
print(f"Validation samples: {len(val_dataset):,}")
print(f"{'='*70}\n")

training_start_time = time.time()

for epoch in range(NUM_EPOCHS):
  epoch_start = time.time()

  # Train
  train_loss, train_acc = training_epoch(
      model,train_loader, criterion, optimizer, device, epoch
  )

  val_loss, val_acc = validate(
      model, val_loader, criterion, device, phase='VAL'
  )

  # Actualizar scheduler
  scheduler.step(val_acc)

  # Calcular tiempo de epoca
  epoch_time = time.time() - epoch_start

  # Historial
  history['train_loss'].append(train_loss)
  history['train_acc'].append(train_acc)
  history['val_loss'].append(val_loss)
  history['val_acc'].append(val_acc)
  history['epoch_times'].append(epoch_time)


  print(f"\n{'='*70}")
  print(f"EPOCH {epoch+1}/{NUM_EPOCHS} SUMMARY")
  print(f"{'='*70}")
  print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
  print(f"Val Loss:   {val_loss:.4f} | Val Acc:   {val_acc:.2f}%")
  print(f"Epoch Time: {epoch_time:.1f}s | LR: {optimizer.param_groups[0]['lr']:.6f}")

  # Guardar el mejor modelo

  if val_acc > best_val_acc:
    best_val_acc = val_acc
    best_epoch = epoch + 1
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'val_acc': val_acc,
        'num_classes': num_classes,      
        'label_encoder': label_encoder,  
        'img_size': IMG_SIZE             
    }, 'efficientNet_best.pth')
  print(f"{'='*70}\n")

total_training_time = time.time() - training_start_time

print(f"\n{'='*70}")
print(f"TRAINING COMPLETADO!")
print(f"{'='*70}")
print(f"Timepo total de training: {total_training_time/60:.1f} min")
print(f"AVG tiempo por epoch: {np.mean(history['epoch_times']):.1f}s")
print(f"Mejor val accuracy: {best_val_acc:.2f}% (Epoch {best_epoch})")
print(f"{'='*70}\n")


## Validar los modelos con el set de test

### Baseline

In [None]:
baseline_checkpoint = torch.load('efficientNet_best.pth')
baseline_model.load_state_dict(baseline_checkpoint['model_state_dict'])

baseline_test_loss, baseline_test_acc = validate(
    baseline_model, test_loader, baseline_criterion, device, phase='TEST'
)

print(f"\n{'='*70}")
print(f"BASELINE TEST RESULTS")
print(f"{'='*70}")
print(f"Test Accuracy: {baseline_test_acc:.2f}%")
print(f"Test Loss: {baseline_test_loss:.4f}")
print(f"{'='*70}\n")

### EfficientNet

In [None]:
checkpoint = torch.load('best_mushroom_classifier.pth',weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])


# Evaluar usando el dataset de test
test_loss, test_acc = validate(
    model, test_loader, criterion, device, phase='TEST'
)

print(f"\n{'='*70}")
print(f"EFFICIENTNET TEST RESULTS")
print(f"{'='*70}")
print(f"Test Accuracy: {test_acc:.2f}%")
print(f"Test Loss: {test_loss:.4f}")
print(f"{'='*70}\n")

## Visualizar historial del entrenamiento`

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(16, 12))
fig.suptitle('Baseline CNN vs Transfer Learning (EfficientNet-B0)',
             fontsize=16, fontweight='bold')

epochs_range = range(1, NUM_EPOCHS_BASELINE + 1)

# Training Loss
axes[0, 0].plot(epochs_range, baseline_history['train_loss'],
                'b-o', label='Baseline', linewidth=2)
axes[0, 0].plot(epochs_range, history['train_loss'],
                'r-s', label='Transfer Learning', linewidth=2)
axes[0, 0].set_xlabel('Epoch', fontsize=12)
axes[0, 0].set_ylabel('Training Loss', fontsize=12)
axes[0, 0].set_title('Training Loss Comparison', fontsize=14, fontweight='bold')
axes[0, 0].legend(fontsize=11)
axes[0, 0].grid(True, alpha=0.3)

# Validation Loss
axes[0, 1].plot(epochs_range, baseline_history['val_loss'],
                'b-o', label='Baseline', linewidth=2)
axes[0, 1].plot(epochs_range, history['val_loss'],
                'r-s', label='Transfer Learning', linewidth=2)
axes[0, 1].set_xlabel('Epoch', fontsize=12)
axes[0, 1].set_ylabel('Validation Loss', fontsize=12)
axes[0, 1].set_title('Validation Loss Comparison', fontsize=14, fontweight='bold')
axes[0, 1].legend(fontsize=11)
axes[0, 1].grid(True, alpha=0.3)

# Validation Accuracy
axes[1, 0].plot(epochs_range, baseline_history['val_acc'],
                'b-o', label='Baseline', linewidth=2)
axes[1, 0].plot(epochs_range, history['val_acc'],
                'r-s', label='Transfer Learning', linewidth=2)
axes[1, 0].axhline(y=best_baseline_acc, color='b', linestyle='--', alpha=0.5)
axes[1, 0].axhline(y=best_val_acc, color='r', linestyle='--', alpha=0.5)
axes[1, 0].set_xlabel('Epoch', fontsize=12)
axes[1, 0].set_ylabel('Validation Accuracy (%)', fontsize=12)
axes[1, 0].set_title('Validation Accuracy Comparison', fontsize=14, fontweight='bold')
axes[1, 0].legend(fontsize=11)
axes[1, 0].grid(True, alpha=0.3)

# Test Accuracy Bar Chart
models = ['Baseline\nCNN', 'Transfer\nLearning']
test_accs = [baseline_test_acc, test_acc]
colors = ['#3498db', '#2ecc71']

bars = axes[1, 1].bar(models, test_accs, color=colors, alpha=0.8, width=0.6)
axes[1, 1].set_ylabel('Test Accuracy (%)', fontsize=12)
axes[1, 1].set_title('Final Test Accuracy Comparison', fontsize=14, fontweight='bold')
axes[1, 1].set_ylim(0, 100)
axes[1, 1].grid(True, alpha=0.3, axis='y')

# Add value labels
for bar, acc in zip(bars, test_accs):
    height = bar.get_height()
    axes[1, 1].text(bar.get_x() + bar.get_width()/2., height + 2,
                    f'{acc:.1f}%', ha='center', va='bottom',
                    fontsize=14, fontweight='bold')

plt.tight_layout()
plt.savefig('baseline_vs_transfer_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

### Comparacion de modelos

In [None]:
print("="*70 + "\n")

# Create comparison table
comparison_data = {
    'Metric': [
        'Model Architecture',
        'Parameters',
        'Pre-trained Weights',
        'Training Time (min)',
        'Best Val Accuracy (%)',
        'Test Accuracy (%)',
        'Accuracy Improvement'
    ],
    'Baseline CNN': [
        'Simple 5-layer CNN',
        f'{baseline_params:,}',
        'No',
        f'{baseline_total_time/60:.1f}',
        f'{best_baseline_acc:.2f}',
        f'{baseline_test_acc:.2f}',
        'Baseline'
    ],
    'Transfer Learning': [
        'EfficientNet-B0',
        f'{total_params:,}',
        'Yes (ImageNet)',
        f'{total_training_time/60:.1f}',
        f'{best_val_acc:.2f}',
        f'{test_acc:.2f}',
        f'+{test_acc - baseline_test_acc:.2f}%'
    ]
}

comparison_df = pd.DataFrame(comparison_data)
print(comparison_df.to_string(index=False))
print("\n" + "="*70)