In [2]:
import open_clip

import torch
from PIL import Image
import random
import glob
import os
import numpy as np
import time
import pickle

from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.preprocessing import OrdinalEncoder, LabelEncoder

import mord


from imblearn.over_sampling import SMOTE

In [3]:
device = 'cuda' if torch.cuda.is_available() else "cpu"

In [4]:
os.makedirs('./models/', exist_ok=True) 

## Load model

In [5]:
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32-quickgelu', pretrained='laion400m_e32')
tokenizer = open_clip.get_tokenizer('ViT-B-32-quickgelu')

In [37]:
def make_prediction(image_path):
    eras = ["1930s", "1940s", "1950s", "1960s", "1970s"]
    image = Image.open(image_path)
    true_label = int(image_path.split('/')[6][:-1])
    tensor = preprocess(image).unsqueeze(0).to(device)
    text_inputs = tokenizer([f"a photo from the {c}" for c in eras])
    
    with torch.no_grad():
        image_features = model.encode_image(tensor)
        text_features = model.encode_text(text_inputs)
    
    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)
    similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
    values, indices = similarity[0].topk(1)
    prediction = [eras[indices], np.round(values.item(), 3), true_label]
    return prediction

In [40]:
image_path = imglist[10]

In [41]:
print(make_prediction(image_path))

['1960s', 0.531, 1930]


In [12]:
indices

tensor([3])

## prepare dataset

In [42]:
imglist = glob.glob('./data/HistoricalColor-ECCV2012/data/imgs/decade_database/**/*.jpg')
random.shuffle(imglist) 

In [43]:
predictions = [make_prediction(img) for img in imglist]

In [53]:
y_pred = [int(_[0][:-1]) for _ in predictions]
y_true = [_[2] for _ in predictions]

In [55]:
score =accuracy_score(y_true,y_pred)
print(score)

0.24754716981132074


## Saving model

In [146]:
timestamp = time.strftime("%Y%m%d-%H%M")

with open(f'./models/{timestamp}_ordinal.pkl', 'wb') as f:
    pickle.dump(clf, f)
print('model saved!')

model saved!
