### training 


In [None]:
# ---------------------- 1. Imports ----------------------
import os
import cv2
import torch
import numpy as np
import matplotlib.pyplot as plt
from collections import Counter
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay
from torch import nn
from torch.utils.data import Dataset
from fastai.vision.all import *
from albumentations.pytorch import ToTensorV2
import albumentations as A

# ---------------------- 2. Configuration ----------------------
BATCH_SIZE = 16  
IMG_SIZE = 224
EPOCHS = 20
LEARNING_RATE = 2e-5

dataset_root = r"D:\Submitted Matrial (conference&journal)\Sagittal Data Artical\V0.47 Dataset analysis\dataspitting\Cropped_ROIs_V0.47_Split\train_by_class"

# Auto-detect Device
device = torch.device("cpu")
print(f"Running on device: {device}")

# ---------------------- 3. Transforms ----------------------
train_transform = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.Rotate(limit=30, p=0.7),
    A.RandomBrightnessContrast(p=0.7),
    A.Resize(IMG_SIZE, IMG_SIZE),
    A.Normalize(mean=(0.5,), std=(0.5,)),
    ToTensorV2()
])

val_transform = A.Compose([
    A.Resize(IMG_SIZE, IMG_SIZE),
    A.Normalize(mean=(0.5,), std=(0.5,)),
    ToTensorV2()
])

# ---------------------- 4. Dataset Classes ----------------------
class ImageFolderDataset(Dataset):
    def __init__(self, root_dir):
        self.samples = []
        self.classes = sorted([d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))])
        self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}
        
        for label_str in self.classes:
            label_dir = os.path.join(root_dir, label_str)
            label = self.class_to_idx[label_str]
            for fname in os.listdir(label_dir):
                if fname.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp')):
                    self.samples.append((os.path.join(label_dir, fname), label))
                    
    def __len__(self): return len(self.samples)
    def __getitem__(self, idx):
        path, label = self.samples[idx]
        img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
        if img is None: return np.zeros((IMG_SIZE, IMG_SIZE, 3), dtype=np.uint8), label
        img = cv2.equalizeHist(cv2.resize(img, (IMG_SIZE, IMG_SIZE)))
        img = np.stack([img]*3, axis=-1)
        return img, label

class AlbumentationsDataset(Dataset):
    def __init__(self, subset, transform=None):
        self.subset = subset
        self.transform = transform
    def __len__(self): return len(self.subset)
    def __getitem__(self, idx):
        img, label = self.subset[idx]
        if self.transform:
            img = self.transform(image=img)['image']
        return img, label

# ---------------------- 5. Setup Data ----------------------
full_ds_raw = ImageFolderDataset(dataset_root)
train_idx, val_idx = train_test_split(list(range(len(full_ds_raw))), test_size=0.2, random_state=42, stratify=[s[1] for s in full_ds_raw.samples])

train_dataset = AlbumentationsDataset(torch.utils.data.Subset(full_ds_raw, train_idx), transform=train_transform)
val_dataset = AlbumentationsDataset(torch.utils.data.Subset(full_ds_raw, val_idx), transform=val_transform)

dls = DataLoaders.from_dsets(
    train_dataset, val_dataset,
    bs=BATCH_SIZE,
    num_workers=0, 
    pin_memory=False,
    device=device 
)

# ---------------------- 6. Handle Class Imbalance (NEW) ----------------------
# 1. Count samples in each class
all_labels = [label for _, label in full_ds_raw.samples]
counts = Counter(all_labels)
print(f"Class Counts: {dict(counts)}")

# 2. Calculate Inverse Weights: (Total / Class_Count)
# Rare classes get HIGH weights, Common classes get LOW weights
weights = []
for i in range(len(full_ds_raw.classes)):
    count = counts.get(i, 0)
    if count > 0:
        weights.append(1.0 / count)
    else:
        weights.append(1.0) # Prevent division by zero

# 3. Normalize weights so they sum to the number of classes (optional but standard)
weights = torch.tensor(weights, dtype=torch.float)
weights = weights / weights.sum() * len(full_ds_raw.classes)
weights = weights.to(device)

print(f"Calculated Class Weights: {weights}")

# ---------------------- 7. Model Setup ----------------------
print("Loading Model...")
model = torch.hub.load('facebookresearch/deit:main', 'deit_tiny_patch16_224', pretrained=True)
model.head = nn.Sequential(nn.Dropout(0.25), nn.Linear(192, len(full_ds_raw.classes)))
model = model.to(device)

# Apply the WEIGHTS to the Loss Function here
learn = Learner(dls, model, loss_func=CrossEntropyLossFlat(weight=weights), metrics=accuracy)

# ---------------------- 8. Training ----------------------
print("Starting Training...")
with learn.no_bar():
    learn.fit_one_cycle(EPOCHS, lr_max=LEARNING_RATE)

# ---------------------- 9. Evaluation ----------------------
print("Saving Model...")
save_path = os.path.join(dataset_root, 'deit_balanced.pth')
torch.save(learn.model.state_dict(), save_path) # Save pure state dict for safety

