In [1]:
import os

base_folder_path = r'C:\Users\Meetp\OneDrive\Documents\ComputerVision\FinalProject\bonebreak\Bone Break Classification'
image_types = ['jpg', 'png']

for fracture_type in os.listdir(base_folder_path):
    fracture_path = os.path.join(base_folder_path, fracture_type)
    if os.path.isdir(fracture_path):
        count = 0
        for root, _, files in os.walk(fracture_path):
            for file in files:
                if any(file.lower().endswith(ext) for ext in image_types):
                    count += 1
        print(f"{fracture_type}: {count}")


Avulsion fracture: 123
Comminuted fracture: 148
Fracture Dislocation: 156
Greenstick fracture: 122
Hairline Fracture: 111
Impacted fracture: 84
Longitudinal fracture: 80
Not Fractured: 4908
Oblique fracture: 85
Pathological fracture: 134
Spiral Fracture: 86


In [2]:
# Stratified/Balanced Sampling (During Training)

from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from PIL import ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

dataset = datasets.ImageFolder(base_folder_path, transform=transform)

In [3]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, WeightedRandomSampler
from collections import Counter
import os

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

dataset = datasets.ImageFolder(base_folder_path, transform=transform)

# Recalculate class counts
targets = [label for _, label in dataset]
class_counts = Counter(targets)
num_samples = len(dataset)
class_weights = {cls: num_samples / count for cls, count in class_counts.items()}
sample_weights = [class_weights[label] for label in targets]

sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(dataset), replacement=True)
dataloader = DataLoader(dataset, batch_size=32, sampler=sampler)


In [4]:
from collections import Counter

label_counts = Counter()

# Iterate through the full dataloader once
for _, labels in dataloader:
    label_counts.update(labels.tolist())

print("Sampled class distribution:")
for class_idx, count in sorted(label_counts.items()):
    print(f"Class {class_idx} ({dataset.classes[class_idx]}): {count} samples")


Sampled class distribution:
Class 0 (Avulsion fracture): 525 samples
Class 1 (Comminuted fracture): 545 samples
Class 2 (Fracture Dislocation): 591 samples
Class 3 (Greenstick fracture): 532 samples
Class 4 (Hairline Fracture): 570 samples
Class 5 (Impacted fracture): 565 samples
Class 6 (Longitudinal fracture): 523 samples
Class 7 (Not Fractured): 548 samples
Class 8 (Oblique fracture): 542 samples
Class 9 (Pathological fracture): 562 samples
Class 10 (Spiral Fracture): 534 samples


In [5]:
import torch.nn as nn
import torchvision.models as models
from torchvision.models import resnet18, ResNet18_Weights

class FractureClassifier(nn.Module):
    def __init__(self, num_fracture_types=10):  # not counting "not fractured"
        super(FractureClassifier, self).__init__()
        self.base = resnet18(weights=ResNet18_Weights.DEFAULT)
        in_features = self.base.fc.in_features
        self.base.fc = nn.Identity()

        # Head 1: Fracture or not (binary)
        self.fracture_detect = nn.Linear(in_features, 2)

        # Head 2: Type of fracture (multiclass)
        self.fracture_type = nn.Linear(in_features, num_fracture_types)

    def forward(self, x):
        features = self.base(x)
        out1 = self.fracture_detect(features)  # Binary output
        out2 = self.fracture_type(features)    # Multiclass output
        return out1, out2


In [6]:
all_labels = []  # original labels from dataset (0 = Not Fractured, 1-10 = fracture types)

binary_labels = [0 if label == 0 else 1 for label in all_labels]
fracture_type_labels = [label - 1 if label > 0 else -1 for label in all_labels]


In [7]:
import torch.nn.functional as F

loss_fn_binary = nn.CrossEntropyLoss()
loss_fn_type = nn.CrossEntropyLoss()


In [8]:
import torch
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = FractureClassifier()
model = model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

num_epochs = 25

