In [1]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
import torch.nn as nn
from torch.nn import CrossEntropyLoss, Sigmoid
from torchvision import transforms, utils
import open_clip
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from datetime import datetime
from multiprocessing import cpu_count
import pickle


os.environ["CUDA_VISIBLE_DEVICES"] = ""
device = 'cuda' if torch.cuda.is_available() else 'cpu'

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
class ArtGraphDataset(Dataset):
    def __init__(self, artworks_dir, artworks_csv, transform, train):
        self.labels = pd.read_csv(r'data/processed/tag_filtered.csv')['t.name'].values.tolist() # all tags
        self.n_labels = len(self.labels)
        self.artworks = []
        self.targets = []
        self.transform = transform
        self.train = train

        artwork_df = pd.read_csv(artworks_csv)
        artwork_list = artwork_df['image'].values.tolist()
        artwork_files = os.scandir(artworks_dir)
        for f in artwork_files:
            if f.name in artwork_list:
                i = artwork_df[artwork_df['image'] == f.name].index.values[0]
                self.artworks.append(os.path.join(artworks_dir, artwork_df['image'][i]))
                self.targets.append(list(dict.fromkeys(artwork_df['tag'][i].split(', '))))

    def __len__(self):
        return len(self.artworks)

    def __getitem__(self, item):
        if torch.is_tensor(item):
            item = item.tolist()

        artwork = self.transform(Image.open(self.artworks[item]).convert('RGB'))
        title = os.path.basename(self.artworks[item])
        index = [self.labels.index(t) for t in self.targets[item] if t in self.labels]
        if self.train:
            jolly = multi_hot_encoding(index, self.n_labels)
        else:
            jolly = index
            #target = self.targets[item]

        sample = {'artwork' : artwork, 'title' : title, 'jolly' : jolly}
        return sample

def multi_hot_encoding(index, n):
    zero = np.zeros((1, n))
    zero[0][index] = 1.0
    return zero[0]

## training

In [7]:
def train(train_artworks_dir, train_artworks_csv):
    BATCH_SIZE = 5
    GRAD_ACC = 10
    EPOCHS = 2
    LEARNING_RATE = 2e-5
    # model set up
    model, _, transform = open_clip.create_model_and_transforms('ViT-B-16', 'openai', device=device)
    model.ln_final = nn.Sequential(nn.Linear(in_features=512, out_features=274), nn.Sigmoid()) # becomes a multi-label classifier
    # freeze parameters
    for name, param in model.named_parameters():
        if param.requires_grad and not ('token_embedding' in name or 'ln_final' in name):
            param.requires_grad = False
    # optimizer set up
    non_frozen_params = [p for p in model.parameters() if p.requires_grad]
    opt = AdamW(non_frozen_params, lr=LEARNING_RATE)
    # loss set up
    criterion = CrossEntropyLoss()
    # prepare dataset
    dataset = ArtGraphDataset(artworks_dir=train_artworks_dir, artworks_csv=train_artworks_csv, transform=transform, train=True)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
    # training
    model.train()
    for epoch in range(EPOCHS):
      i = 0
      running_loss = 0.0
      print('[LOG] {} starting training epoch {}'.format(datetime.now(), epoch))
      for checkpoint, batch in enumerate(dataloader):
          prediction = model.ln_final(model.encode_image(batch['artwork'], normalize=True))
          loss = criterion(prediction, batch['jolly']) / GRAD_ACC
          loss.backward()
          running_loss += loss.item()
          i += prediction.size(0)
          # optimization step
          if i % (BATCH_SIZE * GRAD_ACC) == 0 or prediction.size(0) < BATCH_SIZE:
            opt.step()
            opt.zero_grad()
            print('[LOG] {} [EPOCH {}] [CHECKPOINT {}] optimization step'.format(datetime.now(), epoch, checkpoint))
          # saving checkpoint
          if (checkpoint+1) % 500 == 0:
            torch.save(model.state_dict(), r'models/base/vit-b-16.pt')
            avg_loss = running_loss / (checkpoint+1)
            print('[LOG] {} [EPOCH {}] [CHECKPOINT {}] current average loss: {}'.format(datetime.now(), epoch, checkpoint, avg_loss))
          # last batch, final average loss
          if prediction.size(0) < BATCH_SIZE:
            avg_loss = running_loss / len(dataloader)
            print('[LOG] {} [EPOCH {}] [CHECKPOINT {}] final average loss: {}'.format(datetime.now(), epoch, checkpoint, avg_loss))

    torch.save(model.state_dict(), r'models/base/vit-b-16.pt')
    print('[LOG] {} finished training!'.format(datetime.now()))

