In [None]:
# region 1. Setup Environment & Helper Functions
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, random_split, Subset, WeightedRandomSampler
from torchvision import datasets, models, transforms
from torchvision.models import EfficientNet_V2_L_Weights
from torch.utils.checkpoint import checkpoint
import matplotlib.pyplot as plt
import time
import copy
from tqdm.auto import tqdm
import numpy as np
from sklearn.metrics import confusion_matrix
import seaborn as sns
import sys
import warnings
import os

warnings.filterwarnings("ignore")

# --- HARDWARE SETUP (Silent Optimization) ---
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# Logika presisi tetap berjalan di background, tapi tidak di-print ke user
if hasattr(torch, 'float8_e4m3fn') and DEVICE.type == 'cuda':
    AMP_DTYPE = torch.float8_e4m3fn 
    USE_SCALER = False 
elif torch.cuda.is_bf16_supported():
    AMP_DTYPE = torch.bfloat16
    USE_SCALER = False
else:
    AMP_DTYPE = torch.float16
    USE_SCALER = True

print(f"Device: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")

class CheckpointWrapper(nn.Module):
    def __init__(self, module):
        super().__init__()
        self.module = module
    def forward(self, x):
        return checkpoint(self.module, x, use_reentrant=False)

class TransformedSubset(torch.utils.data.Dataset):
    def __init__(self, subset, transform):
        self.subset = subset
        self.transform = transform
    def __getitem__(self, index):
        x, y = self.subset[index]
        if self.transform: x = self.transform(x)
        return x, y
    def __len__(self):
        return len(self.subset)

def train_model(model, criterion, optimizer, scheduler, dataloaders, device, num_epochs, dataset_sizes, phase_name="Training"): 
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
    
    scaler = torch.amp.GradScaler('cuda') if USE_SCALER else None

    print(f"\n--- Memulai {phase_name} ({num_epochs} Epochs) ---")

    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")
        for phase in ['train', 'val']:
            model.train() if phase == 'train' else model.eval()
            running_loss, running_corrects = 0.0, 0
            
            # Progress bar bersih
            pbar = tqdm(dataloaders[phase], desc=f"{phase.capitalize()}", leave=False)

            for inputs, labels in pbar:
                inputs, labels = inputs.to(device), labels.to(device)
                optimizer.zero_grad()
                
                # Autocast berjalan silent sesuai kemampuan hardware
                with torch.set_grad_enabled(phase == 'train'), \
                     torch.autocast(device_type=device.type, dtype=AMP_DTYPE, enabled=(device.type == 'cuda')): 
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                if phase == 'train':
                    if scaler:
                        scaler.scale(loss).backward()
                        scaler.step(optimizer)
                        scaler.update()
                    else:
                        loss.backward()
                        optimizer.step()
                
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                
                # Update progress bar dengan metrics real-time
                pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{torch.sum(preds == labels.data)/inputs.size(0):.4f}'})

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]
            print(f"  {phase.capitalize()} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")

            if phase == 'train':
                history['train_loss'].append(epoch_loss)
                history['train_acc'].append(epoch_acc.item())
            else:
                history['val_loss'].append(epoch_loss)
                history['val_acc'].append(epoch_acc.item())
                if epoch_acc > best_acc:
                    best_acc = epoch_acc
                    best_model_wts = copy.deepcopy(model.state_dict())
                    torch.save(model.state_dict(), f'best_model_{phase_name.replace(" ", "_").lower()}_temp.pth')
        
        if scheduler: scheduler.step()

    time_elapsed = time.time() - since
    print(f"Selesai: {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s | Best Acc: {best_acc:.4f}")
    try: model.load_state_dict(best_model_wts)
    except: pass
    return model, history
# endregion

In [None]:
# region 2. Config & Data Loading
DATASET_PATH = 'Dataset'
IMG_SIZE = 224

# Batch Size (Sesuaikan jika OOM)
BATCH_SIZE_EXTRACT = 64  
BATCH_SIZE_TUNE = 8      
    
EPOCHS_FEATURE_EXTRACT = 10
EPOCHS_FINE_TUNE = 20
    
VALIDATION_SPLIT = 0.2 
LR_FEATURE_EXTRACT = 1e-2
LR_FINE_TUNE = 1e-3
WEIGHT_DECAY = 1e-2
LABEL_SMOOTHING = 0.1
    
