In [1]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [2]:
!cp "/content/drive/MyDrive/ECHO_NET_DYNAMIC/checkpoints/best_ef_model.pth" "/content/"

In [None]:
import gradio as gr
import cv2
import numpy as np
import torch
import torch.nn as nn

class EFNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv3d(3, 16, kernel_size=(3,3,3), padding=1),
            nn.BatchNorm3d(16),
            nn.ReLU(),
            nn.MaxPool3d((1,2,2)),
            nn.Conv3d(16, 32, kernel_size=(3,3,3), padding=1),
            nn.BatchNorm3d(32),
            nn.ReLU(),
            nn.MaxPool3d((2,2,2)),
            nn.Conv3d(32, 64, kernel_size=(3,3,3), padding=1),
            nn.BatchNorm3d(64),
            nn.ReLU(),
            nn.AdaptiveAvgPool3d((1, 4, 4))
        )
        self.lstm = nn.LSTM(input_size=64*4*4, hidden_size=128, num_layers=2, batch_first=True)
        self.regressor = nn.Sequential(
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 1)
        )

    def forward(self, x):
        batch_size = x.size(0)
        cnn_out = self.cnn(x)
        cnn_out = cnn_out.permute(0, 2, 1, 3, 4)
        cnn_out = cnn_out.reshape(batch_size, -1, 64*4*4)
        lstm_out, _ = self.lstm(cnn_out)
        last_frame_out = lstm_out[:, -1, :]
        ef_pred = self.regressor(last_frame_out)
        return ef_pred.squeeze()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = EFNet().to(device)
model.load_state_dict(torch.load('/content/best_ef_model.pth', map_location=device))
model.eval()


