<a href="https://colab.research.google.com/github/AaryanAnand10/Implant/blob/main/Untitled5.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split, Subset
from torchvision import datasets, transforms, models
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix
from PIL import Image
import time
import copy

# Check for GPU availability and set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Path to the knee implant folder
# Ensure your Google Drive is mounted if using Colab
# from google.colab import drive
# drive.mount('/content/drive')
base_path = "/content/drive/MyDrive/Knee" # Adjust if your path is different

# Basic parameters
img_size = 224 # Standard size for many CNNs
batch_size = 32 # Adjust based on GPU memory
learning_rate = 0.001
num_epochs = 30 # Start with a moderate number, can increase if needed
patience = 7 # For early stopping

print("Step 1: Setup complete")

In [None]:
# Define transformations for training and validation sets
# Use standard ImageNet normalization stats as a starting point
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

# Augmentation for training data
train_transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.RandomRotation(10), # Limited rotation for medical images
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.1, contrast=0.1), # Slight color adjustments
    transforms.ToTensor(),
    normalize,
])

# Minimal transformation for validation data
val_transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    normalize,
])

# Load the dataset using ImageFolder
try:
    full_dataset = datasets.ImageFolder(base_path)
    class_names = full_dataset.classes
    num_classes = len(class_names)
    print(f"Found {len(full_dataset)} images in {num_classes} classes:")
    for i, class_name in enumerate(class_names):
        print(f"- {class_name} (Index: {i})")

    # Split dataset into training and validation sets (80/20 split)
    val_split = 0.2
    val_size = int(val_split * len(full_dataset))
    train_size = len(full_dataset) - val_size

    # Ensure reproducibility
    generator = torch.Generator().manual_seed(42)
    train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size], generator=generator)

    # Apply respective transformations
    train_dataset.dataset.transform = train_transform
    # We need to wrap the validation subset to apply the transform correctly
    # Create a new Subset with the correct transform
    val_indices = val_dataset.indices
    val_dataset_transformed = Subset(datasets.ImageFolder(base_path, transform=val_transform), val_indices)


    # Create DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_dataset_transformed, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

    dataloaders = {'train': train_loader, 'val': val_loader}
    dataset_sizes = {'train': len(train_dataset), 'val': len(val_dataset_transformed)}

    print(f"Training samples: {dataset_sizes['train']}, Validation samples: {dataset_sizes['val']}")
    print("Step 2: Data loading and preprocessing complete")

except FileNotFoundError:
    print(f"Error: Dataset path not found at {base_path}. Please check the path.")
except Exception as e:
    print(f"An error occurred during data loading: {e}")


In [None]:
# Define the CNN architecture
class KneeImplantCNN(nn.Module):
    def __init__(self, num_classes):
        super(KneeImplantCNN, self).__init__()
        # Convolutional Block 1
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32) # BN after Conv, before ReLU [[8]]
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) # Output: 112x112

        # Convolutional Block 2
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) # Output: 56x56

        # Convolutional Block 3
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.relu3 = nn.ReLU()
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) # Output: 28x28

        # Convolutional Block 4 (Optional, keep model smaller)
        # self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        # self.bn4 = nn.BatchNorm2d(256)
        # self.relu4 = nn.ReLU()
        # self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) # Output: 14x14

        # Flatten layer
        self.flatten = nn.Flatten()

        # Fully Connected Layers
        # Calculate the flattened size dynamically (assuming Block 3 is last)
        # Size = 128 filters * 28 * 28 feature map size
        self.fc1 = nn.Linear(128 * 28 * 28, 512)
        self.bn_fc1 = nn.BatchNorm1d(512) # BN for dense layer
        self.relu_fc1 = nn.ReLU()
        self.dropout1 = nn.Dropout(0.5) # Dropout for regularization

        self.fc2 = nn.Linear(512, num_classes) # Output layer

    def forward(self, x):
        x = self.pool1(self.relu1(self.bn1(self.conv1(x))))
        x = self.pool2(self.relu2(self.bn2(self.conv2(x))))
        x = self.pool3(self.relu3(self.bn3(self.conv3(x))))
        # if using Block 4: x = self.pool4(self.relu4(self.bn4(self.conv4(x))))

        x = self.flatten(x)
        x = self.dropout1(self.relu_fc1(self.bn_fc1(self.fc1(x))))
        x = self.fc2(x) # Raw logits output
        return x

# Instantiate the model and move it to the device
model = KneeImplantCNN(num_classes).to(device)
print(model)
print("Step 3: CNN model defined")

In [None]:
# Loss function - CrossEntropyLoss includes Softmax
criterion = nn.CrossEntropyLoss()

# Optimizer - Adam is a good default
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Learning Rate Scheduler - Reduce LR on plateau
# Monitors validation loss and reduces LR if it stops improving
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3, verbose=True)

print("Step 4: Loss, optimizer, and scheduler defined")