PLOT_FILENAME = 'training_results.png'
CONFUSION_MATRIX_FILENAME = 'confusion_matrix.png'
BEST_MODEL_EXTRACT_PATH = 'best_model_extract.pth'
MODEL_SAVE_PATH = 'citrus_efficientnetv2l_final.pth'

print("Mempersiapkan Data...")
weights = EfficientNet_V2_L_Weights.DEFAULT
preprocess = weights.transforms(antialias=True)

train_transforms = transforms.Compose([
    transforms.Resize(IMG_SIZE + 32),
    transforms.CenterCrop(IMG_SIZE),
    preprocess
])
val_transforms = preprocess

try:
    full_dataset = datasets.ImageFolder(DATASET_PATH, transform=None)
    CLASSES = sorted(full_dataset.classes)
    NUM_CLASSES = len(CLASSES)
    print(f"Dataset: {len(full_dataset)} images | {NUM_CLASSES} classes")

    class_counts = np.bincount(full_dataset.targets)
    class_weights = [len(full_dataset) / c for c in class_counts]
    
    train_size = int((1 - VALIDATION_SPLIT) * len(full_dataset))
    val_size = len(full_dataset) - train_size
    train_indices, val_indices = random_split(range(len(full_dataset)), [train_size, val_size], generator=torch.Generator().manual_seed(42))

    train_targets = [full_dataset.targets[i] for i in train_indices]
    sample_weights = [class_weights[t] for t in train_targets]
    sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(train_indices), replacement=True)

    train_dataset = TransformedSubset(Subset(full_dataset, train_indices), train_transforms)
    val_dataset = TransformedSubset(Subset(full_dataset, val_indices), val_transforms)

    num_workers = 0 
    
    dataloaders_extract = {
        'train': DataLoader(train_dataset, batch_size=BATCH_SIZE_EXTRACT, sampler=sampler, num_workers=num_workers, pin_memory=True),
        'val': DataLoader(val_dataset, batch_size=BATCH_SIZE_EXTRACT, shuffle=False, num_workers=num_workers, pin_memory=True)
    }
    dataloaders_tune = {
        'train': DataLoader(train_dataset, batch_size=BATCH_SIZE_TUNE, sampler=sampler, num_workers=num_workers, pin_memory=True),
        'val': DataLoader(val_dataset, batch_size=BATCH_SIZE_TUNE, shuffle=False, num_workers=num_workers, pin_memory=True)
    }
    dataset_sizes = {'train': len(train_dataset), 'val': len(val_dataset)}
    class_names = CLASSES

except Exception as e:
    print(f"Error Dataset: {e}")
    sys.exit(1)
# endregion

In [None]:
# region 3. Feature Extraction
print("\n=== TAHAP 1: FEATURE EXTRACTION ===")
model_extract = models.efficientnet_v2_l(weights=weights)

# Manual Gradient Checkpointing
for i in range(len(model_extract.features)):
    model_extract.features[i] = CheckpointWrapper(model_extract.features[i])

# Freeze Backbone
for param in model_extract.features.parameters():
    param.requires_grad = False

num_ftrs = model_extract.classifier[1].in_features
model_extract.classifier = nn.Sequential(
    nn.Dropout(p=0.4, inplace=True),
    nn.Linear(num_ftrs, NUM_CLASSES)
)
model_extract = model_extract.to(DEVICE)

criterion_extract = nn.CrossEntropyLoss(label_smoothing=LABEL_SMOOTHING)
optimizer_extract = optim.AdamW(model_extract.classifier.parameters(), lr=LR_FEATURE_EXTRACT, weight_decay=WEIGHT_DECAY)
scheduler_extract = optim.lr_scheduler.CosineAnnealingLR(optimizer_extract, T_max=EPOCHS_FEATURE_EXTRACT)

model_extract, history_extract = train_model(
    model_extract, criterion_extract, optimizer_extract, scheduler_extract, dataloaders_extract, DEVICE, 
    EPOCHS_FEATURE_EXTRACT, dataset_sizes, phase_name="Feature Extraction"
)

torch.save(model_extract.state_dict(), BEST_MODEL_EXTRACT_PATH)
# endregion

