In [1]:
# !pip isntall -r requirements.txt

In [None]:
from PIL import Image, ImageDraw, ImageFont
import matplotlib.pyplot as plt
import numpy as np
import cv2
import random

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.utils import save_image

import importlib
import config
import os
import datetime

importlib.reload(config)
# from datasets import load_dataset


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
# if device != "cuda":
#     raise Exception("Can't launch CUDA")

In [4]:
# try:
#     dataset = load_dataset('DonkeySmall/OCR-English-Printed-12', split='train[:1%]')
#     print("Данные загружены успешно!")
# except Exception as e:
#     print(f"Ошибка при загрузке набора данных: {e}")

In [5]:
image_size = 32
batch_size = 32
epochs = 200
lr = 0.0002

In [6]:
def cv2_show(img, key='q', time=0, window_name='cv2'):
    print(time)
    cv2.imshow(window_name, img)
    if time:
        cv2.waitKey(time)
        cv2.destroyAllWindows()
    else:
        printed_key = cv2.waitKey(0)
        if printed_key == key:
            cv2.destroyAllWindows()


def stack_images(images, direction='vertical'):
    if not images or any(img is None for img in images):
        raise ValueError(
            "Список изображений пуст или содержит недопустимые значения")

    if direction == 'horizontal':
        min_height = min(img.shape[0] for img in images)
        resized_images = [cv2.resize(img, (int(img.shape[1] * (min_height / img.shape[0])), min_height))
                          for img in images]
        stacked_image = np.hstack(resized_images)
    elif direction == 'vertical':
        min_width = min(img.shape[1] for img in images)
        resized_images = [cv2.resize(img, (min_width, int(img.shape[0] * (min_width / img.shape[1]))))
                          for img in images]
        stacked_image = np.vstack(resized_images)
    else:
        raise ValueError("direction должен быть 'horizontal' или 'vertical'")

    return stacked_image


def add_noise_and_distortion(img):
    noise = np.random.uniform(0, 50, img.shape).astype(np.uint8)
    noisy_img = cv2.add(img, noise)

    # искажения с помощью линий
    num_lines = np.random.randint(0, 10)
    for _ in range(num_lines):
        x1, y1 = np.random.randint(
            0, img.shape[1]), np.random.randint(0, img.shape[0])
        x2, y2 = np.random.randint(
            0, img.shape[1]), np.random.randint(0, img.shape[0])
        cv2.line(noisy_img, (x1, y1), (x2, y2), (0, 0, 0), 1)

    return noisy_img

In [7]:
def generate_text_image(text, img=None, noise=False, max_char_y_offset=0.025, max_char_x_offset=0.025):
    gen_height, gen_width = image_size, image_size * 4

    if img is None:
        img = np.ones((gen_height, gen_width), dtype=np.uint8) * 255  # белый фон

    # Используем первый шрифт из списка
    font_path = config.fonts_paths[0]
    font_size = random.randint(20, 35)
    font = ImageFont.truetype(font_path, font_size)

    pil_img = Image.fromarray(img, mode='L')  # Устанавливаем режим 'L' для одноканального изображения
    draw = ImageDraw.Draw(pil_img)

    # Находим позицию текста
    bbox = draw.textbbox((0, 0), text, font=font)
    text_size = (bbox[2] - bbox[0], bbox[3] - bbox[1])

    # Уменьшаем размер шрифта, если текст не помещается
    while text_size[0] > img.shape[1] or text_size[1] > img.shape[0]:
        font_size -= 1
        font = ImageFont.truetype(font_path, font_size)
        bbox = draw.textbbox((0, 0), text, font=font)
        text_size = (bbox[2] - bbox[0], bbox[3] - bbox[1])

    text_x = (img.shape[1] - text_size[0]) // 2
    text_y = (img.shape[0] - text_size[1]) // 2

    # Рисуем текст с небольшими случайными смещениями
    for char in text:
        bbox = draw.textbbox((text_x, text_y), char, font=font)
        char_width = bbox[2] - bbox[0]
        char_height = bbox[3] - bbox[1]

        x_offset = random.randint(int(-char_width * max_char_x_offset), int(char_width * max_char_x_offset))
        y_offset = random.randint(int(-char_height * max_char_y_offset), int(char_height * max_char_y_offset))
        draw.text((text_x + x_offset, text_y + y_offset), char, font=font, fill=0)  # Черный текст

        text_x += char_width  # Сдвигаем на ширину текущей буквы

    img = np.array(pil_img)  # Преобразуем обратно в формат OpenCV

    return add_noise_and_distortion(img) if noise else img