EFNet(
  (cnn): Sequential(
    (0): Conv3d(3, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (1): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2), padding=0, dilation=1, ceil_mode=False)
    (4): Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (5): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU()
    (7): MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=0, dilation=1, ceil_mode=False)
    (8): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (9): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU()
    (11): AdaptiveAvgPool3d(output_size=(1, 4, 4))
  )
  (lstm): LSTM(1024, 128, num_layers=2, batch_first=True)
  (regressor): Sequential(
    (0): Linear(in_features=128, out_features=64, bias=True)
    

In [None]:
import gradio as gr
import cv2
import numpy as np
import torch
import time
import pandas as pd
import os
from datetime import datetime
import json
import plotly.graph_objects as go
import plotly.express as px
from PIL import Image

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


os.makedirs("ef_results", exist_ok=True)

prediction_history = []


def predict_ef(video_file, save_result=True):
    try:
        if video_file is None:
            return "<span style='color:red;font-weight:bold;'>Error: Please upload a video file.</span>", None, None, None

        video_path = video_file.name

        if not str(video_path).lower().endswith((".mp4", ".avi", ".mov", ".mkv")):
            return "<span style='color:red;font-weight:bold;'>Error: Unsupported video format.</span>", None, None, None

        cap = cv2.VideoCapture(video_path)
        frames = []

        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        fps = cap.get(cv2.CAP_PROP_FPS)
        duration = total_frames / fps if fps > 0 else 0

        if total_frames == 0:
            return "<span style='color:red;font-weight:bold;'>Error: Cannot read video.</span>", None, None, None

        indices = np.linspace(0, total_frames - 1, 20, dtype=int)
        extracted_frames = []

        for i in indices:
            cap.set(cv2.CAP_PROP_POS_FRAMES, i)
            ret, frame = cap.read()
            if not ret:
                frame = np.zeros((112, 112, 3), dtype=np.uint8)
            frame = cv2.resize(frame, (112, 112))
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frames.append(frame)

            if i in [indices[0], indices[len(indices)//2], indices[-1]]:
                extracted_frames.append((i, frame))

        cap.release()

        video_tensor = torch.FloatTensor(
            np.stack(frames).transpose(3, 0, 1, 2) / 255.0
        ).unsqueeze(0).to(device)

        with torch.no_grad():
            ef = model(video_tensor).item()
            ef = round(ef, 2)

        if ef >= 55:
            color = "#16a34a"   # Green
            status = "Normal EF"
            gradient = "linear-gradient(135deg, #4ade80, #16a34a)"
            status_color = "green"
        elif ef >= 40:
            color = "#f97316"   # Orange
            status = "Mildly Reduced EF"
            gradient = "linear-gradient(135deg, #fb923c, #f97316)"
            status_color = "orange"
        else:
            color = "#dc2626"   # Red
            status = "Severely Reduced EF"
            gradient = "linear-gradient(135deg, #f87171, #dc2626)"
            status_color = "red"

        html_output = f"""
        <div style="
            background: {gradient};
            padding: 25px;
            border-radius: 18px;
            color: white;
            font-family: Arial, sans-serif;
            width: 85%;
            margin: auto;
            box-shadow: 0 4px 15px rgba(0,0,0,0.2);
        ">
            <h2 style="margin: 0; font-size: 32px; text-align:center;">EF Prediction</h2>
            <p style="font-size: 50px; font-weight:bold; text-align:center; margin: 10px 0;">
                {ef}%
            </p>
            <p style="text-align:center; font-size: 20px; opacity: 0.9;">
                Status: <b>{status}</b>
            </p>
            <div style="display: flex; justify-content: space-between; margin-top: 15px;">
                <div style="text-align: center;">
                    <p style="font-size: 14px; margin: 0;">Video Duration</p>
                    <p style="font-size: 16px; font-weight: bold;">{duration:.2f} sec</p>
                </div>
                <div style="text-align: center;">
                    <p style="font-size: 14px; margin: 0;">Total Frames</p>
                    <p style="font-size: 16px; font-weight: bold;">{total_frames}</p>
                </div>
                <div style="text-align: center;">
                    <p style="font-size: 14px; margin: 0;">FPS</p>
                    <p style="font-size: 16px; font-weight: bold;">{fps:.2f}</p>
                </div>
            </div>
        </div>
        """

        fig_ef = go.Figure(go.Indicator(
            mode = "gauge+number+delta",
            value = ef,
            domain = {'x': [0, 1], 'y': [0, 1]},
            title = {'text': "Ejection Fraction (%)"},
            delta = {'reference': 55},
            gauge = {
                'axis': {'range': [None, 100]},
                'bar': {'color': color},
                'steps': [
                    {'range': [0, 40], 'color': "lightgray"},
                    {'range': [40, 55], 'color': "gray"},
                    {'range': [55, 100], 'color': "lightgray"}
                ],
                'threshold': {
                    'line': {'color': "red", 'width': 4},
                    'thickness': 0.75,
                    'value': 55
                }
            }
        ))
        fig_ef.update_layout(height=300, font={'color': "darkblue", 'family': "Arial"})

        fig_frames = px.imshow(
            np.array([frame for _, frame in extracted_frames]),
            facet_col=0,
            binary_string=False,
            labels={'facet_col': 'Sample Frames'},
            title='Sample Frames from Video'
        )
        fig_frames.update_xaxes(showticklabels=False)
        fig_frames.update_yaxes(showticklabels=False)
        for i in range(len(extracted_frames)):
            fig_frames.layout.annotations[i]['text'] = f"Frame {extracted_frames[i][0]}"

        if save_result:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            result = {
                "timestamp": timestamp,
                "ef": ef,
                "status": status,
                "video_path": video_path,
                "duration": duration,
                "total_frames": total_frames,
                "fps": fps
            }

            prediction_history.append(result)

            with open(f"ef_results/result_{timestamp}.json", "w") as f:
                json.dump(result, f)

        if prediction_history:
            df = pd.DataFrame(prediction_history[-5:])  # Show last 5 predictions
            df['timestamp'] = pd.to_datetime(df['timestamp'], format='%Y%m%d_%H%M%S')
            df = df.sort_values('timestamp', ascending=False)
            df['timestamp'] = df['timestamp'].dt.strftime('%Y-%m-%d %H:%M:%S')
            df = df[['timestamp', 'ef', 'status']]
            df.columns = ['Timestamp', 'EF (%)', 'Status']

            def highlight_status(val):
                if val == 'Normal EF':
                    return 'background-color: #d4edda'
                elif val == 'Mildly Reduced EF':
                    return 'background-color: #fff3cd'
                else:
                    return 'background-color: #f8d7da'

            styled_df = df.style.applymap(highlight_status, subset=['Status'])
            history_table = styled_df.to_html()
        else:
            history_table = "<p>No prediction history available</p>"

        return html_output, fig_ef, fig_frames, history_table

    except Exception as e:
        error_msg = f"<span style='color:red;font-weight:bold;'>Error: {str(e)}</span>"
        print(f"Full error details: {e}")  
        return error_msg, None, None, None

def clear_history():
    global prediction_history
    prediction_history = []
    return "<p>History cleared</p>"

def download_history():
    if not prediction_history:
        return None

    df = pd.DataFrame(prediction_history)
    df['timestamp'] = pd.to_datetime(df['timestamp'], format='%Y%m%d_%H%M%S')
    df = df.sort_values('timestamp', ascending=False)
    df['timestamp'] = df['timestamp'].dt.strftime('%Y-%m-%d %H:%M:%S')

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    csv_path = f"ef_results/history_{timestamp}.csv"
    df.to_csv(csv_path, index=False)

    return csv_path

#  Web Interface
with gr.Blocks(title="EF Prediction (Deep Learning Model)", theme=gr.themes.Soft()) as iface:
    gr.Markdown("""
    # Ejection Fraction (EF) Prediction System

    Upload an echocardiography video and the deep learning model will predict the **Ejection Fraction (EF)** with color-coded severity levels.

    ## How to use:
    1. Upload a video file (MP4, AVI, MOV, MKV)
    2. Click "Predict EF" to analyze the video
    3. View the results and sample frames
    4. Check the prediction history below
    """)

    with gr.Row():
        with gr.Column(scale=1):
            video_input = gr.File(file_types=["video"], label="Upload Echocardiography Video")
            predict_btn = gr.Button("Predict EF", variant="primary")

            with gr.Row():
                clear_btn = gr.Button("Clear History")
                download_btn = gr.Button("Download History")

            gr.Markdown("""
            ### About EF Values:
            - **Normal EF**: â‰¥55% (Green)
            - **Mildly Reduced EF**: 40-54% (Orange)
            - **Severely Reduced EF**: <40% (Red)
            """)

            gr.Markdown("""
            ### Model Information:
            - Architecture: 3D CNN
            - Input: 20 frames (112x112 RGB)
            - Training Dataset: EchoNet-Dynamic
            """)

            download_output = gr.File(label="Download History CSV", visible=False)

        with gr.Column(scale=2):
            html_output = gr.HTML(label="Prediction Result")

    with gr.Row():
        with gr.Column():
            fig_ef = gr.Plot(label="EF Value Visualization")

        with gr.Column():
            fig_frames = gr.Plot(label="Sample Frames from Video")

    with gr.Row():
        history_table = gr.HTML(label="Prediction History")

    predict_btn.click(
        fn=predict_ef,
        inputs=[video_input],
        outputs=[html_output, fig_ef, fig_frames, history_table]
    )

    clear_btn.click(
        fn=clear_history,
        outputs=[history_table]
    )

    download_btn.click(
        fn=download_history,
        outputs=[download_output]
    ).then(
        lambda: gr.File(visible=True),
        outputs=[download_output]
    )

iface.launch()

  with gr.Blocks(title="EF Prediction (Deep Learning Model)", theme=gr.themes.Soft()) as iface:


It looks like you are running Gradio on a hosted Jupyter notebook, which requires `share=True`. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://ba97a0998234105cc2.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


