In [1]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.nn import CrossEntropyLoss
from torchvision import transforms, utils
import open_clip
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from datetime import datetime
from multiprocessing import cpu_count
import pickle
from sklearn.metrics import confusion_matrix

os.environ["CUDA_VISIBLE_DEVICES"] = ""
device = 'cuda' if torch.cuda.is_available() else 'cpu'

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class ArtGraphDataset(Dataset):
    def __init__(self, artworks_dir, artworks_csv, transform, mode=None):
        self.labels = pd.read_csv(r'data/processed/tag_filtered.csv')['t.name'].values.tolist() # all tags
        self.n_labels = len(self.labels)
        self.artworks = []
        self.text = []
        #self.targets = []
        self.transform = transform
        #self.info = []
        self.mode = mode

        artwork_df = pd.read_csv(artworks_csv)
        artwork_list = artwork_df['image'].values.tolist()
        artwork_files = os.scandir(artworks_dir)
        for f in artwork_files:
            if f.name in artwork_list:
                i = artwork_df[artwork_df['image'] == f.name].index.values[0]
                if mode == 0: # caption
                    info = artwork_df['caption'][i]
                elif mode == 1: # info
                    info = artwork_df['info'][i]
                elif mode == 2: # caption + description
                    info = artwork_df['caption'][i] + artwork_df['info'][i]
                else:
                    info = list(dict.fromkeys(artwork_df['tag'][i].split(', '))) # labels
                self.text.append(info)
                self.artworks.append(os.path.join(artworks_dir, artwork_df['image'][i]))
                #self.targets.append(list(dict.fromkeys(artwork_df['tag'][i].split(', '))))

    def __len__(self):
        return len(self.artworks)

    def __getitem__(self, item):
        if torch.is_tensor(item):
            item = item.tolist()

        artwork = self.transform(Image.open(self.artworks[item]).convert('RGB'))
        title = os.path.basename(self.artworks[item])
        if self.mode != None: #training
            jolly = self.text[item] # info
        else: #testing
            jolly = [self.labels.index(t) for t in self.text[item] if t in self.labels] #indexes
            #text = multi_hot_encoding(index, self.n_labels)
        #target = self.targets[item]

        sample = {'artwork' : artwork, 'title' : title, 'jolly': jolly}
        return sample

def multi_hot_encoding(index, n):
    zero = np.zeros((1, n))
    zero[0][index] = 1.0
    return zero[0]

def compute_embeddings(dataset, model):
    tokenizer = open_clip.get_tokenizer('ViT-B-16')
    labels = tokenizer([f'An artwork of {l}.' for l in dataset.labels])
    embeddings = model.encode_text(labels, normalize = True)
    return embeddings

## training

In [None]:
def train(train_artworks_dir, train_artworks_csv, mode, model_dir):
    BATCH_SIZE = 5
    GRAD_ACC = 10
    EPOCHS = 2
    LEARNING_RATE = 2e-5
    # model set up
    model, _, transform = open_clip.create_model_and_transforms('ViT-B-16', 'openai', device=device)
    tokenizer = open_clip.get_tokenizer('ViT-B-16')
    # optimizer set up
    opt = AdamW(model.parameters(), lr=LEARNING_RATE)
    # loss set up
    criterion = CrossEntropyLoss()
    # prepare dataset
    dataset = ArtGraphDataset(artworks_dir=train_artworks_dir, artworks_csv=train_artworks_csv, mode=mode, transform=transform)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
    # training
    model.train()
    for epoch in range(EPOCHS):
      i = 0
      running_loss = 0.0
      print('[LOG] {} starting training epoch {}'.format(datetime.now(), epoch))
      for checkpoint, batch in enumerate(dataloader):
          text_embed = model.encode_text(tokenizer(batch['jolly']), normalize=True)
          image_embed = model.encode_image(batch['artwork'], normalize=True)
          prediction = torch.matmul(image_embed, text_embed.T).softmax(dim=1)
          s = prediction.size(0)
          truth = torch.eye(s)
          loss = criterion(prediction, truth) / GRAD_ACC
          loss.backward()
          running_loss += loss.item()
          i += s
          # optimization step
          if i % (BATCH_SIZE * GRAD_ACC) == 0 or s < BATCH_SIZE:
            opt.step()
            opt.zero_grad()
            print('[LOG] {} [EPOCH {}] [CHECKPOINT {}] optimization step'.format(datetime.now(), epoch, checkpoint))
          # saving checkpoint
          if (checkpoint+1) % 500 == 0:
            torch.save(model.state_dict(), os.path.join(model_dir, 'finetuned_{}.pt'.format(mode)))
            avg_loss = running_loss / (checkpoint+1)
            print('[LOG] {} [EPOCH {}] [CHECKPOINT {}] current average loss: {}'.format(datetime.now(), epoch, checkpoint, avg_loss))
          # last batch, final average loss
          if s < BATCH_SIZE:
            avg_loss = running_loss / len(dataloader)
            print('[LOG] {} [EPOCH {}] [CHECKPOINT {}] final average loss: {}'.format(datetime.now(), epoch, checkpoint, avg_loss))

    torch.save(model.state_dict(), os.path.join(model_dir, 'finetuned_{}.pt'.format(mode)))
    print('[LOG] {} finished training!'.format(datetime.now()))

