In [1]:
from tqdm import tqdm
import net as nn
import torch
import pydicom
import pylibjpeg
from PIL import Image as im
from PIL import ImageDraw
import numpy as np
import cv2
from ipywidgets import *
import xnat

running on GPU


In [2]:
net = nn.Net()
net.load_weights('trained_data.npy')

In [3]:
def get_data(project, url):
    file_list = []
   
    experiments = project.experiments
    
    experiment_info.max = len(experiments) - 1
    experiment_info.layout = Layout(display = 'inherit')
    
    for experiment_counter, experiment in enumerate(experiments.values()):
        experiment_info.value = experiment_counter
        
        
        if 'SessionData' in experiment._XSI_TYPE:
            scans = experiment.scans
            try:
                scan_info.max = len(scans) - 1
                scan_info.layout = Layout(display = 'inherit')

                for scan_counter, scan in enumerate(scans.values()):

                    scan_info.value = scan_counter

                    for resource in scan.resources.values():
                        resource_index = 0

                        if resource.format.upper() == 'DICOM':

                            file_list.append({
                                    'file': resource.files[0], 
                                    'experimentID': experiment.id,
                                    'projectID': project.id,
                                    'scanID': scan.id
                            })
                            
            except xnat.exceptions.XNATResponseError:
                pass
    return file_list



In [4]:
def input_image(imdir, project_data):
    images = []
    
    for file in os.listdir(imdir):
        filepath = os.path.join(imdir, file)
        image = im.open(filepath)
        rgbimg = im.new("RGB", image.size)
        rgbimg.paste(image)
        image = rgbimg
        image_draw = ImageDraw.Draw(image)
        images.append(image)
    
    text = []

    
    progress.max = image.width - 1
    progress.value = 0
    progress.layout = Layout(display = 'inherit')
    for x in range(0, image.width, 32):
        progress.value = x
        for y in range(0, image.height, 16):
            preped_images = torch.Tensor([]).to(nn.device)
            for image in images:
                input_image = image.crop((x, y, x + 64, y + 32)).save(os.path.join('TempImages', 'temp_image.jpg'))
                input_image = cv2.imread(os.path.join('TempImages', 'temp_image.jpg'), cv2.IMREAD_GRAYSCALE)
                input_image = cv2.resize(input_image, (128, 64))
                input_image = torch.Tensor(input_image).to(nn.device).view(1, 1, 128, 64)/255.0
                preped_images = torch.cat((preped_images, input_image), dim = 0)
                
            preped_images.view(-1, 1, 128, 64).to(nn.device)
            net_out = net(preped_images)
            num_text = 0
            total = 0
            for output in net_out:
                if torch.argmax(output) == 0:
                    num_text += 1
                total += 1
            
            if float(num_text) / float(total) > TEXT_PCT:
                text.append([x, y])
                
    progress.layout = Layout(display = 'none')
                
    #for positions in text:
     #   if [positions[0] + 32, positions[1] + 16] in text and [positions[0] - 32, positions[1] + 16] in text and [positions[0] + 32, positions[1] - 16] in text and [positions[0] - 32, positions[1] - 16] in text:
      #      image_draw.rectangle((positions[0], positions[1], positions[0] + 64, positions[1] + 32), outline = 'red')
            
    new_text = []
    for x in range(0, image.width, 32):
        for y in range(0, image.height, 16):
            check = [[x, y] in text, [x - 32, y] in text, [x, y - 16] in text, [x - 32, y - 16] in text]
            agreed = 0
            for i in check:
                if i:
                    agreed += 1
            if agreed >= 3:
                new_text.append([x, y])
                image_draw.rectangle((x - 32, y - 16, x, y), outline = 'red')
                
    text = new_text
        
    image.save(os.path.join('TempImages', 'temp_image.jpg'))
    
    image_display.layout = Layout(border_width = '2px', display = 'block')
    image_display.value = open(os.path.join('TempImages', 'temp_image.jpg'), 'rb').read()
    image_display.width = image.width
    image_display.height = image.height
    image_display.format = 'jpg'
    
    button_frame.layout = Layout(display = 'block')
    
    
    return text
                
    
            
            

