In [1]:
import torch
import os
import numpy as np
from tqdm.notebook import tqdm
import pickle
import torchvision.transforms as T
import torchvision
from models import Generator, Classifier
import time
from datetime import datetime
import wandb
# from random_erase import RandomErasing

In [2]:
random_erase = T.RandomErasing(p=1.0)

def resizeOddSizes(data):
    size = data.shape
    return torchvision.transforms.CenterCrop((size[1] - (size[1] % 4), size[2] - (size[2] % 4)))(data)
    
    

data_transform = T.Compose([
    T.ToTensor(),
    T.Resize(size=256),
    T.RandomCrop(size=(256, 256))
])     


image_dataset = torchvision.datasets.ImageFolder('data/', transform=data_transform)

        
dataloader = torch.utils.data.DataLoader(image_dataset, batch_size=8, shuffle=True)

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

generator = Generator(resnet_blocks=15, features=64).to(device)
classifier = Classifier().to(device)
optimizer_generator = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_classifier = torch.optim.Adam(classifier.parameters(), lr=0.0002, betas=(0.5, 0.999))
criterion = torch.nn.MSELoss()
EPOCHS = 100000000000
losses = []

In [4]:
def displayImage(tensor, name):
    image = torchvision.transforms.ToPILImage()(tensor)
    image = image.resize((768, 256))
    image.save(f'results/{name}.jpg')

In [5]:
def evaluate_criterion(output, is_real):
    if is_real:
        return criterion(output, torch.ones(size=output.shape).to(device))
    else:
        return criterion(output, torch.zeros(size=output.shape).to(device))
    

In [6]:
def backward_generator(generated_batch):
    generator_loss = evaluate_criterion(classifier(generated_batch), True)
    return generator_loss
    

In [7]:
def identity_generator(original):
    generator_loss = criterion(generator(original), original)
    return generator_loss

In [8]:
def backward_classifier(batch, is_real):
    classifier_loss = evaluate_criterion(classifier(batch), is_real)
    return classifier_loss

In [9]:
def trainOnBatch(batch):
    removed = random_erase(batch).to(device)
    original = batch.to(device)
    generated = generator(removed)
    
    generator.zero_grad()
    generator_loss = backward_generator(generated) + (identity_generator(original) * 5)
    generator_loss.backward()
    optimizer_generator.step()

    classifier.zero_grad()
    classifier_loss = backward_classifier(generated.detach(), False) + backward_classifier(original, True)
    classifier_loss.backward()
    optimizer_classifier.step()
    
    return generator_loss.item(), classifier_loss.item()

    
    

    # displayImage(batch[0][0])
    # displayImage(out[0].cpu())

In [10]:
def showResult(batch, name):
    removed = random_erase(batch).to(device)
    original = batch.to(device)
    out = generator(removed)
    result = torch.cat([original[0], removed[0], out[0]], dim=2)
    
    displayImage(result, name)

In [11]:
def showResult(batch, name):
    removed = random_erase(batch).to(device)
    original = batch.to(device)
    out = generator(removed)
    result = torch.cat([original[0], removed[0], out[0]], dim=2)
    
    displayImage(result, name)

In [12]:
wandb.init(
    project='Filler',
    name="cycle style" + str(datetime.now()),
    config={}
)

def log_loss(epoch, generator_loss, classifier_loss):
    wandb.log({
        'epoch': e,
        'gen_loss': generator_loss,
        'clas_loss': classifier_loss
    })
    
def log_image(batch):
    removed = random_erase(batch).to(device)
    original = batch.to(device)
    out = generator(removed)
    result = torch.cat([original[0], removed[0], out[0]], dim=2)
    
    image = torchvision.transforms.ToPILImage()(result)
    # image = image.resize((768, 256))
    
    
    wandb.log({
        'image': wandb.Image(image),
    })

    

[34m[1mwandb[0m: Currently logged in as: [33mbarisimre[0m ([33mcpl57[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
generator.train()
classifier.train()
for e in range(EPOCHS):
    i = 0
    for batch in tqdm(dataloader):
        i += 1
        try:
            gen_loss, classifier_loss = trainOnBatch(batch[0])
            log_loss(e, gen_loss, classifier_loss)
        except ValueError as e:
            print(e)
            continue
        if i % 50 == 0:
            log_image(batch[0])
            # print(sum(losses) / len(losses))
            losses = []

  0%|          | 0/540 [00:00<?, ?it/s]