In [4]:
mode = {'caption': 0,
        'info': 1,
        'joint': 2}

train_artworks_dir = r'dataset/train/images-filtered/'
train_artworks_csv = r'dataset/train/train.csv'

In [8]:
# CAPTION MODE
model_dir = r'models/fine-tuned/caption_0'
train(train_artworks_dir, train_artworks_csv, mode['caption'], model_dir)

[LOG] 2023-09-16 05:51:24.952118 starting training epoch 0
[LOG] 2023-09-16 05:52:14.956312 [EPOCH 0] [CHECKPOINT 9] optimization step
[LOG] 2023-09-16 05:53:30.819284 [EPOCH 0] [CHECKPOINT 19] optimization step
[LOG] 2023-09-16 05:54:36.715799 [EPOCH 0] [CHECKPOINT 29] optimization step
[LOG] 2023-09-16 05:55:33.712008 [EPOCH 0] [CHECKPOINT 39] optimization step
[LOG] 2023-09-16 05:56:25.340481 [EPOCH 0] [CHECKPOINT 49] optimization step
[LOG] 2023-09-16 05:57:21.655351 [EPOCH 0] [CHECKPOINT 59] optimization step
[LOG] 2023-09-16 05:58:17.957831 [EPOCH 0] [CHECKPOINT 69] optimization step
[LOG] 2023-09-16 05:59:11.507726 [EPOCH 0] [CHECKPOINT 79] optimization step
[LOG] 2023-09-16 06:00:06.166735 [EPOCH 0] [CHECKPOINT 89] optimization step
[LOG] 2023-09-16 06:01:01.372491 [EPOCH 0] [CHECKPOINT 99] optimization step
[LOG] 2023-09-16 06:01:54.734934 [EPOCH 0] [CHECKPOINT 109] optimization step
[LOG] 2023-09-16 06:02:49.406226 [EPOCH 0] [CHECKPOINT 119] optimization step
[LOG] 2023-09-16

In [5]:
# INFO MODE
model_dir = r'models/fine-tuned/info_1'
train(train_artworks_dir, train_artworks_csv, mode['info'], model_dir)

[LOG] 2023-09-17 01:41:55.256970 starting training epoch 0
[LOG] 2023-09-17 01:43:14.394544 [EPOCH 0] [CHECKPOINT 9] optimization step
[LOG] 2023-09-17 01:44:22.019283 [EPOCH 0] [CHECKPOINT 19] optimization step
[LOG] 2023-09-17 01:45:17.374984 [EPOCH 0] [CHECKPOINT 29] optimization step
[LOG] 2023-09-17 01:46:13.439911 [EPOCH 0] [CHECKPOINT 39] optimization step
[LOG] 2023-09-17 01:47:08.645663 [EPOCH 0] [CHECKPOINT 49] optimization step
[LOG] 2023-09-17 01:48:02.519344 [EPOCH 0] [CHECKPOINT 59] optimization step
[LOG] 2023-09-17 01:49:13.300594 [EPOCH 0] [CHECKPOINT 69] optimization step
[LOG] 2023-09-17 01:50:09.584252 [EPOCH 0] [CHECKPOINT 79] optimization step
[LOG] 2023-09-17 01:51:05.077047 [EPOCH 0] [CHECKPOINT 89] optimization step
[LOG] 2023-09-17 01:51:59.064339 [EPOCH 0] [CHECKPOINT 99] optimization step
[LOG] 2023-09-17 01:52:54.051395 [EPOCH 0] [CHECKPOINT 109] optimization step
[LOG] 2023-09-17 01:54:08.364498 [EPOCH 0] [CHECKPOINT 119] optimization step
[LOG] 2023-09-17

