# classification_dataset

### Imports

In [12]:
import os
import sys
import json

import shutil
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from shapely.geometry import Polygon

import keras
import cv2 as cv

In [13]:
import tensorflow as tf
gpus = tf.config.list_physical_devices('GPU')
print("GPU Available:", gpus)
print("cuDNN Enabled:", tf.test.is_built_with_cuda())

if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(e)

GPU Available: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
cuDNN Enabled: True
Physical devices cannot be modified after being initialized


### Definitions

In [14]:
IMAGES_PATH = '../../media/data/input_color/'

sys.path.insert(0, "../../")
from config import MEDIA_PATH, CROPPED_PATH, MODELS_PATH

#Configuration
BATCH_SIZE = 32

# Paths
CSV_PATH = os.path.join(CROPPED_PATH, 'ina', 'data')
INA_CROPS_PATH = os.path.join(CROPPED_PATH, 'ina', 'images')
RF_CROPS_PATH = os.path.join(CROPPED_PATH, 'onion_cell_merged', 'images')
MODEL_PATH = os.path.join(MODELS_PATH, 'supervised', 'supervised_Encoder_SSIM+MAE3.keras')
JSON_PATH = os.path.join(MEDIA_PATH, 'images', 'ina', 'tagged_images', 'corte-27-02-2024.json')
OUTPUT_CLASSES_PATH = os.path.join(MEDIA_PATH, 'cropped_images', 'classification')
IMAGES_PATH = os.path.join(MEDIA_PATH, 'images', 'ina', 'tagged_images', 'input')


### Functions

In [15]:
def get_coco_bbox(data, id):
  """
  Given a list of objects it returns a list of bounding boxes for the given id.

  Args:
    data: list of objects to search from.
    id: id of the objetct to search.

  Returns:
    A list of bounding boxes.
  """
  bboxes = [cell['bbox'] for cell in data['annotations'] if cell['image_id'] == id]
  classes = [cell['attributes']['Fase'] for cell in data['annotations'] if cell['image_id'] == id]
  return bboxes, classes

def bb_overlap_percentage(box1, box2):

    """
    Calculates the Intersection over Union (IoU) of two bounding boxes using Shapely.

    Args:
        box1: A tuple or list containing (x1, y1, width, height) of the first bounding box.
        box2: A tuple or list containing (x1, y1, width, height) of the second bounding box.

    Returns:
        The IoU value, a float between 0 and 1.
    """

    # Calculate box coordinates
    x1_min, y1_min, width1, height1 = box1
    x1_max = x1_min + width1
    y1_max = y1_min + height1
    x2_min, y2_min, width2, height2 = box2
    x2_max = x2_min + width2
    y2_max = y2_min + height2

    # Create polygons
    poly1 = Polygon([(x1_min, y1_min), (x1_max, y1_min), (x1_max, y1_max), (x1_min, y1_max)])
    poly2 = Polygon([(x2_min, y2_min), (x2_max, y2_min), (x2_max, y2_max), (x2_min, y2_max)])

    # Calculate intersection and union areas
    intersection = poly1.intersection(poly2).area
    union = poly1.union(poly2).area

    #print(poly1.area, poly2.area, intersection)

    return intersection / poly1.area if union > 0 else 0.0

def bbox_fully_contained(bbox1, bbox2):
    """
    Checks if bounding box 1 is fully contained within bounding box 2.

    Args:
        bbox1: A tuple (x1, y1, width1, height1) representing the first bounding box.
        bbox2: A tuple (x2, y2, width2, height2) representing the second bounding box.

    Returns:
        True if bbox1 is fully contained within bbox2, False otherwise.
    """

    # Calculate coordinates of bounding boxes
    x1, y1, w1, h1 = bbox1
    x2, y2, w2, h2 = bbox2
    bbox1_coords = [(x1, y1), (x1 + w1, y1), (x1 + w1, y1 + h1), (x1, y1 + h1)]
    bbox2_coords = [(x2, y2), (x2 + w2, y2), (x2 + w2, y2 + h2), (x2, y2 + h2)]

    # Create Shapely polygons
    poly1 = Polygon(bbox1_coords)
    poly2 = Polygon(bbox2_coords)

    # Check if bbox1 is fully contained within bbox2
    return poly1.within(poly2)

def bbox_intercept(bbox, bbox_list, threshold=0):
  for idx, bbox_target in enumerate(bbox_list):
    if threshold > 0: # 0.75
      iou = bb_overlap_percentage(bbox, bbox_target)
      if iou >= threshold:
        return True, idx

    else:
      contained = bbox_fully_contained(bbox, bbox_target)
      if contained:
        return True, idx

  return False, -1

def find_image_name_and_id(data, image_name):
  """
  Given a list of objects it searches the file_name from the id.

  Args:
    data: list of objects to search from.
    id: id of the file name to return.

  Returns:
    The file name of the image id.
  """
  #First i have to check wether the image_name is an id or a file name (ex: 331 or 004_000001.jpg)
  if(image_name.isnumeric()):
    id = int(image_name)
    for img in data['images']:
        if img['id'] == id:
            real_name, _ = os.path.splitext(img['file_name'])
            return real_name, id
  else:
    real_name = image_name
    for img in data['images']:
        if real_name in img['file_name']:
            id = int(img['id'])
            return real_name, id

def process_images_in_batches(string_list, batch_size=10):
  """
  Processes a list of strings in batches of a specified size.

  Args:
    string_list: The list of strings to process.
    batch_size: The size of each batch.

  Yields:
    A batch of strings.
  """
  for i in range(0, len(string_list), batch_size):
    batch_num = i // batch_size + 1  # Calculate batch number (1-indexed)
    yield batch_num, string_list[i:i + batch_size]

