In [None]:
# region 1. Setup & Training Function
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, random_split, Subset, WeightedRandomSampler
from torchvision import datasets, models, transforms
from torchvision.models import EfficientNet_V2_S_Weights
from torch.utils.checkpoint import checkpoint
import matplotlib.pyplot as plt
import time
import copy
from tqdm.auto import tqdm
import numpy as np
from sklearn.metrics import confusion_matrix
import seaborn as sns
import sys
import warnings
import os
import gc

warnings.filterwarnings("ignore")

DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
if DEVICE.type == 'cuda':
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.benchmark = True
    AMP_DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16

print(f"Device: {torch.cuda.get_device_name(0)} | Prec: {AMP_DTYPE}")

class CheckpointWrapper(nn.Module):
    def __init__(self, module):
        super().__init__()
        self.module = module
    def forward(self, x):
        return checkpoint(self.module, x, use_reentrant=False)

class TransformedSubset(torch.utils.data.Dataset):
    def __init__(self, subset, transform):
        self.subset = subset
        self.transform = transform
    def __getitem__(self, index):
        x, y = self.subset[index]
        if self.transform: x = self.transform(x)
        return x, y
    def __len__(self):
        return len(self.subset)

def train_model(model, criterion, optimizer, scheduler, dataloaders, device, num_epochs, dataset_sizes, phase_name="Training", accumulation_steps=1): 
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': [], 'lr': []}
    
    print(f"\n--- {phase_name} ({num_epochs} Epochs) ---")

    for epoch in range(num_epochs):
        current_lr = optimizer.param_groups[0]['lr']
        history['lr'].append(current_lr)
        
        print(f"Epoch {epoch+1}/{num_epochs} [LR: {current_lr:.1e}]", end="")
        
        for phase in ['train', 'val']:
            model.train() if phase == 'train' else model.eval()
            running_loss, running_corrects = 0.0, 0
            
            optimizer.zero_grad(set_to_none=True)

            for i, (inputs, labels) in enumerate(dataloaders[phase]):
                inputs = inputs.to(device, memory_format=torch.channels_last, non_blocking=True)
                labels = labels.to(device, non_blocking=True)

                with torch.set_grad_enabled(phase == 'train'), \
                     torch.autocast(device_type=device.type, dtype=AMP_DTYPE, enabled=True): 
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                if torch.isnan(loss):
                    optimizer.zero_grad(set_to_none=True)
                    continue

                if phase == 'train':
                    loss = loss / accumulation_steps
                    loss.backward()
                    if (i + 1) % accumulation_steps == 0 or (i + 1) == len(dataloaders[phase]):
                        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                        optimizer.step()
                        optimizer.zero_grad(set_to_none=True)
                
                loss_val = loss.item() * accumulation_steps if phase == 'train' else loss.item()
                running_loss += loss_val * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]
            
            if phase == 'train':
                print(f" | T-Loss: {epoch_loss:.3f} Acc: {epoch_acc:.3f}", end="")
                history['train_loss'].append(epoch_loss)
                history['train_acc'].append(epoch_acc.item())
            else:
                print(f" | V-Loss: {epoch_loss:.3f} Acc: {epoch_acc:.3f}")
                history['val_loss'].append(epoch_loss)
                history['val_acc'].append(epoch_acc.item())
                if epoch_acc > best_acc:
                    best_acc = epoch_acc
                    best_model_wts = copy.deepcopy(model.state_dict())
                    try: torch.save(model.state_dict(), f'best_{phase_name.lower()[:4]}.pth')
                    except: pass
        
        if scheduler: scheduler.step()
        gc.collect()
        torch.cuda.empty_cache()

    time_elapsed = time.time() - since
    print(f"Done in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s. Best Acc: {best_acc:.4f}")
    try: model.load_state_dict(best_model_wts)
    except: pass
    return model, history
# endregion

In [None]:
# region 2. Config & Data Loading
DATASET_PATH = 'Dataset'
IMG_SIZE = 224

RUN_FEATURE_EXTRACTION = True  
RUN_FINE_TUNING = True         

