In [None]:
###
# This code creates training data for semi-supervised learning with pseudo-labels 
# based on a pre-trained model. In this code, the prediction threshold is set to 0.5.
###

In [None]:
from __future__ import print_function

import glob, os, random, torch, timm, shutil, pickle, time, yaml, json, gc, cv2, torchvision
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pytorch_lightning as pl
import itertools
import matplotlib.patches as patches

from PIL import Image, ImageEnhance, ImageOps, ImageDraw
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from tqdm.notebook import tqdm
from pathlib import Path
from pprint import pprint
from tempfile import TemporaryDirectory
from concurrent.futures import ThreadPoolExecutor
import concurrent.futures
from collections import defaultdict

from ultralytics import YOLO

os.environ["CUDA_DEVICE_ORDER"] = 'PCI_BUS_ID'
os.environ["CUDA_VISIBLE_DEVICES"] = '0'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device} device")

seed = 0

Image.MAX_IMAGE_PIXELS = None
%matplotlib inline


In [None]:
print(torch.__version__)
print(torch.cuda.is_available())
print(torch.cuda.device_count())

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(seed)

In [None]:
##########

In [None]:
DEVICE = 0
BATCH = 16
IOU = 0.5
CONF = 0.5
WORKERS = 24

In [None]:
def perform_nms(boxes, scores, iou_thr=IOU):
    if len(boxes) == 0:
        return np.array([]), np.array([])

    boxes_tensor = torch.tensor(boxes)
    scores_tensor = torch.tensor(scores)
    nms_indices = torchvision.ops.nms(boxes_tensor, scores_tensor, iou_threshold=iou_thr)

    return boxes[nms_indices.numpy()], scores[nms_indices.numpy()]

def process_image(img_path, label_dir, conf=CONF, iou=IOU):
    label_file = os.path.splitext(os.path.basename(img_path))[0] + ".txt"
    label_path = os.path.join(label_dir, label_file)

    ensemble_results = []
    for model in models:
        results = model(img_path, conf=conf, iou=iou, imgsz=640, max_det=100, augment=False, stream=False)
        ensemble_results.append(results)

    combined_boxes = []
    combined_scores = []
    combined_classes = []
    for result in ensemble_results:
        boxes = result[0].boxes.xyxy.cpu().numpy()
        scores = result[0].boxes.conf.cpu().numpy()
        classes = result[0].boxes.cls.cpu().numpy()
        combined_boxes.append(boxes)
        combined_scores.append(scores)
        combined_classes.append(classes)
    
    combined_boxes = np.concatenate(combined_boxes)
    combined_scores = np.concatenate(combined_scores)
    combined_classes = np.concatenate(combined_classes)
    
    high_conf_indices = combined_scores >= conf
    combined_boxes = combined_boxes[high_conf_indices]
    combined_scores = combined_scores[high_conf_indices]
    combined_classes = combined_classes[high_conf_indices]

    if len(combined_boxes) == 0:
        open(label_path, 'w').close()  
        return

    final_boxes, final_scores = perform_nms(combined_boxes, combined_scores, iou_thr=iou)
    final_classes = combined_classes[high_conf_indices]

    orig_img = cv2.imread(img_path)
    img_height, img_width, _ = orig_img.shape

    with open(label_path, "w") as f:
        for bbox, cls_id, score in zip(final_boxes, final_classes, final_scores):
            x_min, y_min, x_max, y_max = bbox

            x_center = (x_min + x_max) / 2 / img_width
            y_center = (y_min + y_max) / 2 / img_height
            width = (x_max - x_min) / img_width
            height = (y_max - y_min) / img_height
            f.write(f"{int(cls_id)} {x_center} {y_center} {width} {height} {score}\n")

def process_subdirectory(subdir):
    img_dir = os.path.join(subdir, "img")
    if not os.path.exists(img_dir):
        print(f"Image directory not found in {subdir}")
        return

    subdir_name = os.path.basename(subdir)
    label_dir = os.path.join(output_base_dir, f"labels_{subdir_name}")
    os.makedirs(label_dir, exist_ok=True)

    image_files = [f for f in os.listdir(img_dir) if f.endswith(".png")]
    
    def safe_process_image(img_file):
        try:
            process_image(os.path.join(img_dir, img_file), label_dir)
        except Exception as e:
            print(f"Error processing file {img_file}: {e}")
            label_file = os.path.splitext(img_file)[0] + ".txt"
            label_path = os.path.join(label_dir, label_file)
            open(label_path, 'w').close()

    with ThreadPoolExecutor(max_workers=WORKERS) as executor:
        list(tqdm(executor.map(safe_process_image, image_files), total=len(image_files)))

    for img_file in image_files:
        label_file = os.path.splitext(img_file)[0] + ".txt"
        label_path = os.path.join(label_dir, label_file)
        if not os.path.exists(label_path):
            print(f"Creating missing label file for {img_file}")
            open(label_path, 'w').close()

    label_files = [f for f in os.listdir(label_dir) if f.endswith(".txt")]
    print(f"Processed {subdir_name}: {len(image_files)} images, {len(label_files)} labels created")
    if len(image_files) != len(label_files):
        print(f"Warning: Number of images ({len(image_files)}) does not match number of labels ({len(label_files)})")



