In [None]:
import numpy as np 
import pandas as pd 
import matplotlib.pyplot as plt 
import cv2 
import os
import sys
import re
import random
import xml.etree.ElementTree as ET
import albumentations as A
import pandas as pd

In [None]:
def count_class_distribution(path):
    classes = {}
    root_path = os.getcwd()

    def plot(data):
        labels, counts = zip(*data.items())
        plt.figure(figsize=(6, 4))
        bars = plt.bar(labels, counts, color='green')
        plt.xlabel('Class Labels')
        plt.ylabel('Count')
        plt.title('Class Distribution')
        for bar, count in zip(bars, counts):
            plt.text(bar.get_x() + bar.get_width() / 2 - 0.1, bar.get_height() + 0, str(count), ha='center', va='bottom')
        plt.tight_layout()
        plt.show()

    if os.path.exists(os.path.join(root_path, path)):
        for root_dir, dirs, files in os.walk(os.path.join(os.path.join(root_path, path))):
            for filename in files:
                xml_path = os.path.join(root_dir, filename)
                if xml_path.endswith(".xml"):
                    tree = ET.parse(xml_path)
                    root_xml = tree.getroot()

                    for obj in root_xml.findall('.//object'):
                        class_label = obj.find('name').text

                        if class_label in classes:
                            classes[class_label] += 1
                        else:
                            classes[class_label] = 1
        plot(classes)
    else:
        print("Path does not exist")

count_class_distribution("data/labels/train")

In [None]:
xml_files = []
for (r,d,f) in os.walk("data\\labels\\train"):
    for i in f:
        if i.endswith(".xml"):
            xml_files.append(
                os.path.join(r,i)
            )
xml_files

In [None]:
def get_class(xml_files, class_name):
    valid_files = []
    # Loop through all files in the directory
    for file_path in xml_files:
        # Parse the XML file
        tree = ET.parse(file_path)
        root = tree.getroot()
        # Get all object classes in the file
        classes = set(
            [obj.find('name').text for obj in root.findall('object')])
        # Check if the file contains only 'tm', 'tnb', or 'other' classes
        if classes.issubset({class_name}):
            if len(root.findall('object')) <=1:
            # Store the file path and class labels
                for class_label in classes:
                    valid_files.append({"image": root.find("path").text, 'annot': os.path.join(
                        os.getcwd(), file_path), 'Class': class_label})

    # Create a pandas DataFrame from the valid files
    df = pd.DataFrame(valid_files)
    return df

In [None]:
df_tm = get_class(xml_files, "tm")
df_tnb = get_class(xml_files, "tnb")
df_other = get_class(xml_files, "other")

In [None]:
df_tm.head()

In [None]:
df_tnb.head()

In [None]:
df_other.head()

In [None]:
print(f"tm length = {df_tm.shape}")
print(f"tnb length = {df_tnb.shape}")
print(f"other length = {df_other.shape}")

In [None]:

def augment_image_with_xml(image_path: str, xml_path: str, ip: str, ap: str, c: int, class_name: str, rm: list):
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    tree = ET.parse(xml_path)
    root = tree.getroot()

    # Find and remove objects with name
    for obj in root.findall('object'):
        name = obj.find('name').text
        if name in rm:
            root.remove(obj)

    bounding_boxes = []
    for obj in root.findall('.//object'):
        xmin = int(obj.find('bndbox/xmin').text)
        ymin = int(obj.find('bndbox/ymin').text)
        xmax = int(obj.find('bndbox/xmax').text)
        ymax = int(obj.find('bndbox/ymax').text)
        bounding_boxes.append([xmin, ymin, xmax, ymax])

    bbox_params = A.BboxParams(
        format='pascal_voc', min_area=0, min_visibility=0, label_fields=['category_id'])
    rotation = random.choice([(-5, 5), (-10,10)])
    augmentation = A.Compose([
        A.HorizontalFlip(p=.8),
        A.Rotate(limit=rotation, border_mode=4,
                 value=None, mask_value=None, p=.8)
    ], bbox_params=bbox_params)

    augmented = augmentation(image=image, bboxes=bounding_boxes, category_id=[
                             1]*len(bounding_boxes))
    augmented_image = augmented['image']
    augmented_boxes = augmented['bboxes']

    for obj, new_bbox in zip(root.findall('.//object'), augmented_boxes):
        obj.find('bndbox/xmin').text = str(int(new_bbox[0]))
        obj.find('bndbox/ymin').text = str(int(new_bbox[1]))
        obj.find('bndbox/xmax').text = str(int(new_bbox[2]))
        obj.find('bndbox/ymax').text = str(int(new_bbox[3]))

    image_name = re.findall(re.compile(r"[a-zA-Z-0-9\._]+"), image_path)
    imgName = image_name[-1].split('.')
    imgName = imgName[0] + "_AUG_set1" + class_name + str(c) + '.'+imgName[1]
    image_name[-1] = imgName
    image_name = os.path.join(*image_name)

    xmal_name = re.findall(re.compile(r"[a-zA-Z-0-9\._]+"), xml_path)
    xmalName = xmal_name[-1].split('.')
    xmalName = xmalName[0] + "_AUG_set1" + class_name + str(c) + '.'+xmalName[1]
    xmal_name[-1] = xmalName
    xmal_name = os.path.join(*xmal_name)

    output_image_path = os.path.join(ip, os.path.basename(image_name))
    output_xml_path = os.path.join(ap, os.path.basename(xmal_name))
    cv2.imwrite(output_image_path, cv2.cvtColor(
        augmented_image, cv2.COLOR_RGB2BGR))
    tree.write(output_xml_path)


