# How to Use This Notebook: Poultry Disease Classifier (EfficientNet-B3)

This notebook trains an EfficientNet-B3 model for poultry disease image classification using a two-stage fine-tuning approach.

## 1. Prerequisites

*   Python environment with Jupyter support (Google Colab, Kaggle, local).
*   GPU recommended for faster training.
*   Kaggle account/API key for dataset download (if running locally or on Colab without native Kaggle integration).

## 2. Setup

Run the initial cells in order:

1.  **Download Dataset:** Fetches the dataset zip file from Kaggle.
2.  **Unzip Dataset:** Extracts images into the `/content/poultry-diseases` directory.
3.  **Install Dependencies:** Installs `torch`, `torchvision`, `efficientnet_pytorch`, and `thop`.
4.  **Import Libraries:** Loads necessary Python packages.

## 3. Configuration & Data Preparation

1.  **Model Setup:** Loads a pre-trained EfficientNet-B3 and adapts its final layer for 4 classes.
2.  **Data Transformations:** Defines image resizing, augmentation (for training), and normalization.
3.  **Dataset Loading & Sampling:**
    *   **CRITICAL:** Update the paths in the `ImageFolder` calls to point to your *actual* `train` and `test` directories (e.g., `/content/poultry-diseases/data/data/train`).
    *   The code samples 10,000 images per class for training (`train_subset`) and a proportionally sampled validation set (`val_subset`).
4.  **DataLoaders:** Prepares data batches for training and validation. Adjust `batch_size` based on your GPU memory.

## 4. Model Training (Two Stages)

1.  **Stage 1 (Classifier Only):**
    *   Freezes base model layers, trains only the final classifier for 17 epochs.
    *   Run the cell under the "Stage 1" comment.
2.  **Stage 2 (Full Model):**
    *   Unfreezes all layers, fine-tunes the entire model for 10 epochs with a lower learning rate.
    *   Saves the model weights (`best_model.pth`) whenever validation accuracy improves.
    *   Run the cell under the "Stage 2" comment.

**Note on Iterations:** The code executes *one* full training run (Stage 1 + Stage 2). The final markdown log shows results from *three* such iterations, which would require re-running the sampling and training cells, potentially with adjustments.

## 5. Evaluation and Analysis

After training, run the subsequent cells:

1.  **Load Best Model:** Loads the saved `best_model.pth`.
2.  **Calculate Metrics:** Computes and prints F1 scores (macro, weighted), mAP, per-class precision/recall/F1, and the confusion matrix using the validation subset. Ensure `class_names` match your dataset folders.
3.  **Calculate FLOPs:** Estimates the total computational cost (TeraFLOPs) for the entire training process (both stages across multiple iterations, as configured in the cell).
4.  **Calculate Inference Speed:** Measures the average time to classify a single image.

## 6. Interpreting Results

*   **Accuracy/Loss:** Monitor trends during training.
*   **F1/mAP/Per-Class:** Assess detailed performance, especially for imbalanced classes.
*   **Confusion Matrix:** Identify specific class confusions.
*   **FLOPs/Inference Speed:** Gauge model efficiency and prediction speed.

Downloading the whole dataset

In [None]:
!curl -L -o /content/poultry-diseases.zip\
  https://www.kaggle.com/api/v1/datasets/download/chandrashekarnatesh/poultry-diseases

Unzip the dataset

In [None]:
!unzip poultry-diseases.zip -d /content/poultry-diseases

Install all the dependencies

In [None]:
!pip install torch torchvision efficientnet_pytorch
!pip install thop

In [None]:
import torch
import random
import numpy as np
import time
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from efficientnet_pytorch import EfficientNet
from torch.utils.data import Subset, DataLoader
from sklearn.metrics import precision_recall_fscore_support, f1_score, average_precision_score, confusion_matrix
from thop import profile

In [None]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load pre-trained EfficientNet B3
model = EfficientNet.from_pretrained('efficientnet-b3').to(device)

In [None]:
# Modify the classifier
num_classes = 4
model._fc = nn.Linear(model._fc.in_features, num_classes).to(device)

