In [2]:
import numpy as np
import tensorflow as tf
import os
import pydicom
import cv2
import matplotlib.pyplot as plt
from pydicom.dataset import Dataset, FileDataset
from pydicom.uid import generate_uid
from datetime import datetime
import time

In [3]:
def load_dicom_images(dicom_dir, image_size=(128, 128)):
   
    images = []
    for filename in os.listdir(dicom_dir):
        if filename.endswith(".dcm"):
            filepath = os.path.join(dicom_dir, filename)
            dicom = pydicom.dcmread(filepath)
            image = dicom.pixel_array
            image = cv2.resize(image, image_size)
            image = (image - np.min(image)) / (np.max(image) - np.min(image))
            images.append(image)

    images = np.array(images)
    images = np.expand_dims(images, axis=-1)
    return images

# Load images
dicom_dir = '/home/matrix/Downloads/DIcom gans/Data/raw'
images = load_dicom_images(dicom_dir)
print(f"Loaded {len(images)} DICOM images.")

Loaded 450 DICOM images.


In [4]:
from tensorflow.keras import layers

In [5]:
def build_generator(input_shape=(100,)):
    model = tf.keras.Sequential([
        tf.keras.layers.Dense(128 * 32 * 32, input_dim=input_shape[0]),
        tf.keras.layers.Reshape((32, 32, 128)),
        tf.keras.layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding='same'),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.LeakyReLU(),
        tf.keras.layers.Conv2DTranspose(64, (4, 4), strides=(2, 2), padding='same'),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.LeakyReLU(),
        tf.keras.layers.Conv2D(1, (7, 7), activation='tanh', padding='same')
    ])
    return model

In [6]:

def build_discriminator(input_shape=(128, 128, 1)):
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(64, (3, 3), padding='same', input_shape=input_shape),
        tf.keras.layers.LeakyReLU(),
        tf.keras.layers.Conv2D(128, (3, 3), strides=(2, 2)),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.LeakyReLU(),
        tf.keras.layers.Conv2D(256, (3, 3), strides=(2, 2)),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.LeakyReLU(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(1, activation='sigmoid')
    ])
    return model

In [7]:
# Cell 3: Build GAN

def build_gan(generator, discriminator):
    discriminator.trainable = False
    gan_input = tf.keras.layers.Input(shape=(100,))
    generated_image = generator(gan_input)
    gan_output = discriminator(generated_image)
    gan = tf.keras.Model(gan_input, gan_output)
    return gan

In [8]:
# Initialize models
generator = build_generator()
discriminator = build_discriminator()
gan = build_gan(generator, discriminator)

# Compile models
discriminator.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
gan.compile(optimizer='adam', loss='binary_crossentropy')


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


In [9]:
def save_dicom_image(image, filepath, original_dicom):
    """
    Save a single image as a DICOM file, preserving metadata from the original DICOM.
    """
    ds = original_dicom.copy()
    
    # Update necessary attributes
    ds.SOPInstanceUID = generate_uid()
    ds.file_meta.MediaStorageSOPInstanceUID = ds.SOPInstanceUID
    ds.PixelData = (image * 255).astype(np.uint8).tobytes()
    ds.Rows, ds.Columns = image.shape
    
    # Save the DICOM file
    ds.save_as(filepath)
    print(f"DICOM file saved: {filepath}")


In [9]:
# # Compile models
# discriminator.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
# gan.compile(optimizer='adam', loss='binary_crossentropy')

In [10]:
# # Training loop
# real_labels = np.ones((batch_size, 1))
# fake_labels = np.zeros((batch_size, 1))

# for epoch in range(epochs):
#     # Train discriminator
#     idx = np.random.randint(0, images.shape[0], batch_size)
#     real_images = images[idx]
#     noise = np.random.normal(0, 1, (batch_size, 100))
#     fake_images = generator.predict(noise)
    
#     d_loss_real = discriminator.train_on_batch(real_images, real_labels)
#     d_loss_fake = discriminator.train_on_batch(fake_images, fake_labels)
    
