# Load data

In [1]:
import pandas as pd
import numpy as np
import json
import matplotlib.pyplot as plt
import seaborn as sns
import cv2

train_dir = '../input/herbarium-2022-fgvc9/train_images/'
test_dir = '../input/herbarium-2022-fgvc9/test_images/'

with open("../input/herbarium-2022-fgvc9/train_metadata.json") as json_file:
    train_meta = json.load(json_file)
with open("../input/herbarium-2022-fgvc9/test_metadata.json") as json_file:
    test_meta = json.load(json_file)

# JSON -> Dataframe

In [2]:
image_ids = [image["image_id"] for image in train_meta["images"]]
image_dirs = [train_dir + image['file_name'] for image in train_meta["images"]]
category_ids = [annotation['category_id'] for annotation in train_meta['annotations']]
genus_ids = [annotation['genus_id'] for annotation in train_meta['annotations']]

test_ids = [image['image_id'] for image in test_meta]
test_dirs = [test_dir + image['file_name'] for image in test_meta]

#Create the initial training dataframe with the above defined columns
train_df = pd.DataFrame({
    "image_id" : image_ids,
    "image_dir" : image_dirs,
    "category" : category_ids,
    "genus" : genus_ids})

#Create a testing dataframe
test_df = pd.DataFrame({
    "test_id" : test_ids,
    "test_dir" : test_dirs
})

# Mapping genus and family

In [3]:
#Add a genus column to the dataframe
genus_map = {genus['genus_id'] : genus['genus'] for genus in train_meta['genera']}
train_df['genus'] = train_df['genus'].map(genus_map)

##Create a family column in the datagframe based on the genus names
    # Step 1: Create dictionary of genus -> family mapping
genus_family_map = {}
for category in train_meta["categories"]:
    genus = category['genus']
    family = category['family']
    genus_family_map[genus] = family

    # Step 2: Create new column with default value of None™
train_df['family'] = None

    # Step 3: Update values in new column based on genus -> family mapping
for i, row in train_df.iterrows():
    genus = row['genus']
    if genus in genus_family_map:
        family = genus_family_map[genus]
        train_df.at[i, 'family'] = family

train_df

Unnamed: 0,image_id,image_dir,category,genus,family
0,00000__001,../input/herbarium-2022-fgvc9/train_images/000...,0,Abies,Pinaceae
1,00000__002,../input/herbarium-2022-fgvc9/train_images/000...,0,Abies,Pinaceae
2,00000__003,../input/herbarium-2022-fgvc9/train_images/000...,0,Abies,Pinaceae
3,00000__004,../input/herbarium-2022-fgvc9/train_images/000...,0,Abies,Pinaceae
4,00000__005,../input/herbarium-2022-fgvc9/train_images/000...,0,Abies,Pinaceae
...,...,...,...,...,...
839767,15504__032,../input/herbarium-2022-fgvc9/train_images/155...,15504,Zygophyllum,Zygophyllaceae
839768,15504__033,../input/herbarium-2022-fgvc9/train_images/155...,15504,Zygophyllum,Zygophyllaceae
839769,15504__035,../input/herbarium-2022-fgvc9/train_images/155...,15504,Zygophyllum,Zygophyllaceae
839770,15504__036,../input/herbarium-2022-fgvc9/train_images/155...,15504,Zygophyllum,Zygophyllaceae


# Filtering to Poaceae

In [4]:
#Filter only the images of plants that are in the Poaceae family
train_df = train_df.loc[train_df['family'] == 'Poaceae']
#Reset index
train_df = train_df.reset_index(drop=True)
train_df

Unnamed: 0,image_id,image_dir,category,genus,family
0,00333__001,../input/herbarium-2022-fgvc9/train_images/003...,333,Agrostis,Poaceae
1,00333__002,../input/herbarium-2022-fgvc9/train_images/003...,333,Agrostis,Poaceae
2,00333__003,../input/herbarium-2022-fgvc9/train_images/003...,333,Agrostis,Poaceae
3,00333__004,../input/herbarium-2022-fgvc9/train_images/003...,333,Agrostis,Poaceae
4,00333__005,../input/herbarium-2022-fgvc9/train_images/003...,333,Agrostis,Poaceae
...,...,...,...,...,...
53542,15501__101,../input/herbarium-2022-fgvc9/train_images/155...,15501,Zuloagaea,Poaceae
53543,15501__103,../input/herbarium-2022-fgvc9/train_images/155...,15501,Zuloagaea,Poaceae
53544,15501__105,../input/herbarium-2022-fgvc9/train_images/155...,15501,Zuloagaea,Poaceae
53545,15501__106,../input/herbarium-2022-fgvc9/train_images/155...,15501,Zuloagaea,Poaceae


