### 1. Imports

In [None]:
import math

import numpy as np
import matplotlib.pyplot as plt
import cv2

import torch
from torch import Generator
from torch.utils.data import DataLoader, random_split

import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint, DeviceStatsMonitor, EarlyStopping, RichProgressBar
from lightning.pytorch.profilers import PyTorchProfiler
from lightning.pytorch.loggers import TensorBoardLogger

from data import DanforthDataset
from models import TrainingModel, UNet

from skimage.morphology import skeletonize

### 2. Configuration

In [None]:
torch.set_float32_matmul_precision('medium')
L.seed_everything(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

In [None]:
model_directory = 'danforth-raw-unet'

### 3. Definitions

#### Dataset

In [None]:
dataset = DanforthDataset('data/danforth/images', 'data/danforth/masks')
train_dataset, val_dataset, test_dataset = random_split(dataset, [0.80, 0.15, 0.05], generator=Generator().manual_seed(0))

test_dataset.dataset.transform = None

##### Display data

In [None]:
figure = plt.figure(figsize=(8, 8))
num_images = 4
for i in range(1, num_images * 2, 2):
    random_index = torch.randint(len(dataset), size=(1,)).item()
    image, mask = dataset[random_index].values()

    figure.add_subplot(num_images, 2, i)
    plt.title('Image')
    plt.imshow(image.permute(1, 2, 0).squeeze())
    figure.add_subplot(num_images, 2, i + 1)
    plt.title('Mask')
    plt.imshow(mask.squeeze(), cmap='gray')

plt.show()

#### Model

In [None]:
model = TrainingModel(UNet)

#### Trainer

In [None]:
checkpoint_callback = ModelCheckpoint(dirpath=f'checkpoints/{model_directory}',
                                      monitor='val_loss', save_top_k=5, mode='min', save_last=True)

In [None]:
trainer = L.Trainer(
    max_epochs=5000,
    log_every_n_steps=1,
    precision='bf16-mixed',
    logger=TensorBoardLogger(save_dir=f'logs/{model_directory}'),
    profiler=PyTorchProfiler(dirpath=f'logs/{model_directory}/profiler', filename='perf_logs'),
    callbacks=[
        DeviceStatsMonitor(cpu_stats=True),
        EarlyStopping(monitor='val_loss', patience=100, mode='min'),
        checkpoint_callback,
        RichProgressBar()
    ]
)

### 4. Train Model

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=2, prefetch_factor=2)
val_dataloader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=2, prefetch_factor=2)

In [None]:
model.train()
trainer.fit(model, train_dataloader, val_dataloader)

### 5. Test Model

In [None]:
test_dataloader = DataLoader(test_dataset, batch_size=4, shuffle=False, num_workers=1, prefetch_factor=1)
test_dataloader.dataset.transform = None

In [None]:
model.eval()
trainer.test(model, test_dataloader, ckpt_path='last')

##### Display predictions

In [None]:
predictions = trainer.predict(model, test_dataloader, ckpt_path='last')

In [None]:
samples = [torch.unbind(batch, dim=0) for batch in predictions]
samples = [item for batch in samples for item in batch]

In [None]:
figure = plt.figure(figsize=(8, 8))
for index, images in enumerate(zip(test_dataset, samples)):
    sample, pred = images
    image, mask = sample.values()

    pred = pred.float()

    figure.add_subplot(len(samples), 3, index * 3 + 1)
    plt.title('Image')
    plt.imshow(image.permute(1, 2, 0).squeeze())
    figure.add_subplot(len(samples), 3, index * 3 + 2)
    plt.title('Mask')
    plt.imshow(mask.squeeze(), cmap='gray')
    figure.add_subplot(len(samples), 3, index * 3 + 3)
    plt.title('Prediction')
    plt.imshow(pred.squeeze(), cmap='gray')

plt.show()

### 6. Calculate Data

In [None]:
scaling = 8 / 400

In [None]:
image = (samples[1] > 0.5).squeeze().numpy().astype(np.uint8)
skeleton = skeletonize((samples[1] > 0.5).squeeze().numpy()).astype(np.uint8)

image_contours, _ = cv2.findContours((samples[1] > 0.5).squeeze().numpy().astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
skeleton_contours, _ = cv2.findContours(skeleton, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

In [None]:
figure = plt.figure(figsize=(8, 8))

figure.add_subplot(1, 2, 1)
plt.imshow(skeleton, cmap='gray')
plt.title('Skeleton')

figure.add_subplot(1, 2, 2)
contoured = np.zeros((400, 400), dtype=np.uint8)
cv2.drawContours(contoured, image_contours, -1, (255, 255, 255), 1)
plt.imshow(contoured, cmap='gray')
plt.title('Contours')

plt.show()

#### Root Count

In [None]:
root_count = len(image_contours)
print(f'Number of roots: {root_count}')

#### Root Length

In [None]:
total_length = 0
for contour in skeleton_contours:
    total_length += cv2.arcLength(contour, False)

total_length = total_length * scaling
print(f'Total root length: {total_length:.5f} inches')
print(f'Average root length: {total_length / root_count:.2f} inches')

#### Root Area

In [None]:
total_area = torch.sum((samples[1] > 0.5).float())
total_area = total_area * (scaling ** 2)
print(f'Total root area: {total_area:.2f} square inches')
print(f'Average root area: {total_area / root_count:.2f} square inches')

#### Root Diameter

In [None]:
image_contour_points = np.vstack(image_contours).squeeze()

diameters = []

for skeleton_contour in skeleton_contours:
    skeleton_contour_points = skeleton_contour[:, 0, :]

    for point in skeleton_contour_points:
        distances = np.linalg.norm(image_contour_points - point, axis=1)
        closest_index = np.argmin(distances)

        diameters.append(2 * distances[closest_index] * scaling)

print(f'Average root diameter: {np.mean(diameters):.2f} inches')

#### Root Volume

In [None]:
total_volume = 0

for diameter in diameters:
    radius = diameter / 2
    volume = math.pi * radius ** 2
    total_volume += volume

print(f'Total root volume: {total_volume:.2f} cubic inches')
print(f'Average root volume: {total_volume / root_count:.2f} cubic inches')