In [1]:
import tensorflow as tf

print("TensorFlow version:", tf.__version__)

# Check for GPU availability
if tf.config.list_physical_devices('GPU'):
    print("GPU is available.")
else:
    print("GPU is not available.")


TensorFlow version: 2.10.1
GPU is available.


In [2]:
!pip show tqdm


Name: tqdm
Version: 4.66.5
Summary: Fast, Extensible Progress Meter
Home-page: https://tqdm.github.io
Author: 
Author-email: 
License: MPL-2.0 AND MIT
Location: c:\users\zap\.conda\envs\tf_cuda\lib\site-packages
Requires: colorama
Required-by: 


: 

In [3]:
import tensorflow as tf
from tensorflow.keras import layers, models
from tqdm import tqdm  # For the progress bar
import os
import pandas as pd
from PIL import Image
import numpy as np

class ChestXrayDataset:
    def __new__(cls, csv_file, root_dir, transform=None):
        cls.root_dir = root_dir  # Store root_dir as class attribute
        cls.transform = transform  # Store transform as class attribute

        labels_frame = pd.read_csv(csv_file)
        img_names = labels_frame.iloc[:, 0].values
        labels = labels_frame.iloc[:, 1].values
        dataset = tf.data.Dataset.from_tensor_slices((img_names, labels))
        
        # Map img_name and label using the _load_image_and_label function
        dataset = dataset.map(lambda img_name, label: tf.py_function(
            func=cls._load_image_and_label, 
            inp=[img_name, label],  # Only pass img_name and label as tensors
            Tout=(tf.float32, tf.float32)
        ), num_parallel_calls=tf.data.AUTOTUNE)
        
        return dataset

    @staticmethod
    def _load_image_and_label(img_name, label):
        try:
            img_name = img_name.numpy().decode("utf-8")
            label = label.numpy().decode("utf-8")
            
            img_path = ChestXrayDataset.find_image(img_name)

            if img_path is None:
                image = np.zeros((256, 256, 3), dtype=np.uint8)  # Return a default 256x256 black image
                one_hot_label = np.zeros(14, dtype=np.float32)  # Return a default label of zeros
            else:
                image = Image.open(img_path).convert('RGB')
                image = np.array(image)

                # Apply the transformation if provided
                if ChestXrayDataset.transform:
                    image = ChestXrayDataset.transform(image)

                # Convert label to one-hot encoding
                labels = label.split('|')
                one_hot_label = np.array([1 if c in labels else 0 for c in ChestXrayDataset.get_classes()], dtype=np.float32)

            return image, one_hot_label
        
        except Exception as e:
            print(f"Error processing {img_name}: {e}")
            image = np.zeros((256, 256, 3), dtype=np.uint8)  # Safe fallback image
            one_hot_label = np.zeros(14, dtype=np.float32)  # Safe fallback label
            return image, one_hot_label

    @staticmethod
    def find_image(img_name):
        for i in range(1, 3):
            folder = f'images_{i:03d}'
            img_path = os.path.join(ChestXrayDataset.root_dir, folder, 'images', img_name)
            if os.path.exists(img_path):
                return img_path
        return None

    @staticmethod
    def get_classes():
        return ['Atelectasis', 'Consolidation', 'Infiltration', 'Pneumothorax', 'Edema',
                'Emphysema', 'Fibrosis', 'Effusion', 'Pneumonia', 'Pleural_thickening',
                'Cardiomegaly', 'Nodule Mass', 'Hernia', 'No Finding']

# Set up the CNN model using TensorFlow/Keras
def create_model(input_shape, num_classes):
    model = models.Sequential([
        layers.Conv2D(16, (3, 3), padding='same', activation='relu', input_shape=input_shape),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(32, (3, 3), padding='same', activation='relu'),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(64, (3, 3), padding='same', activation='relu'),
        layers.MaxPooling2D((2, 2)),
        layers.Flatten(),
        layers.Dense(512, activation='relu'),
        layers.Dropout(0.5),
        layers.Dense(num_classes, activation='sigmoid')
    ])
    return model

# Preprocessing function
def preprocess(image):
    image = tf.image.resize(image, (256, 256))
    image = (image / 255.0) * 2 - 1  # Normalize to [-1, 1]
    return image

csv_file = './MED_IMG_DATA/NIH_Chest_X-rays/Data_Entry_2017.csv'
data_dir = './MED_IMG_DATA/NIH_Chest_X-rays/images_001'
batch_size = 16

# Load dataset and split into train/validation sets
dataset = ChestXrayDataset(csv_file, data_dir, transform=preprocess)

# 80% train, 20% validation split
dataset_list = list(dataset)  # Force dataset evaluation
train_size = int(0.8 * len(dataset_list))
train_dataset = tf.data.Dataset.from_tensor_slices(dataset_list[:train_size]).batch(batch_size)
val_dataset = tf.data.Dataset.from_tensor_slices(dataset_list[train_size:]).batch(batch_size)

# Compile and train the model
input_shape = (256, 256, 3)
num_classes = 14
model = create_model(input_shape, num_classes)

model.compile(optimizer='adam',
              loss='binary_crossentropy',
              metrics=['accuracy'])

# Custom training loop with progress display using tqdm
num_epochs = 5
for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    
    # Training loop
    prog_bar = tqdm(train_dataset, total=len(train_dataset), desc="Training", unit="batch")
    for step, (images, labels) in enumerate(prog_bar):
        loss_value, accuracy_value = model.train_on_batch(images, labels)
        
        prog_bar.set_postfix({"loss": loss_value, "accuracy": accuracy_value})

    # Validation loop
    print("\nValidating...")
    prog_bar_val = tqdm(val_dataset, total=len(val_dataset), desc="Validation", unit="batch")
    val_loss, val_acc = 0, 0
    for step, (images, labels) in enumerate(prog_bar_val):
        val_loss_value, val_acc_value = model.test_on_batch(images, labels)
        val_loss += val_loss_value
        val_acc += val_acc_value
        prog_bar_val.set_postfix({"val_loss": val_loss_value, "val_accuracy": val_acc_value})
    
    # Average validation metrics across the validation set
    avg_val_loss = val_loss / len(val_dataset)
    avg_val_acc = val_acc / len(val_dataset)
    print(f"\nEpoch {epoch+1} Summary: Val Loss: {avg_val_loss:.4f}, Val Accuracy: {avg_val_acc:.4f}")

print('Finished Training')
