In [6]:
from torchvision.datasets import ImageFolder
from collections import Counter
import torch
import numpy as np
from torch.utils.data import Subset, DataLoader
from torchvision import transforms
from ocr_ensemble.classifiers import ClipEmbedding
from tqdm import tqdm
from sklearn.linear_model import LogisticRegressionCV, LogisticRegression
import pickle
from collections import defaultdict

# download and extract the data from
https://drive.google.com/file/d/1BWMCgBJLEu1CDG4vrEwYkr1glQ2sG4x_/view?usp=sharing

In [2]:
class PILToNumpy:
    def __call__(self, pic):
        return np.array(pic)

In [3]:
emb = ClipEmbedding()

In [4]:
data = ImageFolder('../data/moe_classifier/', transform=transforms.Compose([PILToNumpy(),
                                                                            emb.get_transform()]))

In [5]:
# Count the class occurrences
class_counts = Counter(data.targets)

# Print class membership counts
for class_idx, count in class_counts.items():
    class_name = data.classes[class_idx]
    print(f"{class_name}: {count}")

handwritten: 2915
printed: 17567
scene: 11435
stage1: 86379


In [20]:
n = len(data)
indices = []

counters = {key: 0 for key in class_counts.keys()}
for key in counters.keys():
    subset = np.where(np.array(data.targets) == key)[0]
    while counters[key] < 1000:
        idx = np.random.randint(len(subset))
        t = data.targets[subset[idx]]
        counters[key] += 1
        indices += [subset[idx]]
print(counters)
            

{0: 1000, 1: 1000, 2: 1000, 3: 1000}


In [21]:
subset = Subset(data, indices)

In [22]:
loader = DataLoader(subset, batch_size=100)

In [23]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
features = []
targets = []
for batch in tqdm(loader):
    feats = emb.encode_image_batch(batch[0].to(device)).detach().cpu().numpy()
    features += [feats]
    targets += [batch[1].detach().cpu().numpy()]
features = np.concatenate(features, axis=0)
targets = np.concatenate(targets, axis=0)

100%|██████████████████████████████████████████████████████████████████████████████████| 40/40 [01:37<00:00,  2.44s/it]


In [26]:
cv = LogisticRegressionCV(Cs=[0.1, 1, 5, 10, 100, 1000], refit=True)
cv.fit(features, targets)

In [27]:
cv.scores_

{0: array([[0.85875, 0.91875, 0.9375 , 0.94   , 0.945  , 0.94   ],
        [0.8675 , 0.91125, 0.925  , 0.9275 , 0.9325 , 0.93   ],
        [0.87   , 0.91   , 0.93625, 0.9425 , 0.94   , 0.93875],
        [0.86125, 0.9225 , 0.94625, 0.95375, 0.95375, 0.95375],
        [0.87875, 0.91125, 0.935  , 0.9425 , 0.95125, 0.95875]]),
 1: array([[0.85875, 0.91875, 0.9375 , 0.94   , 0.945  , 0.94   ],
        [0.8675 , 0.91125, 0.925  , 0.9275 , 0.9325 , 0.93   ],
        [0.87   , 0.91   , 0.93625, 0.9425 , 0.94   , 0.93875],
        [0.86125, 0.9225 , 0.94625, 0.95375, 0.95375, 0.95375],
        [0.87875, 0.91125, 0.935  , 0.9425 , 0.95125, 0.95875]]),
 2: array([[0.85875, 0.91875, 0.9375 , 0.94   , 0.945  , 0.94   ],
        [0.8675 , 0.91125, 0.925  , 0.9275 , 0.9325 , 0.93   ],
        [0.87   , 0.91   , 0.93625, 0.9425 , 0.94   , 0.93875],
        [0.86125, 0.9225 , 0.94625, 0.95375, 0.95375, 0.95375],
        [0.87875, 0.91125, 0.935  , 0.9425 , 0.95125, 0.95875]]),
 3: array([[0.85875, 0.91

In [31]:
with open('../models/moe_clf_4.pkl', 'wb') as f:
    pickle.dump(cv, f)

In [32]:
with open('../models/moe_clf_4.pkl', 'rb') as f:
    clf = pickle.load(f)

(clf.predict(features) == targets).mean()

0.98475

In [34]:
import json

with open('../models/moe_labels_4.json', 'w') as f:
    json.dump(data.classes, f)