In [5]:
def check_file(project):
    global current_experiment, current_scan, session
    
    try:
        experiment = project.experiments[current_experiment]

        current_scan += 1

        if current_scan >= len(experiment.scans):
            current_scan = 0
            current_experiment += 1
            experiment = project.experiments[current_experiment]
            while not 'SessionData' in experiment._XSI_TYPE:
                current_experiment += 1
                experiment = project.experiments[current_experiment]
                good_project = False
                while not good_project:
                    try:
                        if experiment.scans:
                            good_project = True
                    except:
                        current_experiment += 1
                        experiment = project.experiments[current_experiment]

        if current_experiment > len(project.experiments):
            print('finished')
            generate_report()
            return None
    
    
        scan = experiment.scans[current_scan]
    except xnat.exceptions.XNATResponseError:
        print('Skipping scan because of xnat response error', 'Scan: {}'.format(current_scan), 'Experiment: {}'.format(current_experiment), experiment)#, 'Scans: {}'.format(len(experiment.scan)))
        check_file(project)
        return None

    print('Experiment progress: ', current_scan + 1, '/', len(experiment.scans))
    print('Project progress: ', current_experiment + 1, '/', len(project.experiments))
    print('Scan URI: ', scan.url)
        
    info_text.value = 'Downloading files...'
    
    for file in os.listdir('TempImages'):
        filepath = os.path.join('TempImages', file)
        os.remove(filepath)
    
    progress.max = FILES_PER_SCAN - 1
    progress.value = 0
    progress.layout = Layout(display = 'inherit')
    
    for file_index, file in enumerate(scan.files.values()):
        progress.value = file_index
        if file_index < FILES_PER_SCAN:
            file.download(os.path.join('TempImages', str(file_index) + '.dcm'), verbose = False)
        else:
            break
            
    progress.layout = Layout(display = 'none')
    
    files = []
    
    try:
        directory = os.listdir('TempImages')
        progress.max = len(directory) - 1
        progress.value = 0
        progress.layout = Layout(display = 'inherit')
        for counter, file in enumerate(directory):
            progress.value = counter
            
            filename = os.path.join('TempImages', file)
            
            ds = pydicom.dcmread(filename, force = True)
            ds.PhotometricInterpretation = 'YBR_FULL'
            image = ds.pixel_array

            new_img = []
            max_value = np.amax(image)
            min_value = np.amin(image)

            info_text.value = 'Extracting pixel data...'

            #get maximum and minimum pixel values
            if (len(image.shape) == 2):
                for i in image:
                    row = []
                    for pixel in i:
                        row.append((pixel + (0 - min_value)) / ((max_value + (0 - min_value)) / 255.0))
                    new_img.append(row)

                new_img = np.array(new_img)
                os.remove(filename)
                cv2.imwrite(filename[:-4] + '.jpg', new_img)

            else:
                new_img = image[0]
                os.remove(filename)
                cv2.imwrite(filename[:-4] + '.jpg', new_img)
                
        progress.layout = Layout(display = 'none')

        info_text.value = 'Checking for text...'
        text = input_image('TempImages', file)
        info_text.value = 'Getting user input...'

        current_file = [file, text]
        
    except Exception as e:
        print(e)
        
        current_file_index += 1
        
        image_display.layout = Layout(border_width = '2px', display = 'none')
        button_frame.layout = Layout(display = 'none')

        if current_file_index < len(files):
            check_file(files[current_file_index])

        else:
            generate_report()
        

In [6]:
def submit_url():
    global url
            
    url = url_entry.value
    
def submit_project():
    global project
    
    project = project_entry.value
    
def submit_username():
    global auth
    
    auth[0] = username_entry.value
    
def submit_password():
    global auth
    
    auth[1] = password_entry.value
    


In [7]:
def start(b):
    global files, good, bad, current_file, current_file_index, project, auth, url, session, current_experiment, current_scan, sproject
    submit_url()
    submit_project()
    submit_username()
    submit_password()
    if url:
        if project:
            if auth[0]:
                if auth[1]:
                    b.disabled = True
                    
                    files = []

                    good = []
                    bad = []
                    
                    session = xnat.connect(url, user = auth[0], password = auth[1])
                        
                    sproject = session.projects[project]
                        
                    current_file = None
                    current_file_index = 0
                    
                    currect_experiment = 0
                    current_scan = -1
                    current_resource = 0
                    
                    #info_text.value = 'Getting project info...'
                    #files = get_data(project, url)

                    #experiment_info.layout = Layout(display = 'none')
                    #scan_info.layout = Layout(display = 'none')

                    #files_info.max = len(files)
                    #files_info.value = current_file_index
                    #files_info.layout = Layout(display = 'inherit')

                    check_file(sproject)
                    
                    
                    

In [8]:
def good_click(b):
    global good, sproject
    
    good.append(current_file)
    
    image_display.layout = Layout(border_width = '2px', display = 'none')
    button_frame.layout = Layout(display = 'none')
    
    check_file(sproject)
    
def bad_click(b):
    global bad, sproject
    
    bad.append(current_file)
    
    image_display.layout = Layout(border_width = '2px', display = 'none')
    button_frame.layout = Layout(display = 'none')
    
    check_file(sproject)

