Find ROI from fish scan using NN
====

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
"""
Directory path for RDSF fish data

"""
import os

# Where I've mounted it
rdsf_dir = os.path.expanduser("~/zebrafish_rdsf/")

data_dir = os.path.join(rdsf_dir, "DATABASE/uCT/Wahab_clean_dataset/low_res_clean_v3/")

print(data_dir)
assert os.path.isdir(data_dir)

/home/mh19137/zebrafish_rdsf/DATABASE/uCT/Wahab_clean_dataset/low_res_clean_v3/


In [3]:
"""
Read a subset of the images

"""

import numpy as np
from tqdm import tqdm

from dev import image_io

good_imgs = [41, 42, 44, 47, 49, 61, 62]

images = [
    (
        image_io.read_tiffstack(n)
        if not os.path.exists(f"data_cache/{n}.npy")
        else np.load(f"data_cache/{n}.npy")
    )
    for n in tqdm(good_imgs)
]

100%|██████████| 7/7 [00:26<00:00,  3.78s/it]


In [4]:
"""
Cache them because it took ages to read

"""

if not os.path.exists("data_cache"):
    os.mkdir("data_cache")

for i, img in zip(good_imgs, images):
    if not os.path.exists(f"data_cache/{i}.npy"):
        np.save(f"data_cache/{i}.npy", img)

In [5]:
"""
Downsample the arrays

"""

from skimage.measure import block_reduce

downsampled_images = []
for image in images:
    print(image.size, end=" -> ")
    downsampled_image = block_reduce(image, (10, 10, 10))
    print(downsampled_image.size)
    downsampled_images.append(downsampled_image)

del images

829472000 -> 845000
803912000 -> 819200
755728000 -> 768000
793403100 -> 793800
803510044 -> 819200
803510044 -> 819200
829057264 -> 845000


In [6]:
from dev import plot

fig, _ = plot.plot_arr(images[4])
fig.savefig("downsampled.png")

NameError: name 'images' is not defined

In [None]:
"""
Find the jaw location in these images from the metadata

"""

from dev import metadata

# Get the mastersheet
mastersheet = metadata.mastersheet()

# Choose the right rows
mastersheet = mastersheet[mastersheet["old_n"].isin(good_imgs)]
assert len(mastersheet) == len(good_imgs)

# Convert to dict
jaw_locs = dict(zip(mastersheet["old_n"], mastersheet["jaw_center"]))

# Convert values in the dict to tuples of [Z X Y]
jaw_locs = {k: image_io.parse_roi(v) for k, v in jaw_locs.items()}
jaw_locs

{41: (1739, 296, 308),
 42: (1694, 330, 470),
 44: (1759, 298, 376),
 47: (1645, 466, 157),
 49: (1561, 399, 482),
 61: (1459, 309, 424),
 62: (1495, 354, 341)}

: 

In [None]:
# Scale the locations to be in the downsampled space

downsampled_jaw_locs = [[i / 10] for v in jaw_locs.values() for i in v]

In [None]:
"""
Convert them to the right format for the model

"""

import torch

# Normalise, convert to tensor and reshape
tensors = [image_io.img2pytorch(img) for img in tqdm(downsampled_images)]

 71%|███████▏  | 5/7 [00:11<00:04,  2.40s/it]

In [None]:
# Convert to a dataset
from torch.utils.data import Dataset


class JawLocationDataset(Dataset):
    def __init__(
        self,
        tensors: list[torch.tensor],
        locations: list[tuple[int, int, int]],
    ):
        self.tensors = tensors
        self.locations = torch.tensor(locations, dtype=torch.float32)

        assert len(tensors) == len(locations)

    def __len__(self):
        return len(self.tensors)

    def __getitem__(self, idx):
        """
        Returns a tuple of (image, location)

        """
        return self.tensors[idx], self.locations[idx]

In [None]:
# Pad the images
from torch.utils.data import DataLoader
from torch.nn.functional import pad


def collate(batch):
    max_depth = 0
    max_height = 0
    max_width = 0
    processed_batch = []

    # Find the max dimensions
    for item in batch:
        data, target = item
        depth, height, width = (
            data.size(1),
            data.size(2),
            data.size(3),
        )  # Assuming data shape is [C, D, H, W]
        if depth > max_depth:
            max_depth = depth
        if height > max_height:
            max_height = height
        if width > max_width:
            max_width = width

    # Pad the images
    for item in batch:
        data, target = item
        depth, height, width = data.size(1), data.size(2), data.size(3)
        # Calculate padding
        pad_left = (max_width - width) // 2
        pad_right = max_width - width - pad_left
        pad_top = (max_height - height) // 2
        pad_bottom = max_height - height - pad_top
        pad_front = (max_depth - depth) // 2
        pad_back = max_depth - depth - pad_front
        # Apply padding
        padded_data = pad(
            data,
            (pad_left, pad_right, pad_top, pad_bottom, pad_front, pad_back),
            "constant",
            0,
        )
        processed_batch.append((padded_data, target))

    # Use the default collate function here for the padded batch
    return torch.utils.data.dataloader.default_collate(processed_batch)


