# Blood Cell Detection with WGAN Data Augmentation + YOLOv8
Complete pipeline: Download ZIP → Convert to YOLO → Train WGAN → Generate Synthetic Data → Train YOLO

## Cell 0: Mount Google Drive and Download Dataset ZIP

In [None]:
from google.colab import drive
import os
import shutil
from glob import glob

# Mount Google Drive
drive.mount('/content/drive')

# TODO: Update this path to your Drive dataset location
# Example: '/content/drive/MyDrive/archive (3).zip' or '/content/drive/MyDrive/blood-cell-dataset.zip'
DRIVE_ZIP_PATH = '/content/drive/MyDrive/archive (3).zip'  # CHANGE THIS PATH

# Extract ZIP
if os.path.exists(DRIVE_ZIP_PATH):
    print(f"✓ Found ZIP at: {DRIVE_ZIP_PATH}")
    !unzip -q '{DRIVE_ZIP_PATH}' -d /content/
    print("✓ ZIP extracted")
    
    # List extracted contents
    !ls -la /content/
else:
    print(f"✗ ZIP not found at {DRIVE_ZIP_PATH}")
    print("Please update DRIVE_ZIP_PATH to your dataset location")

## Cell 1: Convert BCCD XML Annotations to YOLO Format

In [None]:
import xml.etree.ElementTree as ET
import os
from glob import glob

# Classes mapping
classes = ["WBC", "RBC", "Platelets"]

def convert_bbox(size, box):
    """Convert VOC bbox to YOLO normalized format"""
    dw = 1./size[0]
    dh = 1./size[1]
    x = (box[0] + box[1])/2.0
    y = (box[2] + box[3])/2.0
    w = box[1] - box[0]
    h = box[3] - box[2]
    x = x*dw
    w = w*dw
    y = y*dh
    h = h*dh
    return (x, y, w, h)

# Find BCCD directory
bccd_dirs = glob('/content/**/BCCD', recursive=True)
if bccd_dirs:
    in_dir = bccd_dirs[0]
else:
    in_dir = '/content/BCCD'  # fallback

print(f"Looking for BCCD in: {in_dir}")

xml_dir = os.path.join(in_dir, "Annotations")
img_dir = os.path.join(in_dir, "JPEGImages")
label_dir = os.path.join(in_dir, "labels")
os.makedirs(label_dir, exist_ok=True)

if os.path.exists(xml_dir):
    image_list = sorted(os.listdir(xml_dir))
    
    for xml_filename in image_list:
        if not xml_filename.endswith('.xml'):
            continue
        tree = ET.parse(os.path.join(xml_dir, xml_filename))
        root = tree.getroot()
        size = root.find('size')
        w = int(size.find('width').text)
        h = int(size.find('height').text)
        yolo_lines = []
        for obj in root.iter('object'):
            cls = obj.find('name').text
            if cls not in classes:
                continue
            cls_id = classes.index(cls)
            xmlbox = obj.find('bndbox')
            b = (
                float(xmlbox.find('xmin').text),
                float(xmlbox.find('xmax').text),
                float(xmlbox.find('ymin').text),
                float(xmlbox.find('ymax').text)
            )
            bb = convert_bbox((w, h), b)
            yolo_lines.append(f"{cls_id} {' '.join([str(round(a,6)) for a in bb])}")
        img_base = xml_filename.replace(".xml",".jpg")
        with open(os.path.join(label_dir, xml_filename.replace(".xml",".txt")), "w") as f:
            f.write('\n'.join(yolo_lines))
    
    print(f"✓ Converted {len(image_list)} XML annotations to YOLO format")
else:
    print(f"✗ Annotations directory not found at {xml_dir}")
    print("Check your extracted directory structure")

## Cell 2: Organize Data into train/val/test splits for YOLO

In [None]:
import shutil
from glob import glob
import random
import os

# Create YOLO dataset directory structure
os.makedirs("yolo_dataset/images/train", exist_ok=True)
os.makedirs("yolo_dataset/images/val", exist_ok=True)
os.makedirs("yolo_dataset/images/test", exist_ok=True)
os.makedirs("yolo_dataset/labels/train", exist_ok=True)
os.makedirs("yolo_dataset/labels/val", exist_ok=True)
os.makedirs("yolo_dataset/labels/test", exist_ok=True)

# Find BCCD JPEGImages directory
jpeg_dirs = glob('/content/**/JPEGImages', recursive=True)
if jpeg_dirs:
    jpeg_dir = jpeg_dirs[0]
    labels_dir = os.path.join(os.path.dirname(jpeg_dir), 'labels')
else:
    print("JPEGImages directory not found")

