In [50]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/updated-medical-image-dataset/new_dataset/Bone Break Classification/val/Avulsion fracture/389382_jpg.rf.8cb98ee798766a2d3f6a76279ba5d0d9.jpg
/kaggle/input/updated-medical-image-dataset/new_dataset/Bone Break Classification/val/Avulsion fracture/avulsionlessertrochanter01jakpelvis15yo_jpg.rf.075c733923f27bd4611acf064a41a4e3.jpg
/kaggle/input/updated-medical-image-dataset/new_dataset/Bone Break Classification/val/Avulsion fracture/000002_png.rf.c3e00ebc2db78bc94e644c3f6605dad0.jpg
/kaggle/input/updated-medical-image-dataset/new_dataset/Bone Break Classification/val/Avulsion fracture/1b62e6fbfbc5a2f70c6af413189cfc82_jpg.rf.2761e933cc6d206308877cccaf0642b2.jpg
/kaggle/input/updated-medical-image-dataset/new_dataset/Bone Break Classification/val/Avulsion fracture/13256_2019_2325_Fig1_HTML_png.rf.09368fddb2da3979a3e1e25a0cac6f45.jpg
/kaggle/input/updated-medical-image-dataset/new_dataset/Bone Break Classification/val/Avulsion fracture/60683ca7a8a5848feda86d15_acl-avulsion-fract

In [67]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split, ConcatDataset
from torchvision import transforms
from PIL import Image
import numpy as np
from tqdm import tqdm
from vit_pytorch import ViT

In [69]:
# Dataset Class
class MedicalImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []
        self.class_to_idx = {}
        
        self.load_images()

    def load_images(self):
        """
        Load images from the dataset directory.
        Assumes directory structure: root_dir/train_or_val/class/image.jpg
        """
        for subset in ['train', 'val', 'test']:  # Add more subsets if needed
            subset_path = os.path.join(self.root_dir, subset)
            
            if os.path.exists(subset_path):
                for class_name in os.listdir(subset_path):
                    class_path = os.path.join(subset_path, class_name)
                    
                    if os.path.isdir(class_path):
                        # Add class to mapping if not already exists
                        if class_name not in self.class_to_idx:
                            self.class_to_idx[class_name] = len(self.class_to_idx)
                        
                        # Load images for this class
                        for img_name in os.listdir(class_path):
                            img_path = os.path.join(class_path, img_name)
                            
                            if img_name.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tif')):
                                self.image_paths.append(img_path)
                                self.labels.append(self.class_to_idx[class_name])

        print(f"Found {len(self.image_paths)} images across {len(self.class_to_idx)} classes")
        print("Class mapping:", self.class_to_idx)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]

        try:
            image = Image.open(img_path).convert("RGB")
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            # Return a dummy image to prevent training interruption
            image = Image.fromarray(np.zeros((224, 224, 3), dtype=np.uint8))

        if self.transform:
            image = self.transform(image)

        return image, label

In [70]:
# Data Augmentation and Normalization
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [71]:
# Model Creation
def create_model(num_classes):
    model = ViT(
        image_size=224,
        patch_size=16,
        num_classes=num_classes,
        dim=512,
        depth=6,
        heads=8,
        mlp_dim=1024,
        dropout=0.1,
        emb_dropout=0.1
    )
    
    # Wrap the model for multi-GPU support
    if torch.cuda.device_count() > 1:
        print(f"🚀 Using {torch.cuda.device_count()} GPUs!")
        model = nn.DataParallel(model)
    
    # Move model to GPU
    model = model.to(DEVICE)
    
    return model