In [5]:
# INFO MODE
model_dir = r'models/fine-tuned/joint_2'
train(train_artworks_dir, train_artworks_csv, mode['joint'], model_dir)

[LOG] 2023-09-18 01:06:42.797184 starting training epoch 0
[LOG] 2023-09-18 01:07:34.316269 [EPOCH 0] [CHECKPOINT 9] optimization step
[LOG] 2023-09-18 01:08:25.304267 [EPOCH 0] [CHECKPOINT 19] optimization step
[LOG] 2023-09-18 01:09:23.967645 [EPOCH 0] [CHECKPOINT 29] optimization step
[LOG] 2023-09-18 01:10:17.397091 [EPOCH 0] [CHECKPOINT 39] optimization step
[LOG] 2023-09-18 01:11:06.796732 [EPOCH 0] [CHECKPOINT 49] optimization step
[LOG] 2023-09-18 01:12:05.173615 [EPOCH 0] [CHECKPOINT 59] optimization step
[LOG] 2023-09-18 01:13:00.764515 [EPOCH 0] [CHECKPOINT 69] optimization step
[LOG] 2023-09-18 01:13:51.621952 [EPOCH 0] [CHECKPOINT 79] optimization step
[LOG] 2023-09-18 01:14:43.675319 [EPOCH 0] [CHECKPOINT 89] optimization step
[LOG] 2023-09-18 01:15:38.570691 [EPOCH 0] [CHECKPOINT 99] optimization step
[LOG] 2023-09-18 01:16:30.269199 [EPOCH 0] [CHECKPOINT 109] optimization step
[LOG] 2023-09-18 01:17:23.827637 [EPOCH 0] [CHECKPOINT 119] optimization step
[LOG] 2023-09-18

## testing

In [3]:
# GATHER EVALUATION DATA

def get_groundtruth(truth):
    groundtruth = []
    for i in truth:
        groundtruth.append(i.item())
    return groundtruth

def k_accuracy(groundtruth, indexes, k): # computes accuracy on a given denominator k
    correct = 0
    if k != 0:
        for g in groundtruth:
            if g in indexes:
                correct =+ 1
        accuracy = (correct / k)
    else:
        accuracy = 0.0
    return accuracy

def evaluate_accuracy(groundtruth, prediction):
    return k_accuracy(groundtruth, prediction, len(groundtruth))

def evaluate_top_k_accuracy(groundtruth, prediction):
    thresholds = [1, 5, 10, 20]
    top_k_accuracy = {}
    for k in thresholds:
        labels, indexes = torch.topk(prediction, k)
        top_k_accuracy[k] = k_accuracy(groundtruth, indexes.tolist(), k)
    return top_k_accuracy

def evaluate_confusion(groundtruth, prediction, confusion, labels):
    for l in labels:
        i = labels.index(l)
        if i not in groundtruth and i not in prediction: # true negative: nor target nor predicted
            confusion[labels[i]]['tn'] += 1
        elif i in groundtruth and i not in prediction: # false negative: target but not predicted
            confusion[labels[i]]['fn'] += 1
        elif i in groundtruth and i in prediction: # true positive: target and predicted
            confusion[labels[i]]['tp'] += 1
        elif i in prediction and i not in groundtruth: # false positive: not target but predicted
            confusion[labels[i]]['fp'] +=1
    return confusion

def save_evaluation(accuracy, top_k_accuracy, confusion, dim, path):
    pickle.dump(accuracy, open(os.path.join(path, 'accuracy'), 'wb+'))
    pickle.dump(top_k_accuracy, open(os.path.join(path, 'top_k_accuracy'), 'wb+'))
    pickle.dump(confusion, open(os.path.join(path, 'confusion'), 'wb+'))
    pickle.dump(dim, open(os.path.join(path, 'dim'), 'wb+'))
    print('[LOG] {} evaluation data correctly saved at {}!'.format(datetime.now(), path))

