In [54]:
from torchvision.datasets import ImageFolder
from collections import Counter
import torch
import numpy as np
from torch.utils.data import Subset, DataLoader
from torchvision import transforms
from ocr_ensemble.classifiers import ClipEmbedding
from tqdm import tqdm
from sklearn.linear_model import LogisticRegressionCV, LogisticRegression
import pickle

# download and extract the data from
https://drive.google.com/file/d/1Ld4hp9J6j_IwkzxIMmfUX3uxHh_fc_4N/view?usp=share_link

In [55]:
class PILToNumpy:
    def __call__(self, pic):
        return np.array(pic)

In [56]:
emb = ClipEmbedding()

In [57]:
data = ImageFolder('../data/moe_classifier/', transform=transforms.Compose([PILToNumpy(),
                                                                            emb.get_transform()]))

In [58]:
# Count the class occurrences
class_counts = Counter(data.targets)

# Print class membership counts
for class_idx, count in class_counts.items():
    class_name = data.classes[class_idx]
    print(f"{class_name}: {count}")

handwritten: 2915
printed: 17567
scene: 11435


In [59]:
n = len(data)
indices = []
n_handwritten = 0
n_printed = 0
n_scene = 0
while n_handwritten < 1000 or n_printed < 1000 or n_scene < 1000:
    idx = np.random.randint(n)
    t = data.targets[idx]
    if t == 0 and n_handwritten < 1000:
        indices += [idx]
        n_handwritten += 1 
    elif t == 1 and n_printed < 1000:
        indices += [idx]
        n_printed += 1
    elif t == 2and n_scene < 1000:
        indices += [idx]
        n_scene += 1

In [60]:
subset = Subset(data, indices)

In [61]:
loader = DataLoader(subset, batch_size=100)

In [62]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
features = []
targets = []
for batch in tqdm(loader):
    feats = emb.encode_image_batch(batch[0].to(device)).detach().cpu().numpy()
    features += [feats]
    targets += [batch[1].detach().cpu().numpy()]
features = np.concatenate(features, axis=0)
targets = np.concatenate(targets, axis=0)

100%|██████████████████████████████████████████████████████████████████████████████████| 30/30 [00:30<00:00,  1.00s/it]


In [63]:
cv = LogisticRegressionCV(Cs=[0.1, 1, 5, 10, 100], refit=True)
cv.fit(features, targets)

In [64]:
cv.scores_

{0: array([[0.94166667, 0.96166667, 0.97      , 0.975     , 0.98166667],
        [0.965     , 0.97666667, 0.98166667, 0.98833333, 0.99      ],
        [0.95833333, 0.97      , 0.97833333, 0.98166667, 0.99      ],
        [0.955     , 0.97666667, 0.98166667, 0.98833333, 0.98833333],
        [0.95666667, 0.97333333, 0.98166667, 0.98166667, 0.98833333]]),
 1: array([[0.94166667, 0.96166667, 0.97      , 0.975     , 0.98166667],
        [0.965     , 0.97666667, 0.98166667, 0.98833333, 0.99      ],
        [0.95833333, 0.97      , 0.97833333, 0.98166667, 0.99      ],
        [0.955     , 0.97666667, 0.98166667, 0.98833333, 0.98833333],
        [0.95666667, 0.97333333, 0.98166667, 0.98166667, 0.98833333]]),
 2: array([[0.94166667, 0.96166667, 0.97      , 0.975     , 0.98166667],
        [0.965     , 0.97666667, 0.98166667, 0.98833333, 0.99      ],
        [0.95833333, 0.97      , 0.97833333, 0.98166667, 0.99      ],
        [0.955     , 0.97666667, 0.98166667, 0.98833333, 0.98833333],
       

In [49]:
with open('../models/moe_clf.pkl', 'wb') as f:
    pickle.dump(cv, f)

In [52]:
with open('../models/moe_clf.pkl', 'rb') as f:
    clf = pickle.load(f)

(clf.predict(features) == targets).mean()

0.9973333333333333

In [53]:
import json

with open('../models/moe_labels.json', 'w') as f:
    json.dump(['handwritten', 'printed', 'scene'], f)