In [None]:
# ============================
# 1. Install Dependencies
# ============================
!pip install torch torchvision torchmetrics tqdm timm

# ============================
# 2. Imports
# ============================
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchmetrics.classification import Accuracy
from tqdm import tqdm
from timm import create_model
from google.colab import files

# ============================
# 3. Configurations
# ============================
BATCH_SIZE = 64        # Increase if memory allows
NUM_EPOCHS = 10
LEARNING_RATE = 1e-4
NUM_CLASSES = 6        # hypodontia, Tooth Discoloration, Data caries, Gingivitis, Mouth Ulcer, Calculus
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ============================
# 4. Data Augmentation and Preprocessing
# ============================
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

train_dataset = datasets.ImageFolder(root="/content/split_oral_diseases/train", transform=transform)
val_dataset = datasets.ImageFolder(root="/content/split_oral_diseases/val", transform=transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)

print("Classes:", train_dataset.classes)

# ============================
# 5. Model Setup: Vision Transformer (ViT Tiny)
# ============================
model = create_model('vit_tiny_patch16_224', pretrained=True)
model.head = nn.Linear(model.head.in_features, NUM_CLASSES)
model = model.to(DEVICE)

# ============================
# 6. Loss Function and Optimizer
# ============================
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

# ============================
# 7. Metrics
# ============================
train_acc = Accuracy(task="multiclass", num_classes=NUM_CLASSES).to(DEVICE)
val_acc = Accuracy(task="multiclass", num_classes=NUM_CLASSES).to(DEVICE)

# ============================
# 8. Mixed Precision Setup
# ============================
scaler = torch.cuda.amp.GradScaler()

# ============================
# 9. Training Loop
# ============================
for epoch in range(NUM_EPOCHS):
    model.train()a
    total_loss = 0.0
    train_acc.reset()

    for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} - Training"):
        images, labels = images.to(DEVICE), labels.to(DEVICE)

        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            outputs = model(images)
            loss = criterion(outputs, labels)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()
        preds = torch.argmax(outputs, dim=1)
        train_acc.update(preds, labels)

    avg_train_loss = total_loss / len(train_loader)
    train_accuracy = train_acc.compute().item()

    # Validation phase
    model.eval()
    total_val_loss = 0.0
    val_acc.reset()

    with torch.no_grad():
        for images, labels in tqdm(val_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} - Validation"):
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            with torch.cuda.amp.autocast():
                outputs = model(images)
                loss = criterion(outputs, labels)

            total_val_loss += loss.item()
            preds = torch.argmax(outputs, dim=1)
            val_acc.update(preds, labels)

    avg_val_loss = total_val_loss / len(val_loader)
    val_accuracy = val_acc.compute().item()

    print(f"Epoch {epoch+1}: Train Loss={avg_train_loss:.4f}, Train Acc={train_accuracy:.4f}, "
          f"Val Loss={avg_val_loss:.4f}, Val Acc={val_accuracy:.4f}")

# ============================
# 10. Save Model
# ============================
torch.save(model.state_dict(), "vit_tiny_dental.pth")
print("✅ Model saved as vit_tiny_dental.pth")

# ============================
# 11. Download Model
# ============================
files.download("vit_tiny_dental.pth")

# ============================
# 12. Final Evaluation
# ============================
model.eval()
val_acc.reset()
total_val_loss = 0.0

with torch.no_grad():
    for images, labels in tqdm(val_loader, desc="Final Evaluation"):
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        with torch.cuda.amp.autocast():
            outputs = model(images)
            loss = criterion(outputs, labels)

        total_val_loss += loss.item()
        preds = torch.argmax(outputs, dim=1)
        val_acc.update(preds, labels)

avg_val_loss = total_val_loss / len(val_loader)
final_val_accuracy = val_acc.compute().item()

print(f"\n✅ Final Evaluation Results:")
print(f"Validation Loss: {avg_val_loss:.4f}")
print(f"Validation Accuracy: {final_val_accuracy:.4f}")
