In [1]:
from fastai.vision.all import array, nn, os, PILImage, torch, TensorImage
from torchvision.transforms import ToPILImage
from typing import List
import time
from IPython.core.display import HTML
from ipywidgets import Label, Button, FileUpload, Output, VBox, AppLayout, Layout, Dropdown
import warnings
warnings.filterwarnings('ignore')

#set simpsons like font in next cell

In [2]:
%%html
<style>
@import url('https://fonts.googleapis.com/css2?family=Gochi+Hand&display=swap');
.out_style{
    color: black;
    background-color:yellow;
    font-family: 'Gochi Hand', cursive;
}
.gochihand {
    font-family: 'Gochi Hand', cursive;
}
</style>

In [3]:
#define functions used in the model
def pad_conv_norm_relu(ch_in:int, ch_out:int, pad_mode:str, norm_layer:nn.Module, ks:int=3, bias:bool=True, 
                       pad=1, stride:int=1, activ:bool=True, init=nn.init.kaiming_normal_)->List[nn.Module]:
    layers = []
    if pad_mode == 'reflection': layers.append(nn.ReflectionPad2d(pad))
    elif pad_mode == 'border':   layers.append(nn.ReplicationPad2d(pad))
    p = pad if pad_mode == 'zeros' else 0
    conv = nn.Conv2d(ch_in, ch_out, kernel_size=ks, padding=p, stride=stride, bias=bias)
    if init:
        init(conv.weight)
        if hasattr(conv, 'bias') and hasattr(conv.bias, 'data'): conv.bias.data.fill_(0.)
    layers += [conv, norm_layer(ch_out)]
    if activ: layers.append(nn.ReLU(inplace=True))
    return layers

def convT_norm_relu(ch_in:int, ch_out:int, norm_layer:nn.Module, ks:int=3, stride:int=2, bias:bool=True):
    return [nn.ConvTranspose2d(ch_in, ch_out, kernel_size=ks, stride=stride, padding=1, output_padding=1, bias=bias),
            norm_layer(ch_out), nn.ReLU(True)]

def convT_norm_relu(ch_in:int, ch_out:int, norm_layer:nn.Module, ks:int=3, stride:int=2, bias:bool=True):
    return [nn.ConvTranspose2d(ch_in, ch_out, kernel_size=ks, stride=stride, padding=1, output_padding=1, bias=bias),
            norm_layer(ch_out), nn.ReLU(True)]

class ResnetBlock(nn.Module):
    def __init__(self, dim:int, pad_mode:str='reflection', norm_layer:nn.Module=None, dropout:float=0., bias:bool=True):
        super().__init__()
        assert pad_mode in ['zeros', 'reflection', 'border'], f'padding {pad_mode} not implemented.'
        norm_layer = ifnone(norm_layer, nn.InstanceNorm2d)
        layers = pad_conv_norm_relu(dim, dim, pad_mode, norm_layer, bias=bias)
        if dropout != 0: layers.append(nn.Dropout(dropout))
        layers += pad_conv_norm_relu(dim, dim, pad_mode, norm_layer, bias=bias, activ=False)
        self.conv_block = nn.Sequential(*layers)

    def forward(self, x): return x + self.conv_block(x)

    
