In [293]:
import torch
import logging
import pickle as pkl
import pandas as pd
from typing import Any, Dict
import torch.nn as nn
from tqdm import tqdm
from torch.cuda.amp import autocast, GradScaler
from torchvision.models import resnet
import numpy as np
from sklearn import metrics
import logging
import pandas as pd
import logging
import cv2
import albumentations as A
from torchvision import models as tvmodels
from torch.utils.data import Dataset, DataLoader
from sklearn import model_selection
from typing import Tuple
from albumentations.pytorch import ToTensorV2
from sklearn.metrics import confusion_matrix

In [304]:
data = read_data("../data/external/cassava-leaf-disease-classification-small/train.csv")

In [294]:
model=torch.load("../models/cassnet_20-2", map_location=torch.device('cpu'))

In [295]:
def validate_model(model, optimizer, loss, val_dl, device):
    with torch.no_grad():
        acc, metric = val_one_epoch(model, loss, val_dl, device)
    return acc, metric

In [296]:
def read_data(path_to_data: str) -> pd.DataFrame:

    df: pd.DataFrame = pd.read_csv(path_to_data)
    df['label'] = df['label'].astype('string')

    return df

In [297]:
def train_val_split(
        df: pd.DataFrame,
        train_img_path: str,
        batch_size: int,
        num_workers: int,
        image_size: int
) -> Tuple[DataLoader, DataLoader]:

    img_mean = [0.485, 0.456, 0.406]
    img_std = [0.229, 0.224, 0.225]

    img_size = image_size


    train_df, valid_df = model_selection.train_test_split(
        df, train_size=0.1
    )
    train_trans = A.Compose([
        A.RandomResizedCrop(img_size, img_size),
        A.Transpose(p=0.5),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.ShiftScaleRotate(p=0.5),
        A.Normalize(img_mean, img_std),
        ToTensorV2(),
    ])

    val_trans = A.Compose([
        A.CenterCrop(img_size, img_size),
        A.Normalize(img_mean, img_std),
        ToTensorV2(),
    ])

    test_trans = A.Compose([
        A.CenterCrop(img_size, img_size),
        A.Normalize(img_mean, img_std),
        ToTensorV2(),
    ])
        
    train_ds = GetData(train_df, train_img_path, label_out = True, transform=train_trans)
    valid_ds = GetData(valid_df, train_img_path, label_out = True, transform=val_trans)
    
    train_dl = DataLoader(
        train_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=True,   
    )
    val_dl = DataLoader(
        valid_ds, 
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=False
    )
    return train_dl, val_dl

In [298]:
class GetData(Dataset):
    def __init__(self, df, dirr, label_out=True, transform=None):
        super().__init__()
        self.dirr = dirr
        self.label_out = label_out
        self.transform = transform
        self.df = df.reset_index(drop=True).copy()
        if self.label_out == True:
            self.labels = self.df['label'].values
    
    def __len__(self):
        return self.df.shape[0]
    
    def __getitem__(self, index:int):
        img = get_img("{}/{}".format(self.dirr, self.df.loc[index]['image_id']))
        if self.label_out == True:
            target = float(self.labels[index])
        
        img = self.transform(image=img)['image']
            
        if self.label_out:    
            return img, target
        if not self.label_out:
            return img
def get_img(path):
    im_bgr = cv2.imread(path)
    im_rgb = im_bgr[:, :, ::-1]
    return im_rgb


In [299]:
def val_one_epoch(model, loss, data_loader, device):
    model.eval()
    preds_all = []
    targets_all = []
    loss_sum = 0
    sample_num = 0
    pbar = tqdm(enumerate(data_loader), total=len(data_loader))
    for step, (imgs, targets) in pbar:
        imgs = imgs.to(device).float()
        targets = targets.to(device).long()
        
        preds = model(imgs)
        preds_all += [torch.argmax(preds, 1).detach().cpu().numpy()]
        targets_all += [targets.detach().cpu().numpy()]
        
        cost = loss(preds, targets)
        
        loss_sum += cost.item()*targets.shape[0]
        sample_num += targets.shape[0]
        
    preds_all = np.concatenate(preds_all)
    print('preds', preds_all)
    targets_all = np.concatenate(targets_all)
    print('targets', targets_all)
    confusion_matrix(preds_all, targets_all)
    #return (preds_all==targets_all).mean()

    scores = {
        "precision_recall": metrics.classification_report(
            y_true=targets_all,
            y_pred=preds_all
        ),
         "accuracy": metrics.accuracy_score(
            y_true=targets_all,
            y_pred=preds_all       
        )
    }
    print('accuracy = {:.4f}'.format((preds_all==targets_all).mean()))

    return (preds_all==targets_all).mean(), scores

In [300]:
device = 'cpu'
optimizer = torch.optim.Adam(model.parameters())
weight = torch.tensor([1087/21397, 2189/21397, 2386/21397,13158/21397, 2577/21397], dtype=torch.float, device='cpu')
loss = nn.CrossEntropyLoss(weight=weight).to(device)

In [301]:
train_df, valid_df = train_val_split(
        data,
        "../data/external/cassava-leaf-disease-classification-small/train_images/", 20,0, 224)

In [307]:
    acc, metric = validate_model(model, optimizer, loss, valid_df, device
    )

100%|██████████| 1/1 [00:01<00:00,  1.30s/it]preds [1 3 3 3 1 1 1 3 3 3 3 3 3 3 3 2 3 3]
targets [4 1 1 1 4 3 1 1 1 3 1 1 2 1 2 1 1 1]
accuracy = 0.1111

