# core

> Fill in a module description here

In [None]:
#| default_exp core

In [None]:
#| export
import uuid,asyncio,shutil,subprocess,base64,time
from PIL import Image
from fasthtml.common import *
from gaspard import *
from plash_cli.auth import *
from fastlite import database
from fastcore.utils import *

In [None]:
from fasthtml.jupyter import *

In [None]:
#| export
ROOT_DIR = Config.find('settings.ini').config_path

In [None]:
#| export
app, rt = fast_app()
task_status = {}

In [None]:
#| export
utc_now = lambda: int(time.time())

In [None]:
#| export
class User: id: str
class Animation: id: str; uid: int; status: str = 'in_progress'
class Transaction: id: int; uid: str; amount: int; ts:int
class Vote: id: int; uid: str; aid: str; ts: int

In [None]:
#| export
db = database(ROOT_DIR / ('data/prod.db' if os.getenv('PLASH_PRODUCTION') else 'data/dev.db'))
users = db.create(User, transform=True)
animations = db.create(Animation, transform=True)
transactions = db.create(Transaction, transform=True)
votes = db.create(Vote, transform=True)

In [None]:
#| export
def toolbar(content=None, uid=None):
    credits = 0 if not uid else sum(t.amount for t in transactions("uid=?", (uid,)))
    return Div(
        # Existing toolbar
        Div(
            Div("Sketch Star", style="font-weight:bold; font-size:1.2em"),
            Div(
                Span(f"Credits: {credits}", style="margin-right:15px; font-weight:bold"),
                 A("Sketch!", href="/", style="margin-right:10px; padding:6px 12px; background:#007bff; color:white; text-decoration:none; border-radius:4px"),
                A("Leaderboard", href="/leaderboard", style="margin-right:10px; padding:6px 12px; background:#28a745; color:white; text-decoration:none; border-radius:4px"),
                A("Logout", href="/logout", style="padding:6px 12px; background:#dc3545; color:white; text-decoration:none; border-radius:4px") if uid else None,
                style="display:flex; align-items:center"
            ),
            style="display:flex; justify-content:space-between; align-items:center; padding:10px 20px; background:#f8f9fa; border-bottom:1px solid #dee2e6"
        ),
        # Centered content container
        Div(content, style="max-width:1200px; margin:0 auto; padding:20px") if content else None
    )

In [None]:
#| export
@rt('/login')
def login(session):
    # Get 3 random completed animations to showcase
    random_animations = db.q("SELECT id FROM animation WHERE status = 'completed' ORDER BY RANDOM() LIMIT 3")
    
    showcase_gifs = []
    for anim in random_animations:
        showcase_gifs.append(
            Img(src=f"/data/animations/{anim['id']}/animation.gif", 
                style="width:200px; height:200px; object-fit:contain; border-radius:8px; margin:10px")
        )
    
    return toolbar(uid=None, content=Div(
        H1("🎨 Welcome to Sketch Star!", style="text-align:center; color:#007bff"),
        P("Turn your sketches into magical animations! Draw on our canvas and watch AI bring your creativity to life.", 
          style="text-align:center; font-size:1.2em; margin:20px 0"),
        
        Div(H3("✨ See what others have created:", style="text-align:center"),
            Div(*showcase_gifs, style="text-align:center; margin:20px 0"),
            style="margin:30px 0"),
        
        Div(
            P("Ready to create your own animation?", style="font-size:1.1em; margin-bottom:15px"),
            A("🚀 Sign in with Google to get started!", href=mk_signin_url(session), 
              style="padding:15px 30px; background:#4285f4; color:white; text-decoration:none; border-radius:8px; font-size:1.1em; display:inline-block"),
            style="text-align:center; margin:40px 0"
        )
    ))

In [None]:
#| export
@rt(signin_completed_rt)
def signin_completed(session, signin_reply: str):
    try: session['uid'] = uid = goog_id_from_signin_reply(session, signin_reply)
    except Exception as e: return Div(H2("Login Failed"),P(f"Error: {e}"),A("Try Again", href="/login"))
    if not users.get(uid,default=None):
        users.insert(User(id=uid))
        if users.count <100: transactions.insert(Transaction(uid=uid, amount=3, ts=utc_now()))
    return RedirectResponse('/', status_code=303)

In [None]:
#| export
@rt('/logout')
def logout(session):
    session.pop('uid', None)
    return RedirectResponse('/', status_code=303)

In [None]:
#| export
@rt("/animation.gif")
async def get_gif(): return FileResponse("animation.gif")

@rt("/animations/{task_id}/animation.gif")
async def get_animation_gif(task_id: str): return FileResponse(f"./data/animations/{task_id}/animation.gif")

