In [1]:
from util import *
from preprocess import *
import os
from tqdm.notebook import tqdm

WORD_IMG_SHAPE = (64, 64 * 4)
LATENT_SIZE = np.prod(WORD_IMG_SHAPE)
BATCH_SIZE = 32

In [2]:
generator = nn.Sequential(
    # in: latent_size x 1 x 1

    nn.ConvTranspose2d(LATENT_SIZE, 1024, kernel_size=4, stride=1, padding=0, bias=False),
    nn.BatchNorm2d(1024),
    nn.ReLU(True),
    # out: 512 x 4 x 4

    nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(512),
    nn.ReLU(True),
    # out: 256 x 8 x 8

    nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(256),
    nn.ReLU(True),
    # out: 128 x 16 x 16

    nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(128),
    nn.ReLU(True),
    # out: 64 x 32 x 32
    
    nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(64),
    nn.ReLU(True),
    # out: 64 x 32 x 32

    nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1, bias=False),
    nn.Tanh(),
    # out: 1 x 64 x 64
    
    Reshape('0', 1, '2*3')
)

In [3]:
sample_dir = 'generated'
os.makedirs(sample_dir, exist_ok=True)

def save_samples(index, latent_tensors, show=True, num_samples=3*8):
    fake_images = generator(latent_tensors)
    fake_fname = 'generated-images-{0:0=4d}.png'.format(index)
    save_image(fake_images.reshape((num_samples, 1, *WORD_IMG_SHAPE)), os.path.join(sample_dir, fake_fname), nrow=3, pad_value=1)
    print('Saving', fake_fname)
    if show:
        fig, ax = plt.subplots(figsize=(8, 8))
        ax.set_xticks([]); ax.set_yticks([])
        ax.imshow(make_grid(fake_images.cpu().detach(), nrow=12).permute(1, 2, 0))

In [4]:


def train_generator(opt_g, base_imgs, crpt_imgs):
    batch_size = base_imgs.shape[0]
    # Clear generator gradients
    opt_g.zero_grad()
    
    # Generate fake images
    # latent = torch.randn(batch_size, latent_size, 1, 1)
    # corrupted, base, labels = next(cwg)
    # print(np.shape(crpt_imgs))
    # print(np.shape(crpt_imgs[0]))
    # print(type(crpt_imgs))
    # print(crpt_imgs.reshape((BATCH_SIZE, LATENT_SIZE, 1, 1)).shape)
    rest_imgs = generator(crpt_imgs.reshape((batch_size, LATENT_SIZE, 1, 1)))
    # print(f"rest img shape: {rest_imgs.shape}")
    # Try to fool the discriminator
    # fool_preds = discriminator(fake_images)
    # fool_targets = torch.ones(batch_size, 1)
    # fool_loss = F.mse_loss(fool_preds, fool_targets)
    
    sim_loss = F.mse_loss((rest_imgs / 2), base_imgs.reshape((batch_size, 1, LATENT_SIZE)) / 2)
    print(f"loss shape: {sim_loss.shape}")
    sim_loss += 0.2 * (-torch.flatten(rest_imgs) + 1)
    
    # Update generator weights
    # loss = 2.5 * fool_loss + 150 * sim_loss
    loss = 150 * sim_loss
    loss.backward()
    opt_g.step()
    
    print(f"generator loss: {(torch.mean(loss).item())}")
    
    return loss.item()

In [5]:
num_examples = 3*8

data_loader = DataLoader(load_dataset('char', equal_shapes=False), shuffle=True, num_workers=0, pin_memory=True)
cwg = CorruptWordGen(data_loader, batch_size=num_examples, img_shape=WORD_IMG_SHAPE)

fixed_base, fixed_corrupted, fixed_labels = next(cwg)
fixed_corrupted = fixed_corrupted.reshape((num_examples, LATENT_SIZE, 1, 1))

def fit(epochs, lr, start_idx=1):
  torch.cuda.empty_cache()
    
  # Losses & scores
  losses_g = []
  losses_d = []
  real_scores = []
  fake_scores = []
  gen_scores = []
  
  # Create optimizers
  # opt_d = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
  opt_g = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
  
  for epoch in range(epochs):
    data_loader = DataLoader(load_dataset('char', equal_shapes=False), shuffle=True, num_workers=0, pin_memory=True)
    cwg = PrettyCorruptWordGen(data_loader, batch_size=32, img_shape=WORD_IMG_SHAPE)
    
    for base_imgs, crpt_imgs, labels in tqdm(cwg):
      
      # # Train discriminator
      # loss_d, real_score, fake_score = train_discriminator(real_images, opt_d)
      # # Train generator
      loss_g = train_generator(opt_g, base_imgs, crpt_imgs)
        
    # Record losses & scores
    losses_g.append(loss_g)
    # losses_d.append(loss_d)
    # real_scores.append(real_score)
    # fake_scores.append(fake_score)
    
    # Log losses & scores (last batch)
    print("Epoch [{}/{}], loss_g: {:.4f}".format(
      epoch+1, epochs, loss_g))

    # Save generated images
    save_samples(epoch+start_idx, fixed_corrupted, show=False, num_samples = num_examples)
  
  return losses_g, losses_d, real_scores, fake_scores

In [6]:

# print(torch.cat(fixed_base).shape)
# print(np.shape(fixed_base[0]))
# print(np.shape(fixed_corrupted[0]))
save_image(fixed_base.reshape((num_examples, 1, *WORD_IMG_SHAPE)), os.path.join('generated', "base.png"), nrow=3, pad_value=1)
save_image(fixed_corrupted.reshape((num_examples, 1, *WORD_IMG_SHAPE)), os.path.join('generated', "corrupted.png"), nrow=3, pad_value=1)
print(fixed_labels)

[[11, 23, 21, 4, 11, 20], [16, 16, 16], [21, 6, 9, 20], [21, 13, 10, 4, 11], [7, 6, 20, 12, 11, 17, 19, 16], [18, 1, 1, 1, 6, 21, 23, 0], [4, 9, 19], [14, 10, 7, 2, 23, 20], [1, 1, 18, 24, 19, 12, 7, 2], [5, 10, 10, 19, 11, 2], [0, 5, 2, 6, 0, 20, 12], [19, 16, 23, 16, 10, 10, 20, 13], [6, 6, 23, 11], [7, 6], [18, 1], [0, 16, 6, 5, 13, 18, 9, 15, 1], [5, 11, 2, 6, 16, 4, 20], [22, 9, 18, 1, 3, 23], [21, 6, 21], [6, 9, 20, 13], [21, 16, 21, 20, 18], [6, 10, 11, 20, 18, 9], [16, 6, 18, 20, 1], [4, 2, 6, 7, 18, 0, 6, 4, 1]]


In [7]:
save_samples(0, fixed_corrupted, show=False, num_samples = num_examples)
history = fit(100, 1e-2)

Saving generated-images-0000.png


  0%|          | 0/35 [00:00<?, ?it/s]

loss shape: torch.Size([])


RuntimeError: output with shape [] doesn't match the broadcast shape [524288]