In [15]:
#%% -------- 1. Import Dependencies --------
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import numpy as np
import cv2
from pathlib import Path
from skimage.filters import frangi, hessian
from skimage.morphology import erosion, dilation, disk
from skimage.feature import hessian_matrix_eigvals
import os
from tqdm import tqdm

#%% -------- 0. Configuration & Paths --------
DATA_ROOT = Path("Radiographs")
SPLITS = {
    'train': DATA_ROOT / "train",
    'val': DATA_ROOT / "val",
    'test': DATA_ROOT / "test"
}

In [16]:
#%% -------- 2. Custom Preprocessing Pipeline --------
class DentalPreprocessor:
    def __init__(self, clip_limit=2.0, grid_size=(8,8)):
        self.clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=grid_size)
        
    def __call__(self, img):
        # Convert to numpy array
        img_np = np.array(img)
        
        # CLAHE for contrast enhancement
        clahe_img = self.clahe.apply(img_np)
        
        # Frangi vesselness filter
        frangi_img = frangi(clahe_img, sigmas=range(1, 3, 1))
        
        # Entropy-driven dynamic morphology
        entropy_img = self.calculate_entropy(clahe_img)
        processed_img = self.dynamic_morphology(frangi_img, entropy_img)
        
        return torch.from_numpy(processed_img).float().unsqueeze(0)
    
    def calculate_entropy(self, img, kernel_size=7):
        from skimage.filters.rank import entropy
        return entropy(img, disk(kernel_size))
    
    def dynamic_morphology(self, img, entropy_img, threshold=0.5):
        selem = disk(2)
        normalized_entropy = (entropy_img - entropy_img.min()) / (entropy_img.max() - entropy_img.min())
        mask = normalized_entropy > threshold
        processed = np.zeros_like(img)
        processed[mask] = dilation(img[mask], selem)
        processed[~mask] = erosion(img[~mask], selem)
        return 
    
    #%% -------- 2.5 Custom Dataset Class --------
class DentalDataset(Dataset):
    def __init__(self, transform=None, mode='train'):
        self.transform = transform
        self.image_paths = []
        self.labels = []
        
        # Get the appropriate split directory
        split_dir = SPLITS[mode]
        
        # Load male samples (class 0)
        male_dir = split_dir / "male"
        male_images = list(male_dir.glob("*.png")) + list(male_dir.glob("*.jpg"))
        self.image_paths.extend(male_images)
        self.labels.extend([0] * len(male_images))
        
        # Load female samples (class 1)
        female_dir = split_dir / "female"
        female_images = list(female_dir.glob("*.png")) + list(female_dir.glob("*.jpg"))
        self.image_paths.extend(female_images)
        self.labels.extend([1] * len(female_images))

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('L')  # Convert to grayscale
        label = self.labels[idx]
        
        if self.transform:
            image = self.transform(image)
            
        return image, label

In [17]:
#%% -------- 3. Data Visualization Cell --------
class_names = ['Male', 'Female']  # Updated to male/female

def plot_class_distribution():
    counts = np.bincount(train_dataset.targets)
    plt.figure(figsize=(10, 5))
    plt.bar(class_names, counts)
    plt.title('Class Distribution in Training Set')
    plt.xlabel('Class')
    plt.ylabel('Count')
    plt.xticks(rotation=45)
    plt.show()

def show_sample_images(dataset, num_images=6):
    plt.figure(figsize=(15, 10))
    for i in range(num_images):
        idx = np.random.randint(len(dataset))
        img, label = dataset[idx]
        plt.subplot(2, 3, i+1)
        plt.imshow(img.squeeze(), cmap='gray')
        plt.title(f"Class: {class_names[label]}")
        plt.axis('off')
    plt.tight_layout()
    plt.show()

# Call after dataset creation
plot_class_distribution()
show_sample_images(train_dataset)

NameError: name 'train_dataset' is not defined

