In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from tqdm import tqdm
import os

# ========================================
# CONFIG
# ========================================
DATA_DIR = r"D:/viot/Data__Split"
BATCH_SIZE = 32
NUM_CLASSES = 3
NUM_EPOCHS = 25
DEVICE = torch.device("cpu")  # CPU only
LR = 1e-3
PATIENCE = 4
CHECKPOINT_PATH = "checkpoint.pth"

torch.set_num_threads(8)
torch.set_float32_matmul_precision("medium")

# ========================================
# DATA
# ========================================
train_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25),
    transforms.ToTensor(),
])
val_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
])

train_dataset = datasets.ImageFolder(f"{DATA_DIR}/train", transform=train_transform)
val_dataset   = datasets.ImageFolder(f"{DATA_DIR}/val", transform=val_transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader   = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

# ========================================
# MODEL
# ========================================
model = models.efficientnet_b0(weights="IMAGENET1K_V1")
for param in model.features[:-1].parameters():
    param.requires_grad = False

in_features = model.classifier[1].in_features
model.classifier = nn.Sequential(
    nn.Dropout(p=0.4),
    nn.Linear(in_features, NUM_CLASSES)
)
model = model.to(DEVICE)

criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = optim.Adam(model.parameters(), lr=LR, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)

# ========================================
# LOAD CHECKPOINT (if exists)
# ========================================
start_epoch = 0
best_val_acc = 0
if os.path.exists(CHECKPOINT_PATH):
    print(f" Loading checkpoint from {CHECKPOINT_PATH}...")
    checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    best_val_acc = checkpoint['best_val_acc']
    print(f" Resumed from epoch {start_epoch}, best val acc {best_val_acc:.4f}")

# ========================================
# TRAIN LOOP
# ========================================
no_improve = 0
for epoch in range(start_epoch, NUM_EPOCHS):
    print(f"\n=== Epoch {epoch+1}/{NUM_EPOCHS} ===")

    model.train()
    running_loss, correct, total = 0, 0, 0
    loop = tqdm(train_loader, desc="Training", leave=False)
    for imgs, labels in loop:
        imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * imgs.size(0)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
        loop.set_postfix(loss=loss.item(), acc=f"{correct/total:.4f}")

    train_acc = correct / total
    train_loss = running_loss / total

    # ---------- VALIDATION ----------
    model.eval()
    correct_val, total_val = 0, 0
    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
            outputs = model(imgs)
            _, preds = torch.max(outputs, 1)
            correct_val += (preds == labels).sum().item()
            total_val += labels.size(0)
    val_acc = correct_val / total_val
    scheduler.step()

    print(f"Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f}")

    # ---------- SAVE CHECKPOINT ----------
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'best_val_acc': best_val_acc
    }, CHECKPOINT_PATH)

    print(" Checkpoint saved!")

    # ---------- SAVE BEST MODEL ----------
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), "best_cpu_model.pth")
        print(f" New best model saved with Val Acc: {val_acc:.4f}")
        no_improve = 0
    else:
        no_improve += 1
        if no_improve >= PATIENCE:
            print(" Early stopping triggered.")
            break

print("\n Training complete!")
print(f" Best Validation Accuracy: {best_val_acc:.4f}")


 Loading checkpoint from checkpoint.pth...
 Resumed from epoch 23, best val acc 0.9917

=== Epoch 24/25 ===


                                                                                     

Train Acc: 0.9850 | Val Acc: 0.9915
 Checkpoint saved!

=== Epoch 25/25 ===


                                                                                     

Train Acc: 0.9846 | Val Acc: 0.9917
 Checkpoint saved!

 Training complete!
 Best Validation Accuracy: 0.9917


In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from tqdm import tqdm

# ==============================
# CONFIG
# ==============================
DATA_DIR = r"D:/viot/Data_Split_Stage2"  
BATCH_SIZE = 16
NUM_EPOCHS = 20
DEVICE = torch.device("cpu")  
LR = 1e-3
PATIENCE = 6
CHECKPOINT_PATH = "disease_stage2_checkpoint.pth"
BEST_MODEL_PATH = "disease_stage2_best_model.pth"

torch.set_num_threads(8)
torch.set_float32_matmul_precision("medium")

# ==============================
# DATASET AND TRANSFORMS
# ==============================
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25),
    transforms.ToTensor(),
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

train_dataset = datasets.ImageFolder(os.path.join(DATA_DIR, "train"), transform=train_transform)
val_dataset = datasets.ImageFolder(os.path.join(DATA_DIR, "val"), transform=val_transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

NUM_CLASSES = len(train_dataset.classes)
print(f" Number of specific disease classes: {NUM_CLASSES}")
print(" Disease class names:")
for cls in train_dataset.classes:
    print("-", cls)

# ==============================
# MODEL
# ==============================
model = models.efficientnet_b0(weights="IMAGENET1K_V1")  

# Freeze feature layers for faster training
for param in model.features.parameters():
    param.requires_grad = False

in_features = model.classifier[1].in_features
model.classifier = nn.Sequential(
    nn.Dropout(p=0.4),
    nn.Linear(in_features, NUM_CLASSES)
)

model = model.to(DEVICE)

# ==============================
# LOSS, OPTIMIZER, SCHEDULER
# ==============================
criterion = nn.CrossEntropyLoss(label_smoothing=0.05)
optimizer = optim.Adam(model.parameters(), lr=LR, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)

# ==============================
# RESUME CHECKPOINT IF EXISTS
# ==============================
start_epoch = 0
best_val_acc = 0.0
if os.path.exists(CHECKPOINT_PATH):
    print(f" Loading checkpoint from {CHECKPOINT_PATH}...")
    checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
    start_epoch = checkpoint["epoch"] + 1
    best_val_acc = checkpoint["best_val_acc"]
    print(f" Resumed at epoch {start_epoch}, Best Val Acc: {best_val_acc:.4f}")

# ==============================
# TRAINING LOOP
# ==============================
no_improve = 0
for epoch in range(start_epoch, NUM_EPOCHS):
    print(f"\n=== Epoch {epoch+1}/{NUM_EPOCHS} ===")

    # ---------- TRAIN ----------
    model.train()
    running_loss, correct, total = 0.0, 0, 0
    loop = tqdm(train_loader, desc="Training", leave=False)
    for imgs, labels in loop:
        imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * imgs.size(0)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
        loop.set_postfix(loss=loss.item(), acc=f"{correct/total:.4f}")

    train_acc = correct / total
    train_loss = running_loss / total

    # ---------- VALIDATION ----------
    model.eval()
    correct_val, total_val = 0, 0
    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
            outputs = model(imgs)
            _, preds = torch.max(outputs, 1)
            correct_val += (preds == labels).sum().item()
            total_val += labels.size(0)

    val_acc = correct_val / total_val
    scheduler.step()

    print(f"Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f}")

    # ---------- SAVE CHECKPOINT ----------
    torch.save({
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "scheduler_state_dict": scheduler.state_dict(),
        "best_val_acc": best_val_acc
    }, CHECKPOINT_PATH)

    # ---------- SAVE BEST MODEL ----------
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), BEST_MODEL_PATH)
        print(f" New best model saved (Val Acc: {best_val_acc:.4f})")
        no_improve = 0
    else:
        no_improve += 1
        if no_improve >= PATIENCE:
            print(" Early stopping triggered.")
            break

print("\n Training complete!")
print(f"Best Validation Accuracy: {best_val_acc:.4f}")