for epoch in range(num_epochs):
    running_loss = 0.0
    correct_binary = 0
    total_binary = 0
    correct_type = 0
    total_type = 0

    model.train()
    loop = tqdm(dataloader, leave=True)

    for images, original_labels in loop:
        images = images.to(device)
        original_labels = original_labels.to(device)

        # Your original label conversions
        binary_labels = (original_labels != 0).long().to(device)    # Not Fractured = 0, Fractured = 1
        type_labels = (original_labels - 1).long().to(device)       # Shift classes for type prediction

        output1, output2 = model(images)

        loss_binary = loss_fn_binary(output1, binary_labels)

        # Mask: Only fractured images (binary = 1)
        mask = binary_labels == 1
        if mask.sum() > 0:
            loss_type = loss_fn_type(output2[mask], type_labels[mask])
            loss = loss_binary + loss_type
        else:
            loss = loss_binary

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        # Accuracy tracking
        pred_binary = torch.argmax(output1, dim=1)
        correct_binary += (pred_binary == binary_labels).sum().item()
        total_binary += binary_labels.size(0)

        if mask.sum() > 0:
            pred_type = torch.argmax(output2[mask], dim=1)
            correct_type += (pred_type == type_labels[mask]).sum().item()
            total_type += mask.sum().item()

        loop.set_description(f"Epoch [{epoch+1}/{num_epochs}]")
        loop.set_postfix(loss=loss.item())

    binary_acc = 100 * correct_binary / total_binary
    type_acc = 100 * correct_type / total_type if total_type > 0 else 0

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss:.4f}, Binary Accuracy: {binary_acc:.2f}%, Type Accuracy: {type_acc:.2f}%")

    # Save model after each epoch
    torch.save(model.state_dict(), f"model_epoch_{epoch+1}.pt")


Epoch [1/25]: 100%|██████████| 189/189 [03:50<00:00,  1.22s/it, loss=0.174] 


Epoch 1/25, Loss: 156.7177, Binary Accuracy: 96.36%, Type Accuracy: 81.26%


Epoch [2/25]: 100%|██████████| 189/189 [03:48<00:00,  1.21s/it, loss=0.061] 


Epoch 2/25, Loss: 11.0435, Binary Accuracy: 99.88%, Type Accuracy: 99.40%


Epoch [3/25]: 100%|██████████| 189/189 [03:54<00:00,  1.24s/it, loss=0.00693]


Epoch 3/25, Loss: 3.3406, Binary Accuracy: 100.00%, Type Accuracy: 99.84%


Epoch [4/25]: 100%|██████████| 189/189 [03:46<00:00,  1.20s/it, loss=0.0104] 


Epoch 4/25, Loss: 2.1572, Binary Accuracy: 100.00%, Type Accuracy: 99.87%


Epoch [5/25]: 100%|██████████| 189/189 [03:46<00:00,  1.20s/it, loss=0.00305]


Epoch 5/25, Loss: 0.9147, Binary Accuracy: 100.00%, Type Accuracy: 100.00%


Epoch [6/25]: 100%|██████████| 189/189 [03:44<00:00,  1.19s/it, loss=0.00408] 


Epoch 6/25, Loss: 0.5423, Binary Accuracy: 100.00%, Type Accuracy: 100.00%


Epoch [7/25]: 100%|██████████| 189/189 [03:43<00:00,  1.18s/it, loss=0.0125]  


Epoch 7/25, Loss: 0.4173, Binary Accuracy: 100.00%, Type Accuracy: 100.00%


Epoch [8/25]: 100%|██████████| 189/189 [03:49<00:00,  1.21s/it, loss=0.00184] 


Epoch 8/25, Loss: 0.3516, Binary Accuracy: 100.00%, Type Accuracy: 100.00%


Epoch [9/25]: 100%|██████████| 189/189 [03:44<00:00,  1.19s/it, loss=0.00245] 


Epoch 9/25, Loss: 0.3303, Binary Accuracy: 100.00%, Type Accuracy: 100.00%


Epoch [10/25]: 100%|██████████| 189/189 [03:44<00:00,  1.19s/it, loss=0.00156] 


Epoch 10/25, Loss: 0.2066, Binary Accuracy: 100.00%, Type Accuracy: 100.00%


Epoch [11/25]: 100%|██████████| 189/189 [03:46<00:00,  1.20s/it, loss=0.00159] 


Epoch 11/25, Loss: 0.1645, Binary Accuracy: 100.00%, Type Accuracy: 100.00%


Epoch [12/25]: 100%|██████████| 189/189 [03:53<00:00,  1.24s/it, loss=0.0143] 


Epoch 12/25, Loss: 46.9369, Binary Accuracy: 98.53%, Type Accuracy: 93.72%


Epoch [13/25]: 100%|██████████| 189/189 [03:47<00:00,  1.20s/it, loss=0.00698]


Epoch 13/25, Loss: 9.1098, Binary Accuracy: 99.85%, Type Accuracy: 99.00%


Epoch [14/25]: 100%|██████████| 189/189 [03:48<00:00,  1.21s/it, loss=0.00232] 


