In [None]:
# Segmentation Training

In [1]:
!pip install opencv-python numpy





In [3]:
import cv2
import numpy as np
import os

def create_mask(image_path, threshold=127):
    # Read the image
    image = cv2.imread(image_path)
    
    # Convert to grayscale
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    
    # Apply thresholding
    _, binary = cv2.threshold(gray, threshold, 255, cv2.THRESH_BINARY)
    
    # Find contours
    contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
    # Create mask (black background)
    mask = np.zeros(image.shape[:2], dtype=np.uint8)
    
    # Draw contours on mask (white for polyps)
    cv2.drawContours(mask, contours, -1, (255), thickness=cv2.FILLED)
    
    return mask

def process_images(input_folder, output_folder):
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)
    
    for filename in os.listdir(input_folder):
        if filename.endswith(('.png', '.jpg', '.jpeg')):
            image_path = os.path.join(input_folder, filename)
            mask = create_mask(image_path)
            
            # Save the mask
            base_name = os.path.splitext(filename)[0]
            cv2.imwrite(os.path.join(output_folder, f"{base_name}_mask.png"), mask)

# Example usage
input_folder = "images"
output_folder = "mask"
process_images(input_folder, output_folder)

In [7]:
!pip install -U git+https://github.com/qubvel/segmentation_models.pytorch

Collecting git+https://github.com/qubvel/segmentation_models.pytorch
  Cloning https://github.com/qubvel/segmentation_models.pytorch to c:\users\sage0\appdata\local\temp\pip-req-build-6g9dk29v


  ERROR: Error [WinError 2] The system cannot find the file specified while executing command git version
ERROR: Cannot find command 'git' - do you have 'git' installed and in your PATH?


In [2]:
import os
import torch
import torchvision
from torchvision import transforms
from PIL import Image
import numpy as np
from sklearn.metrics import jaccard_score, f1_score, precision_score, recall_score
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
import segmentation_models_pytorch as smp
from tqdm import tqdm
import logging
import matplotlib.pyplot as plt

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Check for GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info(f"Using device: {device}")

def load_images(folder_path):
    images = []
    filenames = []
    for filename in os.listdir(folder_path):
        if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
            try:
                img_path = os.path.join(folder_path, filename)
                img = Image.open(img_path).convert('RGB')
                img = transforms.Resize((256, 256))(img)
                img = transforms.ToTensor()(img)
                images.append(img)
                filenames.append(filename)
            except Exception as e:
                logging.error(f"Error loading image {filename}: {str(e)}")
    return images, filenames

def get_models():
    models = [
        ('FCN ResNet50', torchvision.models.segmentation.fcn_resnet50(pretrained=True)),
        ('FCN ResNet101', torchvision.models.segmentation.fcn_resnet101(pretrained=True)),
        ('DeepLabV3 ResNet50', torchvision.models.segmentation.deeplabv3_resnet50(pretrained=True)),
        ('DeepLabV3 ResNet101', torchvision.models.segmentation.deeplabv3_resnet101(pretrained=True)),
        ('DeepLabV3 MobileNet', torchvision.models.segmentation.deeplabv3_mobilenet_v3_large(pretrained=True)),
        ('U-Net', smp.Unet(encoder_name="resnet34", encoder_weights="imagenet", in_channels=3, classes=2)),
        ('PSPNet', smp.PSPNet(encoder_name="resnet34", encoder_weights="imagenet", in_channels=3, classes=2)),
        ('FPN', smp.FPN(encoder_name="resnet34", encoder_weights="imagenet", in_channels=3, classes=2)),
        ('LinkNet', smp.Linknet(encoder_name="resnet34", encoder_weights="imagenet", in_channels=3, classes=2)),
        ('MANet', smp.MAnet(encoder_name="resnet34", encoder_weights="imagenet", in_channels=3, classes=2))
    ]
    return models

def segment_image(model, image):
    model.eval()
    with torch.no_grad():
        image = image.to(device)
        if isinstance(model, (smp.Unet, smp.PSPNet, smp.FPN, smp.Linknet, smp.MAnet)):
            output = model(image.unsqueeze(0))
            output = torch.softmax(output, dim=1)
            output = output[:, 1, :, :]  # Take the probability of the foreground class
        else:
            output = model(image.unsqueeze(0))['out']
            output = torch.softmax(output, dim=1)
            output = output[:, 1, :, :]  # Take the probability of the foreground class
        output = output.squeeze().cpu()
    return (output > 0.5).byte()

def save_segmentation(segmentation, original_image, filename, output_folder, model_name):
    os.makedirs(os.path.join(output_folder, model_name), exist_ok=True)
    output_path = os.path.join(output_folder, model_name, f"seg_{filename}")
    
    # Convert segmentation to RGB for visualization
    seg_rgb = torch.zeros(3, 256, 256, dtype=torch.uint8)
    seg_rgb[0] = segmentation * 255  # Red channel
    seg_rgb = seg_rgb.permute(1, 2, 0).numpy()
    
    # Overlay segmentation on original image
    original_np = (original_image.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
    overlay = (original_np * 0.7 + seg_rgb * 0.3).astype(np.uint8)
    
    # Create a figure with original, segmentation, and overlay
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))
    ax1.imshow(original_np)
    ax1.set_title('Original')
    ax1.axis('off')
    ax2.imshow(seg_rgb)
    ax2.set_title('Segmentation')
    ax2.axis('off')
    ax3.imshow(overlay)
    ax3.set_title('Overlay')
    ax3.axis('off')
    
    plt.tight_layout()
    plt.savefig(output_path)
    plt.close()

