In [0]:
# Install required packages
!pip install flask transformers torch pillow flask-cors accelerate pyngrok

from flask import Flask, request, jsonify
from flask_cors import CORS
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
import torch
from PIL import Image
import io
from pyngrok import ngrok, conf

app = Flask(__name__)
CORS(app, resources={r"/*": {
    "origins": [
        "http://127.0.0.1:5000",  # Frontend
        "http://localhost:5000",   # Frontend alternate
        "http://127.0.0.1:5001",  # Backend
        "http://localhost:5001"    # Backend alternate
    ]
}})

# Configure ngrok
NGROK_AUTH_TOKEN = "2oHUyMyGNJuD34GO6NdGJd8KAxd_3TyYzUGLMA9DvgUopNRw3"
ngrok.set_auth_token(NGROK_AUTH_TOKEN)

# Update port to 5002
PORT = 5002

try:
    print('Loading AI models...')
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Adjust model parameters for better performance
    image_captioner = pipeline(
        "image-to-text", 
        model="microsoft/git-large-coco",
        device=device,
        max_new_tokens=50  # Limit caption length
    )
    
    # Update story model initialization
    story_model_name = "gpt2-large"
    story_tokenizer = AutoTokenizer.from_pretrained(story_model_name, padding_side='left')
    story_tokenizer.pad_token = story_tokenizer.eos_token
    
    story_model = AutoModelForCausalLM.from_pretrained(
        story_model_name,
        torch_dtype=torch.float16,
        device_map="auto",
        pad_token_id=story_tokenizer.eos_token_id
    ).to(device)
    
    print(f'Models loaded successfully on {device}!')
    
    def generate_horror_story(caption):
        """Generate a longer, structured horror story"""
        prompt = f"""Scene: {caption}

Write a terrifying horror story with the following structure:
1. Set the dark, eerie atmosphere
2. Introduce the trapped characters
3. Build tension through mysterious sounds and events
4. Create a climactic confrontation
5. End with a chilling revelation

Story:

"""
        
        inputs = story_tokenizer.encode(prompt, return_tensors="pt", add_special_tokens=True)
        attention_mask = torch.ones(inputs.shape, device=device)
        
        outputs = story_model.generate(
            inputs.to(device),
            attention_mask=attention_mask,
            max_length=800,  # Longer story
            min_length=400,  # Ensure minimum length
            do_sample=True,  # Enable sampling
            temperature=0.85,
            top_p=0.9,
            repetition_penalty=1.2,
            no_repeat_ngram_size=3,
            num_beams=5,
            early_stopping=True,
            pad_token_id=story_tokenizer.eos_token_id
        )
        
        story = story_tokenizer.decode(outputs[0], skip_special_tokens=True)
        return story[len(prompt):].strip()
    
    @app.after_request
    def after_request(response):
        response.headers.add('Access-Control-Allow-Origin', '*')
        response.headers.add('Access-Control-Allow-Headers', 'Content-Type')
        response.headers.add('Access-Control-Allow-Methods', 'GET,POST,OPTIONS')
        response.headers.add('Connection', 'close')
        return response
    
    @app.route('/generate_story', methods=['POST', 'OPTIONS'])
    def generate_story():
        if request.method == 'OPTIONS':
            return '', 204
        try:
            if 'image' not in request.files:
                return jsonify({'status': 'error', 'message': 'No image provided'}), 400
                
            image_file = request.files['image']
            image = Image.open(io.BytesIO(image_file.read()))
            
            caption = image_captioner(image)[0]['generated_text']
            story = generate_horror_story(caption)
            
            return jsonify({
                'status': 'success',
                'caption': caption,
                'story': story
            })
            
        except Exception as e:
            return jsonify({'status': 'error', 'message': str(e)}), 500

except Exception as e:
    print(f"Error setting up server: {str(e)}")
    raise e

if __name__ == '__main__':
    # Start ngrok when the app starts
    ngrok.kill()  # Kill any existing tunnels
    public_url = ngrok.connect(
        addr=PORT,
        proto="http"  # Force HTTP protocol
    )
    print(f'Ngrok tunnel established! Public URL: {public_url}')
    # Run app locally
    app.run(host='127.0.0.1', port=PORT)
