In [None]:
# !gdown 15DnB8MFbW8JaZteDv6EE02AAWN0tM_eB

In [None]:
# !unzip /content/drive/MyDrive/Datasets/ACDC/ACDC.zip -d /content
# !unzip /content/ACDC.zip -d /content

In [3]:
!pip install nibabel

Looking in links: /cvmfs/soft.computecanada.ca/custom/python/wheelhouse/gentoo2023/x86-64-v3, /cvmfs/soft.computecanada.ca/custom/python/wheelhouse/gentoo2023/generic, /cvmfs/soft.computecanada.ca/custom/python/wheelhouse/generic
Collecting nibabel
  Using cached nibabel-5.3.2-py3-none-any.whl.metadata (9.1 kB)
Processing /cvmfs/soft.computecanada.ca/custom/python/wheelhouse/generic/importlib_resources-6.5.2+computecanada-py3-none-any.whl (from nibabel)
Using cached nibabel-5.3.2-py3-none-any.whl (3.3 MB)
Installing collected packages: importlib-resources, nibabel
Successfully installed importlib-resources-6.5.2+computecanada nibabel-5.3.2


In [5]:
!pip install tqdm

Looking in links: /cvmfs/soft.computecanada.ca/custom/python/wheelhouse/gentoo2023/x86-64-v3, /cvmfs/soft.computecanada.ca/custom/python/wheelhouse/gentoo2023/generic, /cvmfs/soft.computecanada.ca/custom/python/wheelhouse/generic
Processing /cvmfs/soft.computecanada.ca/custom/python/wheelhouse/generic/tqdm-4.67.1+computecanada-py3-none-any.whl
Installing collected packages: tqdm
Successfully installed tqdm-4.67.1+computecanada


In [7]:
!pip install torch

Looking in links: /cvmfs/soft.computecanada.ca/custom/python/wheelhouse/gentoo2023/x86-64-v3, /cvmfs/soft.computecanada.ca/custom/python/wheelhouse/gentoo2023/generic, /cvmfs/soft.computecanada.ca/custom/python/wheelhouse/generic
Processing /cvmfs/soft.computecanada.ca/custom/python/wheelhouse/gentoo2023/x86-64-v3/torch-2.6.0+computecanada-cp310-cp310-linux_x86_64.whl
Processing /cvmfs/soft.computecanada.ca/custom/python/wheelhouse/generic/filelock-3.17.0+computecanada-py3-none-any.whl (from torch)
Processing /cvmfs/soft.computecanada.ca/custom/python/wheelhouse/generic/networkx-3.4.2+computecanada-py3-none-any.whl (from torch)
Processing /cvmfs/soft.computecanada.ca/custom/python/wheelhouse/generic/fsspec-2025.2.0+computecanada-py3-none-any.whl (from torch)
Processing /cvmfs/soft.computecanada.ca/custom/python/wheelhouse/generic/sympy-1.13.1+computecanada-py3-none-any.whl (from torch)
Installing collected packages: sympy, networkx, fsspec, filelock, torch
  Attempting uninstall: sympy

# Load Data

In [1]:
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import os
import nibabel as nib
from tqdm import tqdm
import random
from torch.utils.data import DataLoader, Dataset
import os
import random
import pandas as pd

In [9]:
def load_nifti(file_path):
    """
    Load a NIfTI file and return its data as a NumPy array.
    """
    nii = nib.load(file_path)
    return nii.get_fdata()

def extract_2d_slices(volume, slice_axis=2):
    """
    Extract 2D slices from a 3D volume along the specified axis.
    Args:
        volume: 3D numpy array (H x W x D)
        slice_axis: Axis along which to extract slices (default: 2)
    Returns:
        List of 2D slices.
    """
    image_slices = []

    for slice_idx in range(volume.shape[slice_axis]):
        slice_data = volume[:, :, slice_idx]
        slice_data = np.uint8(np.interp(slice_data, (slice_data.min(), slice_data.max()), (0, 255)))
        img_pil = Image.fromarray(slice_data)
        image_slices.append(Image.fromarray(slice_data))

    return image_slices

