In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
pip install monai

In [None]:
import torch
import numpy as np
import random
from typing import List, Callable
from scipy.ndimage import map_coordinates, gaussian_filter

# MONAI imports
from monai.transforms import (
    RandAdjustContrast,
    RandGaussianNoise,
    RandGaussianSmooth,
    RandScaleIntensity,
    RandShiftIntensity,
    RandGibbsNoise,
    RandKSpaceSpikeNoise,
    RandRicianNoise,
    RandHistogramShift,
    RandCoarseDropout,
    RandCoarseShuffle,
)


class CBCTRandAugment:
    """
    RandAugment implementation for preprocessed CBCT volumes with MONAI augmentations.
    Input: torch.Tensor of shape [1, n, n, n], dtype=float32, range=[0,1]
    Works with any cubic volume size.
    """

    def __init__(self, n: int = 2, m: int = 6, use_monai: bool = True):
        """
        Args:
            n: Number of transformations to apply sequentially
            m: Magnitude of transformations (0-10 scale)
            use_monai: Whether to include MONAI transformations
        """
        self.n = n
        self.m = m
        self.use_monai = use_monai

        # Original custom operations
        self.operations: List[Callable] = [
            self._random_rotation,
            self._random_flip,
            self._random_translation,
            self._random_noise,
            self._random_gamma,
            self._random_contrast,
            self._random_gaussian_blur,
            self._elastic_deformation,
            self._random_zoom,
            self._identity,
        ]

        # Add MONAI operations if enabled
        if use_monai:
            self._init_monai_transforms()
            self.operations.extend([
                self._monai_adjust_contrast,
                self._monai_gaussian_noise,
                self._monai_gaussian_smooth,
                self._monai_scale_intensity,
                self._monai_shift_intensity,
                self._monai_gibbs_noise,
                self._monai_kspace_spike,
                self._monai_rician_noise,
                self._monai_histogram_shift,
                self._monai_coarse_dropout,
                self._monai_coarse_shuffle,
            ])

    def _init_monai_transforms(self):
        """Initialize MONAI transforms with magnitude-based probabilities"""
        # These will be recreated with scaled parameters in each method
        pass

    def __call__(self, image: torch.Tensor) -> torch.Tensor:
        """
        Apply RandAugment to preprocessed CBCT volume

        Args:
            image: Tensor of shape [1, n, n, n], range [0,1]
        Returns:
            Augmented tensor of same shape and range
        """
        # Validate input
        assert len(image.shape) == 4, f"Expected 4D tensor, got {len(image.shape)}D"
        assert image.shape[0] == 1, f"Expected channel dimension = 1, got {image.shape[0]}"
        assert image.dtype == torch.float32, f"Expected float32, got {image.dtype}"

        # Select N random operations
        selected_ops = np.random.choice(self.operations, self.n, replace=False)

        # Apply operations sequentially
        for op in selected_ops:
            image = op(image)
            # Ensure values stay in [0,1] range
            image = torch.clamp(image, 0.0, 1.0)

        return image

    def _scale_magnitude(self, max_val: float) -> float:
        """Scale magnitude from [0,10] to [0, max_val]"""
        random_multiplier = random.randint(1, self.m)
        return (random_multiplier / 10.0) * max_val

    def _get_prob(self) -> float:
        """Get probability based on magnitude"""
        return min(0.9, self.m / 10.0)

    # ==================== MONAI-based Transformations ====================

    def _monai_adjust_contrast(self, image: torch.Tensor) -> torch.Tensor:
        """MONAI contrast adjustment with magnitude-scaled gamma"""
        max_gamma = (0.5, 1.5)  # Contrast range
        gamma_range = self._scale_magnitude(0.5)
        gamma = (1.0 - gamma_range, 1.0 + gamma_range)

        transform = RandAdjustContrast(prob=self._get_prob(), gamma=gamma)
        return transform(image)

    def _monai_gaussian_noise(self, image: torch.Tensor) -> torch.Tensor:
        """MONAI Gaussian noise with magnitude-scaled standard deviation"""
        max_std = 0.15
        std = self._scale_magnitude(max_std)

        transform = RandGaussianNoise(prob=self._get_prob(), mean=0.0, std=std)
        return transform(image)

    def _monai_gaussian_smooth(self, image: torch.Tensor) -> torch.Tensor:
        """MONAI Gaussian smoothing with magnitude-scaled sigma"""
        max_sigma = (0.5, 2.0)
        sigma_scale = self._scale_magnitude(1.0)
        sigma_min = 0.25 * sigma_scale
        sigma_max = max_sigma[1] * sigma_scale

        transform = RandGaussianSmooth(
            prob=self._get_prob(),
            sigma_x=(sigma_min, sigma_max),
            sigma_y=(sigma_min, sigma_max),
            sigma_z=(sigma_min, sigma_max)
        )
        return transform(image)

    def _monai_scale_intensity(self, image: torch.Tensor) -> torch.Tensor:
        """MONAI intensity scaling with magnitude-based factors"""
        max_factor = 0.3
        factor = self._scale_magnitude(max_factor)
        factors = (1.0 - factor, 1.0 + factor)

        transform = RandScaleIntensity(prob=self._get_prob(), factors=factors)
        return transform(image)

    def _monai_shift_intensity(self, image: torch.Tensor) -> torch.Tensor:
        """MONAI intensity shift with magnitude-based offset"""
        max_offset = 0.15
        offset = self._scale_magnitude(max_offset)

        transform = RandShiftIntensity(prob=self._get_prob(), offsets=(-offset, offset))
        return transform(image)

    def _monai_gibbs_noise(self, image: torch.Tensor) -> torch.Tensor:
        """MONAI Gibbs ringing artifact simulation"""
        # Scale alpha based on magnitude (0.0 to 1.0)
        max_alpha = 0.8
        alpha = self._scale_magnitude(max_alpha)

        transform = RandGibbsNoise(prob=self._get_prob(), alpha=(0.0, alpha))
        return transform(image)

    def _monai_kspace_spike(self, image: torch.Tensor) -> torch.Tensor:
        """MONAI k-space spike noise (MRI artifact simulation)"""
        # Useful for simulating motion or hardware artifacts
        max_intensity = (0.5, 1.5)
        intensity_scale = self._scale_magnitude(1.0)

        transform = RandKSpaceSpikeNoise(
            prob=self._get_prob(),
            intensity_range=(0.1 * intensity_scale, max_intensity[1] * intensity_scale)
        )
        return transform(image)

    def _monai_rician_noise(self, image: torch.Tensor) -> torch.Tensor:
        """MONAI Rician noise (common in MRI)"""
        max_std = 0.15
        std = self._scale_magnitude(max_std)

        transform = RandRicianNoise(prob=self._get_prob(), mean=0.0, std=std)
        return transform(image)

    def _monai_histogram_shift(self, image: torch.Tensor) -> torch.Tensor:
        """MONAI histogram shifting for intensity distribution changes"""
        max_num_points = 15
        num_points = max(3, int(self._scale_magnitude(max_num_points)))

        transform = RandHistogramShift(prob=self._get_prob(), num_control_points=num_points)
        return transform(image)

    def _monai_coarse_dropout(self, image: torch.Tensor) -> torch.Tensor:
        """MONAI coarse dropout - randomly drop out cubic regions"""
        # Scale number and size of holes based on magnitude
        max_holes = 8
        num_holes = max(1, int(self._scale_magnitude(max_holes)))

        # Hole size as fraction of volume
        max_spatial_size = 0.15  # 15% of dimension
        spatial_size = self._scale_magnitude(max_spatial_size)
        hole_size = int(spatial_size * image.shape[1])  # Assume cubic volume

        transform = RandCoarseDropout(
            prob=self._get_prob(),
            holes=num_holes,
            spatial_size=hole_size,
            fill_value=0
        )
        return transform(image)

    def _monai_coarse_shuffle(self, image: torch.Tensor) -> torch.Tensor:
        """MONAI coarse shuffle - shuffle cubic regions"""
        # Scale number and size of regions based on magnitude
        max_holes = 6
        num_holes = max(1, int(self._scale_magnitude(max_holes)))

        max_spatial_size = 0.12
        spatial_size = self._scale_magnitude(max_spatial_size)
        hole_size = int(spatial_size * image.shape[1])

        transform = RandCoarseShuffle(
            prob=self._get_prob(),
            holes=num_holes,
            spatial_size=hole_size
        )
        return transform(image)

    # ==================== Original Custom Transformations ====================

    def _identity(self, image: torch.Tensor) -> torch.Tensor:
        """Identity transformation"""
        return image

    def _random_rotation(self, image: torch.Tensor) -> torch.Tensor:
        """Random 3D rotation with magnitude-scaled angle"""
        max_angle = 30
        angle = self._scale_magnitude(max_angle)

        axis = np.random.randint(0, 3)
        angle_rad = np.random.uniform(-angle, angle) * np.pi / 180

        if axis == 0:
            dims = [2, 3]
        elif axis == 1:
            dims = [1, 3]
        else:
            dims = [1, 2]

        if abs(angle) < 5:
            return image
        elif abs(angle) % 90 < 5:
            k = int(round(angle / 90)) % 4
            return torch.rot90(image, k=k, dims=dims)
        else:
            return self._apply_affine_rotation(image, angle_rad, axis)

    def _apply_affine_rotation(self, image: torch.Tensor, angle: float, axis: int) -> torch.Tensor:
        """Apply arbitrary angle rotation using affine transformation"""
        img_np = image.squeeze(0).numpy()
        cos_a, sin_a = np.cos(angle), np.sin(angle)
        center = np.array(img_np.shape) / 2

        if axis == 0:
            matrix = np.array([[1, 0, 0], [0, cos_a, -sin_a], [0, sin_a, cos_a]])
        elif axis == 1:
            matrix = np.array([[cos_a, 0, sin_a], [0, 1, 0], [-sin_a, 0, cos_a]])
        else:
            matrix = np.array([[cos_a, -sin_a, 0], [sin_a, cos_a, 0], [0, 0, 1]])

        coords = np.mgrid[0:img_np.shape[0], 0:img_np.shape[1], 0:img_np.shape[2]]
        coords = coords.reshape(3, -1)
        coords = coords - center.reshape(3, 1)
        rotated_coords = matrix @ coords
        rotated_coords = rotated_coords + center.reshape(3, 1)

        rotated_img = map_coordinates(img_np, rotated_coords.reshape(3, *img_np.shape),
                                    order=1, mode='constant', cval=0.0)

        return torch.tensor(rotated_img, dtype=torch.float32).unsqueeze(0)

    def _random_flip(self, image: torch.Tensor) -> torch.Tensor:
        """Random flip along one or more axes"""
        axes_to_flip = []
        for dim in [1, 2, 3]:
            if torch.rand(1) < 0.5:
                axes_to_flip.append(dim)

        for dim in axes_to_flip:
            image = torch.flip(image, dims=[dim])

        return image

    def _random_translation(self, image: torch.Tensor) -> torch.Tensor:
        """Random translation with magnitude-scaled displacement"""
        max_translation = 0.2
        translation = self._scale_magnitude(max_translation)

        shifts = []
        for dim_size in image.shape[1:]:
            max_shift = int(translation * dim_size)
            shift = np.random.randint(-max_shift, max_shift + 1)
            shifts.append(shift)

        for i, shift in enumerate(shifts):
            if shift != 0:
                image = torch.roll(image, shifts=shift, dims=i+1)

        return image

    def _random_noise(self, image: torch.Tensor) -> torch.Tensor:
        """Add Gaussian noise with magnitude-scaled standard deviation"""
        max_std = 0.1
        noise_std = self._scale_magnitude(max_std)

        if noise_std > 0:
            noise = torch.randn_like(image) * noise_std
            image = image + noise

        return image

    def _random_gamma(self, image: torch.Tensor) -> torch.Tensor:
        """Gamma correction with magnitude-scaled gamma value"""
        max_gamma_change = 0.4
        gamma_change = self._scale_magnitude(max_gamma_change)
        gamma = 1.0 + np.random.uniform(-gamma_change/2, gamma_change/2)

        image = torch.pow(image, gamma)

        return image

    def _random_contrast(self, image: torch.Tensor) -> torch.Tensor:
        """Adjust contrast with magnitude-scaled factor"""
        max_contrast_change = 0.4
        contrast_change = self._scale_magnitude(max_contrast_change)
        contrast_factor = 1.0 + np.random.uniform(-contrast_change/2, contrast_change/2)

        mean_val = torch.mean(image)
        image = (image - mean_val) * contrast_factor + mean_val

        return image

    def _random_gaussian_blur(self, image: torch.Tensor) -> torch.Tensor:
        """Apply Gaussian blur with magnitude-scaled sigma"""
        max_sigma = 2.0
        sigma = self._scale_magnitude(max_sigma)

        if sigma > 0.1:
            img_np = image.squeeze(0).numpy()
            blurred = gaussian_filter(img_np, sigma=sigma)
            image = torch.tensor(blurred, dtype=torch.float32).unsqueeze(0)

        return image

    def _elastic_deformation(self, image: torch.Tensor) -> torch.Tensor:
        """Apply elastic deformation with magnitude-scaled displacement"""
        max_displacement = 10
        displacement = self._scale_magnitude(max_displacement)

        if displacement < 1:
            return image

        shape = image.shape[1:]
        dx = np.random.uniform(-displacement, displacement, shape)
        dy = np.random.uniform(-displacement, displacement, shape)
        dz = np.random.uniform(-displacement, displacement, shape)

        sigma = displacement / 3
        dx = gaussian_filter(dx, sigma=sigma)
        dy = gaussian_filter(dy, sigma=sigma)
        dz = gaussian_filter(dz, sigma=sigma)

        coords = np.mgrid[0:shape[0], 0:shape[1], 0:shape[2]].astype(np.float32)
        coords[0] += dx
        coords[1] += dy
        coords[2] += dz

        img_np = image.squeeze(0).numpy()
        deformed = map_coordinates(img_np, coords, order=1, mode='reflect')

        return torch.tensor(deformed, dtype=torch.float32).unsqueeze(0)

    def _random_zoom(self, image: torch.Tensor) -> torch.Tensor:
        """Zoom the image and randomly crop to keep the same dimension"""
        # Define zoom range based on magnitude
        # Zoom > 1.0 means zoom in (makes objects larger)
        # Zoom < 1.0 means zoom out (makes objects smaller)
        max_zoom = 0.3  # Â±30% zoom
        zoom_range = self._scale_magnitude(max_zoom)
        zoom_factor = 1.0 + np.random.uniform(-zoom_range, zoom_range)

        # Clamp zoom to reasonable range (0.7 to 1.5)
        zoom_factor = np.clip(zoom_factor, 0.7, 1.5)

        img_np = image.squeeze(0).numpy()
        original_shape = img_np.shape

        # Calculate new size after zoom
        new_shape = tuple([int(s * zoom_factor) for s in original_shape])

        # Create coordinate grid for the new size
        coords = np.array(np.meshgrid(
            np.linspace(0, original_shape[0] - 1, new_shape[0]),
            np.linspace(0, original_shape[1] - 1, new_shape[1]),
            np.linspace(0, original_shape[2] - 1, new_shape[2]),
            indexing='ij'
        ))

        # Apply zoom through interpolation
        zoomed = map_coordinates(img_np, coords, order=3, mode='constant', cval=0.0)

        if zoom_factor > 1.0:
            # Zoomed in - need to crop
            # Random crop to original size
            crop_start = []
            for i in range(3):
                max_start = new_shape[i] - original_shape[i]
                if max_start > 0:
                    start = np.random.randint(0, max_start + 1)
                else:
                    start = 0
                crop_start.append(start)

            cropped = zoomed[
                crop_start[0]:crop_start[0] + original_shape[0],
                crop_start[1]:crop_start[1] + original_shape[1],
                crop_start[2]:crop_start[2] + original_shape[2]
            ]

            return torch.tensor(cropped, dtype=torch.float32).unsqueeze(0)

        else:
            # Zoomed out - need to pad
            # Random placement of smaller volume in original size canvas
            padded = np.zeros(original_shape, dtype=np.float32)

            pad_start = []
            for i in range(3):
                max_start = original_shape[i] - new_shape[i]
                if max_start > 0:
                    start = np.random.randint(0, max_start + 1)
                else:
                    start = 0
                pad_start.append(start)

            padded[
                pad_start[0]:pad_start[0] + new_shape[0],
                pad_start[1]:pad_start[1] + new_shape[1],
                pad_start[2]:pad_start[2] + new_shape[2]
            ] = zoomed

            return torch.tensor(padded, dtype=torch.float32).unsqueeze(0)


