In [1]:
import pandas as pd
import numpy as np
import os
import cv2
import matplotlib.pyplot as plt
import torch
import torchvision
from torchvision import models, datasets, transforms
from torch.utils.data import DataLoader, random_split
import torch.nn as nn
import torch.optim as optim
from PIL import Image, ImageFile

In [2]:
data = pd.read_csv('datasets\FracAtlas\dataset.csv')

In [3]:
data.head(2)

Unnamed: 0,image_id,hand,leg,hip,shoulder,mixed,hardware,multiscan,fractured,fracture_count,frontal,lateral,oblique
0,IMG0000000.jpg,0,1,0,0,0,0,1,0,0,1,1,0
1,IMG0000001.jpg,0,1,0,0,0,0,1,0,0,1,1,0


In [4]:
base_path = 'datasets\FracAtlas\images'
fractured_path = os.path.join(base_path, 'Fractured')
non_fractured_path = os.path.join(base_path, 'Non_fractured')
non_fractured_path

'datasets\\FracAtlas\\images\\Non_fractured'

In [5]:
image_paths = {img: os.path.join(fractured_path, img) for img in os.listdir(fractured_path)}
image_paths.update({img: os.path.join(non_fractured_path, img) for img in os.listdir(non_fractured_path)})

Preprocessing !!

In [6]:
data_directory = 'datasets\FracAtlas\images'
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


In [7]:
# This makes PIL ignore bytes beyond the truncation point
ImageFile.LOAD_TRUNCATED_IMAGES = True

# Custom loader that uses PIL and ignores errors in truncated images
def custom_loader(path):
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')
        
dataset = datasets.ImageFolder(root=data_directory, transform=transform, loader=custom_loader)


In [8]:
train_size = int(0.8 * len(dataset))
train_size

3266

In [9]:
val_size = len(dataset) - train_size
val_size

817

In [10]:
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
train_dataset

<torch.utils.data.dataset.Subset at 0x17a179e4dc0>

In [11]:
train_loader = DataLoader(train_dataset, batch_size= 32, shuffle= True)
val_loader = DataLoader(val_dataset, batch_size=32)
train_loader

<torch.utils.data.dataloader.DataLoader at 0x17a179e4ee0>

Transfer Learning func 

In [12]:
def modify_model(model, num_classes=2):
    if hasattr(model, 'fc'):  # ResNet
        num_ftrs = model.fc.in_features
        model.fc = nn.Linear(num_ftrs, num_classes)
    elif hasattr(model, 'classifier'):  # DenseNet
        if isinstance(model.classifier, nn.Sequential):
            *layers, last_layer = model.classifier.children()
            num_ftrs = last_layer.in_features
            new_last_layer = nn.Linear(num_ftrs, num_classes)
            model.classifier = nn.Sequential(*layers, new_last_layer)
        else:
            num_ftrs = model.classifier.in_features
            model.classifier = nn.Linear(num_ftrs, num_classes)
    else:
        raise Exception("Unknown model architecture")
    return model

Training Func

In [13]:
def train_model(model, criterion, optimizer, num_epochs=1):
    for epoch in range(num_epochs):
        model.train()
        total_loss, total_correct = 0, 0

        for inputs, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            total_correct += (outputs.argmax(1) == labels).sum().item()

        train_accuracy = total_correct / len(train_dataset)

        model.eval()
        val_correct = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                outputs = model(inputs)
                val_correct += (outputs.argmax(1) == labels).sum().item()

        val_accuracy = val_correct / len(val_dataset)
        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {total_loss:.4f}, Train Acc: {train_accuracy:.4f}, Val Acc: {val_accuracy:.4f}")

    return model

In [14]:
model_save_dir = 'models'
if not os.path.exists(model_save_dir):
    os.makedirs(model_save_dir)

In [15]:
models_to_try = [models.efficientnet_b0, models.resnet18, models.densenet121]

for base_model_func in models_to_try:
    print(f"Training {base_model_func.__name__}")
    base_model = base_model_func(pretrained=True)
    model = modify_model(base_model)
    criterion = nn.CrossEntropyLoss()
    if hasattr(model, 'fc'):
        optimizer = optim.Adam(model.fc.parameters(), lr=0.001)
    elif hasattr(model, 'classifier'):
        optimizer = optim.Adam(model.classifier.parameters(), lr=0.001)
    else:
        raise Exception("Unknown model architecture")
    trained_model = train_model(model, criterion, optimizer)
    # Save the model
    model_path = os.path.join(model_save_dir, f"{base_model_func.__name__}_model.pth")
    torch.save(trained_model.state_dict(), model_path)


Training efficientnet_b0


