In [1]:
import pandas as pd
from tqdm import tqdm

In [2]:
from torchvision.transforms import (
    CenterCrop,
    Compose,
    Normalize,
    RandomHorizontalFlip,
    RandomResizedCrop,
    Resize,
    ToTensor,
)
import torch

# ImageNet_mean = [0.485, 0.456, 0.406]
# ImageNet_std = [0.229, 0.224, 0.225]
ImageNet_mean = [0.5, 0.5, 0.5]
ImageNet_std = [0.5, 0.5, 0.5]
normalize = Normalize(mean=ImageNet_mean, std=ImageNet_std)
train_transforms = Compose(
    [
        RandomResizedCrop(224),
        RandomHorizontalFlip(),
        ToTensor(),
        normalize,
    ]
)

val_transforms = Compose(
    [
        # Resize(image_processor.size["height"]),
        Resize(224),
        CenterCrop(224),
        ToTensor(),
        normalize,
    ]
)


def preprocess_train(example_batch):
    """Apply train_transforms across a batch."""
    example_batch["pixel_values"] = [train_transforms(image.convert("RGB")) for image in example_batch["image"]]
    return example_batch


def preprocess_val(example_batch):
    """Apply val_transforms across a batch."""
    example_batch["pixel_values"] = [val_transforms(image.convert("RGB")) for image in example_batch["image"]]
    return example_batch

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from loaders.ImageData import ImageDataset
data_dir = "/bucket/npss/CottonPestClassification_v3a_os_lora/"
data_dir_2 = "/bucket/npss/CottonPestClassification_v3a_npss/"

# train_loader = ImageDataset(data_dir, split='train', transform=train_transforms)
val_loader = ImageDataset(data_dir, split='val', transform=val_transforms)
test_loader = ImageDataset(data_dir_2, split='reporting', transform=val_transforms)
holdout_loader = ImageDataset(data_dir_2, split='holdout', transform=val_transforms)
# val_loader = ImageDataset(data_dir, split='val')
# test_loader = ImageDataset(data_dir, split='test')
# holdout_loader = ImageDataset(data_dir, split='holdout')

# label2id = train_loader.labels2id
label2id = val_loader.labels2id
# id2label = train_loader.id2label
id2label = val_loader.id2label

In [4]:
#load lora model from local
from transformers import AutoModelForImageClassification
from peft import LoraConfig, get_peft_model, PeftModel
from transformers import AutoImageProcessor
import timm
import evaluate
from timm.models import create_model, load_checkpoint
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform

# config = LoraConfig(
#     r=16,
#     lora_alpha=16,
#     target_modules=["query", "value"],
#     lora_dropout=0.1,
#     bias="none",
#     modules_to_save=["classifier"],
# )
# config = LoraConfig(
#     r=16,
#     lora_alpha=16,
#     target_modules=["qkv"],
#     lora_dropout=0.1,
#     bias="none",
#     modules_to_save=["classifier"],
# )

model_checkpoint = "output/train/vit_base_patch16_224.orig_in21k-timm-050524-OS/model_best.pth.tar"

# model = AutoModelForImageClassification.from_pretrained(
#     model_checkpoint,
#     label2id=label2id,
#     id2label=id2label,
#     ignore_mismatched_sizes=True,
# )
model = timm.create_model('timm/vit_base_patch16_224.orig_in21k', 
                                pretrained=True, num_classes=3)
transform = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))
# load_checkpoint(model, model_checkpoint, strict=False)

inference_model = PeftModel.from_pretrained(
                                    model,
                                    'output/vit-base-patch16-224-lora-IN21k_NPSS_CheckNew_1004/checkpoint-11')
# inference_model = model

In [5]:
model.head

ModulesToSaveWrapper(
  (original_module): Linear(in_features=768, out_features=3, bias=True)
  (modules_to_save): ModuleDict(
    (default): Linear(in_features=768, out_features=3, bias=True)
  )
)

In [6]:
# image_processor = AutoImageProcessor.from_pretrained(model_checkpoint)

In [7]:
# image_processor

In [8]:
# # image_processor 
# ImageNet_mean = [0.485, 0.456, 0.406]
# ImageNet_std = [0.229, 0.224, 0.225]

