In [30]:
!pip install diffusers
!pip install nltk
import nltk
nltk.download('punkt')
import nltk
nltk.download('averaged_perceptron_tagger')



[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!


True

In [31]:
!pip install streamlit



In [32]:
%%writefile app.py
import streamlit as st
import torch
from diffusers import SemanticStableDiffusionPipeline
import matplotlib.pyplot as plt
import nltk
from nltk.tokenize import word_tokenize
from nltk.tag import pos_tag
import io
import networkx as nx
import spacy

nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')

# Load English tokenizer, tagger, parser, NER, and word vectors
nlp = spacy.load("en_core_web_sm")

wave_css = """
.stApp > header {
    background-color: transparent;
}

.stApp {
    background: linear-gradient(45deg,  #E7717D 40%, #FAEBD7 45%, #C2B9B0 55%, #91BAD6 10%);
    animation: my_animation 20s ease infinite;
    background-size: 200% 200%;
    background-attachment: fixed;
}

@keyframes my_animation {
    0% {background-position: 0% 0%;}
    50% {background-position: 100% 100%;}
    100% {background-position: 0% 0%;}
}
"""

def preprocess_text(text):
    return nlp(text)

def text_to_prompts(text):
    # Split the text into sentences using full stops as delimiters
    sentences = text.split(".")

    # Remove any empty strings resulting from consecutive full stops
    sentences = [sentence.strip() for sentence in sentences if sentence.strip()]

    return sentences

def find_adjective_noun_in_prompts(prompts):
    all_adjective_noun_pairs = []
    for prompt in prompts:
        tokens = word_tokenize(prompt)
        tagged_words = pos_tag(tokens)

        adjective_noun_pairs = []
        for i in range(len(tagged_words) - 1):
            word, tag = tagged_words[i]
            next_word, next_tag = tagged_words[i + 1]
            if tag.startswith('JJ') and next_tag.startswith('NN'):
                adjective_noun_pairs.append((word, next_word))

        all_adjective_noun_pairs.extend(adjective_noun_pairs)
    return all_adjective_noun_pairs

def generate_scene_graph(text):
    # Preprocess the text
    doc = preprocess_text(text)

    # Create a directed graph using NetworkX
    graph = nx.DiGraph()

    # Iterate through the tokens in the sentence
    for token in doc:
        # Add nodes for each token with POS tags
        graph.add_node(token.text, pos=token.pos_)

    # Add edges based on syntactic dependencies
    for token in doc:
        for child in token.children:
            graph.add_edge(token.text, child.text, dep=token.dep_)

    return graph

def visualize_scene_graph(graph):
    plt.figure()
    # Visualize the scene graph using NetworkX
    pos = nx.spring_layout(graph)

    # Create labels with word and POS information
    node_labels = {node: f"{node}\n{data['pos']}" for node, data in graph.nodes(data=True)}

    # Draw nodes with labels
    nx.draw(graph, pos, with_labels=True, labels=node_labels, font_weight='bold',
            node_color='skyblue', node_size=1500, font_size=7, edge_color='gray', linewidths=0.5)

    # Save the plot to a buffer
    buffer = io.BytesIO()
    plt.savefig(buffer, format='png')  # Explicitly set the format to PNG
    buffer.seek(0)

    return buffer

def main():
    st.title("GraphPix: Sequential Scene Synthesis For Objects")

    st.markdown(f"<style>{wave_css}</style>", unsafe_allow_html=True)

    # Text input box
    text_input = st.text_area("Enter the text paragraph:")

    if st.button("Generate Images"):
        # Load the Diffusion Model pipeline
        pipe = SemanticStableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5").to(device='cuda')

        # Define the prompts
        prompts = text_to_prompts(text_input)

        # Find adjective-noun pairs in prompts
        all_adjective_noun_pairs = find_adjective_noun_in_prompts(prompts)

        # Construct editing prompts using adjective-noun pairs
        editing_prompts = [", ".join(pair) for pair in all_adjective_noun_pairs]

        # Set up the generator
        gen = torch.Generator(device='cuda')
        gen.manual_seed(21)

        # Generate images based on the prompts and editing prompts
        images = []
        scene_graphs = []
        for prompt, editing_prompt in zip(prompts, editing_prompts):
            try:
                out = pipe(
                    prompt=prompt,
                    generator=gen,
                    num_images_per_prompt=1,
                    guidance_scale=7,
                    editing_prompt=[editing_prompt],
                    reverse_editing_direction=[False],
                    edit_warmup_steps=[10],
                    edit_guidance_scale=[4],
                    edit_threshold=[0.99],
                    edit_momentum_scale=0.3,
                    edit_mom_beta=0.6,
                    edit_weights=[1]
                )
                images.extend(out.images)

                # Generate scene graph for the current prompt
                scene_graph = generate_scene_graph(prompt)
                scene_graphs.append(scene_graph)
            except Exception as e:
                st.error(f"Error generating images: {e}")

        # Display all generated images with option to show scene graph
        num_images = len(images)
        st.text(f"Number of generated images: {num_images}")

        for i, (image, scene_graph) in enumerate(zip(images, scene_graphs)):
            st.image(image, caption=f"Generated Image {i+1}", use_column_width=True, width=400)

            # Icon symbol next to each generated image
            icon_expander = st.empty()
            with icon_expander:
                with st.expander(f"Show Scene Graph {i+1}", expanded=False):
                    st.image(visualize_scene_graph(scene_graph), caption="Scene Graph", use_column_width=True)

if __name__ == "__main__":
    main()

Overwriting app.py


In [33]:
!streamlit run app.py &>/dev/null&

In [34]:
!pip install pyngrok



In [35]:
!ngrok authtoken 2cM6tQosgyOMLeO5fgtuaayoTXJ_txjdzgiQnRspg5omXTfw

Authtoken saved to configuration file: /root/.config/ngrok/ngrok.yml


In [36]:
from pyngrok import ngrok
import streamlit as st

# Get the port number where Streamlit is running
port = st.get_option("server.port")

# Setup a tunnel to the streamlit port
public_url = ngrok.connect(port)
public_url

<NgrokTunnel: "https://9613-35-239-246-11.ngrok-free.app" -> "http://localhost:8501">