In [72]:
def train_model(model, train_loader, valid_loader, num_epochs=15, learning_rate=1e-4):
    criterion = nn.CrossEntropyLoss().to(DEVICE)
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

    best_val_accuracy = 0.0

    for epoch in range(num_epochs):
        # Training Phase
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0

        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
        for images, labels in progress_bar:
            images, labels = images.to(DEVICE), labels.to(DEVICE)

            optimizer.zero_grad()
            
            with torch.cuda.amp.autocast():
                outputs = model(images)
                loss = criterion(outputs, labels)
            
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()

            progress_bar.set_postfix({
                'Loss': f'{loss.item():.4f}', 
                'Accuracy': f'{100 * train_correct / train_total:.2f}%'
            })

        # Validation Phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for images, labels in valid_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                
                with torch.cuda.amp.autocast():
                    outputs = model(images)
                    loss = criterion(outputs, labels)
                
                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()

        train_accuracy = 100 * train_correct / train_total
        val_accuracy = 100 * val_correct / val_total

        print(f"Epoch {epoch+1}: "
              f"Train Loss: {train_loss/len(train_loader):.4f}, "
              f"Train Accuracy: {train_accuracy:.2f}%, "
              f"Val Loss: {val_loss/len(valid_loader):.4f}, "
              f"Val Accuracy: {val_accuracy:.2f}%")

        scheduler.step()

        # Save best model
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            torch.save(model.state_dict(), "best_medical_vit_model.pth")

    return model

In [74]:
# Main Execution
def main():
    # Configuration
    BASE_PATH = r"/kaggle/input/updated-medical-image-dataset/new_dataset"  # Update with your path
    BATCH_SIZE = 32
    EPOCHS = 15
    LEARNING_RATE = 1e-4

    # Set the device globally
    global DEVICE
    DEVICE = get_device()

    # Initialize datasets for each task
    task_datasets = []
    all_class_to_idx = {}

    # Iteratively load datasets for each task
    for task_folder in os.listdir(BASE_PATH):
        task_path = os.path.join(BASE_PATH, task_folder)
        if os.path.isdir(task_path):
            task_dataset = MedicalImageDataset(task_path, transform=transform)
            
            # Only add non-empty datasets
            if len(task_dataset) > 0:
                task_datasets.append(task_dataset)
                all_class_to_idx.update(task_dataset.class_to_idx)

    # Combine all datasets
    full_dataset = ConcatDataset(task_datasets)

    # Print total dataset size
    total_size = len(full_dataset)
    print(f"Total dataset size: {total_size}")

    if total_size == 0:
        raise ValueError("No images found in the dataset. Please check your dataset directory structure.")

    # Split dataset
    train_size = max(1, int(0.7 * total_size))
    valid_size = max(1, int(0.2 * total_size))
    test_size = total_size - train_size - valid_size

    train_dataset, valid_dataset, test_dataset = random_split(
        full_dataset, 
        [train_size, valid_size, test_size],
        generator=torch.Generator().manual_seed(42)
    )

    # DataLoaders
    train_loader = DataLoader(
        train_dataset, 
        batch_size=BATCH_SIZE, 
        shuffle=True, 
        num_workers=4,
        pin_memory=True,
        persistent_workers=True
    )

    valid_loader = DataLoader(
        valid_dataset, 
        batch_size=BATCH_SIZE, 
        shuffle=False, 
        num_workers=4,
        pin_memory=True,
        persistent_workers=True
    )

    num_classes = len(all_class_to_idx)
    print(f"Number of classes detected: {num_classes}")

    # Create and train model
    model = create_model(num_classes)
    trained_model = train_model(model, train_loader, valid_loader, num_epochs=EPOCHS, learning_rate=LEARNING_RATE)

    # Final model save
    torch.save(trained_model.state_dict(), "final_medical_vit_model.pth")
    print("✅ Model training complete!")

In [75]:
if __name__ == "__main__":
    main()

✅ GPU Available: Tesla P100-PCIE-16GB
   CUDA Version: 12.1
   cuDNN Version: 90100
   Total GPU Memory: 15.89 GB
