# Test for Models module

## Librarie

In [None]:
import os, sys

sys.path.insert(0, os.pardir)

import torch
from torch import nn
import numpy as np

from src.models.unet3d import UNet3D

## Model

### Try Half precision

In [None]:
try:
    c = nn.Conv3d(1, 32, 3).half()
    inputs = torch.randn(8, 1, 8, 256, 256).half()
    outputs = c(inputs)
    display(inputs.shape)
    display(outputs.shape)
except Exception as e:
    print(e)

### Try on GPU

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
try:
    c = nn.Conv3d(1, 32, 3).to(device=device)
    inputs = torch.randn(8, 1, 8, 256, 256).to(device=device)
    outputs = c(inputs)
    display(inputs.shape)
    display(outputs.shape)
except Exception as e:
    print(e)

### From 3D Images To 2D Mask

#### Max Pooling 3D

In [None]:
# pool of non-square window
m = nn.MaxPool3d((8, 1, 1))
inputs = torch.randn(8, 1, 8, 256, 256)
outputs = m(inputs)

display(inputs.shape)
display(outputs.shape)

#### Linear 

In [None]:
m = nn.Linear(20, 30)
input = torch.randn(128, 20)
output = m(input)
display(input.shape)
display(output.size())

In [None]:
input = torch.randn(32, 1, 5, 5)
display(input.shape)
# With default parameters
m = nn.Flatten()
output = m(input)
display(output.size())
# With non-default parameters
m = nn.Flatten(0, 2)
output = m(input)
display(output.size())

In [None]:
inputs = torch.randn(8, 1, 8, 256, 256)

### UNet3D

In [None]:
list_channels = [1, 32, 64, 128]
UNet3D = UNet3D(list_channels, depth=8)
UNet3D(inputs).shape

### H-DenseUNe

# Pytorch-Lightning

## Metrics

In [None]:
from src.data.make_dataset import CustomDataset
from src.models.metrics import F05Score
from constant import TRAIN_FRAGMENTS
from torch.utils.data import DataLoader
from src.utils import get_device

dataset = CustomDataset(TRAIN_FRAGMENTS)
dataloader = DataLoader(dataset=dataset, batch_size=16)
image_sizes = ["A compléter"]
metric = F05Score(image_sizes)
metric_noise = F05Score(image_sizes)
device = get_device()

for inputs, masks, coords, indexes in dataloader:
    outputs = masks
    # This metric have to give a score of 1
    metric.update(outputs, masks, coords, indexes)
    
    outputs = masks + torch.as_tensor(np.random.random_integers(0, 1, size=masks.shape)).to(device=device)
    # This metric have to give a score of ~0.5
    metric_noise.update(outputs, masks, coords, indexes)

print('Perfect F05Score:', metric.compute())
print('Noisy F05Score:', metric_noise.compute())

## Losses

In [None]:
from src.models.losses import CombinedLoss