In this kernel , I am using Tez liberary to build a custom model and using EfficientNet-B4 . Special thanks to [Abhishek Thakur](http://www.kaggle.com/abhishek) for the tutorial . but still needs to understand quite a lot stuff 😊

In [None]:
#installing liberaries
!pip install tez
!pip install efficientnet_pytorch

In [None]:
#importing packages  
import os
import albumentations
import matplotlib.pyplot as plt
import pandas as pd

import tez
from tez.datasets import ImageDataset
from tez.callbacks import EarlyStopping

import torch 
import torch.nn as nn
import torchvision
from torch.nn import functional as F

from sklearn import metrics , model_selection
from efficientnet_pytorch import EfficientNet

%matplotlib inline

In [None]:
#reading train.csv
dfx = pd.read_csv("../input/cassava-leaf-disease-classification/train.csv")

In [None]:
#fetching the first five rows of the train data file
dfx.head()

In [None]:
#count of each unique label 
dfx.label.value_counts()

In [None]:
#splitting the data into train and valid 
df_train , df_valid = model_selection.train_test_split(
                                    dfx,
                                    test_size=0.1,
                                    random_state=2020,
                                    stratify=dfx.label.values
)
df_train = df_train.reset_index(drop=True)
df_valid = df_valid.reset_index(drop=True)

In [None]:
df_train.shape , df_valid.shape

In [None]:
#loading the images with image path
image_path = "../input/cassava-leaf-disease-classification/train_images"

train_image_paths = [
    os.path.join(image_path,x) for x in df_train.image_id.values
]
valid_image_paths = [
    os.path.join(image_path,x) for x in df_valid.image_id.values
]

In [None]:
train_image_paths[:5]

In [None]:
#storing the lables values 
train_targets = df_train.label.values
valid_targets = df_valid.label.values

In [None]:
#using image dataloader to load the image datset 
train_dataset = ImageDataset(
    image_paths =train_image_paths,
    targets = train_targets,
    augmentations = None
    
)

In [None]:
#function to plot the image 
def plot_image(image_dict):
    img_tensor = image_dict["image"]
    target = image_dict["targets"]
    plt.figure(figsize=(5,5))
    print(target)
    image = img_tensor.permute(1,2,0)/255
    plt.imshow(image)

In [None]:
#plot before augmentation
plot_image(train_dataset[3])

In [None]:
#applying augmnetation to train and valid data
train_aug = albumentations.Compose(
    [
        albumentations.RandomResizedCrop(256,256),
        albumentations.Transpose(p=0.5),
        albumentations.HorizontalFlip(p=0.5),
        albumentations.VerticalFlip(p=0.5),
        albumentations.Resize(256,256),
        
    ]
)
valid_aug = albumentations.Compose(
    [
        albumentations.CenterCrop(256,256,p=1.0),
        albumentations.Resize(256,256),
        albumentations.Transpose(p=0.5),
        albumentations.HorizontalFlip(p=0.5),
        albumentations.VerticalFlip(p=0.5)
    ]

)

train_dataset = ImageDataset(
    image_paths =train_image_paths,
    targets = train_targets,
    
    augmentations = train_aug
)
valid_dataset = ImageDataset(
    image_paths =valid_image_paths,
    targets = valid_targets,
    
    augmentations = valid_aug
)

In [None]:
#plot after augmnetation
plot_image(train_dataset[3])

In [None]:
#custom model
class LeafModel(tez.Model):
    def __init__(self, num_classes):
        super().__init__()

        self.effnet = EfficientNet.from_pretrained("efficientnet-b4")
        self.dropout = nn.Dropout(0.1)
        self.out = nn.Linear(1792, num_classes)
        self.step_scheduler_after = "epoch"
        
    def monitor_metrics(self, outputs, targets):
        if targets is None:
            return {}
        outputs = torch.argmax(outputs, dim=1).cpu().detach().numpy()
        targets = targets.cpu().detach().numpy()
        accuracy = metrics.accuracy_score(targets, outputs)
        return {"accuracy": accuracy}
    
    def fetch_optimizer(self):
        opt = torch.optim.Adam(self.parameters(), lr=1e-3)
        return opt
    
    def fetch_scheduler(self):
        sch = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=0.5)
        return sch
    
    def forward(self, image, targets=None):
        batch_size, _, _, _ = image.shape

        x = self.effnet.extract_features(image)
        x = F.adaptive_avg_pool2d(x, 1).reshape(batch_size, -1)
        outputs = self.out(self.dropout(x))
        
        if targets is not None:
            loss = nn.CrossEntropyLoss()(outputs, targets)
            metrics = self.monitor_metrics(outputs, targets)
            return outputs, loss, metrics
        return outputs, None, None

In [None]:
#loading the model and fitting it 
model = LeafModel(num_classes=dfx.label.nunique())

es = EarlyStopping(
    monitor="valid_loss", model_path="model.bin", patience=3, mode="min"
)
model.fit(
    train_dataset,
    valid_dataset=valid_dataset,
    train_bs=32,
    valid_bs=64,
    device="cuda",
    epochs=10,
    callbacks=[es]
)

In [None]:
#prepartion of test dataset
test_dfx = pd.read_csv("../input/cassava-leaf-disease-classification/sample_submission.csv")
image_path = "../input/cassava-leaf-disease-classification/test_images/"
test_image_paths = [os.path.join(image_path, x) for x in test_dfx.image_id.values]
# fake targets
test_targets = test_dfx.label.values


test_aug = albumentations.Compose([
            albumentations.CenterCrop(256, 256, p=1.),
            albumentations.Resize(256, 256),
            albumentations.Normalize(
                mean=[0.485, 0.456, 0.406], 
                std=[0.229, 0.224, 0.225], 
                max_pixel_value=255.0, 
                p=1.0
            )], p=1.)

test_dataset = ImageDataset(
    image_paths=test_image_paths,
    targets=test_targets,
    augmentations=test_aug,
)


In [None]:
#predition 
preds = model.predict(test_dataset, batch_size=32, n_jobs=-1)
final_preds = None
for p in preds:
    if final_preds is None:
        final_preds = p
    else:
        final_preds = np.vstack((final_preds, p))
final_preds = final_preds.argmax(axis=1)
test_dfx.label = final_preds
test_dfx.to_csv("submission.csv", index=False)