In [9]:
def generate_report():
    global good, bad, auth
    
    frames = []
    
    for i in good:
        print(i[1], type(i[1]))
        project = Label(
            value = 'Project: ' + str(i[0]['projectID'])
        )
        experiment = Label(
            value = 'Experiment: ' + str(i[0]['experimentID'])
        )
        scan = Label(
            value = 'Scan: ' + str(i[0]['scanID'])
        )
        r = requests.get(i[0]['file'], auth = tuple(auth), allow_redirects = True)
        open('file.dcm', 'wb').write(r.content)
        
        ds = pydicom.dcmread('file.dcm')
        ds.PhotometricInterpretation = 'YBR_FULL'
        image = ds.pixel_array

        new_img = []
        max_value = np.amax(image)
        min_value = np.amin(image)

        #get maximum and minimum pixel values
        if (len(image.shape) == 2):
            for r in image:
                row = []
                for pixel in r:
                    row.append((pixel + (0 - min_value)) / ((max_value + (0 - min_value)) / 255.0))
                new_img.append(row)

            new_img = np.array(new_img)
            cv2.imwrite('tempImage.jpg', new_img)

        else:
            new_img = image[0]
            cv2.imwrite('tempImage.jpg', new_img)

        image = im.open('tempImage.jpg')
        rgbimg = im.new("RGB", image.size)
        rgbimg.paste(image)
        image = rgbimg
        image_draw = ImageDraw.Draw(image)
        
        for text in i[1]:
            image_draw.rectangle((text[0] - 32, text[1] - 16, text[0], text[1]), outline = 'red')
            
        image.save('tempImage.jpg')
        
        file = open('tempImage.jpg', 'rb')
        img = file.read()
        
        image = Image(
            value = img,
            width = image.width,
            height = image.height
        )
        
        frames.append(VBox(
            children = (project, experiment, scan, image),
            layout = Layout(padding = '10px 10px 10px 10px')
        ))
        
    app.children = frames

In [10]:
url = None
project = None
auth = [None, None]

files = []

good = []
bad = []

current_file = None
current_file_index = 0

current_experiment = 0
current_scan = 0

FILES_PER_SCAN = 10
TEXT_PCT = .50

In [11]:
url_entry = Text(
    value = 'https://xnat-demo.radiologics.com',
    placeholder = 'ex: https://xnat-demo.radiologics.com',
    description = 'URL:',
)

project_entry = Text(
    value = 'RADVAL',
    placeholder = '',
    description = 'Project:',
)

username_entry = Text(
    value = 'danAdmin',
    placeholder = '',
    description = 'Username: '
)

password_entry = Password(
    value = 'abCD1234!',
    placeholder = '',
    description = 'Password:',
)

entry_frame = VBox(
    children = (url_entry, project_entry, username_entry, password_entry),
)

start_button = Button(
    description = 'Start',
    button_style = 'success',
    icon = 'play'
)

start_button.on_click(start)

info_text = Label(
    layout = Layout(padding = '0 0 0 25px')
)

start_info_frame = HBox(
    children = (start_button, info_text),
    layout = Layout(padding = '10px 10px 10px 10px')
)

experiment_info = IntProgress(
    value = 0,
    min = 0,
    max = 0,
    step = 1, 
    description = 'Experiments: ',
    orientation = 'horizontal',
    layout = Layout(display = 'none')
)
scan_info = IntProgress(
    value = 0,
    min = 0,
    max = 0,
    step = 1, 
    description = 'Scans: ',
    orientation = 'horizontal',
    layout = Layout(display = 'none')
)

files_info = IntProgress(
    value = 0,
    min = 0, 
    max = 0,
    step = 1, 
    description = 'Files: ',
    orientation = 'horizontal',
    layout = Layout(display = 'none')
)

progress = IntProgress(
    orientation = 'horizontal',
    layout = Layout(display = 'none')
)

image_display = Image(
    layout = Layout(border_width = '2px', display = 'none')
)

good_button = Button(
    description = 'Good',
    button_style = 'success',
    icon = 'thumbs-up'
)

good_button.on_click(good_click)

bad_button = Button(
    description = 'Bad',
    button_style = 'danger',
    icon = 'thumbs-down'
)

bad_button.on_click(bad_click)

button_frame = HBox(
    children = (good_button, bad_button),
    layout = Layout(display = 'none')
)

app = VBox(
    children = (entry_frame, start_info_frame, progress, image_display, button_frame)
)

display(app)

VBox(children=(VBox(children=(Text(value='https://xnat-demo.radiologics.com', description='URL:', placeholder=…

Experiment progress:  1 / 3
Project progress:  1 / 14
Scan URI:  /data/archive/projects/RADVAL/subjects/XNAT02_S00001/experiments/XNAT02_E00001/scans/4
