# Fine tune

In [1]:
from google.colab import drive
import glob
from pathlib import Path
from itertools import islice
 
drive.mount('/content/drive',force_remount=True)

Mounted at /content/drive


In [1]:
cd drive/MyDrive/DeOldify

/content/drive/MyDrive/DeOldify


In [None]:
!pip install -r colab_requirements.txt

In [2]:
import os
import fastai
from fastai import *
from fastai.vision import *
from fastai.callbacks.tensorboard import *
from fastai.vision.gan import *
from deoldify.generators import *
from deoldify.critics import *
from deoldify.dataset import *
from deoldify.loss import *
from deoldify.save import *
from PIL import Image, ImageDraw, ImageFont
from PIL import ImageFile

In [3]:
path = Path('data/')
path_hr = path
path_lr = path/'bandw'

proj_id = 'StableModel'

gen_name = proj_id + '_gen'
pre_gen_name = gen_name + '_0'
crit_name = proj_id + '_crit'

name_gen = proj_id + '_image_gen'
path_gen = path/name_gen

TENSORBOARD_PATH = Path('data/tensorboard/' + proj_id)

nf_factor = 2
pct_start = 1e-8

In [4]:

def get_data( bs:int, sz:int, keep_pct:float):

    return get_colorize_data(sz=sz, bs=bs, crappy_path=path_lr, good_path=path_hr, 
                             random_seed=None, keep_pct=keep_pct)

def get_crit_data(classes, bs, sz):
    src = ImageList.from_folder(path, include=classes, recurse=True).split_by_rand_pct(0.1, seed=42)
    
    ll = src.label_from_folder(classes=classes)
    
    data = (ll.transform(get_transforms(max_zoom=2.), size=sz)
           .databunch(bs=bs).normalize(imagenet_stats))
    
    return data

def create_training_images(fn,i):
    dest = path_lr/fn.relative_to(path_hr)
    
    dest.parent.mkdir(parents=True, exist_ok=True)
    
    img = PIL.Image.open(fn).convert('LA').convert('RGB')
    
    img.save(dest)  
    
def save_preds(dl):
    i=0
    names = dl.dataset.items
    
    for b in dl:
        preds = learn_gen.pred_batch(batch=b, reconstruct=True)
        for o in preds:
            o.save(path_gen/names[i].name)
            i += 1
    
def save_gen_images( keep_pct=0.085):
    if path_gen.exists(): shutil.rmtree(path_gen)
    path_gen.mkdir(exist_ok=True)
    
    data_gen = get_data(bs=bs, sz=sz, keep_pct=keep_pct)
    
    save_preds(data_gen.fix_dl)
    
    PIL.Image.open(path_gen.ls()[0])

In [5]:
if not path_lr.exists():
    il = ImageList.from_folder(path_hr)
    parallel(create_training_images, il.items)

In [15]:
pretrain_learner_path = 'ColorizeStable_gen' 
pretrain_critic_path = 'ColorizeStable_crit'

In [7]:
learn_crit=None
learn_gen=None
gc.collect()

357

In [33]:
bs = 1
sz = 192
keep_pct = 1.0
lr=2e-5

In [9]:
data_gen = get_data(bs=bs, sz=sz, keep_pct=keep_pct)

In [10]:
learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor)

Downloading: "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth" to /root/.torch/models/vgg16_bn-6c64b313.pth
553507836it [00:05, 95171857.49it/s]
Downloading: "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth" to /root/.torch/models/resnet101-5d3b4d8f.pth
178728960it [00:05, 34596978.42it/s]


In [20]:
cd ..

/content/drive/MyDrive/DeOldify


In [13]:
!wget https://www.dropbox.com/s/usf7uifrctqw9rl/ColorizeStable_gen.pth?dl=0 -O ColorizeStable_gen.pth

--2022-06-11 13:46:57--  https://www.dropbox.com/s/usf7uifrctqw9rl/ColorizeStable_gen.pth?dl=0
Resolving www.dropbox.com (www.dropbox.com)... 162.125.3.18, 2620:100:6018:18::a27d:312
Connecting to www.dropbox.com (www.dropbox.com)|162.125.3.18|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: /s/raw/usf7uifrctqw9rl/ColorizeStable_gen.pth [following]
--2022-06-11 13:46:57--  https://www.dropbox.com/s/raw/usf7uifrctqw9rl/ColorizeStable_gen.pth
Reusing existing connection to www.dropbox.com:443.
HTTP request sent, awaiting response... 302 Found
Location: https://uc6521307fa295c2fecb5bf762b9.dl.dropboxusercontent.com/cd/0/inline/BnB8ex3sFcOCU58EEgwLTCiDJGQMfnYjNoO3Rpj0ODumIQAuCcKJjMxgxokje85cQIUglIY1OHSU4cOn123xrdFUooo2nRKoci_xq-kXcXppYA_kdRHzGxVxZ-MDMscbpoVx3_oQp7aQwqmlKt61zA_lmDjJeKLEYdGwnCojLTvO5EGO2pCypOgdtVSlcFwB1Zk/file# [following]
--2022-06-11 13:46:58--  https://uc6521307fa295c2fecb5bf762b9.dl.dropboxusercontent.com/cd/0/inline/BnB8ex3sFcOC