In [8]:
train_artworks_dir=r'dataset/train/images-filtered/'
train_artworks_csv=r'dataset/train/train.csv'

train(train_artworks_dir, train_artworks_csv)

[LOG] 2023-09-14 23:31:18.980066 starting training epoch 0
[LOG] 2023-09-14 23:31:31.872797 [EPOCH 0] [CHECKPOINT 9] optimization step
[LOG] 2023-09-14 23:31:45.266718 [EPOCH 0] [CHECKPOINT 19] optimization step
[LOG] 2023-09-14 23:31:59.184977 [EPOCH 0] [CHECKPOINT 29] optimization step
[LOG] 2023-09-14 23:32:13.454034 [EPOCH 0] [CHECKPOINT 39] optimization step
[LOG] 2023-09-14 23:32:27.818873 [EPOCH 0] [CHECKPOINT 49] optimization step
[LOG] 2023-09-14 23:32:41.799936 [EPOCH 0] [CHECKPOINT 59] optimization step
[LOG] 2023-09-14 23:32:55.801817 [EPOCH 0] [CHECKPOINT 69] optimization step
[LOG] 2023-09-14 23:33:11.060523 [EPOCH 0] [CHECKPOINT 79] optimization step
[LOG] 2023-09-14 23:33:29.656018 [EPOCH 0] [CHECKPOINT 89] optimization step
[LOG] 2023-09-14 23:33:48.236026 [EPOCH 0] [CHECKPOINT 99] optimization step
[LOG] 2023-09-14 23:34:06.143041 [EPOCH 0] [CHECKPOINT 109] optimization step
[LOG] 2023-09-14 23:34:23.865068 [EPOCH 0] [CHECKPOINT 119] optimization step
[LOG] 2023-09-14

## testing

In [None]:
# GATHER DATA FOR EVALUATION

def get_groundtruth(truth):
    groundtruth = []
    for i in truth:
        groundtruth.append(i.item())
    return groundtruth

def k_accuracy(groundtruth, indexes, k): # computes accuracy on a given denominator k
    correct = 0
    if k != 0:
        for g in groundtruth:
            if g in indexes:
                correct =+ 1
        accuracy = (correct / k)
    else:
        accuracy = 0.0
    return accuracy

def evaluate_accuracy(groundtruth, prediction):
    return k_accuracy(groundtruth, prediction, len(groundtruth))

def evaluate_top_k_accuracy(groundtruth, prediction):
    thresholds = [1, 5, 10, 20]
    top_k_accuracy = {}
    for k in thresholds:
        labels, indexes = torch.topk(prediction, k)
        top_k_accuracy[k] = k_accuracy(groundtruth, indexes.tolist(), k)
    return top_k_accuracy

def evaluate_confusion(groundtruth, prediction, confusion, labels):
    for l in labels:
        i = labels.index(l)
        if i not in groundtruth and i not in prediction: # true negative: nor target nor predicted
            confusion[labels[i]]['tn'] += 1
        elif i in groundtruth and i not in prediction: # false negative: target but not predicted
            confusion[labels[i]]['fn'] += 1
        elif i in groundtruth and i in prediction: # true positive: target and predicted
            confusion[labels[i]]['tp'] += 1
        elif i in prediction and i not in groundtruth: # false positive: not target but predicted
            confusion[labels[i]]['fp'] +=1
    return confusion