def Read_ACDC(data_dir, output_dir, data_set,slice_axis=2):

    os.makedirs(output_dir, exist_ok=True)

    patients = [ i for i in os.listdir(data_dir) if i != 'MANDATORY_CITATION.md']

    for patient_id in tqdm(sorted(patients)):
        patient_dir = os.path.join(data_dir, patient_id)
        if not os.path.isdir(patient_dir):
            print("Patient Directory Not Found")
            print(patient_dir)
            continue

        patient_files =[i for i in os.listdir(patient_dir) if
         (i.endswith(".nii.gz")) and (f'{patient_id}_frame' in i)]

        image_data = sorted([i for i in patient_files if 'gt' not in i])
        gt_data = sorted([i for i in patient_files if 'gt' in i])

        ed_file = os.path.join(patient_dir, image_data[0])
        es_file = os.path.join(patient_dir, image_data[1])
        ed_file_gt = os.path.join(patient_dir, gt_data[0])
        es_file_gt = os.path.join(patient_dir, gt_data[1])
        print('\n',patient_id)
        print('ed:', ed_file)
        print('es:', es_file)
        print('ed gt:', ed_file_gt)
        print('es gt:', es_file_gt)
        print("#"*73)
        print("\n")

        if os.path.exists(ed_file) and os.path.exists(es_file) and os.path.exists(es_file_gt) and os.path.exists(es_file_gt):
            ed_volume = load_nifti(ed_file)
            es_volume = load_nifti(es_file)
            ed_volume_gt = load_nifti(ed_file_gt)
            es_volume_gt = load_nifti(es_file_gt)

            # Extract 2D slices
            ed_slices = extract_2d_slices(ed_volume, slice_axis)
            es_slices = extract_2d_slices(es_volume, slice_axis)
            ed_slices_gt = extract_2d_slices(ed_volume_gt, slice_axis)
            es_slices_gt = extract_2d_slices(es_volume_gt, slice_axis)

            os.makedirs(f"{output_dir}/{data_set}/ed/images", exist_ok=True)
            os.makedirs(f"{output_dir}/{data_set}/es/images", exist_ok=True)
            os.makedirs(f"{output_dir}/{data_set}/ed/masks", exist_ok=True)
            os.makedirs(f"{output_dir}/{data_set}/es/masks", exist_ok=True)

            for idx, slice_d in enumerate(ed_slices):
              slice_name = ed_file.split("/")[-1].replace(".nii.gz",'') + f'_slice_{idx}.png'
              slice_d.save(f"{output_dir}/{data_set}/ed/images/{slice_name}")

            for idx, slice_d in enumerate(es_slices):
              slice_name = es_file.split("/")[-1].replace(".nii.gz",'') + f'_slice_{idx}.png'
              slice_d.save(f"{output_dir}/{data_set}/es/images/{slice_name}")

            for idx, slice_d in enumerate(ed_slices_gt):
              slice_name = ed_file_gt.split("/")[-1].replace(".nii.gz",'') + f'_slice_{idx}_gt.png'
              slice_d.save(f"{output_dir}/{data_set}/ed/masks/{slice_name}")

            for idx, slice_d in enumerate(es_slices_gt):
              slice_name = es_file_gt.split("/")[-1].replace(".nii.gz",'') + f'_slice_{idx}_gt.png'
              slice_d.save(f"{output_dir}/{data_set}/es/masks/{slice_name}")

    ed_slices_num = len(os.listdir(f"{output_dir}/{data_set}/ed/images"))
    es_slices_num = len(os.listdir(f"{output_dir}/{data_set}/es/images"))
    ed_slices_num_gt = len(os.listdir(f"{output_dir}/{data_set}/ed/masks"))
    es_slices_num_gt = len(os.listdir(f"{output_dir}/{data_set}/es/masks"))

    print("\n")
    print(f"Number of ed slices: {ed_slices_num}\nNumber of es slices: {es_slices_num}")
    print(f"Number of ed slices gt: {ed_slices_num_gt}\nNumber of es slices gt: {es_slices_num_gt}")

In [None]:
data_dir = "ACDC/database/training/"  # Path to ACDC dataset
output_dir = "output_data"  # Path to save slices and pairs
Read_ACDC(data_dir,output_dir,"Training")

In [None]:
data_dir = "ACDC/database/testing/"  # Path to ACDC dataset
output_dir = "output_data"  # Path to save slices and pairs
Read_ACDC(data_dir,output_dir,"Testing")

## Model

In [22]:
!module load StdEnv/2020