In [22]:
!wget https://www.dropbox.com/s/wlqu6w88qwzcvfn/ColorizeStable_crit.pth?dl=0 -O ColorizeStable_crit.pth

--2022-06-11 13:51:23--  https://www.dropbox.com/s/wlqu6w88qwzcvfn/ColorizeStable_crit.pth?dl=0
Resolving www.dropbox.com (www.dropbox.com)... 162.125.3.18, 2620:100:6018:18::a27d:312
Connecting to www.dropbox.com (www.dropbox.com)|162.125.3.18|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: /s/raw/wlqu6w88qwzcvfn/ColorizeStable_crit.pth [following]
--2022-06-11 13:51:24--  https://www.dropbox.com/s/raw/wlqu6w88qwzcvfn/ColorizeStable_crit.pth
Reusing existing connection to www.dropbox.com:443.
HTTP request sent, awaiting response... 302 Found
Location: https://uceb3cb5c44626c5b1c8fa09c0bf.dl.dropboxusercontent.com/cd/0/inline/BnCLvT7Eh8o5NVqjxctZL4toQa2_wITf0eL5sANaon5I6Td-ptsJH9vbfCza6b7LiXyxMJ5YZMF-Dnf7mb8kfXOft5rKchbTEIWpkWVQ7xsoMZt3SqhZfXJrtCzzrkxGAUD7-iqadvayuVRQvshkeTEn_opGErscpYcRVtOrKXYzhMKn1JpSRY2zClPXbfEY23g/file# [following]
--2022-06-11 13:51:24--  https://uceb3cb5c44626c5b1c8fa09c0bf.dl.dropboxusercontent.com/cd/0/inline/BnCLvT7Eh

In [21]:
learn_gen.load(pretrain_learner_path, with_opt=False)

Learner(data=ImageDataBunch;

Train: LabelList (900 items)
x: ImageImageList
Image (3, 192, 192),Image (3, 192, 192),Image (3, 192, 192),Image (3, 192, 192),Image (3, 192, 192)
y: ImageList
Image (3, 192, 192),Image (3, 192, 192),Image (3, 192, 192),Image (3, 192, 192),Image (3, 192, 192)
Path: data/bandw;

Valid: LabelList (100 items)
x: ImageImageList
Image (3, 192, 192),Image (3, 192, 192),Image (3, 192, 192),Image (3, 192, 192),Image (3, 192, 192)
y: ImageList
Image (3, 192, 192),Image (3, 192, 192),Image (3, 192, 192),Image (3, 192, 192),Image (3, 192, 192)
Path: data/bandw;

Test: None, model=DynamicUnetWide(
  (layers): ModuleList(
    (0): Sequential(
      (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace)
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (4): Sequential(
        (0): Bottle

In [25]:
save_gen_images(1.00)

In [26]:
data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)

In [28]:
learn_crit = colorize_crit_learner(data=data_crit, nf=256).load(pretrain_critic_path , with_opt=False)

In [29]:
learn_gen.freeze_to(-1)
learn_crit.freeze_to(-1)

In [30]:
switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.65)
learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.0,1.5), show_img=False, switcher=switcher,
                                 
                                 opt_func=partial(optim.Adam, betas=(0.,0.9)), wd=1e-3)

learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))

learn.callback_fns.append(partial(GANTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GanLearner', visual_iters=100))

learn.callback_fns.append(partial(GANSaveCallback, learn_gen=learn_gen, filename=pre_gen_name, save_iters=100))

In [31]:
learn.data = get_data(sz=sz, bs=bs, keep_pct=keep_pct)

In [38]:
learn.fit(1,lr)

epoch,train_loss,valid_loss,gen_loss,disc_loss,time


  .format(op_name, op_name))
  .format(op_name, op_name))
  .format(op_name, op_name))


KeyboardInterrupt: ignored

In [None]:
learn.lr_find()