In [None]:
# Data transformations for training dataset
train_transform = transforms.Compose([
    transforms.Resize((300, 300)),              # Resize images to 300x300 pixels
    transforms.RandomHorizontalFlip(),          # Randomly flip images horizontally for data augmentation
    transforms.ToTensor(),                      # Convert images to PyTorch tensors (C x H x W format)
    transforms.Normalize(                       # Normalize pixel values using ImageNet mean and std
        mean=[0.485, 0.456, 0.406],            # Mean for RGB channels
        std=[0.229, 0.224, 0.225]              # Standard deviation for RGB channels
    ),
])

# Data transformations for validation dataset
val_transform = transforms.Compose([
    transforms.Resize((300, 300)),              # Resize images to 300x300 pixels
    transforms.ToTensor(),                      # Convert images to PyTorch tensors (C x H x W format)
    transforms.Normalize(                       # Normalize pixel values using ImageNet mean and std
        mean=[0.485, 0.456, 0.406],            # Mean for RGB channels
        std=[0.229, 0.224, 0.225]              # Standard deviation for RGB channels
    ),
])

In [None]:
# Define the number of classes in the dataset
num_classes = 4  # Poultry disease dataset has 4 distinct classes

# Initialize training dataset
# ImageFolder loads images from the specified directory, applying train_transform
train_dataset = ImageFolder(root='path_to/poultry-diseases/data/data/train', transform=train_transform)

# Initialize validation dataset
# ImageFolder loads images from the specified directory, applying val_transform
val_dataset = ImageFolder(root='path_to/poultry-diseases/data/data/test', transform=val_transform)

# Get training dataset details
train_targets = train_dataset.targets  # Extract list of class labels for training images
# Create dictionary mapping each class index to indices of images belonging to that class
train_class_indices = {
    class_idx: [i for i, t in enumerate(train_targets) if t == class_idx]
    for class_idx in range(num_classes)
}
# Calculate number of training images per class
T_per_class = [len(train_class_indices[class_idx]) for class_idx in range(num_classes)]
# Print the number of training images for each class (expected ~100,000 per class)
print("Original training images per class:", T_per_class)

# Get validation dataset details
val_targets = val_dataset.targets  # Extract list of class labels for validation images
# Create dictionary mapping each class index to indices of images belonging to that class
val_class_indices = {
    class_idx: [i for i, t in enumerate(val_targets) if t == class_idx]
    for class_idx in range(num_classes)
}
# Calculate number of validation images per class
V_per_class = [len(val_class_indices[class_idx]) for class_idx in range(num_classes)]
# Print the number of validation images for each class (size may vary)
print("Original validation images per class:", V_per_class)

In [None]:
# Load the training dataset using ImageFolder, applying the transformation defined for training data
train_dataset = ImageFolder(root='path_to/poultry-diseases/data/data/train', transform=train_transform)

# Get the class labels (targets) from the training dataset
targets = train_dataset.targets

# Create a dictionary where the key is the class index and the value is a list of image indices for that class
class_indices = {class_idx: [i for i, t in enumerate(targets) if t == class_idx] for class_idx in range(num_classes)}

# Initialize an empty list to store the indices of the sampled images
sampled_indices = []

# For each class, sample a specified number of images (10000 in this case, but adjust according to your needs)
for class_idx in range(num_classes):
    class_idx_list = class_indices[class_idx]
    sampled = random.sample(class_idx_list, 10000)  # Sample 10000 images per class for the first iteration
    sampled_indices.extend(sampled)  # Add the sampled indices to the list

# Create a subset of the training dataset using the sampled indices, to ensure only the sampled images are used
train_subset = Subset(train_dataset, sampled_indices)

# Load the validation dataset (using the full test set as validation data here)
val_dataset = ImageFolder(root='path_to/poultry-diseases/data/data/test', transform=val_transform)

# Define the batch size for training and validation
batch_size = 32

# Create a DataLoader for the training subset, which will load the training data in batches, shuffle, and use multiple workers
train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)

# Create a DataLoader for the validation dataset, which will load the validation data in batches without shuffling
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)


In [None]:
# Number of sampled images per class for training
t = 10000

# Initialize an empty list to store the indices of the sampled images for training
train_sampled_indices = []