def save_evaluation(accuracy, top_k_accuracy, confusion, dim, path):
    pickle.dump(accuracy, open(os.path.join(path, 'accuracy'), 'wb+'))
    pickle.dump(top_k_accuracy, open(os.path.join(path, 'top_k_accuracy'), 'wb+'))
    pickle.dump(confusion, open(os.path.join(path, 'confusion'), 'wb+'))
    pickle.dump(dim, open(os.path.join(path, 'dim'), 'wb+'))
    print('[LOG] {} evaluation data correctly saved at {}!'.format(datetime.now(), path))

In [81]:
def test(model_dir, test_artwork_dir, test_artwork_csv, evaluation_dir):
    # model set up
    model, _, transform = open_clip.create_model_and_transforms('ViT-B-16', 'openai', device=device)
    model.ln_final = nn.Sequential(nn.Linear(in_features=512, out_features=274), nn.Sigmoid()) # becomes a multi-label classifier
    model.load_state_dict(torch.load(model_dir))
    BATCH_SIZE = 1
    # prepare dataset
    dataset = ArtGraphDataset(artworks_dir=test_artworks_dir, artworks_csv=test_artworks_csv, transform=transform, train=False)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)
    # set up data structures
    confusion = {}
    for l in dataset.labels:
        confusion.update({l : {'tp': 0, 'fp': 0, 'tn': 0, 'fn': 0}})
    accuracy = {}
    top_k_accuracy = {}

    print('[LOG] {} starting test...'.format(datetime.now()))
    with torch.no_grad():
        for checkpoint, batch in enumerate(dataloader):
            groundtruth = get_groundtruth(batch['jolly'])
            prediction = model.ln_final(model.encode_image(batch['artwork'], normalize=True))
            labels, indexes = torch.topk(prediction, len(groundtruth))
            title = batch['title'][0]
            print('------------------------------------------------------------------------------------')
            print('Artwork: {}'.format(title))
            print('Groundtruth: {}'.format([dataset.labels[g] for g in groundtruth]))
            print('Prediction: {}'.format([dataset.labels[g] for g in indexes.tolist()[0]]))
            accuracy.update({title : evaluate_accuracy(groundtruth, indexes.data[0].tolist())})
            top_k_accuracy.update({title : evaluate_top_k_accuracy(groundtruth, prediction.data[0])})
            confusion = evaluate_confusion(groundtruth, indexes.data[0].tolist(), confusion, dataset.labels)
            if checkpoint+1 % 100 == 0:
                avg = sum(accuracy.values()) / (checkpoint+1)
                print('------------------------------------------------------------------------------------')
                print('[LOG] {} average accuracy: {}'.format(datetime.now(), avg))

    print('[LOG] {} finished testing!'.format(datetime.now()))
    save_evaluation(accuracy_base, top_k_accuracy_base, confusion_base, dim, evaluation_dir)
    return accuracy, top_k_accuracy, confusion, len(dataset)

In [8]:
test_artworks_dir= r'dataset/test/images-filtered/'
test_artworks_csv= r'dataset/test/test.csv'
evaluation_dir = r'models/base/evaluation/'
model_dir = r'models/base/vit-b-16.pt'

accuracy_base, top_k_accuracy_base, confusion_base, dim = test(model_dir, test_artworks_dir, test_artworks_csv, evaluation_dir)