def predict_cell(model, image_path, images_batch, color_type):
  """
  Given an image batch it returns the predictions of the batch with the given model.

  Args:
    model: keras model to use.
    image_path: path to the folder where the images are.
    images_batch: list of the image names to include in the batch

  Returns:
    A list of predictions.
  """

  images = []
  for image in images_batch:
      img = cv.imread(os.path.join(image_path,image), color_type)
      img = cv.resize(img, (128, 128))
      img = img / 255.0
      images.append(img)
  
  batch = np.stack(images)
  if color_type == cv.IMREAD_GRAYSCALE:\
    # Add missing channel
    batch = np.expand_dims(batch, axis=-1).astype(np.float32)

  prediction = model.predict(batch, verbose=0)
  prediction = tf.nn.softmax(prediction, axis=-1)
  return prediction

def extract_cell_id(filename):
    # Assumes filename format: <prefix>_<cell_id>.png
    return int(os.path.splitext(filename)[0].split('_')[-1])

def list_files(directory):
    all_files = []
    for root, dirs, files in os.walk(directory):
        for file in files:
            all_files.append(os.path.join(root, file))
    return all_files

### List of elements to use

In [16]:
csvs = sorted(os.listdir(CSV_PATH)) #Paths to the csv of SAM detections of each image
ina_crops = list_files(INA_CROPS_PATH) #Paths to the ina_crops made from SAM detection of the full_images
rf_crops = list_files(RF_CROPS_PATH) #Paths to the ina_crops made from SAM detection of the full_images
# images = sorted(os.listdir(IMAGES_PATH)) #full_images from where the ina_crops are made
with open(JSON_PATH, 'r') as f: #json with the information of the filename of the images
    data = json.load(f)

model = keras.models.load_model(MODEL_PATH)
if model.input.shape[-1] == 1:
    color_type = cv.IMREAD_GRAYSCALE
else:
    color_type = cv.IMREAD_COLOR

### Generate dataset

In [17]:
tagged_ids = set(list(annotation["image_id"] for annotation in data['annotations']))
tagged_images = list(find_image_name_and_id(data, str(id))[0] for id in tagged_ids)

In [18]:
all_crops = ina_crops + rf_crops

all_cell_crops = []

for idx, batch in process_images_in_batches(all_crops, batch_size=BATCH_SIZE): #Read the images in batch_size batches
    print(f"Batch {idx}/{int(len(all_crops)/BATCH_SIZE)}", end='\r')

    batch_prediction = predict_cell(model, image_path=INA_CROPS_PATH, images_batch=batch, color_type=color_type)

    is_cell = 1-np.argmax(batch_prediction, axis=1).astype(bool)

    for idx, crop in enumerate(batch):
        if is_cell[idx]:
            all_cell_crops.append(crop)


Batch 3032/3031

In [19]:
test_crops = []
untagged_crops = all_cell_crops

for tagged_image in tagged_images:
    to_move = [cell_crop for cell_crop in untagged_crops if tagged_image in cell_crop]
    for cell_crop in to_move:
        untagged_crops.remove(cell_crop)
        test_crops.append(cell_crop)


In [20]:
import random
random.seed(42)
from collections import defaultdict

prefix_to_files = defaultdict(list)

for path in untagged_crops:
    prefix = os.path.basename(path).split('_')[0]
    prefix_to_files[prefix].append(path)

train_crops = untagged_crops
validation_crops = []

# Select and remove 10% from each group
for prefix, files in prefix_to_files.items():
    n = max(1, int(len(files) * 0.1))  # At least 1 file if group is small
    chosen = random.sample(files, n)
    validation_crops.extend(chosen)

    for f in chosen:
        train_crops.remove(f)


In [21]:
train_output = os.path.join(OUTPUT_CLASSES_PATH, 'train')
validation_output = os.path.join(OUTPUT_CLASSES_PATH, 'validation')
test_output = os.path.join(OUTPUT_CLASSES_PATH, 'test')

os.makedirs(train_output, exist_ok=True)
os.makedirs(validation_output, exist_ok=True)
os.makedirs(test_output, exist_ok=True)

def copy_images_to_folder(image_paths, output):
    """
    Copies images to a folder named after the subset (train/val/test) inside output_base.
    Creates the folder if it does not exist.
    """
    for img_path in image_paths:
        shutil.copy(img_path, output)

copy_images_to_folder(train_crops, train_output)
copy_images_to_folder(validation_crops, validation_output)

In [22]:
for test_crop in test_crops:
    # Extract the base filename without extension
    base = os.path.basename(test_crop)
    cell_id = base.split('.')[0].split('_')[-1]
    image_name = '_'.join(base.split('_')[:2])
    _, image_number = find_image_name_and_id(data, image_name= image_name)

    df = pd.read_csv(os.path.join(CSV_PATH, f"{image_name}.csv"))
    df_bbox = df[df['cell_id'] == int(cell_id)]

    bbox = df_bbox[['x', 'y', 'w', 'h']].values.flatten().tolist()

    bboxes_coco, classes_coco =  get_coco_bbox(data, image_number)

    intercept, coco_idx = bbox_intercept(bbox, bboxes_coco, 0.75)

    if intercept:
        output_dir = os.path.join(test_output, classes_coco[coco_idx])
        os.makedirs(output_dir, exist_ok=True)
        shutil.copy(test_crop, output_dir)