# Loop through each class index (from 0 to num_classes - 1)
for class_idx in range(num_classes):
    # Randomly sample 't' images from the indices of the current class
    sampled = random.sample(train_class_indices[class_idx], t)
    # Add the sampled indices to the list of training sampled indices
    train_sampled_indices.extend(sampled)

# Create a subset of the training dataset containing only the sampled indices
train_subset = Subset(train_dataset, train_sampled_indices)

# Verify the number of images per class in the training subset
print("Training subset images per class:")
for class_idx in range(num_classes):
    # Count the number of images in train_sampled_indices belonging to the current class
    count = len([i for i in train_sampled_indices if train_targets[i] == class_idx])
    print(f"Class {class_idx}: {count}")


Training subset images per class:
Class 0: 10000
Class 1: 10000
Class 2: 10000
Class 3: 10000


In [None]:
# Initialize an empty list to store the indices of the sampled images for validation
val_sampled_indices = []

# Loop through each class index (from 0 to num_classes - 1)
for class_idx in range(num_classes):
    # Calculate the proportion of images to sample for this class based on the total number of images 't'
    proportion = t / T_per_class[class_idx]  # e.g., 2500 / 100000 = 0.025
    
    # Calculate the number of images to sample for this class based on its proportion
    v = int(V_per_class[class_idx] * proportion)  # Number to sample
    
    # If the calculated number of images to sample exceeds the available images, cap it
    if v > V_per_class[class_idx]:
        v = V_per_class[class_idx]
    # Ensure that at least one image is sampled for each class
    elif v < 1:
        v = 1
    
    # Randomly sample 'v' images from the indices of the current class
    sampled = random.sample(val_class_indices[class_idx], v)
    
    # Add the sampled indices to the list of validation sampled indices
    val_sampled_indices.extend(sampled)

# Create a subset of the validation dataset containing only the sampled indices
val_subset = Subset(val_dataset, val_sampled_indices)

# Verify the number of images per class in the validation subset
print("Validation subset images per class:")
for class_idx in range(num_classes):
    # Count the number of images in val_sampled_indices belonging to the current class
    count = len([i for i in val_sampled_indices if val_targets[i] == class_idx])
    print(f"Class {class_idx}: {count}")


Validation subset images per class:
Class 0: 1875
Class 1: 1741
Class 2: 1588
Class 3: 1862


In [None]:
# Define the batch size for training and validation
batch_size = 32  # You can adjust this depending on your system's memory capacity

# Create a DataLoader for the training subset
train_loader = DataLoader(
    train_subset,  # The training data subset containing the sampled images
    batch_size=batch_size,  # Number of samples per batch
    shuffle=True,  # Shuffle the data at the start of each epoch to improve generalization
    num_workers=2,  # Number of subprocesses to load the data (adjust based on your machine's capabilities)
    pin_memory=True  # Pin the data in memory to speed up the transfer to GPU (if using a GPU)
)

# Create a DataLoader for the validation subset
val_loader = DataLoader(
    val_subset,  # The validation data subset containing the sampled images
    batch_size=batch_size,  # Number of samples per batch
    shuffle=False,  # Do not shuffle validation data as it's not necessary for evaluation
    num_workers=2,  # Number of subprocesses to load the data
    pin_memory=True  # Pin the data in memory to speed up the transfer to GPU
)


In [None]:
# Loss function: CrossEntropyLoss for multi-class classification
criterion = nn.CrossEntropyLoss()

# Import necessary components for mixed precision training
from torch.cuda.amp import GradScaler, autocast

# Initialize a GradScaler to help with mixed precision
scaler = GradScaler()

# Training function
def train(model, device, train_loader, optimizer, criterion):
    model.train()  # Set the model to training mode
    running_loss = 0.0  # Initialize running loss

    # Loop through the training data in batches
    for data, target in train_loader:
        data, target = data.to(device), target.to(device)  # Move data and target to the device (GPU/CPU)

        optimizer.zero_grad()  # Clear previous gradients
        with autocast('cuda'):  # Enable mixed precision for faster computation on GPUs
            output = model(data)  # Forward pass through the model
            loss = criterion(output, target)  # Compute the loss

        # Scale the loss and backpropagate
        scaler.scale(loss).backward()

        # Update model parameters with gradient scaling
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item()  # Accumulate the loss for monitoring
    # Return average loss for the epoch
    return running_loss / len(train_loader)