In [9]:
# outs = val_loader.__getitem__(0)
# path = outs['image_file_path']
# img = outs['image']
# label = outs['labels']

In [10]:
# encoding = image_processor(img, return_tensors="pt")
# with torch.no_grad():
#     # check = {'pixel_values': img}
#     output = inference_model(**encoding)
#     logits = output.logits
#     probs = logits.softmax(dim=1)
#     print(probs, probs.argmax(dim=1), label)

In [11]:
# encoding['pixel_values'].shape

In [12]:
val_outs_df = pd.DataFrame(columns=['image_file_path', 'label', 'pred_label', 'pred_prob', 'logits'])

In [13]:
path_list, label_list, pred_label_list, pred_prob_list, logits_list = [], [], [], [], []
for item in tqdm(val_loader):
    path = item['image_file_path']
    img = item['image']
    label = item['labels']

    # encoding = image_processor(img.convert("RGB"), return_tensors="pt")
    # encoding_val = {'pixel_values': img.unsqueeze(0)}
    with torch.no_grad():
        output = inference_model(img.unsqueeze(0))
        logits = output
        probs = logits.softmax(dim=1)
    
    path_list.append(path)
    label_list.append(label)
    pred_label_list.append(probs.argmax(dim=1).item())
    pred_prob_list.append(probs)
    logits_list.append(logits)

val_outs_df['image_file_path'] = path_list
val_outs_df['label'] = label_list
val_outs_df['pred_label'] = pred_label_list
val_outs_df['pred_prob'] = pred_prob_list
val_outs_df['logits'] = logits_list

100%|██████████| 1413/1413 [09:54<00:00,  2.38it/s]


In [14]:
test_outs_df = pd.DataFrame(columns=['image_file_path', 'label', 'pred_label', 'pred_prob', 'logits'])
path_list_test, label_list_test, pred_label_list_test, pred_prob_list_test, logits_list_test = [], [], [], [], []
for item in tqdm(test_loader):
    path = item['image_file_path']
    img = item['image']
    label = item['labels']

    # encoding = image_processor(img.convert("RGB"), return_tensors="pt")
    # encoding_test = {'pixel_values': img.unsqueeze(0)}    
    # print(img.shape, encoding['pixel_values'].shape)
    with torch.no_grad():
        # output = inference_model(**encoding_test)
        output = inference_model(img.unsqueeze(0))
        # logits = output.logits
        logits = output
        probs = logits.softmax(dim=1)
    
    path_list_test.append(path)
    label_list_test.append(label)
    pred_label_list_test.append(probs.argmax(dim=1))
    pred_prob_list_test.append(probs)
    logits_list_test.append(logits)

test_outs_df['image_file_path'] = path_list_test
test_outs_df['label'] = label_list_test
test_outs_df['pred_label'] = pred_label_list_test
test_outs_df['pred_prob'] = pred_prob_list_test
test_outs_df['logits'] = logits_list_test

100%|██████████| 1867/1867 [14:07<00:00,  2.20it/s]


In [15]:
npss_outs_df = pd.DataFrame(columns=['image_file_path', 'label', 'pred_label', 'pred_prob', 'logits'])
path_list_npss, label_list_npss, pred_label_list_npss, pred_prob_list_npss, logits_list_npss = [], [], [], [], []
for item in tqdm(holdout_loader):
    path = item['image_file_path']
    img = item['image']
    label = item['labels']

    # encoding = image_processor(img.convert("RGB"), return_tensors="pt")
    encoding_hold = {'pixel_values': img.unsqueeze(0)}
    with torch.no_grad():
        # output = inference_model(**encoding_hold)
        output = inference_model(img.unsqueeze(0))
        # logits = output.logits
        logits = output
        probs = logits.softmax(dim=1)
    
    path_list_npss.append(path)
    label_list_npss.append(label)
    pred_label_list_npss.append(probs.argmax(dim=1))
    pred_prob_list_npss.append(probs)
    logits_list_npss.append(logits)

npss_outs_df['image_file_path'] = path_list_npss
npss_outs_df['label'] = label_list_npss
npss_outs_df['pred_label'] = pred_label_list_npss
npss_outs_df['pred_prob'] = pred_prob_list_npss
npss_outs_df['logits'] = logits_list_npss


