In [1]:
import torch.nn as nn
import torch.nn.functional as F
import torch
from model import build_model
import tqdm
from data import load_data
import matplotlib.pyplot as plt



In [2]:
LATENT_DIM = 256
BATCH_SIZE = 8

train_loader, test_loader = load_data(train_batch_size=BATCH_SIZE, test_batch_size=2)
model = build_model(LATENT_DIM, BATCH_SIZE)

In [None]:
def evaluate(model: nn.Module, num_data):
    model.eval()
    progress_bar = tqdm.tqdm(test_loader)
    generator_losses = []
    descriminator_losses = []
    for i, (images, _) in enumerate(progress_bar):
      with torch.no_grad():
        generator_loss, descriminator_loss = model(images)
        generator_losses.append(generator_loss)
        descriminator_losses.append(descriminator_loss)
        if i == num_data:
            break
    
    eval_generator_mean_loss = torch.tensor(generator_losses).mean()
    eval_descriminator_mean_loss = torch.tensor(descriminator_losses).mean()
    
    img = model.generate_image_batch(1).squeeze().detach().cpu()
    print(plt.imshow(img))
    
    model.train()
    
    return eval_descriminator_mean_loss, eval_generator_mean_loss
 
def train(model, num_epochs, eval_epoch):
    descriminator_optimizer = torch.optim.AdamW(model.descriminator.parameters(), lr=1e-3)
    generator_optimizer = torch.optim.AdamW(model.generator.parameters(), lr=1e-3)
    
    progress = tqdm.tqdm(train_loader, dynamic_ncols=True)
    for epoch in range(num_epochs):
        model.train()
        progress.set_description(f'Epoch: {epoch}')
        generator_losses = []
        descriminator_losses = []
        for images, _ in progress:
          generator_loss, descriminator_loss = model(images)
          descriminator_optimizer.zero_grad()
          descriminator_loss.backward()
          descriminator_optimizer.step()
          
          generator_optimizer.zero_grad()
          generator_loss.backward()
          generator_optimizer.step()
          
          generator_losses.append(generator_loss.item())
          descriminator_losses.append(descriminator_loss.item())
      
        generator_mean_loss = torch.tensor(generator_losses).mean()
        descriminator_mean_loss = torch.tensor(descriminator_losses).mean()
       
        if (epoch+1) % eval_epoch == 0:
          eval_descriminator_loss, eval_generator_loss = evaluate(model, num_data=20)
          progress.set_postfix({'generator_loss': generator_mean_loss, 'descriminator_loss': descriminator_mean_loss, 'eval_gen_loss': eval_generator_loss, 'eval_des_loss':eval_descriminator_loss})
        else:
         progress.set_postfix({'generator_loss': generator_mean_loss, 'descriminator_loss': descriminator_mean_loss})
       
      
          
    

In [7]:
model

DCGAN(
  (generator): Generator(
    (seq_pipe): Sequential(
      (0): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
      (3): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): LeakyReLU(negative_slope=0.01)
      (6): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2))
      (7): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (8): LeakyReLU(negative_slope=0.01)
      (9): ConvTranspose2d(32, 16, kernel_size=(4, 4), stride=(1, 1))
      (10): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (11): LeakyReLU(negative_slope=0.01)
      (12): ConvTranspose2d(16, 1, kernel_size=(4, 4), stride=(1, 1))
      (13): Sigmoid()
    )
  )
  (descriminator): 

In [8]:
train(model, num_epochs=20, eval_epoch=2)

Epoch: 0:   0%|          | 0/7500 [00:00<?, ?it/s]


RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [32, 1]], which is output 0 of AsStridedBackward0, is at version 5; expected version 3 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

In [None]:
model

