<a href="https://colab.research.google.com/github/Navaneeth272001/GAN-Based-X-Ray-Artifact-Detection-and-Removal/blob/main/GAN_Based_X_Ray_Artifact_Removal.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Notes
Let's start with a simple dataset to ensure that the model works and to prove that this approach actually is a possible solution. After that we can try out a more complicated dataset which has more variations.

## Simple Dataset
- `Thickness` of the grid lines is constant throughout the dataset
- `Distance` of the grid lines is constant throughout the dataset
- `Intensity` of the grid lines is constant throughout the dataset
- `Angle` of the grid lines is either 0°(`horizontal`) or 90°(`veritcal`)

## Complex Dataset
- `Thickness` of the grid lines vary throughout the dataset but constant for one X-ray
- `Distance` of the grid lines vary throughout the dataset but constant for one X-ray
- `Intensity` of the grid lines vary throughout the dataset but constant for one X-ray
- `Angle` of the grid lines vary between 0° and 90° but constant for one X-ray

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# Simple Dataset Preparation

In [None]:
"""
PARAMETERS FOR THE DATASET CREATION
"""

THICKNESS = 2  # in opencv scale (can be thought of pixels)
DISTANCE = 5  # in pixels
INTENSITY = 15  # this value will be subracted to the R, G, B channels
ANGLES = ["0", "90"]  # angle for one X-ray is randomly sampled from this list

In [None]:
"""
import neccessary packages
"""

# built-in packages
import glob
import random
from typing import Tuple

# third-party packages
import cv2
import numpy as np
from google.colab.patches import cv2_imshow

In [None]:
"""
Return a binary image containing horizontal grid lines
"""

def get_horizontal_grid(nrows: int, ncols: int) -> np.ndarray:
    binary_mask = np.zeros((nrows, ncols, 1), np.uint8)

    for row_idx in range(nrows):
        if row_idx % DISTANCE == 0:
            cv2.line(binary_mask, (0, row_idx), (ncols, row_idx), 255, THICKNESS)

    return binary_mask

In [None]:
"""
Return a binary image containing vertical grid lines
"""

def get_vertical_grid(nrows: int, ncols: int) -> np.ndarray:
    binary_mask = np.zeros((nrows, ncols, 1), np.uint8)

    for col_idx in range(ncols):
        if col_idx % DISTANCE == 0:
            cv2.line(binary_mask, (col_idx, 0), (col_idx, nrows), 255, THICKNESS)

    return binary_mask

In [None]:
"""
ONLY FOR VERTICAL AND HORIZONTAL GRIDS

Draw grid lines on the given image.

Return a new image with grid lines and it's corresponding binary grid.
"""

def draw_grid_lines(image: np.ndarray):
    nrows, ncols, nchannels = image.shape

    # generate the binary grid mask
    angle = random.choice(ANGLES)
    if angle == "0":
        binary_grid = get_horizontal_grid(nrows, ncols)
    elif angle == "90":
        binary_grid = get_vertical_grid(nrows, ncols)
    else:
        raise Exception("Grid lines with angle other than 0° and 90° are not implemented")

    # apply the grid mask on the given image
    intensities = np.zeros((image.shape)) - INTENSITY
    grid_masked_image = cv2.add(image, intensities, mask=binary_grid, dtype=cv2.CV_8U)
    bg_image = cv2.bitwise_and(image, image, mask=cv2.bitwise_not(binary_grid))

    grid_image = cv2.add(bg_image, grid_masked_image)

    return grid_image, binary_grid

In [None]:
orig_image = cv2.imread("data/trainB/0.png")

out, _ = draw_grid_lines(orig_image)

cv2_imshow(out)

AttributeError: ignored

# Complex Dataset Preparation

In [None]:
"""
PARAMETERS FOR THE DATASET CREATION
"""

THICKNESS_RANGE = [1, 3]  # in opencv scale (can be thought of pixels)
DISTANCE_RANGE = [7, 9]  # in pixels
INTENSITY_RANGE = [10, 20]  # this value will be subracted to the R, G, B channels

In [None]:
"""
import neccessary packages
"""

# built-in packages
import glob
import random
from typing import Tuple

# third-party packages
import cv2
import numpy as np
from google.colab.patches import cv2_imshow

In [None]:
"""
Return a binary image containing grid lines in random angle
"""