# Evaluation function (for validation or test set)
def evaluate(model, device, val_loader, criterion):
    model.eval()  # Set the model to evaluation mode (disables dropout, etc.)
    val_loss = 0.0  # Initialize validation loss
    correct = 0  # Initialize correct predictions counter
    total = 0  # Initialize total number of samples

    # No gradients needed during evaluation
    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)  # Move data and target to the device

            output = model(data)  # Forward pass through the model
            val_loss += criterion(output, target).item()  # Compute the loss

            # Get the predicted class for each sample
            pred = output.argmax(dim=1, keepdim=True)  # Get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()  # Count correct predictions
            total += target.size(0)  # Count total samples

    # Compute the average validation loss
    val_loss /= len(val_loader)
    
    # Calculate accuracy as a percentage
    accuracy = 100. * correct / total
    
    # Return validation loss and accuracy
    return val_loss, accuracy


In [None]:
# Stage 1: Train only the classifier (freeze other layers)
# Freeze all layers in the model by setting 'requires_grad' to False
for param in model.parameters():
    param.requires_grad = False

# Enable gradient computation for the fully connected (classifier) layer
for param in model._fc.parameters():
    param.requires_grad = True

# Use the AdamW optimizer to train only the classifier with a small learning rate
optimizer = torch.optim.AdamW(model._fc.parameters(), lr=0.0001)

# Loop through the epochs for training and validation
for epoch in range(1, 18):  # Training for 17 epochs (1 to 17)
    # Train the model on the current epoch
    train_loss = train(model, device, train_loader, optimizer, criterion)
    
    # Evaluate the model on the validation set
    val_loss, val_accuracy = evaluate(model, device, val_loader, criterion)
    
    # Print the training and validation results for the current epoch
    print(f'Stage 1, Epoch {epoch}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%')


Stage 1, Epoch 1: Train Loss: 0.3497, Val Loss: 0.2797, Val Accuracy: 89.72%
Stage 1, Epoch 2: Train Loss: 0.3466, Val Loss: 0.2790, Val Accuracy: 89.83%
Stage 1, Epoch 3: Train Loss: 0.3426, Val Loss: 0.2784, Val Accuracy: 89.82%
Stage 1, Epoch 4: Train Loss: 0.3427, Val Loss: 0.2782, Val Accuracy: 89.86%
Stage 1, Epoch 5: Train Loss: 0.3423, Val Loss: 0.2790, Val Accuracy: 89.81%
Stage 1, Epoch 6: Train Loss: 0.3392, Val Loss: 0.2774, Val Accuracy: 89.91%
Stage 1, Epoch 7: Train Loss: 0.3452, Val Loss: 0.2780, Val Accuracy: 89.85%
Stage 1, Epoch 8: Train Loss: 0.3402, Val Loss: 0.2785, Val Accuracy: 89.82%
Stage 1, Epoch 9: Train Loss: 0.3459, Val Loss: 0.2783, Val Accuracy: 89.82%
Stage 1, Epoch 10: Train Loss: 0.3410, Val Loss: 0.2775, Val Accuracy: 89.85%
Stage 1, Epoch 11: Train Loss: 0.3364, Val Loss: 0.2785, Val Accuracy: 89.85%
Stage 1, Epoch 12: Train Loss: 0.3401, Val Loss: 0.2776, Val Accuracy: 89.85%
Stage 1, Epoch 13: Train Loss: 0.3346, Val Loss: 0.2786, Val Accuracy: 89

In [None]:
# Stage 2: Fine-tune the entire model (unfreeze all layers)
# Unfreeze all layers by setting 'requires_grad' to True
for param in model.parameters():
    param.requires_grad = True

# Use the SGD optimizer with a smaller learning rate and momentum
optimizer = optim.SGD(model.parameters(), lr=0.0001, momentum=0.8496)

# Initialize variable to track the best validation accuracy
best_accuracy = 0.0