# Get all images
all_imgs = sorted(glob(os.path.join(jpeg_dir, "*.jpg")))
random.seed(42)
random.shuffle(all_imgs)

# Split: 80% train, 10% val, 10% test
n = len(all_imgs)
train_imgs = all_imgs[:int(n*0.8)]
val_imgs = all_imgs[int(n*0.8):int(n*0.9)]
test_imgs = all_imgs[int(n*0.9):]

# Copy images and labels
for split, split_imgs in zip(['train','val','test'], [train_imgs, val_imgs, test_imgs]):
    for img in split_imgs:
        img_base = os.path.basename(img)
        lbl_base = img_base.replace('.jpg', '.txt')
        shutil.copy(img, f"yolo_dataset/images/{split}/{img_base}")
        lbl_src = os.path.join(labels_dir, lbl_base)
        if os.path.exists(lbl_src):
            shutil.copy(lbl_src, f"yolo_dataset/labels/{split}/{lbl_base}")

print(f"✓ Split dataset:")
print(f"  Train: {len(train_imgs)} images")
print(f"  Val:   {len(val_imgs)} images")
print(f"  Test:  {len(test_imgs)} images")

## Cell 3: Write data.yaml for YOLO

In [None]:
import os

yaml_content = f"""path: {os.path.abspath('yolo_dataset')}
train: images/train
val: images/val
test: images/test
nc: 3
names: ['WBC','RBC','Platelets']
"""

with open('yolo_dataset/data.yaml','w') as f:
    f.write(yaml_content)

print("✓ data.yaml created for YOLO training")
print("\ndata.yaml content:")
print(yaml_content)

## Cell 4: Reorganize BCCD Images for WGAN (ImageFolder Format)

In [None]:
import shutil
from glob import glob
import xml.etree.ElementTree as ET
import os

# Create ImageFolder-style directory for GAN
gan_data_dir = "/content/data"
os.makedirs(gan_data_dir, exist_ok=True)

# Extract class labels from XML files
xml_files = glob('/content/**/Annotations/*.xml', recursive=True)
class_names = {"WBC": 0, "RBC": 1, "Platelets": 2}

# Create class subdirectories
for class_name in class_names.keys():
    os.makedirs(os.path.join(gan_data_dir, class_name), exist_ok=True)

# For each image, check what classes it contains and copy to appropriate folders
for xml_file in xml_files:
    tree = ET.parse(xml_file)
    root = tree.getroot()
    img_name = root.find('filename').text
    
    # Get all classes in this image
    classes_in_image = set()
    for obj in root.iter('object'):
        cls = obj.find('name').text
        if cls in class_names:
            classes_in_image.add(cls)
    
    # Find and copy image to each class folder it contains
    jpeg_dirs = glob('/content/**/JPEGImages', recursive=True)
    if jpeg_dirs:
        src_img = os.path.join(os.path.dirname(jpeg_dirs[0]), 'JPEGImages', img_name)
        if os.path.exists(src_img):
            for cls in classes_in_image:
                dst_img = os.path.join(gan_data_dir, cls, img_name)
                if not os.path.exists(dst_img):
                    shutil.copy(src_img, dst_img)

print(f"✓ Organized BCCD images into ImageFolder format at {gan_data_dir}")
print(f"  Classes: {list(class_names.keys())}")
for class_name in class_names.keys():
    count = len(glob(os.path.join(gan_data_dir, class_name, "*.jpg")))
    print(f"  - {class_name}: {count} images")

## Cell 5: WGAN - Imports and Config

In [None]:
import os
import time
import random
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets, utils
from torch.utils.data import DataLoader

# --------------------
# Config / Hyperparams
# --------------------
DATASET_PATH = "/content/data"          # from Cell 4
OUT_DIR = "gan_outputs"
MODEL_DIR = "models"
IMG_SIZE = 128
BATCH_SIZE = 32                         # smaller for stability
Z_DIM = 100
NUM_EPOCHS = 200                        # more epochs for WGAN
LR = 5e-5                               # lower LR for WGAN
BETA1 = 0.5
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SAMPLE_EVERY = 10
NUM_WORKERS = 2
LAMBDA_GP = 10                          # gradient penalty weight

os.makedirs(OUT_DIR, exist_ok=True)
os.makedirs(MODEL_DIR, exist_ok=True)

print(f"Using device: {DEVICE}")

## Cell 6: WGAN - Data Preparation

In [None]:
# --------------------
# Data preparation
# --------------------
transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3),  # normalized to [-1,1]
])

dataset = datasets.ImageFolder(root=DATASET_PATH, transform=transform)
dataloader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

