<a href="https://colab.research.google.com/github/GandlinAlexandr/ApPyHW1/blob/main/GAN/gan-text-plus-metrics.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Библиотеки

In [None]:
!pip install datasets
!python -m spacy download en_core_web_md

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from tqdm import tqdm
import os
import torch.nn.functional as F

import torch.nn.functional as F
from torchvision.models import inception_v3
from torchvision import transforms
import torchvision.models as models

from scipy.linalg import sqrtm

## Метрики

In [None]:
def inception_score(images, batch_size=32, splits=10):
    """
    Расчёт Inception для набора картинок.

    Args:
        images (torch.Tensor): Тензор (N, 3, H, W).
        batch_size (int): Размер батчей для InceptionV3.
        splits (int): Количество разбиений для расчёта IS.

    Returns:
        float: Inception Score.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    N = len(images)

    # Загрузка модели InceptionV3
    inception_model = inception_v3(
        weights=models.Inception_V3_Weights.DEFAULT, transform_input=False
    ).to(device)
    inception_model.eval()

    # Предобработка картинок - убрал нормализацию - наш вход уже нормализован
    transform = transforms.Compose(
        [
            transforms.Resize(
                (299, 299)
            )  # Необходимо, так как InceptionV3 обучалась именно на таких размерах
        ]
    )
    images = torch.stack([transform(img) for img in images])

    # Расчёт предсказаний
    preds = []
    for i in range(0, N, batch_size):
        batch = images[i : i + batch_size].to(device)
        with torch.no_grad():
            preds.append(F.softmax(inception_model(batch), dim=1))
    preds = torch.cat(preds, dim=0).cpu().numpy()

    # Расчёт Inception Score
    split_scores = []
    for k in range(splits):
        part = preds[k * (N // splits) : (k + 1) * (N // splits), :]
        py = np.mean(part, axis=0)
        scores = [np.sum(p * (np.log(p) - np.log(py))) for p in part]
        split_scores.append(np.exp(np.mean(scores)))

    return np.mean(split_scores), np.std(split_scores)

In [None]:
def calculate_fid(real_images, generated_images, batch_size=32):
    """
    Вычисление Frechet Inception Distance (FID).

    Args:
        real_images (torch.Tensor): Тензор реальных изображений (N, 3, H, W).
        generated_images (torch.Tensor): Тензор сгенерированных изображений (M, 3, H, W).
        batch_size (int): Размер батча для модели InceptionV3, которая и рассчитывает метрики.

    Returns:
        float: FID.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    inception_model = inception_v3(
        weights=models.Inception_V3_Weights.DEFAULT, transform_input=False
    ).to(device)
    inception_model.eval()

    def get_activations(images):
        """Извлечение фич из InceptionV3."""
        activations = []
        for i in range(0, len(images), batch_size):
            batch = images[i : i + batch_size].to(device)
            with torch.no_grad():
                features = inception_model(batch).detach()
                activations.append(features.cpu())
        return torch.cat(activations, dim=0).numpy()

    # Обработка images - убрал нормализацию - наш вход уже нормализован
    transform = transforms.Compose(
        [
            transforms.Resize(
                (299, 299)
            )  # Необходимо, так как InceptionV3 обучалась именно на таких размерах
        ]
    )
    real_images = torch.stack([transform(img) for img in real_images])
    generated_images = torch.stack([transform(img) for img in generated_images])

    # Извлечение активаций (как я понимаю, из слоя нейросети InceptionV3)
    act_real = get_activations(real_images)
    act_gen = get_activations(generated_images)

    # Расчёт статистик
    mu_real, sigma_real = act_real.mean(axis=0), np.cov(act_real, rowvar=False)
    mu_gen, sigma_gen = act_gen.mean(axis=0), np.cov(act_gen, rowvar=False)

    # Расчёт FID
    diff = mu_real - mu_gen
    covmean = sqrtm(sigma_real @ sigma_gen).real
    fid = diff.dot(diff) + np.trace(sigma_real + sigma_gen - 2 * covmean)

    return fid

## Обработка входных данных

In [None]:
import spacy

spacy.prefer_gpu()
nlp = spacy.load("en_core_web_md")

In [None]:
device = ""
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [None]:
class TransformDataset(Dataset):
    def __init__(self, df, new_size, nlp_model):
        self.df = df
        self.new_size = new_size
        self.transform = transforms.Compose(
            [
                transforms.Resize(new_size),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        )
        self.nlp_model = nlp_model

    def __len__(self):  # без len загрузчик не работает
        return len(self.df)

    def __getitem__(self, idx):

        # Изображение
        image = self.df[idx]["image"]  # загружаем изображение и текст по индексу
        image = image.convert("RGB")  # преобразуем в трехканальное
        image = self.transform(image)  # применяем трансформации

        # Текст
        text = self.df[idx]["text"]
        doc = nlp(text)  # Обработка текста
        lemmatized_tokens = [
            token.lemma_ for token in doc if not token.is_punct and not token.is_space
        ]  # Лемматизация
        vectors = [
            token.vector.get()
            for token in doc
            if not token.is_punct and not token.is_space
        ]  # Векторизация

        return image, torch.tensor(np.mean(vectors, axis=0)[:vect_size])

In [None]:
def text_plus_image(text, image):
    embedding = text.unsqueeze(-1).unsqueeze(-1)
    text_embedding_tensor = F.interpolate(
        embedding, size=(28, 28), mode="bilinear", align_corners=False
    )
    combined_input = torch.cat((text_embedding_tensor, image), dim=1)
    return combined_input

In [None]:
dataset = load_dataset("iamkaikai/amazing_logos_v4", split="train")

## Экземпляры моделей

In [None]:
# ГЕНЕРАТОР
class Generator(nn.Module):
    def __init__(self, z_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(
                z_dim + vect_size, 128, kernel_size=7, stride=1, padding=0, bias=False
            ),  # тут получается не совсем монотонное расширение,
            nn.BatchNorm2d(128),  # поэтому можно увеличить z_dim
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1, bias=False),
            nn.Tanh(),
        )

    def forward(self, z):
        z = z.view(z.size(0), z.size(1), 1, 1)
        img = self.model(z)
        return img

