# Load data

In [3]:
import pandas as pd
import numpy as np
import json
import matplotlib.pyplot as plt
import seaborn as sns
import cv2

train_dir = '../input/herbarium-2022-fgvc9/train_images/'
test_dir = '../input/herbarium-2022-fgvc9/test_images/'

with open("../input/herbarium-2022-fgvc9/train_metadata.json") as json_file:
    train_meta = json.load(json_file)
with open("../input/herbarium-2022-fgvc9/test_metadata.json") as json_file:
    test_meta = json.load(json_file)

# JSON -> Dataframe

In [4]:
image_ids = [image["image_id"] for image in train_meta["images"]]
image_dirs = [train_dir + image['file_name'] for image in train_meta["images"]]
category_ids = [annotation['category_id'] for annotation in train_meta['annotations']]
genus_ids = [annotation['genus_id'] for annotation in train_meta['annotations']]

test_ids = [image['image_id'] for image in test_meta]
test_dirs = [test_dir + image['file_name'] for image in test_meta]

#Create the initial training dataframe with the above defined columns
train_df = pd.DataFrame({
    "image_id" : image_ids,
    "image_dir" : image_dirs,
    "category" : category_ids,
    "genus" : genus_ids})

#Create a testing dataframe
test_df = pd.DataFrame({
    "test_id" : test_ids,
    "test_dir" : test_dirs
})

# Mapping genus and family

In [5]:
#Add a genus column to the dataframe
genus_map = {genus['genus_id'] : genus['genus'] for genus in train_meta['genera']}
train_df['genus'] = train_df['genus'].map(genus_map)

##Create a family column in the datagframe based on the genus names
    # Step 1: Create dictionary of genus -> family mapping
genus_family_map = {}
for category in train_meta["categories"]:
    genus = category['genus']
    family = category['family']
    genus_family_map[genus] = family

    # Step 2: Create new column with default value of None™
train_df['family'] = None

    # Step 3: Update values in new column based on genus -> family mapping
for i, row in train_df.iterrows():
    genus = row['genus']
    if genus in genus_family_map:
        family = genus_family_map[genus]
        train_df.at[i, 'family'] = family

train_df

Unnamed: 0,image_id,image_dir,category,genus,family
0,00000__001,../input/herbarium-2022-fgvc9/train_images/000...,0,Abies,Pinaceae
1,00000__002,../input/herbarium-2022-fgvc9/train_images/000...,0,Abies,Pinaceae
2,00000__003,../input/herbarium-2022-fgvc9/train_images/000...,0,Abies,Pinaceae
3,00000__004,../input/herbarium-2022-fgvc9/train_images/000...,0,Abies,Pinaceae
4,00000__005,../input/herbarium-2022-fgvc9/train_images/000...,0,Abies,Pinaceae
...,...,...,...,...,...
839767,15504__032,../input/herbarium-2022-fgvc9/train_images/155...,15504,Zygophyllum,Zygophyllaceae
839768,15504__033,../input/herbarium-2022-fgvc9/train_images/155...,15504,Zygophyllum,Zygophyllaceae
839769,15504__035,../input/herbarium-2022-fgvc9/train_images/155...,15504,Zygophyllum,Zygophyllaceae
839770,15504__036,../input/herbarium-2022-fgvc9/train_images/155...,15504,Zygophyllum,Zygophyllaceae


# Filtering to Poaceae

In [6]:
#Filter only the images of plants that are in the Poaceae family
train_df = train_df.loc[train_df['family'] == 'Poaceae']
#Reset index
train_df = train_df.reset_index(drop=True)
train_df

Unnamed: 0,image_id,image_dir,category,genus,family
0,00333__001,../input/herbarium-2022-fgvc9/train_images/003...,333,Agrostis,Poaceae
1,00333__002,../input/herbarium-2022-fgvc9/train_images/003...,333,Agrostis,Poaceae
2,00333__003,../input/herbarium-2022-fgvc9/train_images/003...,333,Agrostis,Poaceae
3,00333__004,../input/herbarium-2022-fgvc9/train_images/003...,333,Agrostis,Poaceae
4,00333__005,../input/herbarium-2022-fgvc9/train_images/003...,333,Agrostis,Poaceae
...,...,...,...,...,...
53542,15501__101,../input/herbarium-2022-fgvc9/train_images/155...,15501,Zuloagaea,Poaceae
53543,15501__103,../input/herbarium-2022-fgvc9/train_images/155...,15501,Zuloagaea,Poaceae
53544,15501__105,../input/herbarium-2022-fgvc9/train_images/155...,15501,Zuloagaea,Poaceae
53545,15501__106,../input/herbarium-2022-fgvc9/train_images/155...,15501,Zuloagaea,Poaceae


In [7]:
#Add category_id and species column
train_df["species"] = None

# Extract category_id and species values from categories where the family is Poaceae
species_list = []
for category in train_meta["categories"]:
    if category["family"] == "Poaceae":
        species_list.append({
            "category": category["category_id"],
            "species": category["species"]
        })