print("Dataset classes:", dataset.classes)
print("Dataset size:", len(dataset))
print(f"Number of batches: {len(dataloader)}")

## Cell 7: WGAN - Models (Generator & Discriminator)

In [None]:
# --------------------
# WGAN Models
# --------------------

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

class Generator(nn.Module):
    def __init__(self, z_dim=100, ngf=64, nc=3):
        super().__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(z_dim, ngf*16, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf*16),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf*16, ngf*8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf*8),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf*8, ngf*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf*4),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf*4, ngf*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf*2),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf*2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, z):
        return self.net(z)

class Discriminator(nn.Module):
    """WGAN Discriminator (Critic) - NO sigmoid, outputs raw Wasserstein distance"""
    def __init__(self, nc=3, ndf=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf, ndf*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*2),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf*2, ndf*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*4),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf*4, ndf*8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*8),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf*8, ndf*16, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*16),
            nn.LeakyReLU(0.2, inplace=True),

            # NO sigmoid - output raw score
            nn.Conv2d(ndf*16, 1, 4, 1, 0, bias=False)
        )

    def forward(self, x):
        return self.net(x).view(-1)

# Instantiate models
netG = Generator(Z_DIM).to(DEVICE)
netD = Discriminator().to(DEVICE)
netG.apply(weights_init)
netD.apply(weights_init)

# WGAN uses RMSprop instead of Adam
optimizerD = optim.RMSprop(netD.parameters(), lr=LR)
optimizerG = optim.RMSprop(netG.parameters(), lr=LR)

fixed_noise = torch.randn(64, Z_DIM, 1, 1, device=DEVICE)

print("✓ Generator and Discriminator initialized (WGAN)")

## Cell 8: WGAN - Gradient Penalty Function

In [None]:
# --------------------
# Gradient Penalty for WGAN-GP
# --------------------

