In [None]:
import torch
import torch.nn as nn
from torchvision import models, transforms
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision.datasets import ImageFolder
from torchmetrics import Accuracy, Precision, Recall, F1Score
import matplotlib.pyplot as plt
import numpy as np
import time
from tqdm import tqdm
import os
from collections import defaultdict
import logging
from datetime import datetime

# ===== Logging Setup =====
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(f"training_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

# ===== Configuration =====
DATA_DIR = "resnet_dataset_corrected"
MODEL_SAVE_PATH = "resnet50_wildlife.pth"
BATCH_SIZE = 8
EPOCHS = 15
LR = 0.001
NUM_WORKERS = 4
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

logger.info(f"\n{'='*50}")
logger.info("ResNet50 Training Setup")
logger.info(f"{'='*50}")
logger.info(f"Device: {DEVICE}")
logger.info(f"Batch size: {BATCH_SIZE}")
logger.info(f"Epochs: {EPOCHS}")
logger.info(f"Learning rate: {LR}")
logger.info(f"Data directory: {DATA_DIR}")

# ===== Data Augmentation =====
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# ===== Dataset Preparation =====
logger.info("\n[1/4] Loading datasets...")
train_dataset = ImageFolder(f"{DATA_DIR}/train", transform=train_transform)
val_dataset = ImageFolder(f"{DATA_DIR}/val", transform=val_transform)

# Class weights for imbalance
class_counts = torch.bincount(torch.tensor(train_dataset.targets))
class_weights = (1. / class_counts.float()) * len(class_counts) / 2.0
weights = class_weights[train_dataset.targets]
sampler = WeightedRandomSampler(weights, len(train_dataset))

logger.info(f"Train samples: {len(train_dataset)}")
logger.info(f"Validation samples: {len(val_dataset)}")
logger.info(f"Classes: {train_dataset.classes}")

# ===== Data Loaders =====
train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, sampler=sampler,
    num_workers=NUM_WORKERS, pin_memory=True
)
val_loader = DataLoader(
    val_dataset, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=NUM_WORKERS, pin_memory=True
)

# ===== Model Setup =====
logger.info("\n[2/4] Initializing model...")
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)

# Modify final layer
num_features = model.fc.in_features
model.fc = nn.Sequential(
    nn.Linear(num_features, 512),
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.Linear(512, len(train_dataset.classes))
)
model = model.to(DEVICE)

# ===== Training Setup =====
criterion = nn.CrossEntropyLoss(weight=class_weights.to(DEVICE))
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=2, verbose=True)

# Metrics
metrics = {
    'accuracy': Accuracy(task='multiclass', num_classes=len(train_dataset.classes)).to(DEVICE),
    'precision': Precision(task='multiclass', average='macro', num_classes=len(train_dataset.classes)).to(DEVICE),
    'recall': Recall(task='multiclass', average='macro', num_classes=len(train_dataset.classes)).to(DEVICE),
    'f1': F1Score(task='multiclass', average='macro', num_classes=len(train_dataset.classes)).to(DEVICE)
}

# ===== Training Loop =====
logger.info("\n[3/4] Starting training...")
best_val_acc = 0.0
history = defaultdict(list)
start_time = time.time()

for epoch in range(EPOCHS):
    epoch_start = time.time()
    
    # Training phase
    model.train()
    train_loss = 0.0
    
    for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Train]"):
        inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item() * inputs.size(0)
    
    # Validation phase
    model.eval()
    val_loss = 0.0
    val_outputs = []
    val_targets = []
    
    with torch.no_grad():
        for inputs, labels in tqdm(val_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Val]"):
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            val_loss += loss.item() * inputs.size(0)
            val_outputs.append(outputs)
            val_targets.append(labels)
    
    # Calculate metrics
    train_loss = train_loss / len(train_dataset)
    val_loss = val_loss / len(val_dataset)
    
    val_outputs = torch.cat(val_outputs)
    val_targets = torch.cat(val_targets)
    
    val_acc = metrics['accuracy'](val_outputs, val_targets)
    val_precision = metrics['precision'](val_outputs, val_targets)
    val_recall = metrics['recall'](val_outputs, val_targets)
    val_f1 = metrics['f1'](val_outputs, val_targets)
    
    scheduler.step(val_acc)
    
    # Save history
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc.item())
    history['val_precision'].append(val_precision.item())
    history['val_recall'].append(val_recall.item())
    history['val_f1'].append(val_f1.item())
    
    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), MODEL_SAVE_PATH)
        logger.info(f"✓ New best model saved with val_acc: {val_acc:.4f}")
    
    # Epoch summary
    epoch_time = time.time() - epoch_start
    logger.info(f"\nEpoch {epoch+1} Summary:")
    logger.info(f"Time: {epoch_time:.1f}s | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
    logger.info(f"Val Accuracy: {val_acc:.4f} | Precision: {val_precision:.4f} | Recall: {val_recall:.4f} | F1: {val_f1:.4f}")
    logger.info(f"Current LR: {optimizer.param_groups[0]['lr']:.2e}")

# ===== Training Complete =====
logger.info("\n[4/4] Training complete!")
logger.info(f"Best validation accuracy: {best_val_acc:.4f}")
logger.info(f"Model saved to: {MODEL_SAVE_PATH}")

# ===== Plotting =====
plt.figure(figsize=(12, 5))

# Loss plot
plt.subplot(1, 2, 1)
plt.plot(history['train_loss'], label='Train Loss')
plt.plot(history['val_loss'], label='Val Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

# Accuracy plot
plt.subplot(1, 2, 2)
plt.plot(history['val_acc'], label='Accuracy')
plt.plot(history['val_precision'], label='Precision')
plt.plot(history['val_recall'], label='Recall')
plt.plot(history['val_f1'], label='F1 Score')
plt.title('Validation Metrics')
plt.xlabel('Epoch')
plt.ylabel('Score')
plt.legend()

plt.tight_layout()
plt.savefig('training_metrics.png')
logger.info("✓ Training metrics plot saved to training_metrics.png")

# ===== Final Report =====
logger.info("\nFinal Report:")
logger.info(f"Total training time: {time.time() - start_time:.1f} seconds")
logger.info(f"Best validation accuracy: {best_val_acc:.4f}")
logger.info("\nClass distribution:")
for i, class_name in enumerate(train_dataset.classes):
    logger.info(f"{class_name}: {torch.sum(torch.tensor(train_dataset.targets) == i).item()} samples")
