In [None]:
# ===================================================================
# BLOCK 1: INSTALLATION AND IMPORTS
# ===================================================================
# Install necessary libraries
!pip install -q segmentation-models-pytorch tqdm albumentations ultralytics

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp
from ultralytics import YOLO
from tqdm import tqdm
import warnings
import urllib.request
warnings.filterwarnings('ignore')

# Set up device (use GPU if available)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs")

---
## Part 1: U-Net Model Training
---

In [None]:
# ===================================================================
# 1.1: Prepare Dataset and DataLoader for U-Net
# ===================================================================
class CellSegDataset(Dataset):
    def __init__(self, images_dir, masks_dir, transform=None):
        self.images_dir = images_dir
        self.masks_dir = masks_dir
        self.image_files = sorted(os.listdir(images_dir))
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        img_path = os.path.join(self.images_dir, img_name)
        # Corresponding mask name
        mask_path = os.path.join(self.masks_dir, f"MASK_{os.path.splitext(img_name)[0]}.png")

        image = cv2.imread(img_path, cv2.IMREAD_COLOR)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented["image"]
            mask = augmented["mask"]

        # Convert mask to a tensor (C, H, W)
        if isinstance(mask, np.ndarray):
            mask = torch.from_numpy((mask > 0).astype("np.int64"))
        elif isinstance(mask, torch.Tensor):
            mask = (mask > 0).long()
        else:
            raise TypeError(f"Unsupported mask type: {type(mask)}")
        
        # Ensure mask has a channel dimension
        if mask.dim() == 2:
            mask = mask.unsqueeze(0)

        return image, mask

# Define image augmentations
train_transform = A.Compose([
    A.Resize(256, 256),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.Rotate(limit=15, p=0.5),
    A.RandomResizedCrop(size=(256, 256), scale=(0.9, 1.0), p=0.3),
    A.Affine(scale=(0.95, 1.05), translate_percent=(0.02, 0.02), rotate=(-10, 10), p=0.3),
    A.Normalize(mean=(0.0, 0.0, 0.0), std=(1.0, 1.0, 1.0)),
    ToTensorV2()
])

val_test_transform = A.Compose([
    A.Resize(256, 256),
    A.Normalize(mean=(0.0, 0.0, 0.0), std=(1.0, 1.0, 1.0)),
    ToTensorV2()
])

# Path to the dataset
base_dir = "/kaggle/input/cell-counting-roboflow-segmentation-masks"

train_dataset = CellSegDataset(
    images_dir=os.path.join(base_dir, "train/images"),
    masks_dir=os.path.join(base_dir, "train/masks_binary"),
    transform=train_transform
)

val_dataset = CellSegDataset(
    images_dir=os.path.join(base_dir, "valid/images"),
    masks_dir=os.path.join(base_dir, "valid/masks_binary"),
    transform=val_test_transform
)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=2)

print(f"Train samples: {len(train_dataset)}, Validation samples: {len(val_dataset)}")

In [None]:
# ===================================================================
# 1.2: Define U-Net Model, Loss, and Training Function
# ===================================================================
model_unet_train = smp.Unet(
    encoder_name="resnet152",
    encoder_weights="imagenet",
    classes=2,
    activation=None,
)
print(f"U-Net Model parameters: {sum(p.numel() for p in model_unet_train.parameters()):,}")

# --- Define Loss Functions ---
class DiceLoss(nn.Module):
    """Dice Loss for segmentation"""
    def __init__(self, smooth=1):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
    
    def forward(self, predictions, targets):
        predictions = torch.softmax(predictions, dim=1)
        predictions = predictions[:, 1, :, :] # Only get probabilities for the foreground class
        targets = targets.float()
        
        intersection = (predictions * targets).sum()
        dice = (2. * intersection + self.smooth) / (predictions.sum() + targets.sum() + self.smooth)
        
        return 1 - dice

class CombinedLoss(nn.Module):
    """Combination of CrossEntropy and Dice Loss"""
    def __init__(self, weight_ce=0.5, weight_dice=0.5):
        super(CombinedLoss, self).__init__()
        self.weight_ce = weight_ce
        self.weight_dice = weight_dice
        self.ce_loss = nn.CrossEntropyLoss()
        self.dice_loss = DiceLoss()
    
    def forward(self, predictions, targets):
        # CrossEntropy Loss
        targets_ce = targets.squeeze(1).long() if targets.dim() == 4 else targets.long()
        ce = self.ce_loss(predictions, targets_ce)
        
        # Dice Loss
        targets_dice = targets.float().squeeze(1) if targets.dim() == 4 else targets.float()
        dice = self.dice_loss(predictions, targets_dice)
        return self.weight_ce * ce + self.weight_dice * dice