def compute_gradient_penalty(discriminator, real_data, fake_data):
    """Gradient penalty to enforce Lipschitz constraint"""
    batch_size = real_data.size(0)
    
    # Sample interpolation coefficient
    alpha = torch.rand(batch_size, 1, 1, 1, device=DEVICE)
    
    # Interpolate between real and fake data
    interpolates = (alpha * real_data + (1 - alpha) * fake_data).requires_grad_(True)
    
    # Discriminator output on interpolated data
    d_interpolates = discriminator(interpolates)
    
    # Compute gradients
    fake_output = torch.ones(batch_size, device=DEVICE, requires_grad=False)
    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake_output,
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]
    
    # Compute gradient penalty
    gradients = gradients.view(batch_size, -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    
    return gradient_penalty

print("✓ Gradient penalty function defined")

## Cell 9: WGAN - Training Loop

In [None]:
# --------------------
# WGAN-GP Training loop
# --------------------

iters = 0
n_critic = 5  # Train discriminator 5x more than generator

for epoch in range(1, NUM_EPOCHS + 1):
    epoch_start = time.time()
    running_D_loss = 0.0
    running_G_loss = 0.0
    
    for i, (data, _) in enumerate(dataloader):
        real_data = data.to(DEVICE)
        batch_size = real_data.size(0)
        
        # --------------------
        # Train Discriminator (Critic)
        # --------------------
        for _ in range(n_critic):
            netD.zero_grad()
            
            # Real data
            real_output = netD(real_data)
            
            # Fake data
            noise = torch.randn(batch_size, Z_DIM, 1, 1, device=DEVICE)
            fake_data = netG(noise)
            fake_output = netD(fake_data.detach())
            
            # Wasserstein distance (critic loss)
            d_loss = -torch.mean(real_output) + torch.mean(fake_output)
            
            # Gradient penalty
            gp = compute_gradient_penalty(netD, real_data, fake_data.detach())
            d_loss_total = d_loss + LAMBDA_GP * gp
            
            d_loss_total.backward()
            optimizerD.step()
            
            running_D_loss += d_loss_total.item()
        
        # --------------------
        # Train Generator
        # --------------------
        netG.zero_grad()
        
        noise = torch.randn(batch_size, Z_DIM, 1, 1, device=DEVICE)
        fake_data = netG(noise)
        fake_output = netD(fake_data)
        
        # Generator loss (fool the discriminator)
        g_loss = -torch.mean(fake_output)
        
        g_loss.backward()
        optimizerG.step()
        
        running_G_loss += g_loss.item()
        iters += 1
    
    # Epoch stats
    avg_D = running_D_loss / (len(dataloader) * n_critic)
    avg_G = running_G_loss / len(dataloader)
    epoch_time = time.time() - epoch_start
    
    if epoch % 10 == 0 or epoch == 1:
        print(f"Epoch [{epoch}/{NUM_EPOCHS}]  D_loss: {avg_D:.4f}  G_loss: {avg_G:.4f}  time: {epoch_time:.1f}s")
    
    # Save models & samples periodically
    if epoch % SAMPLE_EVERY == 0 or epoch == 1:
        g_path = os.path.join(MODEL_DIR, f"netG_epoch{epoch}.pth")
        d_path = os.path.join(MODEL_DIR, f"netD_epoch{epoch}.pth")
        torch.save(netG.state_dict(), g_path)
        torch.save(netD.state_dict(), d_path)
        
        # Generate sample grid
        with torch.no_grad():
            fake_samples = netG(fixed_noise).detach().cpu()
        grid = utils.make_grid(fake_samples, padding=2, normalize=True)
        sample_path = os.path.join(OUT_DIR, f"sample_epoch{epoch}.png")
        utils.save_image(grid, sample_path)

print("✓ WGAN Training finished!")

## Cell 10: Generate Synthetic Images for Augmentation

In [None]:
# --------------------
# Generate synthetic images
# --------------------

aug_out_dir = "synthetic_blood_cells"
os.makedirs(aug_out_dir, exist_ok=True)

NUM_SYNTHETIC = 1000  # Generate 1000 synthetic images

netG.eval()
with torch.no_grad():
    for i in range(NUM_SYNTHETIC):
        z = torch.randn(1, Z_DIM, 1, 1, device=DEVICE)
        fake_img = netG(z).detach().cpu()
        # Denormalize from [-1,1] to [0,1]
        fake_img = (fake_img + 1) / 2
        utils.save_image(fake_img, os.path.join(aug_out_dir, f"synthetic_{i:05d}.png"))
        if (i + 1) % 100 == 0:
            print(f"Generated {i+1}/{NUM_SYNTHETIC} synthetic images")

print(f"✓ Generated {NUM_SYNTHETIC} synthetic images in '{aug_out_dir}'")

## Cell 11: Add Synthetic Images to YOLO Training Dataset

In [None]:
import shutil
from glob import glob
import os

# Copy synthetic images to YOLO train folder
synthetic_imgs = glob("synthetic_blood_cells/*.png")
for syn_img in synthetic_imgs:
    img_base = os.path.basename(syn_img)
    dst_img = os.path.join("yolo_dataset/images/train/", img_base)
    shutil.copy(syn_img, dst_img)
    
    # Create empty label file (synthetic images have no annotations)
    lbl_base = img_base.replace('.png', '.txt')
    with open(os.path.join("yolo_dataset/labels/train/", lbl_base), 'w') as f:
        f.write("")  # Empty - no objects in synthetic images

print(f"✓ Added {len(synthetic_imgs)} synthetic images to YOLO training set")
total_train = len(glob('yolo_dataset/images/train/*.jpg')) + len(glob('yolo_dataset/images/train/*.png'))
print(f"  Total training images: {total_train}")

## Cell 12: Install YOLOv8

In [None]:
# Install YOLOv8 package
!pip install ultralytics --upgrade -q

import ultralytics
print(f"✓ YOLOv8 version: {ultralytics.__version__}")

## Cell 13: Train YOLOv8 Model

In [None]:
from ultralytics import YOLO
import os

# Load pretrained YOLOv8 nano model
model = YOLO('yolov8n.pt')

# Train on BCCD dataset (with synthetic augmentation)
results = model.train(
    data='yolo_dataset/data.yaml',
    epochs=100,
    imgsz=640,
    batch=16,
    patience=20,
    device=0,  # GPU device
    project='blood_cell_detection',
    name='yolov8n_bccd'
)

print("✓ YOLO training completed!")

## Cell 14: Validate YOLO Model

In [None]:
# Run validation
metrics = model.val()

print(f"mAP50: {metrics.box.map50:.3f}")
print(f"mAP50-95: {metrics.box.map:.3f}")

## Cell 15: Test and Visualize Predictions

In [None]:
from IPython.display import Image, display
import glob
import os

# Predict on test images
test_images = glob.glob('yolo_dataset/images/test/*.jpg')

print(f"Testing on {len(test_images)} images...\n")

for i, img_path in enumerate(test_images[:5]):  # Show first 5
    results = model.predict(img_path, conf=0.5)
    results[0].save(filename=f'test_pred_{i}.jpg')
    display(Image(filename=f'test_pred_{i}.jpg'))
    print(f"Image {i+1}: {os.path.basename(img_path)}")

## Cell 16: Export Model (Optional)

In [None]:
# Export model to ONNX format
model.export(format='onnx')

print("✓ Model exported to ONNX format")
print("  Files available in: runs/detect/yolov8n_bccd/")