Found 1129 images across 10 classes
Class mapping: {'Avulsion fracture': 0, 'Spiral Fracture': 1, 'Impacted fracture': 2, 'Hairline Fracture': 3, 'Greenstick fracture': 4, 'Pathological fracture': 5, 'Oblique fracture': 6, 'Fracture Dislocation': 7, 'Longitudinal fracture': 8, 'Comminuted fracture': 9}
Found 387 images across 7 classes
Class mapping: {'squamous.cell.carcinoma_left.hilum_T1_N2_M0_IIIa': 0, 'normal': 1, 'large.cell.carcinoma_left.hilum_T2_N2_M0_IIIa': 2, 'adenocarcinoma_left.lower.lobe_T2_N0_M0_Ib': 3, 'squamous.cell.carcinoma': 4, 'adenocarcinoma': 5, 'large.cell.carcinoma': 6}
Found 3297 images across 2 classes
Class mapping: {'benign': 0, 'malignant': 1}
Found 7023 images across 4 classes
Class mapping: {'pituitary': 0, 'notumor': 1, 'meningioma': 2, 'glioma': 3}
Total dataset size: 11836
Number of classes detected: 23


  with torch.cuda.amp.autocast():
Epoch 1/15: 100%|██████████| 259/259 [00:31<00:00,  8.30it/s, Loss=0.6994, Accuracy=51.29%]
  with torch.cuda.amp.autocast():


Epoch 1: Train Loss: 1.2416, Train Accuracy: 51.29%, Val Loss: 0.9039, Val Accuracy: 64.93%


Epoch 2/15: 100%|██████████| 259/259 [00:30<00:00,  8.41it/s, Loss=1.0434, Accuracy=65.03%]


Epoch 2: Train Loss: 0.8850, Train Accuracy: 65.03%, Val Loss: 0.7581, Val Accuracy: 71.65%


Epoch 3/15: 100%|██████████| 259/259 [00:30<00:00,  8.43it/s, Loss=0.6284, Accuracy=70.33%]


Epoch 3: Train Loss: 0.7622, Train Accuracy: 70.33%, Val Loss: 0.7122, Val Accuracy: 73.38%


Epoch 4/15: 100%|██████████| 259/259 [00:30<00:00,  8.41it/s, Loss=0.3240, Accuracy=71.99%]


Epoch 4: Train Loss: 0.7135, Train Accuracy: 71.99%, Val Loss: 0.6742, Val Accuracy: 72.96%


Epoch 5/15: 100%|██████████| 259/259 [00:30<00:00,  8.42it/s, Loss=0.7716, Accuracy=74.25%]


Epoch 5: Train Loss: 0.6580, Train Accuracy: 74.25%, Val Loss: 0.7164, Val Accuracy: 71.36%


Epoch 6/15: 100%|██████████| 259/259 [00:30<00:00,  8.43it/s, Loss=0.5874, Accuracy=74.79%]


Epoch 6: Train Loss: 0.6445, Train Accuracy: 74.79%, Val Loss: 0.6459, Val Accuracy: 76.00%


Epoch 7/15: 100%|██████████| 259/259 [00:30<00:00,  8.43it/s, Loss=0.8248, Accuracy=75.82%]


Epoch 7: Train Loss: 0.6049, Train Accuracy: 75.82%, Val Loss: 0.6293, Val Accuracy: 75.67%


Epoch 8/15: 100%|██████████| 259/259 [00:30<00:00,  8.42it/s, Loss=0.7656, Accuracy=77.45%]


Epoch 8: Train Loss: 0.5804, Train Accuracy: 77.45%, Val Loss: 0.5946, Val Accuracy: 76.17%


Epoch 9/15: 100%|██████████| 259/259 [00:30<00:00,  8.44it/s, Loss=0.9229, Accuracy=78.27%]


Epoch 9: Train Loss: 0.5536, Train Accuracy: 78.27%, Val Loss: 0.5981, Val Accuracy: 77.02%


Epoch 10/15: 100%|██████████| 259/259 [00:30<00:00,  8.42it/s, Loss=0.5004, Accuracy=79.53%]


Epoch 10: Train Loss: 0.5244, Train Accuracy: 79.53%, Val Loss: 0.5726, Val Accuracy: 78.20%