def get_random_grid(nrows: int, ncols: int) -> np.ndarray:
    DISTANCE = np.random.randint(DISTANCE_RANGE[0], DISTANCE_RANGE[1] + 1)
    THICKNESS = np.random.randint(THICKNESS_RANGE[0], THICKNESS_RANGE[1] + 1)

    binary_mask = np.zeros((nrows, ncols, 1), np.uint8)

    i = np.random.randint(0, nrows)
    j = np.random.randint(0, ncols)

    count = 0
    while i - count >= 0 or j - count >= 0:
        if count % DISTANCE != 0:
            count += 1
            continue

        p1 = (0, i - count)
        p2 = (ncols, j - count)

        cv2.line(binary_mask, p1, p2, 255, THICKNESS)

        count += 1

    count = 0
    while i + count <= nrows or j + count <= nrows:
        if count % DISTANCE != 0:
            count += 1
            continue

        p1 = (0, i + count)
        p2 = (ncols, j + count)

        cv2.line(binary_mask, p1, p2, 255, THICKNESS)

        count += 1

    return binary_mask

In [None]:
"""
Draw grid lines on the given image.

Return a new image with grid lines and it's corresponding binary grid.
"""

def draw_grid_lines(image: np.ndarray):
    INTENSITY = np.random.randint(INTENSITY_RANGE[0], INTENSITY_RANGE[1] + 1)
    nrows, ncols, nchannels = image.shape

    # generate the binary grid mask
    binary_grid = get_random_grid(nrows, ncols)

    # apply the grid mask on the given image
    intensities = np.zeros((image.shape)) - INTENSITY
    grid_masked_image = cv2.add(image, intensities, mask=binary_grid, dtype=cv2.CV_8U)
    bg_image = cv2.bitwise_and(image, image, mask=cv2.bitwise_not(binary_grid))

    grid_image = cv2.add(bg_image, grid_masked_image)

    return grid_image, binary_grid

In [None]:
orig_image = cv2.imread("0.png")
orig_image = cv2.resize(orig_image, (256, 256))

out, mask = draw_grid_lines(orig_image)

cv2_imshow(out)

# Generate Dataset

In [None]:
# copy source images
!cp -R "/content/drive/MyDrive/GAN hack/images" .

In [None]:
"""
Utility function to draw grids on all images in the given source directory
and save it in the given out directory.
"""

def draw_grids(src_dir: str, out_dir: str) -> None:
    # iterate through the file names in the src dir
    image_paths = glob.glob(src_dir + "/*")
    for idx, image_path in enumerate(image_paths):
        orig_image = cv2.imread(image_path)
        grid_image, grid = draw_grid_lines(orig_image)

        file_name = f"{out_dir}/{idx}.png"
        cv2.imwrite(file_name, grid_image)

In [None]:
# Generate B domain images (without grid) for train & test
images = glob.glob("images/*")
images = random.sample(images, 3300)

train_images = images[:3000]
test_images = images[3000:]

for idx, image_path in enumerate(train_images):
    image = cv2.imread(image_path)
    image = cv2.resize(image, (256, 256))
    cv2.imwrite(f"data/trainB/{idx}.png", image)

for idx, image_path in enumerate(test_images):
    image = cv2.imread(image_path)
    image = cv2.resize(image, (256, 256))
    cv2.imwrite(f"data/testB/{idx}.png", image)

In [None]:
# Generate A domain images (with grid) for train & test
images = glob.glob("images/*")
images = random.sample(images, 3300)

train_images = images[:3000]
test_images = images[3000:]

for idx, image_path in enumerate(train_images):
    image = cv2.imread(image_path)
    image = cv2.resize(image, (256, 256))
    grid_image, _ = draw_grid_lines(image)
    cv2.imwrite(f"data/trainA/{idx}.png", grid_image)

for idx, image_path in enumerate(test_images):
    image = cv2.imread(image_path)
    image = cv2.resize(image, (256, 256))
    grid_image, _ = draw_grid_lines(image)
    cv2.imwrite(f"data/testA/{idx}.png", grid_image)

In [None]:
!zip -rq complex_data.zip data

In [None]:
!cp complex_data.zip "/content/drive/MyDrive/GAN hack/"

# Training

In [None]:
# load the data from drive
!cp "/content/drive/MyDrive/GAN hack/complex_data.zip" .
!unzip -q complex_data.zip

