# 1010 GAN

- paper, implementation : <GAN 첫걸음>

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

import numpy as np

In [12]:
BATCH_SIZE = 8192
BATCH_NUM = 1000
EPOCHS = 10
device = "cuda:0"

## Make Training Dataset

In [13]:
labels = np.random.randint(low=0, high=2, size=(BATCH_SIZE * BATCH_NUM,))

examples = np.random.rand(BATCH_SIZE * BATCH_NUM, 4)
examples[labels == 1, 0] = examples[labels == 1, 0] * 0.2
examples[labels == 1, 1] = examples[labels == 1, 1] * 0.2 + 0.8
examples[labels == 1, 2] = examples[labels == 1, 2] * 0.2
examples[labels == 1, 3] = examples[labels == 1, 3] * 0.2 + 0.8

delim = "\t"
with open("train.tsv", "w") as f:
    f.write(f"ex0{delim}ex1{delim}ex2{delim}ex3{delim}label\n")
    f.write('\n'.join([f"{examples[i, 0]}{delim}{examples[i, 1]}{delim}{examples[i, 2]}{delim}{examples[i, 3]}{delim}{labels[i]}" for i in range(BATCH_SIZE * BATCH_NUM)]))

## Read Training Dataset

In [37]:
class Iterator:
    def __init__(self, file, delim, device):
        self.file = open(file, "r")
        self._resetFile()
        
        self.delim = delim
        
        self.device = device
        self.idx = 0
    
    def _resetFile(self):
        self.file.seek(0)
        self.file.readline() # remove header line
        
    def __next__(self):
        lines = [self.file.readline().strip() for x in range(BATCH_SIZE)]
        if len(lines[0]) == 0:
            self._resetFile()
            raise StopIteration
        
        dataset = np.array([[
                float(x.split(self.delim)[0]),
                float(x.split(self.delim)[1]),
                float(x.split(self.delim)[2]),
                float(x.split(self.delim)[3]),
                int(x.split(self.delim)[4])
            ] for x in lines])
        
        examples = torch.tensor(dataset[:, :4], dtype=torch.float32, device=self.device)
        labels = torch.tensor(dataset[:, 4], dtype=torch.float32, device=self.device)
        return [examples, labels]
    
    def __iter__(self):
        return self
    
    def __del__(self):
        self.file.close()
        
iterator = Iterator(file="train.tsv", delim="\t", device=device)

In [38]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.model = nn.Sequential(
            nn.Linear(4, 3),
            nn.ReLU(),
            nn.Linear(3, 1),
            nn.ReLU()
        )
        
    def forward(self, x):
        return self.model(x)

In [39]:
def train(model, iterator, criterion, optimizer):
    model.train()
    
    total_loss = 0
    for i, batch in enumerate(iterator):
        print(f"  Batch {i + 1}/{BATCH_NUM}", end="\r")
        optimizer.zero_grad()
        
        examples, labels = batch
        labels = labels.view(-1, 1)
        
        outputs = model.forward(examples)
        
        loss = criterion(outputs, labels)
        loss.backward()
        
        optimizer.step()
        
        total_loss += loss.item()
        
    model.eval()
    #print()
    #print(f"  Real Data : {model.forward(torch.tensor([1, 0, 1, 0], dtype=torch.float32).to(device)).item()}")
    #print(f"  Fake Data : {model.forward(torch.tensor([0, 0, 0, 0], dtype=torch.float32).to(device)).item()}")
    
    return total_loss / BATCH_NUM

In [40]:
D = Discriminator().to(device)
criterion = nn.MSELoss()
optimizer = optim.SGD(D.parameters(), lr=0.01)

losses = []
for epoch in range(EPOCHS):
    print(f"Epoch {epoch + 1}/{EPOCHS}")
    loss = train(D, iterator, criterion, optimizer)
    print(f"  Loss : {loss}")

Epoch 1/10
  Loss : 0.08115858364105225
Epoch 2/10
  Loss : 0.022485326915979386
Epoch 3/10
  Loss : 0.01797101628035307
Epoch 4/10
  Loss : 0.016887724050320685
Epoch 5/10
  Loss : 0.016489399135112763
Epoch 6/10
  Loss : 0.016300717486068605
Epoch 7/10
  Loss : 0.016194526047445833
Epoch 8/10
  Loss : 0.016127761744894088
Epoch 9/10
  Loss : 0.01608292795624584
Epoch 10/10
  Loss : 0.016051676724106075


In [41]:
D.eval()
D.forward(torch.tensor([1, 0, 1, 0], dtype=torch.float32).to(device)).item()

0.0

In [42]:
D.forward(torch.tensor([0, 0, 0, 0], dtype=torch.float32).to(device)).item()

0.0