<a href="https://colab.research.google.com/github/SujayKrish03/Medical-image-classifier/blob/main/medicalmain.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

                                        **Medical and Non Medical Image Classifier**




The code imports libraries for building a CNN-based medical image classifier using PyTorch, enabling image preprocessing, dataset loading, and model training for tasks like disease detection from X-rays or MRIs. It also includes tools for web scraping medical images, file management, and image processing to support data collection and preparation.

In [1]:
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torch.nn as nn
import requests
from bs4 import BeautifulSoup
import urllib.request
import os
from PIL import Image
import shutil

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
!pip install kaggle



In [4]:


# Upload kaggle.json
from google.colab import files
uploaded = files.upload()  # Upload kaggle.json
if not os.path.exists('/content/kaggle.json'):
    print("Error: kaggle.json not uploaded correctly.")
    exit()

Saving kaggle.json to kaggle.json


In [5]:
!mkdir -p ~/.kaggle
!mv kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

In [6]:
!kaggle datasets list || echo "Kaggle API setup failed. Check kaggle.json."

ref                                                        title                                                    size  lastUpdated                 downloadCount  voteCount  usabilityRating  
---------------------------------------------------------  -------------------------------------------------  ----------  --------------------------  -------------  ---------  ---------------  
rohitgrewal/airlines-flights-data                          Airlines Flights Data                                 2440299  2025-07-29 09:16:00.463000           8978        161  1.0              
wasiqaliyasir/breast-cancer-dataset                        Breast cancer dataset                                   49830  2025-07-30 12:52:44.057000           5326        179  1.0              
kunshbhatia/delhi-air-quality-dataset                      Delhi Air Quality Dataset                               30430  2025-07-28 14:00:14.247000           3758         82  1.0              
abdulmalik1518/cars-datasets-2

In [7]:
os.makedirs('/content/data/chest_xray', exist_ok=True)

Importing(Downloading) the medical dataset from kaggle

In [8]:
!kaggle datasets download -d paultimothymooney/chest-xray-pneumonia
!unzip -q chest-xray-pneumonia.zip -d /content/data/chest_xray/

Dataset URL: https://www.kaggle.com/datasets/paultimothymooney/chest-xray-pneumonia
License(s): other
Downloading chest-xray-pneumonia.zip to /content
100% 2.29G/2.29G [00:16<00:00, 174MB/s]
100% 2.29G/2.29G [00:16<00:00, 149MB/s]


In [9]:
from torchvision.datasets import CIFAR10
cifar10_dataset = CIFAR10(root='/content/data/cifar10', train=True, download=True)

100%|██████████| 170M/170M [00:05<00:00, 31.1MB/s]


In [10]:
os.makedirs('/content/data/train/medical', exist_ok=True)
os.makedirs('/content/data/train/non-medical', exist_ok=True)