In [8]:
# images = []

# for _ in range(6):
#     text = random.choice(config.texts)
#     result = generate_text_image(text, noise=False)
#     images.append(result[0])

# for _ in range(6):
#     text = random.choice(config.texts)
#     result = generate_text_image(text, noise=True)
#     images.append(result[0])

# stacked_images = stack_images(images)
# cv2_show(stacked_images, time=5000)


# for _ in range(1):
#     img = generate_text_image(random.choice(config.texts), noise=False)
#     cv2_show(img, time=5000)

In [9]:
from torchvision import transforms
from PIL import Image


transform = transforms.Compose([
    transforms.Resize((image_size, image_size * 4)),
    transforms.ToTensor()
])


class TextDataset(Dataset):
    def __init__(self, texts):
        self.num_samples = len(texts)
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        self.texts = texts

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        text = self.texts[idx]
        img = generate_text_image(text)
        # print("Real image shape:", img.shape)
        img = self.transform(img)
        return img

In [10]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(image_size * (image_size * 4), 1024),
            # nn.ReLU(),
            # nn.Linear(4096, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        x = x.view(x.size(0), image_size * (image_size * 4))
        output = self.model(x)
        return output


class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(image_size * (image_size * 4), 32_768),
            # nn.ReLU(),
            # nn.Linear(32_768, 32_768),
            nn.ReLU(),
            nn.Linear(32_768, image_size * (image_size * 4)),
            nn.Tanh(),
        )

    def forward(self, x):
        output = self.model(x)
        output = output.view(x.size(0), 1, image_size, image_size * 4)
        return output

In [None]:
from IPython.display import clear_output



dataset = TextDataset(config.texts)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


netG = Generator().to(device)
netD = Discriminator().to(device)

criterion = nn.BCELoss()
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(0.5, 0.999))

output_dir = datetime.datetime.now().strftime("TEXT_GAN_%Y-%m-%d_%H-%M-%S")
os.makedirs(output_dir, exist_ok=True)

for epoch in range(epochs):
    for i, data in enumerate(dataloader):
        netD.zero_grad()
        real = data.to(device)
        # print(real.shape)
        batch_size = real.size(0)
        labels = torch.full((batch_size,), 1.0, dtype=torch.float, device=device)
        output = netD(real).view(-1)
        # print(output.shape, labels.shape)
        lossD_real = criterion(output, labels)
        lossD_real.backward()

        noise = real.view(batch_size, 1, image_size * (image_size * 4))
        fake = netG(noise)
        labels.fill_(0.0)
        output = netD(fake.detach()).view(-1)
        lossD_fake = criterion(output, labels)
        lossD_fake.backward()
        optimizerD.step()

        netG.zero_grad()
        labels.fill_(1.0)
        output = netD(fake).view(-1)
        lossG = criterion(output, labels)
        lossG.backward()
        optimizerG.step()
        # print(f"Batch {i} / {batch_size}")

    print(f"Epoch [{epoch}/{epochs}] Loss D: {lossD_real + lossD_fake}, Loss G: {lossG}")
    save_image(fake.data, f"{output_dir}/fake_samples_epoch_{epoch}.png", normalize=True)
    # avg_discriminator_loss = sum_discriminator_loss / batches
    # print(f"Elapsed time: {time.time() - start} seconds")
    # print(f"Avg D Loss: {avg_discriminator_loss} Avg D Loss: {avg_generator_loss}")
    fake = fake.cpu().detach()
    real = real.cpu().detach()

    plt.figure(figsize=(10, 5))
    for i in range(8):
        ax = plt.subplot(4, 2, i + 1)
        plt.imshow(real[i].reshape(image_size, image_size * 4), cmap="gray_r")
        plt.xticks([])
        plt.yticks([])
    plt.title(f"{epoch} epoch")
    plt.savefig(f"{output_dir}/{epoch}.png")
    plt.show()

    plt.figure(figsize=(10, 5))
    for i in range(8):
        ax = plt.subplot(4, 2, i + 1)
        plt.imshow(fake[i].reshape(image_size, image_size * 4), cmap="gray_r")
        plt.xticks([])
        plt.yticks([])
    plt.title(f"{epoch} epoch")
    plt.savefig(f"{output_dir}/{epoch}.png")
    plt.show()