# # Example usage
# if __name__ == "__main__":
#     # Create augmenter with MONAI transforms
#     augmenter = CBCTRandAugment(n=3, m=7, use_monai=True)

#     # Create dummy CBCT volume
#     dummy_volume = torch.rand(1, 128, 128, 128, dtype=torch.float32)

#     # Apply augmentation
#     augmented = augmenter(dummy_volume)

#     print(f"Original shape: {dummy_volume.shape}")
#     print(f"Augmented shape: {augmented.shape}")
#     print(f"Value range: [{augmented.min():.3f}, {augmented.max():.3f}]")
#     print(f"Total operations available: {len(augmenter.operations)}")

In [None]:
import torch.nn as nn

# Modified 3D ResNet for larger input size and binary classification
class ModifiedWideResNet3D(nn.Module):
    """Modified WideResNet3D for larger input and binary classification"""

    def __init__(self, width=2, depth=[2, 2, 2, 2], num_classes=2, dropout_rate=0.3):
        super(ModifiedWideResNet3D, self).__init__()

        nChannels = [16*width, 16*width, 32*width, 64*width, 128*width]

        # Initial convolution - downsample immediately for large inputs
        self.conv1 = nn.Conv3d(1, nChannels[0], kernel_size=7, stride=2,
                              padding=3, bias=False)
        self.bn1 = nn.BatchNorm3d(nChannels[0])
        self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)

        # Residual blocks with progressive downsampling
        self.block1 = self._make_layer(nChannels[0], nChannels[1], depth[0], stride=1, dropout_rate=dropout_rate)
        self.block2 = self._make_layer(nChannels[1], nChannels[2], depth[1], stride=2, dropout_rate=dropout_rate)
        self.block3 = self._make_layer(nChannels[2], nChannels[3], depth[2], stride=2, dropout_rate=dropout_rate)
        self.block4 = self._make_layer(nChannels[3], nChannels[4], depth[3], stride=2, dropout_rate=dropout_rate)

        # Final layers
        self.adaptive_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
        self.dropout = nn.Dropout(dropout_rate)
        self.fc = nn.Linear(nChannels[4], num_classes)

        self._initialize_weights()

    def _make_layer(self, in_channels, out_channels, num_blocks, stride, dropout_rate):
        layers = []
        layers.append(BasicBlock3D(in_channels, out_channels, stride, dropout_rate))

        for i in range(1, num_blocks):
            layers.append(BasicBlock3D(out_channels, out_channels, 1, dropout_rate))

        return nn.Sequential(*layers)

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
            elif isinstance(m, nn.BatchNorm3d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.leaky_relu(x, negative_slope=0.01)
        x = self.maxpool(x)

        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)

        x = self.adaptive_pool(x)
        x = x.view(x.size(0), -1)
        x = self.dropout(x)
        x = self.fc(x)

        return x