MICRO_BATCH_SIZE = 8
ACCUM_STEPS_EXTRACT = 4 
ACCUM_STEPS_TUNE = 4    
    
EPOCHS_FEATURE_EXTRACT = 20
EPOCHS_FINE_TUNE = 30
    
VALIDATION_SPLIT = 0.2 
LR_FEATURE_EXTRACT = 1e-2 
LR_FINE_TUNE = 1e-4
WEIGHT_DECAY = 1e-4
LABEL_SMOOTHING = 0.1
    
PLOT_FILENAME = 'training_results.png'
CONF_MATRIX_FILE = 'confusion_matrix.png'
BEST_FE_PATH = 'best_feat.pth'
FINAL_MODEL_PATH = 'citrus_v2s_final.pth'

def prepare_data():
    weights = EfficientNet_V2_S_Weights.DEFAULT
    preprocess = weights.transforms(antialias=True)

    train_transforms = transforms.Compose([
        transforms.Resize(IMG_SIZE + 32),
        transforms.CenterCrop(IMG_SIZE),
        transforms.RandomHorizontalFlip(p=0.5),
        preprocess
    ])
    val_transforms = preprocess

    try:
        full_dataset = datasets.ImageFolder(DATASET_PATH, transform=None)
        CLASSES = sorted(full_dataset.classes)
        NUM_CLASSES = len(CLASSES)
        
        print(f"Data: {len(full_dataset)} imgs | Classes: {NUM_CLASSES}")
        print(f"Batch: {MICRO_BATCH_SIZE} (Accum: {ACCUM_STEPS_EXTRACT}/{ACCUM_STEPS_TUNE})")

        class_counts = np.bincount(full_dataset.targets)
        class_weights = [len(full_dataset) / c for c in class_counts]
        
        train_size = int((1 - VALIDATION_SPLIT) * len(full_dataset))
        val_size = len(full_dataset) - train_size
        train_indices, val_indices = random_split(range(len(full_dataset)), [train_size, val_size], generator=torch.Generator().manual_seed(42))

        train_targets = [full_dataset.targets[i] for i in train_indices]
        sample_weights = [class_weights[t] for t in train_targets]
        sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(train_indices), replacement=True)

        train_dataset = TransformedSubset(Subset(full_dataset, train_indices), train_transforms)
        val_dataset = TransformedSubset(Subset(full_dataset, val_indices), val_transforms)
        
        dataloaders_extract = {
            'train': DataLoader(train_dataset, batch_size=MICRO_BATCH_SIZE, sampler=sampler, num_workers=0, pin_memory=True),
            'val': DataLoader(val_dataset, batch_size=MICRO_BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)
        }
        dataloaders_tune = {
            'train': DataLoader(train_dataset, batch_size=MICRO_BATCH_SIZE, sampler=sampler, num_workers=0, pin_memory=True),
            'val': DataLoader(val_dataset, batch_size=MICRO_BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)
        }
        dataset_sizes = {'train': len(train_dataset), 'val': len(val_dataset)}
        return dataloaders_extract, dataloaders_tune, dataset_sizes, CLASSES, NUM_CLASSES, weights

    except Exception as e:
        print(f"Error: {e}")
        sys.exit(1)

if __name__ == '__main__':
    dl_extract, dl_tune, ds_sizes, class_names, num_classes, weights = prepare_data()
# endregion

In [None]:
# region 3. Feature Extraction
if __name__ == '__main__':
    if 'dl_extract' not in locals():
        dl_extract, dl_tune, ds_sizes, class_names, num_classes, weights = prepare_data()

    if RUN_FEATURE_EXTRACTION:
        print("\n>>> START FEATURE EXTRACTION")
        model = models.efficientnet_v2_s(weights=weights)
        
        for param in model.features.parameters():
            param.requires_grad = False
        
        for i in range(len(model.features)):
            model.features[i] = CheckpointWrapper(model.features[i])

        model.classifier = nn.Sequential(
            nn.Dropout(p=0.4, inplace=True),
            nn.Linear(model.classifier[1].in_features, num_classes)
        )
        model = model.to(DEVICE, memory_format=torch.channels_last)

        criterion = nn.CrossEntropyLoss(label_smoothing=LABEL_SMOOTHING)
        optimizer = optim.SGD(model.classifier.parameters(), lr=LR_FEATURE_EXTRACT, momentum=0.9, weight_decay=WEIGHT_DECAY)
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS_FEATURE_EXTRACT)

        model, hist_extract = train_model(
            model, criterion, optimizer, scheduler, dl_extract, DEVICE, 
            EPOCHS_FEATURE_EXTRACT, ds_sizes, "FeatExtract", ACCUM_STEPS_EXTRACT
        )
        torch.save(model.state_dict(), BEST_FE_PATH)
        
        del model, optimizer
        torch.cuda.empty_cache()
        gc.collect()
    else:
        print("\n>>> FEATURE EXTRACTION SKIPPED")