#     # Train generator
#     noise = np.random.normal(0, 1, (batch_size, 100))
#     g_loss = gan.train_on_batch(noise, real_labels)
    
#     if epoch % 100 == 0:
#         print(f"{epoch} [D loss: {0.5 * np.add(d_loss_real, d_loss_fake)}] [G loss: {g_loss}]")
#         # Optionally save images and model checkpoints here


In [11]:
def save_images_as_dicom(generated_images, epoch, original_dicom, save_dir='outputs/generated_dicom'):
    """
    Save generated images as DICOM files.
    """
    os.makedirs(save_dir, exist_ok=True)
    for i, img in enumerate(generated_images):
        filepath = os.path.join(save_dir, f"epoch{epoch}_img{i}.dcm")
        save_dicom_image(img.squeeze(), filepath, original_dicom)

In [14]:
batch_size = 32
epochs = 1000  # Increase the number of epochs for better results
save_interval = 100  # Save generated images every 100 epochs

# Get a sample original DICOM for metadata
sample_dicom = pydicom.dcmread(os.path.join(dicom_dir, os.listdir(dicom_dir)[0]))

real_labels = np.ones((batch_size, 1))
fake_labels = np.zeros((batch_size, 1))

for epoch in range(epochs):
    start_time = time.time()

    # Train discriminator
    idx = np.random.randint(0, images.shape[0], batch_size)
    real_images = images[idx]
    noise = np.random.normal(0, 1, (batch_size, 100))
    fake_images = generator.predict(noise)

    d_loss_real = discriminator.train_on_batch(real_images, real_labels)
    d_loss_fake = discriminator.train_on_batch(fake_images, fake_labels)

    # Train generator
    noise = np.random.normal(0, 1, (batch_size, 100))
    g_loss = gan.train_on_batch(noise, real_labels)

    # End of training step
    end_time = time.time()
    epoch_time = end_time - start_time

    # Print progress
    if epoch % 10 == 0:
        print(f"Epoch {epoch + 1}/{epochs}, Time: {epoch_time:.2f}s, "
              f"D loss: {0.5 * (d_loss_real[0] + d_loss_fake[0]):.4f}, "
              f"G loss: {g_loss[0]:.4f}")

    # Save generated images as DICOM
    if epoch % save_interval == 0:
        save_images_as_dicom(fake_images, epoch, sample_dicom)

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 693ms/step
Epoch 1/1000, Time: 8.05s, D loss: 0.6887, G loss: 0.6897
DICOM file saved: outputs/generated_dicom/epoch0_img0.dcm
DICOM file saved: outputs/generated_dicom/epoch0_img1.dcm
DICOM file saved: outputs/generated_dicom/epoch0_img2.dcm
DICOM file saved: outputs/generated_dicom/epoch0_img3.dcm
DICOM file saved: outputs/generated_dicom/epoch0_img4.dcm
DICOM file saved: outputs/generated_dicom/epoch0_img5.dcm
DICOM file saved: outputs/generated_dicom/epoch0_img6.dcm
DICOM file saved: outputs/generated_dicom/epoch0_img7.dcm
DICOM file saved: outputs/generated_dicom/epoch0_img8.dcm
DICOM file saved: outputs/generated_dicom/epoch0_img9.dcm
DICOM file saved: outputs/generated_dicom/epoch0_img10.dcm
DICOM file saved: outputs/generated_dicom/epoch0_img11.dcm
DICOM file saved: outputs/generated_dicom/epoch0_img12.dcm
DICOM file saved: outputs/generated_dicom/epoch0_img13.dcm
DICOM file saved: outputs/generated_dicom/epoch0_img1

In [None]:
# Add this after the training loop

num_to_generate = 450  # Set this to how many new images you want
noise = np.random.normal(0, 1, (num_to_generate, 100))
generated_images = generator.predict(noise)
save_images_as_dicom(generated_images, "final_batch", sample_dicom)

print(f"Generated {num_to_generate} new DICOM images.")