<a href="https://colab.research.google.com/github/OluwafemiOlasupo/Coding-Contest/blob/main/Problem_3_Web_App.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install streamlit pyngrok

Collecting streamlit
  Downloading streamlit-1.46.0-py3-none-any.whl.metadata (9.0 kB)
Collecting pyngrok
  Downloading pyngrok-7.2.11-py3-none-any.whl.metadata (9.4 kB)
Collecting watchdog<7,>=2.1.5 (from streamlit)
  Downloading watchdog-6.0.0-py3-none-manylinux2014_x86_64.whl.metadata (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.3/44.3 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
Collecting pydeck<1,>=0.8.0b4 (from streamlit)
  Downloading pydeck-0.9.1-py2.py3-none-any.whl.metadata (4.1 kB)
Downloading streamlit-1.46.0-py3-none-any.whl (10.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.1/10.1 MB[0m [31m82.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pyngrok-7.2.11-py3-none-any.whl (25 kB)
Downloading pydeck-0.9.1-py2.py3-none-any.whl (6.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.9/6.9 MB[0m [31m76.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading watchdog-6.0.0-py3-none-manylinux2014_x86_64.wh

In [None]:
NGROK_TOKEN = "..........................."  # Replace with your actual token

In [None]:
from pyngrok import ngrok, conf
import os

In [None]:
# Set ngrok token
ngrok.set_auth_token(NGROK_TOKEN)



In [None]:
streamlit_app_code = '''
import streamlit as st
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import io

# Set page config
st.set_page_config(
    page_title="MNIST Digit Generator",
    page_icon="🔢",
    layout="wide"
)

# Generator Model Class (same as training script)
class Generator(nn.Module):
    def __init__(self, latent_dim=100, num_classes=10):
        super(Generator, self).__init__()
        self.latent_dim = latent_dim
        self.num_classes = num_classes

        # Embedding for class labels
        self.label_embedding = nn.Embedding(num_classes, latent_dim)

        # Generator layers
        self.model = nn.Sequential(
            nn.Linear(latent_dim * 2, 256),
            nn.ReLU(True),
            nn.Linear(256, 512),
            nn.ReLU(True),
            nn.Linear(512, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 28 * 28),
            nn.Tanh()
        )

    def forward(self, noise, labels):
        label_embed = self.label_embedding(labels)
        gen_input = torch.cat([noise, label_embed], dim=1)
        img = self.model(gen_input)
        img = img.view(img.size(0), 1, 28, 28)
        return img

@st.cache_resource
def load_model():
    """Load the trained generator model"""
    try:
        # Initialize model
        generator = Generator(latent_dim=100, num_classes=10)

        # Load trained weights - assumes model file is in same directory
        checkpoint = torch.load('digit_generator_model.pth', map_location='cpu')
        generator.load_state_dict(checkpoint['generator_state_dict'])
        generator.eval()

        return generator
    except Exception as e:
        st.error(f"Error loading model: {e}")
        st.info("Make sure the model file 'digit_generator_model.pth' is in the same directory!")
        return None

def generate_digit_images(generator, digit, num_images=5):
    """Generate images for a specific digit"""
    with torch.no_grad():
        # Generate random noise
        noise = torch.randn(num_images, 100)
        # Create labels for the specific digit
        labels = torch.full((num_images,), digit, dtype=torch.long)

        # Generate images
        generated_imgs = generator(noise, labels)

        # Convert to numpy and denormalize
        images = generated_imgs.cpu().numpy()
        images = (images + 1) / 2  # Denormalize from [-1, 1] to [0, 1]
        images = np.squeeze(images)

        return images

def create_image_grid(images, digit):
    """Create a grid of generated images"""
    fig, axes = plt.subplots(1, 5, figsize=(15, 3))
    fig.suptitle(f'Generated Handwritten Digit: {digit}', fontsize=18, fontweight='bold')

    for i, img in enumerate(images):
        axes[i].imshow(img, cmap='gray')
        axes[i].set_title(f'Sample {i+1}', fontsize=12)
        axes[i].axis('off')
        # Add border to make it look more like MNIST
        axes[i].set_facecolor('white')

    plt.tight_layout()
    return fig

def main():
    # Title and description
    st.title("🔢 MNIST Handwritten Digit Generator")
    st.markdown("""
    ### Generate Realistic Handwritten Digits with AI

    This web application uses a **Generative Adversarial Network (GAN)** trained on the MNIST dataset
    to generate realistic handwritten digits. Select any digit (0-9) and the AI will create 5 unique
    samples that look like they were written by hand!

    **✨ Features:**
    - Generate digits 0-9 on demand
    - Each generation produces unique variations
    - 28x28 pixel format (same as MNIST dataset)
    - Trained using PyTorch on Google Colab
    """)

    # Load model
    generator = load_model()

    if generator is None:
        st.stop()

    # Create two columns for layout
    col1, col2 = st.columns([3, 1])

    with col2:
        st.markdown("### 🎮 Controls")
        selected_digit = st.selectbox(
            "**Choose digit to generate:**",
            options=list(range(10)),
            index=0,
            help="Select any digit from 0 to 9"
        )

        generate_button = st.button(
            "🎲 Generate New Images",
            type="primary",
            help="Click to generate 5 new variations of the selected digit"
        )

        st.markdown("---")
        st.markdown("### ℹ️ About")
        st.markdown("""
        **Model:** Conditional GAN
        **Dataset:** MNIST
        **Framework:** PyTorch
        **Training:** Google Colab T4 GPU
        **Image Size:** 28×28 pixels
        """)

    with col1:
        st.markdown(f"### Generated Images for Digit: **{selected_digit}**")

        # Initialize session state
        if 'generated_images' not in st.session_state:
            st.session_state.generated_images = None
            st.session_state.last_digit = None

        # Generate images when button is clicked or digit changes
        if (generate_button or
            st.session_state.last_digit != selected_digit or
            st.session_state.generated_images is None):

            with st.spinner("🤖 AI is generating handwritten digits..."):
                try:
                    # Generate images
                    images = generate_digit_images(generator, selected_digit, 5)

                    # Store in session state
                    st.session_state.generated_images = images
                    st.session_state.last_digit = selected_digit

                    # Success message
                    st.success(f"✅ Successfully generated 5 unique samples of digit {selected_digit}!")

                except Exception as e:
                    st.error(f"❌ Error generating images: {e}")

        # Display images if available
        if st.session_state.generated_images is not None:
            fig = create_image_grid(st.session_state.generated_images, selected_digit)
            st.pyplot(fig, use_container_width=True)
            plt.close(fig)

            # Add some metrics
            st.markdown("---")
            col_a, col_b, col_c = st.columns(3)
            with col_a:
                st.metric("Images Generated", "5")
            with col_b:
                st.metric("Image Resolution", "28×28")
            with col_c:
                st.metric("Current Digit", selected_digit)

    # Footer
    st.markdown("---")
    st.markdown("""
    <div style='text-align: center; color: #666;'>
        <p>🚀 Built with Streamlit • 🧠 Powered by PyTorch • 🎯 Trained on Google Colab</p>
    </div>
    """, unsafe_allow_html=True)

if __name__ == "__main__":
    main()
'''

In [None]:
# Write the Streamlit app to a file
with open('digit_generator_app.py', 'w') as f:
    f.write(streamlit_app_code)

print("✅ Streamlit app file created: digit_generator_app.py")

✅ Streamlit app file created: digit_generator_app.py


In [None]:
# ============================================================================
# STEP 4: Function to run the Streamlit app with ngrok
# ============================================================================
def run_streamlit_app():
    """Run the Streamlit app and create public URL"""
    import subprocess
    import time
    import threading

    # Kill any existing ngrok tunnels
    ngrok.kill()

    # Function to run streamlit in background
    def run_streamlit():
        os.system("streamlit run digit_generator_app.py --server.port 8501 --server.headless true --server.fileWatcherType none --browser.gatherUsageStats false")

    # Start streamlit in a separate thread
    streamlit_thread = threading.Thread(target=run_streamlit)
    streamlit_thread.daemon = True
    streamlit_thread.start()

    # Wait a bit for streamlit to start
    time.sleep(5)

    # Create ngrok tunnel
    public_url = ngrok.connect(8501, proto="http", bind_tls=True)

    print("🌐" + "="*60)
    print("🚀 STREAMLIT APP IS NOW RUNNING!")
    print("="*62)
    print(f"📱 Public URL: {public_url}")
    print("="*62)
    print("📋 Instructions:")
    print("1. Click the URL above to access your web app")
    print("2. The app will be accessible to anyone with the link")
    print("3. Keep this Colab session running to maintain access")
    print("4. The app will auto-sleep if inactive (normal behavior)")
    print("="*62)

    return public_url