In [None]:
import os
import nibabel as nib
import numpy as np
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from ipywidgets import interact, IntSlider
from sklearn.metrics import accuracy_score

# Optimizări pentru GPU
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
cudnn.benchmark = True
cudnn.enabled = True

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Definește modelul și încarcă greutățile
model = UNet3D_Attention().to(device).half()  # Folosim float16 pentru reducerea consumului de memorie
model.load_state_dict(torch.load('/content/drive/MyDrive/best_model_fold_5.pt', map_location=device))
model.eval()
def dice_coef(pred, target, smooth=1e-6):
    intersection = np.sum(pred * target)
    return (2. * intersection + smooth) / (np.sum(pred) + np.sum(target) + smooth)
# Clasa dataset optimizată (lazy loading și float16)
class Nii3DDataset(Dataset):
    def __init__(self, image_dir):
        self.image_dir = image_dir
        self.images = [f for f in os.listdir(self.image_dir) if f.endswith('.nii')]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.images[idx])
        image = nib.load(image_path).get_fdata()
        #image = (image - np.min(image)) / (np.max(image) - np.min(image))  # Normalizează
        p1, p99 = np.percentile(image, (1, 99))
        image = np.clip(image, p1, p99)
        image = (image - p1) / (p99 - p1)
        image = torch.tensor(image, dtype=torch.float16).unsqueeze(0)  # Convertire la tensor
        return image, self.images[idx]

# Setează directorul cu imagini de test
test_image_dir = '/content/drive/MyDrive/Algoritmi_An_III/gztest/image/'

test_dataset = Nii3DDataset(test_image_dir)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, pin_memory=True, num_workers=0, persistent_workers=False)  # batch_size redus, pin_memory activat

# Asigură existența directorului pentru predicții
os.makedirs('/content/predictions/', exist_ok=True)

def process_image(image, ground_truth_path, z_depth=64):
    D, H, W = image.shape[2:]  # Obținem dimensiunile volumului
    prediction_full = torch.zeros((D, H, W), dtype=torch.float32, device='cpu')  # Rezultatul complet pe CPU

    ground_truth = nib.load(ground_truth_path).get_fdata()
    ground_truth = (ground_truth > 0.5).astype(np.float32)  # Binarizare ground truth

    with torch.no_grad():
        for z in range(0, D, z_depth):  # Procesare pe segmente Z
            image_slice = image[:, :, z:z+z_depth, :, :].to(device, non_blocking=True).half()  # Trimitem doar bucata pe GPU

            with torch.cuda.amp.autocast():  # Mixed Precision
                prediction_slice = model(image_slice).squeeze().cpu()  # Mutăm pe CPU

            prediction_full[z:z+z_depth, :, :] = prediction_slice.to(torch.float32)  # 🔹 Conversie înainte de asignare

    prediction = prediction_full.numpy()  # Convertim în NumPy
    t= 0.5
    prediction = (prediction > 0.1).astype(np.float32)  # 🔹 Binarizare pentru compatibilitate cu ground truth

    # Ajustăm dimensiunea ground truth
    ground_truth = ground_truth[:prediction.shape[0], :, :]

    print("Prediction shape:", prediction.shape)
    print("Ground truth shape:", ground_truth.shape)
    pred_voxels = np.sum(prediction == 1)
    gt_voxels = np.sum(ground_truth == 1)

    print(f"Voxeli leziune în prediction: {pred_voxels}")
    print(f"Voxeli leziune în ground truth: {gt_voxels}")
    # Calculăm acuratețea
    accuracy = accuracy_score(ground_truth.flatten(), prediction.flatten())  # 🔹 Acum formatele sunt compatibile
    print(f"Accuracy: {accuracy:.4f}")
    dice = dice_coef(prediction, ground_truth)
    print(f"Dice score (t={t}): {dice:.4f}")
    # Funcția pentru afișarea unui slice specific
    def plot_slice(slice_idx):
        plt.figure(figsize=(12, 6))
        plt.subplot(1, 2, 1)
        plt.imshow(prediction[slice_idx, :, :], cmap='gray')
        plt.title(f"Predicted Mask - Slice {slice_idx}")
        plt.axis('off')

        plt.subplot(1, 2, 2)
        plt.imshow(ground_truth[slice_idx, :, :], cmap='gray')
        plt.title(f"Ground Truth - Slice {slice_idx}")
        plt.axis('off')
        plt.show()

    # Slider interactiv pentru selecția unui slice
    interact(plot_slice, slice_idx=IntSlider(min=0, max=prediction.shape[0]-1, step=1, value=0))

    del image, prediction_full  # Eliberăm memoria GPU
    torch.cuda.empty_cache()



# Loop prin imagini
ground_truth_path = "/content/drive/MyDrive/Algoritmi_An_III/gztest/016/mask_time02_registered_to_time01.nii.gz"
for image, filename in test_loader:
    print(f"Processing {filename[0]}")
    process_image(image, ground_truth_path)

    del image  # Eliberare memorie suplimentară
    torch.cuda.empty_cache()

print("Processing Done")