class BasicBlock3D(nn.Module):
    """Basic 3D residual block"""

    def __init__(self, in_channels, out_channels, stride=1, dropout_rate=0.0):
        super(BasicBlock3D, self).__init__()

        self.bn1 = nn.BatchNorm3d(in_channels)
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3,
                              stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm3d(out_channels)
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3,
                              stride=1, padding=1, bias=False)

        self.dropout = nn.Dropout3d(dropout_rate) if dropout_rate > 0 else None

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv3d(in_channels, out_channels, kernel_size=1,
                         stride=stride, bias=False),
                nn.BatchNorm3d(out_channels)
            )

    def forward(self, x):
        out = F.leaky_relu(self.bn1(x), negative_slope=0.01)
        out = self.conv1(out)
        out = F.leaky_relu(self.bn2(out), negative_slope=0.01)

        if self.dropout:
            out = self.dropout(out)

        out = self.conv2(out)
        out += self.shortcut(x)

        return out

In [None]:
import os
import numpy as np
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import nibabel as nib
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import matplotlib.pyplot as plt
from tqdm import tqdm
import json
import pandas as pd
from datetime import datetime
import torch.nn.functional as F
from pathlib import Path
# from augmentation_3d import CBCTRandAugment

class NiftiDataset(Dataset):
    """Custom dataset for loading .nii.gz files"""

    def __init__(self, data_dir, split='train', transform=None, target_size=(64, 64, 64)):
        """
        Args:
            data_dir: csv_file = r"d:\Kananat\Data\Last0\labels.csv"
            transform: Optional transform to be applied on a sample
            target_size: Target size to resize images to (depth, height, width)
        """
        self.data_dir = data_dir
        self.transform = transform
        self.target_size = target_size
        self.samples = []

        # # First, collect samples for each class separately
        # class_0_samples = []
        # class_1_samples = []

        # # Load file paths and labels
        # for class_label in ['0', '1']:
        #     class_dir = os.path.join(data_dir, class_label)
        #     if os.path.exists(class_dir):
        #         for filename in os.listdir(class_dir):
        #             if filename.endswith('.nii.gz') or filename.endswith('.nii'):
        #                 filepath = os.path.join(class_dir, filename)
        #                 if class_label == '0':
        #                     class_0_samples.append((filepath, 0))
        #                 else:
        #                     class_1_samples.append((filepath, 1))

        df = pd.read_csv(data_dir)
        df = df[df['split'] == split]

        ID = df.values.tolist()

        for id in ID:

            data_path = Path(data_dir).parent / f"{id[0]}_preprocessed.nii.gz"
            labels = id[1:7]

            if not data_path.exists():
                print(f"Missing file: {data_path}")
                continue

            self.samples.append((str(data_path), labels))

        # self.samples = class_0_samples + class_1_samples
        random.shuffle(self.samples)

        print(f"Total samples: {len(self.samples)}")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        filepath, label = self.samples[idx]

        label = torch.FloatTensor(label)

        # Load NIfTI image
        nii_img = nib.load(filepath)
        image = nii_img.get_fdata()

        # Handle different input dimensions
        if len(image.shape) == 4:
            # If 4D, take the first volume
            image = image[:, :, :, 0]

        # Normalize to [0, 1]
        image = self._normalize_image(image)

        # Resize if needed
        if image.shape != self.target_size:
            image = self._resize_image(image, self.target_size)

        # Convert to tensor and add channel dimension
        image = torch.from_numpy(image).float().unsqueeze(0)  # Shape: (1, D, H, W)

        if self.transform:
            image = self.transform(image)

        return image, label

    def _normalize_image(self, image):
        """Normalize image to [0, 1] range"""
        # Remove NaN and infinity values
        # image = np.nan_to_num(image, nan=0.0, posinf=0.0, neginf=0.0)

        # # Normalize to [0, 1]
        # min_val = np.min(image)
        # max_val = np.max(image)
        # if max_val > min_val:
        #     image = (image - min_val) / (max_val - min_val)
        # else:
        #     image = np.zeros_like(image)

        return image

    def _resize_image(self, image, target_size):
        """Resize image using trilinear interpolation"""
        # # Convert to tensor for interpolation
        # image_tensor = torch.from_numpy(image).float().unsqueeze(0).unsqueeze(0)

        # # Resize using trilinear interpolation
        # resized = F.interpolate(
        #     image_tensor,
        #     size=target_size,
        #     mode='trilinear',
        #     align_corners=False
        # )

        # return resized.squeeze(0).squeeze(0).numpy()
        return image


