In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import numpy as np
from sklearn.metrics import classification_report
import timm  # Using timm for pretrained ViT

# Configuration
CFG = {
    "image_size": 224,
    "batch_size": 32,
    "num_workers": 4,
    "num_classes": 7,  # 7 oral diseases in the dataset
    "lr": 3e-4,
    "epochs": 20,
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "model_name": "vit_base_patch16_224",
    "pretrained": True,
    "model_path": "best_vit_oral.pth"
}

# Data Transforms
train_transform = transforms.Compose([
    transforms.Resize((CFG['image_size'], CFG['image_size'])),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((CFG['image_size'], CFG['image_size'])),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load Dataset
def create_dataloaders(data_dir):
    full_dataset = datasets.ImageFolder(
        root=data_dir,
        transform=train_transform
    )
    
    # Split dataset (80% train, 20% val)
    train_size = int(0.8 * len(full_dataset))
    val_size = len(full_dataset) - train_size
    train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
    
    # Apply val transform to validation set
    val_dataset.dataset.transform = val_transform
    
    # Get class weights
    class_counts = np.bincount(full_dataset.targets)
    class_weights = 1. / class_counts
    class_weights = torch.tensor(class_weights, dtype=torch.float32).to(CFG['device'])
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=CFG['batch_size'],
        shuffle=True,
        num_workers=CFG['num_workers']
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=CFG['batch_size'],
        shuffle=False,
        num_workers=CFG['num_workers']
    )
    
    return train_loader, val_loader, class_weights

def create_model(class_weights):
    model = timm.create_model(
        CFG['model_name'],
        pretrained=CFG['pretrained'],
        num_classes=CFG['num_classes']
    ).to(CFG['device'])
    
    # Loss and Optimizer
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    optimizer = optim.AdamW(model.parameters(), lr=CFG['lr'], weight_decay=0.01)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CFG['epochs'])
    
    return model, criterion, optimizer, scheduler
# Training Function (same as before)
def train_model(model, criterion, optimizer, scheduler):
    best_accuracy = 0.0
    
    for epoch in range(CFG['epochs']):
        model.train()
        running_loss = 0.0
        
        for images, labels in train_loader:
            images = images.to(CFG['device'])
            labels = labels.to(CFG['device'])
            
            optimizer.zero_grad()
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item() * images.size(0)
            
        scheduler.step()
        epoch_loss = running_loss / len(train_loader.dataset)
        
        # Validation
        model.eval()
        val_loss = 0.0
        correct = 0
        all_preds = []
        all_labels = []
        
        with torch.no_grad():
            for images, labels in val_loader:
                images = images.to(CFG['device'])
                labels = labels.to(CFG['device'])
                
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item() * images.size(0)
                
                _, preds = torch.max(outputs, 1)
                correct += torch.sum(preds == labels.data)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
                
        val_loss = val_loss / len(val_loader.dataset)
        val_accuracy = correct.double() / len(val_loader.dataset)
        
        print(f"Epoch {epoch+1}/{CFG['epochs']}")
        print(f"Train Loss: {epoch_loss:.4f} | Val Loss: {val_loss:.4f}")
        print(f"Val Accuracy: {val_accuracy:.4f}")
        
        # Save best model
        if val_accuracy > best_accuracy:
            best_accuracy = val_accuracy
            torch.save(model.state_dict(), CFG['model_path'])
            print(f"Saved new best model with accuracy {val_accuracy:.4f}")
            
    # Final evaluation
    print("\nClassification Report:")
    print(classification_report(all_labels, all_preds, target_names=class_names))

# Main Execution
if __name__ == "__main__":
    import os
    from pathlib import Path
    
    # Get dataset path from user input
    data_dir = input(r"E:\PROJECT_VI-SEM\Dataset_types").strip()
    data_dir = os.path.expanduser(data_dir)  # Handles ~ in paths
    
    # Verify path exists
    if not os.path.exists(data_dir):
        raise FileNotFoundError(f"Path '{data_dir}' does not exist!")
    if not os.path.isdir(data_dir):
        raise NotADirectoryError(f"'{data_dir}' is not a directory!")
    
    # Create dataloaders
    train_loader, val_loader, class_weights = create_dataloaders(data_dir)
    class_names = train_loader.dataset.dataset.classes
    
    # Initialize model
    model, criterion, optimizer, scheduler = create_model(class_weights)
    
    print("\nDataset Info:")
    print(f"Class names: {class_names}")
    print(f"Total classes: {len(class_names)}")
    print(f"Training samples: {len(train_loader.dataset)}")
    print(f"Validation samples: {len(val_loader.dataset)}")
    print(f"Training on {CFG['device']}\n")
    
    # Start training
    train_model(model, criterion, optimizer, scheduler)