# --- Training Function ---
def train_unet_model(model, train_loader, val_loader, num_epochs=100, learning_rate=1e-3):
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    model.to(device)
    
    criterion = CombinedLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5, factor=0.5)
    
    train_losses, val_losses = [], []
    best_val_loss = float('inf')
    
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        train_pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} - Train')
        
        for images, masks in train_pbar:
            images, masks = images.to(device), masks.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            train_pbar.set_postfix({'loss': loss.item()})
        
        train_loss /= len(train_loader)
        train_losses.append(train_loss)
        
        model.eval()
        val_loss = 0
        val_pbar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{num_epochs} - Validation')
        
        with torch.no_grad():
            for images, masks in val_pbar:
                images, masks = images.to(device), masks.to(device)
                outputs = model(images)
                loss = criterion(outputs, masks)
                val_loss += loss.item()
                val_pbar.set_postfix({'loss': loss.item()})
        
        val_loss /= len(val_loader)
        val_losses.append(val_loss)
        scheduler.step(val_loss)
        
        print(f'Epoch {epoch+1}/{num_epochs}: Train Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}')
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_unet_model.pth')
            print(f'==> Best model saved with validation loss: {val_loss:.4f}')
            
    return model, train_losses, val_losses

In [None]:
# ===================================================================
# 1.3: Start U-Net Training
# ===================================================================
print("Starting U-Net model training...")
trained_model, train_losses, val_losses = train_unet_model(
    model_unet_train, train_loader, val_loader, num_epochs=100, learning_rate=1e-3
)

# Plot the training history
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Validation Loss')
plt.title('U-Net Training History')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.show()

---
## Part 2: YOLOv8 Model Training
---

In [None]:
# ===================================================================
# 2.1: Start YOLOv8 Training
# ===================================================================
model_yolo_train = YOLO('yolov8n.pt') # Start from a pre-trained model

print("Starting YOLOv8 model training...")
results = model_yolo_train.train(
    data='/kaggle/input/cell-counting-roboflow-segmentation-masks/Cell_Counting_dataset_from_roboflow/data.yaml',
    epochs=100,
    imgsz=640,
    name='cell_detection_yolo_model' # Directory name for saving results
)
print("YOLOv8 training completed successfully!")

---
## Part 3: Inference on New Images from URLs
---

In [None]:
# ===================================================================
# 3.1: Load Trained Models
# ===================================================================

# --- Load U-Net Model ---
model_unet_eval = smp.Unet(
    encoder_name="resnet152",
    encoder_weights=None, # No need to reload imagenet weights
    classes=2,
    activation=None,
)
if torch.cuda.device_count() > 1:
    model_unet_eval = nn.DataParallel(model_unet_eval)
# Load the state_dict from the saved file
model_unet_eval.load_state_dict(torch.load("best_unet_model.pth", map_location=device))
model_unet_eval.to(device)
model_unet_eval.eval()
print("Trained U-Net model loaded successfully for inference.")

# --- Load YOLO Model ---
model_path_yolo = '/kaggle/working/runs/detect/cell_detection_yolo_model/weights/best.pt'
model_yolo_eval = YOLO(model_path_yolo)
print("Trained YOLOv8 model loaded successfully for inference.")

In [None]:
# ===================================================================
# 3.2: Define Helper and Pipeline Functions for Inference
# ===================================================================
def estimate_average_cell_diameter(image_path):
    img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
    if img is None: return None
    _, thresh = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(thresh, connectivity=8)
    areas = [stats[i, cv2.CC_STAT_AREA] for i in range(1, num_labels) if 10 < stats[i, cv2.CC_STAT_AREA] < 50000]
    if not areas: return None
    return np.sqrt(np.mean(areas) / np.pi) * 2

def preprocess_image(image):
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    enhanced_contrast = clahe.apply(gray)
    return cv2.cvtColor(enhanced_contrast, cv2.COLOR_GRAY2RGB)

def post_process_mask(mask, min_area=25):
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
    opened_mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel, iterations=1)
    num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(opened_mask, connectivity=8)
    cleaned_mask = np.zeros_like(opened_mask)
    for i in range(1, num_labels):
        if stats[i, cv2.CC_STAT_AREA] >= min_area:
            cleaned_mask[labels == i] = 255
    return cleaned_mask

def count_objects(binary_mask):
    contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    return len(contours)

def predict_unet_mask(model, device, image_tensor):
    model.eval()
    with torch.no_grad():
        outputs = model(image_tensor.to(device))
        probs = torch.softmax(outputs, dim=1)
        binary = probs[:, 1, :, :] > 0.5
        return binary.cpu().numpy().astype(np.uint8) * 255