@rt("/status/{task_id}")
async def get_status(task_id: str):
    print(f"Status endpoint called for task: {task_id}")
    task = task_status.get(task_id, dict(progress=0, status='not_found', step='Task not found'))
    
    if task['status'] == 'complete': 
        return Div(
            P("✅ Animation complete!"), 
            Img(src=f"/data/animations/{task_id}/animation.gif", style="max-width:100%; border-radius:8px; box-shadow:0 4px 8px rgba(0,0,0,0.1)")
        )
    elif task['status'] == 'not_found':
        return Div(P("❌ Task not found"))
    else:
        print("returning the intermediate step...")
        return Div(
            P(f"🔄 {task['step']}"), 
            Progress(value=task['progress'], max=100, style="width:100%; margin:10px 0"), 
            P(f"Progress: {task['progress']}%", style="font-size:0.9em; color:#666"),
            id="progress", 
            hx_get=f"/status/{task_id}", 
            hx_trigger="every 2s", 
            hx_swap="outerHTML"
        )

In [None]:
#| export
@rt
async def index(session):
    if not (uid:=session.get('uid')): return RedirectResponse('/login', status_code=303)
    return toolbar(uid=uid,content=Div(
            Div(
                Button("Red", onclick="setColor('red')", style="background:red;color:white;margin:2px"),
                Button("Blue", onclick="setColor('blue')", style="background:blue;color:white;margin:2px"),
                Button("Black", onclick="setColor('black')", style="background:black;color:white;margin:2px"),
            ),
            Br(),
            Button("Submit Drawing", onclick="submitCanvas()", id="submit-btn"),
            Div(id="result", style="margin-top:20px")(
            Canvas(id="canvas", width="1024", height="1024", style="border:1px solid black"),
            Script("""
                const canvas = document.getElementById('canvas');
                const ctx = canvas.getContext('2d');
                let drawing = false;
                let currentColor = 'black';

                const img = new Image();
                img.onload = function() { ctx.drawImage(img, 0, 0, canvas.width, canvas.height); };
                img.src = '/data/frame0.png';

                function setColor(color) { currentColor = color; }

                canvas.addEventListener('mousedown', e => { drawing = true; draw(e); });
                canvas.addEventListener('mousemove', draw);
                canvas.addEventListener('mouseup', () => { drawing = false; ctx.beginPath(); });

                function draw(e) {
                    if (!drawing) return;
                    ctx.lineWidth = 3; ctx.lineCap = 'round'; ctx.strokeStyle = currentColor;
                    ctx.lineTo(e.clientX - canvas.offsetLeft, e.clientY - canvas.offsetTop);
                    ctx.stroke(); ctx.beginPath();
                    ctx.moveTo(e.clientX - canvas.offsetLeft, e.clientY - canvas.offsetTop);
                }

                function submitCanvas() {
                    const dataURL = canvas.toDataURL('image/png');
                    htmx.ajax('POST', '/submit_drawing', {values: {canvas_data: dataURL}, target: '#result'});
                }
            """)
        ))
    )

In [None]:
#| export
async def create_animation_async(task_id: str, canvas_data: str):
    task_status[task_id] = dict(progress=0, status='running', step='Initializing animation...')
    print(f"Task {task_id}: {task_status[task_id]}")

    session_dir = Path(f'./data/animations/{task_id}')
    session_dir.mkdir(parents=True, exist_ok=True)
    shutil.copy(Path('data/frame0.png'), session_dir / 'frame0.png')

    task_status[task_id].update(progress=10, step='Setting up session...')
    print(f"Task {task_id}: {task_status[task_id]}")

    img_data = canvas_data.split(',')[1]
    img_bytes = base64.b64decode(img_data)
    with open(session_dir / 'canvas_drawing.png', 'wb') as f: f.write(img_bytes)
    
    task_status[task_id].update(progress=20, step='Analyzing drawing...')
    
    chat = Chat("gemini-2.5-flash-image-preview")
    chat([session_dir / "frame0.png", session_dir / "canvas_drawing.png", """\
You're evaluating what the user drew on the original image and will animate it into a gif.
First decide if what the user drew should enable the character reach the star.
If the drawing of the user is not helpful the character should fall in the pit or fly off screen or just now shrug. Fails should be funny!
          
Second, you describe what the succesfull or unsuccesful action looks like. You ALWAYS do so in 6 steps.
Describe in 6 frames how the user's addition enables the character to reach the star or how they fail (one of these NOT both). 
Describe each frame in a single sentence and return them as a numbered list. 
We will generate those frames and combine them into a gif. 

It's important you don't add any other attributes besides what the user has drawn, but you can make creative use of it and add some pzazz. 
For example, if the character gets a trampoline instead of just jumping you can describe a fancy flip. 

It's important that you do not modify the original landscape.
"""])
    
    for i in range(1, 7):
        progress = 20 + (i-1) * 10  # Start at 20, increment by 10 for each frame
        task_status[task_id].update(progress=progress, step=f'Generating frame {i}/6...')
        print(f"Task {task_id}: {task_status[task_id]}")
        
        res = chat(f"Ok now generate frame {i}")
        with open(session_dir / f'frame{i}.png', 'wb') as f: f.write(res.candidates[0].content.parts[0].inline_data.data)
        
        progress_after = 20 + i * 10
        task_status[task_id].update(progress=progress_after, step=f'Frame {i}/6 completed')
    
    task_status[task_id].update(progress=90, step='Creating animation...')
    
    img0 = Image.open(session_dir / 'frame0.png').resize((1024, 1024))
    img0.save(session_dir / 'frame0_resized.png')
    
    img0 = Image.open(session_dir / 'canvas_drawing.png').resize((1024, 1024))
    img0.save(session_dir / 'canvas_drawing_resized.png')
    
    subprocess.run(['magick', '-delay', '20', str(session_dir / 'frame0_resized.png'), str(session_dir / 'canvas_drawing_resized.png')] + [str(session_dir / f'frame{i}.png') for i in range(1, 7)] + [str(session_dir / 'animation.gif')])
    
    task_status[task_id] = dict(progress=100, status='complete', step='Animation complete!', gif_path=str(session_dir / 'animation.gif'))
    a = animations[task_id]
    a.status = 'completed'
    animations.update(a)


In [None]:
#| export
@rt("/submit_drawing", methods=["POST"])
async def submit_drawing(canvas_data: str, session):
    if not session.get('uid'): return Div(P("❌ You must be logged in to submit a drawing."))
    uid = session['uid']
    
    ts = transactions("uid=?", (uid,))
    if balance:=sum(t.amount for t in ts) <= 0: return Div(P("❌ Insufficient credits to create animation."))
    
    transactions.insert(Transaction(uid=uid, amount=-1, ts=utc_now()))
    task_id = str(uuid.uuid4())
    animations.insert(Animation(id=task_id, uid=uid))
    asyncio.create_task(create_animation_async(task_id, canvas_data))
    return Div(P("🔄 Animation started..."), Progress(value=0, max=100), id="progress", hx_get=f"/status/{task_id}", hx_trigger="every 2s", hx_swap="outerHTML")

In [None]:
#| export
query = """
SELECT a.*, COALESCE(COUNT(v.aid), 0) as vote_count 
FROM animation a 
LEFT JOIN vote v ON a.id = v.aid 
WHERE a.status = 'completed'
GROUP BY a.id 
ORDER BY vote_count DESC 
LIMIT 100
"""

In [None]:
#| export
@rt('/leaderboard')
async def leaderboard(session):
    uid = session.get('uid')
    if not uid: return RedirectResponse('/login', status_code=303)
    
    # Get animations with vote counts
    top_animations = db.q(query)
    leaderboard_items = []
    for row in top_animations:
        aid = row['id']
        user_voted = len(votes("uid=? AND aid=?", (uid, aid))) > 0
        if uid == row['id']: btn = "Yours!"
        else: btn = "❤️ Undo" if user_voted else  "🤍 Upvote"
        vote_button = Button(
            btn,
            hx_post=f"/vote/{aid}",
            hx_target=f"#vote-{aid}",
            style="padding:8px 16px; border:none; border-radius:4px; cursor:pointer; " + 
                  ("background:#dc3545; color:white" if user_voted else "background:#28a745; color:white")
        )
        
        leaderboard_items.append(
            Div(
                Img(src=f"/data/animations/{aid}/animation.gif",
                    style="width:300px; height:300px; object-fit:contain; border-radius:8px"),
                Div(
                    P(f"Votes: {row['vote_count']}", style="font-weight:bold; margin:10px 0"),
                    Div(vote_button, style="text-align:center; padding:10px"),
                    id=f"vote-{aid}"
                ),
                style="display:inline-block; margin:20px; border:1px solid #ddd; border-radius:12px; padding:15px"
            )
        )
    
    return toolbar(uid=uid, content=Div(
        H1("🏆 Animation Leaderboard"),
        P("Vote for your favorite animations!"),
        Div(*leaderboard_items, style="text-align:center")
    ))

In [None]:
#| export
@rt("/vote/{animation_id}", methods=["POST"])
async def toggle_vote(animation_id: str, session):
    if not (uid:=session.get('uid')): return P("❌ Must be logged in")
    if existing_vote:= votes("uid=? AND aid=?", (uid, animation_id)):
        votes.delete(existing_vote[0].id)
        user_voted = False
    else:
        votes.insert(Vote(uid=uid, aid=animation_id, ts=utc_now()))
        user_voted = True
    
    # Get updated vote count
    new_count = len(votes("aid=?", (animation_id,)))
    
    vote_button = Button(
        "❤️ Undo" if user_voted else "🤍 Upvote",
        hx_post=f"/vote/{animation_id}",
        hx_target=f"#vote-{animation_id}",
        style="padding:8px 16px; border:none; border-radius:4px; cursor:pointer; " + 
              ("background:#dc3545; color:white" if user_voted else "background:#28a745; color:white")
    )
    
    return Div(
        P(f"Votes: {new_count}", style="font-weight:bold; margin:10px 0"),
        Div(vote_button, id=f"vote-{animation_id}")
    )