In [5]:
#Add category_id and species column
train_df["species"] = None

# Extract category_id and species values from categories where the family is Poaceae
species_list = []
for category in train_meta["categories"]:
    if category["family"] == "Poaceae":
        species_list.append({
            "category": category["category_id"],
            "species": category["species"]
        })

# loop through data frame and species list to update species column
for i, row in train_df.iterrows():
    for species in species_list:
        if row['category'] == species['category']:
            train_df.at[i, 'species'] = species['species']
            
train_df

Unnamed: 0,image_id,image_dir,category,genus,family,species
0,00333__001,../input/herbarium-2022-fgvc9/train_images/003...,333,Agrostis,Poaceae,blasdalei
1,00333__002,../input/herbarium-2022-fgvc9/train_images/003...,333,Agrostis,Poaceae,blasdalei
2,00333__003,../input/herbarium-2022-fgvc9/train_images/003...,333,Agrostis,Poaceae,blasdalei
3,00333__004,../input/herbarium-2022-fgvc9/train_images/003...,333,Agrostis,Poaceae,blasdalei
4,00333__005,../input/herbarium-2022-fgvc9/train_images/003...,333,Agrostis,Poaceae,blasdalei
...,...,...,...,...,...,...
53542,15501__101,../input/herbarium-2022-fgvc9/train_images/155...,15501,Zuloagaea,Poaceae,bulbosa
53543,15501__103,../input/herbarium-2022-fgvc9/train_images/155...,15501,Zuloagaea,Poaceae,bulbosa
53544,15501__105,../input/herbarium-2022-fgvc9/train_images/155...,15501,Zuloagaea,Poaceae,bulbosa
53545,15501__106,../input/herbarium-2022-fgvc9/train_images/155...,15501,Zuloagaea,Poaceae,bulbosa


In [6]:
print('Top 15 Genus in Poaceae family')
print("")
print(train_df['genus'].value_counts())

Top 15 Genus in Poaceae family

Muhlenbergia       4228
Paspalum           3124
Poa                2608
Dichanthelium      2474
Sporobolus         2304
                   ... 
Ptilagrostiella      14
Rhipidocladum        11
Dupontia             10
Kalinia              10
Barkworthia           8
Name: genus, Length: 158, dtype: int64


# Data visualization

Genus

In [7]:
#genus_data = train_df['genus'].value_counts().head(15)
#genus_data = pd.DataFrame({'Genus' : genus_data.index,
#                     'values' : genus_data.values})
                     
#plt.figure(figsize = (20, 10))
#sns.barplot(x='values', y = 'Genus', data = genus_data , palette='summer_r')
#plt.show()

##From most to least: Muhlenbergia, Paspalum, Poa, Dichanthelium, Sporobolus, Eragrostis etc.

In [8]:
#Muhlenbergia data
muh_pas_df = train_df[(train_df['genus'] == 'Paspalum') | (train_df['genus'] == 'Muhlenbergia')]
muh_pas_df = muh_pas_df.reset_index(drop=True)
muh_pas_df

Unnamed: 0,image_id,image_dir,category,genus,family,species
0,09492__001,../input/herbarium-2022-fgvc9/train_images/094...,9492,Muhlenbergia,Poaceae,alopecuroides
1,09492__003,../input/herbarium-2022-fgvc9/train_images/094...,9492,Muhlenbergia,Poaceae,alopecuroides
2,09492__004,../input/herbarium-2022-fgvc9/train_images/094...,9492,Muhlenbergia,Poaceae,alopecuroides
3,09492__005,../input/herbarium-2022-fgvc9/train_images/094...,9492,Muhlenbergia,Poaceae,alopecuroides
4,09492__006,../input/herbarium-2022-fgvc9/train_images/094...,9492,Muhlenbergia,Poaceae,alopecuroides
...,...,...,...,...,...,...
7347,10398__026,../input/herbarium-2022-fgvc9/train_images/103...,10398,Paspalum,Poaceae,wrightii
7348,10398__029,../input/herbarium-2022-fgvc9/train_images/103...,10398,Paspalum,Poaceae,wrightii
7349,10398__030,../input/herbarium-2022-fgvc9/train_images/103...,10398,Paspalum,Poaceae,wrightii
7350,10398__031,../input/herbarium-2022-fgvc9/train_images/103...,10398,Paspalum,Poaceae,wrightii


