In [13]:
from fastai.vision.all import *
import timm

In [15]:
# Step 1: Load and Prepare Your Data
path = "./chest_xray/"  # Replace with the actual path to your dataset


In [17]:
# Define the DataBlock
dblock = DataBlock(
    blocks=(ImageBlock, CategoryBlock), 
    get_items=get_image_files, 
    splitter=RandomSplitter(valid_pct=0.2, seed=42),
    get_y=parent_label,
    item_tfms=Resize(460),
    batch_tfms=[*aug_transforms(size=224), Normalize.from_stats(*imagenet_stats)]
)

In [20]:
# Create the DataLoaders
dls = dblock.dataloaders(path, bs=64)

In [23]:
# Step 2: Create the Learner with Xception
model_path = "./chest_xray_model.pth"

In [25]:
# Define the custom Xception architecture
class XceptionModel(nn.Module):
    def __init__(self, num_classes=3, pretrained=True):
        super(XceptionModel, self).__init__()
        self.base_model = timm.create_model("xception", pretrained=pretrained)
        # Modify the last fully connected layer to match the number of classes
        self.fc = nn.Linear(self.base_model.num_features, num_classes)

    def forward(self, x):
        x = self.base_model.forward_features(x)
        x = F.adaptive_avg_pool2d(x, (1, 1)).reshape(x.size(0), -1)
        x = self.fc(x)
        return x

In [27]:
learn = Learner(dls, XceptionModel(num_classes=2), metrics=accuracy)

In [28]:
# Check if a saved model exists
if Path(model_path).exists():
    # Load the saved model if it exists
    learn.load(model_path)
else:
    # Create a new learner and train it if the saved model doesn't exist
    learn.fine_tune(5)
    # Save the trained model
    learn.save(model_path)

epoch,train_loss,valid_loss,accuracy,time
0,0.20358,2.460959,0.711358,2:56:14


epoch,train_loss,valid_loss,accuracy,time
0,0.099475,0.127784,0.962425,2:58:15
1,0.101834,0.082557,0.965841,2:30:55
2,0.091238,0.086764,0.969257,2:30:34
3,0.0721,0.093992,0.972673,2:30:09
4,0.049818,0.069485,0.978651,2:28:50


In [31]:
# Step 3: Make Predictions on a New Image
# Replace "path/to/your/image.jpg" with the actual path to your image file
image_path = "chest_xray\\val\\NORMAL\\NORMAL2-IM-1427-0001.jpeg"

img = PILImage.create(image_path)

# Make a prediction
prediction, _, _ = learn.predict(img)
print(f"Predicted label: {prediction}")

  return getattr(torch, 'has_mps', False)


Predicted label: PNEUMONIA