In [None]:
# ДИСКРИМИНАТОР
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(
                3 + vect_size, 64, kernel_size=4, stride=2, padding=1, bias=False
            ),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 1, kernel_size=7, stride=1, padding=0, bias=False),
            nn.Sigmoid(),
        )

    def forward(self, img):
        validity = self.model(img)
        return validity.view(-1, 1)

## Параметры

In [None]:
img_shape = (3, 28, 28)  # форма изображений (3 канала, 28 на 28 пикселей)
z_dim = 100  # размер входного вектора шума
batch_size = 32  # размер батча
epochs = 50  # количество эпох
vect_size = 300  # размер вектора эмбеддинга

In [None]:
adversarial_loss = nn.BCELoss()
generator = Generator(z_dim).to(device)
discriminator = Discriminator().to(device)

optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [None]:
train_dataset = TransformDataset(
    dataset.select(range(296951, 397251)), new_size=(28, 28), nlp_model=nlp
)
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True
)

## Обучение

In [None]:
def train_step(real_images, real_text):
    real_images = real_images.to(device)
    batch_size = real_images.size(0)
    real_text = real_text.to(device)
    # Тренировка дискриминатора
    optimizer_D.zero_grad()

    real_labels = torch.ones(batch_size, 1).to(device)
    fake_labels = torch.zeros(batch_size, 1).to(device)
    real_imagess = text_plus_image(real_text, real_images)
    real_output = discriminator(real_imagess)
    real_loss = adversarial_loss(real_output, real_labels)

    z = torch.cat([torch.randn(batch_size, z_dim).to(device), real_text], 1)
    fake_images = generator(z)
    fake_imagess = text_plus_image(real_text, fake_images)
    fake_output = discriminator(fake_imagess.detach())
    fake_loss = adversarial_loss(fake_output, fake_labels)
    d_loss = real_loss + fake_loss
    d_loss.backward()
    optimizer_D.step()

    # Тренировка генератора
    optimizer_G.zero_grad()

    fake_output = discriminator(fake_imagess)
    g_loss = adversarial_loss(fake_output, real_labels)

    g_loss.backward()
    optimizer_G.step()

    return d_loss.item(), g_loss.item(), fake_images

In [None]:
indexes_to_generate = np.random.randint(0, len(dataset), 16)
text_to_generate = dataset.select(indexes_to_generate)
text_to_generate = TransformDataset(text_to_generate, new_size=(28, 28), nlp_model=nlp)
txt_to_generate = torch.empty(size=(16, 300))
for idx, (i, j) in enumerate(text_to_generate):
    txt_to_generate[idx] = j

In [None]:
def generate_and_save_images(model, epoch):
    z = torch.cat([torch.randn(16, 100), txt_to_generate], 1).to(device)
    generated_images = model(z).detach().cpu()

    plt.figure(figsize=(4, 4))

    for i in range(generated_images.size(0)):
        plt.subplot(4, 4, i + 1)
        plt.imshow(
            generated_images[i].permute(1, 2, 0).numpy() * 0.5 + 0.5
        )  # Денормализация
        plt.axis("off")

    plt.savefig(f"/kaggle/working/image_at_epoch_{epoch}.png")
    plt.show()

In [None]:
import pandas as pd

metrics_data = pd.DataFrame(
    columns=[
        "epoch",
        "g_loss",
        "d_loss",
        "inception_score_mean",
        "inception_score_std",
        "fid_score",
    ]
)

In [None]:
def train(epochs, train_loader):
    metrics_data = pd.DataFrame()
    for epoch in range(epochs):
        for i, (image, text) in enumerate(train_loader):
            d_loss, g_loss, fake_images = train_step(image, text)
            if i % 10 == 0:
                print(
                    f"Epoch [{epoch}/{epochs - 1}] Batch [{i}/{len(train_loader) - 1}] D_loss: {d_loss:.4f}, G_loss: {g_loss:.4f}"
                )

        if epoch % 1 == 0:

            generate_and_save_images(generator, epoch)

            is_mean, is_std = inception_score(
                images=fake_images, batch_size=batch_size, splits=10
            )
            print(f"Epoch {epoch} - Inception Score: {is_mean:.4f} ± {is_std:.4f}")

            fid_score = calculate_fid(
                real_images=image, generated_images=fake_images, batch_size=batch_size
            )
            print(f"Epoch {epoch} - FID Score: {fid_score:.4f}")

            new_row = pd.DataFrame(
                [
                    {
                        "epoch": epoch,
                        "g_loss": g_loss,
                        "d_loss": d_loss,
                        "inception_score_mean": is_mean,
                        "inception_score_std": is_std,
                        "fid_score": fid_score,
                    }
                ]
            )
            metrics_data = pd.concat([metrics_data, new_row], ignore_index=True)

        if (epoch % 5) == 0 or (epoch == epochs - 1):
            torch.save(
                generator.state_dict(),
                os.path.join("/kaggle/working/generator_{}.pth".format(epoch)),
            )
            torch.save(
                discriminator.state_dict(),
                os.path.join("/kaggle/working/discriminator_{}.pth".format(epoch)),
            )
            metrics_data.to_csv("/kaggle/working/metrics_data.csv", index=False)

In [None]:
train(epochs, train_loader)