**Species**

In [9]:
#data_species = data_muhlenbergia['species'].value_counts().head(15) #data_species = pd.DataFrame({'Species' : data_species.index,
#                     'values' : data_species.values})#plt.figure(figsize = (20, 10))#sns.barplot(x='values', y = 'Species', data = data_species , palette='summer_r')#plt.show()#data_species

# Image displaying

In [10]:
def show_images(genus):
    images = train_df.loc[train_df['genus'] == genus]['image_dir'][:9]
    i = 1
    fig = plt.figure(figsize = (18, 18))
    plt.suptitle(genus, fontsize = '30')
    for image in images:
        img = cv2.imread(image)
        ax = fig.add_subplot(3, 3, i)
        ax.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
        ax.set_axis_off()
        i += 1
    plt.show()

# Splitting the dataset into training and validation

**2 Genuses**

In [11]:
muh_df = muh_pas_df[muh_pas_df["genus"] == "Muhlenbergia"] 
pas_df = muh_pas_df[muh_pas_df["genus"] == "Paspalum"]

#15 percent of images will be used for validation
# Muh total: 4228 --> 15% = 634
# Pas total: 3124 --> 15% = 467
muh_valid = muh_df.sample(n=634, random_state=42)
muh_train = muh_df.drop(muh_valid.index)
muh_valid = muh_valid.reset_index(drop=True)
muh_train = muh_train.reset_index(drop=True)

pas_valid = pas_df.sample(n=467, random_state=42)
pas_train = pas_df.drop(pas_valid.index)
pas_valid = pas_valid.reset_index(drop=True)
pas_train = pas_train.reset_index(drop=True)

# Merging the Muhlanbergia and Paspalum databases
muh_pas_train = pd.concat([muh_train, pas_train])
muh_pas_train = muh_pas_train.reset_index(drop=True)

muh_pas_valid = pd.concat([muh_valid, pas_valid])
muh_pas_valid = muh_pas_valid.reset_index(drop=True)

muh_pas_train

Unnamed: 0,image_id,image_dir,category,genus,family,species
0,09492__001,../input/herbarium-2022-fgvc9/train_images/094...,9492,Muhlenbergia,Poaceae,alopecuroides
1,09492__003,../input/herbarium-2022-fgvc9/train_images/094...,9492,Muhlenbergia,Poaceae,alopecuroides
2,09492__004,../input/herbarium-2022-fgvc9/train_images/094...,9492,Muhlenbergia,Poaceae,alopecuroides
3,09492__005,../input/herbarium-2022-fgvc9/train_images/094...,9492,Muhlenbergia,Poaceae,alopecuroides
4,09492__006,../input/herbarium-2022-fgvc9/train_images/094...,9492,Muhlenbergia,Poaceae,alopecuroides
...,...,...,...,...,...,...
6246,10398__026,../input/herbarium-2022-fgvc9/train_images/103...,10398,Paspalum,Poaceae,wrightii
6247,10398__029,../input/herbarium-2022-fgvc9/train_images/103...,10398,Paspalum,Poaceae,wrightii
6248,10398__030,../input/herbarium-2022-fgvc9/train_images/103...,10398,Paspalum,Poaceae,wrightii
6249,10398__031,../input/herbarium-2022-fgvc9/train_images/103...,10398,Paspalum,Poaceae,wrightii


**Whole Poaceae family**

# Creating the model

In [12]:
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.models as models
from PIL import Image
import torch.optim as optim
import numpy as np
from sklearn import preprocessing
import torch

In [13]:
batch_size = 128
epochs = 10
IM_SIZE = 224

X_Train, Y_Train = muh_pas_train["image_dir"].values, muh_pas_train["genus"].values

le = preprocessing.LabelEncoder()
train_labels = le.fit_transform(Y_Train)
train_labels = torch.as_tensor(train_labels)

X_Valid, Y_Valid = muh_pas_valid["image_dir"].values, muh_pas_valid["genus"].values

