In [None]:
import requests
import zipfile
import io
import os

def download_and_extract_dataset(url, extract_to='./dataset'):
    print("Downloading dataset...")
    response = requests.get(url)
    print("Download complete. Extracting...")
    z = zipfile.ZipFile(io.BytesIO(response.content))
    z.extractall(extract_to)
    print(f"Dataset extracted to {extract_to}")

if __name__ == "__main__":
    dataset_url = "https://dicom5c.blob.core.windows.net/public/Data.zip"
    download_and_extract_dataset(dataset_url)

Downloading dataset...


In [None]:
import os
import numpy as np
import tensorflow as tf
from sklearn.model_selection import train_test_split
import albumentations as A

class BrainMRIDataLoader:
    def __init__(self, data_dir, img_size=(256, 256), batch_size=32):
        self.data_dir = data_dir
        self.img_size = img_size
        self.batch_size = batch_size
        
    def load_data(self):
        images = []
        masks = []
        for filename in os.listdir(self.data_dir):
            if filename.endswith(('_1.tif', '_2.tif', '_3.tif', '_4.tif', '_5.tif')):
                img_path = os.path.join(self.data_dir, filename)
                mask_path = os.path.join(self.data_dir, filename.replace('.tif', '_mask.tif'))
                
                if os.path.exists(mask_path):
                    img = self._load_img(img_path)
                    mask = self._load_img(mask_path)
                    
                    images.append(img)
                    masks.append(mask)
        
        return np.array(images), np.array(masks)
    
    def _load_img(self, path):
        img = tf.io.read_file(path)
        img = tf.image.decode_tiff(img)
        img = tf.image.resize(img, self.img_size)
        img = tf.cast(img, tf.float32) / 255.0
        return img.numpy()
    
    def preprocess(self, image, mask):
        augmentor = A.Compose([
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.RandomRotate90(p=0.5),
            A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=45, p=0.5),
            A.OneOf([
                A.ElasticTransform(alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03, p=0.5),
                A.GridDistortion(p=0.5),
                A.OpticalDistortion(distort_limit=1, shift_limit=0.5, p=0.5),
            ], p=0.3),
            A.CLAHE(clip_limit=2.0, p=0.8),
            A.RandomBrightnessContrast(p=0.8),    
            A.RandomGamma(p=0.8),
        ])
        
        augmented = augmentor(image=image, mask=mask)
        return augmented['image'], augmented['mask']

    def create_dataset(self, X, y):
        dataset = tf.data.Dataset.from_tensor_slices((X, y))
        dataset = dataset.map(lambda x, y: tf.numpy_function(self.preprocess, [x, y], [tf.float32, tf.float32]),
                              num_parallel_calls=tf.data.experimental.AUTOTUNE)
        dataset = dataset.batch(self.batch_size).prefetch(tf.data.experimental.AUTOTUNE)
        return dataset

In [None]:
import tensorflow as tf
from tensorflow.keras import layers, Model

