In [22]:
from torchvision.io import read_image
from torchvision.models.efficientnet import efficientnet_b0, EfficientNet_B0_Weights
#from torchvision.models.vgg import vgg13, VGG13_Weights
from torchvision.transforms import Normalize
from torchvision.transforms._presets import ImageClassification, InterpolationMode
from functools import partial

from captum.attr import Saliency, DeepLift
from captum.attr import visualization as viz

from io import BytesIO
from PIL import Image

%matplotlib inline
import numpy as np

net_weights = EfficientNet_B0_Weights#EfficientNet_B6_Weights.DEFAULT
net = efficientnet_b0 #efficientnet_b6

class EfficientNet():
    def __init__(self) -> None:
        # Use pretrained efficient net as default.
        #TODO: ASYNC
        weights = net_weights
        #self.preprocess = weights.transforms() 
        #What transform: https://pytorch.org/vision/main/models/generated/torchvision.models.efficientnet_b0.html#torchvision.models.EfficientNet_B0_Weights
        #MAke custom, to avoid zooming and cropping
        transform = partial(ImageClassification, crop_size=256, resize_size=256, interpolation=InterpolationMode.BICUBIC)
        self.preprocess = transform()
        model = net(net_weights)
        model.eval()
        self.model = model        
        self.saliency = Saliency(self.model)

    def pass_image_to_net(self):
        '''Triggers the image classification process
        :return: PIL Image instance of the original image and the explained image
        '''
        # Make Prediction
        predict = self.model(self.preprocessed_image).squeeze(0).softmax(0)
        class_id = predict.argmax().item()        
        self.prediction = class_id
        #self.prediction = self.categories[class_id]
        #self.prediction = self.categories_eng[class_id]

        # Make images
        #TODO: ASYNC
        attribution = self._get_explanation(class_id, self.preprocessed_image)
        ki_image = self._get_attribution_image(self.preprocessed_image, attribution, self.original_image)

        return ki_image
    
    def get_predictions(self):
        predict = self.model(self.preprocessed_image).squeeze(0).softmax(0)
        res = sorted(range(len(predict)), key = lambda sub: predict[sub])[-3:]
        return res, predict[res]

    def process_image(self, image_path:str):
        # Process the image
        img = read_image(image_path)
        self.preprocessed_image = self.preprocess(img).unsqueeze(0)
        
        self.original_image = self._get_original_image(self.preprocessed_image)
        return self.original_image
        

    def _get_original_image(self, preprocessed_image)-> Image:
        '''Private function that returns the original but scaled image'''
        # Scale image
        scaled_image = self._scale_image(preprocessed_image)
        
        # Use Captum for visualisation        
        fig, ax = viz.visualize_image_attr(None, #hier
                                    np.transpose(scaled_image.squeeze().cpu().detach().numpy(), (1,2,0)),
                                    method='original_image',
                                    show_colorbar=False,
                                    use_pyplot=False)
        
        # Get a PIL image
        image = self._fig2img(fig)
        return image

    def _get_attribution_image(self, preprocessed_image, attribution, original_image)-> Image:
        '''Private function that returns the explaination image'''
        # Use Captum for visualisation     
        fig, ax = viz.visualize_image_attr(attribution,
                             np.transpose(preprocessed_image.squeeze().cpu().detach().numpy(), (1,2,0)),
                             method='heat_map',
                             show_colorbar=False,
                             sign='absolute_value',
                             alpha_overlay=0.5,
                             outlier_perc=1,
                             use_pyplot=False)
        
        image = self._fig2img(fig)
        image = self._threshold_blue_channel(image) #looks cleaner, accuracy is not relevant for this purpose here.
        
        t_image = Image.composite(image, original_image, mask=image) #combine with original image
        t_image.show()

        return t_image
    

    def _get_explanation(self, id, preprocessed_image):
         '''Get the attribution map'''
         attribution = self.saliency.attribute(preprocessed_image, target=id)
         return np.transpose(attribution.squeeze().cpu().detach().numpy(), (1,2,0))


    def _threshold_blue_channel(self,image, threshold = 230)->Image:
        '''Calculate the blue marks for the KI recognition. 
        This is not accurate, but the outcome looks way nicer and cleaner.
        For the purpose of visualisation it is perfectly fine
        (Saliency is not reliable anyways)
        '''
        def pixelProc(attr):
            return 0
        multiBands = image.split()
        # only blue pixels thus set green and red to 0
        redBand = multiBands[0].point(pixelProc)
        greenBand = multiBands[1].point(pixelProc)
        # allow blue pixels when above threshold
        blueBand = multiBands[2].point(lambda x: 0 if x > threshold else 255)

        return Image.merge("RGBA", (redBand, greenBand, blueBand, blueBand))


    def _scale_image(self,preprocessed_image):
        '''Make image look correctly'''
        inv_normalize = Normalize(
            mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
            std=[1/0.229, 1/0.224, 1/0.225]
        )
        return inv_normalize(preprocessed_image)
    
    def _fig2img(self, figure, use_buf = True):
        '''TODO: Remove, for testing purpose only. Decide for a method!'''
        if use_buf:
            buf = BytesIO()
            figure.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
            buf.seek(0)
            return Image.open(buf)
        else:
            return Image.frombytes('RGB', figure.canvas.get_width_height(),figure.canvas.tostring_rgb())