In [None]:
class MultiLabelTrainer:
    """Training class for multilabel classification"""

    def __init__(self, model, device, num_classes, save_dir='checkpoints', mixed_precision=False, threshold=0.5):
        self.model = model.to(device)
        self.device = device
        self.num_classes = num_classes  # Add number of classes
        self.threshold = threshold  # Threshold for converting logits to predictions
        self.save_dir = save_dir
        self.mixed_precision = mixed_precision
        os.makedirs(save_dir, exist_ok=True)

        self.train_losses = []
        self.val_losses = []
        # Change: Store different metrics for multilabel
        self.train_f1_scores = []
        self.val_f1_scores = []
        self.train_precision_scores = []
        self.val_precision_scores = []
        self.train_recall_scores = []
        self.val_recall_scores = []
        self.learning_rates = []

        # Initialize mixed precision training (unchanged)
        if self.mixed_precision and device.type == 'cuda':
            self.scaler = torch.amp.GradScaler('cuda')
            print("Mixed precision training enabled")
        else:
            self.scaler = None
            if self.mixed_precision:
                print("Mixed precision requested but CUDA not available")

        # Initialize CSV file with multilabel metrics
        self.csv_path = os.path.join(save_dir, 'training_metrics.csv')
        self.init_csv()

    def init_csv(self):
        """Initialize CSV file with headers for multilabel metrics"""
        df = pd.DataFrame(columns=['epoch', 'train_loss', 'train_f1', 'train_precision', 'train_recall',
                                  'val_loss', 'val_f1', 'val_precision', 'val_recall', 'lr'])
        df.to_csv(self.csv_path, index=False)

    def save_metrics_to_csv(self, epoch, train_loss, train_f1, train_prec, train_rec,
                           val_loss, val_f1, val_prec, val_rec, lr):
        """Save metrics for current epoch to CSV"""
        new_row = {
            'epoch': epoch + 1,
            'train_loss': train_loss,
            'train_f1': train_f1,
            'train_precision': train_prec,
            'train_recall': train_rec,
            'val_loss': val_loss,
            'val_f1': val_f1,
            'val_precision': val_prec,
            'val_recall': val_rec,
            'lr': lr
        }
        df = pd.DataFrame([new_row])
        df.to_csv(self.csv_path, mode='a', header=False, index=False)

    def train(self, train_loader, val_loader, num_epochs=50, lr=0.001, weight_decay=1e-4):
        """Train the model"""
        
        # CHANGE: Use BCEWithLogitsLoss for multilabel
        criterion = nn.BCEWithLogitsLoss()
        optimizer = optim.Adam(self.model.parameters(), lr=lr, weight_decay=weight_decay)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
                                                        factor=0.8, patience=5)

        best_val_f1 = 0.0  # Changed from best_val_acc to F1
        best_val_loss = float('inf')

        for epoch in range(num_epochs):
            print(f'\nEpoch {epoch+1}/{num_epochs}')
            print('-' * 50)

            # Training phase
            train_metrics = self._train_epoch(train_loader, criterion, optimizer)

            # Validation phase
            val_metrics = self._validate_epoch(val_loader, criterion)

            # Get current learning rate
            current_lr = optimizer.param_groups[0]['lr']

            # Update learning rate
            old_lr = optimizer.param_groups[0]['lr']
            scheduler.step(val_metrics['loss'])
            new_lr = optimizer.param_groups[0]['lr']
            if new_lr != old_lr:
                print(f'Learning rate reduced from {old_lr:.6f} to {new_lr:.6f}')

            # Save metrics
            self.train_losses.append(train_metrics['loss'])
            self.val_losses.append(val_metrics['loss'])
            self.train_f1_scores.append(train_metrics['f1'])
            self.val_f1_scores.append(val_metrics['f1'])
            self.train_precision_scores.append(train_metrics['precision'])
            self.val_precision_scores.append(val_metrics['precision'])
            self.train_recall_scores.append(train_metrics['recall'])
            self.val_recall_scores.append(val_metrics['recall'])
            self.learning_rates.append(current_lr)

            # Save metrics to CSV
            self.save_metrics_to_csv(
                epoch, train_metrics['loss'], train_metrics['f1'], train_metrics['precision'], train_metrics['recall'],
                val_metrics['loss'], val_metrics['f1'], val_metrics['precision'], val_metrics['recall'], current_lr
            )

            print(f'Train Loss: {train_metrics["loss"]:.4f}, F1: {train_metrics["f1"]:.4f}, '
                  f'Precision: {train_metrics["precision"]:.4f}, Recall: {train_metrics["recall"]:.4f}')
            print(f'Val Loss: {val_metrics["loss"]:.4f}, F1: {val_metrics["f1"]:.4f}, '
                  f'Precision: {val_metrics["precision"]:.4f}, Recall: {val_metrics["recall"]:.4f}')
            print(f'Learning Rate: {current_lr:.6f}')

            # Save best model based on F1 score
            if val_metrics['f1'] >= best_val_f1:
                best_val_f1 = val_metrics['f1']
                self.save_checkpoint(epoch, val_metrics['f1'], 'best_model_F1.pth')
                print(f'New best model saved with Val F1: {val_metrics["f1"]:.4f}')

            if val_metrics['loss'] < best_val_loss:
                best_val_loss = val_metrics['loss']
                self.save_checkpoint(epoch, val_metrics['f1'], 'best_model_Loss.pth')
                print(f'New best model saved with Val Loss: {val_metrics["loss"]:.4f}')

            # Save regular checkpoint
            if (epoch + 1) % 25 == 0:
                self.save_checkpoint(epoch, val_metrics['f1'], f'checkpoint_epoch_{epoch+1}.pth')

        print(f'\nTraining completed. Best validation F1: {best_val_f1:.4f}')
        print(f'Training metrics saved to: {self.csv_path}')
        return best_val_f1

    def _train_epoch(self, train_loader, criterion, optimizer):
        """Train for one epoch - modified for multilabel"""
        self.model.train()
        running_loss = 0.0
        all_predictions = []
        all_targets = []

        progress_bar = tqdm(train_loader, desc='Training')

        for batch_idx, (data, target) in enumerate(progress_bar):
            data, target = data.to(self.device, non_blocking=True), target.to(self.device, non_blocking=True)
            
            # IMPORTANT: Ensure target is float for BCEWithLogitsLoss
            target = target.float()

            optimizer.zero_grad()

            # Mixed precision forward pass
            if self.scaler is not None:
                with torch.amp.autocast('cuda'):
                    output = self.model(data)
                    loss = criterion(output, target)

                # Mixed precision backward pass
                self.scaler.scale(loss).backward()
                self.scaler.step(optimizer)
                self.scaler.update()
            else:
                # Standard precision
                output = self.model(data)
                loss = criterion(output, target)
                loss.backward()
                optimizer.step()

            running_loss += loss.item()
            
            # CHANGE: Convert logits to predictions using sigmoid and threshold
            with torch.no_grad():
                predictions = (torch.sigmoid(output) > self.threshold).float()
                all_predictions.append(predictions.cpu().numpy())
                all_targets.append(target.cpu().numpy())

            # Calculate batch metrics
            batch_f1 = self._calculate_batch_metrics(predictions, target)
            
            # Update progress bar
            progress_bar.set_postfix({
                'Loss': f'{running_loss/(batch_idx+1):.4f}',
                'F1': f'{batch_f1:.3f}',
                'GPU_Mem': f'{torch.cuda.memory_allocated()/1e9:.1f}GB' if torch.cuda.is_available() else 'N/A'
            })

        # Calculate epoch metrics
        all_predictions = np.vstack(all_predictions)
        all_targets = np.vstack(all_targets)
        metrics = self._calculate_multilabel_metrics(all_predictions, all_targets)
        metrics['loss'] = running_loss / len(train_loader)
        
        return metrics

    def _validate_epoch(self, val_loader, criterion):
        """Validate for one epoch - modified for multilabel"""
        self.model.eval()
        running_loss = 0.0
        all_predictions = []
        all_targets = []

        with torch.no_grad():
            progress_bar = tqdm(val_loader, desc='Validation')

            for batch_idx, (data, target) in enumerate(progress_bar):
                data, target = data.to(self.device, non_blocking=True), target.to(self.device, non_blocking=True)
                target = target.float()

                # Mixed precision inference
                if self.scaler is not None:
                    with torch.amp.autocast('cuda'):
                        output = self.model(data)
                        loss = criterion(output, target)
                else:
                    output = self.model(data)
                    loss = criterion(output, target)

                running_loss += loss.item()
                
                # Convert logits to predictions
                predictions = (torch.sigmoid(output) > self.threshold).float()
                all_predictions.append(predictions.cpu().numpy())
                all_targets.append(target.cpu().numpy())
                
                # Calculate batch metrics
                batch_f1 = self._calculate_batch_metrics(predictions, target)

                progress_bar.set_postfix({
                    'Loss': f'{running_loss/(batch_idx+1):.4f}',
                    'F1': f'{batch_f1:.3f}',
                    'GPU_Mem': f'{torch.cuda.memory_allocated()/1e9:.1f}GB' if torch.cuda.is_available() else 'N/A'
                })

        # Calculate epoch metrics
        all_predictions = np.vstack(all_predictions)
        all_targets = np.vstack(all_targets)
        metrics = self._calculate_multilabel_metrics(all_predictions, all_targets)
        metrics['loss'] = running_loss / len(val_loader)
        
        return metrics

    def _calculate_batch_metrics(self, predictions, targets):
        """Calculate F1 score for a batch"""
        from sklearn.metrics import f1_score
        pred_np = predictions.cpu().numpy() if torch.is_tensor(predictions) else predictions
        target_np = targets.cpu().numpy() if torch.is_tensor(targets) else targets
        return f1_score(target_np, pred_np, average='macro', zero_division=0)

    def _calculate_multilabel_metrics(self, predictions, targets):
        """Calculate comprehensive multilabel metrics"""
        from sklearn.metrics import f1_score, precision_score, recall_score, hamming_loss
        
        metrics = {
            'f1': f1_score(targets, predictions, average='macro', zero_division=0),
            'precision': precision_score(targets, predictions, average='macro', zero_division=0),
            'recall': recall_score(targets, predictions, average='macro', zero_division=0),
            'f1_micro': f1_score(targets, predictions, average='micro', zero_division=0),
            'f1_samples': f1_score(targets, predictions, average='samples', zero_division=0),
            'hamming_loss': hamming_loss(targets, predictions)
        }
        return metrics

    def evaluate(self, test_loader):
        """Evaluate on test set - modified for multilabel"""
        self.model.eval()
        all_preds = []
        all_targets = []
        all_probs = []

        with torch.no_grad():
            for data, target in tqdm(test_loader, desc='Testing'):
                data, target = data.to(self.device), target.to(self.device)
                target = target.float()
                
                output = self.model(data)
                probs = torch.sigmoid(output)
                predictions = (probs > self.threshold).float()

                all_preds.append(predictions.cpu().numpy())
                all_targets.append(target.cpu().numpy())
                all_probs.append(probs.cpu().numpy())

        all_preds = np.vstack(all_preds)
        all_targets = np.vstack(all_targets)
        all_probs = np.vstack(all_probs)

        # Calculate metrics
        from sklearn.metrics import classification_report, multilabel_confusion_matrix
        
        metrics = self._calculate_multilabel_metrics(all_preds, all_targets)
        
        print(f'\nTest Metrics:')
        print(f'F1 Score (macro): {metrics["f1"]:.4f}')
        print(f'F1 Score (micro): {metrics["f1_micro"]:.4f}')
        print(f'F1 Score (samples): {metrics["f1_samples"]:.4f}')
        print(f'Precision (macro): {metrics["precision"]:.4f}')
        print(f'Recall (macro): {metrics["recall"]:.4f}')
        print(f'Hamming Loss: {metrics["hamming_loss"]:.4f}')
        
        # Per-class metrics
        report = classification_report(all_targets, all_preds, target_names=[f'Class_{i}' for i in range(self.num_classes)])
        print('\nPer-Class Classification Report:')
        print(report)
        
        # Confusion matrices for each class
        cm = multilabel_confusion_matrix(all_targets, all_preds)
        
        return metrics, report, cm, all_probs

    def plot_training_curves(self):
        """Plot training and validation curves - modified for multilabel metrics"""
        fig, axes = plt.subplots(2, 3, figsize=(18, 10))
        axes = axes.flatten()

        epochs = range(1, len(self.train_losses) + 1)

        # Loss curves
        axes[0].plot(epochs, self.train_losses, label='Train Loss')
        axes[0].plot(epochs, self.val_losses, label='Val Loss')
        axes[0].set_title('Training and Validation Loss')
        axes[0].set_xlabel('Epoch')
        axes[0].set_ylabel('Loss')
        axes[0].legend()
        axes[0].grid(True)

        # F1 curves
        axes[1].plot(epochs, self.train_f1_scores, label='Train F1')
        axes[1].plot(epochs, self.val_f1_scores, label='Val F1')
        axes[1].set_title('F1 Score')
        axes[1].set_xlabel('Epoch')
        axes[1].set_ylabel('F1 Score')
        axes[1].legend()
        axes[1].grid(True)

        # Precision curves
        axes[2].plot(epochs, self.train_precision_scores, label='Train Precision')
        axes[2].plot(epochs, self.val_precision_scores, label='Val Precision')
        axes[2].set_title('Precision')
        axes[2].set_xlabel('Epoch')
        axes[2].set_ylabel('Precision')
        axes[2].legend()
        axes[2].grid(True)

        # Recall curves
        axes[3].plot(epochs, self.train_recall_scores, label='Train Recall')
        axes[3].plot(epochs, self.val_recall_scores, label='Val Recall')
        axes[3].set_title('Recall')
        axes[3].set_xlabel('Epoch')
        axes[3].set_ylabel('Recall')
        axes[3].legend()
        axes[3].grid(True)

        # Learning rate curve
        axes[4].plot(epochs, self.learning_rates, label='Learning Rate', color='orange')
        axes[4].set_title('Learning Rate Schedule')
        axes[4].set_xlabel('Epoch')
        axes[4].set_ylabel('Learning Rate')
        axes[4].set_yscale('log')
        axes[4].legend()
        axes[4].grid(True)

        # Hide the last subplot
        axes[5].axis('off')

        plt.tight_layout()
        plt.savefig(os.path.join(self.save_dir, 'training_curves.png'))
        plt.show()

