In [1]:
!wget https://raw.githubusercontent.com/Lavabar/kaggle_Herbarium22/main/efficientnet.py
!wget https://github.com/Lavabar/kaggle_Herbarium22/raw/2a66a174217fae8ac69da232f6e6dd5407faeeb0/checkpoint0_7000.pth

--2022-04-16 19:36:14--  https://raw.githubusercontent.com/Lavabar/kaggle_Herbarium22/main/efficientnet.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.110.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 7517 (7.3K) [text/plain]
Saving to: ‘efficientnet.py’


2022-04-16 19:36:14 (42.3 MB/s) - ‘efficientnet.py’ saved [7517/7517]

--2022-04-16 19:36:15--  https://github.com/Lavabar/kaggle_Herbarium22/raw/2a66a174217fae8ac69da232f6e6dd5407faeeb0/checkpoint0_7000.pth
Resolving github.com (github.com)... 140.82.112.3
Connecting to github.com (github.com)|140.82.112.3|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/Lavabar/kaggle_Herbarium22/2a66a174217fae8ac69da232f6e6dd5407faeeb0/checkpoint0_7000.pth [following]
--2022-04-16 19:36:

In [2]:
# PyTorch model and training necessities
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from efficientnet import EfficientNetB0, EfficientNetB3

from torch.utils.tensorboard import SummaryWriter

# Image datasets and image manipulation
import torchvision
import torchvision.transforms as transforms
from torchvision import datasets

# Image display
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from PIL import Image

torch.manual_seed(0)
images_path = '../input/herbarium-2022-fgvc9/train_images/'

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [4]:
class FirstModelFeatureExtractor(object):
    
    def __init__(self, model_path):
        self.model = torch.load(model_path)
        self.model.eval()
        
    def __call__(self, sample):
        #print(f'{sample.shape}')
        res = self.model.feature_extractor(torch.unsqueeze(sample, 0).cuda())
        #print(f'{res.shape}')
        return torch.squeeze(res)
        

In [5]:
# Gather datasets and prepare them for consumption
transform = transforms.Compose([
                                transforms.Resize((224, 224)),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,)),
                                FirstModelFeatureExtractor('./checkpoint0_7000.pth')
                            ])

class CategoryDataset(torch.utils.data.Dataset):
    def __init__(self, dataframe_X, dataframe_Y, transform, keys):
        self.dataframe = dataframe_X.merge(dataframe_Y)
        self.transform = transform
        self.keys = keys

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, index):
        row = self.dataframe.iloc[index]
        #print(row["file_name"])
        return (
            transform(Image.open(images_path + row["file_name"])).cuda(),
            torch.tensor(self.keys[row["category_id"]]).cuda()
        )

In [6]:
def do_training(training_loader, net, family_id, n_epoch=1):
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.RMSprop(net.parameters(), lr=0.001)
    
    for epoch in range(n_epoch):  # loop over the dataset multiple times
        running_loss = 0.0

        for i, data in enumerate(training_loader, 0):
            # basic training loop
            inputs, labels = data
            print(inputs.shape)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            print('{}, Batch {}'.format(family_id, i + 1))
            running_loss += loss.item()
            if i % 50 == 49:# Every 50 mini-batches...
                torch.save(net, f'./checkpoint{family_id}_{epoch}_{i+1}.pth')
                running_loss = 0.0
    print('Finished Training')
    
    return net

In [7]:
import os
fam_ids = sorted(os.listdir('../input/herbarium22-family-split/dataset_fsplit/'))

In [None]:
for family_id in fam_ids[195:]:

    train_X = pd.read_csv(f'../input/herbarium22-family-split/dataset_fsplit/{family_id}/train_X.csv')
    train_Y = pd.read_csv(f'../input/herbarium22-family-split/dataset_fsplit/{family_id}/train_Y.csv')

    n_categories = train_Y.nunique()['category_id']

    keys = dict(zip(sorted(train_Y['category_id'].unique()), range(n_categories)))
    print(keys)

    train_dataset = CategoryDataset(train_X, train_Y, transform, keys)

    training_loader = torch.utils.data.DataLoader(train_dataset,
                                                  batch_size=64,
                                                  shuffle=False,
                                                  drop_last=True)


    net = EfficientNetB0(in_sz=1536, out_sz=n_categories)
    net.to(device)

    net = do_training(training_loader, net, family_id)
    torch.save(net, f'./checkpoint_{family_id}_final.pth')

In [9]:
!zip -r -9 models_.zip ./checkpoint_*_final.pth

  adding: checkpoint_70_final.pth (deflated 8%)
  adding: checkpoint_71_final.pth (deflated 8%)
  adding: checkpoint_72_final.pth (deflated 8%)
  adding: checkpoint_73_final.pth (deflated 8%)
  adding: checkpoint_74_final.pth (deflated 8%)
  adding: checkpoint_75_final.pth (deflated 8%)
  adding: checkpoint_76_final.pth (deflated 9%)
  adding: checkpoint_77_final.pth (deflated 8%)
  adding: checkpoint_78_final.pth (deflated 8%)
  adding: checkpoint_79_final.pth (deflated 8%)
  adding: checkpoint_81_final.pth (deflated 8%)
  adding: checkpoint_83_final.pth (deflated 8%)
  adding: checkpoint_84_final.pth (deflated 8%)
  adding: checkpoint_85_final.pth (deflated 8%)
  adding: checkpoint_86_final.pth (deflated 8%)
  adding: checkpoint_87_final.pth (deflated 8%)
  adding: checkpoint_88_final.pth (deflated 8%)
  adding: checkpoint_89_final.pth (deflated 8%)
  adding: checkpoint_8_final.pth (deflated 8%)
  adding: checkpoint_90_final.pth (deflated 8%)
  adding: checkpoint_