In [4]:
def test(model_dir, test_artworks_dir, test_artworks_csv, evaluation_dir):
    # model set up
    model, _, transform = open_clip.create_model_and_transforms('ViT-B-16', 'openai', device=device)
    model.load_state_dict(torch.load(model_dir))
    BATCH_SIZE = 1
    # prepare dataset
    dataset = ArtGraphDataset(artworks_dir=test_artworks_dir, artworks_csv=test_artworks_csv, transform=transform)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)
    # compute labels embeddings
    labels_embed = compute_embeddings(dataset, model)
    # set up data structures
    confusion = {}
    for l in dataset.labels:
        confusion.update({l : {'tp': 0, 'fp': 0, 'tn': 0, 'fn': 0}})
    accuracy = {}
    top_k_accuracy = {}

    print('[LOG] {} starting test of model {}...'.format(datetime.now(), model_dir))
    with torch.no_grad():
        for checkpoint, batch in enumerate(dataloader):
            image_embed = model.encode_image(batch['artwork'], normalize=True)
            groundtruth = get_groundtruth(batch['jolly'])
            prediction = torch.matmul(image_embed, torch.clone(labels_embed.T))
            labels, indexes = torch.topk(prediction, len(groundtruth))
            title = batch['title'][0]
            print('------------------------------------------------------------------------------------')
            print('Artwork: {}'.format(title))
            print('Groundtruth: {}'.format([dataset.labels[g] for g in groundtruth]))
            print('Prediction: {}'.format([dataset.labels[g] for g in indexes.tolist()[0]]))
            accuracy.update({title : evaluate_accuracy(groundtruth, indexes.data[0].tolist())})
            top_k_accuracy.update({title : evaluate_top_k_accuracy(groundtruth, prediction.data[0])})
            confusion = evaluate_confusion(groundtruth, indexes.data[0].tolist(), confusion, dataset.labels)
            if checkpoint+1 % 100 == 0:
                avg = sum(accuracy.values()) / (checkpoint+1)
                print('------------------------------------------------------------------------------------')
                print('[LOG] {} average accuracy: {}'.format(datetime.now(), avg))

    print('[LOG] {} finished testing!'.format(datetime.now()))
    save_evaluation(accuracy, top_k_accuracy, confusion, len(dataset), evaluation_dir)
    return accuracy, top_k_accuracy, confusion, len(dataset)

In [12]:
# CAPTION MODE
test_artworks_dir = r'dataset/test/images-filtered/'
test_artworks_csv = r'dataset/test/test.csv'
evaluation_dir = r'models/fine-tuned/caption_0/evaluation/'
model_dir = r'models/fine-tuned/caption_0/finetuned_0.pt'

accuracy_0, top_k_accuracy_0, confusion_0, dim_0 = test(model_dir, test_artworks_dir, test_artworks_csv, evaluation_dir)

[LOG] 2023-09-16 16:10:30.125182 starting test...
------------------------------------------------------------------------------------
Artwork: abbott-handerson-thayer_the-angel-1903.jpg
Groundtruth: ['Lady']
Prediction: ['Lady']
------------------------------------------------------------------------------------
Artwork: abbott-handerson-thayer_village-street-dominica.jpg
Groundtruth: ['Sky', 'Sketch']
Prediction: ['Rural area', 'countryside']
------------------------------------------------------------------------------------
Artwork: abdul-mati-klarwein_landing-strip-1984.jpg
Groundtruth: ['Leaf', 'Water', 'Line', 'Botany', 'Pattern', 'Sky', 'Plant', 'sunlight', 'Tree']
Prediction: ['Geological phenomenon', 'Ecoregion', 'Atmospheric phenomenon', 'Geology', 'Formation', 'Sky', 'Biome', 'Line', 'Rectangle']
------------------------------------------------------------------------------------
Artwork: abdul-mati-klarwein_landscape-described-1963.jpg
Groundtruth: ['Geology', 'Rock']
Pred

In [5]:
# INFO MODE
test_artworks_dir = r'dataset/test/images-filtered/'
test_artworks_csv = r'dataset/test/test.csv'
evaluation_dir = r'models/fine-tuned/info_1/evaluation/'
model_dir = r'models/fine-tuned/info_1/finetuned_1.pt'

accuracy_1, top_k_accuracy_1, confusion_1, dim_1 = test(model_dir, test_artworks_dir, test_artworks_csv, evaluation_dir)

