## Project : Cell image segmentation projects

Contact: Elena Casiraghi (University Milano elena.casiraghi@unimi.it)

Cell segmentation is usually the first step for downstream single-cell analysis in microscopy image-based biology and biomedical research. Deep learning has been widely used for cell-image segmentation.
The CellSeg competition aims to benchmark cell segmentation methods that could be applied to various microscopy images across multiple imaging platforms and tissue types for cell Segmentation. The  Dataset challenge organizers provide both labeled images and unlabeled ones.
The “2018 Data Science Bowl” Kaggle competition provides cell images and their masks for training cell/nuclei segmentation models.

In 2022 another [Cell Segmentation challenge was proposed at Neurips](https://neurips22-cellseg.grand-challenge.org/).
For interested readers, the competition proceeding has been published on [PMLR](https://proceedings.mlr.press/v212/)

### Project Description

In the field of (bio-medical) image processing, segmentation of images is typically performed via U-Nets [1,2].

A U-Net consists of an encoder - a series of convolution and pooling layers which reduce the spatial resolution of the input, followed by a decoder - a series of transposed convolution and upsampling layers which increase the spatial resolution of the input. The encoder and decoder are connected by a bottleneck layer which is responsible for reducing the number of channels in the input.
The key innovation of U-Net is the addition of skip connections that connect the contracting path to the corresponding layers in the expanding path, allowing the network to recover fine-grained details lost during downsampling.

<img src='https://production-media.paperswithcode.com/methods/Screen_Shot_2020-07-07_at_9.08.00_PM_rpNArED.png' width="400"/>


At this [link](https://rpubs.com/eR_ic/unet), you find an R implementation of basic U-Nets. At this [link](https://github.com/zhixuhao/unet), you find a Keras implementation of UNets.  
Other implementations of more advanced UNets are also made available in [2] at these links: [UNet++](https://github.com/MrGiovanni/UNetPlusPlus)
and by the CellSeg organizers as baseline models: [https://neurips22-cellseg.grand-challenge.org/baseline-and-tutorial/](https://neurips22-cellseg.grand-challenge.org/baseline-and-tutorial/)


### Project aim

The aim of the project is to download the *gray-level* (.tiff or .tif files) cell images from the [CellSeg](https://neurips22-cellseg.grand-challenge.org/dataset/) competition and assess the performance of an UNet or any other Deep model for cell segmentation.
We suggest using gray-level images to obtain a model that is better specified on a sub class of images.

Students are not restricted to use UNets but may other model is wellcome; e.g., even transformer based model in the [leaderboard](https://neurips22-cellseg.grand-challenge.org/evaluation/testing/leaderboard/) may be tested.
Students are free to choose any model, as long as they are able to explain their rationale, architecture, strengths and weaknesses.



### References

[1] Ronneberger, O., Fischer, P., Brox, T. (2015). U-Net: Convolutional Networks for Biomedical Image Segmentation. In: Navab, N., Hornegger, J., Wells, W., Frangi, A. (eds) Medical Image Computing and Computer-Assisted Intervention – MICCAI 2015. MICCAI 2015. Lecture Notes in Computer Science(), vol 9351. Springer, Cham. https://doi.org/10.1007/978-3-319-24574-4_28

[2] Long, F. Microscopy cell nuclei segmentation with enhanced U-Net. BMC Bioinformatics 21, 8 (2020). https://doi.org/10.1186/s12859-019-3332-1


## Initialization

In [38]:
!pip install --upgrade gdown
!pip install torch
!pip install torchvision
!pip install opencv-contrib-python
!pip install torchsummary

Collecting torchsummary
  Downloading torchsummary-1.5.1-py3-none-any.whl.metadata (296 bytes)
Downloading torchsummary-1.5.1-py3-none-any.whl (2.8 kB)
Installing collected packages: torchsummary
Successfully installed torchsummary-1.5.1


In [17]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import os
import tempfile
from typing import Callable, List, Tuple
from torchvision import transforms
from PIL import Image

In [18]:
TRAIN_PATH = 'data_train/Training-labeled/images/'
TRAIN_LABELS_PATH = 'data_train/Training-labeled/labels/'

TEST_PATH = 'data_test/Testing/Public/images/'
TEST_LABELS_PATH = 'data_test/Testing/Public/labels/'

VAL_PATH = 'data_val/Tuning/images/'
VAL_LABELS_PATH = 'data_val/Tuning/labels/'

use_cuda = torch.cuda.is_available()
if use_cuda:
  device = torch.device("cuda")
  dataloader_kwargs = {"batch_size": 32, "shuffle": True, "pin_memory": True}
else:
  device = torch.device("cpu")
  dataloader_kwargs = {"batch_size": 64}


### Data preparation
[Browse the data](https://drive.google.com/drive/folders/1MaJibsHYitCPOltxVzYjr3rm5s9Vpjpv)

In [26]:
!curl -o data_test.zip "https://zenodo.org/records/10719375/files/Testing.zip?download=1"
!unzip -d data_test data_test.zip

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0 2793M    0 7354k    0     0  1900k      0  0:25:05  0:00:03  0:25:02 1899kArchive:  data_test.zip
  End-of-central-directory signature not found.  Either this file is not
  a zipfile, or it constitutes one disk of a multi-part archive.  In the
  latter case the central directory and zipfile comment will be found on
  the last disk(s) of this archive.
unzip:  cannot find zipfile directory in one of data_test.zip or
        data_test.zip.zip, and cannot find data_test.zip.ZIP, period.


In [None]:
!curl -o data_train.zip "https://zenodo.org/records/10719375/files/Training-labeled.zip?download=1"
!unzip -d data_train data_train.zip

zsh:1: no matches found: https://zenodo.org/records/10719375/files/Training-labeled.zip?download=1


unzip:  cannot find or open data_train.zip, data_train.zip.zip or data_train.zip.ZIP.


In [None]:
!curl -o data_val.zip "https://zenodo.org/records/10719375/files/Tuning.zip?download=1"
!unzip -d data_val data_val.zip

zsh:1: no matches found: https://zenodo.org/records/10719375/files/Tuning.zip?download=1
unzip:  cannot find or open data_val.zip, data_val.zip.zip or data_val.zip.ZIP.


In [53]:
# Partially adapted from https://colab.research.google.com/github/mim-ml-teaching/public-dnn-2024-25/blob/master/docs/DNN-Lab-7-UNet-in-Pytorch-student-version.ipynb
class ImageTiffDataset(torch.utils.data.Dataset):
 def __init__(self,
              image_dir: str,
              target_dir: str,
              cache_dir: str,
              filenames: List[str],
              transform: torch.nn.Module = transforms.ToTensor(),
              target_transform: torch.nn.Module = transforms.ToTensor()):
   self.image_dir = image_dir
   self.target_dir = target_dir
   self.cache_dir = cache_dir
   self.filenames = filenames
   self.transform = transform
   self.target_transform = target_transform

   if not os.path.exists(self.cache_dir):
     os.mkdir(self.cache_dir)
     os.mkdir(os.path.join(self.cache_dir, "images"))
     os.mkdir(os.path.join(self.cache_dir, "target"))

 def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
   img_filename = self.filenames[idx]
    # Change. I had this error: No such file or directory: 'data_train/Training-labeled/labels/cell_00490.tif.tiff'
   target_filename = os.path.basename(img_filename) + '.tiff'
   target_filename = os.path.basename(img_filename)
    # End of change
   img_path = os.path.join(self.image_dir, img_filename)
   target_path = os.path.join(self.target_dir, target_filename)
   img_cache = os.path.join(self.cache_dir, img_filename)
   target_cache = os.path.join(self.cache_dir, target_filename)

   if not os.path.exists(img_cache):
      # Change. I had this error: AttributeError: module 'PIL.Image' has no attribute 'load'
     img = Image.load(img_path)
      #End of change
     img = Image.open(img_path)
     img = self.transform(img)
     torch.save(img, img_cache)
   else:
     img = torch.load(img_cache)

   if not os.path.exists(target_cache):
      # Change. I had this error: AttributeError: module 'PIL.Image' has no attribute 'load'
     target = Image.load(target_path)
     target = Image.open(target_path)
     # End of change
     target = self.target_transform(target)
     torch.save(target, target_cache)
   else:
     target = torch.load(target_cache)

   return img, target

 def __len__(self) -> int:
   return len(self.filenames)

def make_tiff_dataset(image_dir: str, target_dir: str, cache_dir: str):
 filenames = []
 with os.scandir(image_dir) as it:
   for entry in it:
     if not entry.is_file():
       continue
     lower_name = entry.name.lower()
     if lower_name.endswith('.tiff') or lower_name.endswith('.tif'):
       filenames.append(entry.name)
 return ImageTiffDataset(image_dir, target_dir, cache_dir, filenames)

In [54]:
train_dataset = make_tiff_dataset(TRAIN_PATH, TRAIN_LABELS_PATH, tempfile.mkdtemp())
test_dataset = make_tiff_dataset(TEST_PATH, TEST_LABELS_PATH, tempfile.mkdtemp())
val_dataset = make_tiff_dataset(VAL_PATH, VAL_LABELS_PATH, tempfile.mkdtemp())
dataloader_train = torch.utils.data.DataLoader(train_dataset, **dataloader_kwargs)
dataloader_test = torch.utils.data.DataLoader(test_dataset, **dataloader_kwargs)
dataloader_val = torch.utils.data.DataLoader(val_dataset, **dataloader_kwargs)

print('Train:', len(train_dataset))
print('Test:', len(test_dataset))
print('Val:', len(val_dataset))

Train: 491
Test: 30
Val: 58


In [55]:
def train_unet(
   model: torch.nn.Module,
   device: torch.device,
   train_loader: torch.utils.data.DataLoader,
   optimizer: torch.optim.Optimizer,
   epoch: int,
   log_interval: int):
 model.train()
 correct = 0
 for batch_idx, (data, target) in enumerate(train_loader):
   data, target = data.to(device), target.to(device)
   optimizer.zero_grad()
   output = model(data)
   log_probs = F.log_softmax(output, dim=1)
   loss = F.nll_loss(log_probs, target)
   pred = log_probs.argmax(
       dim=1, keepdim=True
   )  # get the index of the max log-probability
   correct += pred.eq(target.view_as(pred)).sum().item()
   loss.backward()
   optimizer.step()
   if batch_idx % log_interval == 0:
     _, _, image_width, image_height = data.size()
     print("Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
         epoch,
         batch_idx * len(data),
         len(train_loader.dataset),
         100.0 * batch_idx / len(train_loader),
         loss.item(),
       ))
 print(
   "Train accuracy: {}/{} ({:.0f}%)".format(
       correct,
         (len(train_loader.dataset) * image_width * image_height),
         100.0 * correct / (len(train_loader.dataset) * image_width * image_height),
     )
 )

## Basic U-Nets

In [56]:
class UNetConvBlock(nn.Module):
 def __init__(self, in_channels, out_channels):
   self.layer = nn.Sequential(
       nn.Conv2d(
           in_channels,
           out_channels,
           kernel_size=3,
           padding=1,
           dilation=0,
           padding_mode='reflect'),
       nn.ReLU() # Leaky ReLU?
   )

 def forward(self, x):
   return self.layer(x)

class UNetEncoderBlock(nn.Module):
 def __init__(self, in_channels: int, out_channels: int, maxpool: bool = True):
   assert out_channels > in_channels
   if maxpool:
     self.layer = nn.Sequential(
         UNetConvBlock(in_channels, out_channels),
         UNetConvBlock(out_channels, out_channels),
         nn.MaxPool2d(2, dilation=0)
     )
   else:
     self.layer = nn.Sequential(
         UNetConvBlock(in_channels, out_channels),
         UNetConvBlock(out_channels, out_channels)
     )

 def forward(self, x):
   return self.layer(x)

class UNetDecoderBlock(nn.Module):
 def __init__(self, in_channels: int, out_channels: int, unmaxpool: bool = True):
   assert out_channels < in_channels

   if unmaxpool:
     assert False
     self.layer = nn.Sequential(
          # TODO
     )
   else:
     self.layer = nn.Sequential(
         UNetConvBlock(in_channels, out_channels),
         UNetConvBlock(out_channels, out_channels)
     )

class UNet(nn.Module):
 def __init__(self, encoder_channels: List[int], decoder_channels: List[int]):
   assert len(encoder_channels) > 0
   self.encoder = nn.ModuleList()
   self.decoder = nn.ModuleList()

    #Expected:
    #UNetEncoderBlock(64, 128),
    #UNetEncoderBlock(128, 256),
    #UNetEncoderBlock(256, 512),
    #UNetEncoderBlock(512, 1024, False)
   in_channels = 1
   for out_channels in encoder_channels[:-1]:
     self.encoder.append(UNetEncoderBlock(in_channels, out_channels))
     in_channels = out_channels
   self.encoder.append(UNetEncoderBlock(in_channels, encoder_channels[-1], maxpool=False))

   for out_channels in decoder_channels[:-1]:
     self.decoder.append(in_channels)



In [61]:
import os
import cv2
import torch
from torch.utils.data import Dataset

def find_image_mask_pairs(images_dir, masks_dir):
    image_files = [f for f in os.listdir(images_dir) if f.endswith('.tiff') or f.endswith('.tif')]
    mask_files = set(f for f in os.listdir(masks_dir) if f.endswith('.tiff') or f.endswith('.tif'))

    pairs = []
    for img_file in image_files:
        if img_file.endswith('.tiff'):
            base_name = img_file[:-5]
        else:
            base_name = img_file[:-4]

        candidate_mask_tiff = base_name + '_label.tiff'
        candidate_mask_tif = base_name + '_label.tif'

        if candidate_mask_tiff in mask_files:
            pairs.append((img_file, candidate_mask_tiff))
        elif candidate_mask_tif in mask_files:
            pairs.append((img_file, candidate_mask_tif))
        else:
            print(f"Brak maski dla obrazu {img_file}")

    return pairs

class CellSegDataset(Dataset):
    def __init__(self, images_dir, masks_dir, target_size=(128,128)):
        self.images_dir = images_dir
        self.masks_dir = masks_dir
        self.target_size = target_size
        self.pairs = find_image_mask_pairs(images_dir, masks_dir)

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        img_name, mask_name = self.pairs[idx]
        img_path = os.path.join(self.images_dir, img_name)
        mask_path = os.path.join(self.masks_dir, mask_name)

        # Wczytanie obrazów w trybie grayscale
        image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

        if image is None:
            return None, None
        if mask is None:
            return None, None
        
        # Zmiana rozmiaru
        image = cv2.resize(image, self.target_size, interpolation=cv2.INTER_AREA)
        mask = cv2.resize(mask, self.target_size, interpolation=cv2.INTER_NEAREST)

        # Normalizacja obrazu do zakresu [0, 1]
        image = image.astype('float32') / 255.0
        # Maskę możesz zostawić w oryginalnej skali (0/255), albo też znormalizować
        mask = mask.astype('float32') / 255.0

        # Dodaj kanał (C=1) dla kompatybilności z siecią (C,H,W)
        image = image[None, :, :]
        mask = mask[None, :, :]

        # Zamiana na tensory torch
        image = torch.from_numpy(image)
        mask = torch.from_numpy(mask)

        return image, mask

In [47]:
## FIXED DATASET ????


import cv2
import torch
from torch.utils.data import Dataset, DataLoader
import os
import numpy as np
from PIL import Image  # Alternative for handling complex TIFFs

class CellSegmentationDatasetCV2(Dataset):
    """Robust dataset for TIFF files with better error handling"""
    
    def __init__(self, image_dir, mask_dir, target_size=(512, 512), use_pil=False):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.target_size = target_size
        self.use_pil = use_pil  # Option to use PIL instead of OpenCV
        
        # Only process TIFF files
        valid_exts = ('.tif', '.tiff')
        
        # Collect valid image-mask pairs
        self.image_mask_pairs = []
        for fname in os.listdir(image_dir):
            if fname.lower().endswith(valid_exts):
                name, _ = os.path.splitext(fname)
                # Look for corresponding mask
                for ext in valid_exts:
                    mask_fname = f"{name}_label{ext}"
                    mask_path = os.path.join(mask_dir, mask_fname)
                    if os.path.isfile(mask_path):
                        self.image_mask_pairs.append((fname, mask_fname))
                        break
        
        if len(self.image_mask_pairs) == 0:
            raise ValueError(f"No valid TIFF image-mask pairs found")
        
        print(f"Found {len(self.image_mask_pairs)} TIFF image-mask pairs")
        print(f"Using {'PIL' if use_pil else 'OpenCV'} for image loading")
    
    def __len__(self):
        return len(self.image_mask_pairs)
    
    def _load_and_resize_with_pil(self, path, is_mask=False):
        """Load and resize using PIL (more robust for complex TIFFs)"""
        img = Image.open(path)
        
        # Convert to numpy array
        img_array = np.array(img)
        
        # Resize using PIL
        if is_mask:
            img = img.resize(self.target_size, Image.NEAREST)
        else:
            img = img.resize(self.target_size, Image.BILINEAR)
        
        return np.array(img), img_array.shape, img_array.dtype
    
    def _load_and_resize_with_cv2(self, path, is_mask=False):
        """Load and resize using OpenCV with error handling"""
        # Load image
        img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
        if img is None:
            raise ValueError(f"Failed to load image: {path}")
        
        original_shape = img.shape
        original_dtype = img.dtype
        
        # Handle special cases before resize
        if len(img.shape) > 3:
            # Multi-page TIFF or 4D array - take first slice
            img = img[0] if img.shape[0] > 1 else img.squeeze(0)
        
        # Ensure image is 2D or 3D
        if len(img.shape) == 1:
            raise ValueError(f"1D image found at {path}")
        
        # Convert to a dtype that OpenCV can handle for resize
        resize_dtype_map = {
            np.uint16: cv2.CV_16U,
            np.int16: cv2.CV_16S,
            np.uint8: cv2.CV_8U,
            np.float32: cv2.CV_32F,
            np.float64: cv2.CV_32F  # Convert float64 to float32
        }
        
        # Convert problematic dtypes
        if img.dtype == np.float64:
            img = img.astype(np.float32)
        elif img.dtype not in resize_dtype_map:
            # Convert unknown types to float32
            img = img.astype(np.float32)
        
        # Resize
        try:
            interpolation = cv2.INTER_NEAREST if is_mask else cv2.INTER_LINEAR
            img_resized = cv2.resize(img, self.target_size, interpolation=interpolation)
        except cv2.error as e:
            print(f"OpenCV resize failed for {path}")
            print(f"Shape: {img.shape}, dtype: {img.dtype}")
            raise e
        
        return img_resized, original_shape, original_dtype
    
    def __getitem__(self, idx):
        img_fname, mask_fname = self.image_mask_pairs[idx]
        
        img_path = os.path.join(self.image_dir, img_fname)
        mask_path = os.path.join(self.mask_dir, mask_fname)
        
        try:
            # Load and resize images
            if self.use_pil:
                image, orig_img_shape, orig_img_dtype = self._load_and_resize_with_pil(img_path, is_mask=False)
                mask, orig_mask_shape, orig_mask_dtype = self._load_and_resize_with_pil(mask_path, is_mask=True)
            else:
                image, orig_img_shape, orig_img_dtype = self._load_and_resize_with_cv2(img_path, is_mask=False)
                mask, orig_mask_shape, orig_mask_dtype = self._load_and_resize_with_cv2(mask_path, is_mask=True)
            
            # Debug info for first image
            if idx == 0:
                print(f"\n[First TIFF Image Info]")
                print(f"Image path: {img_path}")
                print(f"Original shape: {orig_img_shape}, dtype: {orig_img_dtype}")
                print(f"After resize: {image.shape}, dtype: {image.dtype}")
                print(f"Value range: [{image.min()}, {image.max()}]")
                print(f"Mask unique values: {np.unique(mask)}")
            
        except Exception as e:
            print(f"\nError processing image {idx}: {img_fname}")
            print(f"Error details: {str(e)}")
            # Try with PIL if OpenCV failed
            if not self.use_pil:
                print("Retrying with PIL...")
                self.use_pil = True
                return self.__getitem__(idx)
            else:
                raise
        
        # Ensure mask is 2D
        if len(mask.shape) == 3:
            mask = mask[:, :, 0]
        elif len(mask.shape) != 2:
            raise ValueError(f"Unexpected mask shape: {mask.shape}")
        
        # Normalize image based on data type
        image = image.astype(np.float32)
        if orig_img_dtype == np.uint16 or image.max() > 255:
            # 16-bit image
            image = image / 65535.0
        elif orig_img_dtype == np.uint8 or (image.max() <= 255 and image.min() >= 0):
            # 8-bit image
            image = image / 255.0
        else:
            # Other types - normalize to [0, 1]
            img_min, img_max = image.min(), image.max()
            if img_max > img_min:
                image = (image - img_min) / (img_max - img_min)
        
        # Ensure values are in [0, 1]
        image = np.clip(image, 0, 1)
        
        # Convert mask to binary
        mask = (mask > 0).astype(np.int64)
        
        # Handle channel dimension
        if len(image.shape) == 2:
            # Grayscale: (H, W) -> (1, H, W)
            image = np.expand_dims(image, axis=0)
        else:
            # Multi-channel: (H, W, C) -> (C, H, W)
            image = image.transpose(2, 0, 1)
        
        # Convert to tensors
        image_tensor = torch.from_numpy(image.copy()).float()  # .copy() ensures contiguous memory
        mask_tensor = torch.from_numpy(mask.copy()).long()
        
        return image_tensor, mask_tensor


# Utility function to diagnose TIFF files
def diagnose_tiff_files(image_dir, mask_dir, num_samples=5):
    """Diagnose TIFF files to understand their format"""
    print("="*60)
    print("TIFF File Diagnosis")
    print("="*60)
    
    valid_exts = ('.tif', '.tiff')
    files = [f for f in os.listdir(image_dir) if f.lower().endswith(valid_exts)]
    
    for i, fname in enumerate(files[:num_samples]):
        print(f"\nFile {i+1}: {fname}")
        img_path = os.path.join(image_dir, fname)
        
        # Try OpenCV
        try:
            img_cv = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
            if img_cv is not None:
                print(f"  OpenCV: shape={img_cv.shape}, dtype={img_cv.dtype}, "
                      f"range=[{img_cv.min()}, {img_cv.max()}]")
            else:
                print("  OpenCV: Failed to load")
        except Exception as e:
            print(f"  OpenCV: Error - {str(e)}")
        
        # Try PIL
        try:
            img_pil = Image.open(img_path)
            img_array = np.array(img_pil)
            print(f"  PIL: mode={img_pil.mode}, size={img_pil.size}, "
                  f"array_shape={img_array.shape}, dtype={img_array.dtype}")
        except Exception as e:
            print(f"  PIL: Error - {str(e)}")


# Test function
def test_dataset(image_dir, mask_dir, use_pil=False):
    """Test the dataset with both OpenCV and PIL"""
    print(f"\nTesting dataset with {'PIL' if use_pil else 'OpenCV'}...")
    
    dataset = CellSegmentationDatasetCV2(
        image_dir, mask_dir, 
        target_size=(512, 512),
        use_pil=use_pil
    )
    
    # Test first few samples
    for i in range(min(3, len(dataset))):
        try:
            img, mask = dataset[i]
            print(f"Sample {i}: Image {img.shape}, Mask {mask.shape} ✓")
        except Exception as e:
            print(f"Sample {i}: Failed - {str(e)}")
            return False
    
    return True


if __name__ == "__main__":
    # Example usage
    image_dir = TRAIN_PATH
    mask_dir = TRAIN_LABELS_PATH
    
    # First, diagnose the TIFF files
    diagnose_tiff_files(image_dir, mask_dir)
    
    # Test with OpenCV
    if not test_dataset(image_dir, mask_dir, use_pil=False):
        print("\nOpenCV failed, trying PIL...")
        test_dataset(image_dir, mask_dir, use_pil=True)

# import cv2
# import torch
# from torch.utils.data import Dataset
# import os

# class CellSegmentationDatasetCV2(Dataset):
#     def __init__(self, image_dir, mask_dir, target_size=(128, 128)):
#         self.image_dir = image_dir
#         self.mask_dir = mask_dir
#         self.target_size = target_size

#         # Uwzględnij tylko .tif i .tiff
#         valid_exts = ('.tif', '.tiff')

#         # Zbierz dostępne obrazy i sprawdź, czy odpowiadające maski istnieją
#         self.image_mask_pairs = []
#         for fname in os.listdir(image_dir):
#             if fname.lower().endswith(valid_exts):
#                 name, _ = os.path.splitext(fname)
#                 # Zakładamy, że maska ma "_label" dodane
#                 for ext in valid_exts:
#                     mask_fname = f"{name}_label{ext}"
#                     mask_path = os.path.join(mask_dir, mask_fname)
#                     if os.path.isfile(mask_path):
#                         self.image_mask_pairs.append((fname, mask_fname))
#                         break  # znajdziono odpowiednią maskę, więc nie trzeba sprawdzać dalej

#     def __len__(self):
#         return len(self.image_mask_pairs)

#     def __getitem__(self, idx):
#         img_path, mask_path = self.image_mask_pairs[idx]

#         img_path = os.path.join(self.image_dir, img_path)
#         mask_path = os.path.join(self.mask_dir, mask_path)

#         image = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
#         mask = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)

#         if image is None:
#             raise ValueError(f"Nie udało się wczytać obrazu: {img_path}")
#         if mask is None:
#             raise ValueError(f"Nie udało się wczytać maski: {mask_path}")

#         print(f"Image shape before resize: {image.shape}, dtype: {image.dtype}")
#         print(f"Mask shape before resize: {mask.shape}, dtype: {mask.dtype}")

#         return image, mask

TIFF File Diagnosis

File 1: cell_00279.tif
  OpenCV: shape=(1608, 1608), dtype=uint16, range=[100, 276]
  PIL: mode=I;16, size=(1608, 1608), array_shape=(1608, 1608), dtype=uint16

File 2: cell_00379.tif
  OpenCV: shape=(970, 970), dtype=uint8, range=[6, 245]
  PIL: mode=L, size=(970, 970), array_shape=(970, 970), dtype=uint8

File 3: cell_00644.tif
  OpenCV: shape=(420, 350), dtype=uint16, range=[328, 6859]
  PIL: mode=I;16, size=(350, 420), array_shape=(420, 350), dtype=uint16

File 4: cell_00352.tif
  OpenCV: shape=(970, 970), dtype=uint8, range=[4, 255]
  PIL: mode=L, size=(970, 970), array_shape=(970, 970), dtype=uint8

File 5: cell_00263.tif
  OpenCV: shape=(1608, 1608), dtype=uint16, range=[99, 309]
  PIL: mode=I;16, size=(1608, 1608), array_shape=(1608, 1608), dtype=uint16

Testing dataset with OpenCV...
Found 491 TIFF image-mask pairs
Using OpenCV for image loading

[First TIFF Image Info]
Image path: data_train/Training-labeled/images/cell_00279.tif
Original shape: (1608, 16

In [None]:
def print_all_shapes(name, dataset):
    print(f"\n{name} — liczba par: {len(dataset)}")
    for idx in range(len(dataset)):
        img, mask = dataset[idx]
        print(f"[{idx}] Image shape: {img.shape}, Mask shape: {mask.shape}")

# Załaduj dane
train_dataset = ImageTiffDataset(TRAIN_PATH, TRAIN_LABELS_PATH) # TODO: more args needed or change init
#print(len(train_dataset))
val_dataset = ImageTiffDataset(VAL_PATH, VAL_LABELS_PATH)
test_dataset = ImageTiffDataset(TEST_PATH, TEST_LABELS_PATH)

# Wypisz wszystkie wymiary
print_all_shapes("Train", train_dataset)
print_all_shapes("Validation", val_dataset)
print_all_shapes("Test", test_dataset)

import numpy as np

def analyze_mask_values(dataset, name):
    print(f"\nAnaliza masek w zbiorze: {name}")
    unique_values = set()
    for idx in range(len(dataset)):
        _, mask = dataset[idx]
        vals = np.unique(mask)
        unique_values.update(vals.tolist())
        if idx < 5:  # pokaż kilka przykładów na start
            print(f"[{idx}] Unikalne wartości: {vals}")
    print(f"Zbiór wszystkich unikalnych wartości w maskach ({name}): {sorted(unique_values)}")

analyze_mask_values(train_dataset, "Train")
analyze_mask_values(val_dataset, "Validation")
analyze_mask_values(test_dataset, "Test")

TypeError: ImageTiffDataset.__init__() missing 2 required positional arguments: 'cache_dir' and 'filenames'

In [36]:
import cv2

img_path = "data_train/Training-labeled/images/cell_00302.tiff"
image = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
print(image)

[[43. 42. 41. ... 44. 42. 43.]
 [43. 40. 40. ... 42. 43. 42.]
 [43. 41. 40. ... 42. 43. 41.]
 ...
 [51. 48. 46. ... 42. 43. 42.]
 [50. 48. 47. ... 43. 44. 42.]
 [49. 49. 50. ... 45. 46. 46.]]


## Basic UNet (UNet_1) from https://rpubs.com/eR_ic/unet

In [49]:
# This code implements a U-Net model for semantic segmentation from the paper U-Net: Convolutional Networks for Biomedical Image Segmentation
import torch
import torch.nn as nn
import torchvision.transforms.functional

# Implement the double 3X3 convolution blocks
# The original paper did not use padding, but we will use padding to keep the image size the same

class double_convolution(nn.Module):
    """
    This class implements the double convolution block which consists of two 3X3 convolution layers,
    each followed by a ReLU activation function.

    """
    def __init__(self, in_channels, out_channels): # Initialize the class
        super().__init__() # Initialize the parent class

        # First 3X3 convolution layer
        self.first_cnn = nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 3, padding = 1)
        self.act1 = nn.ReLU()

        # Second 3X3 convolution layer
        self.second_cnn = nn.Conv2d(in_channels = out_channels, out_channels = out_channels, kernel_size = 3, padding = 1)
        self.act2 = nn.ReLU()

    # Pass the input through the double convolution block
    def forward(self, x):
        x = self.first_cnn(x)
        x = self.act1(x)
        x = self.act2(self.second_cnn(x))
        return x


# Implement the Downsample block that occurs after each double convolution block
class down_sample(nn.Module):
    """
    This class implements the downsample block which consists of a Max Pooling layer with a kernel size of 2.
    The Max Pooling layer halves the image size reducing the spatial resolution of the feature maps
    while retaining the most important features.
    """
    def __init__(self):
        super().__init__()
        self.max_pool = nn.MaxPool2d(kernel_size = 2, stride = 2)

    # Pass the input through the downsample block
    def forward(self, x):
        x = self.max_pool(x)
        return x

# Implement the UpSample block that occurs in the decoder part of the network
class up_sample(nn.Module):
    """
    This class implements the upsample block which consists of a convolution transpose layer with a kernel size of 2.
    The convolution transpose layer doubles the image size increasing the spatial resolution of the feature maps
    while retaining the learned features.
    """
    def __init__(self, in_channels, out_channels):
        super().__init__()

        # Convolution transpose layer
        self.up_sample = nn.ConvTranspose2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 2, stride = 2)

    # Pass the input through the upsample block
    def forward(self, x):
        x = self.up_sample(x)
        return x

# Implement the crop and concatenate block that occurs in the decoder part of the network
# This block concatenates the output of the upsample block with the output of the corresponding downsample block
# The output of the crop and concatenate block is then passed through a double convolution block
class crop_and_concatenate(nn.Module):
    """
    This class implements the crop and concatenate block which combines the output of the upsample block
    with the corresponding features from the contracting path through skip connections,
    allowing the network to recover the fine-grained details lost during downsampling
    and produce a high-resolution output segmentation map.
    """
    # def forward(self, upsampled, bypass):
    #     # Crop the feature map from the contacting path to match the size of the upsampled feature map
    #     bypass = torchvision.transforms.functional.center_crop(img = bypass, output_size = [upsampled.shape[2], upsampled.shape[3]])
    #     # Concatenate the upsampled feature map with the cropped feature map from the contracting path
    #     x = torch.cat([upsampled, bypass], dim = 1) # Concatenate along the channel dimension
    #     return x
    # Alternatively crop the upsampled feature map to match the size of the feature map from the contracting path
    def forward(self, upsampled, bypass):
        upsampled = torchvision.transforms.functional.resize(img = upsampled, size = bypass.shape[2:], antialias=True)
        x = torch.cat([upsampled, bypass], dim = 1) # Concatenate along the channel dimension
        return x

# m = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
# input = torch.randn(1, 1024, 28, 28)
# m(input).shape

# m = nn.MaxPool2d(kernel_size = 2, stride = 2)
# xx = torch.randn(1, 1, 143, 143)
# m(xx).shape

## Implement the UNet architecture
class UNet(nn.Module):
    # in_channels: number of channels in the input image
    # out_channels: number of channels in the output image
    def __init__(self, in_channels, out_channels):
        super().__init__()

        # Define the contracting path: convolution blocks followed by downsample blocks
        self.down_conv = nn.ModuleList(double_convolution(in_chans, out_chans) for in_chans, out_chans in
                                       [(in_channels, 64), (64, 128), (128, 256), (256, 512)]) # List of downsample blocks

        self.down_samples = nn.ModuleList(down_sample() for _ in range(4))

        # Define the bottleneck layer
        self.bottleneck = double_convolution(in_channels = 512, out_channels = 1024)

        # Define the expanding path: upsample blocks followed by convolution blocks
        self.up_samples = nn.ModuleList(up_sample(in_chans, out_chans) for in_chans, out_chans in
                                        [(1024, 512), (512, 256), (256, 128), (128, 64)]) # List of upsample blocks

        self.concat = nn.ModuleList(crop_and_concatenate() for _ in range(4))

        self.up_conv = nn.ModuleList(double_convolution(in_chans, out_chans) for in_chans, out_chans in
                                        [(1024, 512), (512, 256), (256, 128), (128, 64)]) # List of convolution blocks

        # Final 1X1 convolution layer to produce the output segmentation map:
        # The primary purpose of 1x1 convolutions is to transform the channel dimension of the feature map,
        # while leaving the spatial dimensions unchanged.
        self.final_conv = nn.Conv2d(in_channels = 64, out_channels = out_channels, kernel_size = 1)

    # Pass the input through the UNet architecture
    def forward(self, x):
        # Pass the input through the contacting path
        skip_connections = [] # List to store the outputs of the downsample blocks
        for down_conv, down_sample in zip(self.down_conv, self.down_samples):
            x = down_conv(x)
            skip_connections.append(x)
            x = down_sample(x)

        # Pass the output of the contacting path through the bottleneck layer
        x = self.bottleneck(x)

        # Pass the output of the bottleneck layer through the expanding path
        skip_connections = skip_connections[::-1] # Reverse the list of skip connections
        for up_sample, concat, up_conv in zip(self.up_samples, self.concat, self.up_conv):
            x = up_sample(x)
            x = concat(x, skip_connections.pop(0)) # Remove the first element from the list of skip connections
            x = up_conv(x)

        # Pass the output of the expanding path through the final convolution layer
        x = self.final_conv(x)
        return x

# Sanity check for the model
import torchsummary
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(in_channels = 3, out_channels = 1).to(device)
dummy_input = torch.randn((1, 3, 32, 32)).to(device)
mask = model(dummy_input)
mask.shape

# See how data flows through the network
torchsummary.summary(model, input_size = (3, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 32, 32]           1,792
              ReLU-2           [-1, 64, 32, 32]               0
            Conv2d-3           [-1, 64, 32, 32]          36,928
              ReLU-4           [-1, 64, 32, 32]               0
double_convolution-5           [-1, 64, 32, 32]               0
         MaxPool2d-6           [-1, 64, 16, 16]               0
       down_sample-7           [-1, 64, 16, 16]               0
            Conv2d-8          [-1, 128, 16, 16]          73,856
              ReLU-9          [-1, 128, 16, 16]               0
           Conv2d-10          [-1, 128, 16, 16]         147,584
             ReLU-11          [-1, 128, 16, 16]               0
double_convolution-12          [-1, 128, 16, 16]               0
        MaxPool2d-13            [-1, 128, 8, 8]               0
      down_sample-14            [-1, 1

## Training UNet_1

In [50]:
from torch.nn import functional as F

def train_one_epoch(
    model: nn.Module,
    dataloader: torch.utils.data.DataLoader,
    optimizer: torch.optim.Optimizer,
    device: torch.device,
    loss_fn: nn.Module,
    epoch: int = 0,
    log_interval: int = 10
):
    """
    Trains the model for one epoch.

    Args:
        model: the neural network (e.g., UNet)
        dataloader: DataLoader providing training batches
        optimizer: optimizer used for weight updates (e.g., Adam)
        device: device to run computations on (CPU or GPU)
        loss_fn: loss function (e.g., CrossEntropyLoss)
        epoch: current epoch number (used for logging)
        log_interval: print status every N batches

    Returns:
        avg_loss: average training loss over the epoch
        accuracy: pixel-level classification accuracy
    """

    model.train()
    running_loss = 0.0
    correct_pixels = 0
    total_pixels = 0

    for batch_idx, (inputs, targets) in enumerate(dataloader):
        inputs = inputs.to(device)           # Input shape: (B, C, H, W)
        targets = targets.to(device)         # Target shape: (B, H, W), where each pixel has a class label

        optimizer.zero_grad()                # Clear gradients from previous step
        outputs = model(inputs)              # Output shape: (B, num_classes, H, W)

        loss = loss_fn(outputs, targets)     # Calculate loss
        loss.backward()                      # Backpropagation
        optimizer.step()                     # Update model weights

        running_loss += loss.item()          # Accumulate batch loss

        # Compute pixel-wise accuracy
        preds = outputs.argmax(dim=1)        # Get predicted class per pixel, shape: (B, H, W)
        correct_pixels += (preds == targets).sum().item()  # Count correct predictions
        total_pixels += targets.numel()      # Total number of pixels

        if batch_idx % log_interval == 0:
            print(
                f"[Epoch {epoch}] Batch {batch_idx}/{len(dataloader)} - "
                f"Loss: {loss.item():.4f}"
            )

    # Average loss and overall accuracy for the epoch
    avg_loss = running_loss / len(dataloader)
    accuracy = 100.0 * correct_pixels / total_pixels

    print(f"Epoch {epoch} finished. Avg Loss: {avg_loss:.4f} - Accuracy: {accuracy:.2f}%")
    return avg_loss, accuracy

@torch.no_grad()
def validate_one_epoch(
    model: nn.Module,
    dataloader: torch.utils.data.DataLoader,
    device: torch.device,
    loss_fn: nn.Module
):
    """
    Evaluates the model on the validation set.

    Args:
        model: the neural network
        dataloader: DataLoader for validation data
        device: device to run computations on
        loss_fn: same loss function as used in training

    Returns:
        avg_loss: average validation loss
        accuracy: pixel-wise classification accuracy
    """

    model.eval() # Set model to evaluation mode (disables dropout, etc.)
    running_loss = 0.0
    correct_pixels = 0
    total_pixels = 0

    for inputs, targets in dataloader:
        inputs = inputs.to(device)
        targets = targets.to(device)

        outputs = model(inputs)              # Forward pass
        loss = loss_fn(outputs, targets)     # Compute loss
        running_loss += loss.item()

        preds = outputs.argmax(dim=1)        # Get predicted class per pixel
        correct_pixels += (preds == targets).sum().item()
        total_pixels += targets.numel()

    avg_loss = running_loss / len(dataloader)
    accuracy = 100.0 * correct_pixels / total_pixels
    print(f"Validation - Avg Loss: {avg_loss:.4f} - Accuracy: {accuracy:.2f}%")
    return avg_loss, accuracy

In [51]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# in_channels=3 for RGB images, out_channels=2 for binary segmentation with 2 classes (background vs object)
model = UNet(in_channels=3, out_channels=2).to(device)

# Use the Adam optimizer with a small learning rate (good starting point for U-Net training)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# Use CrossEntropyLoss, which is suitable for multi-class (including binary) classification problems
# It expects raw logits from the model and class indices (not one-hot) as targets
loss_fn = nn.CrossEntropyLoss()

# Set the number of training epochs
num_epochs = 30

# Training loop over all epochs
for epoch in range(1, num_epochs + 1):
    # Train the model for one epoch on the training dataset
    train_one_epoch(model, dataloader_train_cv, optimizer, device, loss_fn, epoch)

    # Evaluate the model on the validation dataset after each epoch
    validate_one_epoch(model, dataloader_val_cv, device, loss_fn)

ValueError: Input and output must have the same number of spatial dimensions, but got input with spatial dimensions of [1608] and output size of [512, 512]. Please provide input tensor in (N, C, d1, d2, ...,dK) format and output size in (o1, o2, ...,oK) format.

In [None]:
import os

path = "data_train/Training-labeled/images/cell_00311.tiff"
print(f"File exists: {os.path.exists(path)}")
print(f"File size: {os.path.getsize(path)} bytes")

import cv2

img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
print(type(img), img.shape if img is not None else "Can't load image")

File exists: True
File size: 8000272 bytes
<class 'numpy.ndarray'> (1000, 1000)


In [None]:
import os
from PIL import Image, UnidentifiedImageError

def validate_tiff_dataset(image_dir, label_dir):
    broken_images = []
    broken_labels = []

    image_files = sorted(os.listdir(image_dir))
    label_files = sorted(os.listdir(label_dir))

    print(f"Sprawdzam {len(image_files)} obrazów i {len(label_files)} masek...")

    for filename in image_files:
        img_path = os.path.join(image_dir, filename)
        try:
            with Image.open(img_path) as img:
                img.verify()  # tylko weryfikuje (nie ładuje całkowicie)
        except (UnidentifiedImageError, OSError, FileNotFoundError) as e:
            print(f"[BŁĄD OBRAZU] {filename}: {e}")
            broken_images.append(filename)

    for filename in label_files:
        label_path = os.path.join(label_dir, filename)
        try:
            with Image.open(label_path) as img:
                img.verify()
        except (UnidentifiedImageError, OSError, FileNotFoundError) as e:
            print(f"[BŁĄD MASKI] {filename}: {e}")
            broken_labels.append(filename)

    print("\nPodsumowanie:")
    print(f"- Uszkodzone obrazy: {len(broken_images)}")
    print(f"- Uszkodzone maski: {len(broken_labels)}")

    return broken_images, broken_labels

bad_imgs, bad_labels = validate_tiff_dataset(TRAIN_PATH, TRAIN_LABELS_PATH)

Sprawdzam 1000 obrazów i 1000 masek...
[BŁĄD OBRAZU] cell_00301.tiff: cannot identify image file 'data_train/Training-labeled/images/cell_00301.tiff'
[BŁĄD OBRAZU] cell_00302.tiff: cannot identify image file 'data_train/Training-labeled/images/cell_00302.tiff'
[BŁĄD OBRAZU] cell_00303.tiff: cannot identify image file 'data_train/Training-labeled/images/cell_00303.tiff'
[BŁĄD OBRAZU] cell_00304.tiff: cannot identify image file 'data_train/Training-labeled/images/cell_00304.tiff'
[BŁĄD OBRAZU] cell_00305.tiff: cannot identify image file 'data_train/Training-labeled/images/cell_00305.tiff'
[BŁĄD OBRAZU] cell_00306.tiff: cannot identify image file 'data_train/Training-labeled/images/cell_00306.tiff'
[BŁĄD OBRAZU] cell_00307.tiff: cannot identify image file 'data_train/Training-labeled/images/cell_00307.tiff'
[BŁĄD OBRAZU] cell_00308.tiff: cannot identify image file 'data_train/Training-labeled/images/cell_00308.tiff'
[BŁĄD OBRAZU] cell_00309.tiff: cannot identify image file 'data_train/Tra

In [None]:
Image.open('data_train/Training-labeled/images/cell_00501.tif').verify()

UnidentifiedImageError: cannot identify image file 'data_train/Training-labeled/images/cell_00501.tif'

In [None]:
target_np = cv2.imread('data_train/Training-labeled/labels/cell_00490_label.tiff', cv2.IMREAD_UNCHANGED)
print(target_np)

[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]]


In [None]:
import cv2

for img in bad_imgs:
    path = os.path.join(TRAIN_PATH, img)
    image = cv2.imread(path, cv2.IMREAD_UNCHANGED)
    if image is None:
        print(f"{img} nie da się otworzyć przez OpenCV")
    else:
        print(f"{img} OK — można użyć OpenCV zamiast PIL")

cell_00301.tiff OK — można użyć OpenCV zamiast PIL
cell_00302.tiff OK — można użyć OpenCV zamiast PIL
cell_00303.tiff OK — można użyć OpenCV zamiast PIL
cell_00304.tiff OK — można użyć OpenCV zamiast PIL
cell_00305.tiff OK — można użyć OpenCV zamiast PIL
cell_00306.tiff OK — można użyć OpenCV zamiast PIL
cell_00307.tiff OK — można użyć OpenCV zamiast PIL
cell_00308.tiff OK — można użyć OpenCV zamiast PIL
cell_00309.tiff OK — można użyć OpenCV zamiast PIL
cell_00310.tiff OK — można użyć OpenCV zamiast PIL
cell_00311.tiff OK — można użyć OpenCV zamiast PIL
cell_00312.tiff OK — można użyć OpenCV zamiast PIL
cell_00313.tiff OK — można użyć OpenCV zamiast PIL
cell_00314.tiff OK — można użyć OpenCV zamiast PIL
cell_00315.tiff OK — można użyć OpenCV zamiast PIL
cell_00316.tiff OK — można użyć OpenCV zamiast PIL
cell_00501.tif OK — można użyć OpenCV zamiast PIL
cell_00502.tif OK — można użyć OpenCV zamiast PIL
cell_00503.tif OK — można użyć OpenCV zamiast PIL
cell_00504.tif OK — można użyć Ope

## Testing UNet_1

In [None]:
@torch.no_grad()
def test_unet(model, dataloader, device):
    model.eval()
    total_correct = 0
    total_pixels = 0

    for images, masks in dataloader:
        images = images.to(device)             # (B, C, H, W)
        masks = masks.to(device)               # (B, H, W) — ground truth z etykietami klas

        outputs = model(images)                # (B, num_classes, H, W)
        preds = outputs.argmax(dim=1)          # (B, H, W)

        total_correct += (preds == masks).sum().item()
        total_pixels += masks.numel()

    accuracy = 100.0 * total_correct / total_pixels
    print(f"Test Accuracy: {accuracy:.2f}%")
    return accuracy

import matplotlib.pyplot as plt

@torch.no_grad()
def visualize_predictions(model, dataloader, device, num_examples=3):
    model.eval()
    for images, masks in dataloader:
        images = images.to(device)
        outputs = model(images)
        preds = outputs.argmax(dim=1)

        for i in range(min(num_examples, images.size(0))):
            img = images[i].cpu().permute(1, 2, 0).numpy()
            mask = masks[i].cpu().numpy()
            pred = preds[i].cpu().numpy()

            fig, axs = plt.subplots(1, 3, figsize=(12, 4))
            axs[0].imshow(img)
            axs[0].set_title("Input Image")
            axs[1].imshow(mask)
            axs[1].set_title("Ground Truth")
            axs[2].imshow(pred)
            axs[2].set_title("Prediction")
            plt.show()
        break  # tylko jedna batch

In [None]:
test_unet(model, dataloader_test, device)
visualize_predictions(model, dataloader_test, device)

RuntimeError: stack expects each tensor to be equal size, but got [1, 566, 630] at entry 0 and [3, 944, 1266] at entry 1

## ResidualAttentionUnet (UNet_2) from https://rpubs.com/eR_ic/unet

In [None]:
import torchvision

# Define a Residual block
class residual_block(nn.Module):
    """
    This class implements a residual block which consists of two convolution layers with group normalization
    """
    def __init__(self, in_channels, out_channels, n_groups = 8):
        super().__init__()
        # First convolution layer
        self.first_conv = nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 3, padding = 1)
        self.first_norm = nn.GroupNorm(num_groups = n_groups, num_channels = out_channels)
        self.act1 = nn.SiLU() # Swish activation function

        # Second convolution layer
        self.second_conv = nn.Conv2d(in_channels = out_channels, out_channels = out_channels, kernel_size = 3, padding = 1)
        self.second_norm = nn.GroupNorm(num_groups = n_groups, num_channels = out_channels)
        self.act2 = nn.SiLU() # Swish activation function

        # If the number of input channels is not equal to the number of output channels,
        # then use a 1X1 convolution layer to compensate for the difference in dimensions
        # This allows the input to have the same dimensions as the output of the residual block
        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 1)
        else:
            # Pass the input as is
            self.shortcut = nn.Identity()

    # Pass the input through the residual block
    def forward(self, x):
        # Store the input
        input = x

        # Pass input through the first convolution layer
        x = self.act1(self.second_norm(self.first_conv(x)))

        # Pass the output of the first convolution layer through the second convolution layer
        x = self.act2(self.second_norm(self.second_conv(x)))

        # Add the input to the output of the second convolution layer
        # This is the skip connection
        x = x + self.shortcut(input)
        return x

# Implement the DownSample block that occurs after each residual block
class down_sample(nn.Module):
    def __init__(self):
        super().__init__()
        self.max_pool = nn.MaxPool2d(kernel_size = 2, stride = 2)

    # Pass the input through the downsample block
    def forward(self, x):
        x = self.max_pool(x)
        return x

# Implement the UpSample block that occurs in the decoder path/expanding path
class up_sample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        # Convolution transpose layer to upsample the input
        self.up_sample = nn.ConvTranspose2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 2, stride = 2)

    # Pass the input through the upsample block
    def forward(self, x):
        x = self.up_sample(x)
        return x

# Implement the crop and concatenate layer
class crop_and_concatenate(nn.Module):
    def forward(self, upsampled, bypass):
        # Crop the upsampled feature map to match the dimensions of the bypass feature map
        upsampled = torchvision.transforms.functional.resize(upsampled, size = bypass.shape[2:], antialias=True)
        x = torch.cat([upsampled, bypass], dim = 1) # Concatenate along the channel dimension
        return x

# Implement an attention block
class attention_block(nn.Module):
    def __init__(self, skip_channels, gate_channels, inter_channels = None, n_groups = 8):
        super().__init__()

        if inter_channels is None:
            inter_channels = skip_channels // 2

        # Implement W_g i.e the convolution layer that operates on the gate signal
        # Upsample gate signal to be the same size as the skip connection
        self.W_g = up_sample(in_channels = gate_channels, out_channels = skip_channels)
        #self.W_g_norm = nn.GroupNorm(num_groups = n_groups, num_channels = skip_channels)
        #self.W_g_act = nn.SiLU() # Swish activation function

        # Implement W_x i.e the convolution layer that operates on the skip connection
        self.W_x = nn.Conv2d(in_channels = skip_channels, out_channels = inter_channels, kernel_size = 1)
        #self.W_x_norm = nn.GroupNorm(num_groups = n_groups, num_channels = inter_channels)
        #self.W_x_act = nn.SiLU() # Swish activation function

        # Implement phi i.e the convolution layer that operates on the output of W_x + W_g
        self.phi = nn.Conv2d(in_channels = inter_channels, out_channels = 1, kernel_size = 1)
        #self.phi_norm = nn.GroupNorm(num_groups = n_groups, num_channels = 1)
        #self.phi_act = nn.SiLU() # Swish activation function

        # Implement the sigmoid activation function
        self.sigmoid = nn.Sigmoid()
        # Implement the Swish activation function
        self.act = nn.SiLU()

        # Implement final group normalization layer
        self.final_norm = nn.GroupNorm(num_groups = n_groups, num_channels = skip_channels)

    # Pass the input through the attention block
    def forward(self, skip_connection, gate_signal):
        # Upsample the gate signal to match the channels of the skip connection
        gate_signal = self.W_g(gate_signal)
        # Ensure that the sizes of the skip connection and the gate signal match before addition
        if gate_signal.shape[2:] != skip_connection.shape[2:]:
            gate_signal = torchvision.transforms.functional.resize(gate_signal, size = skip_connection.shape[2:], antialias=True)
        # Project to the intermediate channels
        gate_signal = self.W_x(gate_signal)

        # Project the skip connection to the intermediate channels
        skip_signal = self.W_x(skip_connection)

        # Add the skip connection and the gate signal
        add_xg = gate_signal + skip_signal

        # Pass the output of the addition through the activation function
        add_xg = self.act(add_xg)

        # Pass the output of attention through a 1x1 convolution layer to obtain the attention map
        attention_map = self.sigmoid(self.phi(add_xg))

        # Multiply the skip connection with the attention map
        # Perform element-wise multiplication
        skip_connection = torch.mul(skip_connection, attention_map)

        skip_connection = nn.Conv2d(in_channels = skip_connection.shape[1], out_channels = skip_connection.shape[1], kernel_size = 1)(skip_connection)
        skip_connection = self.act(self.final_norm(skip_connection))

        return skip_connection


## Implement a residual attention U-Net
class ResidualAttentionUnet(nn.Module):
    def __init__(self, in_channels, out_channels, n_groups = 4, n_channels = [64, 128, 256, 512, 1024]):
        super().__init__()

        # Define the contracting path: residual blocks followed by downsampling
        self.down_conv = nn.ModuleList(residual_block(in_chans, out_chans) for in_chans, out_chans in
                                       [(in_channels, n_channels[0]), (n_channels[0], n_channels[1]), (n_channels[1], n_channels[2]), (n_channels[2], n_channels[3])])
        self.down_samples = nn.ModuleList(down_sample() for _ in range(4))

        # Define the bottleneck residual block
        self.bottleneck = residual_block(n_channels[3], n_channels[4])

        # Define the attention blocks
        self.attention_blocks = nn.ModuleList(attention_block(skip_channels = residuals_chans, gate_channels = gate_chans) for gate_chans, residuals_chans in
                                              [(n_channels[4], n_channels[3]), (n_channels[3], n_channels[2]), (n_channels[2], n_channels[1]), (n_channels[1], n_channels[0])])

        # Define the expanding path: upsample blocks, followed by crop and concatenate, followed by residual blocks
        self.upsamples = nn.ModuleList(up_sample(in_chans, out_chans) for in_chans, out_chans in
                                       [(n_channels[4], n_channels[3]), (n_channels[3], n_channels[2]), (n_channels[2], n_channels[1]), (n_channels[1], n_channels[0])])

        self.concat = nn.ModuleList(crop_and_concatenate() for _ in range(4))

        self.up_conv = nn.ModuleList(residual_block(in_chans, out_chans) for in_chans, out_chans in
                                     [(n_channels[4], n_channels[3]), (n_channels[3], n_channels[2]), (n_channels[2], n_channels[1]), (n_channels[1], n_channels[0])])

        # Final 1X1 convolution layer to produce the output segmentation map:
        # The primary purpose of 1x1 convolutions is to transform the channel dimension of the feature map,
        # while leaving the spatial dimensions unchanged.
        self.final_conv = nn.Conv2d(in_channels = n_channels[0] , out_channels = out_channels, kernel_size = 1)

    # Pass the input through the residual attention U-Net
    def forward(self, x):
        # Store the skip connections
        skip_connections = []
        # # Store the gate signals
        # gate_signals = []

        # Pass the input through the contracting path
        for down_conv, down_sample in zip(self.down_conv, self.down_samples):
            x = down_conv(x)
            skip_connections.append(x)
            #gate_signals.append(x)
            x = down_sample(x)

        # Pass the output of the contracting path through the bottleneck
        x = self.bottleneck(x)
        skip_connections.append(x)

        # Attention on the residual connections
        #skip_connections = skip_connections[::-1]
        n = len(skip_connections)
        indices = [(n - 1 - i, n - 2 - i) for i in range(n - 1)]
        attentions = []
        for i, g_x in enumerate(indices):
            g_gate = g_x[0]
            x_residual = g_x[1]
            attn = self.attention_blocks[i](skip_connections[x_residual], skip_connections[g_gate])
            attentions.append(attn)

        #attentions = attentions[::-1]

        # Pass the output of the attention blocks through the expanding path
        for up_sample, concat, up_conv in zip(self.upsamples, self.concat, self.up_conv):
            x = up_sample(x)
            x = concat(x, attentions.pop(0))
            x = up_conv(x)

        # Pass the output of the expanding path through the final convolution layer
        x = self.final_conv(x)
        return x

## Sanity check
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ResidualAttentionUnet(in_channels = 3, out_channels = 1).to(device)
x = torch.randn((1, 3, 32, 32)).to(device)
mask = model(x)
mask.shape

# See how data flows through the network
torchsummary.summary(model, input_size = (3, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 32, 32]           1,792
         GroupNorm-2           [-1, 64, 32, 32]             128
              SiLU-3           [-1, 64, 32, 32]               0
            Conv2d-4           [-1, 64, 32, 32]          36,928
         GroupNorm-5           [-1, 64, 32, 32]             128
              SiLU-6           [-1, 64, 32, 32]               0
            Conv2d-7           [-1, 64, 32, 32]             256
    residual_block-8           [-1, 64, 32, 32]               0
         MaxPool2d-9           [-1, 64, 16, 16]               0
      down_sample-10           [-1, 64, 16, 16]               0
           Conv2d-11          [-1, 128, 16, 16]          73,856
        GroupNorm-12          [-1, 128, 16, 16]             256
             SiLU-13          [-1, 128, 16, 16]               0
           Conv2d-14          [-1, 128,

## Training UNet_2

## Testing UNet_2