In [None]:
!cp "/content/drive/MyDrive/GAN hack/checkpoints.zip" .
!unzip -q checkpoints.zip

In [None]:
!git clone https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix.git
!pip install -r pytorch-CycleGAN-and-pix2pix/requirements.txt

Cloning into 'pytorch-CycleGAN-and-pix2pix'...
remote: Enumerating objects: 2513, done.[K
remote: Total 2513 (delta 0), reused 0 (delta 0), pack-reused 2513[K
Receiving objects: 100% (2513/2513), 8.20 MiB | 7.07 MiB/s, done.
Resolving deltas: 100% (1575/1575), done.
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting dominate>=2.4.0
  Downloading dominate-2.7.0-py2.py3-none-any.whl (29 kB)
Collecting visdom>=0.1.8.8
  Downloading visdom-0.2.4.tar.gz (1.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.4/1.4 MB[0m [31m38.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting wandb
  Downloading wandb-0.15.0-py3-none-any.whl (2.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m76.0 MB/s[0m eta [36m0:00:00[0m
Collecting jsonpatch
  Downloading jsonpatch-1.32-py2.py3-none-any.whl (12 kB)
Collecting pathtools
  Download

In [None]:
# TRAINING

!python pytorch-CycleGAN-and-pix2pix/train.py --dataroot ./data --name xray_cyclegan --model cycle_gan --display_id -1  --display_port 8097 --display_server http://localhost --save_epoch_freq 1 --epoch_count 89 --continue_train

----------------- Options ---------------
               batch_size: 1                             
                    beta1: 0.5                           
          checkpoints_dir: ./checkpoints                 
           continue_train: True                          	[default: False]
                crop_size: 256                           
                 dataroot: ./data                        	[default: None]
             dataset_mode: unaligned                     
                direction: AtoB                          
              display_env: main                          
             display_freq: 400                           
               display_id: -1                            	[default: 1]
            display_ncols: 4                             
             display_port: 8097                          
           display_server: http://localhost              
          display_winsize: 256                           
                    epoch: latest         

In [None]:
# SAVE THE MODEL WEIGHTS

!zip -r checkpoints.zip checkpoints
!cp checkpoints.zip "/content/drive/MyDrive/GAN hack/"

updating: checkpoints/xray_cyclegan/ (stored 0%)
updating: checkpoints/xray_cyclegan/.ipynb_checkpoints/ (stored 0%)
updating: checkpoints/xray_cyclegan/75_net_D_A.pth (deflated 7%)
updating: checkpoints/xray_cyclegan/75_net_D_B.pth (deflated 7%)
updating: checkpoints/xray_cyclegan/75_net_G_A.pth (deflated 7%)
updating: checkpoints/xray_cyclegan/75_net_G_B.pth (deflated 7%)
updating: checkpoints/xray_cyclegan/76_net_D_A.pth (deflated 7%)
updating: checkpoints/xray_cyclegan/76_net_D_B.pth (deflated 7%)
updating: checkpoints/xray_cyclegan/76_net_G_A.pth (deflated 7%)
updating: checkpoints/xray_cyclegan/76_net_G_B.pth (deflated 7%)
updating: checkpoints/xray_cyclegan/latest_net_D_A.pth (deflated 7%)
updating: checkpoints/xray_cyclegan/latest_net_D_B.pth (deflated 7%)
updating: checkpoints/xray_cyclegan/latest_net_G_A.pth (deflated 7%)
updating: checkpoints/xray_cyclegan/latest_net_G_B.pth (deflated 7%)
updating: checkpoints/xray_cyclegan/loss_log.txt (deflated 83%)
updating: checkpoints/x

In [None]:
# TESTING

!python pytorch-CycleGAN-and-pix2pix/test.py --dataroot test --name xray_cyclegan --model cycle_gan --no_dropout --num_test 1 --gpu_ids -1 --eval

----------------- Options ---------------
             aspect_ratio: 1.0                           
               batch_size: 1                             
          checkpoints_dir: ./checkpoints                 
                crop_size: 256                           
                 dataroot: test                          	[default: None]
             dataset_mode: unaligned                     
                direction: AtoB                          
          display_winsize: 256                           
                    epoch: latest                        
                     eval: True                          	[default: False]
                  gpu_ids: -1                            	[default: 0]
                init_gain: 0.02                          
                init_type: normal                        
                 input_nc: 3                             
                  isTrain: False                         	[default: None]
                load_iter: