## Notebook Link
https://colab.research.google.com/drive/1uNygzDR4hISwLOgmDS31hRHfr6KAF7Ib?usp=sharing

In [None]:
from IPython.display import clear_output
clear_output(wait=False)

!wget https://raw.githubusercontent.com/Unknown-Geek/Story-Generator/main/requirements.txt
!wget https://raw.githubusercontent.com/Unknown-Geek/Story-Generator/main/setup.py
!python setup.py

clear_output(wait=False)
print("Setup complete!")

In [None]:
from flask import Flask, request, jsonify
from flask_cors import CORS
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
import torch
from PIL import Image
import io
import base64
from pyngrok import ngrok

# Configure ngrok
from pyngrok import ngrok
ngrok.set_auth_token('2oHUyMyGNJuD34GO6NdGJd8KAxd_3TyYzUGLMA9DvgUopNRw3')

app = Flask(__name__)
CORS(app)

def load_models():
    # Initialize device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Load image captioning model
    image_captioner = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
    
    # Load story generation model
    model_name = "gpt2"
    story_tokenizer = AutoTokenizer.from_pretrained(model_name)
    story_model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
    
    return image_captioner, story_tokenizer, story_model, device

@app.route('/generate_story', methods=['POST'])
def generate_story():
    try:
        data = request.json
        image_data = data['image'].split(',')[1]  # Remove the data:image/jpeg;base64 prefix
        genre = data['genre']
        
        # Decode base64 image
        image_bytes = base64.b64decode(image_data)
        image = Image.open(io.BytesIO(image_bytes))
        
        # Generate image caption
        caption = image_captioner(image)[0]['generated_text']
        
        # Create prompt for story generation
        prompt = f"Write a {genre.lower()} story about: {caption}\n\nStory:"
        
        # Generate story
        inputs = story_tokenizer(prompt, return_tensors="pt").to(device)
        output_sequences = story_model.generate(
            input_ids=inputs['input_ids'],
            max_length=200,
            temperature=0.8,
            top_k=50,
            top_p=0.9,
            repetition_penalty=1.2,
            do_sample=True,
            num_return_sequences=1
        )
        
        story = story_tokenizer.decode(output_sequences[0], skip_special_tokens=True)
        
        return jsonify({
            'success': True,
            'story': story.replace(prompt, '').strip()
        })
        
    except Exception as e:
        return jsonify({
            'success': False,
            'error': str(e)
        }), 500

from IPython.display import clear_output
clear_output(wait=False)

def start_server():
    try:
        ngrok.kill()
        public_url = ngrok.connect(5000)
        clear_output(wait=False)
        print(f'Server running at: {public_url}')
        
        from url_store import save_url
        save_url(str(public_url))
        app.run(port=5000)
    except Exception as e:
        print(f'Error starting server: {str(e)}')
        ngrok.kill()

if __name__ == '__main__':
    print("Loading AI models...")
    try:
        image_captioner, story_tokenizer, story_model, device = load_models()
        clear_output(wait=False)
        print("Models loaded successfully!")
        print("Starting the server...")
        start_server()
    except Exception as e:
        print(f"Error: {str(e)}")
        ngrok.kill()
        
clear_output(wait=False)