In [None]:
def train_model(model, criterion, optimizer, scheduler, num_epochs=25, patience=7):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    epochs_no_improve = 0 # Counter for early stopping

    # Store history
    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}

    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # Zero the parameter gradients
                optimizer.zero_grad()

                # Forward pass
                # Track history only in train phase
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # Backward pass + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # Statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            # Store history
            if phase == 'train':
                history['train_loss'].append(epoch_loss)
                history['train_acc'].append(epoch_acc.item()) # Use .item() to get Python number
            else:
                history['val_loss'].append(epoch_loss)
                history['val_acc'].append(epoch_acc.item())

                # Adjust learning rate based on validation loss
                scheduler.step(epoch_loss)

                # Check for improvement for early stopping and best model saving
                if epoch_acc > best_acc:
                    print(f"Validation accuracy improved ({best_acc:.4f} --> {epoch_acc:.4f}). Saving model...")
                    best_acc = epoch_acc
                    best_model_wts = copy.deepcopy(model.state_dict())
                    epochs_no_improve = 0
                    # Save the best model weights
                    torch.save(model.state_dict(), 'knee_implant_cnn_best.pth')
                else:
                    epochs_no_improve += 1
                    print(f"Validation accuracy did not improve for {epochs_no_improve} epochs.")

        # Early stopping check
        if epochs_no_improve >= patience:
            print(f"Early stopping triggered after {epoch+1} epochs.")
            break
        print()


    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:.4f}')

    # Load best model weights
    model.load_state_dict(best_model_wts)
    return model, history

# Start training
print("Step 5: Starting model training...")
model, history = train_model(model, criterion, optimizer, scheduler, num_epochs=num_epochs, patience=patience)
print("Training finished.")

In [None]:
# Plot training history
def plot_history(history):
    # Convert tensor accuracies to float if needed
    train_acc = [acc for acc in history['train_acc']]
    val_acc = [acc for acc in history['val_acc']]

    plt.figure(figsize=(12, 5))

    # Plot accuracy
    plt.subplot(1, 2, 1)
    plt.plot(train_acc, label='Train Accuracy')
    plt.plot(val_acc, label='Validation Accuracy')
    plt.title('Model Accuracy')
    plt.ylabel('Accuracy')
    plt.xlabel('Epoch')
    plt.legend()

    # Plot loss
    plt.subplot(1, 2, 2)
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Validation Loss')
    plt.title('Model Loss')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.legend()

    plt.tight_layout()
    plt.savefig('pytorch_training_history.png')
    plt.show()

print("Step 6: Plotting training history...")
plot_history(history)

In [None]:
# Evaluate the model
def evaluate_model(model, dataloader):
    model.eval()  # Set model to evaluation mode
    all_preds = []
    all_labels = []

    with torch.no_grad(): # Disable gradient calculations
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Calculate overall accuracy
    accuracy = np.sum(np.array(all_preds) == np.array(all_labels)) / len(all_labels)
    print(f"\nOverall Validation Accuracy: {accuracy:.4f}")

    # Classification report
    print("\nClassification Report:")
    print(classification_report(all_labels, all_preds, target_names=class_names, zero_division=0))

    # Confusion matrix
    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(15, 12))

    # Use shorter names for readability
    short_names = [name[:12] + '...' if len(name) > 12 else name for name in class_names]

    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=short_names, yticklabels=short_names)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.xticks(rotation=90)
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.savefig('pytorch_confusion_matrix.png')
    plt.show()

print("Step 7: Evaluating the best model...")
# Load the best model weights before evaluation
model.load_state_dict(torch.load('knee_implant_cnn_best.pth'))
evaluate_model(model, val_loader)

In [None]:
# Prediction function
def predict_implant(image_path, model, class_names, transform):
    model.eval() # Set model to evaluation mode
    try:
        img = Image.open(image_path).convert('RGB')
    except FileNotFoundError:
        print(f"Error: Image file not found at {image_path}")
        return None

    # Apply the *validation* transform
    img_tensor = transform(img).unsqueeze(0).to(device)

    with torch.no_grad():
        outputs = model(img_tensor)
        probabilities = torch.softmax(outputs, dim=1)[0]
        top_p, top_class_idx = probabilities.topk(3, dim=0)

    top_predictions = []
    for i in range(top_p.size(0)):
        idx = top_class_idx[i].item()
        prob = top_p[i].item() * 100
        class_name = class_names[idx]
        top_predictions.append((class_name, prob))

    # Display results
    plt.figure(figsize=(6, 6))
    plt.imshow(img)
    plt.title(f"Predicted: {top_predictions[0][0]}\nConfidence: {top_predictions[0][1]:.1f}%")
    plt.axis('off')

    # Show other predictions as text
    result_text = "\n".join([f"{i+1}. {name}: {conf:.1f}%" for i, (name, conf) in enumerate(top_predictions)])
    plt.figtext(0.5, 0.01, result_text, ha='center', fontsize=10)

    plt.tight_layout()
    plt.show()

    return top_predictions

print("Step 8: Prediction function ready.")
# Example usage:
# Load the best model first
# model.load_state_dict(torch.load('knee_implant_cnn_best.pth'))
# predict_implant('/content/drive/MyDrive/Knee/Depuy AMK/example.jpg', model, class_names, val_transform)