# endregion

In [None]:
# region 4. Fine Tuning
if __name__ == '__main__':
    if RUN_FINE_TUNING:
        print("\n>>> START FINE TUNING")
        torch.cuda.empty_cache()
        gc.collect()

        if 'dl_tune' not in locals():
             dl_extract, dl_tune, ds_sizes, class_names, num_classes, weights = prepare_data()
             
        model = models.efficientnet_v2_s(weights=None) 
        model.classifier = nn.Sequential(
            nn.Dropout(p=0.4, inplace=True),
            nn.Linear(model.classifier[1].in_features, num_classes)
        )
        
        for i in range(len(model.features)):
            model.features[i] = CheckpointWrapper(model.features[i])

        if os.path.exists(BEST_FE_PATH):
            model.load_state_dict(torch.load(BEST_FE_PATH, map_location='cpu'))
        
        for param in model.parameters():
            param.requires_grad = True

        model = model.to(DEVICE, memory_format=torch.channels_last)
        criterion = nn.CrossEntropyLoss(label_smoothing=LABEL_SMOOTHING)
        optimizer = optim.SGD(model.parameters(), lr=LR_FINE_TUNE, momentum=0.9, weight_decay=WEIGHT_DECAY)
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS_FINE_TUNE)

        model, hist_tune = train_model(
            model, criterion, optimizer, scheduler, dl_tune, DEVICE, 
            EPOCHS_FINE_TUNE, ds_sizes, "FineTuning", ACCUM_STEPS_TUNE
        )
        torch.save(model.state_dict(), FINAL_MODEL_PATH)
    else:
        print("\n>>> FINE TUNING SKIPPED")
# endregion

In [None]:
# region 5. Visualization
from sklearn.metrics import classification_report
if __name__ == '__main__':
    print("\n>>> VISUALIZATION")
    
    if 'hist_extract' in locals() or 'hist_tune' in locals():
        acc = hist_extract.get('val_acc', []) + hist_tune.get('val_acc', []) if 'hist_extract' in locals() else hist_tune.get('val_acc', [])
        loss = hist_extract.get('val_loss', []) + hist_tune.get('val_loss', []) if 'hist_extract' in locals() else hist_tune.get('val_loss', [])
        
        plt.figure(figsize=(12, 4))
        plt.subplot(1, 2, 1); plt.plot(acc); plt.title('Acc'); plt.grid(True, alpha=0.3)
        plt.subplot(1, 2, 2); plt.plot(loss, color='red'); plt.title('Loss'); plt.grid(True, alpha=0.3)
        plt.savefig(PLOT_FILENAME); plt.show()

    if 'model' in locals():
        model.eval()
        preds, labels = [], []
        with torch.no_grad():
            for x, y in tqdm(dl_tune['val'], desc="Eval"):
                x, y = x.to(DEVICE), y.to(DEVICE)
                with torch.autocast(device_type=DEVICE.type, dtype=torch.bfloat16, enabled=True):
                    out = model(x)
                preds.extend(torch.max(out, 1)[1].cpu().numpy())
                labels.extend(y.cpu().numpy())

        print(classification_report(labels, preds, target_names=class_names, digits=4))
        
        cm = confusion_matrix(labels, preds)
        plt.figure(figsize=(10, 8))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
        plt.title('Confusion Matrix'); plt.savefig(CONF_MATRIX_FILE); plt.show()
# endregion