class AttentionGate(layers.Layer):
    def __init__(self, filters):
        super(AttentionGate, self).__init__()
        self.W_g = layers.Conv2D(filters, 1, padding='same')
        self.W_x = layers.Conv2D(filters, 1, padding='same')
        self.psi = layers.Conv2D(1, 1, padding='same')
        
    def call(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = tf.keras.activations.relu(g1 + x1)
        psi = self.psi(psi)
        return tf.keras.activations.sigmoid(psi) * x

def attention_unet(input_size=(256, 256, 1), num_classes=1):
    inputs = layers.Input(input_size)
    
    # Encoder
    conv1 = layers.Conv2D(64, 3, activation='relu', padding='same')(inputs)
    conv1 = layers.Conv2D(64, 3, activation='relu', padding='same')(conv1)
    pool1 = layers.MaxPooling2D(pool_size=(2, 2))(conv1)
    
    conv2 = layers.Conv2D(128, 3, activation='relu', padding='same')(pool1)
    conv2 = layers.Conv2D(128, 3, activation='relu', padding='same')(conv2)
    pool2 = layers.MaxPooling2D(pool_size=(2, 2))(conv2)
    
    conv3 = layers.Conv2D(256, 3, activation='relu', padding='same')(pool2)
    conv3 = layers.Conv2D(256, 3, activation='relu', padding='same')(conv3)
    pool3 = layers.MaxPooling2D(pool_size=(2, 2))(conv3)
    
    conv4 = layers.Conv2D(512, 3, activation='relu', padding='same')(pool3)
    conv4 = layers.Conv2D(512, 3, activation='relu', padding='same')(conv4)
    drop4 = layers.Dropout(0.5)(conv4)
    pool4 = layers.MaxPooling2D(pool_size=(2, 2))(drop4)
    
    # Bridge
    conv5 = layers.Conv2D(1024, 3, activation='relu', padding='same')(pool4)
    conv5 = layers.Conv2D(1024, 3, activation='relu', padding='same')(conv5)
    drop5 = layers.Dropout(0.5)(conv5)
    
    # Decoder
    up6 = layers.UpSampling2D(size=(2, 2))(drop5)
    up6 = layers.Conv2D(512, 2, activation='relu', padding='same')(up6)
    attn6 = AttentionGate(512)(g=up6, x=drop4)
    merge6 = layers.concatenate([up6, attn6])
    conv6 = layers.Conv2D(512, 3, activation='relu', padding='same')(merge6)
    conv6 = layers.Conv2D(512, 3, activation='relu', padding='same')(conv6)
    
    up7 = layers.UpSampling2D(size=(2, 2))(conv6)
    up7 = layers.Conv2D(256, 2, activation='relu', padding='same')(up7)
    attn7 = AttentionGate(256)(g=up7, x=conv3)
    merge7 = layers.concatenate([up7, attn7])
    conv7 = layers.Conv2D(256, 3, activation='relu', padding='same')(merge7)
    conv7 = layers.Conv2D(256, 3, activation='relu', padding='same')(conv7)
    
    up8 = layers.UpSampling2D(size=(2, 2))(conv7)
    up8 = layers.Conv2D(128, 2, activation='relu', padding='same')(up8)
    attn8 = AttentionGate(128)(g=up8, x=conv2)
    merge8 = layers.concatenate([up8, attn8])
    conv8 = layers.Conv2D(128, 3, activation='relu', padding='same')(merge8)
    conv8 = layers.Conv2D(128, 3, activation='relu', padding='same')(conv8)
    
    up9 = layers.UpSampling2D(size=(2, 2))(conv8)
    up9 = layers.Conv2D(64, 2, activation='relu', padding='same')(up9)
    attn9 = AttentionGate(64)(g=up9, x=conv1)
    merge9 = layers.concatenate([up9, attn9])
    conv9 = layers.Conv2D(64, 3, activation='relu', padding='same')(merge9)
    conv9 = layers.Conv2D(64, 3, activation='relu', padding='same')(conv9)
    
    outputs = layers.Conv2D(num_classes, 1, activation='sigmoid')(conv9)
    
    model = Model(inputs=inputs, outputs=outputs)
    return model

In [None]:
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping
from data_loader import BrainMRIDataLoader
from models import attention_unet
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

def train_model(model, train_dataset, val_dataset, epochs=100):
    callbacks = [
        ModelCheckpoint('best_model.h5', save_best_only=True, monitor='val_loss'),
        ReduceLROnPlateau(factor=0.1, patience=5, min_lr=1e-6, verbose=1),
        EarlyStopping(patience=10, verbose=1)
    ]
    
    history = model.fit(
        train_dataset,
        validation_data=val_dataset,
        epochs=epochs,
        callbacks=callbacks
    )
    
    return history

def plot_training_history(history):
    plt.figure(figsize=(12, 4))
    plt.subplot(121)
    plt.plot(history.history['loss'], label='Train Loss')
    plt.plot(history.history['val_loss'], label='Validation Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    
    plt.subplot(122)
    plt.plot(history.history['accuracy'], label='Train Accuracy')
    plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
    plt.title('Training and Validation Accuracy')
    plt.legend()
    
    plt.tight_layout()
    plt.savefig('training_history.png')
    plt.close()

if __name__ == "__main__":
    data_loader = BrainMRIDataLoader('./dataset')
    X, y = data_loader.load_data()
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
    train_dataset = data_loader.create_dataset(X_train, y_train)
    val_dataset = data_loader.create_dataset(X_test, y_test)

    model = attention_unet()
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

    history = train_model(model, train_dataset, val_dataset)
    plot_training_history(history)

In [None]:
import numpy as np
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score
import tensorflow as tf
from data_loader import BrainMRIDataLoader
from models import attention_unet

def dice_coefficient(y_true, y_pred, smooth=1):
    y_true_f = y_true.flatten()
    y_pred_f = y_pred.flatten()
    intersection = np.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (np.sum(y_true_f) + np.sum(y_pred_f) + smooth)

def evaluate_model(model, test_dataset):
    y_true = []
    y_pred = []
    
    for images, masks in test_dataset:
        predictions = model.predict(images)
        y_true.extend(masks.numpy().flatten())
        y_pred.extend(predictions.flatten())
    
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    y_pred_binary = (y_pred > 0.5).astype(int)
    
    dice = dice_coefficient(y_true, y_pred_binary)
    precision = precision_score(y_true, y_pred_binary)
    recall = recall_score(y_true, y_pred_binary)
    f1 = f1_score(y_true, y_pred_binary)
    
    print(f"Dice Coefficient: {dice:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1 Score: {f1:.4f}")
    
    cm = confusion_matrix(y_true, y_pred_binary)
    print("Confusion Matrix:")
    print(cm)

if __name__ == "__main__":
    data_loader = BrainMRIDataLoader('./dataset')
    X, y = data_loader.load_data()
    _, X_test, _, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
    test_dataset = data_loader.create_dataset(X_test, y_test)

    model = tf.keras.models.load_model('best_model.h5', custom_objects={'AttentionGate': AttentionGate})
    evaluate_model(model, test_dataset)

In [None]:
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
import numpy as np
import tensorflow as tf
import cv2
import io
from PIL import Image
import base64
from models import AttentionGate

app = FastAPI()

model = tf.keras.models.load_model('best_model.h5', custom_objects={'AttentionGate': AttentionGate})

def preprocess_image(image):
    image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    image = cv2.resize(image, (256, 256))
    image = image.astype(np.float32) / 255.0
    return np.expand_dims(image, axis=[0, -1])

@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
    if file.content_type.split('/')[0] != 'image':
        raise HTTPException(status_code=400, detail="File is not an image.")
    
    contents = await file.read()
    nparr = np.frombuffer(contents, np.uint8)
    img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
    
    preprocessed_img = preprocess_image(img)
    prediction = model.predict(preprocessed_img)
    
    binary_mask = (prediction > 0.5).astype(np.uint8)
    binary_mask = np.squeeze(binary_mask) * 255
    
    # Create a color overlay
    color_mask = cv2.applyColorMap(binary_mask,