dataset = JawLocationDataset(tensors, downsampled_jaw_locs)
dataloader = DataLoader(
    dataset, batch_size=1, shuffle=True, collate_fn=collate
)  # TODO multiple workers

In [None]:
# Create a model for learning the jaw centre
import torch.nn as nn
import torch.nn.functional as F


class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        # Run the Module constructor to make
        super(ResidualBlock, self).__init__()

        self.conv1 = nn.Conv3d(
            in_channels,
            out_channels,
            kernel_size=3,
            stride=stride,
            padding=1,
            bias=False,
        )
        self.bn1 = nn.BatchNorm3d(out_channels)
        self.conv2 = nn.Conv3d(
            out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm3d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv3d(
                    in_channels, out_channels, kernel_size=1, stride=stride, bias=False
                ),
                nn.BatchNorm3d(out_channels),
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

In [None]:
"""
Define the model architecture

"""


class ResNet3D(nn.Module):
    def __init__(self):
        super(ResNet3D, self).__init__()

        self.conv1 = nn.Conv3d(
            in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=1
        )
        self.bn1 = nn.BatchNorm3d(64)
        self.pool = nn.MaxPool3d(kernel_size=2, stride=2)

        # Residual blocks
        self.resblock1 = ResidualBlock(64, 128)
        self.resblock2 = ResidualBlock(128, 256)
        self.resblock3 = ResidualBlock(256, 512)

        self.flatten = nn.Flatten()
        self.dropout = nn.Dropout(0.5)

        self.fc1 = nn.Linear(256 * 20 * 20 * 20, 512)
        self.fc2 = nn.Linear(256, 3)  # Output layer for 3 coordinates

    def forward(self, x):
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        x = self.pool(F.relu(self.resblock1(x)))
        x = self.pool(F.relu(self.resblock2(x)))
        x = self.flatten(x)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x


model = ResNet3D()
print(model)

# Define loss function and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters())

ResNet3D(
  (conv1): Conv3d(1, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (resblock1): ResidualBlock(
    (conv1): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
    (bn1): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
    (bn2): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (shortcut): Sequential(
      (0): Conv3d(64, 128, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
      (1): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (resblock2): ResidualBlock(
    (conv1): Conv3d(128, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), pad



In [None]:
torch.cuda.is_available()

True

In [None]:
"""
Train and dump the model

"""
device = "cuda"

model.to(device)

loss = []
epoch = []
n_epochs = 5
for epoch in range(n_epochs):
    model.train()
    epoch_loss = 0.0
    for image, location in dataloader:
        image, location = image.to(device), location.to(device)

        optimizer.zero_grad()
        outputs = model(image)
        loss = criterion(outputs, location)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    epoch_loss /= len(dataloader)
    loss.append(epoch_loss)

RuntimeError: NVML_SUCCESS == DriverAPI::get()->nvmlInit_v2_() INTERNAL ASSERT FAILED at "../c10/cuda/CUDACachingAllocator.cpp":813, please report a bug to PyTorch. 

In [None]:
"""
Plot the training loss

"""
import matplotlib.pyplot as plt
plt.plot(range(1, n_epochs + 1), loss)

ResNet3D(
  (conv1): Conv3d(1, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (bn2): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv3d(128, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (bn3): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (dropout): Dropout(p=0.5, inplace=False)
  (fc1): Linear(in_features=2048000, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=3, bias=True)
)


In [None]:
# Make a diagram of the architecture
from torchviz import make_dot

dummy_input = torch.randn(1, 1, 160, 160, 160)
dummy_output = model(dummy_input)

graph = make_dot(dummy_output, params=dict(model.named_parameters()))

graph.render("model", format="png")

'model.png'

In [None]:
# Convert some arrays into numpy arrays
import pathlib
from multiprocessing import Pool

array_dir = pathlib.Path("arrays/")
array_dir.mkdir(exist_ok=True)

# Choose some arrays to save
array_ns = [44, 68, 414]

In [None]:
def save_array(n: int) -> None:
    if not (array_dir / f"{n}.npy").exists():
        array = tifs2array(n2dir(n), progress=False)
        np.save(array_dir / f"{n}.npy", array)


# Will need to change this as array_ns gets bigger
with Pool(processes=len(array_ns)) as pool:
    pool.map(
        save_array,
        array_ns,
    )

In [None]:
# Check them
import numpy as np

for n in array_ns:
    arr = np.load(array_dir / f"{n}.npy")
    roi = get_roi(n)
    cropped_array = crop_image(arr, roi)
    plot_slices(cropped_array)

NameError: name 'crop_image' is not defined

In [None]:
# Reshape the arrays into a format that the model can use
train_arrs = [np.expand_dims(np.load(array_dir / f"{n}.npy"), axis=0) for n in array_ns]
print(train_arrs[0].shape)

(1, 2000, 596, 634)


In [None]:
# Get the jaw centres from the metadata
centres = [get_roi(n) for n in array_ns]
centres

NameError: name 'array_ns' is not defined

In [None]:
# Parse the jaw centres to get the right numbers out

In [None]:
# Train the model on a few jaws, look at training loss etc