def resnet_generator(ch_in:int, ch_out:int, n_ftrs:int=64, norm_layer:nn.Module=None, 
                     dropout:float=0., n_blocks:int=6, pad_mode:str='reflection')->nn.Module:
    norm_layer = ifnone(norm_layer, nn.InstanceNorm2d)
    bias = (norm_layer == nn.InstanceNorm2d)
    layers = pad_conv_norm_relu(ch_in, n_ftrs, 'reflection', norm_layer, pad=3, ks=7, bias=bias)
    for i in range(2):
        layers += pad_conv_norm_relu(n_ftrs, n_ftrs *2, 'zeros', norm_layer, stride=2, bias=bias)
        n_ftrs *= 2
    layers += [ResnetBlock(n_ftrs, pad_mode, norm_layer, dropout, bias) for _ in range(n_blocks)]
    for i in range(2):
        layers += convT_norm_relu(n_ftrs, n_ftrs//2, norm_layer, bias=bias)
        n_ftrs //= 2
    layers += [nn.ReflectionPad2d(3), nn.Conv2d(n_ftrs, ch_out, kernel_size=7, padding=0), nn.Tanh()]
    return nn.Sequential(*layers)

def conv_norm_lr(ch_in:int, ch_out:int, norm_layer:nn.Module=None, ks:int=3, bias:bool=True, pad:int=1, stride:int=1, 
                 activ:bool=True, slope:float=0.2, init=nn.init.kaiming_normal_)->List[nn.Module]:
    conv = nn.Conv2d(ch_in, ch_out, kernel_size=ks, padding=pad, stride=stride, bias=bias)
    if init:
        init(conv.weight)
        if hasattr(conv, 'bias') and hasattr(conv.bias, 'data'): conv.bias.data.fill_(0.)
    layers = [conv]
    if norm_layer is not None: layers.append(norm_layer(ch_out))
    if activ: layers.append(nn.LeakyReLU(slope, inplace=True))
    return layers



def discriminator(ch_in:int, n_ftrs:int=64, n_layers:int=3, norm_layer:nn.Module=None, sigmoid:bool=False)->nn.Module:
    norm_layer = ifnone(norm_layer, nn.InstanceNorm2d)
    bias = (norm_layer == nn.InstanceNorm2d)
    layers = conv_norm_lr(ch_in, n_ftrs, ks=4, stride=2, pad=1)
    for i in range(n_layers-1):
        new_ftrs = 2*n_ftrs if i <= 3 else n_ftrs
        layers += conv_norm_lr(n_ftrs, new_ftrs, norm_layer, ks=4, stride=2, pad=1, bias=bias)
        n_ftrs = new_ftrs
    new_ftrs = 2*n_ftrs if n_layers <=3 else n_ftrs
    layers += conv_norm_lr(n_ftrs, new_ftrs, norm_layer, ks=4, stride=1, pad=1, bias=bias)
    layers.append(nn.Conv2d(new_ftrs, 1, kernel_size=4, stride=1, padding=1))
    if sigmoid: layers.append(nn.Sigmoid())
    return nn.Sequential(*layers)



class CycleGAN(nn.Module):
    
    def __init__(self, ch_in:int, ch_out:int, n_features:int=64, disc_layers:int=3, gen_blocks:int=6, lsgan:bool=True, 
                 drop:float=0., norm_layer:nn.Module=None):
        super().__init__()
        self.D_A = discriminator(ch_in, n_features, disc_layers, norm_layer, sigmoid=not lsgan)
        self.D_B = discriminator(ch_in, n_features, disc_layers, norm_layer, sigmoid=not lsgan)
        self.G_A = resnet_generator(ch_in, ch_out, n_features, norm_layer, drop, gen_blocks)
        self.G_B = resnet_generator(ch_in, ch_out, n_features, norm_layer, drop, gen_blocks)
        #G_A: takes real input B and generates fake input A
        #G_B: takes real input A and generates fake input B
        #D_A: trained to make the difference between real input A and fake input A
        #D_B: trained to make the difference between real input B and fake input B
    
    def forward(self, x):
        real_A, real_B = x
        fake_A, fake_B = self.G_A(real_B), self.G_B(real_A)
        if not self.training: return torch.cat([fake_A[:,None],fake_B[:,None]], 1)
        idt_A, idt_B = self.G_A(real_A), self.G_B(real_B) #Needed for the identity loss during training.
        return [fake_A, fake_B, idt_A, idt_B]



In [4]:
# load model
h2s = torch.load('h2s.pkl')
s2h = torch.load('s2h.pkl')

def transform(img, kind='simpsonize'):
    timg = TensorImage(array(img)).permute(2,0,1).float()/255.
    #timg = tfms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(TensorImage(array(img)).permute(2,0,1).float()/255.)
    xb =  TensorImage(timg[None].expand(1, *timg.shape).clone())
    if kind=='simpsonize':
        preds = (h2s(xb)/2 + 0.5)
    else:
        preds = (s2h(xb)/2 + 0.5)
    return ToPILImage()(preds[0])


VERBOSE = False

TEXTS = {"btn_header" : {"en" : "Simpsonizer", "de" : "Simpsonizer"},
         "dpd_lang" : {"en" : "Language", "de" : "Sprache"},
         "dpd_kind" : {"simpsonize" : "Simpsonize", "humanize" : "Humanize"},
         "btn_doc" : {"en" : "Show Info", "de" : "Info Anzeigen"},
         "btn_upload" : {"en" : "Upload Image", "de" : "Bild hochladen"},
         "btn_status_init" : {"en" : "", "de" : ""},
         "btn_status_progress" : {"en" : "Please Wait... - ", "de" : "Bitte Warten... - "},
         "btn_status_default" : {"en" : "Please Wait", "de" : "Bitte Warten"},
         "btn_status_load" : {"en" : "Loading Image", "de" : "Bild wird geladen"},
         "btn_status_detect" : {"en" : "Generating simpsonized image", "de" : "Simpsonierte Version wird generiert"},
         "btn_status_ready" : {"en" : "Ready For User Input", "de" : "Bereit für Benutzereingabe"},
         "btn_status_error" : {"en" : "Error Loading Image", "de" : "Fehler beim Abruf des Bildes"},
         "select_image" : {"en" : "Please Select An Valid Image File", "de" : "Bitte wählen Sie eine gültige Bilddatei"},
         "prob" : {"en" : "Probability", "de" : "Wahrscheinlichkeit"}
        }

lang="en"
kind="simpsonize"

HTML_EN = """<div class="jp-RenderedHTMLCommon jp-RenderedMarkdown jp-MarkdownOutput " data-mime-type="text/markdown">
<h1 id="simpsonizer">Simpsonizer</h1>
<p>Upload a picture and see what the simpsonized version of that image looks like.</p>
<p>This is a small project as a result of Tanishq Abraham&#39;s <a href="https://github.com/tmabraham/UPIT">UPIT</a> code and the book <a href="https://www.amazon.de/Deep-Learning-Coders-Fastai-Pytorch/dp/1492045527">Deep Learning for Coders with fastai and PyTorch: AI Applications Without a PhD</a>.</p>
<p>An artificial intelligence model (CycleGAN) was trained with images of humans and Simpsons to turn a picture of a human into a Simpsons character and vice versa. The results ranged from &quot;it&#39;s something&quot; to &quot;abstract art&quot; to &quot;nightmare&quot;.</p>
</div>"""

HTML_DE = """<div class="jp-RenderedHTMLCommon jp-RenderedMarkdown jp-MarkdownOutput " data-mime-type="text/markdown">
<h1 id="simpsonizer">Simpsonizer</h1>
<p>Lade ein Bild hoch und finde heraus, wie die Simpson-Version davon aussieht. Nach circa 1,5 Minuten erscheint das generierte Bild. </p>
<p>Eine Künstliche Intelligenz Modell (CycleGAN) wurde mit Bilden von Menschen und Simpsons trainiert, um aus einem Bild eines Menschen ein Simpsons Character zu machen und umgekehrt. Die Bandbreite der Ergebnisse reicht von &quot;schön ist anders&quot;, über &quot;sieht aus wie abstrakte Kunst&quot; bis hin zu &quot;Albtraum&quot;.</p>
<p>Dies welches auf Tanishq Abraham&#39;s <a href="https://github.com/tmabraham/UPIT">UPIT</a> und dem Buch <a href="https://www.amazon.de/Deep-Learning-Coders-Fastai-Pytorch/dp/1492045527">Deep Learning for Coders with fastai and PyTorch: AI Applications Without a PhD</a> basiert.</p>
</div>"""

DOC = {"en" : HTML_EN,
       "de" : HTML_DE}

# defining widgets
dpd_lang = Dropdown(options=['en', 'de'], value='en', 
                    description=TEXTS["dpd_lang"][lang], layout=Layout(height='auto', width='auto'))
dpd_kind = Dropdown(options=['simpsonize', 'humanize'], value='simpsonize', 
                    description=TEXTS["dpd_kind"][kind], layout=Layout(height='auto', width='auto'))
btn_doc = Button(description=TEXTS["btn_doc"][lang], layout=Layout(height='auto', width='auto'))
btn_upload = FileUpload(description=TEXTS["btn_upload"][lang], multiple=False, layout=Layout(height='auto', width='auto'))
btn_header = Button(description=TEXTS["btn_header"][lang], disabled=True, layout=Layout(height='auto', width='auto'))
btn_status = Button(description=TEXTS["btn_status_init"][lang], disabled=True, layout=Layout(height='auto', width='auto'))
input_img = Output(clear_output=True)
output = Output(clear_output=True)

# styling
for btn in [btn_header, output]:
    btn.add_class('out_style')
for btn in [btn_status, btn_doc, btn_upload, dpd_lang, dpd_kind]:
    btn.add_class('gochihand')


# defining event functions
def displayWaitMessage(message=TEXTS["btn_status_default"][lang]):
    btn_status.description = f'{TEXTS["btn_status_progress"][lang]} {message}'
    btn_status.style.button_color = 'orange'

def displayReadyness():
    btn_status.description = TEXTS["btn_status_ready"][lang]
    btn_status.style.button_color = 'lightgreen'

def outputImage(img):
    with input_img:
        output.clear_output()
        displayWaitMessage(TEXTS["btn_status_load"][lang])
        start = time.time()
        if VERBOSE: print(img.size)
        display(img.to_thumb(500))
        displayWaitMessage(TEXTS["btn_status_detect"][lang])
        end = time.time()
        if VERBOSE: print('took', end-start, 'for displaying the image')
        start = time.time()
    with output:
        gen_img = transform(img.to_thumb(500), kind=kind)
        display(gen_img)
        end = time.time()
        if VERBOSE: print('took', end-start, 'for displaying the preds')
        displayReadyness()


def on_btn_doc_clicked(b):
    output.clear_output()
    input_img.clear_output()
    with output:
        display(HTML(DOC[lang]))
        
def on_data_change(change):
    output.clear_output()
    input_img.clear_output()
    start = time.time()
    img = PILImage.create(btn_upload.data[-1])
    end = time.time()
    if VERBOSE: print('took', end-start, 'for loading the image')
    outputImage(img)
    btn_upload._counter = 0
    
def on_lang_select(change):
    global lang 
    lang = dpd_lang.value
    dpd_lang.description=TEXTS["dpd_lang"][lang]
    btn_upload.description=TEXTS["btn_upload"][lang]
    btn_header.description=TEXTS["btn_header"][lang]
    btn_status.description=TEXTS["btn_status_init"][lang]
    btn_doc.description=TEXTS["btn_doc"][lang]
    
def on_kind_select(change):
    global kind 
    kind = dpd_kind.value
    dpd_kind.description=TEXTS["dpd_kind"][kind]

# adding events
btn_doc.on_click(on_btn_doc_clicked)
btn_upload.observe(on_data_change, names=['data'])
dpd_lang.observe(on_lang_select)
dpd_kind.observe(on_kind_select)

# Layout and Style 
applayout1 = AppLayout(left_sidebar=btn_doc,
                       center=dpd_kind,
                       right_sidebar=dpd_lang)
applayout2 = AppLayout(header=btn_status,
                       left_sidebar=input_img,
                       right_sidebar=output,
                       footer=btn_upload)

displayReadyness()

display(VBox([btn_header, applayout1, applayout2]))

VBox(children=(Button(description='Simpsonizer', disabled=True, layout=Layout(height='auto', width='auto'), st…

In [5]:
kind

'simpsonize'