# Loop through the epochs for training and validation
for epoch in range(1, 11):  # Fine-tuning for 10 epochs (1 to 10)
    # Train the model on the current epoch
    train_loss = train(model, device, train_loader, optimizer, criterion)
    
    # Evaluate the model on the validation set
    val_loss, val_accuracy = evaluate(model, device, val_loader, criterion)
    
    # Print the training and validation results for the current epoch
    print(f'Stage 2, Epoch {epoch}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%')

    # If the current validation accuracy is better than the best seen so far, save the model
    if val_accuracy > best_accuracy:
        best_accuracy = val_accuracy
        torch.save(model.state_dict(), 'best_model.pth')

# Print the final result
print(f"Training complete. Best validation accuracy: {best_accuracy:.2f}%")


Stage 2, Epoch 1: Train Loss: 0.3317, Val Loss: 0.2703, Val Accuracy: 90.07%
Stage 2, Epoch 2: Train Loss: 0.3291, Val Loss: 0.2642, Val Accuracy: 90.29%
Stage 2, Epoch 3: Train Loss: 0.3156, Val Loss: 0.2590, Val Accuracy: 90.52%
Stage 2, Epoch 4: Train Loss: 0.3079, Val Loss: 0.2530, Val Accuracy: 90.74%
Stage 2, Epoch 5: Train Loss: 0.2981, Val Loss: 0.2497, Val Accuracy: 90.87%
Stage 2, Epoch 6: Train Loss: 0.2943, Val Loss: 0.2447, Val Accuracy: 91.01%
Stage 2, Epoch 7: Train Loss: 0.2833, Val Loss: 0.2401, Val Accuracy: 91.16%
Stage 2, Epoch 8: Train Loss: 0.2823, Val Loss: 0.2367, Val Accuracy: 91.32%
Stage 2, Epoch 9: Train Loss: 0.2769, Val Loss: 0.2331, Val Accuracy: 91.46%
Stage 2, Epoch 10: Train Loss: 0.2705, Val Loss: 0.2294, Val Accuracy: 91.59%
Training complete. Best validation accuracy: 91.59%


In [None]:
# 1. Load your trained model
model.load_state_dict(torch.load('best_model.pth', map_location=device, weights_only=True))
model.eval()  # Set to evaluation mode

# 2. Create a data loader for your validation/test set
# (we are using the above subset of validation/test dataset)

# 3. Run inference on validation/test data
all_preds = []
all_probs = []  # For storing probability outputs
all_targets = []

with torch.no_grad():
    for images, targets in val_loader:  # Use your validation or test dataloader
        images = images.to(device)
        outputs = model(images)

        # Get predicted class
        _, preds = torch.max(outputs, 1)

        # Get probabilities
        probs = torch.nn.functional.softmax(outputs, dim=1)

        all_preds.extend(preds.cpu().numpy())
        all_probs.append(probs.cpu().numpy())
        all_targets.extend(targets.numpy())

# Convert to numpy arrays
all_preds = np.array(all_preds)
all_probs = np.vstack(all_probs)  # Stack all probability outputs
all_targets = np.array(all_targets)

# 4. Calculate metrics
# F1 score (macro - unweighted average across classes)
f1_macro = f1_score(all_targets, all_preds, average='macro')
# F1 score (weighted by support - number of true instances for each class)
f1_weighted = f1_score(all_targets, all_preds, average='weighted')

# Calculate per-class precision, recall, F1
precision, recall, f1_per_class, support = precision_recall_fscore_support(all_targets, all_preds)

# 5. Calculate mAP (Mean Average Precision)
# For multi-class problems, calculate AP for each class and then average
num_classes = all_probs.shape[1]
average_precisions = []

for i in range(num_classes):
    # Create binary labels for this class (one-vs-rest approach)
    binary_targets = (all_targets == i).astype(int)
    # Get predicted probabilities for this class
    class_probs = all_probs[:, i]
    # Calculate average precision
    ap = average_precision_score(binary_targets, class_probs)
    average_precisions.append(ap)

# Calculate mAP
mAP = np.mean(average_precisions)

# 6. Print results
print(f"F1 Score (macro): {f1_macro:.4f}")
print(f"F1 Score (weighted): {f1_weighted:.4f}")
print(f"mAP Score: {mAP:.4f}")

