In [1]:
!nvidia-smi

Mon Mar  6 11:48:22 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 526.98       Driver Version: 526.98       CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name            TCC/WDDM | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA RTX A5000   WDDM  | 00000000:01:00.0  On |                  Off |
| 30%   42C    P8    15W / 230W |    532MiB / 24564MiB |      3%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA RTX A5000   WDDM  | 00000000:61:00.0 Off |                  Off |
| 30%   31C    P8     4W / 230W |     49MiB / 24564MiB |      0%      Default |
|       

# import stuff

In [2]:
import os
import time
from matplotlib import pylab as plt
import numpy as np

import torch
from torch.utils.data import DataLoader

import monai
from monai.data import (
    list_data_collate,
    ITKReader,
    NumpyReader,
)
from monai.networks.nets import DenseNet121
from monai.transforms import (
    Compose,
    EnsureTyped,
    EnsureChannelFirstd,
    Lambdad,
    LoadImaged,
    Resized,
    Rotate90d,
    ScaleIntensityd,
    ToNumpy,
    ToTensord,
)
from monai.utils import first

npc = ToNumpy()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

  from .autonotebook import tqdm as notebook_tqdm


cuda


In [3]:
def gray2rgb(x):
    #print(x.shape)
    if x.shape[0]==1:
        x = x.repeat(3, 1, 1)
        x.meta['original_channel_dim'] = -1 # THIS is the important line! 
    #print(x.shape)
    return x

def clean_tiff_meta(x):
    for key in ['DocumentName', 'ImageDescription', 'Software']:
        if key in x.meta.keys():
            del x.meta[key]
    return x

# load data

In [4]:
# get the folder location
images_path = r"../data/torsion/P01/framesES/"

# get image locations in an array
images = []
for path in os.listdir(images_path):
    image_location = os.path.join(images_path, path)
    if os.path.isfile(image_location):
        images.append(image_location)
test_files = [{"img":  fn_img} for fn_img in images]

# store the time at the beginning
timeBeginning = time.time()

In [5]:
print(test_files[0]["img"])

../data/torsion/P01/framesES/P01_ES_GVS_0.jpg


# 1st Model

## necessary transformations

In [6]:
# apply necessary transformations

test_transforms = Compose(
    [
        LoadImaged(keys=["img"], reader= ITKReader, image_only = True),
        EnsureChannelFirstd(keys=["img"]),
        Lambdad(keys=["img"], func=lambda x: gray2rgb(x)), # gray to rgb conversion
        ScaleIntensityd(keys="img"),
        Rotate90d(keys=["img"]),
        Rotate90d(keys=["img"]),
        Rotate90d(keys=["img"]),
        Resized(keys=["img"], spatial_size=(240,320)),
        EnsureTyped(keys="img"),
        Lambdad(keys=["img"], func=lambda x: clean_tiff_meta(x)), # clean weird keys in TIFF metadata
        ToTensord(keys=["img"]),
    ]
)

batch_size = 128

# get them to monai data loader
test_ds = monai.data.Dataset(data=test_files, transform=test_transforms)
test_loader = DataLoader(
    test_ds, 
    batch_size=batch_size, 
    num_workers=0,
    collate_fn=list_data_collate
)

In [7]:
def draw_image(frame_rgb, ax=None):
    # Plot segmented area
    #fig, ax = plt.subplots(figsize=(3.2,2.4))
    alpha=0.4
    fig_created = False
    if ax is None:
        fig = plt.figure(figsize=(3.2,2.4))
        #canvas = FigureCanvas(fig)
        ax = fig.subplots()
        fig_created = True
    ax.imshow(frame_rgb)#, vmax=1, vmin=0, cmap="gray")
    ax.axis('off')
    if fig_created:
        fig.tight_layout()
        fig.canvas.draw()
        plt.show()

if False:
    check_data = first(test_loader) #check_data = monai.utils.misc.first(check_lo
    for i in range(32):
        # channel first versions
        img_cf = np.squeeze(npc(check_data["img"])[i,:,:,:])    
        # channel last versions for plotting
        img = np.moveaxis(img_cf, [0,1,2], [-1,-3,-2])
        draw_image(img)


## load model, infer

In [8]:
# define model
model = DenseNet121(spatial_dims=2, 
                    in_channels=3,
                    out_channels=30).to(device)

# load weights
model.load_state_dict(torch.load("models/best_metric_model_oc_multi_densenet121.pth"))
timeFirstModelLoaded = time.time()

In [9]:
# infer images, get open close predictions
# store the time at the end of 1st model

soft_act = torch.nn.Softmax()

list_imgs  = []
list_segs  = []
list_preds = []

