In [1]:
import os

os.environ["XRT_TPU_CONFIG"] = "tpu_worker;0;10.0.0.2:8470"

# Creating Pipeline to load the dataset for Training

In [2]:
!pip install timm
!rm -rf GLC
!git clone https://github.com/maximiliense/GLC
import timm
from GLC.metrics import top_30_error_rate
import pytorch_lightning as pl
import torch
import os
from pathlib import Path
import pandas as pd

import torch.nn.functional as F
from torch.utils.data import DataLoader
import albumentations as A
from sklearn.impute import SimpleImputer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from skimage.io import imread
import numpy as np


def get_patch_image(sample, image, path="../input/geolifeclef-2022-lifeclef-2022-fgvc9"):
    patch = get_patch(sample, path)
    return patch + '/' + str(sample) + '_' + image


def get_patch_rgb(sample, path="../input/geolifeclef-2022-lifeclef-2022-fgvc9"):
    return get_patch_image(sample, 'rgb.jpg', path)


def get_patch(sample, path="../input/geolifeclef-2022-lifeclef-2022-fgvc9"):
    country_id = str(sample)[0]
    country = 'fr' if country_id == '1' else 'us'
    subfolder = str(sample)[-2:]
    subsubfolder = str(sample)[-4:-2]
    return path + '/patches-' + country + '/' + subfolder + '/' + subsubfolder


def get_country(sample, path="../input/geolifeclef-2022-lifeclef-2022-fgvc9"):
    country_id = str(sample)[0]
    country = 0 if country_id == '1' else 1
    return country


class RGBDataset(torch.utils.data.Dataset):
    def __init__(self, images, labels=None, trans=None):
        self.images = images
        self.labels = labels
        self.trans = trans

    def __len__(self):
        return len(self.images)

    def __getitem__(self, ix):
        img = imread(self.images[ix])
        if self.trans is not None:
            img = self.trans(image=img)['image']
        if self.labels is not None:
            label = self.labels[ix]
            return img, label
        observation_id = self.images[ix].split('/')[-1].split('_')[0]
        return img, observation_id


class RGNirDataset(torch.utils.data.Dataset):
    def __init__(self, observation_ids, labels=None, trans=None):
        self.observation_ids = observation_ids
        self.labels = labels
        self.trans = trans

    def __len__(self):
        return len(self.observation_ids)

    def __getitem__(self, ix):
        observation_id = self.observation_ids[ix]
        patch = get_patch(observation_id)
        rgb = patch + '/' + str(observation_id) + '_rgb.jpg'
        rgb = imread(rgb)
        nir = patch + '/' + str(observation_id) + '_near_ir.jpg'
        nir = imread(nir)
        img = np.concatenate(
            (rgb[..., :2], np.expand_dims(nir, axis=-1)), axis=2)
        if self.trans is not None:
            img = self.trans(image=img)['image']
        if self.labels is not None:
            label = self.labels[ix]
            return img, label
        return img, observation_id


class RGBDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=32, path="../input/geolifeclef-2022-lifeclef-2022-fgvc9", num_workers=0, pin_memory=False, train_trans=None):
        super().__init__()
        self.batch_size = batch_size
        self.path = Path(path)
        self.num_workers = num_workers
        self.pin_memory = pin_memory
        self.train_trans = train_trans

    def read_data(self, mode="train"):
        obs_fr = pd.read_csv(self.path / 'observations' /
                             f'observations_fr_{mode}.csv', sep=';')
        obs_us = pd.read_csv(self.path / 'observations' /
                             f'observations_us_{mode}.csv', sep=';')
        return pd.concat([obs_fr, obs_us])

    def split_data(self):
        self.data_train = self.data[self.data['subset'] == 'train']
        self.data_val = self.data[self.data['subset'] == 'val']

    def generate_datasets(self):
        self.ds_train = RGBDataset(
            self.data_train.image.values, self.data_train.species_id.values, trans=A.Compose([
                getattr(A, trans)(**params) for trans, params in self.train_trans.items()
            ])
            if self.train_trans is not None else None
        )
        self.ds_val = RGBDataset(
            self.data_val.image.values, self.data_val.species_id.values)
        self.ds_test = RGBDataset(self.data_test.image.values)

    def print_dataset_info(self):
        print('train:', len(self.ds_train))
        print('val:', len(self.ds_val))
        print('test:', len(self.ds_test))

    def setup(self, stage=None):
        self.data = self.read_data()
        self.data['image'] = self.data['observation_id'].apply(get_patch_rgb)
        self.data_test = self.read_data('test')
        self.data_test['image'] = self.data_test['observation_id'].apply(
            get_patch_rgb)
        self.split_data()
        self.generate_datasets()
        self.print_dataset_info()

    def get_dataloader(self, ds, batch_size=None, shuffle=None):
        return DataLoader(
            ds,
            batch_size=batch_size if batch_size is not None else self.batch_size,
            shuffle=shuffle if shuffle is not None else True,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory
        )

    def train_dataloader(self, batch_size=None, shuffle=True):
        return self.get_dataloader(self.ds_train, batch_size, shuffle)

    def val_dataloader(self, batch_size=None, shuffle=False):
        return self.get_dataloader(self.ds_val, batch_size, shuffle)

    def test_dataloader(self, batch_size=None, shuffle=False):
        return self.get_dataloader(self.ds_test, batch_size, shuffle)


class RGNirDataModule(RGBDataModule):
    def __init__(self, batch_size=32, path="../input/geolifeclef-2022-lifeclef-2022-fgvc9", num_workers=0, pin_memory=False, train_trans=None):
        super().__init__(batch_size, path, num_workers, pin_memory, train_trans)

    def generate_datasets(self):
        self.ds_train = RGNirDataset(
            self.data_train.observation_id.values, self.data_train.species_id.values, trans=A.Compose([
                getattr(A, trans)(**params) for trans, params in self.train_trans.items()
            ])
            if self.train_trans is not None else None
        )
        self.ds_val = RGNirDataset(
            self.data_val.observation_id.values, self.data_val.species_id.values)
        self.ds_test = RGNirDataset(self.data_test.observation_id.values)

    def setup(self, stage=None):
        self.data = self.read_data()
        self.data_test = self.read_data('test')
        self.split_data()
        self.generate_datasets()
        self.print_dataset_info()


class RGBNirDataModule(RGNirDataModule):
    def __init__(self, batch_size=32, path="../input/geolifeclef-2022-lifeclef-2022-fgvc9", num_workers=0, pin_memory=False, train_trans=None):
        super().__init__(batch_size, path, num_workers, pin_memory, train_trans)

    def generate_datasets(self):
        self.ds_train = RGBNirDataset(
            self.data_train.observation_id.values, self.data_train.species_id.values, trans=A.Compose([
                getattr(A, trans)(**params) for trans, params in self.train_trans.items()
            ])
            if self.train_trans is not None else None
        )
        self.ds_val = RGBNirDataset(
            self.data_val.observation_id.values, self.data_val.species_id.values)
        self.ds_test = RGBNirDataset(self.data_test.observation_id.values)


class RGBNirDataset(torch.utils.data.Dataset):
    def __init__(self, observation_ids, labels=None, trans=None):
        self.observation_ids = observation_ids
        self.labels = labels
        self.trans = trans

    def __len__(self):
        return len(self.observation_ids)

    def __getitem__(self, ix):
        observation_id = self.observation_ids[ix]
        patch = get_patch(observation_id)
        rgb = patch + '/' + str(observation_id) + '_rgb.jpg'
        rgb = imread(rgb)
        nir = patch + '/' + str(observation_id) + '_near_ir.jpg'
        nir = imread(nir)
        img = np.concatenate((rgb, np.expand_dims(nir, axis=-1)), axis=2)
        if self.trans is not None:
            img = self.trans(image=img)['image']
        if self.labels is not None:
            label = self.labels[ix]
            return img, label
        return img, observation_id