leV = preprocessing.LabelEncoder()
val_labels = le.fit_transform(Y_Valid)
val_labels = torch.as_tensor(val_labels)

Transform = transforms.Compose(
    [transforms.ToTensor(),
    transforms.Resize((IM_SIZE, IM_SIZE)),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])

# X_Train, train_labels
# X_Valid, val_labels

In [14]:
class GetData(Dataset):
    def __init__(self, FNames, Labels, Transform):
        self.fnames = FNames
        self.transform = Transform
        self.labels = Labels         
        
    def __len__(self):
        return len(self.fnames)

    def __getitem__(self, index):       
        x = Image.open(self.fnames[index])
    
        if "train" in self.fnames[index]:             
            return self.transform(x), self.labels[index]
        elif "test" in self.fnames[index]:            
            return self.transform(x), self.fnames[index]

                
trainset = GetData(X_Train, train_labels, Transform)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)

valset = GetData(X_Valid, val_labels, Transform)
valloader = DataLoader(valset, batch_size=batch_size, shuffle=True)

N_Classes = muh_pas_train['genus'].nunique()
next(iter(trainloader))[0].shape

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
model = torchvision.models.densenet169(pretrained=True)


  f"The parameter '{pretrained_param}' is deprecated since 0.13 and will be removed in 0.15, "
Downloading: "https://download.pytorch.org/models/densenet169-b2777c0a.pth" to /root/.cache/torch/hub/checkpoints/densenet169-b2777c0a.pth


  0%|          | 0.00/54.7M [00:00<?, ?B/s]

In [15]:
print(model.classifier.in_features) 
print(model.classifier.out_features)

for param in model.parameters():
    param.requires_grad = False
    
n_inputs = model.classifier.in_features
last_layer = nn.Linear(n_inputs, N_Classes)
model.classifier = last_layer
if torch.cuda.is_available():
    model.cuda()
print(model.classifier.out_features)    

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.classifier.parameters())

1664
1000
2


In [16]:
from tqdm import tqdm

true_labels = []
pred_labels = []

def train(trainloader, model, criterion, optimizer, scaler, device=torch.device("cpu")):
    train_acc = 0.0
    train_loss = 0.0
    for images, labels in tqdm(trainloader):
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        with torch.cuda.amp.autocast(enabled=True):
            output = model(images)
            loss = criterion(output, labels)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            acc = ((output.argmax(dim=1) == labels).float().mean())
            train_acc += acc
            train_loss += loss
            
    return train_acc/len(trainloader), train_loss/len(trainloader)

In [17]:
## Normal Evaluation
def evaluate(testloader, model, criterion, device=torch.device("cpu")):
    eval_acc = 0.0
    eval_loss = 0.0
    for images, labels in tqdm(testloader):
        images = images.to(device)
        labels = labels.to(device)
        with torch.no_grad():
            output = model(images)
            loss = criterion(output, labels)
        acc = ((output.argmax(dim=1) == labels).float().mean())
        eval_acc += acc
        eval_loss += loss
  
    return eval_acc/len(testloader), eval_loss/len(testloader)

In [22]:
%%time
from sklearn.metrics import classification_report

scaler = torch.cuda.amp.GradScaler(enabled=True)
for epoch in range(epochs):
    train_acc, train_loss = train(trainloader, model, criterion, optimizer, scaler, device=device)
    eval_acc, eval_loss = evaluate(valloader, model, criterion, device=torch.device("cuda"))

    # calculate F1 score
    with torch.no_grad():
        model.eval()
        preds = []
        targets = []
        for images, labels in tqdm(valloader):
            images = images.to(device)
            labels = labels.to(device)
            output = model(images)
            preds.append(output.argmax(dim=1).cpu().numpy())
            targets.append(labels.cpu().numpy())
        preds = np.concatenate(preds)
        targets = np.concatenate(targets)
        report = classification_report(targets, preds)
        print(report)

    #print("") 
    #print(f"Epoch {epoch + 1} | Train Acc: {train_acc*100} | Train Loss: {train_loss}")
    #print(f"\t Val Acc: {eval_acc*100} | Val Loss: {eval_loss}")
    #print("===="*8) 

