<a href="https://colab.research.google.com/github/MiraPurkrabek/vs3-cnn-labs/blob/main/Mira_Beer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install kornia timm

In [None]:
from fastai.vision.all import *

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np
import cv2
import seaborn as sns
import torch
torch.set_num_threads(1)
import torch.nn as nn
import torch.nn.functional as F
import torchvision as tv
import kornia as K
from tqdm import tqdm_notebook as tqdm
from time import time

def imshow_torch(tensor,figsize=(8,6), *kwargs):
    plt.figure(figsize=figsize)
    plt.imshow(K.tensor_to_image(tensor), *kwargs)
    return

def imshow_torch_channels(tensor, dim = 1, *kwargs):
    num_ch = tensor.size(dim)
    fig=plt.figure(figsize=(num_ch*5,5))
    tensor_splitted = torch.split(tensor, 1, dim=dim)
    for i in range(num_ch):
        fig.add_subplot(1, num_ch, i+1)
        plt.imshow(K.tensor_to_image(tensor_splitted[i].squeeze(dim)), *kwargs)
    return

In [None]:
!wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-160.tgz
!tar -xzf imagenette2-160.tgz

In [None]:
dls = ImageDataLoaders.from_folder('imagenette2-160/',
                                   valid='val', 
                                   item_tfms=RandomResizedCrop(128, min_scale=0.35),
                                   batch_tfms=Normalize.from_stats(*imagenet_stats))

# if a string is passed into the model argument, it will now use timm (if it is installed)
learn = vision_learner(dls, 'resnet18',  metrics=error_rate)


In [None]:
dls.show_batch()

In [None]:
learn.lr_find()

In [None]:
learn.fine_tune(5,base_lr = 2e-4, freeze_epochs=2, cbs=[ShowGraphCallback()])

# Task

Beer vs Trdelnik recognition

In [None]:
from fastai.data.all import *
from fastai.vision.all import *

fnames = get_image_files('data/')

def label_func(fname):
    return str(fname).split('/')[1]

In [None]:
dblock = DataBlock(blocks    = (ImageBlock, CategoryBlock),
                   get_items = get_image_files,
                   get_y     = label_func,
                   splitter  = RandomSplitter(),
                   item_tfms = Resize(224))

dsets = dblock.datasets('data')
dsets.train[0]

In [None]:
dsets.vocab

In [None]:
dls = dblock.dataloaders("data", bs=4)
dls.show_batch(max_n=9, figsize=(4,4))

In [None]:
learn = Learner(dls, xresnet34(n_out=2), metrics=accuracy)

In [None]:
learn.fine_tune(10)

In [None]:
preds, y, losses = learn.get_preds(with_loss=True)
interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix()

In [None]:
learn.show_results()

In [None]:
img=PILImage(load_image('tr1.jpg'))
pred = learn.predict(img)