# Import and Loads

In [1]:
from ultralytics import YOLO
import os 
import sys
sys.path.append("/mnt/RAID/projects/FjordVision")
from models.probability_tree import ProbabilityTree
import torch
from anytree.importer import JsonImporter
from preprocessing.preprocessing import load_ground_truth_mask_xyn, convert_polygon_to_mask, calculate_binary_mask_iou
from utils.metrics import calculate_hierarchical_precision_recall, calculate_weighted_f1_score, hierarchical_similarity
torch.cuda.empty_cache()

# Function to divide the data into chunks of size n
def chunks(lst, n):
    """Yield successive n-sized chunks from lst."""
    for i in range(0, len(lst), n):
        yield lst[i:i + n]

# Define the file path
weights_path = '/mnt/RAID/projects/FjordVision/runs/segment/Yolov8n-seg-train/weights/best.pt'

# Load the YOLO model weights
model = YOLO(weights_path)

importer = JsonImporter()
with open('data/ontology.json', 'r') as f:
    root = importer.read(f)

classes_file = '/mnt/RAID/datasets/label-studio/fjord/classes.txt'

species_names = []
with open(classes_file, 'r') as file:
    species_names = [line.strip() for line in file]

genus_names, class_names, binary_names = [], [], []
for node in root.descendants:
    if node.rank == 'genus':
        genus_names.append(node.name)
    elif node.rank == 'class':
        class_names.append(node.name)
    elif node.rank == 'binary':
        binary_names.append(node.name)

taxonomies = [species_names, genus_names, class_names, binary_names]

# Construct Probability Tree

In [2]:
# Usage example
ontology_path = 'data/ontology.json'  # Update this path as necessary
prob_tree = ProbabilityTree(ontology_path)

# Training Loop

In [4]:
# Define the image folder path
image_folder_path = '/mnt/RAID/datasets/The Fjord Dataset/fjord/images/test/'
frames = os.listdir(image_folder_path)
image_files_full_path = [image_folder_path + f for f in frames]

# Define the label folder path
label_folder_path = '/mnt/RAID/datasets/The Fjord Dataset/fjord/labels/test/'

classes = '/mnt/RAID/datasets/The Fjord Dataset/fjord/classes.txt'

class_index = []
with open(classes, 'r') as file:
    for line_number, line in enumerate(file, start=1):
        class_name = line.strip()
        class_index.append(class_name)

Y = []
Yhat = []
confidences = []
batch_size = 50

# Loop through batches of images
for image_batch in chunks(image_files_full_path, batch_size):

    with torch.no_grad():
        predictions = model(image_batch, stream=True)

    # Loop through the files in the image folder
    for file_name, prediction in zip(image_batch, predictions):
        # Check if the file is an image file
        if file_name.endswith('.jpg') or file_name.endswith('.png'):
            # Construct the corresponding label file name
            shape = prediction.orig_img.shape[:2]
            base_file_name = file_name.split('/')[-1].replace('.jpg', '.txt')
            label_file_path = label_folder_path + base_file_name
        
            # check if predictions are empty
            if len(prediction.boxes.cls) == 0:
                continue

            GT = load_ground_truth_mask_xyn(label_file_path)
            visited = len(GT)*[None]

            for cls, mask, conf in zip(prediction.boxes.cls, prediction.masks.xyn, prediction.boxes.conf):
                confidences.append(conf.item())
                m = convert_polygon_to_mask(mask, shape)
                best_iou = 0

                # calculate iou and find the best mask
                for idx, (gcls, gmsk) in enumerate(GT):
                    g = convert_polygon_to_mask(gmsk, shape)
                    iou = calculate_binary_mask_iou(m, g)

                    if iou > best_iou and iou > 0.5:
                        best_iou = iou
                        best_g = g
                        best_gcls = gcls
                        visited[idx] = True
                        best_idx = idx

                if best_idx is not None and best_iou > 0.5:
                    visited[best_idx] = True

                if best_g is None:
                    Y.append(None)
                    Yhat.append(int(cls.item()))
                else:
                    Y.append(best_gcls)
                    Yhat.append(int(cls.item()))

            for vis in visited:
                if vis is None:
                    Y.append(GT[idx][0])
                    Yhat.append(None)

    # After processing each batch, clear unused memory from CUDA
    torch.cuda.empty_cache()






