class RGBModule(pl.LightningModule):
    def __init__(self, hparams):
        super().__init__()
        self.save_hyperparameters(hparams)
        self.model = timm.create_model(
            self.hparams.backbone,
            pretrained=self.hparams.pretrained,
            num_classes=17037
        )

    def forward(self, x):
        x = x.float() / 255
        x = x.permute(0, 3, 1, 2)
        return self.model(x)

    def predict(self, x):
        self.eval()
        with torch.no_grad():
            preds = self(x.to(self.device))
            return torch.softmax(preds, dim=1)

    def shared_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        error = top_30_error_rate(
            y.cpu(), torch.softmax(y_hat, dim=1).cpu().detach())
        return loss, error

    def training_step(self, batch, batch_idx):
        loss, error = self.shared_step(batch, batch_idx)
        self.log('loss', loss)
        self.log('error', error, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        loss, error = self.shared_step(batch, batch_idx)
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_error', error, prog_bar=True)

    def configure_optimizers(self):
        optimizer = getattr(torch.optim, self.hparams.optimizer)(self.parameters(),
                                                                 **self.hparams['optimizer_params'])
        return optimizer


class RGBNirModule(RGBModule):
    def __init__(self, hparams):
        super().__init__(hparams)
        self.model = timm.create_model(
            self.hparams.backbone,
            pretrained=self.hparams.pretrained,
            num_classes=17037,
            in_chans=4,
        )

# Loading the Dataset Module

In [3]:
dm =RGBNirDataModule()
dm.setup()

# Checking wether the Data has been loaded

In [4]:
imgs, labels = next(iter(dm.train_dataloader(batch_size=25)))
imgs.shape, imgs.dtype, imgs.max(), imgs.min(), labels

In [5]:
import matplotlib.pyplot as plt

# plot images in a 2x5 grid
fig = plt.figure(figsize=(15,5))
for i in range(5):
    ax = plt.subplot(2, 5, i+1)
    ax.imshow(imgs[i][...,:3])
    ax.set_title('rgb')
    ax.axis('off')
    ax = plt.subplot(2, 5, i+1+5)
    ax.set_title('nir')
    ax.imshow(imgs[i][...,3])
    ax.axis('off')
plt.tight_layout()
plt.show()

# Setting Base model parameters

In [6]:
hparams = {
    'backbone': 'resnet18',
    'pretrained': True,
    'optimzier': 'Adam',
    'optimizer_params': {
        'lr': 1e-3
    }
}

module = RGBNirModule(hparams)
outputs = module(imgs)
outputs.shape

# Training Steps

In [7]:
import pytorch_lightning as pl

hparams = {
    'datamodule': {
        'batch_size': 64,
        'num_workers': 2,
        'pin_memory': False
    },
    'backbone': 'resnet18',
    'pretrained': True,
    'optimizer': 'Adam',
    'optimizer_params': {
        'lr': 1e-3
    }
}


dm = RGBNirDataModule(**hparams['datamodule'])
dm.setup()
module = RGBNirModule(hparams)

trainer = pl.Trainer(
    gpus=[0],
    max_epochs=10,
    enable_checkpointing=False,
    logger=None,
    overfit_batches=0
)

trainer.fit(module, dm)

In [None]:
dm = RGBNirDataModule(batch_size=64, pin_memory=True)
dm.setup()

# Evaluating Top-30 Error Rate

In [None]:
from tqdm import tqdm
from GLC.metrics import top_30_error_rate
import numpy as np 

module.cuda(0)
dl = dm.val_dataloader()
accs = []
for imgs, labels in tqdm(dl):
    preds = module.predict(imgs)

    accs.append(top_30_error_rate(labels, preds.cpu()))
np.mean(accs)