100%|██████████| 101/101 [00:40<00:00,  2.48it/s]


In [16]:
val_outs_df.label.value_counts(), val_outs_df.pred_label.value_counts()

(label
 1    928
 2    452
 0     33
 Name: count, dtype: int64,
 pred_label
 1    945
 2    273
 0    195
 Name: count, dtype: int64)

In [17]:
# npss_outs_df.label.value_counts(), npss_outs_df.pred_label.value_counts()

In [18]:
# val_outs_df['pred_label'] = val_outs_df['pred_label'].apply(lambda x: x.item())
test_outs_df['pred_label'] = test_outs_df['pred_label'].apply(lambda x: x.item())
npss_outs_df['pred_label'] = npss_outs_df['pred_label'].apply(lambda x: x.item())

In [19]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

val_acc = accuracy_score(val_outs_df['label'], val_outs_df['pred_label'])
val_prec = precision_score(val_outs_df['label'], val_outs_df['pred_label'], average=None)
val_recall = recall_score(val_outs_df['label'], val_outs_df['pred_label'], average=None)
val_f1 = f1_score(val_outs_df['label'], val_outs_df['pred_label'], average=None)

test_acc = accuracy_score(test_outs_df['label'], test_outs_df['pred_label'])
test_prec = precision_score(test_outs_df['label'], test_outs_df['pred_label'], average=None)
test_recall = recall_score(test_outs_df['label'], test_outs_df['pred_label'], average=None)
test_f1 = f1_score(test_outs_df['label'], test_outs_df['pred_label'], average=None)

npss_acc = accuracy_score(npss_outs_df['label'], npss_outs_df['pred_label'])
npss_prec = precision_score(npss_outs_df['label'], npss_outs_df['pred_label'], average=None)
npss_recall = recall_score(npss_outs_df['label'], npss_outs_df['pred_label'], average=None)
npss_f1 = f1_score(npss_outs_df['label'], npss_outs_df['pred_label'], average=None)

In [20]:
val_outs_df['correct'] = val_outs_df['label'] == val_outs_df['pred_label']

In [21]:
val_outs_df.correct.value_counts()

correct
True     1014
False     399
Name: count, dtype: int64

In [22]:
print("Mapping: ", label2id)
print("Validation Accuracy: ", round(val_acc, 4))
print("Validation Precision: ", val_prec)
print("Validation Recall: ", val_recall)
print("Validation F1: ", val_f1)

print("Test Accuracy: ", round(test_acc, 4))
print("Test Precision: ", test_prec)
print("Test Recall: ", test_recall)
print("Test F1: ", test_f1)

print("NPSS Accuracy: ", round(npss_acc, 4))
print("NPSS Precision: ", npss_prec)
print("NPSS Recall: ", npss_recall)
print("NPSS F1: ", npss_f1)

Mapping:  {'aphids': 0, 'none': 1, 'whitefly': 2}
Validation Accuracy:  0.7176
Validation Precision:  [0.04615385 0.82857143 0.81318681]
Validation Recall:  [0.27272727 0.84375    0.49115044]
Validation F1:  [0.07894737 0.83609183 0.61241379]
Test Accuracy:  0.7365
Test Precision:  [0.22222222 0.81646091 0.79186603]
Test Recall:  [0.57777778 0.85223368 0.53996737]
Test F1:  [0.32098765 0.83396385 0.64209505]
NPSS Accuracy:  0.7921
NPSS Precision:  [0.83333333 0.16666667 0.82978723]
NPSS Recall:  [0.8  1.   0.78]
NPSS F1:  [0.81632653 0.28571429 0.80412371]


In [23]:
val_outs_df.to_csv('./output/vit-base-patch16-224-lora-IN21k_NPSS_OS_check.csv', index=False)
test_outs_df.to_csv('./output/vit-base-patch16-224-lora-IN21k_NPSS_OS_check.csv', index=False)
npss_outs_df.to_csv('./output/vit-base-patch16-224-lora-IN21k_NPSS_OS_check.csv', index=False)

In [24]:
# inference_model.push_to_hub('ashishp-wiai/vit-base-patch16-224-in21k-finetuned-CottonPestClassification_v3a_os')