# Calculate scores without reclassification

In [5]:
# Calculate weighted precision, recall, and F1
precision, recall = calculate_hierarchical_precision_recall(Y, Yhat, confidences, taxonomies, prob_tree, threshold=0)
weighted_f1_score = calculate_weighted_f1_score(precision, recall)

In [6]:
precision

0.9625112107623318

In [7]:
recall

0.8988274706867672

In [8]:
weighted_f1_score

0.9295799047206583

# Update Predictions with uniform probability tree

In [27]:
# Calculate weighted precision, recall, and F1
precision, recall = calculate_hierarchical_precision_recall(Y, Yhat, confidences, taxonomies, prob_tree, threshold=0.30)
weighted_f1_score = calculate_weighted_f1_score(precision, recall)

# Scores with reclassifiation using uniform tree

In [28]:
precision

0.9602612955906369

In [29]:
recall

0.8975576662143826

In [30]:
weighted_f1_score

0.927851319365302

# Predict Using Probability Tree

In [31]:
label_folder_path = '/mnt/RAID/datasets/The Fjord Dataset/fjord/labels/train/'
frames = os.listdir(label_folder_path)
label_files_full_path = [label_folder_path + f for f in frames]

class_indexes = []

for label_file in label_files_full_path:
    with open(label_file, 'r') as file:
        lines = file.readlines()
        for line in lines:
            if line.strip():
                class_index = int(line.split()[0])
                class_indexes.append(class_index)

class_index_counts = {}
for class_index in class_indexes:
    if class_index in class_index_counts:
        class_index_counts[class_index] += 1
    else:
        class_index_counts[class_index] = 1

sorted_class_index_counts = dict(sorted(class_index_counts.items()))
sorted_class_index_counts
renamed_dict = {species_names[key]: value for key, value in sorted_class_index_counts.items()}
renamed_dict

{'asterias rubens': 1321,
 'asteroidea': 1399,
 'fucus vesiculosus': 2324,
 'henrica': 1210,
 'mytilus edulis': 3098,
 'myxine glurinosa': 1376,
 'pipe': 1519,
 'rock': 1081,
 'saccharina latissima': 1292,
 'tree': 2113,
 'ulva intestinalis': 1202,
 'urospora': 3314,
 'zostera marina': 4837}

In [32]:
ontology_path = 'data/ontology.json'  # Update this path as necessary
new_prob_tree = ProbabilityTree(ontology_path)

In [33]:
new_prob_tree.update_probabilities_with_instance_counts(renamed_dict)
new_prob_tree.print_tree()

object (Rank: root, Probability: 1.0000000000000002)
├── marine life (Rank: binary, Probability: 0.7692307692307694)
│   ├── Asteroidea (Rank: class, Probability: 0.23076923076923078)
│   │   ├── asterias (Rank: genus, Probability: 0.1304330214282821)
│   │   │   ├── asterias rubens (Rank: species, Probability: 0.053509944505205166)
│   │   │   └── asteroidea (Rank: species, Probability: 0.0)
│   │   │       └── asterias (Rank: species, Probability: 0.0)
│   │   └── Henrica (Rank: genus, Probability: 0.0490136509093855)
│   │       └── henrica (Rank: species, Probability: 0.0490136509093855)
│   ├── phaeophyceae (Rank: class, Probability: 0.15384615384615385)
│   │   ├── fucus (Rank: genus, Probability: 0.09413861546562968)
│   │   │   └── fucus vesiculosus (Rank: species, Probability: 0.09413861546562968)
│   │   └── saccharina (Rank: genus, Probability: 0.052335237169360393)
│   │       └── saccharina latissima (Rank: species, Probability: 0.052335237169360393)
│   ├── bivalia (Rank:

In [38]:
# Calculate weighted precision, recall, and F1
precision, recall = calculate_hierarchical_precision_recall(Y, Yhat, confidences, taxonomies, new_prob_tree, threshold=0.30)
weighted_f1_score = calculate_weighted_f1_score(precision, recall)

In [39]:
precision

0.9602612955906369

In [40]:
recall

0.8975576662143826

In [41]:
weighted_f1_score

0.927851319365302