In [None]:
# region 4. Fine Tuning
print("\n=== TAHAP 2: FINE-TUNING ===")
model_tune = models.efficientnet_v2_l(weights=None) 
num_ftrs_tune = model_tune.classifier[1].in_features
model_tune.classifier = nn.Sequential(
    nn.Dropout(p=0.4, inplace=True),
    nn.Linear(num_ftrs_tune, NUM_CLASSES)
)

# Re-apply Checkpointing
for i in range(len(model_tune.features)):
    model_tune.features[i] = CheckpointWrapper(model_tune.features[i])

try:
    model_tune.load_state_dict(torch.load(BEST_MODEL_EXTRACT_PATH, map_location='cpu'))
    print("Bobot dimuat.")
except:
    model_tune.load_state_dict(model_extract.state_dict())

# Unfreeze All Layers
for param in model_tune.parameters():
    param.requires_grad = True

model_tune = model_tune.to(DEVICE)

optimizer_tune = optim.AdamW(model_tune.parameters(), lr=LR_FINE_TUNE, weight_decay=WEIGHT_DECAY)
scheduler_tune = optim.lr_scheduler.CosineAnnealingLR(optimizer_tune, T_max=EPOCHS_FINE_TUNE)

model_final, history_tune = train_model(
    model_tune, criterion_extract, optimizer_tune, scheduler_tune, dataloaders_tune, DEVICE, 
    EPOCHS_FINE_TUNE, dataset_sizes, phase_name="Fine-Tuning"
)

# Merge History
combined_history = {}
if 'history_extract' in locals():
    combined_history['train_loss'] = history_extract.get('train_loss', []) + history_tune.get('train_loss', [])
    combined_history['train_acc'] = history_extract.get('train_acc', []) + history_tune.get('train_acc', [])
    combined_history['val_loss'] = history_extract.get('val_loss', []) + history_tune.get('val_loss', [])
    combined_history['val_acc'] = history_extract.get('val_acc', []) + history_tune.get('val_acc', [])

torch.save(model_final.state_dict(), MODEL_SAVE_PATH)
print(f"Model tersimpan: {MODEL_SAVE_PATH}")
# endregion

In [None]:
# region 5. Visualization & Evaluation
print(f"\nMembuat Visualisasi...")
if combined_history and combined_history.get('train_acc'):
    acc = combined_history['train_acc']
    val_acc = combined_history['val_acc']
    loss = combined_history['train_loss']
    val_loss = combined_history['val_loss']
    epochs_range = range(1, len(acc) + 1)
    
    plt.figure(figsize=(14, 5))
    
    plt.subplot(1, 2, 1)
    plt.plot(epochs_range, acc, label='Train Acc')
    plt.plot(epochs_range, val_acc, label='Val Acc')
    if 'history_extract' in locals():
         plt.axvline(len(history_extract['train_acc']) + 0.5, color='grey', linestyle='--', label='Fine-Tuning Start')
    plt.legend(loc='lower right')
    plt.title('Accuracy')
    
    plt.subplot(1, 2, 2)
    plt.plot(epochs_range, loss, label='Train Loss')
    plt.plot(epochs_range, val_loss, label='Val Loss')
    if 'history_extract' in locals():
         plt.axvline(len(history_extract['train_acc']) + 0.5, color='grey', linestyle='--', label='Fine-Tuning Start')
    plt.legend(loc='upper right')
    plt.title('Loss')

    plt.savefig(PLOT_FILENAME)
    plt.show()

print(f"\nMembuat Confusion Matrix...")
if 'model_final' in locals():
     model_final.eval()
     all_preds, all_labels = [], []
     
     with torch.no_grad():
         for inputs, labels in tqdm(dataloaders_tune['val'], desc="Evaluasi"):
             inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
             
             with torch.autocast(device_type=DEVICE.type, dtype=AMP_DTYPE, enabled=(DEVICE.type == 'cuda')):
                outputs = model_final(inputs)
                
             _, preds = torch.max(outputs, 1)
             all_preds.extend(preds.cpu().numpy())
             all_labels.extend(labels.cpu().numpy())

     cf_matrix = confusion_matrix(all_labels, all_preds)
     plt.figure(figsize=(10, 8))
     sns.heatmap(cf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
     plt.ylabel('True Label')
     plt.xlabel('Predicted Label')
     plt.tight_layout()
     plt.savefig(CONFUSION_MATRIX_FILENAME)
     plt.show()

print("\n--- Selesai ---")
# endregion