In [20]:
import torch
from kemsekov_torch.train import load_best_checkpoint, load_last_checkpoint
from kemsekov_torch.flow_matching import FlowMatching
import torchvision.transforms as T
fm = FlowMatching()
path = 'runs/vae-natural/'
model = torch.jit.load(path+"/model.pt")
# model = load_last_checkpoint(model,path).eval().cuda()
model = load_best_checkpoint(model,path).eval().cuda()

loading runs/vae-natural/checkpoints/epoch-30/state


In [21]:
import random
from matplotlib import pyplot as plt
from vae import decode

def sample(guidance,cls,images_count,image_shape, steps = 32,churn_scale=0.001,random_state=None,device='cuda'):
    assert len(cls)==images_count,"images count must equal length of cls"
    image_shape = [v//8 for v in image_shape]
    if random_state is not None:
        g = torch.Generator()
        g.manual_seed(random_state)
    else:
        g = None
    x0 = torch.randn([images_count,4]+image_shape,generator=g).to(device)
    
    if isinstance(cls[0],int):
        cls = torch.Tensor(cls).to(device).long()
    else:
        cls = torch.Tensor(cls).to(device)
        cls/=cls.sum()+1e-6
    
    # cls_sum = cls.sum()
    # if cls_sum>0:
    #     cls/=cls_sum+1e-6    
    model.to(device)
    def run_model(xt,t):
        pred_no_cls = model(xt,t,cls*0-1)
        if guidance==0:
            return pred_no_cls
        pred_cls =  model(xt,t,cls)
        total = pred_no_cls+guidance*(pred_cls-pred_no_cls)
        return total
    
    sample = fm.sample(run_model,x0,steps,churn_scale=churn_scale,device=device,return_intermediates=False)
    sample_dec = decode(sample,device).clip(0,1)
    return sample_dec

# sample_dec = sample(
#     guidance=4,
#     cls=[random.randint(0,len(model.classes)-1)]*2,
#     images_count=2,
#     image_shape=(256,256),
#     steps=24,
#     # random_state=0
# )
# for v in sample_dec:
#     display(T.ToPILImage()(v))

In [22]:
import os
import tkinter as tk
from tkinter import ttk, filedialog
from tkinter import Canvas
from PIL import Image, ImageTk

class ImageGeneratorApp:
    def __init__(self, root, classes):
        self.root = root
        self.root.title("Image Generator")

        # Main frame split: left (canvas) and right (controls)
        self.main_frame = tk.Frame(root)
        self.main_frame.pack(fill=tk.BOTH, expand=True)

        # Left: image canvas with scrollbar
        self.canvas_frame = tk.Frame(self.main_frame)
        self.canvas_frame.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)

        self.canvas = Canvas(self.canvas_frame, bg="white")
        self.scrollbar = ttk.Scrollbar(self.canvas_frame, orient=tk.HORIZONTAL, command=self.canvas.xview)
        self.canvas.configure(xscrollcommand=self.scrollbar.set)

        self.scrollbar.pack(side=tk.BOTTOM, fill=tk.X)
        self.canvas.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)

        self.images_container = tk.Frame(self.canvas)
        self.image_container = self.canvas.create_window((0, 0), window=self.images_container, anchor="nw")

        # --- Right side (split into top controls and bottom scrollable classes) ---
        self.controls_frame = tk.Frame(self.main_frame, padx=10, pady=10)
        self.controls_frame.pack(side=tk.RIGHT, fill=tk.Y)

        # Top fixed controls (guidance, count, size, buttons)
        self.fixed_controls = tk.Frame(self.controls_frame, padx=5, pady=5)
        self.fixed_controls.pack(side=tk.TOP, fill=tk.X)

        length = 300

        # Guidance
        tk.Label(self.fixed_controls, text="Guidance").pack(anchor="w")
        self.guidance_slider = tk.Scale(
            self.fixed_controls, from_=0, to=5, resolution=0.1, orient=tk.HORIZONTAL, length=length
        )
        self.guidance_slider.pack(anchor="w")
        self.guidance_slider.set(1)

        # Images count
        tk.Label(self.fixed_controls, text="Images count").pack(anchor="w")
        self.images_count = tk.Spinbox(self.fixed_controls, from_=1, to=20, width=5)
        self.images_count.pack(anchor="w")

        # Width and Height (must be divisible by 32)
        tk.Label(self.fixed_controls, text="Width").pack(anchor="w")
        self.width_entry = tk.Spinbox(self.fixed_controls, from_=64, to=2048, increment=32, width=7)
        self.width_entry.pack(anchor="w")
        self.width_entry.delete(0,"end")
        self.width_entry.insert(0,256)


        tk.Label(self.fixed_controls, text="Height").pack(anchor="w")
        self.height_entry = tk.Spinbox(self.fixed_controls, from_=64, to=2048, increment=32, width=7)
        self.height_entry.pack(anchor="w")
        self.height_entry.delete(0,"end")
        self.height_entry.insert(0,256)
        
        tk.Label(self.fixed_controls, text="Steps").pack(anchor="w")
        self.steps_entry = tk.Spinbox(self.fixed_controls, from_=1, to=64, increment=8, width=7)
        self.steps_entry.pack(anchor="w")
        self.steps_entry.delete(0,"end")
        self.steps_entry.insert(0,24)
        
        tk.Label(self.fixed_controls, text="Smoothness").pack(anchor="w")
        self.smoothness_entry = tk.Spinbox(self.fixed_controls, from_=0, to=0.1, increment=0.001, width=7)
        self.smoothness_entry.pack(anchor="w")
        self.smoothness_entry.delete(0,"end")
        self.smoothness_entry.insert(0,0.001)

        # Generate / Save buttons
        self.generate_btn = tk.Button(self.fixed_controls, text="Generate", command=self.generate_image)
        self.generate_btn.pack(pady=5, fill=tk.X)

        self.save_btn = tk.Button(self.fixed_controls, text="Save", command=self.save_images)
        self.save_btn.pack(pady=5, fill=tk.X)

        # Scrollable class sliders
        self.scroll_frame = tk.Frame(self.controls_frame)
        self.scroll_frame.pack(side=tk.BOTTOM, fill=tk.BOTH, expand=True)

        self.controls_canvas = tk.Canvas(self.scroll_frame, width=350)
        self.controls_canvas.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)

        self.controls_scrollbar = ttk.Scrollbar(
            self.scroll_frame, orient="vertical", command=self.controls_canvas.yview
        )
        self.controls_scrollbar.pack(side=tk.RIGHT, fill="y")

        self.controls_canvas.configure(yscrollcommand=self.controls_scrollbar.set)

        self.classes_container = tk.Frame(self.controls_canvas, padx=10, pady=10)
        self.controls_canvas.create_window((0, 0), window=self.classes_container, anchor="nw")

        def update_scrollregion(event):
            self.controls_canvas.configure(scrollregion=self.controls_canvas.bbox("all"))

        self.classes_container.bind("<Configure>", update_scrollregion)

        # Add class sliders into scrollable frame
        self.classes = classes
        self.class_sliders = {}
        for cls in classes:
            label = tk.Label(self.classes_container, text=cls.capitalize())
            label.pack(anchor="w")
            slider = tk.Scale(
                self.classes_container, from_=0, to=1, resolution=0.01, orient=tk.HORIZONTAL, length=length
            )
            slider.pack(anchor="w")
            self.class_sliders[cls] = slider

        # Keep references to images so they are not garbage-collected
        self.generated_images = []
        self.pil_images = []

    def generate_image(self):
        for widget in self.images_container.winfo_children():
            widget.destroy()
        self.generated_images.clear()
        self.pil_images.clear()

        class_values = {cls: slider.get() for cls, slider in self.class_sliders.items()}
        classes_vector = [class_values[c] for c in self.classes]
        guidance = self.guidance_slider.get()
        count = int(self.images_count.get())
        churn_scale = float(self.smoothness_entry.get())
        steps = int(self.steps_entry.get())

        # Width and Height (must be divisible by 32)
        width = int(self.width_entry.get())
        height = int(self.height_entry.get())
        if width % 32 != 0 or height % 32 != 0:
            print("⚠️ Width and height must be divisible by 32. Adjusting automatically.")
            width -= width % 32
            height -= height % 32
            self.width_entry.delete(0, tk.END)
            self.width_entry.insert(0, str(width))
            self.height_entry.delete(0, tk.END)
            self.height_entry.insert(0, str(height))

        # print("Generating images with:")
        # print("Classes:", classes_vector)
        # print("Guidance:", guidance)
        # print("Steps:", steps)
        # print("Churn_Scale:", churn_scale)
        # print("Count:", count)
        # print("Width:", width, "Height:", height)

        samples = sample(guidance,[classes_vector]*count,count,(width,height),steps,churn_scale)
        for i,s in enumerate(samples):
            img = T.ToPILImage()(s)
            self.pil_images.append(img)

            tk_img = ImageTk.PhotoImage(img)
            self.generated_images.append(tk_img)

            lbl = tk.Label(self.images_container, image=tk_img)
            lbl.grid(row=0, column=i, padx=5, pady=5)

        self.canvas.update_idletasks()
        self.canvas.config(scrollregion=self.canvas.bbox("all"))

    def save_images(self):
        if not self.pil_images:
            print("No images to save!")
            return

        folder = filedialog.askdirectory(title="Select folder to save images")
        if not folder:
            return

        for idx, img in enumerate(self.pil_images):
            filepath = os.path.join(folder, f"image_{idx+1}.png")
            img.save(filepath)
            print(f"Saved {filepath}")


if __name__ == "__main__":
    classes = model.classes

    root = tk.Tk()
    root.geometry("1600x1200")
    app = ImageGeneratorApp(root, classes)
    root.mainloop()