[LOG] 2023-09-18 11:52:21.387537 starting test of model models/fine-tuned/info_1/finetuned_1.pt...
------------------------------------------------------------------------------------
Artwork: abbott-handerson-thayer_the-angel-1903.jpg
Groundtruth: ['Lady']
Prediction: ['female-portraits']
------------------------------------------------------------------------------------
Artwork: abbott-handerson-thayer_village-street-dominica.jpg
Groundtruth: ['Sky', 'Sketch']
Prediction: ['countryside', 'Meadow']
------------------------------------------------------------------------------------
Artwork: abdul-mati-klarwein_landing-strip-1984.jpg
Groundtruth: ['Leaf', 'Water', 'Line', 'Botany', 'Pattern', 'Sky', 'Plant', 'sunlight', 'Tree']
Prediction: ['Monochrome photography', 'birds', 'Bird', 'Biome', 'Organism', 'Textile', 'Pattern', 'Graphic design', 'monochrome']
------------------------------------------------------------------------------------
Artwork: abdul-mati-klarwein_landscape-descri

In [8]:
# JOINT MODE
test_artworks_dir = r'dataset/test/images-filtered/'
test_artworks_csv = r'dataset/test/test.csv'
evaluation_dir = r'models/fine-tuned/joint_2/evaluation/'
model_dir = r'models/fine-tuned/joint_2/finetuned_2.pt'

accuracy_2, top_k_accuracy_2, confusion_2, dim_2 = test(model_dir, test_artworks_dir, test_artworks_csv, evaluation_dir)

[LOG] 2023-09-18 12:33:38.480251 starting test of model models/fine-tuned/joint_2/finetuned_2.pt...
------------------------------------------------------------------------------------
Artwork: abbott-handerson-thayer_the-angel-1903.jpg
Groundtruth: ['Lady']
Prediction: ['female-portraits']
------------------------------------------------------------------------------------
Artwork: abbott-handerson-thayer_village-street-dominica.jpg
Groundtruth: ['Sky', 'Sketch']
Prediction: ['Rural area', 'countryside']
------------------------------------------------------------------------------------
Artwork: abdul-mati-klarwein_landing-strip-1984.jpg
Groundtruth: ['Leaf', 'Water', 'Line', 'Botany', 'Pattern', 'Sky', 'Plant', 'sunlight', 'Tree']
Prediction: ['Pattern', 'Eye', 'Wave', 'Monochrome photography', 'Textile', 'Rectangle', 'monochrome', 'Turquoise', 'birds']
------------------------------------------------------------------------------------
Artwork: abdul-mati-klarwein_landscape-describ

In [36]:
# COMPUTE EVALUATION METRICS

def compute_precision(label):
    den = (label['tp'] + label['fp'])
    if den != 0:
        precision = label['tp'] / den
    else:
        precision = 0.0
    return precision

def compute_recall(label):
    den = (label['tp'] + label['fn'])
    if den != 0:
        recall = label['tp'] / (label['tp'] + label['fn'])
    else:
        recall = 0.0
    return recall

def compute_f(precision, recall, b):
    den = precision+recall
    if den != 0:
        f = ((1+pow(b,2))*(precision*recall)) / (pow(b,2)*precision + recall)
    else:
        f = 0.0
    return f