(Generator(
   (seq_pipe): Sequential(
     (0): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(1, 1))
     (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (2): LeakyReLU(negative_slope=0.01)
     (3): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2))
     (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (5): LeakyReLU(negative_slope=0.01)
     (6): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2))
     (7): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (8): LeakyReLU(negative_slope=0.01)
     (9): ConvTranspose2d(32, 16, kernel_size=(4, 4), stride=(1, 1))
     (10): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (11): LeakyReLU(negative_slope=0.01)
     (12): ConvTranspose2d(16, 1, kernel_size=(4, 4), stride=(1, 1))
     (13): Sigmoid()
   )
 ),
 Descriminator(
   (descriminator): Sequential(
     (0)

In [11]:
from model import Generator, Descriminator
generator = Generator(latent_dim=256)
descriminator = Descriminator()

In [None]:
def evaluate(generator: Generator, descriminator: Descriminator, test_loader, num_data, batch_size, device):
    generator.eval()
    descriminator.eval()
    progress_bar = tqdm.tqdm(test_loader)
    generator_losses = []
    descriminator_losses = []
    for i, (images, _) in enumerate(progress_bar):
      with torch.no_grad():
        for images, _ in test_loader:
          real_img_batch = images.to(device)
          fake_img_batch = generator(batch_size)
            
          real_labels = torch.ones(size=(batch_size, ), device=device)
          fake_labels = torch.zeros(size=(batch_size, ), device=device)
            
          # Descriminate
          real_desc_pred = descriminator(real_img_batch).squeeze() # (5, 1)
          fake_desc_pred = descriminator(fake_img_batch).squeeze() # Descrimating Generated image
            
          real_desc_loss = F.binary_cross_entropy(real_desc_pred, real_labels)
          fake_desc_loss = F.binary_cross_entropy(fake_desc_pred, fake_labels)
          descriminator_loss = (real_desc_loss + fake_desc_loss) / 2.0
          
          fake_img_batch = generator(batch_size)
          fake_desc_pred = descriminator(fake_img_batch).squeeze()
          generator_loss = F.binary_cross_entropy(fake_desc_pred, real_labels)
          
          generator_losses.append(generator_loss.item())
          descriminator_losses.append(descriminator_loss.item())
          
        if i == num_data:
            break
    
    eval_generator_mean_loss = torch.tensor(generator_losses).mean()
    eval_descriminator_mean_loss = torch.tensor(descriminator_losses).mean()
    
    img = generator(1).squeeze().detach().cpu()
    plt.imshow(img, cmap='gray')
    plt.title("Generated Image")
    plt.axis("off")
    plt.show()
    
    generator.train()
    descriminator.train()
    
    return eval_descriminator_mean_loss, eval_generator_mean_loss
 
def train(generator, descriminator, num_epochs, eval_epoch, device, batch_size):
    descriminator_optimizer = torch.optim.AdamW(descriminator.parameters(), lr=1e-3)
    generator_optimizer = torch.optim.AdamW(generator.parameters(), lr=1e-3)
    train_loader, test_loader = load_data(train_batch_size=BATCH_SIZE, test_batch_size=2)
    
    progress = tqdm.tqdm(train_loader, dynamic_ncols=True)
    for epoch in range(num_epochs):
        generator.train()
        descriminator.train()
        progress.set_description(f'Epoch: {epoch}')
        generator_losses = []
        descriminator_losses = []
        
        for images, _ in progress:
            real_img_batch = images.to(device)
            fake_img_batch = generator(batch_size)
            
            real_labels = torch.ones(size=(batch_size, ), device=device)
            fake_labels = torch.zeros(size=(batch_size, ), device=device)
            
            # Descriminate
            real_desc_pred = descriminator(real_img_batch).squeeze() # (5, 1)
            fake_desc_pred = descriminator(fake_img_batch).squeeze() # Descrimating Generated image
            
            real_desc_loss = F.binary_cross_entropy(real_desc_pred, real_labels)
            fake_desc_loss = F.binary_cross_entropy(fake_desc_pred, fake_labels)
            descriminator_loss = (real_desc_loss + fake_desc_loss) / 2.0
        
            descriminator_loss.backward()
            descriminator_optimizer.step()
            descriminator_optimizer.zero_grad()
            
            fake_img_batch = generator(batch_size)
            fake_desc_pred = descriminator(fake_img_batch).squeeze()
            generator_loss = F.binary_cross_entropy(fake_desc_pred, real_labels)
            generator_loss.backward()
            generator_optimizer.step()
            generator_optimizer.zero_grad()
            
            generator_losses.append(generator_loss.item())
            descriminator_losses.append(descriminator_loss.item())
            
            progress.set_postfix({'generator_loss': generator_loss.item(), 'descriminator_loss': descriminator_loss.item()})
            
      
        generator_mean_loss = torch.tensor(generator_losses).mean()
        descriminator_mean_loss = torch.tensor(descriminator_losses).mean()
       
        if epoch > 1 :
          eval_descriminator_loss, eval_generator_loss = evaluate(generator, descriminator,test_loader, num_data=20)
          progress.set_postfix({'generator_loss': generator_mean_loss, 'descriminator_loss': descriminator_mean_loss, 'eval_gen_loss': eval_generator_loss, 'eval_des_loss':eval_descriminator_loss})
        else:
         progress.set_postfix({'generator_loss': generator_mean_loss, 'descriminator_loss': descriminator_mean_loss})
       
      
          
    

In [13]:
train(generator=generator, descriminator=descriminator, num_epochs=100, eval_epoch=20, device='cpu', batch_size=8)

Epoch: 0:  21%|██        | 1586/7500 [00:32<02:00, 48.99it/s]


KeyboardInterrupt: 