In [1]:
import yaml
import time
import os
import datetime
import warnings
import sys

warnings.simplefilter(action='ignore', category=FutureWarning)
sys.path.append("../pytorch3dunet/unet3d")
sys.path.append("../src/utils/")

import torch
import numpy as np


from torch import nn
from box import Box
from loguru import logger
from tqdm import tqdm
from torch.cuda import amp
from transformers import get_cosine_with_hard_restarts_schedule_with_warmup


import losses
from base_scripts import set_seed, create_train_test_loader, create_fold_from_data
from my_losses import MyDiceLoss, CategoricalCrossEntropyLoss
from my_metrics import print_metric_MeanIoU

from model import ResidualUNet3D

In [2]:
import os 
import pydicom
import numpy as np 
import matplotlib.pyplot as plt
import ipyvolume as ipv
import nibabel as nib

In [37]:
dicom_folder  = r"C:\KAGGLE\MEDICINE\data\train_images\55567\55583"

In [38]:
lst_files = os.listdir(dicom_folder)

In [39]:
def standardize_pixel_array(dcm: pydicom.dataset.FileDataset) -> np.ndarray:
    """
    Source : https://www.kaggle.com/competitions/rsna-2023-abdominal-trauma-detection/discussion/427217
    """
    # Correct DICOM pixel_array if PixelRepresentation == 1.
    pixel_array = dcm.pixel_array
    if dcm.PixelRepresentation == 1:
        bit_shift = dcm.BitsAllocated - dcm.BitsStored
        dtype = pixel_array.dtype 
        pixel_array = (pixel_array << bit_shift).astype(dtype) >>  bit_shift
#         pixel_array = pydicom.pixel_data_handlers.util.apply_modality_lut(new_array, dcm)

    intercept = float(dcm.RescaleIntercept)
    slope = float(dcm.RescaleSlope)
    center = int(dcm.WindowCenter)
    width = int(dcm.WindowWidth)
    low = center - width / 2
    high = center + width / 2    
    
    pixel_array = (pixel_array * slope) + intercept
    pixel_array = np.clip(pixel_array, low, high)

    return pixel_array

In [40]:
def process(data_path="", size=512):
    lst_files = os.listdir(data_path)
    lst_files = [int(x[:-4]) for x  in lst_files]
    imgs = []
    for f in range(min(lst_files), max(lst_files) + 1):
        path_to_files = os.path.join(data_path, f"{f}.dcm")

        dicom = pydicom.dcmread(path_to_files)

        pos_z = dicom[(0x20, 0x32)].value[-1]

        img = standardize_pixel_array(dicom)
        img = (img - img.min()) / (img.max() - img.min() + 1e-6)

        if dicom.PhotometricInterpretation == "MONOCHROME1":
            img = 1 - img

    
        if img.shape != (512, 512):
            img = cv2.resize(img, (size, size))

        imgs.append(img)

    combined_array = np.stack(imgs, axis=0)
    
    combined_array = np.transpose(combined_array, [1, 2, 0])
    
    return combined_array


In [41]:
img = process(data_path=dicom_folder, size=512)

In [42]:
img.shape

(512, 512, 47)

In [43]:
import volumentations as V
aug = V.Compose([
            V.Resize((256,256,96), interpolation=3, resize_type=0, always_apply=True, p=1.0),
        ], p=1.0)

In [44]:
data_for_aug = {'image': img}
aug_data = aug(**data_for_aug)

img_new = aug_data['image']

In [45]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [46]:
model = ResidualUNet3D(in_channels=1,
                                   out_channels=5,
                                   f_maps=[32,64,128,256,512],
                                   final_sigmoid=False).to(device)

model.load_state_dict(torch.load("ResidualUNet3D_fold_1_last_epochs.pt", map_location= device))
model.to(device)
device

device(type='cuda')

In [47]:
model.eval()
with torch.no_grad():
    pred_mask = (model).forward(torch.tensor(img_new).unsqueeze(0).unsqueeze(0).to(device))
    pred_mask = pred_mask[0]

In [48]:
pred_mask.shape
import torch.nn.functional as F

In [49]:
pred_mask = F.interpolate(
            pred_mask.unsqueeze(0),
            size=img.shape,
            mode='trilinear',
            align_corners=False
        )

In [50]:
pred_mask = pred_mask[0]
pred_mask = torch.argmax(pred_mask, dim=0)

In [51]:
pred_mask = pred_mask.to('cpu').numpy()

In [52]:
# Создайте трехмерную визуализацию
fig = ipv.figure()

# Отобразите объемные данные

# vol = ipv.volshow(np.where(real_mask == 3, 1, 0) , lighting=True)
vol = ipv.volshow(img , lighting=True)
# Добавьте интерактивность (вращение, масштабирование и т.д.)
ipv.style.box_off()
ipv.style.axes_off()
ipv.style.set_style_light()


# Покажите визуализацию
ipv.show()

Container(children=[VBox(children=(HBox(children=(Label(value='levels:'), FloatSlider(value=0.1, max=1.0, step…