In [None]:
from kemsekov_torch.common_modules import Residual
from kemsekov_torch.residual import ResidualBlock
from kemsekov_torch.attention import LinearSelfAttentionBlock,LinearCrossAttentionBlock, EfficientSpatialChannelAttention
from kemsekov_torch.attention import MultiHeadLinearAttention
import torch.nn as nn
import torch

class CrossAttention(nn.Module):
    def __init__(self,in_channels,context_channels,internal_dim=128):
        super().__init__()
        def norm(ch):
            # return nn.Identity()
            return nn.RMSNorm(ch)
            # return nn.LayerNorm(ch)
        
        self.input_2_internal = Residual([
            nn.Linear(in_channels,internal_dim)
            # norm(internal_dim)
        ])
        
        self.context_2_internal = nn.Linear(context_channels,internal_dim)
        self.time = nn.Sequential(
            nn.Linear(1,internal_dim),
            nn.ReLU(),
            nn.Linear(internal_dim,internal_dim),
        )
        self.context_norm = norm(internal_dim)

        self.sa_QKV =nn.Sequential(
            nn.Linear(
                internal_dim,
                internal_dim*3,
            )
        )
        self.sa_norm = norm(internal_dim)
        self.lsa = MultiHeadLinearAttention(
            internal_dim,
            n_heads=max(4,internal_dim//16),
            dropout=0,
            use_classic_attention=True,
            add_rotary_emb=True
        )
        
        self.cross_norm = norm(internal_dim)
        self.lca = MultiHeadLinearAttention(
            internal_dim,
            n_heads=max(4,internal_dim//16),
            dropout=0,
            use_classic_attention=True
        )
        
        self.cross_Q = nn.Sequential(
            nn.Linear(
                internal_dim,
                internal_dim,
            )
        )
        
        self.cross_KV = nn.Sequential(
            nn.Linear(
                internal_dim,
                internal_dim*2,
            )
        )
        self.mlp_norm = norm(internal_dim)
        self.mlp = Residual([
            nn.Linear(internal_dim,4*internal_dim),
            nn.GELU(),
            nn.Linear(4*internal_dim,in_channels),
        ],init_at_zero=True)
        
    def forward(self,x,context,time):
        x_input = x
        x,context = x.transpose(1,-1),context.transpose(1,-1)
        x = self.input_2_internal(x)
        context = self.context_2_internal(context)
        context=context+self.time(time)
        
        q,k,v = self.sa_QKV(self.sa_norm(x)).chunk(3,-1)
        x = self.lsa(q,k,v)[0]+x
         
        q = self.cross_Q(self.cross_norm(x))
        k,v = self.cross_KV(self.context_norm(context)).chunk(2,-1)
        x = self.lca(q,k,v)[0]+x
        
        return self.mlp(self.mlp_norm(x)).transpose(1,-1)+x_input

class FlowMatchingModel(nn.Module):
    def __init__(
        self, 
        in_channels, 
        context_dim,
        expand_dim = 128,
        residual_block_repeats = 1,
        ):
        super().__init__()
        norm = 'batch'
        self.context_dim=context_dim
        self.expand = nn.Conv2d(in_channels,expand_dim,1)
        
        self.down1 = nn.Sequential(
            ResidualBlock(expand_dim,residual_block_repeats*[expand_dim*2],4,stride=2,normalization=norm),
            # EfficientSpatialChannelAttention(expand_dim*2)
        )
        
        self.down2 = nn.Sequential(
            ResidualBlock(expand_dim*2,residual_block_repeats*[expand_dim*4],4,stride=2,normalization=norm),
            # EfficientSpatialChannelAttention(expand_dim*4)
        )
        
        self.down3 = nn.Sequential(
            ResidualBlock(expand_dim*4,residual_block_repeats*[expand_dim*8],4,stride=2,normalization=norm),
            # EfficientSpatialChannelAttention(expand_dim*8)
        )
        
        self.down4 = nn.Sequential(
            ResidualBlock(expand_dim*8,residual_block_repeats*[expand_dim*16],4,stride=2,normalization=norm),
            # EfficientSpatialChannelAttention(expand_dim*8)
        )
        self.attn4 = CrossAttention(expand_dim*16,context_dim,expand_dim*16)
        
        self.down5 = nn.Sequential(
            ResidualBlock(expand_dim*16,residual_block_repeats*[expand_dim*32],4,stride=2,normalization=norm),
            # EfficientSpatialChannelAttention(expand_dim*8)
        )
        self.attn5 = CrossAttention(expand_dim*32,context_dim,expand_dim*32)
        
        self.up1 = nn.Sequential(
            ResidualBlock(expand_dim*32,residual_block_repeats*[expand_dim*16],4,stride=2,normalization=norm).transpose(),
            # EfficientSpatialChannelAttention(expand_dim*16)
        )
        
        self.up2 = nn.Sequential(
            ResidualBlock(expand_dim*16,residual_block_repeats*[expand_dim*8],4,stride=2,normalization=norm).transpose(),
            # EfficientSpatialChannelAttention(expand_dim*8)
        )
        
        self.up3 = nn.Sequential(
            ResidualBlock(expand_dim*8,residual_block_repeats*[expand_dim*4],4,stride=2,normalization=norm).transpose(),
            # EfficientSpatialChannelAttention(expand_dim*4)
        )
        
        self.up4 = nn.Sequential(
            ResidualBlock(expand_dim*4,residual_block_repeats*[expand_dim*2],4,stride=2,normalization=norm).transpose(),
            # EfficientSpatialChannelAttention(expand_dim*2)
        )
        
        self.up5 = nn.Sequential(
            ResidualBlock(expand_dim*2,residual_block_repeats*[expand_dim],4,stride=2,normalization=norm).transpose(),
            # EfficientSpatialChannelAttention(expand_dim)
        )
        
        self.final = ResidualBlock(
            expand_dim,
            [expand_dim,in_channels],
            3,
            normalization=norm
        )

    def forward(self,x, context : torch.Tensor, time):
        if time.dim()<2:
            time = time[:,None]
        x_orig = x
        # make it wider
        time=time*5-2.5
        
        x=self.expand(x)
        
        d1 = self.down1(x)
        # d1=self.attn1(d1,context,time)
        
        d2 = self.down2(d1)
        # d2=self.attn2(d2,context,time)
        
        d3 = self.down3(d2)
        
        d4 = self.down4(d3)
        d4 = self.attn4(d4,context,time)
        
        d5 = self.down5(d4)
        d5 = self.attn5(d5,context,time)
        
        u1 = self.up1(d5)+d4
        u2 = self.up2(u1)+d3
        u3 = self.up3(u2)+d2
        u4 = self.up4(u3)+d1
        u5 = self.up5(u4)+x
        
        return self.final(u5)

In [None]:
import torch
from kemsekov_torch.train import load_best_checkpoint, load_last_checkpoint
from kemsekov_torch.flow_matching import FlowMatching
from kemsekov_torch.common_modules import wrap_submodules,CheapSequential
import torchvision.transforms as T


fm = FlowMatching()
model = FlowMatchingModel(
    3,
    512,
    expand_dim=128,
    residual_block_repeats=1
)
path = 'runs/vae-natural/'


In [None]:
from clip_emb import CLIPEmbedder
c=CLIPEmbedder(device='cpu')
embedds_cache = {}

In [None]:
import random
import PIL.Image
from matplotlib import pyplot as plt

def sample(
    guidance,
    reference_image,
    images_count,
    image_shape, 
    steps = 32,
    churn_scale=0.001,
    random_state=None,
    device='cuda',
    dtype=torch.bfloat16
):
    if random_state is not None:
        g = torch.Generator(device)
        g.manual_seed(random_state)
    else:
        g = None
        
    x0 = torch.randn([images_count,3]+list(image_shape),generator=g,device=device)
    
    if reference_image in embedds_cache:
        print("cache")
        clip_emb=embedds_cache[reference_image].to(device)
    else:
        print("compute")
        clip_emb = c.image_to_embedding(reference_image)[None,:].to(device)
        embedds_cache[reference_image]=clip_emb.cpu()
        
    context=clip_emb[[0]*images_count]
    
    m = model.to(device)
    def run_model(xt,t):
        pred_no_cls = m(xt,(context*0).to(dtype),t)
        if guidance==0:
            return pred_no_cls.float()
        pred_cls =  m(xt,context,t)
        total = pred_no_cls+guidance*(pred_cls-pred_no_cls)
        return total

    with torch.autocast(device,dtype=dtype):
        sample = fm.sample(
            run_model,
            x0,
            steps,
            churn_scale=churn_scale,
            device=device,
        )
        sample_dec = sample.sigmoid().to(device)
    return sample_dec

model = load_last_checkpoint(model,path).eval()
images_count=4
sample_dec = sample(
    guidance=10,
    reference_image='/home/vlad/Documents/data/cat/60.jpeg',
    images_count=images_count,
    image_shape=(256,256),
    steps=32,
    device='cuda',
    # churn_scale=0.01,
    # random_state=113
)

sq = int(images_count**0.5)
plt.figure(figsize=(6,6))
for i,v in enumerate(sample_dec):
    plt.subplot(sq,sq,1+i)
    plt.imshow(T.ToPILImage()(v))
    plt.axis('off')

In [None]:
import os
import tkinter as tk
from tkinter import ttk, filedialog
from tkinter import Canvas
from PIL import Image, ImageTk

class ImageGeneratorApp:
    def __init__(self, root, classes):
        self.root = root
        self.root.title("Image Generator")

        # Main frame split: left (canvas) and right (controls)
        self.main_frame = tk.Frame(root)
        self.main_frame.pack(fill=tk.BOTH, expand=True)

        # Left: image canvas with scrollbar
        self.canvas_frame = tk.Frame(self.main_frame)
        self.canvas_frame.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)

        self.canvas = Canvas(self.canvas_frame, bg="white")
        self.scrollbar = ttk.Scrollbar(self.canvas_frame, orient=tk.HORIZONTAL, command=self.canvas.xview)
        self.canvas.configure(xscrollcommand=self.scrollbar.set)

        self.scrollbar.pack(side=tk.BOTTOM, fill=tk.X)
        self.canvas.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)

        self.images_container = tk.Frame(self.canvas)
        self.image_container = self.canvas.create_window((0, 0), window=self.images_container, anchor="nw")

        # --- Right side (split into top controls and bottom scrollable classes) ---
        self.controls_frame = tk.Frame(self.main_frame, padx=10, pady=10)
        self.controls_frame.pack(side=tk.RIGHT, fill=tk.Y)

        # Top fixed controls (guidance, count, size, buttons)
        self.fixed_controls = tk.Frame(self.controls_frame, padx=5, pady=5)
        self.fixed_controls.pack(side=tk.TOP, fill=tk.X)

        length = 300

        # Guidance
        tk.Label(self.fixed_controls, text="Guidance").pack(anchor="w")
        self.guidance_slider = tk.Scale(
            self.fixed_controls, from_=0, to=5, resolution=0.1, orient=tk.HORIZONTAL, length=length
        )
        self.guidance_slider.pack(anchor="w")
        self.guidance_slider.set(1)

        # Images count
        tk.Label(self.fixed_controls, text="Images count").pack(anchor="w")
        self.images_count = tk.Spinbox(self.fixed_controls, from_=1, to=20, width=5)
        self.images_count.pack(anchor="w")

        # Width and Height (must be divisible by 32)
        tk.Label(self.fixed_controls, text="Width").pack(anchor="w")
        self.width_entry = tk.Spinbox(self.fixed_controls, from_=64, to=2048, increment=32, width=7)
        self.width_entry.pack(anchor="w")
        self.width_entry.delete(0,"end")
        self.width_entry.insert(0,256)


        tk.Label(self.fixed_controls, text="Height").pack(anchor="w")
        self.height_entry = tk.Spinbox(self.fixed_controls, from_=64, to=2048, increment=32, width=7)
        self.height_entry.pack(anchor="w")
        self.height_entry.delete(0,"end")
        self.height_entry.insert(0,256)
        
        tk.Label(self.fixed_controls, text="Steps").pack(anchor="w")
        self.steps_entry = tk.Spinbox(self.fixed_controls, from_=1, to=64, increment=8, width=7)
        self.steps_entry.pack(anchor="w")
        self.steps_entry.delete(0,"end")
        self.steps_entry.insert(0,24)
        
        tk.Label(self.fixed_controls, text="Smoothness").pack(anchor="w")
        self.smoothness_entry = tk.Spinbox(self.fixed_controls, from_=0, to=0.1, increment=0.001, width=7)
        self.smoothness_entry.pack(anchor="w")
        self.smoothness_entry.delete(0,"end")
        self.smoothness_entry.insert(0,0.001)

        # Generate / Save buttons
        self.generate_btn = tk.Button(self.fixed_controls, text="Generate", command=self.generate_image)
        self.generate_btn.pack(pady=5, fill=tk.X)

        self.save_btn = tk.Button(self.fixed_controls, text="Save", command=self.save_images)
        self.save_btn.pack(pady=5, fill=tk.X)

        # Scrollable class sliders
        self.scroll_frame = tk.Frame(self.controls_frame)
        self.scroll_frame.pack(side=tk.BOTTOM, fill=tk.BOTH, expand=True)

        self.controls_canvas = tk.Canvas(self.scroll_frame, width=350)
        self.controls_canvas.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)

        self.controls_scrollbar = ttk.Scrollbar(
            self.scroll_frame, orient="vertical", command=self.controls_canvas.yview
        )
        self.controls_scrollbar.pack(side=tk.RIGHT, fill="y")

        self.controls_canvas.configure(yscrollcommand=self.controls_scrollbar.set)

        self.classes_container = tk.Frame(self.controls_canvas, padx=10, pady=10)
        self.controls_canvas.create_window((0, 0), window=self.classes_container, anchor="nw")

        def update_scrollregion(event):
            self.controls_canvas.configure(scrollregion=self.controls_canvas.bbox("all"))

        self.classes_container.bind("<Configure>", update_scrollregion)

        # Add class sliders into scrollable frame
        self.classes = classes
        self.class_sliders = {}
        for cls in classes:
            label = tk.Label(self.classes_container, text=cls.capitalize())
            label.pack(anchor="w")
            slider = tk.Scale(
                self.classes_container, from_=0, to=1, resolution=0.01, orient=tk.HORIZONTAL, length=length
            )
            slider.pack(anchor="w")
            self.class_sliders[cls] = slider

        # Keep references to images so they are not garbage-collected
        self.generated_images = []
        self.pil_images = []

    def generate_image(self):
        for widget in self.images_container.winfo_children():
            widget.destroy()
        self.generated_images.clear()
        self.pil_images.clear()

        class_values = {cls: slider.get() for cls, slider in self.class_sliders.items()}
        classes_vector = [class_values[c] for c in self.classes]
        guidance = self.guidance_slider.get()
        count = int(self.images_count.get())
        churn_scale = float(self.smoothness_entry.get())
        steps = int(self.steps_entry.get())

        # Width and Height (must be divisible by 32)
        width = int(self.width_entry.get())
        height = int(self.height_entry.get())
        if width % 32 != 0 or height % 32 != 0:
            print("⚠️ Width and height must be divisible by 32. Adjusting automatically.")
            width -= width % 32
            height -= height % 32
            self.width_entry.delete(0, tk.END)
            self.width_entry.insert(0, str(width))
            self.height_entry.delete(0, tk.END)
            self.height_entry.insert(0, str(height))

        # print("Generating images with:")
        # print("Classes:", classes_vector)
        # print("Guidance:", guidance)
        # print("Steps:", steps)
        # print("Churn_Scale:", churn_scale)
        # print("Count:", count)
        # print("Width:", width, "Height:", height)

        samples = sample(guidance,[classes_vector]*count,count,(width,height),steps,churn_scale)
        for i,s in enumerate(samples):
            img = T.ToPILImage()(s)
            self.pil_images.append(img)

            tk_img = ImageTk.PhotoImage(img)
            self.generated_images.append(tk_img)

            lbl = tk.Label(self.images_container, image=tk_img)
            lbl.grid(row=0, column=i, padx=5, pady=5)

        self.canvas.update_idletasks()
        self.canvas.config(scrollregion=self.canvas.bbox("all"))

    def save_images(self):
        if not self.pil_images:
            print("No images to save!")
            return

        folder = filedialog.askdirectory(title="Select folder to save images")
        if not folder:
            return

        for idx, img in enumerate(self.pil_images):
            filepath = os.path.join(folder, f"image_{idx+1}.png")
            img.save(filepath)
            print(f"Saved {filepath}")


if __name__ == "__main__":
    classes = model.classes

    root = tk.Tk()
    root.geometry("1600x1200")
    app = ImageGeneratorApp(root, classes)
    root.mainloop()