# 7. Print per-class metrics
class_names = ['Coccidiosis', 'Healthy', 'New Castle Disease', 'Salmonella']  # Replace with your actual class names
print("\nPer-class metrics:")
for i in range(num_classes):
    print(f"{class_names[i]}: Precision: {precision[i]:.4f}, Recall: {recall[i]:.4f}, F1: {f1_per_class[i]:.4f}, Support: {support[i]}")

# 8. Generate confusion matrix
cm = confusion_matrix(all_targets, all_preds)
print("\nConfusion Matrix:")
print(cm)
# 9. Calculate FLOPs
# Get a sample input from train_loader (adjust shape as needed)
sample_input, _ = next(iter(train_loader))
sample_input = sample_input.to(device)

# Profile the forward pass FLOPs (multiply-add operations)
flops_forward, _ = profile(model, inputs=(sample_input[:1],))  # Use batch size 1 for simplicity

# Estimate FLOPs per training step (forward + backward)
flops_per_train_step = 2 * flops_forward  # Backward pass ≈ forward pass

# Total FLOPs for training (5 epochs)
num_train_steps = len(train_loader)
num_val_steps = len(val_loader)
total_train_flops = 5 * num_train_steps * flops_per_train_step

# Total FLOPs for evaluation (5 epochs)
total_val_flops = 5 * num_val_steps * flops_forward

# Total FLOPs for the snippet
total_flops = total_train_flops + total_val_flops

print(f"Total FLOPs: {total_flops:.2e}")

total_tflops = total_flops / 1e12
print(f"Total TFLOPs: {total_tflops:.4f}")

# 10. Calculate Inference Speed
# Optional: Warm up the model to mitigate any initial overhead
with torch.no_grad():
    _ = model(sample_input[:1])

total_inference_time = 0.0
total_samples = 0

with torch.no_grad():
    for images, _ in val_loader:
        images = images.to(device)
        batch_size = images.size(0)
        start_time = time.time()
        _ = model(images)
        batch_time = time.time() - start_time
        total_inference_time += batch_time
        total_samples += batch_size

# Compute average inference time per sample
avg_inference_time_per_sample = total_inference_time / total_samples
print(f"Average inference time per sample: {avg_inference_time_per_sample:.6f} seconds")

In [None]:
# Load pre-trained EfficientNet B3
model = EfficientNet.from_pretrained('efficientnet-b3').to(device)
model._fc = nn.Linear(model._fc.in_features, num_classes).to(device)
model_path = 'best_model.pth'
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval() # Set to evaluation mode
#  Calculate FLOPs
# Get a sample input from train_loader (adjust shape as needed)
sample_input, _ = next(iter(train_loader))
sample_input = sample_input.to(device)

# Profile the forward pass FLOPs (multiply-add operations)
flops_forward, _ = profile(model, inputs=(sample_input[:1],))  # Use batch size 1 for simplicity

# Stage 1: Classifier only (17 epochs)
# For classifier-only training, we need separate profiling since only classifier is active
# First, set all parameters to not require gradients except classifier
for param in model.parameters():
    param.requires_grad = False
for param in model._fc.parameters():
    param.requires_grad = True

# Profile classifier-only forward pass
flops_forward_classifier_only, _ = profile(model, inputs=(sample_input[:1],))

# Estimate FLOPs per training step for classifier-only training
# Backward pass cost will be much lower since we're only computing gradients for classifier
flops_per_train_step_classifier = flops_forward_classifier_only + (flops_forward_classifier_only * 0.05)  # Approximate backward pass cost

# Stage 2: Full model (10 epochs)
# Reset all parameters to require gradients
for param in model.parameters():
    param.requires_grad = True

# Estimate FLOPs per training step for full model training
flops_per_train_step_full = flops_forward * 2  # Backward pass ≈ forward pass for full model

# Calculate total FLOPs for both stages (per iteration)
num_train_steps = len(train_loader)
num_val_steps = len(val_loader)

# Stage 1: 17 epochs of classifier-only training
num_epochs_stage1 = 17
total_train_flops_stage1 = num_epochs_stage1 * num_train_steps * flops_per_train_step_classifier
total_val_flops_stage1 = num_epochs_stage1 * num_val_steps * flops_forward