In [None]:
#%% -------- 4. Hybrid CNN-Transformer Model --------
class DentalGenderClassifier(nn.Module):
    def __init__(self, patch_size=7, embed_dim=256, num_heads=8, num_layers=6):
        super().__init__()
        
        # CNN Backbone
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((14, 14)))
        
        # Transformer Encoder
        self.patch_embed = nn.Linear(128*patch_size**2, embed_dim)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim, nhead=num_heads, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
        
        # Classifier
        self.classifier = nn.Linear(embed_dim, 1)
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
    def forward(self, x):
        # CNN feature extraction
        features = self.cnn(x)
        b, c, h, w = features.shape
        
        # Convert to patches
        patches = features.unfold(2, 7, 7).unfold(3, 7, 7)
        patches = patches.contiguous().view(b, -1, 128*7*7)
        
        # Transformer processing
        embeddings = self.patch_embed(patches)
        encoded = self.transformer(embeddings)
        
        # Classification
        pooled = encoded.mean(dim=1)
        return torch.sigmoid(self.classifier(pooled))

In [None]:
#%% -------- 5. Training Configuration --------
# Define transforms
train_transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.RandomRotation(15),
    transforms.RandomHorizontalFlip(),
    DentalPreprocessor(),
])

val_transform = transforms.Compose([
    transforms.Resize((512, 512)),
    DentalPreprocessor(),
])

# Initialize datasets
train_dataset = DentalDataset(transform=train_transform, mode='train')
val_dataset = DentalDataset(transform=val_transform, mode='val')
test_dataset = DentalDataset(transform=val_transform, mode='test')

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16)

# Initialize model and optimizer
model = DentalGenderClassifier().to(model.device)
optimizer = optim.AdamW(model.parameters(), lr=3e-5, weight_decay=0.01)
criterion = nn.BCELoss()

In [None]:
#%% -------- 6. Training Loop --------
best_accuracy = 0
for epoch in range(50):
    model.train()
    train_loss = 0
    for batch, (images, labels) in enumerate(tqdm(train_loader)):
        images = images.to(model.device)
        labels = labels.float().to(model.device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs.squeeze(), labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    
    # Validation
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(model.device)
            labels = labels.to(model.device)
            outputs = model(images)
            predicted = (outputs > 0.5).float()
            total += labels.size(0)
            correct += (predicted.squeeze() == labels).sum().item()
    
    accuracy = 100 * correct / total
    print(f"Epoch {epoch+1}: Train Loss {train_loss/len(train_loader):.4f}, Val Acc {accuracy:.2f}%")
    
    if accuracy > best_accuracy:
        best_accuracy = accuracy
        torch.save(model.state_dict(), 'best_model.pth')

In [None]:
#%% -------- 7. Model Evaluation --------
test_dataset = DentalDataset(transform=val_transform, mode='test')
test_loader = DataLoader(test_dataset, batch_size=16)

model.load_state_dict(torch.load('best_model.pth'))
all_preds = []
all_labels = []

with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(model.device)
        outputs = model(images)
        preds = (outputs > 0.5).cpu().numpy()
        all_preds.extend(preds)
        all_labels.extend(labels.numpy())

print(classification_report(all_labels, all_preds, target_names=class_names))
print("Confusion Matrix:")
print(confusion_matrix(all_labels, all_preds))

In [None]:
#%% -------- 8. Model Interpretation --------
def visualize_attention(model, img_tensor):
    # Implementation would use hooks to extract attention maps
    pass  # Add attention visualization implementation

In [None]:
#%% -------- 9. Inference Visualization Cell --------
def predict_and_visualize(model, dataset, num_images=9):
    model.eval()
    plt.figure(figsize=(15, 15))
    
    for i in range(num_images):
        idx = np.random.randint(len(dataset))
        img, true_label = dataset[idx]
        img_tensor = img.unsqueeze(0).to(model.device)
        
        with torch.no_grad():
            prob = model(img_tensor).item()
            pred_label = int(prob > 0.5)
        
        plt.subplot(3, 3, i+1)
        plt.imshow(img.squeeze(), cmap='gray')
        plt.title(f"True: {class_names[true_label]}\nPred: {class_names[pred_label]}\nConf: {prob:.2f}")
        plt.axis('off')
    
    plt.tight_layout()
    plt.show()

predict_and_visualize(model, test_dataset)

print("Full pipeline execution completed!")