In [11]:
!mv /content/data/chest_xray/chest_xray/train/NORMAL/* /content/data/train/medical/ 2>/dev/null
!mv /content/data/chest_xray/chest_xray/train/PNEUMONIA/* /content/data/train/medical/ 2>/dev/null


In [12]:
for i, (image, _) in enumerate(cifar10_dataset):
    image.save(f'/content/data/train/non-medical/image_{i}.jpg')

In [13]:
!ls /content/data/train/medical | wc -l  # Count medical images
!ls /content/data/train/non-medical | wc -l  # Count non-medical images

5216
50002


It resizes images to a uniform 224x224 resolution, converts them to PyTorch tensors, and normalizes pixel values using a mean and standard deviation of 0.485 and 0.229, respectively

In [14]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.485, 0.485], std=[0.229, 0.229, 0.229])
])

In [15]:
try:
    train_dataset = datasets.ImageFolder(root="/content/data/train", transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    print(f"Dataset loaded with classes: {train_dataset.classes}")
except Exception as e:
    print(f"Error loading dataset: {e}")
    exit()

Dataset loaded with classes: ['medical', 'non-medical']


This code defines a `SimpleCNN` class, a basic Convolutional Neural Network (CNN) for a medical image classifier, built using PyTorch's `nn.Module`. In the `__init__` method, it initializes a convolutional layer (`conv1`) that takes 3-channel input images (e.g., RGB medical images), applies 16 filters of size 3x3 with padding, followed by a max-pooling layer (`pool`) to reduce spatial dimensions by half, and a fully connected layer (`fc1`) that outputs scores for two classes (e.g., medical vs. non-medical or disease vs. no disease). The `forward` method processes input images through convolution, ReLU activation, pooling, flattening, and the final linear layer to produce classification outputs, suitable for binary medical image classification tasks like detecting abnormalities in X-rays or MRIs.

In [16]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(16 * 112 * 112, 2)  # 2 classes: medical, non-medical
    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = x.view(-1, 16 * 112 * 112)
        x = self.fc1(x)
        return x

In [17]:
model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

This code snippet implements the training loop for the `SimpleCNN` medical image classifier over 5 epochs. In each epoch, it sets the model to training mode, iterates through batches of images and labels from `train_loader`, moves data to the specified device (GPU/CPU), computes the forward pass, calculates the loss using a predefined `criterion` (e.g., cross-entropy loss), backpropagates the gradients, and updates the model parameters with the `optimizer`. The running loss is accumulated and averaged over the number of batches, printing the average loss per epoch to monitor training progress for classifying medical images, such as identifying diseases in X-rays or MRIs.

In [None]:
num_epochs = 5
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")

Epoch [1/5], Loss: 0.0960


In [None]:
torch.save(model.state_dict(), 'medical_non_medical_model.pth')
print("Model saved as medical_non_medical_model.pth")

In [None]:
def download_images_from_url(url, save_dir="web_images"):
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    headers = {'User-Agent': 'Mozilla/5.0'}
    try:
        response = requests.get(url, headers=headers, timeout=10)
        response.raise_for_status()
        soup = BeautifulSoup(response.text, 'html.parser')
    except requests.RequestException as e:
        print(f"Error fetching webpage {url}: {e}")
        return []

    img_urls = []
    img_tags = soup.find_all('img')
    for img in img_tags:
        src = img.get('src')
        if src and (src.endswith('.jpg') or src.endswith('.jpeg') or src.endswith('.png')):
            if not src.startswith('http'):
                src = urllib.parse.urljoin(url, src)
            img_urls.append(src)

    downloaded_paths = []
    for i, img_url in enumerate(img_urls):
        try:
            img_name = os.path.join(save_dir, f'image_{i}.jpg')
            urllib.request.urlretrieve(img_url, img_name)
            downloaded_paths.append((img_name, img_url))  # Store path and URL
            print(f"Downloaded: {img_name}")
        except Exception as e:
            print(f"Failed to download {img_url}: {e}")

    return downloaded_paths

In [None]:
def predict_image(image_path, model, transform, device):
    try:
        image = Image.open(image_path).convert('RGB')
        image = transform(image).unsqueeze(0).to(device)
        with torch.no_grad():
            output = model(image)
            probabilities = torch.softmax(output, dim=1)
            _, predicted = torch.max(output, 1)
            class_names = ['medical', 'non-medical']
            predicted_class = class_names[predicted.item()]
            confidence = probabilities[0][predicted.item()].item()
        return predicted_class, confidence
    except Exception as e:
        return None, f"Error processing {image_path}: {e}"

In [None]:
model.load_state_dict(torch.load('medical_non_medical_model.pth', map_location=device))
model.eval()

The download_images_from_url function is designed to scrape and download images from a specified webpage, which can be used to collect medical images (e.g., X-rays, MRIs) for a CNN-based medical classifier. It creates a directory (save_dir, defaulting to "web_images")

In [None]:
url = "https://en.wikipedia.org/wiki/Sachin_Tendulkar"
image_paths_urls = download_images_from_url(url)

In [None]:
for image_path, image_url in image_paths_urls:
    predicted_class, confidence = predict_image(image_path, model, transform, device)
    if predicted_class:
        print(f"Image: {image_path}, URL: {image_url}, Predicted: {predicted_class}, Confidence: {confidence:.4f}")
    else:
        print(f"Image: {image_path}, URL: {image_url}, Error: {confidence}")

In [None]:
!pip install pdf2image
!apt-get install -y poppler-utils


In [None]:
from pdf2image import convert_from_path

def extract_images_from_pdf(pdf_files, save_dir="pdf_images"):
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    extracted_paths = []
    for pdf_file in pdf_files:
        try:
            # Convert PDF to images
            images = convert_from_path(pdf_file)
            pdf_name = os.path.basename(pdf_file)
            for i, image in enumerate(images):
                img_name = os.path.join(save_dir, f'{pdf_name}_page_{i+1}.jpg')
                image.save(img_name, 'JPEG')
                extracted_paths.append((img_name, None, pdf_name))  # None for URL
                print(f"Extracted: {img_name} from {pdf_name}")
        except Exception as e:
            print(f"Error processing PDF {pdf_file}: {e}")

    return extracted_paths


In [None]:
from google.colab import files

# Upload PDF files
uploaded_pdfs = files.upload()
pdf_paths = list(uploaded_pdfs.keys())

# Extract and classify
extracted_images = extract_images_from_pdf(pdf_paths)

for image_path, _, source_pdf in extracted_images:
    predicted_class, confidence = predict_image(image_path, model, transform, device)
    if predicted_class:
        print(f"Image: {image_path}, Source PDF: {source_pdf}, Predicted: {predicted_class}, Confidence: {confidence:.4f}")
    else:
        print(f"Image: {image_path}, Source PDF: {source_pdf}, Error: {confidence}")


In [None]:
!pip install gradio pdf2image
!apt-get install -y poppler-utils


Implementing a simple UI for the task using Gradio

In [None]:
import gradio as gr
from pdf2image import convert_from_path
import os
import tempfile
from PIL import Image

# Assume these are already defined and imported:
# - predict_image(image_path, model, transform, device)
# - download_images_from_url(url) → returns List[(img_path, img_url)]
# - extract_images_from_pdf(pdf_paths) → returns List[(img_path, page_num, pdfname)]

# 🔹 Predict for Single Image Upload
def predict_single_image(image):
    if image is None:
        return []

    with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
        temp_path = temp_file.name
        image.save(temp_path)

    pred, conf = predict_image(temp_path, model, transform, device)
    return [(temp_path, f"{pred} ({conf:.2f})")]

# 🔹 Predict from URL (Scrape and classify multiple images)
def predict_from_url(url):
    results = []
    image_paths = download_images_from_url(url)

    for img_path, img_url in image_paths:  # limit for performance
        pred, conf = predict_image(img_path, model, transform, device)
        results.append((img_path, f"{pred} ({conf:.2f})"))
    return results

# 🔹 Predict from PDF (Extract images and classify)
def predict_from_pdf(pdf):
    if pdf is None:
        return []

    with open("temp.pdf", "wb") as f:
        f.write(pdf.read())

    extracted = extract_images_from_pdf(["temp.pdf"])
    results = []
    for img_path, _, _ in extracted:  # Limit for speed
        pred, conf = predict_image(img_path, model, transform, device)
        results.append((img_path, f"{pred} ({conf:.2f})"))
    return results


In [None]:
with gr.Blocks() as demo:
    gr.Markdown("### 🩻 Medical vs Non-Medical Image Classifier")

    with gr.Tab("Single Image Upload"):
        img_input = gr.Image(type="pil")
        img_btn = gr.Button("Classify Image")
        img_out = gr.Gallery(label="Prediction", show_label=True, columns=3, height=300)


    with gr.Tab("From URL"):
        url_input = gr.Textbox(label="Enter Web Page URL")
        url_btn = gr.Button("Fetch & Classify")
        url_out = gr.Gallery(label="Predictions", show_label=True, columns=3, height=300)


    with gr.Tab("From PDF"):
        pdf_input = gr.File(file_types=[".pdf"], label="Upload PDF")
        pdf_btn = gr.Button("Extract & Classify")
        pdf_out = gr.Gallery(label="Predictions", show_label=True)

    img_btn.click(fn=predict_single_image, inputs=img_input, outputs=img_out)
    url_btn.click(fn=predict_from_url, inputs=url_input, outputs=url_out)
    pdf_btn.click(fn=predict_from_pdf, inputs=pdf_input, outputs=pdf_out)

demo.launch()
