# Process images and create labels for YOLO

In [1]:
from azure.storage.blob import BlobServiceClient
from pydicom import dcmread
from io import BytesIO
import numpy as np
import cv2
import pylidc as pl

try:
    with open('/home/andrew/ITRI-LungCancer/keys.txt', 'r') as file:
        data = file.read().splitlines()
        account_name    = data[0]
        account_key     = data[1]
        container_name  = data[2]
    
    blob_service_client = BlobServiceClient(account_url=f"https://{account_name}.blob.core.windows.net", credential=account_key)
    container_client = blob_service_client.get_container_client(container_name)
    blob_name_list = container_client.list_blob_names()
except Exception as ex:
    print('Exception:')
    print(ex)

# Organize file system

In [2]:
# Clean folders
!rm -rf /home/andrew/ITRI-LungCancer/dataset_classify_rgb/

# Recreate dataset structure
!mkdir -p /home/andrew/ITRI-LungCancer/dataset_classify_rgb/images/{train,val,test}
!mkdir -p /home/andrew/ITRI-LungCancer/dataset_classify_rgb/labels/{train,val,test}

# Helper Functions for Creating Dataset

In [3]:
def window_img(img, window_center, window_width):
    win_min = window_center - window_width / 2.0
    win_max = window_center + window_width / 2.0
    img = np.clip(img, win_min, win_max)
    img = (img - win_min) / (win_max - win_min)
    img = np.uint8(img * 255)
    return img

def rescale_img(ds, img):
    if 'RescaleIntercept' in ds and 'RescaleSlope' in ds:
        img = img * ds.RescaleSlope + ds.RescaleIntercept
    return img

def change_file_num(blob_name, val):
    path = blob_name[0:-7]
    num = int(blob_name[-7:-4])
    return path+str(num+val).zfill(3)+'.dcm'
    
def get_dicom(blob_name):
    blob_client = container_client.get_blob_client(blob_name)
    blob_data = blob_client.download_blob().readall()
    blob_stream = BytesIO(blob_data)
    return dcmread(blob_stream)

def get_image(blob_name):
    ds = get_dicom(blob_name)
    image = rescale_img(ds, ds.pixel_array)
    image = window_img(image, -300, 2000)
    return image

def save_if_annotated(scan, ds, slice_location, blob_name, data_string):
    scan_name = blob_name.split('/')[0]
    slice_num = blob_name.split('/')[3].split('-')[1]
    
    for ann_count, ann in enumerate(scan.annotations):
        for contour in ann.contours:
            if abs(contour.image_z_position - slice_location) < scan.slice_spacing and ann.boolean_mask().sum() > 300:
                bbox = ann.bbox()
                bbox_x_center = (bbox[1].start + bbox[1].stop) / ds.Columns / 2
                bbox_y_center = (bbox[0].start + bbox[0].stop) / ds.Rows / 2
                bbox_width = (bbox[1].stop - bbox[1].start)/ds.Columns
                bbox_height = (bbox[0].stop - bbox[0].start)/ds.Rows
                
                image_base = rescale_img(ds, ds.pixel_array)
                image_base = window_img(image_base, -300, 2000)
                
                image_prev = get_image(change_file_num(blob_name, -2))
                image_next = get_image(change_file_num(blob_name, 2))
                
                image = np.stack([image_prev, image_base, image_next], axis=-1)
                
                filename = f"{scan_name}_{slice_num}"
                
                image_path = f'/home/andrew/ITRI-LungCancer/dataset_classify_rgb/images/{data_string}/{filename}.png'
                cv2.imwrite(image_path, image)
                
                label_path = f'/home/andrew/ITRI-LungCancer/dataset_classify_rgb/labels/{data_string}/{filename}.txt'
                label_txt = f"0 {bbox_x_center} {bbox_y_center} {bbox_width} {bbox_height}"
                with open(label_path, 'w') as file:
                    file.write(label_txt)
                return True
    return False

def create_dataset(count, data_string):
    # temp = 0
    while(count > 0):
        blob_name = next(blob_name_list)
        ds = get_dicom(blob_name)
        
        scan = pl.query(pl.Scan).filter(pl.Scan.patient_id == ds.PatientID).first()
        slice_location = ds.ImagePositionPatient[2]
        
        if save_if_annotated(scan, ds, slice_location, blob_name, data_string):
            count -= 1
            if count > 0:
                print(f"{data_string}: {count}    ", end='\r', flush=True)
            # temp += 1
        # if temp > 3:
        #     break
    print(f"{data_string} done!")

# Create Datasets

In [4]:
total = 8000

train_size  = int(total*0.95*0.8)
val_size    = int(total*0.95*0.2)
test_size   = int(total*0.05)

print(train_size, val_size, test_size)

create_dataset(train_size, "train")
create_dataset(val_size, "val")
create_dataset(test_size, "test")

6080 1520 400
train: 5777    

In [None]:
from ultralytics import YOLO

# !rm -rf runs/

model = YOLO("/home/andrew/ITRI-LungCancer/YOLO/yolov8m.pt")
results = model.train(data="/home/andrew/ITRI-LungCancer/YOLO/dataset_classify_rgb.yaml",epochs=300,patience=20,cache=True,lr0=1E-2,imgsz=512)
model.save('/home/andrew/ITRI-LungCancer/YOLO/model_classify_rgb.pt')