# Detection: Train-test split

In [None]:
import pandas as pd
import numpy as np
import xml.etree.ElementTree as ET
import os
from sklearn.model_selection import StratifiedKFold
import warnings
import matplotlib.pyplot as plt

In [None]:
warnings.filterwarnings('ignore')

def extract_info_from_xml(xml_file):
    tree = ET.parse(xml_file)
    root = tree.getroot()
    image_id = os.path.basename(xml_file).split('.')[0]
    class_names = [obj.find('name').text for obj in root.iter('object')]
    bbox_count = len(class_names)
    return image_id, class_names, bbox_count

def create_dataframe(xml_dir):
    data = []
    for xml_file in os.listdir(xml_dir):
        if xml_file.endswith('.xml'):
            image_id, class_names, bbox_count = extract_info_from_xml(os.path.join(xml_dir, xml_file))
            for class_name in class_names:
                data.append({'image_id': image_id, 'class_name': class_name, 'bbox_count': bbox_count})
    return pd.DataFrame(data)

# Modify this path to your XML directory
xml_dir = 'CRIC_db/images'  
df = create_dataframe(xml_dir)
df = df[df['class_name'] != 'No finding']

# Stratified K-Fold
skf = StratifiedKFold(n_splits=20, shuffle=True, random_state=12)
df_folds = df[['image_id']].copy()
df_folds['bbox_count'] = df.groupby('image_id')['bbox_count'].transform('max')
df_folds['object_count'] = df.groupby('image_id')['class_name'].transform('nunique')
df_folds = df_folds.drop_duplicates()

# Stratify group for balancing
df_folds['stratify_group'] = df_folds['object_count'].astype(str) + '_' + df_folds['bbox_count'].apply(lambda x: f'{x // 15}').astype(str)

df_folds['fold'] = 0
for fold_number, (train_index, val_index) in enumerate(skf.split(X=df_folds.index, y=df_folds['stratify_group'])):
    df_folds.loc[df_folds.iloc[val_index].index, 'fold'] = fold_number

# Splitting into train and test sets
df_folds.reset_index(inplace=True)
# Splits
df_train = pd.merge(df, df_folds[df_folds['fold'].isin(range(16))], on='image_id')
df_valid = pd.merge(df, df_folds[df_folds['fold'].isin([16,17])])
df_test = pd.merge(df, df_folds[df_folds['fold'].isin([18,19])])


# print(f"Tentativa {i}======================================")
print(df_train['class_name'].value_counts(normalize=False))
print(df_valid['class_name'].value_counts(normalize=False))
print(df_test['class_name'].value_counts(normalize=False))

In [None]:
print("Train size:", df_train["image_id"].nunique())
print("Validation size:", df_valid["image_id"].nunique())
print("Test size:", df_test["image_id"].nunique())

In [None]:
plt.figure(num=None, figsize=(30, 8))
df_train['class_name'].hist()
df_test['class_name'].hist()
plt.show()

In [None]:
df_train['class_name'].value_counts(normalize=False)

In [None]:
df_valid['class_name'].value_counts(normalize=False)['SCC']

In [None]:
df_test['class_name'].value_counts(normalize=False)

In [None]:
import os
import shutil

def copy_images(df, source_dir, destination_dir):
    if not os.path.exists(destination_dir):
        os.makedirs(destination_dir)

    for image_id in df['image_id'].unique():
        # Assuming the image format is JPEG, change as needed
        source_file = os.path.join(source_dir, image_id + '.png')
        if os.path.exists(source_file):
            shutil.copy(source_file, destination_dir)
        else:
            print(f"Warning: {source_file} does not exist.")

        source_file = os.path.join(source_dir, image_id + '.xml')
        if os.path.exists(source_file):
            shutil.copy(source_file, destination_dir)
        else:
            print(f"Warning: {source_file} does not exist.")

# Set your source directory where all images are currently stored
source_dir = 'CRIC_db/images'

# Create and set your destination directories for train and validation sets
train_dir = 'train_images'
valid_dir = 'valid_images'
test_dir = 'test_images'

# Copy the images
copy_images(df_train, source_dir, train_dir)
copy_images(df_valid, source_dir, valid_dir)
copy_images(df_test, source_dir, test_dir)

# Detection: Single-class split

In [None]:
import os
import shutil
import xml.etree.ElementTree as ET

def modify_xml(source_xml_path, dest_xml_path):
    tree = ET.parse(source_xml_path)
    root = tree.getroot()
    should_save = False

    for obj in root.findall('object'):
        class_name = obj.find('name').text
        if class_name == 'Negative for intraepithelial lesion':
            root.remove(obj)
            should_save = True
        else:
            obj.find('name').text = 'Anomaly'
            should_save = True

    if should_save:
        tree.write(dest_xml_path)

def process_directory(source_subdir, dest_subdir):
    if not os.path.exists(dest_subdir):
        os.makedirs(dest_subdir)

    for file_name in os.listdir(source_subdir):
        file_path = os.path.join(source_subdir, file_name)
        dest_path = os.path.join(dest_subdir, file_name)

        if file_name.endswith('.png'):
            shutil.copy(file_path, dest_path)
        elif file_name.endswith('.xml'):
            modify_xml(file_path, dest_path)

