In [None]:
import os, json

root = "../datasets/circles/png"  
with open(os.path.join(root, "labels.json"), "r") as f:
    meta = json.load(f)


svgs = []
for fname, x, y, r in meta:
    x, y, r = int(x), int(y), int(r)
    svgs.append(f'<circle cx="{x}" cy="{y}" r="{r}" fill="black"/>')

# Собираем уникальные символы
chars = sorted(set("".join(svgs)))

PAD, SOS, EOS = "<PAD>", "<SOS>", "<EOS>"
vocab = [PAD, SOS, EOS] + chars
char2idx = {c:i for i,c in enumerate(vocab)}
idx2char = {i:c for c,i in char2idx.items()}
vocab_size = len(vocab)
print("vocab_size =", vocab_size, "chars:", chars)
print("PAD,SOS,EOS idx:", char2idx[PAD], char2idx[SOS], char2idx[EOS])


vocab_size = 30 chars: [' ', '"', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '<', '=', '>', 'a', 'b', 'c', 'e', 'f', 'i', 'k', 'l', 'r', 'x', 'y']
PAD,SOS,EOS idx: 0 1 2


In [None]:
from torch.utils.data import Dataset
from PIL import Image
import torch

class CircleSVGDataset(Dataset):
    def __init__(self, root, clip_preprocess, char2idx):
        self.root = root
        self.prep = clip_preprocess
        self.char2idx = char2idx
        with open(os.path.join(root, "labels.json"), "r") as f:
            self.meta = json.load(f)

    def __len__(self):
        return len(self.meta)

    def __getitem__(self, idx):
        fname, x, y, r = self.meta[idx]

        img = Image.open(os.path.join(self.root, fname)).convert("RGB")
        img_t = self.prep(img)  # 3×224×224
        s = f'<circle cx="{int(x)}" cy="{int(y)}" r="{int(r)}" fill="black"/>'
        seq = [self.char2idx[SOS]] + [self.char2idx[c] for c in s] + [self.char2idx[EOS]]
        return img_t, torch.tensor(seq, dtype=torch.long)


In [None]:
def collate_fn(batch):
    imgs, seqs = zip(*batch)
    imgs = torch.stack(imgs)   # B×3×224×224

    lengths = [len(s) for s in seqs]
    max_len = max(lengths)


    batch_seq = torch.full((len(seqs), max_len), fill_value=char2idx[PAD], dtype=torch.long)
    for i, s in enumerate(seqs):
        batch_seq[i, :lengths[i]] = s


    input_seq  = batch_seq[:, :-1]
    target_seq = batch_seq[:, 1:]
    return imgs, input_seq, target_seq, lengths


In [None]:
import torch.nn as nn
import torch

class ClipSVGGenerator(nn.Module):
    def __init__(self, clip_model, vocab_size,
                 embed_dim=128, hidden_dim=256, num_layers=2):
        super().__init__()
        # frozen CLIP visual
        self.visual = clip_model.visual
        for p in self.visual.parameters():
            p.requires_grad = False

        self.num_layers = num_layers
        self.hidden_dim = hidden_dim

        # инициализация h0, c0
        self.fc_h = nn.Linear(512, hidden_dim * num_layers)
        self.fc_c = nn.Linear(512, hidden_dim * num_layers)

        # декодер
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=char2idx[PAD])
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
        self.out = nn.Linear(hidden_dim, vocab_size)

    def forward(self, images, input_seq):
        B, T = input_seq.shape

        # encode
        feat = self.visual(images)       # B×512

        # init hidden/cell
        h0 = torch.tanh(self.fc_h(feat)) # B×(H×num_layers)
        c0 = torch.tanh(self.fc_c(feat))
        # переупакуем в форму (num_layers, B, hidden_dim)
        h0 = h0.view(self.num_layers, B, self.hidden_dim)
        c0 = c0.view(self.num_layers, B, self.hidden_dim)

        # векторизуем входную последовательность
        emb = self.embedding(input_seq)  # B×T×embed_dim

        # прогоняем через LSTM
        out, _ = self.lstm(emb, (h0, c0)) # out: B×T×hidden_dim

        # в логиты
        logits = self.out(out)           # B×T×vocab_size
        return logits


In [8]:
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader

device = "cuda" if torch.cuda.is_available() else "cpu"

# загружаем CLIP
import clip
model_clip, preprocess = clip.load("ViT-B/32", device=device)

# датасет и загрузчик
ds = CircleSVGDataset(root, preprocess, char2idx)
loader = DataLoader(ds, batch_size=32, shuffle=True,
                    num_workers=2, collate_fn=collate_fn)

# модель, оптимизатор, loss
model = ClipSVGGenerator(model_clip, vocab_size).to(device).float()
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)
criterion = nn.CrossEntropyLoss(ignore_index=char2idx[PAD])

n_epochs = 10
for epoch in range(n_epochs):
    model.train()
    total_loss = 0
    for imgs, input_seq, target_seq, lengths in loader:
        imgs = imgs.to(device)
        input_seq  = input_seq.to(device)
        target_seq = target_seq.to(device)

        logits = model(imgs, input_seq)  # B×T×V
        B, T, V = logits.shape

        loss = criterion(logits.view(-1, V), target_seq.view(-1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * B

    print(f"Epoch {epoch+1}/{n_epochs}, loss = {total_loss/len(ds):.6f}")


Epoch 1/10, loss = 0.735000
Epoch 2/10, loss = 0.312722
Epoch 3/10, loss = 0.310051
Epoch 4/10, loss = 0.309134
Epoch 5/10, loss = 0.308493
Epoch 6/10, loss = 0.307973
Epoch 7/10, loss = 0.307744
Epoch 8/10, loss = 0.306972
Epoch 9/10, loss = 0.304880
Epoch 10/10, loss = 0.302720


In [None]:
import torch

def generate_svg(model, image, max_len=100):
    model.eval()
    with torch.no_grad():
        img_t = preprocess(image).unsqueeze(0).to(device)

        feat = model.visual(img_t)
        h = torch.tanh(model.fc_h(feat)).view(model.num_layers, 1, model.hidden_dim)
        c = torch.tanh(model.fc_c(feat)).view(model.num_layers, 1, model.hidden_dim)

        input_id = torch.tensor([[char2idx[SOS]]], device=device)
        result = []

        for _ in range(max_len):
            emb = model.embedding(input_id)
            out, (h, c) = model.lstm(emb, (h, c))
            logits = model.out(out.squeeze(1))
            next_id = logits.argmax(dim=-1)

            ch = idx2char[next_id.item()]
            if ch == EOS:
                break
            result.append(ch)
            input_id = next_id.unsqueeze(0)

        return "".join(result)


img = Image.open("../datasets/circles/png/00010.png").convert("RGB")
print(generate_svg(model, img))

<circle cx="100" cy="100" r="72" fill="black"/>


In [None]:
from PIL import Image


In [None]:
# конвертация изображения по строке

def convert_circle(path, model, out_dir = ""):
    width = height = 224
    img = Image.open(path).convert("RGB")

    circle_content = generate_svg(model, img)
    print(circle_content)

    svg_content = f'''<?xml version="1.0" encoding="utf-8"?>
<svg version="1.1" id="Layer_1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" x="0px" y="0px"
    width="{width}px" height="{height}px" viewBox="0 0 {width} {height}" style="enable-background:new 0 0 {width} {height};" xml:space="preserve">
    <rect width="{width}" height="{height}" fill="white"/>
    <g>
        {circle_content}
    </g>
</svg>'''
       
    with open(os.path.join(out_dir, "converted.svg"), 'w') as file:
        file.write(svg_content)

model.eval()

convert_circle("../datasets/circles/png/00010.png", model)


<circle cx="100" cy="100" r="72" fill="black"/>
