In [1]:
from ct_model_class import Model

import numpy as np

import cv2

import pydicom
import matplotlib.pyplot as plt
import os
from PIL import Image
from torchvision.transforms.v2 import Resize
import math
from numba import njit, prange
from tqdm import tqdm
import torchvision
from torchvision.transforms import ToPILImage
import glob
import torchvision.transforms



In [2]:
def apply_model(arr, model):
    img = np.array(arr, dtype = float) 
    img = (img - img.min()) / (img.max() - img.min()) * 255.0  
    img = img.astype(np.uint8)
    img = np.expand_dims(img, 2)
    img = np.dstack((img, img[:, :, 0], img[:, :, 0]))

    aorta, arterial = model.process_images(img)
    
    if aorta[1].item() > 9.5 and arterial[1].item() > 9.5:
        return aorta, arterial, True
    return None, None, False

In [3]:
def take_aorta_diameter(image):
    area_in_pixels = 0
    for i in range(len(image)):
        for j in range(len(image[0])):
            # print(image[i][j])
            if image[i][j] > 50:
                area_in_pixels += 1
    diameter = math.sqrt(area_in_pixels/math.pi)
    
    display(Image.fromarray(crop_image(image)))

    return diameter*2

@njit(parallel=True, fastmath=True)
def check_mask_width_numba(image):
    # it is enough to check just one channel, since images are black-white
    max_width = 0
    for i in prange(len(image)):
        cur_width = 0
        for j in range(len(image[0])):
            if image[i][j] != 0:
                cur_width += 1
        if cur_width > max_width:
            max_width = cur_width
    return max_width


def crop_image(image):
    first_row = -1
    last_row = -1
    for i in range(len(image)):
        some_image = False
        for j in range(len(image[0])):
            if image[i][j] > 20:
                some_image = True
                break
        if some_image and first_row == -1:
            first_row = i
        if not(some_image) and first_row != -1 and last_row == -1:
            last_row = i
            break
            
    first_column = -1
    last_column = -1
    for i in range(len(image[0])):
        some_image = False
        for j in range(len(image)):
            if image[j][i] > 20:
                some_image = True
                break
        if some_image and first_column == -1:
            first_column = i
        if not(some_image) and first_column != -1 and last_column == -1:
            last_column = i
            break
            
    return image[first_row:last_row, first_column:last_column]
                
def find_highest_black_peak(image):
    heights = [0]*len(image[0])
    total_finished = 0
    for i in range(len(image)-1, 0, -1):
        for j in range(len(image[0])):
            if image[i][j] < 20 and heights[j] != -1:
                # print(i, j)
                heights[j] += 1
            elif heights[j] != -1:
                if total_finished == len(image[0])-1:
                    return i
                heights[j] = -1
                total_finished += 1
    return 0

                

def take_diameter(image, height):
    width = 0
    for i in range(len(image[0])):
        if height-10 > 0:
            if image[height-10][i] < 100 and width > 0:
                return width
            elif image[height-10][i] > 100:
                width += 1
        elif height-5 > 0:
            if image[height-5][i] < 100 and width > 0:
                return width
            elif image[height-5][i] > 100:
                width += 1
        else:
            if image[height][i] < 100 and width > 0:
                return width
            elif image[height][i] > 100:
                width += 1
        
    return width


def take_arterial_diameter(image):
    max_width = 0
    image = Image.fromarray(image)

    for rotation_angle in range(30):
        image_rotated = torchvision.transforms.functional.rotate(img=image, angle=rotation_angle)
        image_rotated_np = np.array(image_rotated)
        image_rotated_np_flattened = image_rotated_np.flatten()
        white_width = check_mask_width_numba(image_rotated_np)
        if white_width > max_width:
            max_width = white_width
            final_angle = rotation_angle


    image = torchvision.transforms.functional.rotate(img=image, angle=rotation_angle)

    image_cropped = crop_image(np.array(image))
    height = find_highest_black_peak(image_cropped)
    # color height where diameter is taken
    diameter = take_diameter(image_cropped, height)
    for i in range(len(image_cropped[height])):
        image_cropped[height-10][i] = 1
    display(Image.fromarray(image_cropped))
    return diameter

In [5]:
with open('list1.txt', 'rb') as f:
    diseases = f.readlines()
patient_disease_pairs = list(map(lambda x: x.decode('cp1251').split('\t')[2:4], diseases))

diseases = {}
for patient_disease in patient_disease_pairs:
    if patient_disease[0] == "История болезни":
        continue
    diseases[patient_disease[0]] = patient_disease[1].split("\r")[0]
diseases

{'76619-22': 'ЛАГ-СЗСТ',
 '89328-22': 'ИЛАГ',
 '92878-22': 'ЛАГ-СЗСТ',
 '90126-22': 'Болезнь Рандю-Ослера',
 '14459-23': 'Болезнь Рандю-Ослера',
 '320-23': 'Болезнь Рандю-Ослера',
 '66332-22': 'ХТЭЛГ',
 '47931-22': 'ИЛАГ',
 '72688-22': 'ХТЭЛГ',
 '74710-22': 'ЛАГ-СЗСТ',
 '74340-22': 'ХТЭЛГ',
 '77580-22': 'ЛАГ-СЗСТ',
 '79358-22': 'ХТЭЛГ',
 '76633-22': 'ХТЭЛГ',
 '81046-22': 'ИЛАГ',
 '88812-22': 'ИЛАГ',
 '92597-22': 'ХТЭЛГ',
 '90398-22': 'ИЛАГ',
 '92285-22': 'ХТЭЛГ',
 '97867-22': 'ХТЭЛГ',
 '4633-23': 'ИЛАГ',
 '105671-22': 'ИЛАГ',
 '7234-22': 'ХТЭЛГ',
 '11374-23': 'ЛАГ-СЗСТ',
 '12831-23': 'ЛАГ-СЗСТ',
 '11393-23': 'ХТЭЛГ',
 '17397-23': 'ИЛАГ',
 '23072-23': 'ИЛАГ',
 '21460-23': 'Болезнь Рандю-Ослера',
 '22625-23': 'ХТЭЛГ',
 '36639-23': 'ИЛАГ',
 '37113-23': 'ИЛАГ',
 '44704-23': 'ХТЭЛГ',
 '47704-23': 'ИЛАГ',
 '39311-23': 'ХТЭЛГ',
 '13274-23': 'ХТЭЛГ'}

In [None]:
transform = ToPILImage()
model = Model()
for patient in diseases.keys():
    print(diseases[patient])
    max_ratio = 0
    files = os.listdir(f'med1/{patient}/')
    for file in tqdm(files, position=0, leave=True):
        data = pydicom.dcmread((f'med1/{patient}/' + file))
        arr = data.pixel_array
        if len(arr.shape) == 2:
            aorta, arterial, confidence = apply_model(arr, model)
            if confidence:
                print(take_arterial_diameter(arterial[0])/take_aorta_diameter(aorta[0]))
        if len(arr.shape) == 3:
            for i in tqdm(range(0, len(arr)), position=0, leave=True):
                aorta, arterial, confidence = apply_model(arr[i], model)
                if confidence:
                    ratio = take_arterial_diameter(arterial[0])/take_aorta_diameter(aorta[0])
                    print(ratio)
                    if ratio > max_ratio:
                        max_ration = ratio
    if max_ratio > 1.0:
        print('pulmonary hypertension detected')