In [None]:
# 0. preprocess test_data
# 1. get image from test_data
# 2. first_model(image) -> family_id and features
# 3. if family_id in one-category family -> category_id
#       else: get second_model from models_dict
# 4. second_model(features) -> category_number
# 5. category_number -> category_id
# 6. submit results

In [None]:
!wget https://raw.githubusercontent.com/Lavabar/kaggle_Herbarium22/main/efficientnet.py
!wget https://github.com/Lavabar/kaggle_Herbarium22/raw/main/checkpoint0_7000.pth
!wget https://github.com/Lavabar/kaggle_Herbarium22/raw/main/keys.pkl
!pip install gdown

import gdown
url = 'https://drive.google.com/uc?id=14vnRYKYnFFLK08uX7tq_z4zH2wYbrqET'
output = 'models_cat.zip'
gdown.download(url, output, quiet=False)

!unzip models_cat.zip

In [None]:
import json
import pandas as pd
import torch

import torchvision
import torchvision.transforms as transforms
from torchvision import datasets

from PIL import Image

images_path = '../input/herbarium-2022-fgvc9/test_images/'

In [None]:
with open('../input/herbarium-2022-fgvc9/test_metadata.json', 'rb') as f:
    test_meta = json.load(f)
    
df = pd.DataFrame(test_meta)

In [None]:
# Gather datasets and prepare them for consumption
transform = transforms.Compose([
                                transforms.Resize((224, 224)),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,)),
                            ])

class TestDataset(torch.utils.data.Dataset):
    def __init__(self, dataframe_X, transform):
        self.dataframe = dataframe_X
        self.transform = transform

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, index):
        row = self.dataframe.iloc[index]
        return transform(Image.open(images_path + row)).cuda()

In [None]:
test_X = df['file_name']
test_dataset = TestDataset(test_X, transform)
test_loader = torch.utils.data.DataLoader(test_dataset,
                                          batch_size=1,
                                          shuffle=False)

In [None]:
import pickle 

with open('./keys.pkl', 'rb') as f:
    keys = pickle.load(f)

In [None]:
fnetwork = torch.load('./checkpoint0_7000.pth')
fnetwork.eval()

In [None]:
with open('../input/herbarium-2022-fgvc9/train_metadata.json', 'rb') as f:
#with open('./train_metadata.json', 'rb') as f:
    train_meta = json.load(f)
    
df = pd.DataFrame(train_meta['categories'])
maps = list(enumerate(df['family'].unique()))
maps = {k: v for (v,k) in maps}
df['family_id'] = df['family'].map(maps)

cat_fam = df[['category_id', 'family_id']]
cat_fam['family_id'] = cat_fam['family_id'].astype(int)

In [None]:
cat_cnt = cat_fam.groupby('family_id').count()
onecat = cat_cnt[:][cat_cnt['category_id'] == 1].index
# cat_fam[:][.index]
cat_fam = cat_fam[:][cat_fam['family_id'].isin(onecat)]
cat_fam = cat_fam.set_index('family_id')

In [None]:
predictions = []
i = 0
fnetwork.train(False)
for batch in test_loader:
    inputs = batch
    preds = fnetwork(inputs)
    family = int(preds.argmax().cpu().detach())
    if family in keys.keys():
        torch.cuda.empty_cache()
        model = torch.load(f'/kaggle/working/models_cat/checkpoint_{family}_final.pth')
        model.eval()
        model.train(False)
        cat_pred = model(fnetwork.feature_extractor(inputs))
        category = keys[family][int(cat_pred.argmax().cpu().detach())]
    else:
        print('here')
        category = cat_fam.iloc[family]
    predictions += [category]
    i+=1

In [None]:
my_submission = pd.read_csv('../input/herbarium-2022-fgvc9/sample_submission.csv')

In [None]:
my_submission['Predicted'] = predictions

In [None]:
my_submission.to_csv('submission.csv', index=False)