In [None]:
from google.colab import drive
# Grant access to your Drive
drive.mount('/content/drive')

# Move into working directory and clone StyleGAN2-ADA
%cd /content
!git clone https://github.com/NVlabs/stylegan2-ada-pytorch.git


In [None]:
!pip install torch torchvision ninja pyspng imageio-ffmpeg==0.4.3 scipy


In [None]:
%cd /content/stylegan2-ada-pytorch

# Fix version compatibility
!pip install imageio==2.9.0 imageio-ffmpeg==0.4.2

In [None]:
!pip install streamlit

In [None]:
%%bash
cat > /content/improved_streamlit_stylegan2_ui.py << 'EOF'
import os
import sys
import pickle
import subprocess
import tempfile
import shutil
import numpy as np
import torch
from torchvision.utils import save_image
import streamlit as st
from PIL import Image

# Ensure stylegan2-ada modules are importable
NETWORK_DIR = '/content/stylegan2-ada-pytorch'
sys.path.insert(0, NETWORK_DIR)

# --- CONFIG ---
G_PKL = '/content/drive/MyDrive/stylegan2ada_checkpoints/00008-fashion256-auto1-kimg1360-batch32-noaug-resumecustom-20250604T121719Z-1-001/00008-fashion256-auto1-kimg1360-batch32-noaug-resumecustom/network-snapshot-000160.pkl'
OUT_BASE = '/content/drive/MyDrive/Filtered Fashion images/men_generated_outfit'
PROJECTOR_SCRIPT = f"{NETWORK_DIR}/projector.py"
PROJECTION_DIRS = ['/content/invert1', '/content/invert2']

# Inject custom CSS for larger, centered buttons
st.markdown("""
<style>
/* Global button styling */
div.stButton > button {
    width: 200px;
    height: 50px;
    font-size: 18px;
    display: block;
    margin: 20px auto;
}
</style>
""", unsafe_allow_html=True)

# Project image into W-space
def project_image(src, outdir, n_steps):
    os.makedirs(outdir, exist_ok=True)
    cmd = [
        'python', PROJECTOR_SCRIPT,
        '--outdir', outdir,
        '--network', G_PKL,
        '--num-steps', str(n_steps),
        '--target', src
    ]
    subprocess.run(cmd, cwd=NETWORK_DIR, check=True)

# Mix latents and synthesize images
def generate_mixed(w1_npz, w2_npz, name, truncation_psi):
    out_dir = os.path.join(OUT_BASE, name)
    os.makedirs(out_dir, exist_ok=True)
    data = pickle.load(open(G_PKL, 'rb'))
    G = data['G_ema'].cuda()
    try:
        w_avg = G.mapping.w_avg.unsqueeze(0)
    except Exception:
        w_avg = None
    w1 = torch.from_numpy(np.load(w1_npz)['w']).cuda()
    w2 = torch.from_numpy(np.load(w2_npz)['w']).cuda()
    layers = {'Coarse (0–3)': (0,4), 'Middle–Fine (4–13)': (4,14), 'Fine (8–13)': (8,14)}
    results = {}
    with torch.no_grad():
        for label, (start, end) in layers.items():
            w = w1.clone()
            w[:, start:end] = w2[:, start:end]
            if w_avg is not None:
                w = w_avg + truncation_psi * (w - w_avg)
            img_tensor = (G.synthesis(w, noise_mode='const') + 1) / 2
            path = os.path.join(out_dir, f"mixed_{label.replace(' ', '_')}.png")
            save_image(img_tensor, path)
            results[label] = path
    return results

# --- Streamlit UI ---
st.set_page_config(
    page_title="Fashion Latent Mixer",
    layout="wide",
    initial_sidebar_state="expanded"
)

# Reset callback
def reset_all():
    # Clear session state
    for d in PROJECTION_DIRS:
        if os.path.isdir(d):
            shutil.rmtree(d)
    for key in ['run', 'projected', 'mixed', 'imgs', 'img1', 'img2']:
        if key in st.session_state:
            del st.session_state[key]
    # Remove projection directories

# Initialize session state flags and storage
for key in ['run', 'projected', 'mixed', 'imgs']:
    if key not in st.session_state:
        st.session_state[key] = False if key != 'imgs' else None

# Sidebar controls
st.sidebar.header("Configuration")
truncation_psi = st.sidebar.slider("Truncation ψ", min_value=0.0, max_value=1.0, value=0.7, step=0.05)
n_steps = st.sidebar.number_input("Projection Steps", min_value=1000, max_value=10000, value=5500, step=500)