def apply_augment(df, image_save_path, annot_save_path, loop, class_name, rm_class):
    df = df.reset_index().drop('index', axis=1)
    root_path = os.getcwd()
    if os.path.exists(os.path.join(root_path, image_save_path)) and os.path.exists(os.path.join(root_path, annot_save_path)):
        for i in range(df.shape[0]):
            data_ = df.iloc[i, :]
            image_path = data_['image']
            xml_path = data_['annot']
            for j in range(loop):
                augment_image_with_xml(
                    image_path, xml_path, image_save_path, annot_save_path, j, class_name, rm_class)

        print("annotation done")

In [None]:
apply_augment(df_tm, 'data/images/train',
              'data/labels/train', 3, "tm", ["tnb", "other"])
apply_augment(df_other, 'data/images/train',
              'data/labels/train', 5, "other", ["tm", "tnb"])
apply_augment(df_tnb, 'data/images/train','data/labels/train', 7, "tnb", ["tm", "other"])

In [None]:
def remove_augmentaed_data(path):
    root_path = os.path.join(os.getcwd(), path)
    for root, dirs, files in os.walk(root_path):
        for i in files:
            if "AUG" in i:
                os.remove(os.path.join(root, i))
remove_augmentaed_data('data/images/train')
remove_augmentaed_data('data/labels/train')

In [None]:
def change_xml_filename(path, image_path):
    '''
        path: xml folder location
        image_path: image path, to be written in xml's <filename>  (relative path only)
    '''
    images = []
    for root, dirs, files in os.walk(os.path.join(os.getcwd(), image_path)):
        for i in files:
            images.append(i)
    root_path = os.getcwd()
    if os.path.exists(os.path.join(root_path, path)):
        for root_dir, dirs, files in os.walk(os.path.join(root_path, path)):
            for i in files:
                if i.endswith(".xml"):
                    xml_file_path = os.path.join(root_dir, i)
                    tree = ET.parse(xml_file_path)
                    root_xml = tree.getroot()  # Change the variable name to avoid conflict
                    filename_element = root_xml.find('filename')
                    if filename_element is not None:
                        n = i.split('.')[0]
                        ext = [i for i in images if n in i][0].split('.')[-1]
                        n = i.split('.')[0] + '.' +ext
                        filename_element.text = n
                    else:
                        print("Element 'filename' not found in the XML file.")
                    tree.write(xml_file_path)

In [None]:
change_xml_filename("data/labels/train", "data/images/train")

In [None]:
def change_xml_path(path, image_path):
    '''
        path: xml folder location
        image_path: image path, to be written in xml's <path>  (relative path only)
    '''
    root_path = os.getcwd()
    if os.path.exists(os.path.join(root_path, path)):
        for root_dir, dirs, files in os.walk(os.path.join(root_path, path)):
            for i in files:
                if i.endswith(".xml"):
                    xml_file_path = os.path.join(root_dir, i)
                    tree = ET.parse(xml_file_path)
                    root_xml = tree.getroot()  # Change the variable name to avoid conflict
                    path_element = root_xml.find('path')
                    filename_element = root_xml.find('filename')
                    if path_element is not None:
                        image_path_ = os.path.join(os.getcwd(),image_path, filename_element.text)
                        path_element.text = image_path_
                    else:
                        print("Element 'filename' not found in the XML file.")
                    tree.write(xml_file_path)

In [None]:
change_xml_path("data/labels/train", "data/images/train")
change_xml_path("data/labels/test", "data/images/test")
change_xml_path("data/labels/val", "data/images/val")