In [2]:
cat = ["Schleie","Goldfisch","Weißer Hai","Tigerhai","Hammerhai","Zitterrochen","Stachelrochen","Hahn","Henne","Strauß","Bergfink","Stieglitz","Hausfink","Wacholderdrossel","Indigofink","Rotkehlchen","Bulle","Eichelhäher","Elster","Meise","Wasseramsel","Milan","Weißkopfseeadler","Geier","Steinkauz","Feuersalamander","Molch","Eidechsenmolch","Tüpfelsalamander","Axolotl","Ochsenfrosch","Laubfrosch","Schwanzfrosch","Unechter Kammmolch","Lederschildkröte","Sumpfschildkröte","Sumpfschildkröte","Dosenschildkröte","Bändergecko","Gewöhnlicher Leguan","Amerikanisches Chamäleon","Peitschenschwanz","Agama","Kragenechse","Alligatorenechse","Gila-Monster","Grüne Echse","Afrikanisches Chamäleon","Komodowaran","Afrikanisches Krokodil","Amerikanischer Alligator","Triceratops","Donnerschlange","Ringelnatter","Hognoseschlange","Grüne Schlange","Königsnatter","Strumpfbandnatter","Wasserschlange","Weinschlange","Nachtschlange","Boa constrictor","Felsenpython","Indische Kobra","Grüne Mamba","Seeschlange","Hornviper","Diamantrücken","Sidewinder","Trilobit","Weberknecht","Skorpion","Schwarzgoldspinne","Scheunenspinne","Gartenkreuzspinne","Schwarze Witwe","Vogelspinne","Wolfsspinne","Zecke","Tausendfüßler","Birkhuhn","Schneehuhn","Kragenhuhn","Präriehuhn","Pfau","Wachtel","Rebhuhn","Graupapagei","Ara","Gelbhaubenkakadu","Lories","Kuckuck","Bienenfresser","Nashornvogel","Kolibri","Jacamar","Tukan","Erpel","Mittelsäger","Gans","Schwarzer Schwan","Stoßzahn","Schnabeligel","Schnabeltier","Wallaby","Koala","Wombat","Qualle","Seeanemone","Koralle","Plattwurm","Fadenwurm","Muschel","Schnecke","Nacktschnecke","Meeresschnecke","Käfer","gekammerte Nautilus","Dungeness Crab","Felsenkrabbe","Geißelkrebs","Königskrabbe","Amerikanischer Hummer","Languste","Flusskrebs","Einsiedlerkrebs","Assel","Weißstorch","Schwarzstorch","Löffler","Flamingo","Blaureiher","Silberreiher","Rohrdommel","Kranich","Limikolen","Blässhühner","Trappen","Steinwälzer","Bruchwasserläufer","Rotrückenstrandläufer","Rotschenkel","Wasserläufer","Austernfischer","Pelikan","Königspinguin","Albatros","Grauwal","Schwertwal","Seekuh","Seelöwe","Chihuahua","Japanischer Spaniel","Malteser","Pekinese","Shih-Tzu","Blenheim Spaniel","Papillon","Toy Terrier","Rhodesian Ridgeback","Afghanischer Windhund","Basset","Beagle","Bloodhound","Blaugetickter Coonhound","Schwarz und Tan Coonhound","Walker Coonhound","English foxhound","redbone","borzoi","Irischer Wolfshund","Italienischer Windhund","whippet","Ibizan hound","Norwegian elkhound","Otterhound","Saluki","Schottischer Hirschhund","Weimaraner","Staffordshire Bullterrier","American Staffordshire terrier","Bedlington terrier","Border terrier","Kerry Blue Terrier","Irischer Terrier","Norfolk Terrier","Norwich Terrier","Yorkshire Terrier","Drahthaar-Foxterrier","Lakeland Terrier","Sealyham Terrier","Airedale","Cairn","Australian Terrier","Dandie Dinmont","Boston Bull","Zwergschnauzer","Riesenschnauzer","Standardschnauzer","Schottischer Terrier","Tibet Terrier","Silky Terrier","Soft Coated Wheaten Terrier","West Highland White Terrier","Lhasa","Flat-Coated Retriever","Gelockter Retriever","Golden Retriever","Labrador Retriever","Chesapeake Bay Retriever","Deutscher Kurzhaarvorsteher","Vizsla","English Setter","Irish Setter","Gordon Setter","Bretonischer Spaniel","Clumber","English Springer","Welsh Springer Spaniel","Cocker Spaniel","Sussex Spaniel","Irischer Wasserspaniel","Kuvasz","Schipperke","Groenendael","Malinois","Briard","Kelpie","Komondor","Old English Shepherd Dog","Shetland-Schäferhund","Collie","Border Collie","Bouvier des Flandres","Rottweiler","Deutscher Schäferhund","Dobermann","Zwergpinscher","Großer Schweizer Sennenhund","Berner Sennenhund","Appenzeller","EntleBucher","Boxer","Bullmastiff","Tibetischer Mastiff","Französische Bulldogge","Deutsche Dogge","Bernhardiner","Eskimohund","Malamute","Sibirischer Husky","Dalmatiner","Affenpinscher","Basenji","Mops","Leonberg","Neufundländer","Großer Pyrenäenhund","Samojede","Zwergspitz","Chow","Keeshond","Brabancon griffon","Pembroke","Cardigan","Zwergpudel","Zwergpudel","Standardpudel","Mexikanischer Nackthund","Timberwolf","Weißer Wolf","Roter Wolf","Kojote","Dingo","Dhole","Afrikanischer Jagdhund","Hyäne","Rotfuchs","Rotfuchs","Polarfuchs","Graufuchs","getigerte Katze","Tigerkatze","Perserkatze","Siamkatze","Ägyptische Katze","Puma","Luchs","Leopard","Schneeleopard","Jaguar","Löwe","Tiger","Gepard","Braunbär","Amerikanischer Schwarzbär","Eisbär","Faultierbär","Mungo","Erdmännchen","Tigerkäfer","Marienkäfer","Laufkäfer","Bockkäfer","Blattkäfer","Mistkäfer","Nashornkäfer","Rüsselkäfer","Fliege","Biene","Ameise","Heuschrecke","Grille","Stabheuschrecke","Kakerlake","Gottesanbeterin","Zikade","Blattlaus","Florfliege","Libelle","Kleinlibelle","Admiral","Ringeltaube","Monarch","Kohlweißling","Schwebfalter","Bläuling","Seestern","Seeigel","Seegurke","Waldkaninchen","Hase","Angorakaninchen","Hamster","Stachelschwein","Fuchshörnchen","Murmeltier","Biber","Meerschweinchen","Sauerampfer","Zebra","Schwein","Wildschwein","Warzenschwein","Nilpferd","Ochse","Wasserbüffel","Bison","Widder","Dickhorn","Steinbock","Antilope","Impala","Gazelle","arabisches Kamel","Lama","Wiesel","Nerz","Iltis","Schwarzfußfrettchen","Otter","Stinktier","Dachs","Gürteltier","Dreifingerfaultier","Orang-Utan","Gorilla","Schimpanse","Gibbon","Siamang","Meerkatze","Patasaffe","Pavian","Makake","Langur","Colobus","Nasenaffe","Totenkopfäffchen","Kapuzineraffe","Brüllaffe","Titi","Klammeraffe","Totenkopfäffchen","Madagaskarkatze","Indri","Indischer Elefant","Afrikanischer Elefant","Kleiner Panda","Riesenpanda","Barrakuda","Aal","Coho","Riffbarsch","Anemonenfisch","Stör","Hornhecht","Rotfeuerfisch","Kugelfisch","Abakus","Abaya","Doktorkittel","Akkordeon","Akustikgitarre","Flugzeugträger","Verkehrsflugzeug","Luftschiff","Altar","Krankenwagen","Schwimmfahrzeug","Analoguhr","Bienenhaus","Schürze","Aschenbecher","Sturmgewehr","Rucksack","Bäckerei","Schwebebalken","Luftballon","Kugelschreiber","Pflaster","Banjo","Treppengeländer","Langhantel","Friseurstuhl","Friseursalon","Scheune","Barometer","Fass","Schubkarre","Baseball","Basketball","Wiege","Fagott","Badekappe","Badetuch","Badewanne","Strandwagen","Leuchtturm","Becher","Bärenfell","Bierflasche","Bierglas","Glockenstuhl","Lätzchen","Zweirad","Bikini","Ordner","Fernglas","Vogelhaus","Bootshaus","Bob","Bolo-Krawatte","Haube","Bücherregal","Buchhandlung","Flaschenverschluss","Bogen","Fliege","Messing","Büstenhalter","Wellenbrecher","Brustpanzer","Besen","Eimer","Schnalle","kugelsichere Weste","Schnellzug","Metzgerei","Taxi","Kessel","Kerze","Kanone","Kanu","Dosenöffner","Strickjacke","Autospiegel","Karussell","Tischlerausrüstung","Karton","Autorad","Geldautomat","Kassette","Kassettenspieler","Schloss","Katamaran","CD-Spieler","Cello","Mobiltelefon","Kette","Maschendrahtzaun","Kettenhemd","Kettensäge","Truhe","Wäscheschrank","Glockenspiel","Porzellanschrank","Weihnachtsstrumpf","Kirche","Kino","Hackbeil","Felsenwohnung","Mantel","Klotz","Cocktailshaker","Kaffeebecher","Kaffeekanne","Spule","Kombinationsschloss","Computertastatur","Süßwaren","Containerschiff","Cabrio","Korkenzieher","Kornett","Cowboystiefel","Cowboyhut","Wiege","Kran","Sturzhelm","Kiste","Krippe","Crock Pot (elektrischer Kochtopf)","Krocketball","Krücke","Kürass","Damm","Schreibtisch","Tischcomputer","Wählscheibentelefon","Windel","Digitaluhr","Digitalwecker","Esstisch","Geschirrtuch","Spülmaschine","Scheibenbremse","Dock","Hundeschlitten","Kuppel","Fußmatte","Bohrinsel","Trommel","Trommelstock","Hantel","Holländischer Ofen","elektrischer Ventilator","elektrische Gitarre","elektrische Lokomotive","Unterhaltungszentrum","Briefumschlag","Espressomaschine","Gesichtspuder","Federboa","Datei","Feuerschiff","Feuerwehrauto","Feuerschirm","Fahnenmast","Flöte","Klappstuhl","Fußballhelm","Gabelstapler","Brunnen","Füllfederhalter","Himmelbett","Güterwaggon","Waldhorn","Bratpfanne","Pelzmantel","Müllwagen","Gasmaske","Zapfsäule","Pokal","Go-Kart","Golfball","Golfwagen","Gondel","Gong","Kleid","Flügel","Gewächshaus","Gitter","Lebensmittelgeschäft","Guillotine","Haarnadel","Haarspray","Halbspur","Hammer","Wäschekorb","Handgebläse","Handcomputer","Taschentuch","Festplatte","Mundharmonika","Harfe","Mähdrescher","Beil","Holster","Heimkino","Honigwabe","Haken","Reifrock","Reck","Pferdewagen","Sanduhr","iPod","Bügeleisen","Kürbislaterne","Jeans","Jeep","Trikot","Puzzle","Rikscha","Joystick","Kimono","Knieschoner","Knoten","Laborkittel","Schöpfkelle","Lampenschirm","Laptop","Rasenmäher","Linsenkappe","Brieföffner","Bibliothek","Rettungsboot","Feuerzeug","Limousine","Liner","Lippenstift","Slipper","Lotion","Lautsprecher","Lupe","Holzmühle","Magnetkompass","Postsack","Briefkasten","Badeanzug","Badeanzug","Gullydeckel","Maracas","Marimba","Maske","Streichholz","Maibaum","Labyrinth","Messbecher","Hausapotheke","Megalith","Mikrofon","Mikrowelle","Militäruniform","Milchkanne","Minibus","Minirock","Minivan","Rakete","Fäustling","Rührschüssel","Wohnmobil","Model T","Modem","Kloster","Monitor","Moped","Mörser","Mörtelbrett","Moschee","Moskitonetz","Motorroller","Mountainbike","Bergzelt","Maus","Mausefalle","Umzugswagen","Schnauze","Nagel","Halskrause","Halskette","Brustwarze","Notizbuch","Obelisk","Oboe","Okarina","Tachometer","Ölfilter","Orgel","Oszilloskop","Überrock","Ochsenkarren","Sauerstoffmaske","Paket","Paddel","Schaufelrad","Vorhängeschloss","Pinsel","Schlafanzug","Palast","Panflöte","Papierhandtuch","Fallschirm","Barren (Turngerät)","Parkbank","Parkuhr","Pkw","Terrasse","Münztelefon","Sockel","Bleistiftkasten","Bleistiftspitzer","Parfüm","Petrischale","Fotokopierer","Plektrum","Pickelhaube","Lattenzaun","Pickup","Pier","Sparschwein","Pillenflasche","Kissen","Tischtennisball","Windrad","Pirat","Krug","Flugzeug","Planetarium","Plastiktüte","Geschirrablage","Pflug","Saugglocke","Polaroidkamera","Mast","Polizeiwagen","Poncho","Billardtisch","Flasche","Topf","Töpferscheibe","Bohrmaschine","Gebetsteppich","Drucker","Gefängnis","Projektil","Projektor","Puck","Boxsack","Geldbörse","Feder","Steppdecke","Rennwagen","Schläger","Heizkörper","Radio","Radioteleskop","Regentonne","Wohnmobil","Spule","Spiegelreflexkamera","Kühlschrank","Fernbedienung","Restaurant","Revolver","Gewehr","Schaukelstuhl","Drehspieß","Radiergummi","Rugbyball","Regel","Laufschuh","Safe","Sicherheitsnadel","Salzstreuer","Sandale","Sarong","Saxophon","Schwertscheide","Waage","Schulbus","Schoner","Anzeigetafel","Bildschirm","Schraube","Schraubenzieher","Sitzgurt","Nähmaschine","Schild","Schuhgeschäft","Shoji","Einkaufskorb","Einkaufswagen","Schaufel","Duschhaube","Duschvorhang","Ski","Skimaske","Schlafsack","Rechenschieber","Schiebetür","Schlitz","Schnorchel","Schneemobil","Schneepflug","Seifenspender","Fußball","Socke","Solarschüssel","Sombrero","Suppenschüssel","Theke","Heizlüfter","Raumschiff","Spachtel","Schnellboot","Spinnennetz","Spindel","Sportwagen","Scheinwerfer","Bühne","Dampflokomotive","Stahlbogenbrücke","Stahltrommel","Stethoskop","Stola","Steinmauer","Stoppuhr","Herd","Sieb","Straßenbahn","Trage","Sofa","Stupa","U-Boot","Anzug","Sonnenuhr","Sonnenbrille","Sonnenbrille","Sonnencreme","Hängebrücke","Tupfer","Sweatshirt","Badehose","Schaukel","Schalter","Spritze","Tischlampe","Panzer","Kassettenspieler","Teekanne","Teddy","Fernseher","Tennisball","Strohdach","Theatervorhang","Fingerhut","Mähdrescher","Thron","Ziegeldach","Toaster","Tabakladen","Toilettensitz","Fackel","Totempfahl","Abschleppwagen","Spielzeugladen","Traktor","Sattelschlepper","Tablett","Trenchcoat","Dreirad","Trimaran","Dreibein","Triumphbogen","Trolleybus","Posaune","Wanne","Drehkreuz","Tastatur","Regenschirm","Einrad","Ständer","Staubsauger","Vase","Tresor","Samt","Verkaufsautomat","Gewand","Viadukt","Geige","Volleyball","Waffeleisen","Wanduhr","Brieftasche","Kleiderschrank","Kampfflugzeug","Waschbecken","Waschmaschine","Wasserflasche","Wasserkanne","Wasserturm","Whiskeykanne","Pfeife","Perücke","Fensterscheibe","Fensterladen","Windsor-Krawatte","Weinflasche","Flügel","Wok","Holzlöffel","Wolle","Wurmzaun","Wrack","Jolle","Jurte","Website","Comic","Kreuzworträtsel","Straßenschild","Ampel","Buchumschlag","Speisekarte","Teller","Guacamole","Consomme","Hot Pot","Trifle","Eiscreme","Eislutscher","Baguette","Bagel","Brezel","Cheeseburger","Hotdog","Kartoffelpüree","Kopfkohl","Brokkoli","Blumenkohl","Zucchini","Spaghettikürbis","Eichelkürbis","Butternusskürbis","Gurke","Artischocke","Paprika","Kardone","Pilz","Granny Smith","Erdbeere","Orange","Zitrone","Feige","Ananas","Banane","Jackfrucht","Zimtapfel","Granatapfel","Heu","Carbonara","Schokoladensauce","Teig","Hackbraten","Pizza","Auflauf","Burrito","Rotwein","Espresso","Tasse","Eierlikör","Alm","Blase","Klippe","Korallenriff","Geysir","Seeufer","Landzunge","Sandbank","Küste","Tal","Vulkan","Ballspieler","Bräutigam","Taucher","Raps","Gänseblümchen","Gelber Frauenschuh (Orchidee)","Mais","Eichel","Hüfte","Rosskastanie","Korallenpilz","Fliegenpilz","Bischofsmütze (Pilz)","Stinkmorchel","Erdstern"," Schwefelporlinge","Steinpilz"," Ohr","Toilettenpapier"]

In [29]:
image="bulb2"
path = f"../assets/images/{image}.jpg"
model = EfficientNet()
model.process_image(path)
model.pass_image_to_net()
print(cat[model.prediction])
pred,p = model.get_predictions()
print(f"{cat[pred[2]]}, {p[2]*100}%")
print(f"{cat[pred[1]]}, {p[1]*100}%")
print(f"{cat[pred[0]]}, {p[0]*100}%")

Goldfisch
Goldfisch, 96.4022445678711%
Schleie, 0.15431037545204163%
Riffbarsch, 0.11850397288799286%


In [24]:
model = EfficientNet()
model.process_image(path)
model.pass_image_to_net()
print(cat[model.prediction])

Granny Smith
