In [1]:
import os
import pydicom
import matplotlib.pyplot as plt
import cv2
import numpy as np
import nibabel as nib
from PIL import Image


In [4]:
import pandas as pd

df = pd.DataFrame(pd.read_excel("NNI_Data_valid/data_batch_2_labels.xlsx"))
# Renaming the columns
df.rename(columns={"Filename": "filename",
          "Labels": "label_name"}, inplace=True)

# Adding the new column based on conditions
df['labels'] = df['label_name'].apply(lambda x: 0 if x == 'Normal' else 1)
# Remove the row where "filename" is "M20047"
df = df[df["filename"] != "M20047"]
print(df)
df.to_csv("labels.csv")


   filename label_name  labels
0    M20026     Normal       0
1    M20027     Normal       0
2    M20028     Normal       0
3    M20029   Abnormal       1
4    M20030     Normal       0
5    M20031   Abnormal       1
6    M20032     Normal       0
7    M20033     Normal       0
8    M20034     Normal       0
9    M20035     Normal       0
10   M20036     Normal       0
11   M20037     Normal       0
12   M20038     Normal       0
13   M20039     Normal       0
14   M20040     Normal       0
15   M20041   Abnormal       1
16   M20042   Abnormal       1
17   M20043     Normal       0
18   M20044   Abnormal       1
19   M20045     Normal       0
20   M20046   Abnormal       1
22   M20048     Normal       0
23   M20049     Normal       0
24   M20050   Abnormal       1
25   M30021     Normal       0
26   M30022     Normal       0
27   M30023     Normal       0
28   M30024     Normal       0
29   M30025     Normal       0
30   M30026     Normal       0
31   M30027   Abnormal       1
32   M30

# Preprocessing
## Converting dicom to png

In [3]:
def dcm_to_png(input_folders):
    for folder in input_folders:
        # Create output folder
        output_folder = folder + "_png"
        os.makedirs(output_folder, exist_ok=True)

        # Convert each .dcm file to .png
        for filename in os.listdir(folder):
            if filename.endswith(".dcm"):
                # Load DICOM file
                # dicom = pydicom.dcmread(os.path.join(folder, filename))
                dicom = pydicom.dcmread(os.path.join(folder, filename))

                # Normalize pixel array
                normalized_array = cv2.normalize(
                    dicom.pixel_array, None, 0, 255, cv2.NORM_MINMAX)

                # Convert normalized pixel array to uint8 type
                img_uint8 = normalized_array.astype(np.uint8)

                # Convert to PNG and save
                new_filename = filename.split(
                    '.')[0] + '_' + folder.split('_')[-1] + '.png'
                cv2.imwrite(os.path.join(
                    output_folder, new_filename), img_uint8)


# dcm_to_png(input_folders)


In [4]:
input_folders = ["NNI_Data/2D_projection_AP", "NNI_Data/2D_projection_LR", "NNI_Data/2D_projection_SI"]
dcm_to_png(input_folders)

In [12]:
import torch
png_img = torch.tensor(np.array((Image.open(
    "NNI_Data/2D_projection_AP_png/M30012_AP.png")))).unsqueeze(0).repeat(3, 1, 1)
print(png_img)
print(png_img.shape)
print(np.min(png_img))
print(np.max(png_img))

tensor([[[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]]], dtype=torch.uint8)
torch.Size([3, 512, 512])


TypeError: min() received an invalid combination of arguments - got (out=NoneType, axis=NoneType, ), but expected one of:
 * ()
 * (Tensor other)
 * (int dim, bool keepdim)
      didn't match because some of the keywords were incorrect: out, axis
 * (name dim, bool keepdim)
      didn't match because some of the keywords were incorrect: out, axis