In [None]:
def main():
    """Main training function"""

    # Configuration
    config = {
        'data_dir': r"d:\Kananat\Data\Last0\3_Preprocessed\labels.csv",  # Update this path
        'save_dir': r'/content/drive/MyDrive/TMJOA_data/3D_model/hopefully_last/ResNet/OA_noMONAI',
        'target_size': (256, 256, 256),      # Resize from 255x255x255 to 64x64x64
        'batch_size': 8,                  # Small batch size due to large images
        'num_epochs': 200,
        'learning_rate': 0.001,
        'weight_decay': 1e-4,
        'num_workers': 0,                 # Set to 0 for Windows compatibility
        'mixed_precision': True,          # Enable mixed precision training
    }

    # Create model
    model = ModifiedWideResNet3D(
        width=2,
        depth=[2, 2, 2, 2],
        num_classes=5,
        dropout_rate=0.1
    )

    # ########
    # Continue training

    # checkpoint = torch.load('/content/best_model.pth', map_location='cuda')
    # model = ModifiedWideResNet3D(
    #     input_size=config['target_size'],
    #     width=2,
    #     num_classes=2,
    #     dropout_rate=0.1
    # )
    # model.load_state_dict(checkpoint['model_state_dict'])
    # model = model.to(device)
    # model.train()
    # #########

    # Device configuration with optimization
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Using device: {device}')

    if torch.cuda.is_available():
        print(f'GPU: {torch.cuda.get_device_name(0)}')
        print(f'GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB')
        print(f'CUDA Version: {torch.version.cuda}')

        # Enable cudnn benchmarking for consistent input sizes
        torch.backends.cudnn.benchmark = True

    else:
        print('CUDA not available, using CPU')
        config['mixed_precision'] = False

    # Data transforms
    train_transform = CBCTRandAugment(n=2, m=6, use_monai=True)

    # Create datasets
    print("Loading datasets...")
    train_dataset = NiftiDataset(config['data_dir'], 'train',
        transform=train_transform,
        target_size=config['target_size']
    )

    val_dataset = NiftiDataset(config['data_dir'], 'val',
        transform=None,
        target_size=config['target_size']
    )

    test_dataset = NiftiDataset(config['data_dir'], 'test',
        transform=None,
        target_size=config['target_size']
    )

    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=config['batch_size'],
        shuffle=True,
        num_workers=config['num_workers'],
        persistent_workers=True if torch.cuda.is_available() else False,
        pin_memory=True if torch.cuda.is_available() else False,
        prefetch_factor=2
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=config['batch_size'],
        shuffle=False,
        num_workers=config['num_workers'],
        persistent_workers=True if torch.cuda.is_available() else False,
        pin_memory=True if torch.cuda.is_available() else False,
        prefetch_factor=2
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=config['batch_size'],
        shuffle=False,
        num_workers=config['num_workers'],
        pin_memory=True if torch.cuda.is_available() else False
    )

    print(f"Train samples: {len(train_dataset)}")
    print(f"Val samples: {len(val_dataset)}")
    print(f"Test samples: {len(test_dataset)}")

    # Print model info
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Model created with {total_params:,} parameters")


    trainer = MultiLabelTrainer(
    model=model,
    device=device,
    save_dir=config['save_dir'],
    mixed_precision=config['mixed_precision'],
    num_classes=10,  # number of labels
    threshold=0.5  # can be tuned
)

    # Train model
    print("Starting training...")
    best_val_acc = trainer.train(
        train_loader,
        val_loader,
        num_epochs=config['num_epochs'],
        lr=config['learning_rate'],
        weight_decay=config['weight_decay']
    )

    # Plot training curves
    trainer.plot_training_curves()

    # Load best model and evaluate
    print("Loading best model for final evaluation...")
    trainer.load_checkpoint('best_model_Loss.pth')
    test_acc, test_report, test_cm = trainer.evaluate(test_loader)

    # Save final results
    results = {
        'best_val_accuracy': best_val_acc,
        'test_accuracy': test_acc,
        'config': config,
        'model_parameters': total_params,
        'timestamp': datetime.now().isoformat()
    }

    with open(os.path.join(trainer.save_dir, 'results.json'), 'w') as f:
        json.dump(results, f, indent=2)

    print(f"Training completed! Final test accuracy: {test_acc:.4f}")

In [None]:
if __name__ == "__main__":
    main()