In [7]:
# import shutil
# shutil.rmtree('sample_data', ignore_errors=True)

In [8]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [9]:
pip install html4vision



In [10]:
import sys
sys.path.append('/content/drive/MyDrive/Image Outpainting and Harmonization using GANs - PyTorch Implementation/')

In [11]:
import torch
from outpainting import *

print("PyTorch version: ", torch.__version__)
print("Torchvision version: ", torchvision.__version__)

# Define paths
model_save_path = '/content/drive/MyDrive/Image Outpainting and Harmonization using GANs - PyTorch Implementation/outpaint_models'
html_save_path = '/content/drive/MyDrive/Image Outpainting and Harmonization using GANs - PyTorch Implementation/outpaint_html'
train_dir = '/content/drive/MyDrive/Image Outpainting and Harmonization using GANs - PyTorch Implementation/cat2dog/trainA'
val_dir = '/content/drive/MyDrive/Image Outpainting and Harmonization using GANs - PyTorch Implementation/cat2dog/valA'
test_dir = '/content/drive/MyDrive/Image Outpainting and Harmonization using GANs - PyTorch Implementation/cat2dog/testA'

# Define datasets & transforms
my_tf = transforms.Compose([
        transforms.Resize(output_size),
        transforms.CenterCrop(output_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor()])
batch_size = 4
train_data = CEImageDataset(train_dir, my_tf, output_size, input_size, outpaint=True)
val_data = CEImageDataset(val_dir, my_tf, output_size, input_size, outpaint=True)
test_data = CEImageDataset(test_dir, my_tf, output_size, input_size, outpaint=True)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=1)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=True, num_workers=1)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True, num_workers=1)
print('train:', len(train_data), 'val:', len(val_data), 'test:', len(test_data))

# Define model & device
device = torch.device('cuda:0')
# device = torch.device('cpu')
G_net = CEGenerator(extra_upsample=True)
D_net = CEDiscriminator()
G_net.apply(weights_init_normal)
D_net.apply(weights_init_normal)
# G_net = nn.DataParallel(G_net)
# D_net = nn.DataParallel(D_net)
G_net.to(device)
D_net.to(device)
print('device:', device)

# Define losses
criterion_pxl = nn.L1Loss()
criterion_D = nn.MSELoss()
optimizer_G = optim.Adam(G_net.parameters(), lr=3e-4, betas=(0.5, 0.999))
optimizer_D = optim.Adam(D_net.parameters(), lr=3e-4, betas=(0.5, 0.999))
criterion_pxl.to(device)
criterion_D.to(device)

# Start training
data_loaders = {'train': train_loader, 'val': val_loader, 'test': test_loader} # NOTE: test is evidently not used by the train method
n_epochs = 150
adv_weight = [0.001, 0.005, 0.015, 0.040] # corresponds to epochs 1-10, 10-30, 30-60, 60-onwards
hist_loss = train_CE(G_net, D_net, device, criterion_pxl, criterion_D, optimizer_G, optimizer_D,
                      data_loaders, model_save_path, html_save_path, n_epochs=n_epochs, outpaint=True, adv_weight=adv_weight)

# Save loss history and final generator
pickle.dump(hist_loss, open('hist_loss.p', 'wb'))
torch.save(G_net.state_dict(), 'generator_final.pt')

PyTorch version:  1.8.1+cu101
Torchvision version:  0.9.1+cu101
train: 591 val: 180 test: 100
device: cuda:0
Batch 1/148  loss_pxl 0.6134  loss_adv 1.3501  loss_D 1.0640
Batch 2/148  loss_pxl 0.5709  loss_adv 9.8089  loss_D 7.1726
Batch 4/148  loss_pxl 0.6904  loss_adv 1.9788  loss_D 1.3428
Batch 8/148  loss_pxl 0.4719  loss_adv 0.5312  loss_D 0.2052
Batch 16/148  loss_pxl 0.2578  loss_adv 0.5994  loss_D 0.2364
Batch 32/148  loss_pxl 0.1645  loss_adv 0.5948  loss_D 0.1513
Batch 64/148  loss_pxl 0.1789  loss_adv 0.9968  loss_D 0.1242
Batch 128/148  loss_pxl 0.1254  loss_adv 0.7714  loss_D 0.0929
Generated image table at: /content/drive/MyDrive/Image Outpainting and Harmonization using GANs - PyTorch Implementation/outpaint_html/0/index.html
Epoch 1/150  train  loss_pxl 0.1959  loss_adv 0.9171  loss_D 0.2258
Epoch 1/150  val  loss_pxl 0.1945  loss_adv 1.2251  loss_D 0.1505

Batch 1/148  loss_pxl 0.1761  loss_adv 1.1973  loss_D 0.1597
Batch 2/148  loss_pxl 0.1930  loss_adv 1.2238  loss_D 