In [None]:
# Load models
models_d = '/PATH/TO/YOUR/MODELS'
models_f = sorted(glob.glob(models_d + '/*.pt'))
models = [YOLO(model_f) for model_f in models_f]
models_f

In [None]:
# Directory Configuration
# Under the specified directory, there are subdirectories, each containing an img directory, which contains PNG format images

base_dir = "/PATH/TO/YOUR/DIRECTORY"
output_base_dir = "/PATH/TO/YOUR/OUTPUT/DIRECTORY"

subdirs = [d for d in glob.glob(os.path.join(base_dir, "*")) if os.path.isdir(d)]
for subdir in subdirs:
    process_subdirectory(subdir)

In [None]:
### file selection ###

In [None]:
base_dir = "/PATH/TO/YOUR/DIRECTORY"
output_base_dir = "/PATH/TO/YOUR/OUTPUT/DIRECTORY"

In [None]:
pseudo_labels = sorted(glob.glob(output_base_dir + '/*/*.txt'))
print(pseudo_labels[0])
print(len(pseudo_labels))

In [None]:
imgs = sorted(glob.glob(base_dir + '/*/*/*.png'))
print(imgs[0])
print(len(imgs))

In [None]:
for i, j in zip(pseudo_labels, imgs):
    if os.path.basename(os.path.splitext(i)[0]) != os.path.basename(os.path.splitext(j)[0]):
        print(i)


In [None]:
####

In [None]:
CONF_THRESHOLD = 0.50  
output_dir = output_base_dir + '_threshold_' + str(CONF_THRESHOLD) + '_/labels'  # output directory path
classes_of_interest = [0, 1, 2, 3]  # target classes
os.makedirs(output_dir, exist_ok=True)

In [None]:
def process_label_file(input_path, output_path):
    with open(input_path, 'r') as f:
        lines = f.readlines()
    
    all_above_threshold = True
    contains_class_of_interest = False
    new_lines = []
    
    for line in lines:
        parts = line.strip().split()
        if len(parts) == 6:  # class x y w h conf
            cls, x, y, w, h, conf = parts
            cls = int(cls)
            conf = float(conf)
            
            if conf < CONF_THRESHOLD:
                all_above_threshold = False
                break
            
            if cls in classes_of_interest:
                contains_class_of_interest = True
            new_lines.append(f"{cls} {x} {y} {w} {h}\n")
    
    if all_above_threshold and contains_class_of_interest:
        with open(output_path, 'w') as f:
            f.writelines(new_lines)
        return True
    return False


copied_count = 0
for label_path in tqdm(pseudo_labels):
    relative_path = os.path.relpath(label_path, output_base_dir)
    output_path = os.path.join(output_dir, relative_path)

    os.makedirs(os.path.dirname(output_path), exist_ok=True)

    if process_label_file(label_path, output_path):
        copied_count += 1

print(f"Processed {len(pseudo_labels)} label files.")
print(f"Copied {copied_count} label files to {output_dir}")



In [None]:
###　Copy Corresponding Images ###

In [None]:
# Get list of selected labels and images
selected_labels = sorted(glob.glob(output_base_dir + '_threshold_' + str(CONF_THRESHOLD) + '_/labels/*/*.txt'))
imgs = sorted(glob.glob(base_dir + '/*/*/*.png'))

In [None]:
print(len(selected_labels))
print(selected_labels[0])
print('###')
print(imgs[0])
print(len(imgs))

In [None]:
# Create a set of image filenames
img_basenames = set(os.path.basename(img) for img in imgs)

output_img_dir = output_base_dir + '_threshold_' + str(CONF_THRESHOLD) + '_/images'
os.makedirs(output_img_dir, exist_ok=True)


In [None]:
# Copy images corresponding to labels
copied_count = 0
for label_path in tqdm(selected_labels, desc="Copying images"):
    label_basename = os.path.basename(label_path)
    img_basename = label_basename.replace('.txt', '.png')
    
    if img_basename not in img_basenames:
        print(f"Warning: Corresponding image not found for label {label_path}")
        continue
    
    original_img_path = next(img for img in imgs if os.path.basename(img) == img_basename)

    relative_dir = os.path.relpath(os.path.dirname(label_path), output_base_dir + '_threshold_' + str(CONF_THRESHOLD) + '_/labels')
    output_img_path = os.path.join(output_img_dir, relative_dir, img_basename)

    os.makedirs(os.path.dirname(output_img_path), exist_ok=True)

    shutil.copy2(original_img_path, output_img_path)
    copied_count += 1

print(f"Processed {len(selected_labels)} label files.")
print(f"Copied {copied_count} images to {output_img_dir}")