# loop through data frame and species list to update species column
for i, row in train_df.iterrows():
    for species in species_list:
        if row['category'] == species['category']:
            train_df.at[i, 'species'] = species['species']
            
train_df

Unnamed: 0,image_id,image_dir,category,genus,family,species
0,00333__001,../input/herbarium-2022-fgvc9/train_images/003...,333,Agrostis,Poaceae,blasdalei
1,00333__002,../input/herbarium-2022-fgvc9/train_images/003...,333,Agrostis,Poaceae,blasdalei
2,00333__003,../input/herbarium-2022-fgvc9/train_images/003...,333,Agrostis,Poaceae,blasdalei
3,00333__004,../input/herbarium-2022-fgvc9/train_images/003...,333,Agrostis,Poaceae,blasdalei
4,00333__005,../input/herbarium-2022-fgvc9/train_images/003...,333,Agrostis,Poaceae,blasdalei
...,...,...,...,...,...,...
53542,15501__101,../input/herbarium-2022-fgvc9/train_images/155...,15501,Zuloagaea,Poaceae,bulbosa
53543,15501__103,../input/herbarium-2022-fgvc9/train_images/155...,15501,Zuloagaea,Poaceae,bulbosa
53544,15501__105,../input/herbarium-2022-fgvc9/train_images/155...,15501,Zuloagaea,Poaceae,bulbosa
53545,15501__106,../input/herbarium-2022-fgvc9/train_images/155...,15501,Zuloagaea,Poaceae,bulbosa


In [8]:
print('Top 15 Genus in Poaceae family')
print("")
print(train_df['genus'].value_counts().head(15))

Top 15 Genus in Poaceae family

Muhlenbergia     4228
Paspalum         3124
Poa              2608
Dichanthelium    2474
Sporobolus       2304
Eragrostis       2068
Aristida         1951
Festuca          1469
Bromus           1458
Bouteloua        1427
Panicum          1415
Setaria          1287
Eriocoma         1244
Elymus           1223
Melica            941
Name: genus, dtype: int64


# Data visualization

Genus

In [9]:
#genus_data = train_df['genus'].value_counts().head(15)
#genus_data = pd.DataFrame({'Genus' : genus_data.index,
#                     'values' : genus_data.values})
                     
#plt.figure(figsize = (20, 10))
#sns.barplot(x='values', y = 'Genus', data = genus_data , palette='summer_r')
#plt.show()

##From most to least: Muhlenbergia, Paspalum, Poa, Dichanthelium, Sporobolus, Eragrostis etc.

In [10]:
#Muhlenbergia data
muh_pas_df = train_df[(train_df['genus'] == 'Paspalum') | (train_df['genus'] == 'Muhlenbergia')]
muh_pas_df = muh_pas_df.reset_index(drop=True)
muh_pas_df

Unnamed: 0,image_id,image_dir,category,genus,family,species
0,09492__001,../input/herbarium-2022-fgvc9/train_images/094...,9492,Muhlenbergia,Poaceae,alopecuroides
1,09492__003,../input/herbarium-2022-fgvc9/train_images/094...,9492,Muhlenbergia,Poaceae,alopecuroides
2,09492__004,../input/herbarium-2022-fgvc9/train_images/094...,9492,Muhlenbergia,Poaceae,alopecuroides
3,09492__005,../input/herbarium-2022-fgvc9/train_images/094...,9492,Muhlenbergia,Poaceae,alopecuroides
4,09492__006,../input/herbarium-2022-fgvc9/train_images/094...,9492,Muhlenbergia,Poaceae,alopecuroides
...,...,...,...,...,...,...
7347,10398__026,../input/herbarium-2022-fgvc9/train_images/103...,10398,Paspalum,Poaceae,wrightii
7348,10398__029,../input/herbarium-2022-fgvc9/train_images/103...,10398,Paspalum,Poaceae,wrightii
7349,10398__030,../input/herbarium-2022-fgvc9/train_images/103...,10398,Paspalum,Poaceae,wrightii
7350,10398__031,../input/herbarium-2022-fgvc9/train_images/103...,10398,Paspalum,Poaceae,wrightii


**Species**

In [11]:
#data_species = data_muhlenbergia['species'].value_counts().head(15) #data_species = pd.DataFrame({'Species' : data_species.index,
#                     'values' : data_species.values})#plt.figure(figsize = (20, 10))#sns.barplot(x='values', y = 'Species', data = data_species , palette='summer_r')#plt.show()#data_species

# Image displaying

In [12]:
def show_images(genus):
    images = muh_pas_df.loc[muh_pas_df['genus'] == genus]['image_dir'][:9]
    i = 1
    fig = plt.figure(figsize = (18, 18))
    plt.suptitle(genus, fontsize = '30')
    for image in images:
        img = cv2.imread(image)
        ax = fig.add_subplot(3, 3, i)
        ax.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
        ax.set_axis_off()
        i += 1
    plt.show()