# Main UI
st.title("🎨 Fashion Outfit Mixer")
st.markdown("Upload two fashion images, then mix their latent codes at different levels to generate new style blends.")

# File upload
col1, col2 = st.columns(2)
with col1:
    img1 = st.file_uploader("First image", type=['jpg','png'], key='img1')
with col2:
    img2 = st.file_uploader("Second image", type=['jpg','png'], key='img2')

# Preview & Generate button only when both images are uploaded
if img1 and img2 and not st.session_state['run']:
    cols = st.columns(2)
    for col, img, cap in zip(cols, [img1, img2], ["Image 1", "Image 2"]):
        with col:
            inner = st.columns([1,2,1])[1]
            with inner:
                st.image(Image.open(img).resize((300,300)), caption=cap)
    if st.button("🔄 Generate Mix", key='generate_mix'):
        st.session_state['run'] = True

# Projection: only after clicking Generate
if st.session_state.get('run') and not st.session_state.get('projected'):
    tmp1 = tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(st.session_state['img1'].name)[1])
    tmp1.write(st.session_state['img1'].getbuffer()); tmp1.flush()
    tmp2 = tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(st.session_state['img2'].name)[1])
    tmp2.write(st.session_state['img2'].getbuffer()); tmp2.flush()
    st.info("🔄 Projecting images into latent space...")
    project_image(tmp1.name, PROJECTION_DIRS[0], n_steps)
    project_image(tmp2.name, PROJECTION_DIRS[1], n_steps)
    st.session_state['projected'] = True
    st.success("✅ Projection complete!")

# Mixing: after projection done
if st.session_state.get('projected') and not st.session_state.get('mixed'):
    st.info("🖌️ Generating mixed images...")
    w1, w2 = os.path.join(PROJECTION_DIRS[0], 'projected_w.npz'), os.path.join(PROJECTION_DIRS[1], 'projected_w.npz')
    tag = f"{os.path.splitext(st.session_state['img1'].name)[0]}_{os.path.splitext(st.session_state['img2'].name)[0]}"
    st.session_state['imgs'] = generate_mixed(w1, w2, tag, truncation_psi)
    st.session_state['mixed'] = True
    st.success("🎉 Mix generation done!")

# Display mixes and Start Over button
if st.session_state.get('mixed') and st.session_state.get('imgs'):
    imgs = st.session_state['imgs']
    st.subheader("Generated Mixes")
    cols = st.columns(len(imgs))
    for idx, (lbl, path) in enumerate(imgs.items()):
        with cols[idx]:
            # Image and centered download button
            st.image(path, caption=lbl, width=300)
            btn_col = st.columns([1,2,1])[1]
            with btn_col:
                st.download_button(
                    label="Download",
                    data=open(path, 'rb').read(),
                    file_name=os.path.basename(path),
                    mime="image/png",
                    key=f'download_{idx}'
                )
    # Centered Start Over
    with st.columns([1,2,1])[1]:
        st.button("🔄 Start Over", on_click=reset_all)



EOF

In [None]:
!wget -q -O - ipv4.icanhazip.com

In [None]:
!streamlit run /content/improved_streamlit_stylegan2_ui.py & npx localtunnel --port 8501

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
step 3015/4000: dist 0.14 loss 5.66 
step 3016/4000: dist 0.14 loss 12.06
step 3017/4000: dist 0.14 loss 8.83 
step 3018/4000: dist 0.14 loss 2.50 
step 3019/4000: dist 0.14 loss 2.22 
step 3020/4000: dist 0.14 loss 4.77 
step 3021/4000: dist 0.14 loss 4.71 
step 3022/4000: dist 0.14 loss 3.59 
step 3023/4000: dist 0.14 loss 2.80 
step 3024/4000: dist 0.14 loss 1.62 
step 3025/4000: dist 0.14 loss 1.89 
step 3026/4000: dist 0.14 loss 3.40 
step 3027/4000: dist 0.14 loss 2.42 
step 3028/4000: dist 0.14 loss 0.33 
step 3029/4000: dist 0.14 loss 1.19 
step 3030/4000: dist 0.14 loss 2.62 
step 3031/4000: dist 0.14 loss 1.37 
step 3032/4000: dist 0.14 loss 0.24 
step 3033/4000: dist 0.14 loss 1.10 
step 3034/4000: dist 0.14 loss 1.48 
step 3035/4000: dist 0.14 loss 0.80 
step 3036/4000: dist 0.14 loss 0.57 
step 3037/4000: dist 0.14 loss 0.69 
step 3038/4000: dist 0.14 loss 0.73 
step 3039/4000: dist 0.14 loss 0.80 
step 3040/