In [None]:
import fastai
from fastai.vision import *
from fastai.callbacks import *
from fastai.vision.gan import *
from torchvision.models import vgg16_bn
from PIL import Image, ImageDraw, ImageFont
import torchvision.models as tv_models
import pathlib

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

def gram_matrix(x):
    n,c,h,w = x.size()
    x = x.view(n, c, -1)
    return (x @ x.transpose(1,2))/(c*h*w)

base_loss = F.l1_loss

class FeatureLoss(nn.Module):
    def __init__(self, m_feat, layer_ids, layer_wgts):
        super().__init__()
        self.m_feat = m_feat
        self.loss_features = [self.m_feat[i] for i in layer_ids]
        self.hooks = hook_outputs(self.loss_features, detach=False)
        self.wgts = layer_wgts
        self.metric_names = ['pixel',] + [f'feat_{i}' for i in range(len(layer_ids))
              ] + [f'gram_{i}' for i in range(len(layer_ids))]

    def make_features(self, x, clone=False):
        self.m_feat(x)
        return [(o.clone() if clone else o) for o in self.hooks.stored]

    def forward(self, input, target):
        out_feat = self.make_features(target, clone=True)
        in_feat = self.make_features(input)
        self.feat_losses = [base_loss(input,target)]
        self.feat_losses += [base_loss(f_in, f_out)*w
                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
        self.feat_losses += [base_loss(gram_matrix(f_in), gram_matrix(f_out))*w**2 * 5e3
                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
        self.metrics = dict(zip(self.metric_names, self.feat_losses))
        return sum(self.feat_losses)

    def __del__(self): self.hooks.remove()

        
def create_data(path_to_source, path_to_target, bs, size, tfms, test_share=0.05):
    src = ImageImageList.from_folder(path_to_source).split_by_rand_pct(test_share, seed=42)
    data = (src.label_from_func(lambda x: path_to_target / x.name).transform(tfms, size=size, tfm_y=True)
            .databunch(bs=bs).normalize(imagenet_stats, do_y=True))
    data.c = 3
    return data


def create_crit_data(ds_path, classes, bs, size, test_share=0.05):
    src = ImageList.from_folder(ds_path, include=classes).split_by_rand_pct(test_share, seed=42)
    ll = src.label_from_folder(classes=classes)
    data = (ll.transform(get_transforms(do_flip=True, max_rotate=45, max_zoom=4., max_warp=0.2), size=size)
           .databunch(bs=bs).normalize(imagenet_stats))
    data.c = 3
    return data

def create_gen_learner(data_gen, arch, wd, y_range, loss_gen):
    return unet_learner(data_gen, arch, wd=wd, blur=True, norm_type=NormType.Weight,
                         self_attention=True, y_range=y_range, loss_func=loss_gen)

def create_critic_learner(data, wd, loss_func, metrics):
    return Learner(data, gan_critic(), metrics=metrics, loss_func=loss_func, wd=wd)

def save_preds(dl):
    i=0
    names = dl.dataset.items
    
    for b in dl:
        preds = learn_gen.pred_batch(batch=b, reconstruct=True)
        for o in preds:
            o.save(path_gen/names[i].name)
            i += 1
       
def num_params(model):
    return sum([p.numel() for p in model.parameters()])

# Paths and Hyperparameters

In [None]:
# path to your dataset here
DS_PATH = pathlib.Path("path/to/dataset/")
# path to source domain
DIR_A = "trainA"
PATH_A = DS_PATH / DIR_A
# path to target domain
DIR_B = "trainB"
PATH_B = DS_PATH / DIR_B
# prefix for saved pretrained generator, critic and final generator weights
SAVE_PREFIX = "my-custom-model"

In [None]:
BATCH_SIZE = 8
IMG_SIZE = 256
ENCODER_ARCH = models.resnet34
print(f"Encoder #params: {num_params(ENCODER_ARCH()) / 1e6} M")
TRAIN_TFMS = get_transforms(do_flip=True, max_rotate=45, max_zoom=4., max_warp=0.2)
TEST_SHARE = 0.05

WEIGHT_DECAY = 1e-3
Y_RANGE = (-3, 3)

# pretrain encoder only for this many epochs
NUM_EPOCHS_PRETRAIN_ENCODER = 2
# pretrain whole generator for this many more epochs
NUM_EPOCHS_PRETRAIN_WHOLE_GEN = 3
# pretrain whole generator with this lr
LR_PRETRAIN_WHOLE_GEN = slice(1e-6, 1e-3)
# pretrain critic for this many epochs
NUM_EPOCHS_PRETRAIN_CRIT = 12
# pretrain critic with this lr
LR_PRETRAIN_CRIT = 1e-3

# train entire GAN for this many epochs
NUM_EPOCHS_TRAIN_GAN = 50
# train entire GAN with this lr
LR_TRAIN_GAN = 1e-4
CRITIC_THRESHOLD = 0.65

# Feature loss, can be modified or left as is
vgg_m = vgg16_bn(True).features.to(DEVICE).eval()
requires_grad(vgg_m, False)
blocks = [i - 1 for i, o in enumerate(children(vgg_m)) if isinstance(o, nn.MaxPool2d)]
FEAT_LOSS = FeatureLoss(vgg_m, blocks[2:5], [5, 15, 2])

#  Pre-train the generator

In [None]:
src = ImageImageList.from_folder(PATH_A).split_by_rand_pct(TEST_SHARE, seed=42)
data_gen = create_data(PATH_A, PATH_B, BATCH_SIZE, IMG_SIZE, TRAIN_TFMS, TEST_SHARE)

In [None]:
data_gen.show_batch(8)

In [None]:
learn_gen = create_gen_learner(data_gen, ENCODER_ARCH, wd=WEIGHT_DECAY, y_range=Y_RANGE, loss_gen=FEAT_LOSS)
print(f"Generator #params: {num_params(learn_gen.model) / 1e6} M")

In [None]:
learn_gen.fit_one_cycle(NUM_EPOCHS_PRETRAIN_ENCODER, pct_start=0.8)

In [None]:
learn_gen.unfreeze()

In [None]:
learn_gen.fit_one_cycle(NUM_EPOCHS_PRETRAIN_WHOLE_GEN, LR_PRETRAIN_WHOLE_GEN)

In [None]:
learn_gen.show_results(rows=10)

In [None]:
learn_gen.save(SAVE_PREFIX + "-pretrained-gen")

# Make generated images for critic

In [None]:
learn_gen.load(SAVE_PREFIX + "-pretrained-gen");

In [None]:
name_gen = SAVE_PREFIX + "_images_gen"
path_gen = DS_PATH / name_gen

In [None]:
# uncomment if already exists needed
# shutil.rmtree(path_gen)

In [None]:
path_gen.mkdir(exist_ok=True)

In [None]:
save_preds(data_gen.fix_dl)

In [None]:
PIL.Image.open(path_gen.ls()[0])

# Pretrain critic

In [None]:
learn_gen=None
gc.collect()

In [None]:
data_crit = create_crit_data(DS_PATH, classes=[name_gen, DIR_B],
                             bs=BATCH_SIZE, size=IMG_SIZE, test_share=TEST_SHARE)

In [None]:
data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)

In [None]:
data_crit

In [None]:
loss_critic = AdaptiveLoss(nn.BCEWithLogitsLoss())

In [None]:
learn_critic = create_critic_learner(data_crit, wd=WEIGHT_DECAY, loss_func=loss_critic, metrics=accuracy_thresh_expand)

In [None]:
learn_critic.fit_one_cycle(NUM_EPOCHS_PRETRAIN_CRIT, LR_PRETRAIN_CRIT)

In [None]:
learn_critic.save(SAVE_PREFIX + "-pretrained-crit")

## GAN

In [None]:
learn_critic=None
learn_gen=None
gc.collect()

In [None]:
data_crit = create_crit_data(DS_PATH, [DIR_A, DIR_B], bs=BATCH_SIZE, size=IMG_SIZE, test_share=TEST_SHARE)

In [None]:
learn_crit = create_critic_learner(data_crit, wd=WEIGHT_DECAY,
                                   loss_func=loss_critic, metrics=None).load(SAVE_PREFIX + "-pretrained-crit")

In [None]:
learn_gen = create_gen_learner(data_gen, ENCODER_ARCH, WEIGHT_DECAY, Y_RANGE, FEAT_LOSS).load(SAVE_PREFIX + "-pretrained-gen")

In [None]:
switcher = partial(AdaptiveGANSwitcher, critic_thresh=CRITIC_THRESHOLD)
learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.,50.), show_img=True, switcher=switcher,
                                 opt_func=partial(optim.Adam, betas=(0.,0.99)), wd=WEIGHT_DECAY)
learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))

In [None]:
learn.fit(NUM_EPOCHS_TRAIN_GAN, LR_TRAIN_GAN)

In [None]:
learn.save(SAVE_PREFIX + "-trained-gan")

In [None]:
# learn for more if needed
# learn.fit(10, lr / 2)

In [None]:
learn_gen.show_results(rows=16)

In [None]:
torch.save(learn_gen.model.state_dict(), DS_PATH / "models" / (SAVE_PREFIX + "-gen-state-dict.pth"))