# Stage 2: 10 epochs of full model training
num_epochs_stage2 = 10
total_train_flops_stage2 = num_epochs_stage2 * num_train_steps * flops_per_train_step_full
total_val_flops_stage2 = num_epochs_stage2 * num_val_steps * flops_forward

# Total FLOPs for one complete iteration (Stage 1 + Stage 2)
total_flops_per_iteration = (total_train_flops_stage1 + total_val_flops_stage1 +
                            total_train_flops_stage2 + total_val_flops_stage2)

# Total FLOPs for all iterations
num_iterations = 3  # Original + 2 more iterations = 3 total
total_flops = num_iterations * total_flops_per_iteration

print(f"FLOPs per iteration (Stage 1 + Stage 2): {total_flops_per_iteration:.2e}")
print(f"Total FLOPs for {num_iterations} iterations: {total_flops:.2e}")

total_tflops = total_flops / 1e12
print(f"Total TFLOPs: {total_tflops:.4f}")

# Breakdown by stage and iteration
total_train_flops_stage1_all = num_iterations * total_train_flops_stage1
total_train_flops_stage2_all = num_iterations * total_train_flops_stage2
total_train_tflops_stage1 = total_train_flops_stage1_all / 1e12
total_train_tflops_stage2 = total_train_flops_stage2_all / 1e12

print(f"Stage 1 (Classifier) training TFLOPs ({num_iterations} iterations): {total_train_tflops_stage1:.4f}")
print(f"Stage 2 (Full model) training TFLOPs ({num_iterations} iterations): {total_train_tflops_stage2:.4f}")
print(f"Combined training TFLOPs ({num_iterations} iterations): {(total_train_tflops_stage1 + total_train_tflops_stage2):.4f}")

# Neural Network Training Log Summary

## Iteration 1
### Stage 1

| Epoch | Train Loss | Val Loss | Val Accuracy |
|-------|------------|----------|-------------|
| 1     | 1.2405     | 1.0812   | 68.67%      |
| 2     | 1.0297     | 0.9346   | 70.76%      |
| 3     | 0.9250     | 0.8526   | 71.78%      |
| 4     | 0.8644     | 0.8003   | 73.14%      |
| 5     | 0.8236     | 0.7646   | 74.16%      |
| 6     | 0.7912     | 0.7384   | 74.50%      |
| 7     | 0.7747     | 0.7171   | 75.07%      |
| 8     | 0.7557     | 0.7012   | 75.35%      |
| 9     | 0.7486     | 0.6900   | 76.03%      |
| 10    | 0.7248     | 0.6788   | 75.86%      |
| 11    | 0.7291     | 0.6698   | 75.98%      |
| 12    | 0.7120     | 0.6636   | 75.86%      |
| 13    | 0.7082     | 0.6560   | 75.98%      |
| 14    | 0.7063     | 0.6507   | 76.20%      |
| 15    | 0.6973     | 0.6465   | 76.54%      |
| 16    | 0.6969     | 0.6404   | 76.37%      |
| 17    | 0.6843     | 0.6380   | 76.49%      |

### Stage 2

| Epoch | Train Loss | Val Loss | Val Accuracy |
|-------|------------|----------|-------------|
| 1     | 0.6763     | 0.6054   | 77.85%      |
| 2     | 0.6431     | 0.5789   | 78.92%      |
| 3     | 0.6226     | 0.5537   | 79.77%      |
| 4     | 0.5870     | 0.5331   | 79.89%      |
| 5     | 0.5710     | 0.5161   | 80.45%      |
| 6     | 0.5565     | 0.5019   | 81.08%      |
| 7     | 0.5241     | 0.4871   | 81.76%      |
| 8     | 0.5173     | 0.4745   | 82.44%      |
| 9     | 0.5096     | 0.4623   | 82.21%      |
| 10    | 0.4844     | 0.4533   | 82.44%      |

**Best validation accuracy: 82.44%**

## Iteration 2
### Stage 1