[LOG] 2023-09-15 11:06:57.385701 starting test...
------------------------------------------------------------------------------------
Artwork: abbott-handerson-thayer_the-angel-1903.jpg
Groundtruth: ['Lady']
Prediction: ['female-portraits']
------------------------------------------------------------------------------------
Artwork: abbott-handerson-thayer_village-street-dominica.jpg
Groundtruth: ['Sky', 'Sketch']
Prediction: ['Sketch', 'Sky']
------------------------------------------------------------------------------------
Artwork: abdul-mati-klarwein_landing-strip-1984.jpg
Groundtruth: ['Leaf', 'Water', 'Line', 'Botany', 'Pattern', 'Sky', 'Plant', 'sunlight', 'Tree']
Prediction: ['Natural landscape', 'Line', 'female-portraits', 'Sky', 'Tree', 'Rural area', 'Textile', 'Rock', 'Woody plant']
------------------------------------------------------------------------------------
Artwork: abdul-mati-klarwein_landscape-described-1963.jpg
Groundtruth: ['Geology', 'Rock']
Prediction: ['Sky

In [7]:
# COMPUTE EVALUATION METRICS

def compute_precision(label):
    den = (label['tp'] + label['fp'])
    if den != 0:
        precision = label['tp'] / den
    else:
        precision = 0.0
    return precision

def compute_recall(label):
    den = (label['tp'] + label['fn'])
    if den != 0:
        recall = label['tp'] / (label['tp'] + label['fn'])
    else:
        recall = 0.0
    return recall

def compute_f(precision, recall, b):
    den = precision+recall
    if den != 0:
        f = ((1+pow(b,2))*(precision*recall)) / (pow(b,2)*precision + recall)
    else:
        f = 0.0
    return f

def compute_metrics(accuracy, top_k_accuracy, confusion, dim, path):
    print('[LOG] {} computing metrics...'.format(datetime.now()))
    with open(path, 'w+') as f:
        # average accuracy
        avg_accuracy = sum(accuracy.values()) / dim
        #print('Average accuracy: {}'.format(avg_accuracy))
        f.write('Average accuracy: {}\n'.format(avg_accuracy))
        # average top-k accuracies
        thresholds = [1, 5, 10, 20]
        avg_top_k = {}
        keys = top_k_accuracy.keys()
        for k in thresholds:
            running_accuracy = 0.0
            for key in keys:
                running_accuracy += top_k_accuracy[key][k]
            avg_top_k[k] = running_accuracy / dim
        #print('Average top-k accuracy: {}'.format(avg_top_k))
        f.write('Average top-k accuracy: {}\n'.format(avg_top_k))
        # macro scores
        macro_precision = 0.0
        macro_recall = 0.0
        macro_f1 = 0.0
        macro_f2 = 0.0
        macro_f0 = 0.0
        keys = confusion.keys()
        metrics = {}
        for key in keys:
            precision = compute_precision(confusion[key])
            recall = compute_recall(confusion[key])
            f1 = compute_f(precision, recall, 1)
            f2 = compute_f(precision, recall, 2)
            f0 = compute_f(precision, recall, 0.5)
            metrics.update({key: {'p' : precision, 'r' : recall, 'f1' : f1, 'f2' : f2, 'f0' : f0}})
            #f.write('{} : {}\n'.format(key, metrics[key]))
            macro_precision += precision
            macro_recall += recall
            macro_f1 += f1
            macro_f2 += f2
            macro_f0 += f0
        ordered_metrics = sorted(metrics.items(), key = lambda x: x[1]['f1'])
        f.write('------------------------------------------------------------------------------------\n')
        for l in ordered_metrics:
            f.write('{}\n'.format(l))
        macro_precision = macro_precision / len(metrics)
        macro_recall = macro_recall / len(metrics)
        macro_f1 = macro_f1 / len(metrics)
        macro_f2 = macro_f2 / len(metrics)
        macro_f0 = macro_f0 / len(metrics)
        f.write('------------------------------------------------------------------------------------\n')
        f.write('Macro precision : {}\n'.format(macro_precision))
        f.write('Macro recall : {}\n'.format(macro_recall))
        f.write('Macro F-1 : {}\n'.format(macro_f1))
        f.write('Macro F-2 : {}\n'.format(macro_f2))
        f.write('Macro F-0.5 : {}\n'.format(macro_f0))
    f.close()
    print('[LOG] {} finished computing metrics! Results available at {}'.format(datetime.now(), path))

In [8]:
accuracy = pickle.load(open(r'models/base/evaluation/accuracy', 'rb'))
top_k_accuracy = pickle.load(open(r'models/base/evaluation/top_k_accuracy', 'rb'))
confusion = pickle.load(open(r'models/base/evaluation/confusion', 'rb'))
dim = pickle.load(open(r'models/base/evaluation/dim', 'rb'))

compute_metrics(accuracy, top_k_accuracy, confusion, dim, r'models/base/evaluation/metrics.txt')

[LOG] 2023-09-18 23:30:24.542280 computing metrics...
[LOG] 2023-09-18 23:30:24.557900 finished computing metrics! Results available at models/base/evaluation/metrics.txt