def unet_counting_pipeline(model, device, image_path, transform, min_area=25):
    image = cv2.imread(image_path)
    if image is None: return 0, None, None
    TARGET_CELL_DIAMETER = 60.0
    current_diameter = estimate_average_cell_diameter(image_path)
    if current_diameter is not None and current_diameter > 0:
        resize_factor = TARGET_CELL_DIAMETER / current_diameter
        new_size = (int(image.shape[1] * resize_factor), int(image.shape[0] * resize_factor))
        image = cv2.resize(image, new_size, interpolation=cv2.INTER_AREA)
    
    orig_h, orig_w = image.shape[:2]
    preprocessed_rgb = preprocess_image(image)
    input_tensor = transform(image=preprocessed_rgb)["image"].unsqueeze(0)
    
    pred_mask_batch = predict_unet_mask(model, device, input_tensor)
    mask_resized = cv2.resize(pred_mask_batch[0], (orig_w, orig_h), interpolation=cv2.INTER_NEAREST)
    cleaned_mask = post_process_mask(mask_resized, min_area=min_area)
    num_objects = count_objects(cleaned_mask)
    return num_objects, cleaned_mask, image

def count_cells_yolo(model, image_path, conf_threshold=0.45):
    img = cv2.imread(image_path)
    if img is None: return 0, None, None
    results = model(img, verbose=False, conf=conf_threshold)
    return len(results[0].boxes), img, results

def visualize_unet_prediction(original_bgr, cleaned_mask, num_objects):
    contours, _ = cv2.findContours(cleaned_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    overlay = cv2.drawContours(original_bgr.copy(), contours, -1, (0, 255, 0), 2) # Green contours
    plt.figure(figsize=(12, 6))
    plt.subplot(1, 2, 1); plt.imshow(cv2.cvtColor(original_bgr, cv2.COLOR_BGR2RGB)); plt.title('Original Image'); plt.axis('off')
    plt.subplot(1, 2, 2); plt.imshow(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB)); plt.title(f'U-Net Count: {num_objects}'); plt.axis('off');
    plt.show()

def visualize_yolo_prediction(original_bgr, yolo_results, num_objects):
    overlay = original_bgr.copy()
    for box in yolo_results[0].boxes.xyxy.cpu().numpy():
        x1, y1, x2, y2 = map(int, box[:4])
        cv2.rectangle(overlay, (x1, y1), (x2, y2), (0, 255, 0), 2) # Green boxes
    plt.figure(figsize=(12, 6))
    plt.subplot(1, 2, 1); plt.imshow(cv2.cvtColor(original_bgr, cv2.COLOR_BGR2RGB)); plt.title('Original Image'); plt.axis('off')
    plt.subplot(1, 2, 2); plt.imshow(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB)); plt.title(f'YOLOv8 Count: {num_objects}'); plt.axis('off');
    plt.show()

In [None]:
# ===================================================================
# 3.3: Run Inference on a List of Local Image Paths
# ===================================================================
image_paths = [
    "/kaggle/input/cell-counting-roboflow-segmentation-masks/test/images/Screenshot-2024-08-20-at-3-40-57-PM_png.rf.51490a3f822ef799797a83f5462ccc9a.jpg",
    "/kaggle/input/cell-counting-roboflow-segmentation-masks/test/images/Screenshot-2024-08-20-at-3-41-22-PM_png.rf.44ac6ca63f784c0490e88c39d69692ed.jpg",
    "/kaggle/input/cell-counting-roboflow-segmentation-masks/test/images/Screenshot-2024-08-20-at-6-10-00-PM_png.rf.988e622803e078ad613cd30a92b68a20.jpg",
    "/kaggle/input/cell-counting-roboflow-segmentation-masks/test/images/Screenshot-2024-08-20-at-6-10-27-PM_png.rf.cf3e3aac50ec0abae4b45cb33971be50.jpg"
]

UNET_MIN_CELL_AREA = 20
YOLO_CONF_THRESHOLD = 0.5

for i, path in enumerate(image_paths):
    print(f"\n{'='*20} PROCESSING IMAGE {i+1} {'='*20}")
    print(f"Path: {path}")
    
    # Check if the file exists before processing
    if not os.path.exists(path):
        print(f"Error: Image not found at path: {path}")
        continue
        
    try:
        # --- Run U-Net Pipeline ---
        pred_unet, final_mask, unet_img = unet_counting_pipeline(
            model_unet_eval, device, path, val_test_transform, min_area=UNET_MIN_CELL_AREA
        )
        print(f"U-Net Predicted Count: {pred_unet}")
        if unet_img is not None:
            visualize_unet_prediction(unet_img, final_mask, pred_unet)
            
        # --- Run YOLOv8 Pipeline ---
        pred_yolo, yolo_img, yolo_results = count_cells_yolo(
            model_yolo_eval, path, conf_threshold=YOLO_CONF_THRESHOLD
        )
        print(f"YOLOv8 Predicted Count: {pred_yolo}")
        if yolo_img is not None:
            visualize_yolo_prediction(yolo_img, yolo_results, pred_yolo)

    except Exception as e:
        print(f"Could not process image. Error: {e}")