In [5]:
def dicom2nifti(input_folder, output_folder):
    os.makedirs(output_folder, exist_ok=True)

    dicom_files = [os.path.join(input_folder, f) for f in os.listdir(
        input_folder) if f.endswith('.dcm')]

    # Read metadata and pixel data
    dicoms = [pydicom.read_file(dcm) for dcm in dicom_files]

    # Sort dicom slices based on Instance Number
    dicoms.sort(key=lambda x: int(x.InstanceNumber))

    # Stack slices into 3D array
    img = np.stack([dcm.pixel_array for dcm in dicoms])
    img = img.astype(np.int16)
    img = np.transpose(img, (2, 1, 0))  # transpose to get the desired shape

    # Construct affine matrix from DICOM metadata
    affine = np.eye(4)

    # Create the NIfTI image
    nifti_img = nib.Nifti1Image(img, affine)
    nib.save(nifti_img, os.path.join(output_folder,
             input_folder.split('/')[-1] + '.nii.gz'))

In [6]:
input_folder = "NNI_Data/3D_stack"
output_folder = "NNI_Data/3D_stack_nii"

for subfolder in os.listdir(input_folder):
    print(f"Processing {subfolder}")
    try: 
        dicom2nifti(os.path.join(input_folder, subfolder), output_folder)
    except Exception as e:
        # Do something
        print(f"Error processing {subfolder}. Please see below:")
        print(e)
        


Processing M40006
Processing M10047
Processing M10041
Processing M20001
Processing M40022
Processing M20010
Processing M10049
Processing M40050
Processing M40014
Processing M10017
Processing M40031
Processing M20008
Processing M40042
Processing M30003
Processing M40044
Processing M10010
Processing M20025
Processing M20019
Processing M20005
Processing M50015
Processing M40009
Processing M40023
Processing M40048
Processing M10019
Processing M20002
Processing M10014
Processing M40034
Processing M40015
Processing M30016
Processing M30018
Processing M20012
Processing M50007
Processing M10022
Processing M40047
Processing M10002
Processing M10045
Processing M40016
Processing M50002
Processing M10030
Processing M10046
Processing M30013
Processing M30004
Processing M40041
Processing M40037
Processing M10035
Processing M40012
Processing M40030
Processing M40019
Processing M50005
Processing M10033
Processing M10050
Processing M40036
Processing M40024
Processing M10016
Processing M20009
Processing

In [7]:
# Define path to NIfTI files
path = 'NNI_Data/3D_stack_nii'

# Initialize variables to store min and max dimensions
min_width, min_height, min_depth = float('inf'), float('inf'), float('inf')
max_width, max_height, max_depth = float('-inf'), float('-inf'), float('-inf')

# Loop over all NIfTI files in the directory
for filename in os.listdir(path):
    if filename.endswith('.nii.gz'):
        file_path = os.path.join(path, filename)
        img = nib.load(file_path)
        data = img.get_fdata()
        print(f"File: {filename}, Shape: {data.shape}")

        width, height, depth = data.shape
        min_width, max_width = min(min_width, width), max(max_width, width)
        min_height, max_height = min(
            min_height, height), max(max_height, height)
        min_depth, max_depth = min(min_depth, depth), max(max_depth, depth)

print("\nDimension statistics:")
print(f"Min Width: {min_width}, Max Width: {max_width}")
print(f"Min Height: {min_height}, Max Height: {max_height}")
print(f"Min Depth: {min_depth}, Max Depth: {max_depth}")