Epoch 11/15: 100%|██████████| 259/259 [00:30<00:00,  8.41it/s, Loss=0.5787, Accuracy=80.42%]


Epoch 11: Train Loss: 0.5107, Train Accuracy: 80.42%, Val Loss: 0.5720, Val Accuracy: 78.16%


Epoch 12/15: 100%|██████████| 259/259 [00:30<00:00,  8.44it/s, Loss=0.5964, Accuracy=80.74%]


Epoch 12: Train Loss: 0.4877, Train Accuracy: 80.74%, Val Loss: 0.5679, Val Accuracy: 79.17%


Epoch 13/15: 100%|██████████| 259/259 [00:30<00:00,  8.43it/s, Loss=0.4353, Accuracy=81.73%]


Epoch 13: Train Loss: 0.4687, Train Accuracy: 81.73%, Val Loss: 0.5422, Val Accuracy: 78.92%


Epoch 14/15: 100%|██████████| 259/259 [00:30<00:00,  8.44it/s, Loss=0.5057, Accuracy=82.32%]


Epoch 14: Train Loss: 0.4593, Train Accuracy: 82.32%, Val Loss: 0.5468, Val Accuracy: 79.38%


Epoch 15/15: 100%|██████████| 259/259 [00:30<00:00,  8.41it/s, Loss=0.4811, Accuracy=82.03%]


Epoch 15: Train Loss: 0.4570, Train Accuracy: 82.03%, Val Loss: 0.5557, Val Accuracy: 79.17%
✅ Model training complete!


# Model Validation and Testing Script

In [77]:
# Data Loading and Preprocessing (Similar to Training Script)
class MedicalImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []
        self.class_to_idx = {}
        
        self.load_images()

    def load_images(self):
        for subset in ['test', 'val']:  # Focus on test and validation sets
            subset_path = os.path.join(self.root_dir, subset)
            
            if os.path.exists(subset_path):
                for class_name in os.listdir(subset_path):
                    class_path = os.path.join(subset_path, class_name)
                    
                    if os.path.isdir(class_path):
                        if class_name not in self.class_to_idx:
                            self.class_to_idx[class_name] = len(self.class_to_idx)
                        
                        for img_name in os.listdir(class_path):
                            img_path = os.path.join(class_path, img_name)
                            
                            if img_name.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tif')):
                                self.image_paths.append(img_path)
                                self.labels.append(self.class_to_idx[class_name])

        print(f"Found {len(self.image_paths)} test images across {len(self.class_to_idx)} classes")
        print("Class mapping:", self.class_to_idx)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]

        try:
            image = Image.open(img_path).convert("RGB")
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            image = Image.fromarray(np.zeros((224, 224, 3), dtype=np.uint8))

        if self.transform:
            image = self.transform(image)

        return image, label

In [76]:
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (
    confusion_matrix, 
    classification_report, 
    precision_recall_fscore_support
)
from vit_pytorch import ViT

In [78]:
# Transforms (Same as Training Script)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [79]:
def create_model(num_classes):
    model = ViT(
        image_size=224,
        patch_size=16,
        num_classes=num_classes,
        dim=512,
        depth=6,
        heads=8,
        mlp_dim=1024,
        dropout=0.1,
        emb_dropout=0.1
    )
    return model


In [87]:
def validate_model(model, test_loader, device):
    """
    Comprehensive model validation function
    
    Args:
        model (nn.Module): Trained model
        test_loader (DataLoader): Test data loader
        device (torch.device): Computing device
    
    Returns:
        dict: Validation metrics
    """
    model.eval()
    all_preds = []
    all_labels = []
    correct_predictions = 0
    total_predictions = 0

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            
            total_predictions += labels.size(0)
            correct_predictions += (predicted == labels).sum().item()
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Compute metrics
    accuracy = 100 * correct_predictions / total_predictions
    
    # Detailed Classification Report
    report = classification_report(
        all_labels, 
        all_preds, 
        target_names=list(test_loader.dataset.class_to_idx.keys())

    )
    
    # Precision, Recall, F1-Score
    precision, recall, f1, _ = precision_recall_fscore_support(
        all_labels, all_preds, average='weighted'
    )
    
    # Confusion Matrix
    cm = confusion_matrix(all_labels, all_preds)

    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1_score': f1,
        'classification_report': report,
        'confusion_matrix': cm
    }

