In [4]:
import torch
from torch import nn
import matplotlib.pyplot as plt
import sys
import os
sys.path.append(os.path.abspath("../src/"))  
from data_loader import get_dataloaders
from model import build_model
from train import train_model
from evaluate import test_model



ImportError: cannot import name 'build_model' from 'model' (c:\Users\User\Desktop\Medical-Image-Classification\src\model.py)

In [46]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cpu


In [None]:
# Define paths
data_dir = "data"  
batch_size = 32
img_size = 224

# Load data
train_loader, val_loader, test_loader, class_names = get_dataloaders(
    data_dir=data_dir,
    batch_size=batch_size,
    img_size=img_size
)

print("Classes:", class_names)

ImportError: cannot import name 'get_dataloaders' from 'src.data_loader' (c:\Users\User\Desktop\Medical-Image-Classification\src\data_loader.py)

In [None]:
model = build_model(num_classes=len(class_names), pretrained=True)
model = model.to(device)

print(model)

In [None]:
epochs = 10
lr = 0.001

model, history = train_model(
    model, 
    train_loader, 
    val_loader, 
    device, 
    epochs=epochs, 
    lr=lr
)

In [None]:
test_model(model, test_loader, device, class_names)

In [None]:
torch.save(model.state_dict(), "chest_xray_model.pth")
print("✅ Model saved as chest_xray_model.pth")

In [None]:
import numpy as np
import cv2
import matplotlib.pyplot as plt
from torchvision import transforms

def generate_gradcam(model, image, target_class, device, layer_name="layer4"):
    """
    Generate Grad-CAM heatmap for a given image and class.
    """
    model.eval()
    gradients = []
    activations = []

    def backward_hook(module, grad_input, grad_output):
        gradients.append(grad_output[0].detach())

    def forward_hook(module, input, output):
        activations.append(output.detach())

    # Register hooks
    target_layer = dict([*model.named_modules()])[layer_name]
    forward_handle = target_layer.register_forward_hook(forward_hook)
    backward_handle = target_layer.register_backward_hook(backward_hook)

    # Forward pass
    image = image.unsqueeze(0).to(device)
    output = model(image)
    pred_class = output.argmax(dim=1).item()

    # Backward pass for target class
    model.zero_grad()
    class_score = output[0, target_class]
    class_score.backward()

    # Extract gradients and activations
    grads = gradients[0].cpu().numpy()[0]
    acts = activations[0].cpu().numpy()[0]

    # Compute weights
    weights = np.mean(grads, axis=(1, 2))
    cam = np.zeros(acts.shape[1:], dtype=np.float32)

    for i, w in enumerate(weights):
        cam += w * acts[i]

    # Normalize
    cam = np.maximum(cam, 0)
    cam = cv2.resize(cam, (224, 224))
    cam = cam - np.min(cam)
    cam = cam / np.max(cam)

    # Clean up hooks
    forward_handle.remove()
    backward_handle.remove()

    return cam, pred_class


def show_gradcam(model, img_path, class_names, device):
    transform = transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485], [0.229])
    ])

    img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
    img_rgb = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
    pil_img = transforms.ToPILImage()(img_rgb)
    input_tensor = transform(pil_img)

    # Choose target class (for demo, use model prediction)
    _, pred_class = torch.max(model(input_tensor.unsqueeze(0).to(device)), 1)
    cam, predicted_class = generate_gradcam(model, input_tensor, target_class=pred_class.item(), device=device)

    heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    img_norm = np.float32(img_rgb) / 255
    overlay = heatmap + img_norm
    overlay = overlay / np.max(overlay)

    plt.figure(figsize=(12,6))
    plt.subplot(1,3,1); plt.imshow(img_rgb); plt.title("Original X-ray"); plt.axis("off")
    plt.subplot(1,3,2); plt.imshow(cam, cmap="jet"); plt.title("Grad-CAM Heatmap"); plt.axis("off")
    plt.subplot(1,3,3); plt.imshow(overlay); plt.title(f"Prediction: {class_names[pred_class]}"); plt.axis("off")
    plt.show()
