In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import os, sys
import glob
import math
import cv2
from PIL import Image
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator
from pathlib import Path
import matplotlib.pyplot as plt

  from pandas.core import (


In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
detection_model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn(pretrained=False, pretrained_backbone=False)
num_classes = 6

in_features = detection_model.roi_heads.box_predictor.cls_score.in_features
detection_model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

detection_model.load_state_dict(torch.load('detection_model.pt'))
detection_model = detection_model.to(device)



In [11]:
cropped_images_dir = 'cropped_images'
cropped_images_detailed_dir = 'cropped_images_detailed'
if not Path(cropped_images_dir).exists():
    os.mkdir(cropped_images_dir)
    
if not Path(cropped_images_detailed_dir).exists():
    os.mkdir(cropped_images_detailed_dir)

In [12]:
train_images_dir = 'cvt_png'
all_files = glob.glob(f'{train_images_dir}/*/*/*.png')
all_files.sort()
all_files[:5]

['cvt_png/100206310/Axial T2/000.png',
 'cvt_png/100206310/Axial T2/001.png',
 'cvt_png/100206310/Axial T2/002.png',
 'cvt_png/100206310/Axial T2/003.png',
 'cvt_png/100206310/Axial T2/004.png']

In [13]:
detection_model.eval()
counters = {'_'.join(image_path.split('/')[1:3]):0 for image_path in tqdm(all_files)}
df_boxes = pd.DataFrame()
dir_levels_cnts = {}
bbx1, bbx2, bby1, bby2, fns, all_levels = [], [], [], [], [], []

i_to_level = {1: 'L1/L2', 2: 'L2/L3', 3: 'L3/L4', 4: 'L4/L5', 5: 'L5/S1'}

for image_path in tqdm(all_files):
    splitted = image_path.split('/')
    si, typ = splitted[1], splitted[2]
    if typ != 'Sagittal T1':
        continue
    if not Path(os.path.join(cropped_images_dir, si)).exists():
        os.mkdir(os.path.join(cropped_images_dir, si))
        
    if not Path(os.path.join(cropped_images_dir, si, typ)).exists():
        os.mkdir(os.path.join(cropped_images_dir, si, typ))
        
    if not Path(os.path.join(cropped_images_detailed_dir, si)).exists():
        os.mkdir(os.path.join(cropped_images_detailed_dir, si))
        
    if not Path(os.path.join(cropped_images_detailed_dir, si, typ)).exists():
        os.mkdir(os.path.join(cropped_images_detailed_dir, si, typ))
        
    image_t = cv2.imread(image_path).astype(np.float32)
    image = image_t / 255.0
    image = torch.tensor(image).permute(2, 0, 1).unsqueeze(0).to(device)
    
    boxes, levels = [], []
    with torch.no_grad():
        predictions = detection_model(image)
        
        for i in range(len(predictions)):
            for box in predictions[i]['boxes'].detach().cpu().numpy():
                boxes.append(box)
                
        for i in range(len(predictions)):
            for level in predictions[i]['labels'].detach().cpu().numpy():
                levels.append(i_to_level[level])
                all_levels.append(i_to_level[level])
                
    
    for box in boxes:
        bbx1.append(int(box[0]))
        bbx2.append(int(box[2]))
        bby1.append(int(box[1]))
        bby2.append(int(box[3]))
        fns.append(image_path)
        
#     for box in boxes:
#         image_t = cv2.rectangle(image_t, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (255, 255, 255))
        
#     cv2.imwrite('test.png', image_t)
#     image_t = Image.open('test.png').convert('L')
#     plt.imshow(image_t)
#     plt.show()
    for i, box in enumerate(boxes):
        image_cropped = image_t[int(box[1]):int(box[3]), int(box[0]):int(box[2])]
        image_cropped = cv2.resize(image_cropped, (192, 192))
        level = '_'.join(levels[i].split('/'))
        
        dir_level = os.path.join(cropped_images_detailed_dir, str(si), str(typ), level)
        if not Path(dir_level).exists():
            os.mkdir(dir_level)
            dir_levels_cnts[dir_level] = 0
        
        new_level_image_path = os.path.join(dir_level, 
            (3 - len(str(dir_levels_cnts[dir_level]))) * '0' + str(dir_levels_cnts[dir_level]) + '.png')
        new_image_path = os.path.join(cropped_images_dir, str(si), str(typ),
                (3 - len(str(counters[str(si)+'_'+str(typ)]))) * '0' + str(counters[str(si)+'_'+str(typ)]) + '.png')

        counters[str(si)+'_'+str(typ)] += 1
        dir_levels_cnts[dir_level] += 1
        
        cv2.imwrite(new_image_path, image_cropped)
        cv2.imwrite(new_level_image_path, image_cropped)

df_boxes['filename'] = fns
df_boxes['x1'] = bbx1
df_boxes['y1'] = bby1
df_boxes['x2'] = bbx2
df_boxes['y2'] = bby2
df_boxes['level'] = all_levels
df_boxes.to_csv('boxes.csv', index=False)
    
            

  0%|          | 0/147218 [00:00<?, ?it/s]

  0%|          | 0/147218 [00:00<?, ?it/s]

In [14]:
dir_levels_cnts

{'cropped_images_detailed/100206310/Sagittal T1/L5_S1': 12,
 'cropped_images_detailed/100206310/Sagittal T1/L4_L5': 10,
 'cropped_images_detailed/100206310/Sagittal T1/L1_L2': 8,
 'cropped_images_detailed/100206310/Sagittal T1/L2_L3': 8,
 'cropped_images_detailed/100206310/Sagittal T1/L3_L4': 6,
 'cropped_images_detailed/1002894806/Sagittal T1/L5_S1': 19,
 'cropped_images_detailed/1002894806/Sagittal T1/L4_L5': 14,
 'cropped_images_detailed/1002894806/Sagittal T1/L3_L4': 7,
 'cropped_images_detailed/1002894806/Sagittal T1/L1_L2': 6,
 'cropped_images_detailed/1002894806/Sagittal T1/L2_L3': 6,
 'cropped_images_detailed/1004726367/Sagittal T1/L5_S1': 15,
 'cropped_images_detailed/1004726367/Sagittal T1/L4_L5': 8,
 'cropped_images_detailed/1004726367/Sagittal T1/L2_L3': 8,
 'cropped_images_detailed/1004726367/Sagittal T1/L3_L4': 5,
 'cropped_images_detailed/1004726367/Sagittal T1/L1_L2': 7,
 'cropped_images_detailed/1008446160/Sagittal T1/L5_S1': 7,
 'cropped_images_detailed/1008446160/Sag