In [1]:
from torchvision.models import resnet18

In [2]:
!pip install lmdb

Collecting lmdb
  Downloading lmdb-1.5.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.1 kB)
Downloading lmdb-1.5.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (294 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m294.9/294.9 kB[0m [31m6.0 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: lmdb
Successfully installed lmdb-1.5.1


In [4]:
!pip install pytorch_metric_learning

Collecting pytorch_metric_learning
  Downloading pytorch_metric_learning-2.6.0-py3-none-any.whl.metadata (17 kB)
Downloading pytorch_metric_learning-2.6.0-py3-none-any.whl (119 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m119.3/119.3 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pytorch_metric_learning
Successfully installed pytorch_metric_learning-2.6.0


In [77]:
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision
from torch import nn
import cv2
import numpy as np
import os
from pathlib import Path
from pytorch_metric_learning import losses
from matplotlib import pyplot as plt
import torch.nn.functional as F
from tqdm import tqdm

In [78]:
from pathlib import Path

def evaluate(gt_path, pred_path):
    gt = dict()
    with open(gt_path) as gt_f:
        for line in gt_f:
            name, cls = line.strip().split()
            gt[name] = cls
    
    n_good = 0
    n_all = len(gt)
    with open(pred_path) as pred_f:
        for line in pred_f:
            name, cls = line.strip().split()
            if cls == gt[name]:
                n_good += 1
    
    return n_good / n_all

In [79]:
root = '/kaggle/input/chinese'
train_path = os.path.join(root, 'train.lmdb')
test_path = os.path.join(root, 'test.lmdb')
gt_path = '/kaggle/input/directory_gt/gt.txt'
pred_path = '/kaggle/working/pred.txt'

In [82]:
from itertools import permutations
import zipfile
from typing import Optional, List
from pathlib import Path
import numpy as np
from collections import defaultdict, Counter
import lmdb


class Vocabulary:
    def __init__(self, classes):
        self.classes = sorted(set(classes))
        self._class_to_index = dict((cls, idx) for idx, cls in enumerate(self.classes))
    
    def class_by_index(self, idx: int) -> str:
        return self.classes[idx]

    def index_by_class(self, cls: str) -> int:
        return self._class_to_index[cls]
    
    def num_classes(self) -> int:
        return len(self.classes)


class ArchivedHWDBReader:
    def __init__(self, path: Path):
        self.path = path
        self.archive = None
    
    def open(self):
        self.archive = zipfile.ZipFile(self.path)
    
    def namelist(self):
        return self.archive.namelist()
    
    def decode_image(self, name):
        sample = self.archive.read(name)
        buf = np.asarray(bytearray(sample), dtype='uint8')
        return cv2.imdecode(buf, cv2.IMREAD_GRAYSCALE)
    
    def close(self):
        self.archive.close()
    
    def __enter__(self):
        self.open()
        return self
    
    def __exit__(self, exc_type, exc_value, traceback):
        self.close()


GB = 2**30
class LMDBReader:
    def __init__(self, path: Path):
        self.path = path
        self.env = None
        self.namelist_ = []
    
    def open(self):
        self.env = lmdb.open(self.path, 
                             map_size=GB * 16,
                             lock=False, 
                             subdir=False, 
                             readonly=True)
        self.namelist_ = []
        with self.env.begin(buffers=True) as txn:
            cursor = txn.cursor()
            for key, _ in cursor:
                key = bytes(key).decode('utf-8')
                self.namelist_.append(key)
    
    def namelist(self):
        return self.namelist_
    
    def decode_image(self, name):
        key = name.encode('utf-8')
        with self.env.begin() as txn:
            sample = txn.get(key)
        buf = np.frombuffer(sample, dtype='uint8')
        return cv2.imdecode(buf, cv2.IMREAD_GRAYSCALE)
    
    def close(self):
        self.env.close()
    
    def __enter__(self):
        self.open()
        return self
    
    def __exit__(self, exc_type, exc_value, traceback):
        self.close()


class HWDBDatasetHelper:
    def __init__(self, reader, prefix='Train', vocabulary: Optional[Vocabulary]=None, namelist: Optional[List[str]]=None):
        self.reader = reader
        self.prefix = prefix
        self.index = defaultdict(list)
        self.counter = Counter()
        self.namelist = namelist
        if self.namelist is None:
            self.namelist = list(filter(lambda x: self.prefix in x, self.reader.namelist()))
        self.vocabulary = vocabulary
        self._build_index()
    
    def get_item(self, idx):
        name = self.namelist[idx]
        return self.reader.decode_image(name), \
            self.vocabulary.index_by_class(HWDBDatasetHelper._get_class(name))
    
    def size(self):
        return len(self.namelist)

    def get_all_class_items(self, idx):
        cls = self.vocabulary.class_by_index(idx)
        return self.index[cls]
    
    def most_common_classes(self, n=None):
        return self.counter.most_common(n)
    
    def train_val_split(self, train_part=0.8, seed=42):
        rnd = np.random.default_rng(seed)
        permutation = rnd.permutation(len(self.namelist))
        train_part = int(len(permutation) * train_part)
        train_names = [self.namelist[idx] for idx in permutation[:train_part]]
        val_names = [self.namelist[idx] for idx in permutation[train_part:]]

        return HWDBDatasetHelper(self.reader, self.prefix, self.vocabulary, train_names),\
            HWDBDatasetHelper(self.reader, self.prefix, self.vocabulary, val_names)
    
    @staticmethod
    def _get_class(name):
        return Path(name).parent.name
    
    def _build_index(self):
        classes = set()
        for idx, name in enumerate(self.namelist):
            cls = HWDBDatasetHelper._get_class(name)
            classes.add(cls)
            self.index[cls].append(idx)
            self.counter.update([cls])
        
        if self.vocabulary is None:
            self.vocabulary = Vocabulary(classes)

class HWDBDataset(Dataset):
    def __init__(self, helper: HWDBDatasetHelper):
        self.helper = helper
    
    def __len__(self):
        return self.helper.size()
    
    def __getitem__(self, idx):
        img, label = self.helper.get_item(idx)
        img = cv2.resize(img, (128, 128))
        img = torch.tensor(img, dtype=torch.float32).unsqueeze(0)
        img = (img - 127.5) / 255.
        return img, label


In [83]:
train_reader = LMDBReader(train_path)
train_reader.open()
train_helper = HWDBDatasetHelper(train_reader)
train_helper, val_helper = train_helper.train_val_split()

In [84]:
train_dataset = HWDBDataset(train_helper)
val_dataset = HWDBDataset(val_helper)
train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True, drop_last=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=2048, shuffle=False, num_workers=4)

In [97]:
class CustomResNet(nn.Module):
    def __init__(self, embedding_size=512):
        super().__init__()
        self.resnet = resnet18()
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, embedding_size)
#         self.resnet.fc = nn.Linear(num_features, embedding_size)
    
    def forward(self, x):
        x = x.expand(-1, 3, -1, -1)
        return self.resnet(x)

In [98]:
from pytorch_metric_learning import losses

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CustomResNet()
model = model.to(device)
optim = torch.optim.AdamW(model.parameters(), lr=0.001)
loss_fn = losses.ArcFaceLoss(num_classes=train_helper.vocabulary.num_classes(), embedding_size=512).to(device)

In [99]:
def run_validation(val_loader: DataLoader, model: nn.Module, n_steps=None):
    model.eval()
    n_good = 0
    n_all = 0
    wrapper = lambda x: x
    if n_steps is None:
        n_steps = len(val_loader)
        wrapper = tqdm
    
    with torch.no_grad():
        for batch, (X, y) in enumerate(wrapper(val_loader)):
            if batch == n_steps:
                break
            logits = model(X.to(torch.float32).to(device))
            X = model(X.to(torch.float32).to(device))
            logits = loss_fn.get_logits(X)
            classes = torch.argmax(logits, dim=1).cpu().numpy()
            n_good += sum(classes == y.cpu().numpy())
            n_all += len(classes)
    
    return n_good / n_all

In [101]:
for epoch in range(3):
    print(f'Epoch {epoch}:')
    for batch, (X, y) in enumerate(tqdm(train_loader)):
        model.train()
        
        logits = model(X.to(torch.float32).to(device))
        loss = loss_fn(logits, y.to(torch.long).to(device))
        
        optim.zero_grad()
        loss.backward()
        optim.step()

    torch.save(model.state_dict(), f'my_epoch{epoch}.pth')
    
    accuracy = run_validation(val_loader, model)
    print(f'accuracy: {accuracy}')

Epoch 0:


100%|██████████| 5036/5036 [24:33<00:00,  3.42it/s]
100%|██████████| 315/315 [04:37<00:00,  1.13it/s]


accuracy: 0.9388590602985686
Epoch 1:


100%|██████████| 5036/5036 [24:33<00:00,  3.42it/s]
100%|██████████| 315/315 [04:37<00:00,  1.14it/s]


accuracy: 0.9368159613036741
Epoch 2:


100%|██████████| 5036/5036 [24:33<00:00,  3.42it/s]
100%|██████████| 315/315 [04:37<00:00,  1.13it/s]

accuracy: 0.9577914673856555





In [102]:
test_reader = LMDBReader(test_path)
test_reader.open()
test_helper = HWDBDatasetHelper(test_reader, prefix='Test')
test_dataset = HWDBDataset(test_helper)
test_loader = DataLoader(test_dataset, batch_size=512, shuffle=False, num_workers=0)

In [103]:
preds = []
model.eval()
with torch.no_grad():
    for X, _ in tqdm(test_loader):
        x = model(X.to(torch.float32).to(device))
        logits = loss_fn.get_logits(x)
        classes = torch.argmax(logits, dim=1).cpu().numpy()
        preds.extend(classes)

100%|██████████| 1517/1517 [05:31<00:00,  4.57it/s]


In [106]:
with open(pred_path, 'w', encoding="utf-8") as f_pred:
    for idx, pred in enumerate(preds):
        name = test_helper.namelist[idx]
        cls = train_helper.vocabulary.class_by_index(pred)
        print(name, cls, file=f_pred)

In [107]:
gt_path = '/kaggle/input/gttttt/gt.txt'
evaluate(gt_path, pred_path)

0.9444472346601452