In [2]:
import sys
sys.path.append("..")
from ocr_ensemble.data import load_dataset_1K, load_dataset_10K
from matplotlib import pyplot as plt
import random
from tqdm import tqdm
import numpy as np
import pandas as pd

import clip
import torch
import torchvision.transforms as T
import torch.nn.functional as F
from PIL import Image
from einops import rearrange
import json
from functools import partial
from sklearn.linear_model import LogisticRegressionCV, LogisticRegression
from sklearn.metrics import confusion_matrix
import pickle
import os

def identity(x):
    return x

with open('../data/laion2b-en-10K-labels.json', 'r') as f:
    labels = json.load(f)

def key2label(key, labels):
    return labels[key]

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
transform = T.ToPILImage()


clip_preproc = T.Lambda(lambda img: preprocess(transform(rearrange(torch.tensor(img), 'a b c -> c a b'))))

dataset = load_dataset_10K()
dataset2 = dataset.map_tuple(clip_preproc, identity, partial(key2label, labels=labels))
loader = torch.utils.data.DataLoader(dataset2, batch_size=50)

In [None]:
features = []
caption_features = []
targets = []
for imgs, captions, labels in tqdm(loader):
    feats = model.encode_image(imgs.to(device))
    caption_feats = model.encode_text(clip.tokenize([c[:77] for c in captions]).to(device))
    features += [feats.cpu().detach()]
    caption_features += [caption_feats.cpu().detach()]
    targets += [labels.detach()]
features = torch.cat(features, axis=0)
targets = torch.cat(targets, axis=0)
caption_features = torch.cat(caption_features, axis=0)

31it [00:14,  2.76it/s]

In [None]:
features_normalized = F.normalize(features.float(), dim=1)
caption_features_normalized = F.normalize(caption_features.float(), dim=1)
combined = torch.cat([features_normalized, caption_features_normalized], axis=1)
cv = LogisticRegressionCV(Cs=[0.1, 1, 5, 10, 100])
cv.fit(features_normalized, targets)
print(cv.scores_)

In [None]:
text_feature = F.normalize(caption_features[targets.bool()].mean(axis=0, keepdim=True).float(), dim=1)
text_feature = model.encode_text(clip.tokenize('contains handwritten or printed text').to(device)).cpu().detach()
text_feature = F.normalize(text_feature.float(), dim=1)
notext_feature = F.normalize(caption_features[~targets.bool()].mean(axis=0, keepdim=True).float(), dim=1)
label_embs = torch.cat([notext_feature, text_feature], axis=0)

In [None]:
(F.softmax((features_normalized.float() @ label_embs.T), dim=1)[:, 1] > 0.495).int().sum()

In [None]:
((F.softmax((features_normalized.float() @ label_embs.T), dim=1)[:, 1] > 0.495).int() == targets).float().sum() / len(targets)

In [None]:
confusion_matrix(targets, (F.softmax((features_normalized.float() @ label_embs.T), dim=1)[:, 1] > 0.495).int())

In [None]:
print((features_normalized.float() @ label_embs.T).argmax(axis=1).sum())
print(((features_normalized.float() @ label_embs.T).argmax(axis=1) == targets).float().sum() / len(targets))
print(confusion_matrix(targets, (features_normalized.float() @ label_embs.T).argmax(axis=1)))

In [None]:
# captions normalized: tensor(0.6759)
# captions unnormalized: tensor(0.4354)
# presence text, absense caption mean normalized: tensor(0.7191)
# img feats normalized: tensor(0.8250) 
# img feats unnormalized: tensor(0.5728)
# presence text, absense img feats mean normalized: tensor(0.5651)

In [None]:
n_train = int(0.8*len(features))
X_train = features[:n_train]
y_train = targets[:n_train]
X_test = features[n_train:]
y_test = targets[n_train:]
clf = LogisticRegression(C=10)
clf.fit(X_train, y_train)
print(confusion_matrix(y_test, clf.predict(X_test)))

In [None]:
clf = LogisticRegression(C=10)
clf.fit(features, targets)

In [None]:
os.makedirs('../models', exist_ok = True)
with open('../models/presence_clf.pkl', 'wb') as f:
    pickle.dump(clf, f)

In [None]:
with open('../models/presence_clf.pkl', 'rb') as f:
    clf = pickle.load(f)
print(confusion_matrix(y_test, clf.predict(X_test)))