File: M40006.nii.gz, Shape: (512, 512, 216)
File: M30015.nii.gz, Shape: (512, 512, 156)
File: M10014.nii.gz, Shape: (512, 512, 164)
File: M50015.nii.gz, Shape: (512, 512, 150)
File: M30006.nii.gz, Shape: (512, 512, 160)
File: M10005.nii.gz, Shape: (512, 512, 196)
File: M50010.nii.gz, Shape: (512, 512, 164)
File: M10022.nii.gz, Shape: (512, 512, 164)
File: M40030.nii.gz, Shape: (512, 512, 120)
File: M30019.nii.gz, Shape: (528, 528, 168)
File: M10049.nii.gz, Shape: (512, 512, 164)
File: M20015.nii.gz, Shape: (720, 720, 160)
File: M30020.nii.gz, Shape: (512, 512, 156)
File: M40044.nii.gz, Shape: (512, 512, 164)
File: M10027.nii.gz, Shape: (512, 512, 128)
File: M40013.nii.gz, Shape: (512, 512, 148)
File: M40007.nii.gz, Shape: (512, 512, 184)
File: M10021.nii.gz, Shape: (512, 512, 164)
File: M10046.nii.gz, Shape: (512, 512, 164)
File: M10047.nii.gz, Shape: (512, 512, 192)
File: M30010.nii.gz, Shape: (512, 512, 168)
File: M10043.nii.gz, Shape: (512, 512, 164)
File: M30009.nii.gz, Shape: (512

### View how the data looks like

In [41]:
path = 'NNI_Data/3D_stack_nii'


# Loop over all NIfTI files in the directory
for filename in os.listdir(path):
    if filename.endswith('.nii.gz'):
        file_path = os.path.join(path, filename)
        img = nib.load(file_path)
        data = img.get_fdata()
        print(data)
        print(f"File: {filename}, Shape: {data.shape}")
    break

[[[ 0.  0.  0. ...  0.  0.  0.]
  [ 0.  0.  0. ...  0.  0.  0.]
  [ 0.  0.  0. ...  0.  0.  0.]
  ...
  [ 8.  6.  3. ...  2.  5.  7.]
  [ 9.  7.  4. ...  3.  6.  9.]
  [12. 12.  5. ...  4.  4.  8.]]

 [[ 0.  0.  0. ...  0.  0.  0.]
  [ 0.  0.  0. ...  0.  0.  0.]
  [ 0.  0.  0. ...  0.  0.  0.]
  ...
  [ 5.  5.  7. ...  5.  6.  7.]
  [ 5.  5.  7. ...  7.  9.  9.]
  [ 9. 10.  8. ...  5.  8.  8.]]

 [[ 0.  0.  0. ...  0.  0.  0.]
  [ 0.  0.  0. ...  0.  0.  0.]
  [ 0.  0.  0. ...  0.  0.  0.]
  ...
  [ 9. 10.  8. ...  6.  7.  7.]
  [ 9.  6.  7. ...  8. 10.  8.]
  [ 7.  5.  6. ...  8. 10.  6.]]

 ...

 [[ 0.  0.  0. ...  0.  0.  0.]
  [ 0.  0.  0. ...  0.  0.  0.]
  [ 0.  0.  0. ...  0.  0.  0.]
  ...
  [ 3.  4.  4. ...  6. 13. 15.]
  [ 4.  6.  6. ...  4. 12. 15.]
  [ 4.  6.  5. ...  4.  9. 12.]]

 [[ 0.  0.  0. ...  0.  0.  0.]
  [ 0.  0.  0. ...  0.  0.  0.]
  [ 0.  0.  0. ...  0.  0.  0.]
  ...
  [ 3.  4.  4. ...  3.  9. 11.]
  [ 4.  5.  3. ...  2.  7.  8.]
  [ 4.  5.  3. ...  5.  4.  

In [None]:
print(16*10)

# Modifying and loading model checkpoint

In [4]:
import torch.nn as nn
import torch.nn.functional as F

# from __future__ import annotations

from collections.abc import Sequence

import torch.nn as nn
import torch

from monai.networks.blocks.dynunet_block import UnetOutBlock
from monai.networks.blocks.unetr_block import UnetrBasicBlock, UnetrPrUpBlock, UnetrUpBlock
from monai.networks.nets.vit import ViT
from monai.utils import ensure_tuple_rep


In [8]:
from monai.networks.nets import UNETR
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the checkpoint
checkpoint = torch.load('checkpoints/UNETR_model_best_acc.pth')
model = UNETR(
    in_channels=1,
    out_channels=14,
    img_size=(96, 96, 96),
    feature_size=16,
    hidden_size=768,
    mlp_dim=3072,
    num_heads=12,
    pos_embed="perceptron",
    norm_name="instance",
    res_block=True,
    dropout_rate=0.0,
).to(device)

# model.load_state_dict(checkpoint)


In [9]:
print(checkpoint.keys())

odict_keys(['vit.patch_embedding.position_embeddings', 'vit.patch_embedding.cls_token', 'vit.patch_embedding.patch_embeddings.1.weight', 'vit.patch_embedding.patch_embeddings.1.bias', 'vit.blocks.0.mlp.linear1.weight', 'vit.blocks.0.mlp.linear1.bias', 'vit.blocks.0.mlp.linear2.weight', 'vit.blocks.0.mlp.linear2.bias', 'vit.blocks.0.norm1.weight', 'vit.blocks.0.norm1.bias', 'vit.blocks.0.attn.out_proj.weight', 'vit.blocks.0.attn.out_proj.bias', 'vit.blocks.0.attn.qkv.weight', 'vit.blocks.0.norm2.weight', 'vit.blocks.0.norm2.bias', 'vit.blocks.1.mlp.linear1.weight', 'vit.blocks.1.mlp.linear1.bias', 'vit.blocks.1.mlp.linear2.weight', 'vit.blocks.1.mlp.linear2.bias', 'vit.blocks.1.norm1.weight', 'vit.blocks.1.norm1.bias', 'vit.blocks.1.attn.out_proj.weight', 'vit.blocks.1.attn.out_proj.bias', 'vit.blocks.1.attn.qkv.weight', 'vit.blocks.1.norm2.weight', 'vit.blocks.1.norm2.bias', 'vit.blocks.2.mlp.linear1.weight', 'vit.blocks.2.mlp.linear1.bias', 'vit.blocks.2.mlp.linear2.weight', 'vit.bloc

In [10]:
# def remove_unexpected_keys(model, state_dict):
# state_dict = checkpoint
model_dict = model.state_dict()

# filter out unnecessary keys
state_dict = {k: v for k, v in checkpoint.items() if k in model_dict}

# overwrite entries in the existing state dict
model_dict.update(state_dict)

# load the new state dict
model.load_state_dict(model_dict)


# # assume that `checkpoint` is the loaded original state_dict
# remove_unexpected_keys(model=model, state_dict=checkpoint)


<All keys matched successfully>

In [11]:
model

UNETR(
  (vit): ViT(
    (patch_embedding): PatchEmbeddingBlock(
      (patch_embeddings): Sequential(
        (0): Rearrange('b c (h p1) (w p2) (d p3) -> b (h w d) (p1 p2 p3 c)', p1=16, p2=16, p3=16)
        (1): Linear(in_features=4096, out_features=768, bias=True)
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (blocks): ModuleList(
      (0): TransformerBlock(
        (mlp): MLPBlock(
          (linear1): Linear(in_features=768, out_features=3072, bias=True)
          (linear2): Linear(in_features=3072, out_features=768, bias=True)
          (fn): GELU()
          (drop1): Dropout(p=0.0, inplace=False)
          (drop2): Dropout(p=0.0, inplace=False)
        )
        (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): SABlock(
          (out_proj): Linear(in_features=768, out_features=768, bias=True)
          (qkv): Linear(in_features=768, out_features=2304, bias=False)
          (input_rearrange): Rearrange('b h (qkv l d) -> qkv b l h 

In [12]:

# Define a new model that only uses the encoder part of UNETR
class EncoderModel(nn.Module):
    def __init__(self, unetr_model, out_channels=512):
        super(EncoderModel, self).__init__()
        self.vit = unetr_model.vit
        self.fc = nn.Linear(768, out_channels)  # fully connected layer

    def forward(self, x):
        _, embedding = self.vit(x)
        print(len(embedding))
        print(embedding[0].shape)
        print(embedding[1].shape)
        # Pass the embedding through the fully connected layer before returning
        x = self.fc(embedding)
        return x


# Then instantiate the model

# del model


class EncoderModel(nn.Module):
    def __init__(self, unetr_model, num_heads=12):
        super(EncoderModel, self).__init__()
        # Extract the transformer (encoder) part
        self.vit = unetr_model.vit
        self.num_heads = num_heads

        # Add a fully connected layer
        self.fc = nn.Linear(768, 512)


    def forward(self, x):
        # vit outputs a list of length num_heads
        _, outputs = self.vit(x)

        # Convert list of tensors to tensor
        # Shape: (num_heads, batch_size, sequence_len, embed_dim)
        outputs_tensor = torch.stack(outputs, dim=0)

        # Perform mean pooling across all heads and sequence length
        # Shape: (batch_size, embed_dim)
        mean_pooled_output = torch.mean(outputs_tensor, dim=[0, 2])

        # Pass the mean-pooled output through the fully connected layer
        out = self.fc(mean_pooled_output)
        return out


encoder_model = EncoderModel(model).to(device)

In [14]:
from torch.utils.data import DataLoader
import torch
from torch.utils.data import Dataset


class DummyDataset(Dataset):
    def __init__(self, num_samples=100, image_size=(1, 96, 96, 96)):
        self.num_samples = num_samples
        self.image_size = image_size

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # Generates a random tensor for each image
        sample = torch.randn(self.image_size).to(device)
        # Dummy label (not used in your model, but usually part of a dataset)
        label = torch.tensor(1).to(device)
        return sample, label


dummy_dataset = DummyDataset(image_size=(1,96,96,96))

# Instantiate DataLoader
data_loader = DataLoader(dummy_dataset, batch_size=32, shuffle=True)

# Now, you can iterate over this data loader and feed the input to your model
for batch in data_loader:
    inputs, _ = batch
    print(inputs.shape)
    outputs = encoder_model(inputs)
    print(outputs.shape)  # Check the output shape
    break  # For this test, we just need one batch


torch.Size([32, 1, 96, 96, 96])
torch.Size([32, 512])


In [38]:
import h5py
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import plotly.graph_objects as go

# Open the file in read mode
with h5py.File('modelnet40_hdf5_2048/modelnet40_hdf5_2048/test0.h5', 'r') as f:
    # Print all keys/datasets in the file
    print("Keys: %s" % f.keys())
    # Get the dataset from the file
    dataset = f['data']
    # Check the shape of the dataset
    print(dataset.shape)
    # If you want to check the length of the first dimension
    print(len(dataset))
    print(dataset[0].shape)
    print(dataset[0][0])
    sample_index = 0
    sample = dataset[sample_index]

    # Plot the point cloud
    # Create a scatter3d plot
    fig = go.Figure(data=[go.Scatter3d(
        x=sample[:, 0],
        y=sample[:, 1],
        z=sample[:, 2],
        mode='markers',
        marker=dict(
            size=2,
            opacity=0.8
        )
    )])

    # Update layout for better view
    fig.update_layout(
        scene=dict(
            xaxis=dict(range=[-1, 1], autorange=False),
            yaxis=dict(range=[-1, 1], autorange=False),
            zaxis=dict(range=[-1, 1], autorange=False),
            aspectratio=dict(x=1, y=1, z=1),
        ),
        margin=dict(r=20, l=10, b=10, t=10),
    )

    fig.show()


Keys: <KeysViewHDF5 ['data', 'label']>
(2048, 2048, 3)
2048
(2048, 3)
[ 0.7979659  -0.01555308  0.08455718]


In [31]:

print(dataset[0])


ValueError: Dset_id is not a dataset id (dset_id is not a dataset ID)

ValueError: Dset_id is not a dataset id (dset_id is not a dataset ID)