100%|██████████| 49/49 [02:03<00:00,  2.51s/it]
100%|██████████| 9/9 [00:19<00:00,  2.18s/it]
100%|██████████| 9/9 [00:23<00:00,  2.63s/it]


              precision    recall  f1-score   support

           0       0.94      0.88      0.91       634
           1       0.85      0.92      0.88       467

    accuracy                           0.90      1101
   macro avg       0.89      0.90      0.89      1101
weighted avg       0.90      0.90      0.90      1101



100%|██████████| 49/49 [02:04<00:00,  2.54s/it]
100%|██████████| 9/9 [00:19<00:00,  2.12s/it]
100%|██████████| 9/9 [00:23<00:00,  2.63s/it]


              precision    recall  f1-score   support

           0       0.91      0.94      0.92       634
           1       0.91      0.88      0.89       467

    accuracy                           0.91      1101
   macro avg       0.91      0.91      0.91      1101
weighted avg       0.91      0.91      0.91      1101



100%|██████████| 49/49 [02:04<00:00,  2.55s/it]
100%|██████████| 9/9 [00:19<00:00,  2.19s/it]
100%|██████████| 9/9 [00:22<00:00,  2.55s/it]


              precision    recall  f1-score   support

           0       0.93      0.92      0.92       634
           1       0.89      0.90      0.90       467

    accuracy                           0.91      1101
   macro avg       0.91      0.91      0.91      1101
weighted avg       0.91      0.91      0.91      1101



100%|██████████| 49/49 [02:04<00:00,  2.54s/it]
100%|██████████| 9/9 [00:19<00:00,  2.20s/it]
100%|██████████| 9/9 [00:23<00:00,  2.62s/it]


              precision    recall  f1-score   support

           0       0.91      0.95      0.93       634
           1       0.93      0.87      0.90       467

    accuracy                           0.92      1101
   macro avg       0.92      0.91      0.91      1101
weighted avg       0.92      0.92      0.92      1101



100%|██████████| 49/49 [02:04<00:00,  2.53s/it]
100%|██████████| 9/9 [00:19<00:00,  2.11s/it]
100%|██████████| 9/9 [00:23<00:00,  2.61s/it]


              precision    recall  f1-score   support

           0       0.94      0.93      0.93       634
           1       0.90      0.91      0.91       467

    accuracy                           0.92      1101
   macro avg       0.92      0.92      0.92      1101
weighted avg       0.92      0.92      0.92      1101



100%|██████████| 49/49 [02:04<00:00,  2.54s/it]
100%|██████████| 9/9 [00:19<00:00,  2.17s/it]
100%|██████████| 9/9 [00:22<00:00,  2.53s/it]


              precision    recall  f1-score   support

           0       0.92      0.94      0.93       634
           1       0.92      0.89      0.90       467

    accuracy                           0.92      1101
   macro avg       0.92      0.92      0.92      1101
weighted avg       0.92      0.92      0.92      1101



100%|██████████| 49/49 [02:04<00:00,  2.54s/it]
100%|██████████| 9/9 [00:19<00:00,  2.15s/it]
100%|██████████| 9/9 [00:23<00:00,  2.62s/it]


              precision    recall  f1-score   support

           0       0.88      0.96      0.92       634
           1       0.94      0.82      0.88       467

    accuracy                           0.90      1101
   macro avg       0.91      0.89      0.90      1101
weighted avg       0.91      0.90      0.90      1101



100%|██████████| 49/49 [02:04<00:00,  2.54s/it]
100%|██████████| 9/9 [00:19<00:00,  2.12s/it]
100%|██████████| 9/9 [00:23<00:00,  2.62s/it]


              precision    recall  f1-score   support

           0       0.90      0.96      0.93       634
           1       0.93      0.86      0.90       467

    accuracy                           0.92      1101
   macro avg       0.92      0.91      0.91      1101
weighted avg       0.92      0.92      0.91      1101



100%|██████████| 49/49 [02:04<00:00,  2.53s/it]
100%|██████████| 9/9 [00:19<00:00,  2.12s/it]
100%|██████████| 9/9 [00:23<00:00,  2.59s/it]


              precision    recall  f1-score   support

           0       0.90      0.96      0.93       634
           1       0.94      0.86      0.90       467

    accuracy                           0.92      1101
   macro avg       0.92      0.91      0.91      1101
weighted avg       0.92      0.92      0.91      1101



 18%|█▊        | 9/49 [00:24<01:51,  2.78s/it]


TypeError: int() argument must be a string, a bytes-like object or a number, not 'JpegImageFile'