def calculate_metrics(pred, target):
    pred = pred.numpy()
    target = target.numpy()
    
    pred_binary = pred.astype(int)
    target_binary = (target[0] > 0.5).astype(int)  # Assuming the first channel represents the class
    
    iou = jaccard_score(target_binary.flatten(), pred_binary.flatten())
    f1 = f1_score(target_binary.flatten(), pred_binary.flatten())
    precision = precision_score(target_binary.flatten(), pred_binary.flatten())
    recall = recall_score(target_binary.flatten(), pred_binary.flatten())
    psnr = peak_signal_noise_ratio(target_binary, pred_binary)
    ssim = structural_similarity(target_binary, pred_binary)
    dice = np.sum(pred_binary[target_binary==1])*2.0 / (np.sum(pred_binary) + np.sum(target_binary))
    
    return {
        'IoU': iou,
        'F1 Score': f1,
        'Precision': precision,
        'Recall': recall,
        'PSNR': psnr,
        'SSIM': ssim,
        'Dice Coefficient': dice
    }

def evaluate_models(models, images, filenames, output_folder):
    results = {}
    for model_name, model in tqdm(models, desc="Evaluating models"):
        model = model.to(device)
        model_results = []
        for image, filename in tqdm(zip(images, filenames), desc=f"Processing images for {model_name}", leave=False):
            try:
                segmentation = segment_image(model, image)
                save_segmentation(segmentation, image, filename, output_folder, model_name)
                metrics = calculate_metrics(segmentation, image)
                model_results.append(metrics)
            except Exception as e:
                logging.error(f"Error processing image {filename} with {model_name}: {str(e)}")
        results[model_name] = model_results
    return results

def analyze_results(results):
    for model, model_results in results.items():
        print(f"Results for {model}:")
        if model_results:
            avg_metrics = {metric: np.mean([r[metric] for r in model_results]) for metric in model_results[0]}
            for metric, value in avg_metrics.items():
                print(f"  Average {metric}: {value:.4f}")
        else:
            print("  No valid results for this model.")
        print()

def main(input_folder, output_folder):
    try:
        logging.info("Loading images...")
        images, filenames = load_images(input_folder)
        if not images:
            raise ValueError("No valid images found in the specified folder.")
        
        logging.info("Loading models...")
        models = get_models()
        
        logging.info("Evaluating models and saving segmentations...")
        results = evaluate_models(models, images, filenames, output_folder)
        
        logging.info("Analyzing results...")
        analyze_results(results)
        
    except Exception as e:
        logging.error(f"An error occurred: {str(e)}")

        logging.error(f"An error occurred: {str(e)}")

if __name__ == "__main__":
    input_folder = "test"
    output_folder = "result"
    main(input_folder, output_folder)

2024-08-23 16:26:28,616 - INFO - Using device: cpu
2024-08-23 16:26:28,617 - INFO - Loading images...
2024-08-23 16:26:28,650 - INFO - Loading models...
2024-08-23 16:26:33,733 - INFO - Evaluating models and saving segmentations...
Evaluating models:   0%|                                                                        | 0/10 [00:00<?, ?it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  ssim = structural_similarity(target_binary, pred_binary)

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  ssim = structural_similarity(target_binary, pred_binary)

Processing images for FCN ResNet50: 2it [00:01,  1.57it/s][A
Evaluating models:  10%|██████▍                                                         | 1/10 [00:01<00:11,  1.28s/it]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  ssim = structural_similarity(target_binary, pred_binary)

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(res

Processing images for DeepLabV3 ResNet101: 2it [00:01,  1.13it/s][A
Evaluating models:  40%|█████████████████████████▌                                      | 4/10 [00:06<00:09,  1.63s/it]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  ssim = structural_similarity(target_binary, pred_binary)

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  ssim = structural_similarity(target_binary, pred_binary)

Processing images for DeepLabV3 MobileNet: 2it [00:00,  2.64it/s][A
Evaluating models:  50%|████████████████████████████████                                | 5/10 [00:07<00:06,  1.32s/it]
  ssim = structural_similarity(target_binary, pred_binary)

  ssim = structural_similarity(target_binary, pred_binary)

Processing images for U-Net: 2it [00:00,  2.34it/s][A
Evaluating models:  60%|██████████████████████████████████████▍                         | 6/10 [00:08<00:04,  1.17s/it]
  ssim = structural_similarity(target_binary, pred_binary)

 

Results for FCN ResNet50:
  Average IoU: 0.0000
  Average F1 Score: 0.0000
  Average Precision: 0.0000
  Average Recall: 0.0000
  Average PSNR: 188.3452
  Average SSIM: 1.0000
  Average Dice Coefficient: 0.0000

Results for FCN ResNet101:
  Average IoU: 0.0000
  Average F1 Score: 0.0000
  Average Precision: 0.0000
  Average Recall: 0.0000
  Average PSNR: 188.3452
  Average SSIM: 1.0000
  Average Dice Coefficient: 0.0000

Results for DeepLabV3 ResNet50:
  Average IoU: 0.0000
  Average F1 Score: 0.0000
  Average Precision: 0.0000
  Average Recall: 0.0000
  Average PSNR: 188.3452
  Average SSIM: 1.0000
  Average Dice Coefficient: 0.0000

Results for DeepLabV3 ResNet101:
  Average IoU: 0.0000
  Average F1 Score: 0.0000
  Average Precision: 0.0000
  Average Recall: 0.0000
  Average PSNR: 188.3452
  Average SSIM: 1.0000
  Average Dice Coefficient: 0.0000

Results for DeepLabV3 MobileNet:
  Average IoU: 0.0000
  Average F1 Score: 0.0000
  Average Precision: 0.0000
  Average Recall: 0.0000
  A