<a href="https://colab.research.google.com/github/Aapng-cmd/ML-s-Neuro/blob/main/face_gan_try.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# !ls
# !rm -rf facerecon_v2.zip faces
# !wget http://upos-repo.ru:4444/facerecon_v2.zip
# !unzip facerecon_v2.zip
# !rm facerecon_v2.zip
# !mv t faces
# !ls

In [2]:
# from google.colab import files
# files.upload()

In [3]:
import os
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms


class FaceDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.images = []
        self.labels = []
        self.load_dataset()

    def load_dataset(self):
        for dir_name in os.listdir(self.root_dir):
            angle_x, angle_y = dir_name.split("_")
            angle_x, angle_y = int(angle_x), int(angle_y)
            dir_path = os.path.join(self.root_dir, dir_name)
            for file_name in os.listdir(dir_path):
                if file_name.endswith(".jpg"):
                    img_path = os.path.join(dir_path, file_name)
                    img = cv2.imread(img_path)
                    img = cv2.resize(img, (128, 128))  # Resize images to 128x128
                    self.images.append(img)
                    self.labels.append((angle_x, angle_y))

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx]
        label = self.labels[idx]
        if self.transform:
            img = self.transform(img)
        return {
            "image": img,
            "label": label
        }


data_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

dataset = FaceDataset("faces", transform=data_transform)
data_loader = DataLoader(dataset, batch_size=32, shuffle=True)


In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable


# Define the generator network
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(2, 128)  # input layer (2) -> hidden layer (128)
        self.fc2 = nn.Linear(128, 128*128*3)  # hidden layer (128) -> output layer (128x128x3)
        self.tanh = nn.Tanh()

    def forward(self, x):
        x = torch.relu(self.fc1(x))  # activation function for hidden layer
        x = self.tanh(self.fc2(x))  # activation function for output layer
        x = x.view(-1, 3, 128, 128)  # reshape to 128x128x3
        return x

# Define the discriminator network
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(128*128*3, 128)  # input layer (128x128x3) -> hidden layer (128)
        self.fc2 = nn.Linear(128, 1)  # hidden layer (128) -> output layer (1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = torch.relu(self.fc1(x.view(-1, 128*128*3)))  # activation function for hidden layer
        x = self.sigmoid(self.fc2(x))  # activation function for output layer
        return x

# Initialize the generator and discriminator
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

generator = Generator()
discriminator = Discriminator()
if os.path.exists("generator.pth") and os.path.exists("discriminator.pth"):
    generator.load_state_dict(torch.load('generator.pth', map_location=torch.device(device)))
    discriminator.load_state_dict(torch.load('discriminator.pth', map_location=torch.device(device)))


# Define the loss functions and optimizers
criterion = nn.BCELoss()
optimizer_g = optim.Adam(generator.parameters(), lr=0.01)
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.001)




cpu


In [None]:
# Move the generator to the GPU
generator.to(device)

# Move the discriminator to the GPU
discriminator.to(device)

def checker(generator, label=None):
    import matplotlib.pyplot as plt

    # Generate fake labels
    if label == None:
        fake_labels = torch.randn(1, 2).to(device)
    else:
        fake_labels = torch.tensor(label)
    # Generate fake image
    fake_image = generator(fake_labels)

    # Convert the generated image to a numpy array
    fake_image = fake_image.detach().cpu().numpy()

    # Remove the batch dimension
    fake_image = fake_image[0]

    # Transpose the image to (height, width, channels)
    fake_image = fake_image.transpose((1, 2, 0))

    # Clip the image values to the range [0, 1]
    fake_image = np.clip(fake_image, 0, 1)

    # Display the image
    plt.imshow(fake_image)
    plt.show()


# Train the GAN
EPOCHS = 100
l = len(data_loader)
for epoch in range(EPOCHS):
    for i, data in enumerate(data_loader):
        # Get the real images and labels
        real_images, real_labels = data["image"], data["label"]
        real_images = real_images.to(device)
        real_labels = torch.tensor([label[0] for label in real_labels]).unsqueeze(1).to(device)

        # Generate fake images
        noise = torch.randn(real_labels.size(0), 2, device=device)
        fake_images = generator(noise)

        # Train the discriminator
        optimizer_d.zero_grad()
        real_output = discriminator(real_images)
        fake_output = discriminator(fake_images.detach())
        d_loss_real = criterion(real_output, torch.ones(real_output.size(0), 1, device=device))
        d_loss_fake = criterion(fake_output, torch.zeros(fake_output.size(0), 1, device=device))
        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        optimizer_d.step()

        # Train the generator
        optimizer_g.zero_grad()
        fake_output = discriminator(fake_images)
        g_loss = criterion(fake_output, torch.ones(fake_output.size(0), 1, device=device))
        g_loss.backward()
        optimizer_g.step()

        if i % 100 == 0:
            print(f"{i} / {l}")

    print(f"Epoch {epoch+1} / {EPOCHS}, D loss: {d_loss.item():.4f}, G loss: {g_loss.item():.4f}")
    checker(generator)

0 / 5023
100 / 5023
200 / 5023
300 / 5023
400 / 5023
500 / 5023
600 / 5023
700 / 5023


In [None]:
# torch.save(generator.state_dict(), 'generator.pth')
# torch.save(discriminator.state_dict(), 'discriminator.pth')
# from google.colab import files
# files.download("generator.pth")
# files.download("discriminator.pth")