intel/2020.1.217:
The software listed above is available for non-commercial usage only. By
continuing, you
accept that you will not use the software for commercial purposes.

Le logiciel listé ci-dessus est disponible pour usage non commercial
seulement. En
continuant, vous acceptez de ne pas l'utiliser pour un usage commercial.
	 

Inactive Modules:
  1) arrow         3) gdrcopy/2.3.1     5) nccl/2.18.3        7) ucc/1.2.0
  2) cuda/12.2     4) hwloc/2.9.1       6) ucc-cuda/1.2.0     8) ucx-cuda/1.14.1

Due to MODULEPATH changes, the following have been reloaded:
  1) cudacore/.12.2.2     2) mii/1.1.2

The following have been reloaded with a version change:
  1) StdEnv/2023 => StdEnv/2020
  2) clang/17.0.6 => clang/15.0.2
  3) cmake/3.31.0 => cmake/3.27.7
  4) gcccore/.12.3 => gcccore/.9.3.0
  5) gentoo/2023 => gentoo/2020
  6) imkl/2023.2.0 => imkl/2020.1.217
  7) ipykernel/2024b => ipykernel/2020a
  8) libfabric/1.18.0 => libfabric/1.10.1
  9) opencv/4.10.0-2 => opencv/4.5.1
 10) op

In [23]:
!avail_wheels kornia

name    version    python    arch
------  ---------  --------  -------
kornia  0.7.2      py2,py3   generic


In [None]:
!pip install --no-index --find-links=$PYTHON_WHEELHOUSE kornia

In [None]:
!pip install contrastive-learner

In [None]:
!pip install torchvision

In [2]:
import torch
from contrastive_learner import ContrastiveLearner
from torchvision import models

In [3]:
resnet = models.resnet101(pretrained=True)



In [4]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import glob
# from tqdm.notebook import trange, tqdm
from tqdm import tqdm
class ACDCDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        self.image_paths = image_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx])
        if self.transform:
            img = self.transform(img)
        else:
            img = transforms.ToTensor()(img)

        if img.shape[0] == 1:
          img = img.repeat(3, 1, 1)

        return img

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

ed_image_paths = glob.glob('output_data/Training/ed/images/*.png')
es_image_paths = glob.glob('output_data/Training/es/images/*.png')

image_paths = ed_image_paths + es_image_paths
dataset = ACDCDataset(image_paths, transform=transform)
dataloader = DataLoader(dataset, batch_size=5, shuffle=True)

In [5]:
dataset[1].shape

torch.Size([3, 256, 256])

In [6]:
# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [7]:
learner = ContrastiveLearner(
    resnet,
    image_size=256,
    hidden_layer='avgpool',
    project_hidden=True,
    project_dim=128,
    use_nt_xent_loss=True,
    temperature=0.1,
    augment_both=True
).to(device)

opt = torch.optim.Adam(learner.parameters(), lr=3e-4)

In [8]:
for epoch in range(10):
    for images in tqdm(dataloader, desc=f"Epoch {epoch+1}"):
        images = images.to(device)
        loss = learner(images)
        opt.zero_grad()
        loss.backward()
        opt.step()
    print(f"loss: {loss}")

Epoch 1: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 381/381 [01:11<00:00,  5.36it/s]


loss: 0.14805971086025238


Epoch 2: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 381/381 [01:10<00:00,  5.38it/s]


loss: 0.009666907601058483


Epoch 3: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 381/381 [01:05<00:00,  5.81it/s]


loss: 3.0646700859069824


Epoch 4: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 381/381 [01:14<00:00,  5.08it/s]


loss: 1.3880677223205566


Epoch 5: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 381/381 [01:16<00:00,  5.01it/s]


loss: 0.007869720458984375


Epoch 6: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 381/381 [01:14<00:00,  5.15it/s]


loss: 0.3327084183692932


Epoch 7: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 381/381 [01:14<00:00,  5.12it/s]


loss: 0.06518714874982834


Epoch 8: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 381/381 [01:14<00:00,  5.11it/s]


loss: 0.20197877287864685


Epoch 9: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 381/381 [01:15<00:00,  5.04it/s]


loss: 0.0038962827529758215


Epoch 10: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 381/381 [01:14<00:00,  5.10it/s]

loss: 0.037380367517471313



