Find ROI from fish scan using NN
====

In [2]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
"""
Directory path for RDSF fish data

"""
import os

# Where I've mounted it
rdsf_dir = os.path.expanduser("~/zebrafish_rdsf/")

data_dir = os.path.join(rdsf_dir, "DATABASE/uCT/Wahab_clean_dataset/low_res_clean_v3/")

print(data_dir)
assert os.path.isdir(data_dir)

/home/mh19137/zebrafish_rdsf/DATABASE/uCT/Wahab_clean_dataset/low_res_clean_v3/


In [4]:
"""
Read a subset of the images

"""

import numpy as np
from tqdm import tqdm

from dev import image_io

good_imgs = [41, 42, 44, 47, 49, 61, 62]

images = [
    (
        image_io.read_tiffstack(n)
        if not os.path.exists(f"data_cache/{n}.npy")
        else np.load(f"data_cache/{n}.npy")
    )
    for n in tqdm(good_imgs)
]

100%|██████████| 7/7 [00:19<00:00,  2.79s/it]


In [6]:
"""
Cache them because it took ages to read

"""

if not os.path.exists("data_cache"):
    os.mkdir("data_cache")

for i, img in zip(good_imgs, images):
    if not os.path.exists(f"data_cache/{i}.npy"):
        np.save(f"data_cache/{i}.npy", img)

In [7]:
"""
Find the jaw location in these images from the metadata

"""

from dev import metadata

# Get the mastersheet
mastersheet = metadata.mastersheet()

# Choose the right rows
mastersheet = mastersheet[mastersheet["old_n"].isin(good_imgs)]
assert len(mastersheet) == len(good_imgs)

# Convert to dict
jaw_locs = dict(zip(mastersheet["old_n"], mastersheet["jaw_center"]))

# Convert values in the dict to tuples of [Z X Y]
jaw_locs = {k: image_io.parse_roi(v) for k, v in jaw_locs.items()}
jaw_locs

{41: (1739, 296, 308),
 42: (1694, 330, 470),
 44: (1759, 298, 376),
 47: (1645, 466, 157),
 49: (1561, 399, 482),
 61: (1459, 309, 424),
 62: (1495, 354, 341)}

In [8]:
"""
Convert them to the right format for the model

"""

import torch

# Normalise, convert to tensor and reshape
tensors = [image_io.img2pytorch(img) for img in tqdm(images)]

100%|██████████| 7/7 [00:11<00:00,  1.65s/it]


In [10]:
# Convert to a dataset
from torch.utils.data import Dataset


class JawLocationDataset(Dataset):
    def __init__(
        self,
        tensors: list[torch.tensor],
        locations: list[tuple[int, int, int]],
    ):
        self.tensors = tensors
        self.locations = torch.tensor(locations, dtype=torch.float32)

        assert len(tensors) == len(locations)

    def __len__(self):
        return len(self.tensors)

    def __getitem__(self, idx):
        """
        Returns a tuple of (image, location)

        """
        return self.tensors[idx], self.locations[idx]

In [18]:
from torch.utils.data import DataLoader

dataset = JawLocationDataset(tensors, [jaw_locs[n] for n in good_imgs])
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

In [15]:
"""
Define the model architecture

"""

# Create a model for learning the jaw centre
import torch.nn as nn
import torch.nn.functional as F


class ResNet3D(nn.Module):
    def __init__(self):
        super(ResNet3D, self).__init__()
        # Convolutional layers
        self.conv1 = nn.Conv3d(
            in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=1
        )
        self.bn1 = nn.BatchNorm3d(64)

        self.conv2 = nn.Conv3d(
            in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1
        )
        self.bn2 = nn.BatchNorm3d(128)

        self.conv3 = nn.Conv3d(
            in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1
        )
        self.bn3 = nn.BatchNorm3d(256)

        self.pool = nn.MaxPool3d(kernel_size=2, stride=2)

        self.flatten = nn.Flatten()

        self.dropout = nn.Dropout(0.5)

        # Fully connected layers
        self.fc1 = nn.Linear(256 * 20 * 20 * 20, 256)
        self.fc2 = nn.Linear(256, 3)  # Output layer for 3 coordinates

    def forward(self, x):
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        x = self.pool(F.relu(self.bn3(self.conv3(x))))
        x = self.flatten(x)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x


model = ResNet3D()
print(model)

# Define loss function and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters())

ResNet3D(
  (conv1): Conv3d(1, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (bn2): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv3d(128, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (bn3): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (dropout): Dropout(p=0.5, inplace=False)
  (fc1): Linear(in_features=2048000, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=3, bias=True)
)


In [16]:
torch.cuda.is_available()

True

In [19]:
"""
Train and dump the model

"""

loss = []
epoch = []
n_epochs = 5
for epoch in range(n_epochs):
    model.train()
    epoch_loss = 0.0
    for data, targets in dataloader:
        optimizer.zero_grad()
        outputs = model(data)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    epoch_loss /= len(dataloader)
    loss.append(epoch_loss)

RuntimeError: stack expects each tensor to be equal size, but got [1, 2000, 644, 644] at entry 0 and [1, 2000, 634, 634] at entry 1

In [None]:
"""
Plot the training loss

"""
import matplotlib.pyplot as plt
plt.plot(range(1, n_epochs + 1), loss)

ResNet3D(
  (conv1): Conv3d(1, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (bn2): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv3d(128, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (bn3): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (dropout): Dropout(p=0.5, inplace=False)
  (fc1): Linear(in_features=2048000, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=3, bias=True)
)


In [125]:
# Make a diagram of the architecture
from torchviz import make_dot

dummy_input = torch.randn(1, 1, 160, 160, 160)
dummy_output = model(dummy_input)

graph = make_dot(dummy_output, params=dict(model.named_parameters()))

graph.render("model", format="png")

'model.png'

In [11]:
# Convert some arrays into numpy arrays
import pathlib
from multiprocessing import Pool

array_dir = pathlib.Path("arrays/")
array_dir.mkdir(exist_ok=True)

# Choose some arrays to save
array_ns = [44, 68, 414]

In [12]:
def save_array(n: int) -> None:
    if not (array_dir / f"{n}.npy").exists():
        array = tifs2array(n2dir(n), progress=False)
        np.save(array_dir / f"{n}.npy", array)


# Will need to change this as array_ns gets bigger
with Pool(processes=len(array_ns)) as pool:
    pool.map(
        save_array,
        array_ns,
    )

In [14]:
# Check them
import numpy as np

for n in array_ns:
    arr = np.load(array_dir / f"{n}.npy")
    roi = get_roi(n)
    cropped_array = crop_image(arr, roi)
    plot_slices(cropped_array)

NameError: name 'crop_image' is not defined

In [148]:
# Reshape the arrays into a format that the model can use
train_arrs = [np.expand_dims(np.load(array_dir / f"{n}.npy"), axis=0) for n in array_ns]
print(train_arrs[0].shape)

(1, 2000, 596, 634)


In [1]:
# Get the jaw centres from the metadata
centres = [get_roi(n) for n in array_ns]
centres

NameError: name 'array_ns' is not defined

In [None]:
# Parse the jaw centres to get the right numbers out

In [None]:
# Train the model on a few jaws, look at training loss etc