Epoch 14/25, Loss: 1.2435, Binary Accuracy: 100.00%, Type Accuracy: 99.95%


Epoch [15/25]: 100%|██████████| 189/189 [03:45<00:00,  1.19s/it, loss=0.0028]  


Epoch 15/25, Loss: 0.3998, Binary Accuracy: 100.00%, Type Accuracy: 100.00%


Epoch [16/25]: 100%|██████████| 189/189 [03:50<00:00,  1.22s/it, loss=0.000835]


Epoch 16/25, Loss: 0.2677, Binary Accuracy: 100.00%, Type Accuracy: 100.00%


Epoch [17/25]: 100%|██████████| 189/189 [03:45<00:00,  1.19s/it, loss=0.000678]


Epoch 17/25, Loss: 0.1866, Binary Accuracy: 100.00%, Type Accuracy: 100.00%


Epoch [18/25]: 100%|██████████| 189/189 [03:46<00:00,  1.20s/it, loss=0.00117] 


Epoch 18/25, Loss: 0.1578, Binary Accuracy: 100.00%, Type Accuracy: 100.00%


Epoch [19/25]: 100%|██████████| 189/189 [03:46<00:00,  1.20s/it, loss=0.0017]  


Epoch 19/25, Loss: 0.1232, Binary Accuracy: 100.00%, Type Accuracy: 100.00%


Epoch [20/25]: 100%|██████████| 189/189 [03:45<00:00,  1.19s/it, loss=0.00206] 


Epoch 20/25, Loss: 0.1198, Binary Accuracy: 100.00%, Type Accuracy: 100.00%


Epoch [21/25]: 100%|██████████| 189/189 [03:46<00:00,  1.20s/it, loss=0.000884]


Epoch 21/25, Loss: 0.0866, Binary Accuracy: 100.00%, Type Accuracy: 100.00%


Epoch [22/25]: 100%|██████████| 189/189 [03:45<00:00,  1.19s/it, loss=0.000369]


Epoch 22/25, Loss: 0.0787, Binary Accuracy: 100.00%, Type Accuracy: 100.00%


Epoch [23/25]: 100%|██████████| 189/189 [03:45<00:00,  1.19s/it, loss=0.00106] 


Epoch 23/25, Loss: 0.1210, Binary Accuracy: 100.00%, Type Accuracy: 100.00%


Epoch [24/25]: 100%|██████████| 189/189 [03:46<00:00,  1.20s/it, loss=0.000578]


Epoch 24/25, Loss: 0.0665, Binary Accuracy: 100.00%, Type Accuracy: 100.00%


Epoch [25/25]: 100%|██████████| 189/189 [03:45<00:00,  1.19s/it, loss=0.000297]

Epoch 25/25, Loss: 0.0572, Binary Accuracy: 100.00%, Type Accuracy: 100.00%





In [None]:
# Plots for loss and accuracy
import matplotlib.pyplot as plt
import numpy as np

train_losses = [
    156.7177, 11.0435, 3.3406, 2.1572, 0.9147, 0.5423, 0.4173, 0.3516, 0.3303, 0.2066,
    0.1645, 46.9369, 9.1098, 1.2435, 0.3998, 0.2677, 0.1866, 0.1578, 0.1232, 0.1198,
    0.0866, 0.0787, 0.1210, 0.0665, 0.0572
]

train_binary_accuracies = [
    96.36, 99.88, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0,
    100.0, 98.53, 99.85, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0,
    100.0, 100.0, 100.0, 100.0, 100.0
]

train_type_accuracies = [
    81.26, 99.40, 99.84, 99.87, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0,
    100.0, 93.72, 99.0, 99.95, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0,
    100.0, 100.0, 100.0, 100.0, 100.0
]

epochs = list(range(1, 26))

# --- Plot Loss ---
plt.figure(figsize=(8,6))
plt.plot(epochs, train_losses, label='Train Loss')
plt.title('Training Loss vs. Epochs')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid()
plt.show()

# --- Plot Binary Accuracy ---
plt.figure(figsize=(8,6))
plt.plot(epochs, train_binary_accuracies, label='Binary Accuracy', color='green')
plt.title('Binary Fracture Detection Accuracy vs. Epochs')
plt.xlabel('Epochs')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.grid()
plt.show()

# --- Plot Fracture Type Accuracy ---
plt.figure(figsize=(8,6))
plt.plot(epochs, train_type_accuracies, label='Fracture Type Accuracy', color='orange')
plt.title('Fracture Type Classification Accuracy vs. Epochs')
plt.xlabel('Epochs')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.grid()
plt.show()