In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
from torchvision.models.resnet import ResNet50_Weights
from torchvision.models.densenet import DenseNet121_Weights
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

In [2]:
# cuda availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# transforming to imagenet format
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


In [3]:
# Loading the dataset
train_dataset = datasets.ImageFolder("/kaggle/input/fractatlas/train", transform=transform) #Update the path to dataset
val_dataset = datasets.ImageFolder("/kaggle/input/fractatlas/val", transform=transform) #Update the path to dataset

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=64, shuffle=False)

## Resnet

In [4]:
# Pretrained, SOTA resnet model for imagenet
model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)

# freezing all layers in the model
for param in model.parameters():
    param.requires_grad = False

# Replacing number of classes to classify into, to our desired number
num_classes = 2 # Binary Classification - Fractured or Non_Fractured
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, num_classes)
model = model.to(device)


Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 178MB/s] 


In [5]:
# loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.fc.parameters(), lr=0.001, momentum=0.9)

In [6]:
# fine tuning step

num_epochs = 3
i=0
for epoch in range(num_epochs):
    model.train()
    for inputs, labels in train_loader:
        try:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            i+=1

        except OSError as e:
            print("Skipping a corrupted image", i)

    # Validation Accuracy
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    print(f"Epoch {epoch + 1}/{num_epochs}, Validation Accuracy: {100 * correct / total:.2f}%")


Epoch 1/3, Validation Accuracy: 82.38%
Epoch 2/3, Validation Accuracy: 82.38%
Epoch 3/3, Validation Accuracy: 82.54%


In [7]:
# Saving the model
torch.save(model, '/kaggle/working/model_resnet50.pth')

## DenseNet

In [8]:
# Pretrained, densenet model for imagenet
densenet_model = models.densenet121(weights=DenseNet121_Weights.IMAGENET1K_V1)


# freezing all layers in the model
for param in densenet_model.parameters():
    param.requires_grad = False

# Replacing number of classes to classify into, to our desired number
num_classes = 2  # Binary Classification - Fractured or Non_Fractured
num_ftrs = densenet_model.classifier.in_features
densenet_model.classifier = nn.Linear(num_ftrs, num_classes)
densenet_model = densenet_model.to(device)

Downloading: "https://download.pytorch.org/models/densenet121-a639ec97.pth" to /root/.cache/torch/hub/checkpoints/densenet121-a639ec97.pth
100%|██████████| 30.8M/30.8M [00:00<00:00, 149MB/s] 


In [9]:
# loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(densenet_model.parameters(), lr=0.001, momentum=0.9)

In [10]:
# fine tuning step

num_epochs = 3
i=0
for epoch in range(num_epochs):
    densenet_model.train()
    for inputs, labels in train_loader:
        try:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = densenet_model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            i+=1

        except OSError as e:
            print("Skipping a corrupted image", i)

    # Validation Accuracy
    densenet_model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = densenet_model(inputs)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    print(f"Epoch {epoch + 1}/{num_epochs}, Validation Accuracy: {100 * correct / total:.2f}%")


Epoch 1/3, Validation Accuracy: 83.03%
Epoch 2/3, Validation Accuracy: 86.30%
Epoch 3/3, Validation Accuracy: 86.30%


In [11]:
# Saving the model
torch.save(model, '/kaggle/working/model_densenet121.pth')