def main(source_dir, dest_dir):
    for subdir in ['train_images', 'valid_images', 'test_images']:
        source_subdir = os.path.join(source_dir, subdir)
        dest_subdir = os.path.join(dest_dir, subdir)
        process_directory(source_subdir, dest_subdir)

# Define your source and destination directories
source_dir = 'detection_multi'
dest_dir = 'detection_single'

main(source_dir, dest_dir)

# Classification: Cropping and split

In [None]:
import os
import xml.etree.ElementTree as ET
import matplotlib.pyplot as plt

def extract_box_dimensions(xml_file):
    tree = ET.parse(xml_file)
    root = tree.getroot()
    box_dimensions = []

    for obj in root.iter('object'):
        bbox = obj.find('bndbox')
        xmin = int(bbox.find('xmin').text)
        ymin = int(bbox.find('ymin').text)
        xmax = int(bbox.find('xmax').text)
        ymax = int(bbox.find('ymax').text)

        width = xmax - xmin
        height = ymax - ymin
        box_dimensions.append((width, height))

    return box_dimensions

def analyze_dimensions(xml_dir):
    widths = []
    heights = []
    
    for xml_file in os.listdir(xml_dir):
        if xml_file.endswith('.xml'):
            dimensions = extract_box_dimensions(os.path.join(xml_dir, xml_file))
            for width, height in dimensions:
                widths.append(width)
                heights.append(height)

    return widths, heights

def plot_distribution(widths, heights):
    plt.figure(figsize=(12, 6))

    plt.subplot(1, 2, 1)
    plt.hist(widths, bins=50, color='blue', alpha=0.7)
    plt.title('Distribution of Box Widths')
    plt.xlabel('Width')
    plt.ylabel('Frequency')

    plt.subplot(1, 2, 2)
    plt.hist(heights, bins=50, color='green', alpha=0.7)
    plt.title('Distribution of Box Heights')
    plt.xlabel('Height')
    plt.ylabel('Frequency')

    plt.tight_layout()
    plt.show()

# Define your XML directories
train_xml_dir = 'detection_multi/train_images'
valid_xml_dir = 'detection_multi/valid_images'
test_xml_dir = 'detection_multi/test_images'

In [None]:
# Analyze and plot for train data
train_widths, train_heights = analyze_dimensions(train_xml_dir)
plot_distribution(train_widths, train_heights)

In [None]:
# Analyze and plot for validation data
valid_widths, valid_heights = analyze_dimensions(valid_xml_dir)
plot_distribution(valid_widths, valid_heights)

In [None]:
import os
import xml.etree.ElementTree as ET
from PIL import Image

def get_boxes_and_classes(xml_file):
    tree = ET.parse(xml_file)
    root = tree.getroot()
    boxes = []

    for obj in root.iter('object'):
        bbox = obj.find('bndbox')
        class_name = obj.find('name').text
        xmin = int(bbox.find('xmin').text)
        ymin = int(bbox.find('ymin').text)
        xmax = int(bbox.find('xmax').text)
        ymax = int(bbox.find('ymax').text)

        center_x = (xmin + xmax) // 2
        center_y = (ymin + ymax) // 2
        boxes.append(((center_x, center_y), class_name))

    return boxes

def crop_and_save(image, boxes, dest_dir):
    img_width, img_height = image.size

    for (center_x, center_y), class_name in boxes:
        left = max(center_x - 150, 0)
        right = min(center_x + 150, img_width)
        top = max(center_y - 150, 0)
        bottom = min(center_y + 150, img_height)

        crop = image.crop((left, top, right, bottom))

        class_dir = os.path.join(dest_dir, class_name)
        os.makedirs(class_dir, exist_ok=True)
        crop_name = os.path.basename(image.filename).replace('.png', f'_{center_x}_{center_y}.png')
        crop_path = os.path.join(class_dir, crop_name)
        crop.save(crop_path)

def process_directory(subdir, source_dir, dest_dir):
    for file_name in os.listdir(os.path.join(source_dir, subdir)):
        if file_name.endswith('.xml'):
            xml_file_path = os.path.join(source_dir, subdir, file_name)
            image_file_name = file_name.replace('.xml', '.png')
            image_file_path = os.path.join(source_dir, subdir, image_file_name)

            if os.path.exists(image_file_path):
                boxes = get_boxes_and_classes(xml_file_path)
                image = Image.open(image_file_path)
                crop_and_save(image, boxes, os.path.join(dest_dir, subdir))

# Define source and destination directories
source_dir = f'detection_multi'
dest_dir = f'class_post'

for subdir in ['train_images', 'valid_images', 'test_images']:
    process_directory(subdir, source_dir, dest_dir)

# Class counts

In [None]:
import os

def count_crops_in_class_folders(root_dir):
    count_dict = {}
    for subdir in os.listdir(root_dir):
        class_dir = os.path.join(root_dir, subdir)
        if os.path.isdir(class_dir):
            count = len([name for name in os.listdir(class_dir) if name.endswith('.png')])
            count_dict[subdir] = count
    return count_dict

def print_crop_counts(dest_dir):
    for set_type in ['train', 'val']:
        print(f"\nCounts for {set_type}:")
        set_dir = os.path.join(dest_dir, set_type)
        counts = count_crops_in_class_folders(set_dir)
        for class_name, count in counts.items():
            print(f"Class '{class_name}': {count} crops")

# Define destination directory
dest_dir = 'class_post'

print_crop_counts(dest_dir)