def compute_metrics(accuracy, top_k_accuracy, confusion, dim, path):
    print('[LOG] {} computing metrics...'.format(datetime.now()))
    with open(path, 'w+') as f:
        # average accuracy
        avg_accuracy = sum(accuracy.values()) / dim
        #print('Average accuracy: {}'.format(avg_accuracy))
        f.write('Average accuracy: {}\n'.format(avg_accuracy))
        # average top-k accuracies
        thresholds = [1, 5, 10, 20]
        avg_top_k = {}
        keys = top_k_accuracy.keys()
        for k in thresholds:
            running_accuracy = 0.0
            for key in keys:
                running_accuracy += top_k_accuracy[key][k]
            avg_top_k[k] = running_accuracy / dim
        #print('Average top-k accuracy: {}'.format(avg_top_k))
        f.write('Average top-k accuracy: {}\n'.format(avg_top_k))
        # macro scores
        macro_precision = 0.0
        macro_recall = 0.0
        macro_f1 = 0.0
        macro_f2 = 0.0
        macro_f0 = 0.0

        keys = confusion.keys()
        metrics = {}
        for key in keys:
            precision = compute_precision(confusion[key])
            recall = compute_recall(confusion[key])
            f1 = compute_f(precision, recall, 1)
            f2 = compute_f(precision, recall, 2)
            f0 = compute_f(precision, recall, 0.5)
            metrics.update({key: {'p' : precision, 'r' : recall, 'f1' : f1, 'f2' : f2, 'f0' : f0}})
            macro_precision += precision
            macro_recall += recall
            macro_f1 += f1
            macro_f2 += f2
            macro_f0 += f0
        ordered_metrics = sorted(metrics.items(), key = lambda x: x[1]['f1'])
        f.write('------------------------------------------------------------------------------------\n')
        for l in ordered_metrics:
            f.write('{}\n'.format(l))

        macro_precision = macro_precision / len(metrics)
        macro_recall = macro_recall / len(metrics)
        macro_f1 = macro_f1 / len(metrics)
        macro_f2 = macro_f2 / len(metrics)
        macro_f0 = macro_f0 / len(metrics)
        f.write('------------------------------------------------------------------------------------\n')
        f.write('Macro precision : {}\n'.format(macro_precision))
        f.write('Macro recall : {}\n'.format(macro_recall))
        f.write('Macro F-1 : {}\n'.format(macro_f1))
        f.write('Macro F-2 : {}\n'.format(macro_f2))
        f.write('Macro F-0.5 : {}\n'.format(macro_f0))

    f.close()
    print('[LOG] {} finished computing metrics! Results available at {}'.format(datetime.now(), path))

In [37]:
accuracy_0 = pickle.load(open(r'models/fine-tuned/caption_0/evaluation/accuracy', 'rb'))
top_k_accuracy_0 = pickle.load(open(r'models/fine-tuned/caption_0/evaluation/top_k_accuracy', 'rb'))
confusion_0 = pickle.load(open(r'models/fine-tuned/caption_0/evaluation/confusion', 'rb'))
dim_0 = pickle.load(open(r'models/fine-tuned/caption_0/evaluation/dim', 'rb'))

compute_metrics(accuracy_0, top_k_accuracy_0, confusion_0, dim_0, r'models/fine-tuned/caption_0/evaluation/metrics_0.txt')

[LOG] 2023-09-18 23:21:01.601187 computing metrics...
[LOG] 2023-09-18 23:21:01.616808 finished computing metrics! Results available at models/fine-tuned/caption_0/evaluation/metrics_0.txt


In [38]:
accuracy_1 = pickle.load(open(r'models/fine-tuned/info_1/evaluation/accuracy', 'rb'))
top_k_accuracy_1 = pickle.load(open(r'models/fine-tuned/info_1/evaluation/top_k_accuracy', 'rb'))
confusion_1 = pickle.load(open(r'models/fine-tuned/info_1/evaluation/confusion', 'rb'))
dim_1= pickle.load(open(r'models/fine-tuned/info_1/evaluation/dim', 'rb'))

compute_metrics(accuracy_1, top_k_accuracy_1, confusion_1, dim_1, r'models/fine-tuned/info_1/evaluation/metrics_1.txt')

[LOG] 2023-09-18 23:21:20.643577 computing metrics...
[LOG] 2023-09-18 23:21:20.643577 finished computing metrics! Results available at models/fine-tuned/info_1/evaluation/metrics_1.txt


In [39]:
accuracy_2 = pickle.load(open(r'models/fine-tuned/joint_2/evaluation/accuracy', 'rb'))
top_k_accuracy_2 = pickle.load(open(r'models/fine-tuned/joint_2/evaluation/top_k_accuracy', 'rb'))
confusion_2 = pickle.load(open(r'models/fine-tuned/joint_2/evaluation/confusion', 'rb'))
dim_2= pickle.load(open(r'models/fine-tuned/joint_2/evaluation/dim', 'rb'))

compute_metrics(accuracy_2, top_k_accuracy_2, confusion_2, dim_2, r'models/fine-tuned/joint_2/evaluation/metrics_2.txt')

[LOG] 2023-09-18 23:21:32.326306 computing metrics...
[LOG] 2023-09-18 23:21:32.326306 finished computing metrics! Results available at models/fine-tuned/joint_2/evaluation/metrics_2.txt