counter = 0 
timePassedInLoop = 0
with torch.no_grad():
    for test_data in test_loader:
        timeLoopStarted = time.time()
        
        test_images = (
            test_data["img"].to(device)
        )
        test_outputs = model(test_images)
        
        # it is faster with the code below, but there is not enough memory.
        # list_imgs.append(test_images)
        # list_preds.append(test_outputs)
        
        # so change the device cpu and save it to numpy array
        tmp_imgs = [np.squeeze(img.cpu().numpy()).transpose([1,2,0]) for img in test_images]
        list_imgs.append(tmp_imgs)
        tmp_preds = [(np.squeeze(soft_act(pred).cpu().numpy())).round() for pred in test_outputs]
        list_preds.append(tmp_preds)
        
        timeOfLoop = time.time() - timeLoopStarted
        timePassedInLoop += timeOfLoop
timeFirstModelPredicted = time.time()
print("done predicting")
print("took " + str(timePassedInLoop) + "seconds")
        
#list_imgs = [np.squeeze(img.cpu().numpy()).transpose([1,2,0]) for img in list_imgs]
list_imgs = [np.squeeze(img) for imgs in list_imgs for img in imgs]
#list_preds_act = [(np.squeeze(soft_act(pred).cpu().numpy())).round() for pred in list_preds]
list_preds_act = [np.squeeze(pred) for preds in list_preds for pred in preds]

indexes = range(len(test_files))
indexesToRemove = [idx for idx in indexes if list_preds_act[idx][0] == 1]

timeFirstResultsInOrder = time.time()
print("flattening batch sizes")
print("took " + str(timeFirstResultsInOrder - timeFirstModelPredicted) + "seconds")

Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.


done predicting
took 47.61705446243286seconds
flattening batch sizes
took 0.07629942893981934seconds


# Some cells to check the results of the model (for now)

In [None]:
idx = 0


In [None]:
idxLimit = idx + 10
while idx < idxLimit and idx < len(test_files):
    img_cf  = list_imgs[idx]
    pred = list_preds_act[idx]
    
    # channel last versions for plotting
    img = np.moveaxis(img_cf, [0,1,2], [-1,-3,-2])
    draw_image(img_cf)
    print('Image index: %d'%idx)
    print("prediction\t", 'close' if pred[0] == 1 else 'open')
    print("prediction", pred)
    idx += 1

In [None]:
indexes = range(len(test_files))
counterClosedPrediction = 0
for idx in indexes:
    img_cf  = list_imgs[idx]
    pred = list_preds_act[idx]
    
    if (pred[0] == 1):
        # channel last versions for plotting
        img = np.moveaxis(img_cf, [0,1,2], [-1,-3,-2])
        draw_image(img_cf)
        print('Image index: %d'%idx)
        print("prediction\t", 'close' if pred[0] == 1 else 'open')
        print("prediction", pred)
        counterClosedPrediction += 1
print(counterClosedPrediction)
        
# ES1 14625 is false positive.

In [None]:
indexes = range(len(test_files))
counterAllZeros = 0
for idx in indexes:
    img_cf  = list_imgs[idx]
    pred = list_preds_act[idx]
    
    if (not pred.any()):
        # channel last versions for plotting
        img = np.moveaxis(img_cf, [0,1,2], [-1,-3,-2])
        draw_image(img_cf)
        print('Image index: %d'%idx)
        print("prediction\t", 'close' if pred[0] == 1 else 'open')
        print("prediction", pred)
        counterAllZeros += 1
print(counterAllZeros)


In [None]:
indexes = range(len(test_files))
counterAllZeros = 0
for idx in indexes:
    img_cf  = list_imgs[idx]
    pred = list_preds_act[idx]
    
    if (pred[1:].any()):
        # channel last versions for plotting
        img = np.moveaxis(img_cf, [0,1,2], [-1,-3,-2])
        draw_image(img_cf)
        print('Image index: %d'%idx)
        print("prediction\t", 'close' if pred[0] == 1 else 'open')
        print("prediction", pred)
        counterAllZeros += 1
print(counterAllZeros)


# 2nd model

## necessary transformations 

In [None]:
# apply necessary transformations, if not the same

# get them to monai data loader

## load model, infer

In [None]:
# define model

# load weights

In [None]:
# infer images, get segmentations
# store the time at the end of 2nd model

# 3rd model

## Crop iris (?)

In [None]:
# crop and store only the iris part of previous inferred images

## necessary transformations 

In [None]:
# apply necessary transformations
# to monai data loader

## load model, infer

In [None]:
# define model
# load weights

In [None]:
# infer images, get torsion predictions
# store the time at the end of 3rd model

# combine findings

In [None]:
# combine x,y position of pupil and torsion for each image
# store them as txt, npt etc.