In [1]:
!pip install augmax

Collecting augmax
  Downloading augmax-0.3.1-py3-none-any.whl (21 kB)
Collecting einops>=0.3
  Downloading einops-0.6.1-py3-none-any.whl (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: einops, augmax
Successfully installed augmax-0.3.1 einops-0.6.1
[0m

In [2]:
! pip install orbax

[0m

In [3]:
!pip install nest-asyncio

[0m

In [4]:
import nest_asyncio
nest_asyncio.apply()

In [5]:
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'

In [6]:
## jax
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, pmap
from jax.experimental.pjit import pjit
from jax import lax
from jax import random

from jax import make_jaxpr
from jax import device_put

## Flax
from flax.core import freeze, unfreeze
from flax import linen as nn
from flax.training import train_state
from flax import traverse_util

from orbax.checkpoint import*

# Optimizer for flax
import optax

import augmax as AUX
from PIL import Image
import cv2
import numpy as np
import pandas as pd
from sklearn import datasets
import matplotlib.pyplot as plt

from sklearn.model_selection import KFold, StratifiedKFold

from transformers import FlaxViTModel,FlaxCLIPModel,FlaxBeitModel

import os
import gc
from functools import partial
from tqdm import tqdm
import seaborn as sns

# Import Data

In [7]:
train_df_orig = pd.read_csv('../input/hackathon-online-phuket-landmark-recognition/train.csv')
test_df_orig = pd.read_csv('../input/hackathon-online-phuket-landmark-recognition/test.csv')
train_df = train_df_orig.rename( columns = {'id':'filename'} )
test_df_submit = test_df_orig.rename( columns = {'id':'filename'} )

In [8]:
m_train_path =  '../input/hackathon-online-phuket-landmark-recognition/images/images/train/'
m_test_path = '../input/hackathon-online-phuket-landmark-recognition/images/images/test/'

In [9]:
train_df['filepath'] = [ m_train_path + x for x in train_df['filename'] ]
test_df_submit['filepath'] = [ m_test_path + x for x in test_df_submit['filename'] ]

In [10]:
train_df

Unnamed: 0,filename,label,filepath
0,dc8ca8843cc05c937ae4086da5ad49f1.jpg,7,../input/hackathon-online-phuket-landmark-reco...
1,abf428c748961ce38012c04fb4f67a0a.jpg,10,../input/hackathon-online-phuket-landmark-reco...
2,7bef6daf30000b2bb9e57af7bc87b780.jpg,1,../input/hackathon-online-phuket-landmark-reco...
3,db6f12a84dedc23e3d55320a9149d69a.jpg,7,../input/hackathon-online-phuket-landmark-reco...
4,8d17fcf554881b42a070162c19f73f3a.jpg,8,../input/hackathon-online-phuket-landmark-reco...
...,...,...,...
2720,50db1fe1991abf6aaf36f6f4f1b66cd4.jpg,3,../input/hackathon-online-phuket-landmark-reco...
2721,c0cdb5a4aa8060e763bcaae3cd7cc702.jpg,14,../input/hackathon-online-phuket-landmark-reco...
2722,c953c76faa7580b3b00be8856f6e9d98.jpg,8,../input/hackathon-online-phuket-landmark-reco...
2723,174cae6121652e3527460e8f52672875.jpg,4,../input/hackathon-online-phuket-landmark-reco...


# EDA

In [11]:
# rand_arr = np.random.randint(0,135,5)
# vrows = 5
# fig, axes = plt.subplots(15,vrows,figsize = (12,12))
# for i in tqdm(range(15)):
#     for j,d in enumerate(rand_arr):
#         img_path = train_df[train_df.label == i].iloc[d,0]
#         img_path = m_train_path + img_path
#         img_arr = cv2.imread(img_path)
#         img_arr = cv2.cvtColor(img_arr, cv2.COLOR_BGR2RGB)
        
#         ax = axes[i,j]
#         ax.imshow(img_arr)
#         ax.set_title(f'c : {i}')
    

In [12]:
# sns.countplot(x = train_df['label'])

# Upsampling

In [13]:
# up_df = []
# l_max = train_df['label'].value_counts().max()
# np.random.seed(1995)
# for c in range(train_df['label'].nunique()):
#     df = train_df[train_df['label'] == c].reset_index(drop = True)
#     up_df.append( 
#         df.loc[np.random.randint(0,len(df), l_max - len(df))]
#     )

# print(l_max)

In [14]:
# up_df.append(train_df)
# upsamp_train_df = pd.concat(up_df,ignore_index = True)

In [15]:
# sns.countplot(x = upsamp_train_df['label'])

In [16]:
# upsamp_train_df = upsamp_train_df.sample(frac = 1, random_state = 1995).reset_index(drop = True)

In [17]:
upsamp_train_df = train_df.sample(frac = 1, random_state = 1995).reset_index(drop = True)

In [18]:
# upsamp_train_df

# Preprocessing

In [19]:
transform = AUX.Chain(
    AUX.HorizontalFlip(p = .2),
    AUX.VerticalFlip(p = .2),
    AUX.Rotate(p = .2),
    AUX.RandomContrast(range = (-.5,.5),p = .2),
    AUX.ByteToFloat(),
#     AUX.Normalize(),
)

In [20]:
class Dataload:
    def __init__(self,df,classes = None,img_size = 224,batch_size = 32,transform = None):
        self.classes = classes
        self.df = df
        self.batch_size = batch_size
        self.transform = transform
        self.rng = random.PRNGKey(1995)
        self.datasize = len(self.df)
        self.batch_idx = list(range(0,self.datasize,self.batch_size))
        self.img_size = img_size
        
        ### modify
        
    def get_batch(self,idx,train = True):
#         data_batch = np.zeros((self.batch_size,self.img_size,self.img_size,3))
        data_batch = []
        label_batch = list()
        sliced_df = self.df[ idx:idx + self.batch_size ]
        for i,data in enumerate(sliced_df.iterrows()):
            img = cv2.imread(data[1]['filepath'],cv2.IMREAD_COLOR)
#             img = jnp.array(img,dtype = jnp.float32)
            img = cv2.resize(img,(self.img_size,self.img_size))
            data_batch.append(img)
            
            if train:
                label = data[1]['label']
                z = np.zeros(self.classes)
                z[label] = 1
                label_batch.append(z)
        
        # Random mini Batch        
        if train:
            idx = np.arange(len(data_batch))
            idx = np.random.shuffle(idx)
            label_batch = np.array(label_batch)[idx,:]
            data_batch = np.array(data_batch)
            data_batch = data_batch[idx,:][0]
        
        data_batch = jnp.array(data_batch,dtype = jnp.float32)
        
        if self.transform:
            if train:
                data_batch = jax.jit(jax.vmap(self.transform, in_axes = [None, 0]))\
                                (self.rng, data_batch).block_until_ready()
            else:
                transform = AUX.Chain(
                    AUX.ByteToFloat(),
#                     AUX.Normalize(),
                )
                data_batch = jax.jit(jax.vmap(transform, in_axes = [None, 0]))\
                                (self.rng, data_batch).block_until_ready()
                
            
        return data_batch,jnp.array(label_batch)

# Modeling

In [21]:
class ClassifierHead(nn.Module):
    num_classes: int
    backbone: nn.Module
  

    @nn.compact
    def __call__(self, x):
        x = self.backbone(x).pooler_output
        stk = nn.Dense(512, name='head0', kernel_init= nn.initializers.glorot_uniform())(x)
        stk = nn.activation.relu(stk)
#         stk = nn.Dense(256, name='head1', kernel_init= nn.initializers.glorot_uniform())(stk)
#         stk = nn.Dense(256, name='head2', kernel_init= nn.initializers.glorot_uniform())(stk)
#         stk = nn.activation.relu(stk)
        stk = nn.Dense(15, name='head3', kernel_init= nn.initializers.glorot_uniform())(stk)
        return stk

In [22]:
clipmodel = FlaxBeitModel.from_pretrained("microsoft/beit-base-patch16-224")
model = ClassifierHead(num_classes = train_df['label'].nunique(), backbone = clipmodel.module)

Downloading (…)lve/main/config.json:   0%|          | 0.00/69.9k [00:00<?, ?B/s]

Downloading flax_model.msgpack:   0%|          | 0.00/346M [00:00<?, ?B/s]

In [23]:
x = jnp.empty((1, 224, 224, 3))
variables = model.init(jax.random.PRNGKey(1996), x)
params = variables['params'].unfreeze()
params['backbone'] = clipmodel.params
params = freeze(params)

In [24]:
partition_optimizers = {'trainable': optax.adamw(0.5e-3), 'frozen': optax.set_to_zero()}
traverse = traverse_util.path_aware_map( lambda path, v: 'frozen' if 'backbone' in path else 'trainable', params )
param_partitions = freeze(traverse)
tx = optax.multi_transform(partition_optimizers, param_partitions)

In [25]:
state = train_state.TrainState.create(
  apply_fn=model.apply,
  params=params,
  tx=tx,
)

# Train one step

In [26]:
@jit
def train_step(state, batch):
    """Train for a single step."""
    def loss_fn(params):
        logits = state.apply_fn(
            {'params': params},
            x=batch[0])
        loss = jnp.mean(optax.softmax_cross_entropy(logits, batch[1]))
        return loss, logits
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)

    metrics = {
    'loss': loss,
      'accuracy': jnp.mean(jnp.argmax(logits, -1) == jnp.argmax(batch[1],-1)),
    }
    return state, metrics


@jit
def eval_step(state, batch):
    """Eval for a single step."""
    logits = state.apply_fn(
        {'params': state.params},
        x=batch[0])
    loss = jnp.mean(optax.softmax_cross_entropy(logits, batch[1]))

    metrics = {
    'loss': loss,
      'accuracy': jnp.mean(jnp.argmax(logits, -1) == jnp.argmax(batch[1],-1)),
    }
    return state, metrics

@jit
def predict_step(state, batch):
    """Eval for a single step."""
    logits = state.apply_fn(
        {'params': state.params},
        x=batch[0])
    return logits

# Train whole dataset

In [27]:
def _fit(state,train_df,batch_size):
    loss = []
    auc = []
    loader = Dataload(train_df,train_df['label'].nunique(),transform = transform,batch_size = batch_size)
    for idx,i in enumerate(tqdm(loader.batch_idx)):
        d = loader.get_batch(i)
        state, mtx = train_step(state,d)
        loss.append(mtx['loss'])
        auc.append(mtx['accuracy'])
        if (idx + 1) % 10 == 0:
            print(f'Loss: {(sum(loss) / (idx+1)):.4f} | AUC:{(sum(auc) / (idx+1)):.4f}')

    return state,loss,auc

def _eval(state,test_df,batch_size):
    loss = []
    auc = []
    transform_ev = AUX.Chain(
                    AUX.ByteToFloat(),
#                     AUX.Normalize(),
                )
    loader = Dataload(test_df,test_df['label'].nunique(),transform = transform_ev,batch_size = batch_size)
    for idx,i in enumerate(tqdm(loader.batch_idx)):
        d = loader.get_batch(i)
        state, mtx = eval_step(state,d)
        loss.append(mtx['loss'])
        auc.append(mtx['accuracy'])
    
    print(f'Loss: {(sum(loss) / len(loader.batch_idx)):.4f} | AUC:{(sum(auc) / len(loader.batch_idx)):.4f}')
    
    return state,loss,auc

def _predict(state,test_df,batch_size):
    logits = []
    loader = Dataload(test_df,0,transform = transform,batch_size = batch_size)
    for idx,i in enumerate(tqdm(loader.batch_idx)):
        d = loader.get_batch(i,train = False)
        logit = predict_step(state,d)
        logits.append(np.array(jnp.argmax(logit,-1)))
    
    return logits

In [28]:
fold = KFold(n_splits=3, random_state=1995, shuffle=True)
# Watch list
train_losses = []
test_losses = []

train_auc = []
test_auc = []
########

#check point manager
options = CheckpointManagerOptions(max_to_keep=5)
mngr = CheckpointManager(
          'zckpt',PyTreeCheckpointer(),options=options)
#####
for _ in range(1):
    for i, (train_idx, test_idx) in enumerate(fold.split( upsamp_train_df['filename'])):
        train_df = upsamp_train_df.loc[train_idx]
        test_df = upsamp_train_df.loc[test_idx]
        print(f'------- Fold {i+1} --------')
        state, loss, auc = _fit(state,train_df,64)
        train_losses.append(loss)
        train_auc.append(auc)
        print('------ Eval --------')
        state, loss, auc = _eval(state,test_df,64)
        test_losses.append(loss)
        test_auc.append(auc)
        mngr.save(i,{'model':state})
    
mngr.wait_until_finished()

------- Fold 1 --------


 34%|███▍      | 10/29 [00:36<00:48,  2.57s/it]

Loss: 1.6466 | AUC:0.5297


 69%|██████▉   | 20/29 [00:59<00:21,  2.43s/it]

Loss: 1.0665 | AUC:0.6969


100%|██████████| 29/29 [01:28<00:00,  3.04s/it]


------ Eval --------


100%|██████████| 15/15 [00:32<00:00,  2.19s/it]


Loss: 0.2862 | AUC:0.9084
------- Fold 2 --------


 34%|███▍      | 10/29 [00:19<00:42,  2.25s/it]

Loss: 0.2238 | AUC:0.9297


 69%|██████▉   | 20/29 [00:39<00:18,  2.04s/it]

Loss: 0.2051 | AUC:0.9391


100%|██████████| 29/29 [01:06<00:00,  2.28s/it]


------ Eval --------


100%|██████████| 15/15 [00:21<00:00,  1.41s/it]


Loss: 0.1073 | AUC:0.9688
------- Fold 3 --------


 34%|███▍      | 10/29 [00:19<00:39,  2.06s/it]

Loss: 0.0933 | AUC:0.9797


 69%|██████▉   | 20/29 [00:38<00:19,  2.14s/it]

Loss: 0.0972 | AUC:0.9742


100%|██████████| 29/29 [00:53<00:00,  1.86s/it]


------ Eval --------


100%|██████████| 15/15 [00:16<00:00,  1.12s/it]


Loss: 0.0498 | AUC:0.9917


In [29]:
#predict
logits = _predict(state,test_df_submit,32)
res = []
for ar in logits:
    for i in ar:
        res.append(i)

100%|██████████| 24/24 [00:42<00:00,  1.76s/it]


In [30]:
sub = pd.read_csv('../input/hackathon-online-phuket-landmark-recognition/submit.csv')
sub.predict = res

In [31]:
sub.to_csv("/kaggle/working/submit.csv", index = False)