In [None]:
from pathlib import Path
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt

In [None]:
data_root = Path("/home/haim/code/tumors/data/")
filepath = Path("/home/haim/code/tumors/data/volumes/volume-0.nii")

img = nib.load(filepath)
img_data = img.get_fdata()
print(img_data.shape)
img_data.shape

In [None]:
idx = 73
# plt.figure(figsize=(10, 10))
plt.imshow(img.to("cpu")[:, :, idx].T, cmap="gray")
# plt.imshow(mask[:, :, idx].T, cmap="viridis", alpha=0.3)

plt.axis('off')
plt.show()

In [None]:
idx = 834
# plt.figure(figsize=(10, 10))
plt.imshow(img2.to("cpu")[:, :, idx].T, cmap="gray")
# plt.imshow(mask[:, :, idx].T, cmap="viridis", alpha=0.3)

plt.axis('off')
plt.show()

In [None]:
import torch
import torch.nn.functional as F
img = torch.tensor(img_data, dtype=torch.float, device="cuda")

def resize_3d(image, target_depth):
    height, width, depth = image.shape
    image = image.unsqueeze(0).unsqueeze(0)  # Add batch and channel dimensions
    resized_image = F.interpolate(image, size=(height, width, target_depth), mode='trilinear', align_corners=False)
    return resized_image.squeeze(0).squeeze(0)


img2 = resize_3d(img, target_depth=864)

In [None]:
from monai.networks.nets import UNet

net = UNet(
    spatial_dims=3,       
    in_channels=1,        
    out_channels=1,       
    channels=(16, 32, 64),
    strides=(2, 2),       
    kernel_size=3,        
    up_kernel_size=3,     
    dropout=0.1,          

).to("cuda")

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

count_parameters(net) / 1_000_000

In [None]:
img2.to("cuda").unsqueeze(0).shape

In [None]:
import torch
with torch.no_grad():
    input_tensor = torch.randn(1, 1, 512, 512, 864).to("cuda")  # Adjust input size if needed
    output_tensor = net(input_tensor)
print("Output tensor shape:", output_tensor.shape)

In [None]:
result = net(img2.to("cuda").unsqueeze(0).unsqueeze(0))

In [None]:
result

In [None]:
result.shape

In [2]:
import torch
import torch.nn.functional as F
import pytorch_lightning as pl
import mlflow.pytorch
from monai.networks.nets import UNet
from torchmetrics import MetricCollection, Accuracy
from torchmetrics.detection.iou import IntersectionOverUnion
# from mlflow import log_metric


class SegmentationModel(pl.LightningModule):
    def __init__(
        self,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128),
        strides=(2, 2, 2),
        lr=1e-3,
    ):
        super(SegmentationModel, self).__init__()
        self.save_hyperparameters()
        self.model = UNet(
            spatial_dims=3,
            in_channels=1,
            out_channels=1,
            channels=(16, 32, 64),
            strides=(2, 2),
            kernel_size=3,
            up_kernel_size=3,
            dropout=0.1,
        )
        metrics = MetricCollection(
            {"IoU": IntersectionOverUnion(num_classes=2), "Accuracy": Accuracy()}
        )
        self.train_metrics = metrics.clone(prefix="train_")
        self.val_metrics = metrics.clone(prefix="val_")

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log("train_loss", loss)
        preds = torch.argmax(y_hat, dim=1)
        self.train_metrics(preds, y)
        self.log_dict(self.train_metrics, on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log("val_loss", loss, prog_bar=True)
        preds = torch.argmax(y_hat, dim=1)
        self.val_metrics(preds, y)
        self.log_dict(self.val_metrics, on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)

    # def on_train_epoch_end(self):
    #     metrics = self.train_metrics.compute()
    #     for k, v in metrics.items():
    #         log_metric(k, v)
    #     self.train_metrics.reset()

    # def on_validation_epoch_end(self):
    #     metrics = self.val_metrics.compute()
    #     for k, v in metrics.items():
    #         log_metric(k, v)
    #     self.val_metrics.reset()