| Epoch | Train Loss | Val Loss | Val Accuracy |
|-------|------------|----------|-------------|
| 1     | 0.5280     | 0.4486   | 83.92%      |
| 2     | 0.5166     | 0.4442   | 84.01%      |
| 3     | 0.5107     | 0.4411   | 84.09%      |
| 4     | 0.5071     | 0.4371   | 84.18%      |
| 5     | 0.5045     | 0.4346   | 84.04%      |
| 6     | 0.5032     | 0.4352   | 83.99%      |
| 7     | 0.5011     | 0.4327   | 84.15%      |
| 8     | 0.4982     | 0.4298   | 84.32%      |
| 9     | 0.4984     | 0.4282   | 84.38%      |
| 10    | 0.4981     | 0.4276   | 84.42%      |
| 11    | 0.4985     | 0.4279   | 84.49%      |
| 12    | 0.4981     | 0.4273   | 84.50%      |
| 13    | 0.4947     | 0.4254   | 84.66%      |
| 14    | 0.4943     | 0.4249   | 84.67%      |
| 15    | 0.4935     | 0.4216   | 84.67%      |
| 16    | 0.4945     | 0.4225   | 84.45%      |
| 17    | 0.4936     | 0.4224   | 84.31%      |

### Stage 2

| Epoch | Train Loss | Val Loss | Val Accuracy |
|-------|------------|----------|-------------|
| 1     | 0.4783     | 0.3891   | 85.79%      |
| 2     | 0.4450     | 0.3650   | 86.54%      |
| 3     | 0.4213     | 0.3470   | 87.22%      |
| 4     | 0.3990     | 0.3333   | 87.73%      |
| 5     | 0.3849     | 0.3183   | 88.30%      |
| 6     | 0.3702     | 0.3091   | 88.72%      |
| 7     | 0.3550     | 0.2983   | 89.03%      |
| 8     | 0.3481     | 0.2901   | 89.27%      |
| 9     | 0.3330     | 0.2826   | 89.53%      |
| 10    | 0.3278     | 0.2757   | 89.74%      |

**Best validation accuracy: 89.74%**

## Iteration 3
### Stage 1

| Epoch | Train Loss | Val Loss | Val Accuracy |
|-------|------------|----------|-------------|
| 1     | 0.3497     | 0.2797   | 89.72%      |
| 2     | 0.3466     | 0.2790   | 89.83%      |
| 3     | 0.3426     | 0.2784   | 89.82%      |
| 4     | 0.3427     | 0.2782   | 89.86%      |
| 5     | 0.3423     | 0.2790   | 89.81%      |
| 6     | 0.3392     | 0.2774   | 89.91%      |
| 7     | 0.3452     | 0.2780   | 89.85%      |
| 8     | 0.3402     | 0.2785   | 89.82%      |
| 9     | 0.3459     | 0.2783   | 89.82%      |
| 10    | 0.3410     | 0.2775   | 89.85%      |
| 11    | 0.3364     | 0.2785   | 89.85%      |
| 12    | 0.3401     | 0.2776   | 89.85%      |
| 13    | 0.3346     | 0.2786   | 89.83%      |
| 14    | 0.3372     | 0.2773   | 89.80%      |
| 15    | 0.3404     | 0.2774   | 89.88%      |
| 16    | 0.3368     | 0.2774   | 89.82%      |
| 17    | 0.3386     | 0.2773   | 89.86%      |

### Stage 2

| Epoch | Train Loss | Val Loss | Val Accuracy |
|-------|------------|----------|-------------|
| 1     | 0.2765     | 0.1854   | 93.15%      |
| 2     | 0.1544     | 0.1546   | 94.51%      |
| 3     | 0.1023     | 0.1485   | 95.09%      |
| 4     | 0.0720     | 0.1415   | 95.08%      |
| 5     | 0.0564     | 0.1552   | 95.09%      |
| 6     | 0.0431     | 0.1788   | 94.88%      |
| 7     | 0.0389     | 0.1643   | 94.92%      |
| 8     | 0.0297     | 0.1760   | 95.29%      |
| 9     | 0.0285     | 0.1839   | 95.16%      |
| 10    | 0.0277     | 0.1729   | 95.56%      |

**Best validation accuracy: 95.56%**

### Progress Summary

| Iteration | Final Val Accuracy | Improvement |
|-----------|-------------------|-------------|
| 1         | 82.44%            | -           |
| 2         | 89.74%            | +7.30%      |
| 3         | 95.56%            | +5.62%      |