In [1]:
import open_clip

import torch
from PIL import Image
import random
import glob
import os
import numpy as np
import time
import pickle

from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

In [2]:
device = 'cuda' if torch.cuda.is_available() else "cpu"

In [3]:
os.makedirs('./models/', exist_ok=True) 

## Load model

In [4]:
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32-quickgelu', pretrained='laion400m_e32')
tokenizer = open_clip.get_tokenizer('ViT-B-32-quickgelu')

In [5]:
def get_image_embedding(image_path):
    image = Image.open(image_path)
    tensor = preprocess(image).unsqueeze(0).to(device)
    with torch.no_grad():
        embedding = model.encode_image(tensor)
    return embedding.flatten().cpu().numpy()


def compare_dates(pair):
    date1 = int(pair[0].split('/')[6][:-1])
    date2 = int(pair[1].split('/')[6][:-1])
    
    if date1 < date2:
        return 'younger'
    else:
        return 'older'
    

## prepare dataset

In [62]:
imglist = glob.glob('./data/HistoricalColor-ECCV2012/data/imgs/decade_database/**/*.jpg')
random.shuffle(imglist) 

In [116]:
#pairs = list(zip(*[iter(imglist)]*2))

num_pairs = 5000
pairs = random.choices(imglist, k=num_pairs*2)
pairs = [(pairs[i], pairs[i+1]) for i in range(0, num_pairs*2, 2)]

In [117]:
X = []
labels = []
for pair in pairs:
    embedding1 = get_image_embedding(pair[0])
    embedding2 = get_image_embedding(pair[1])
    features = np.concatenate((embedding1, embedding2))
    X.append(features)
    labels.append(compare_dates(pair))


In [None]:
X_train, X_test, y_train, y_test = train_test_split(X, labels, test_size=0.2)

## Train classifier

In [None]:
clf = LogisticRegression(random_state=0, max_iter=1000, verbose=0)
clf.fit(X_train, y_train)

timestamp = time.strftime("%Y%m%d-%H%M")

y_pred = clf.predict(X_test)

score =accuracy_score(y_test,y_pred)
print(score)

## Saving model

In [135]:
with open(f'./models/{timestamp}.pkl', 'wb') as f:
    pickle.dump(clf, f)
print('model saved!')

model saved!