In [88]:
def plot_confusion_matrix(cm, classes, title='Confusion Matrix'):
    """
    Plot confusion matrix using seaborn
    """
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=classes, 
                yticklabels=classes)
    plt.title(title)
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.tight_layout()
    plt.savefig('confusion_matrix.png')
    plt.close()

In [94]:
def main():
    # Configuration
    BASE_PATH = r"/kaggle/input/updated-medical-image-dataset/new_dataset"  # Update with your path
    MODEL_PATH = "/kaggle/working/final_medical_vit_model.pth"
    BATCH_SIZE = 32

    # Device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Test Dataset
    task_datasets = []
    all_class_to_idx = {}

    for task_folder in os.listdir(BASE_PATH):
        task_path = os.path.join(BASE_PATH, task_folder)
        if os.path.isdir(task_path):
            task_dataset = MedicalImageDataset(task_path, transform=transform)
            
            if len(task_dataset) > 0:
                task_datasets.append(task_dataset)
                all_class_to_idx.update(task_dataset.class_to_idx)

    # Test DataLoader
    test_loader = DataLoader(
        task_datasets[0],  # Use first dataset for testing
        batch_size=BATCH_SIZE, 
        shuffle=False, 
        num_workers=4
    )

    # Model
    num_classes = len(all_class_to_idx)
    model = create_model(num_classes)
    
    # Load trained weights
    model.load_state_dict(torch.load(MODEL_PATH, weights_only=True))
    model = model.to(device)

    # Validation
    validation_results = validate_model(model, test_loader, device)

    # Print Results
    print("\n🔍 Model Validation Results:")
    print(f"Accuracy: {validation_results['accuracy']:.2f}%")
    print(f"Precision: {validation_results['precision']:.4f}")
    print(f"Recall: {validation_results['recall']:.4f}")
    print(f"F1 Score: {validation_results['f1_score']:.4f}")

    # Print Detailed Classification Report
    print("\n📊 Classification Report:")
    print(validation_results['classification_report'])

    # Plot Confusion Matrix
    plot_confusion_matrix(
        validation_results['confusion_matrix'], 
        list(test_loader.dataset.class_to_idx.keys())
    )


In [95]:
if __name__ == "__main__":
    main()

Using device: cuda
Found 140 test images across 10 classes
Class mapping: {'Avulsion fracture': 0, 'Spiral Fracture': 1, 'Impacted fracture': 2, 'Hairline Fracture': 3, 'Greenstick fracture': 4, 'Pathological fracture': 5, 'Oblique fracture': 6, 'Fracture Dislocation': 7, 'Longitudinal fracture': 8, 'Comminuted fracture': 9}
Found 387 test images across 7 classes
Class mapping: {'squamous.cell.carcinoma': 0, 'normal': 1, 'adenocarcinoma': 2, 'large.cell.carcinoma': 3, 'squamous.cell.carcinoma_left.hilum_T1_N2_M0_IIIa': 4, 'large.cell.carcinoma_left.hilum_T2_N2_M0_IIIa': 5, 'adenocarcinoma_left.lower.lobe_T2_N0_M0_Ib': 6}
Found 660 test images across 2 classes
Class mapping: {'benign': 0, 'malignant': 1}
Found 1311 test images across 4 classes
Class mapping: {'pituitary': 0, 'notumor': 1, 'meningioma': 2, 'glioma': 3}

🔍 Model Validation Results:
Accuracy: 22.86%
Precision: 0.3096
Recall: 0.2286
F1 Score: 0.2197

📊 Classification Report:
                       precision    recall  f1-sc