print("\nEvaluation...")
learn.model.eval()
preds, targs = learn.get_preds()
preds_classes = preds.argmax(dim=1)

print("\nClassification Report:")
print(classification_report(targs, preds_classes, target_names=full_ds_raw.classes))

print("\nConfusion Matrix:")
cm = confusion_matrix(targs, preds_classes)
fig, ax = plt.subplots(figsize=(8, 8))
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=full_ds_raw.classes)
disp.plot(cmap='Blues', values_format='d', ax=ax)
plt.title("Confusion Matrix (Balanced Training)")
plt.show()

#### testing 

In [None]:
# ---------------------- 1. Imports ----------------------
import os
import cv2
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, ConfusionMatrixDisplay
from fastai.vision.all import *
from albumentations.pytorch import ToTensorV2
import albumentations as A

# ---------------------- 2. Configuration ----------------------
BATCH_SIZE = 16
IMG_SIZE = 224
NUM_CLASSES = 4  # Ensure this matches your training (0, 1, 2, 3)

# Path to your TEST folder
test_dataset_root = r"D:\Submitted Matrial (conference&journal)\Sagittal Data Artical\V0.47 Dataset analysis\dataspitting\Cropped_ROIs_V0.47_Split\test_by_class"

# Path to your SAVED MODEL (from the previous training step)
# Make sure this points to where you saved the .pth file
model_path = r"D:\Submitted Matrial (conference&journal)\Sagittal Data Artical\V0.47 Dataset analysis\dataspitting\Cropped_ROIs_V0.47_Split\train_by_class\deit_balanced.pth"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on: {device}")

# ---------------------- 3. Transforms (Validation only) ----------------------
test_transform = A.Compose([
    A.Resize(IMG_SIZE, IMG_SIZE),
    A.Normalize(mean=(0.5,), std=(0.5,)),
    ToTensorV2()
])

# ---------------------- 4. Dataset Class (Test Mode) ----------------------
class TestImageFolderDataset(Dataset):
    def __init__(self, root_dir):
        self.samples = []
        # We explicitly define classes 0, 1, 2, 3 to match training
        # This ensures that even if class "3" is missing in test, the order stays correct.
        self.classes = ['0', '1', '2', '3'] 
        self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}
        
        print(f"Class Mapping: {self.class_to_idx}")
        
        for label_str in self.classes:
            label_dir = os.path.join(root_dir, label_str)
            if not os.path.exists(label_dir):
                print(f"Warning: Folder '{label_str}' not found in test path.")
                continue
                
            label = self.class_to_idx[label_str]
            for fname in os.listdir(label_dir):
                if fname.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp')):
                    self.samples.append((os.path.join(label_dir, fname), label))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, label = self.samples[idx]
        img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
        
        if img is None:
            return np.zeros((IMG_SIZE, IMG_SIZE, 3), dtype=np.uint8), label

        img = cv2.equalizeHist(cv2.resize(img, (IMG_SIZE, IMG_SIZE)))
        img = np.stack([img]*3, axis=-1)
        
        # Apply Transform
        img = test_transform(image=img)['image']
        return img, label

# ---------------------- 5. Load Data ----------------------
print("\nLoading Test Data...")
test_ds = TestImageFolderDataset(test_dataset_root)

# Create DataLoader
test_dl = DataLoader(
    test_ds, 
    batch_size=BATCH_SIZE, 
    shuffle=False, 
    num_workers=0, # Critical for Windows
    pin_memory=False
)
print(f"Found {len(test_ds)} test images.")

# ---------------------- 6. Load Model ----------------------
print("\nLoading Model Architecture...")
model = torch.hub.load('facebookresearch/deit:main', 'deit_tiny_patch16_224', pretrained=False)
model.head = nn.Sequential(nn.Dropout(0.25), nn.Linear(192, NUM_CLASSES))

print(f"Loading Weights from: {model_path}")
# Load the weights we saved earlier
state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict)

model.to(device)
model.eval() # Set to evaluation mode

# ---------------------- 7. Run Inference ----------------------
print("\nRunning Prediction...")
all_preds = []
all_labels = []

with torch.no_grad():
    for images, labels in test_dl:
        images = images.to(device)
        labels = labels.to(device)
        
        outputs = model(images)
        _, preds = torch.max(outputs, 1)
        
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# ---------------------- 8. Calculate Metrics ----------------------
print("\n" + "="*40)
print("             TEST REPORT")
print("="*40)

# Accuracy
acc = accuracy_score(all_labels, all_preds)
print(f"Total Accuracy: {acc:.4f}")

# Detailed Report (Precision, Recall, F1 per class)
report = classification_report(all_labels, all_preds, target_names=test_ds.classes)
print("\nDetailed Metrics:")
print(report)

# Confusion Matrix
print("\nConfusion Matrix:")
cm = confusion_matrix(all_labels, all_preds)
fig, ax = plt.subplots(figsize=(8, 8))
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=test_ds.classes)
disp.plot(cmap='Blues', values_format='d', ax=ax)
plt.title(f"Test Set Confusion Matrix\nAccuracy: {acc:.2f}")
plt.show()