In [None]:
! pip install -q timm \
  https://pip.repos.neuron.amazonaws.com/torch-xla/torch_xla-1.13.0%2Btorchneuron3-cp37-cp37m-linux_x86_64.whl

In [None]:
! export XLA_USE_BF16=1

In [None]:
import glob
import io

import pandas as pd

import tensorflow as tf
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch_xla.core.xla_model as xm
import torchvision.transforms as T

from joblib import Parallel, delayed
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import StratifiedKFold
from tqdm.auto import tqdm
from sklearn.metrics import f1_score

In [None]:
class config:
    MAX_EPOCHS = 5      
    N_SPLITS = 5          # Must be equal or less than 8
    BATCH_SIZE = 32

In [None]:
def tfrecords_to_dataframe(fp, test=False):

    def parse(pb, test=False):
        d = {
            "id": tf.io.FixedLenFeature([], tf.string),
            "image": tf.io.FixedLenFeature([], tf.string),
        }
        if not test:
            d["class"] = tf.io.FixedLenFeature([], tf.int64)
        return tf.io.parse_single_example(pb, d)

    df = {"id": [], "img": []}
    if not test:
        df["lab"] = []
    for sample in tf.data.TFRecordDataset(glob.glob(fp)).map(
        lambda pb: parse(pb, test)
    ):
        df["id"].append(sample["id"].numpy().decode("utf-8"))
        df["img"].append(sample["image"].numpy())
        if not test:
            df["lab"].append(sample["class"].numpy())
    return pd.DataFrame(df)

In [None]:
class AverageMeter():
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
train_df = pd.concat([
    tfrecords_to_dataframe("../input/tpu-getting-started/tfrecords-jpeg-224x224/train/*.tfrec"),
    tfrecords_to_dataframe("../input/tpu-getting-started/tfrecords-jpeg-224x224/val/*.tfrec"),
], ignore_index=True).reset_index()

train_df.drop('index', axis=1, inplace=True)

cv = StratifiedKFold(n_splits=config.N_SPLITS, random_state=42, shuffle=True)

train_df['fold'] = -1

for fold, (train_idx, val_idx) in enumerate(cv.split(train_df, train_df['lab'])):
    train_df.loc[val_idx, 'fold'] = fold

In [None]:
class PetalDataset(Dataset):
    
    def __init__(self, df, test=False):
        self.df = df
        self.test = test
        self.transform = T.ToTensor()
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        item = self.df.iloc[idx]
        img = Image.open(io.BytesIO(item.img))
        img = self.transform(img)
        
        if self.test:
            return img
        
        label = item.lab
        return img, label

In [None]:
# Downloading the model
model = timm.create_model('efficientnet_b0', pretrained=True)
del model

In [None]:
counts = train_df['lab'].value_counts()
class_weights = torch.tensor(1/counts.sort_index(), dtype=torch.float)

In [None]:
def train_fn(fold):
    device = xm.xla_device(fold + 1)

    val_ = train_df.query('fold == @fold')
    train_ = train_df.query('fold != @fold')

    train_ds = PetalDataset(train_)
    val_ds = PetalDataset(val_)

    train_loader = DataLoader(train_ds, batch_size=config.BATCH_SIZE, drop_last=True, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=config.BATCH_SIZE)

    model = timm.create_model('efficientnet_b0', pretrained=True, num_classes=104)
    model.to(device)
    optimizer = optim.Adam(model.parameters())
    scheduler = optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=1e-4,
        steps_per_epoch=len(train_loader),
        pct_start=0.1,
        epochs=config.MAX_EPOCHS,
    )
    loss_fn = nn.CrossEntropyLoss(weight=class_weights)
    meter = AverageMeter()
    
    for epoch in range(config.MAX_EPOCHS):
        model.train()
        stream = tqdm(train_loader, desc=f"Fold={fold}, Epoch={epoch}")
        for data, target in stream:
            optimizer.zero_grad()
            data = data.to(device)
            target = target.to(device)
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()

            optimizer.step()
            xm.mark_step()
            scheduler.step()
            
            meter.update(loss.item())
            
            stream.set_postfix({
                "train_loss": meter.avg,
            })
        meter.reset()
        
        model.eval()
        with torch.no_grad():
            stream = tqdm(val_loader, desc=f"Fold={fold}, Validating...")
            for data, target in stream:
                data = data.to(device)
                target = target.to(device)
                output = model(data)
                loss = loss_fn(output, target)
                meter.update(loss.item())

                stream.set_postfix({
                    "val_loss": meter.avg,
                })
        meter.reset()
            
    model.eval()
    xm.save(model.state_dict(), f"model_fold_{fold}.pt")
        
    y_true = []
    y_pred = []

    with torch.no_grad():
        for data, target in stream:
            data = data.to(device)
            target = target.to(device)
            output = model(data)
            y_pred.extend(output.argmax(axis=1).cpu().numpy())
            y_true.extend(target.squeeze().cpu().numpy())

    val_score = f1_score(y_true, y_pred, average='macro')      
    return val_score

In [None]:
Parallel(n_jobs=config.N_SPLITS, backend="threading")(delayed(train_fn)(i) for i in range(config.N_SPLITS))