In [None]:
#インポート

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import csv

In [None]:
#データの読み込み
data = []

filename = '/content/drive/MyDrive/アルバイト/PreApp/Asset/ColorList(1).csv'
with open(filename, encoding='utf8', newline='') as f:
    csvreader = csv.reader(f)
    for row in csvreader:
        data.append(row)

In [None]:
#データ整形

data.pop(0)
data = [*map(lambda row: list(map(lambda x: int(x), row)), data)]

data_x = np.array([*map(lambda x: np.array(x[3:6])/ 255, data)])
data_y = np.array([*map(lambda x: np.array(x[:3]) / 255, data)])

In [None]:
#Discriminatorクラス

class discriminator(nn.Module):
    def __init__(self):
        super(discriminator, self).__init__()

        self.fc1 = nn.Linear(3, 512)
        self.fc2 = nn.Linear(512, 1)
        self.activation = nn.LeakyReLU(0.1)

    def forward(self, x):
        x = self.fc1(x)
        x = self.activation(x)
        x = self.fc2(x)

        return nn.Sigmoid()(x)

In [None]:
#Generatorクラス

class generator(nn.Module):
    def __init__(self):
        super(generator, self).__init__()

        self.fc1 = nn.Linear(3, 1024)
        self.fc2 = nn.Linear(1024, 2048)
        self.fc3 = nn.Linear(2048, 3)
        self.activation = nn.ReLU()

    def forward(self, x):
        x = self.activation(self.fc1(x))
        x = self.activation(self.fc2(x))
        x = self.fc3(x)

        return nn.Tanh()(x)

In [None]:
#ハイパーパラメータの設定

# GPU利用可否確認
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

epochs = 30
lr = 2e-4
batch_size = 64
bce_loss = nn.BCEWithLogitsLoss()
mae_loss = nn.L1Loss()

G = generator().to(device)
D = discriminator().to(device)

G_optimizer = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
D_optimizer = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))

In [None]:
#学習

for epoch in range(epochs):
    for idx, x in enumerate(data_x):
        idx += 1

        #True Color
        real_y = data_y[idx-1]

        input_x = torch.tensor(x, dtype=torch.float32).to(device)
        real_y = torch.tensor(real_y, dtype=torch.float32).to(device)

        fake_color = G(input_x)
        fake_color_tensor = fake_color.detach()
        fake_output = D(fake_color)

        #lossを計算するためのラベル
        real_label = torch.ones(1, requires_grad=True).to(device)
        fake_label = torch.zeros(1, requires_grad=True).to(device)

        #Generateの訓練
        loss_G_bce = bce_loss(fake_output, real_label)
        loss_G_mae = mae_loss(fake_color, real_y)
        loss_G_sum = loss_G_bce + loss_G_mae

        #重みの更新
        G_optimizer.zero_grad()
        D_optimizer.zero_grad()
        loss_G_sum.backward()
        G_optimizer.step()

        #Discriminatorの訓練
        real_output = D(real_y)
        fake_output = D(fake_color_tensor)

        #損失関数の計算
        loss_D_real = bce_loss(real_output, real_label)
        loss_D_fake = bce_loss(fake_output, fake_label)
        loss_D_sum = loss_D_real + loss_D_fake

        #重みの更新
        G_optimizer.zero_grad()
        D_optimizer.zero_grad()
        loss_D_sum.backward()
        D_optimizer.step()

        if idx % 100 == 0 or idx == len(data_x):
            print('Epoch {} Iteration {}: discriminator_loss {:.3f} generator_loss {:.3f}'.format(epoch, idx, loss_D_sum.item(), loss_G_sum_tmp.item()))

    if (epoch+1) % 100 == 0:
        torch.save(G, 'Generator_epoch_{}.pth'.format(epoch))
        print('Model saved.')

Epoch 0 Iteration 100: discriminator_loss 1.397 generator_loss 0.638
Epoch 0 Iteration 200: discriminator_loss 1.399 generator_loss 0.684
Epoch 0 Iteration 300: discriminator_loss 1.394 generator_loss 0.822
Epoch 0 Iteration 400: discriminator_loss 1.393 generator_loss 0.679
Epoch 0 Iteration 500: discriminator_loss 1.393 generator_loss 0.725
Epoch 0 Iteration 588: discriminator_loss 1.400 generator_loss 0.702
Epoch 1 Iteration 100: discriminator_loss 1.410 generator_loss 0.743
Epoch 1 Iteration 200: discriminator_loss 1.386 generator_loss 0.675
Epoch 1 Iteration 300: discriminator_loss 1.393 generator_loss 0.891
Epoch 1 Iteration 400: discriminator_loss 1.388 generator_loss 0.698
Epoch 1 Iteration 500: discriminator_loss 1.408 generator_loss 0.828
Epoch 1 Iteration 588: discriminator_loss 1.433 generator_loss 0.717
Epoch 2 Iteration 100: discriminator_loss 1.403 generator_loss 0.773
Epoch 2 Iteration 200: discriminator_loss 1.388 generator_loss 0.703
Epoch 2 Iteration 300: discriminat

In [None]:
#モデルの保存

torch.save(G.state_dict(), 'generator_model.pth')