#show_images('Muhlenbergia')
#show_images('Paspalum')

# Splitting the dataset into training and validation

In [13]:
muh_df = muh_pas_df[muh_pas_df["genus"] == "Muhlenbergia"] 
pas_df = muh_pas_df[muh_pas_df["genus"] == "Paspalum"]

#15 percent of images will be used for validation
# Muh total: 4228 --> 15% = 634
# Pas total: 3124 --> 15% = 467
muh_valid = muh_df.sample(n=634, random_state=42)
muh_train = muh_df.drop(muh_valid.index)
muh_valid = muh_valid.reset_index(drop=True)
muh_train = muh_train.reset_index(drop=True)

pas_valid = pas_df.sample(n=467, random_state=42)
pas_train = pas_df.drop(pas_valid.index)
pas_valid = pas_valid.reset_index(drop=True)
pas_train = pas_train.reset_index(drop=True)

# Merging the Muhlanbergia and Paspalum databases
muh_pas_train = pd.concat([muh_train, pas_train])
muh_pas_train = muh_pas_train.reset_index(drop=True)

muh_pas_valid = pd.concat([muh_valid, pas_valid])
muh_pas_valid = muh_pas_valid.reset_index(drop=True)

#muh_pas_train['genus'] = muh_pas_train['genus'].replace({'Muhlenbergia': 1, 'Paspalum': 2})
muh_pas_train

Unnamed: 0,image_id,image_dir,category,genus,family,species
0,09492__001,../input/herbarium-2022-fgvc9/train_images/094...,9492,Muhlenbergia,Poaceae,alopecuroides
1,09492__003,../input/herbarium-2022-fgvc9/train_images/094...,9492,Muhlenbergia,Poaceae,alopecuroides
2,09492__004,../input/herbarium-2022-fgvc9/train_images/094...,9492,Muhlenbergia,Poaceae,alopecuroides
3,09492__005,../input/herbarium-2022-fgvc9/train_images/094...,9492,Muhlenbergia,Poaceae,alopecuroides
4,09492__006,../input/herbarium-2022-fgvc9/train_images/094...,9492,Muhlenbergia,Poaceae,alopecuroides
...,...,...,...,...,...,...
6246,10398__026,../input/herbarium-2022-fgvc9/train_images/103...,10398,Paspalum,Poaceae,wrightii
6247,10398__029,../input/herbarium-2022-fgvc9/train_images/103...,10398,Paspalum,Poaceae,wrightii
6248,10398__030,../input/herbarium-2022-fgvc9/train_images/103...,10398,Paspalum,Poaceae,wrightii
6249,10398__031,../input/herbarium-2022-fgvc9/train_images/103...,10398,Paspalum,Poaceae,wrightii


# Creating the model

In [33]:
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.models as models
from PIL import Image
import torch.optim as optim
import numpy as np

In [34]:
batch_size = 32
epoch_size = 5

X_Train, Y_Train = muh_pas_train['image_dir'].values, muh_pas_train['genus'].values

transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])

In [35]:
class GetData(Dataset):
    def __init__(self, FNames, Labels, transform):
        self.fnames = FNames
        self.transform = transform
        self.labels = Labels         
        
    def __len__(self):
        return len(self.fnames)

    def __getitem__(self, index):       
        x = Image.open(self.fnames[index])
    
        if "train" in self.fnames[index]:             
            return self.transform(x), self.labels[index]
        elif "test" in self.fnames[index]:            
            return self.transform(x), self.fnames[index]
                
trainset = GetData(X_Train, Y_Train, transform)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)


In [36]:
def set_device():
    if torch.cuda.is_available():
        dev = "cuda:0"
    else:
        dev = "cpu"
    return torch.device(dev)

In [37]:
model = models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features
n_classes = 2
model.fc == nn.Linear(num_ftrs, n_classes)
device = set_device()
model = model.to(device)
loss = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.003)

In [49]:
def train(model, trainloader, criterion, optimizer, epochs):
    device = set_device()
    
    for epoch in range(epoch_size):
        print("Epoch number %d" %(epoch + 1))
        model.train()
        running_loss = 0.0
        running_correct = 0.0
        total = 0
        
        for data in trainloader:
            images, labels = data
            images = images.to(device)
            print(type(labels))
            labels = labels.to(device)
            total += labels.size(0)
            
            optimizer.zero_grad()
            outputs = model(images)
            predicted = torch.max(outputs.data, 1)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            running_correct += (labels==predicted.sum().item())
        
        epoch_loss = running_loss / len(train_loader)    
        epoch_acc = 100.00 * running_correct / total
        
        print("     -Training dataset. Got %d out of %d images correctly (%.3f%%). Epoch loss: %.3f" % (running_correct, total, epoch_acc, epoch_loss))

In [50]:
train(model, trainloader, loss, optimizer, 10)

Epoch number 1
<class 'tuple'>


AttributeError: 'tuple' object has no attribute 'to'