In [None]:
import numpy as np
import pandas as pd
import os
import json

# input set up
taxo_name = "FFTT" # or "Baseline", "HFTT", FFTT"
taxo_idx = 3 # "Baseline": [0,1], "HFTT": [0,1,2,3], "FFTT": [0,1,2,3,4,5]
dataset_train = "train" 
dataset_test = "test"
dataset_extra = "scitsrcomp"
img_path_train = "td4cltabs/train"
img_path_test = "td4cltabs/test"
img_path_extra = "td4cltabs/SciTSRComp"

Baseline_mappings = ["Baseline_I", "Baseline_II"]
HFTT_mappings = ["HFTT_Novel_I", "HFTT_Novel_II", "HFTT_Novel_III", "HFTT_Novel_IV"]
FFTT_mappings = ["FFTT_Novel_I", "FFTT_Novel_II", "FFTT_Novel_III",
                "FFTT_Novel_IV", "FFTT_Novel_V", "FFTT_Novel_VI"]

if taxo_name == "Baseline":
    mappings = Baseline_mappings
    assert taxo_idx < 2
elif taxo_name == "HFTT":
    mappings = HFTT_mappings
    assert taxo_idx < 4
else:
    mappings = FFTT_mappings
    assert taxo_idx < 6

# read related files
with open(f"td4cltabs/metadata/labels_metadata.json", "r") as input_file:
    taxo_id2names = json.load(input_file)[mappings[taxo_idx]]
    if taxo_name != "FFTT":
        taxo_id2names = {int(k): v for k, v in taxo_id2names.items()}
    
train_df = pd.read_csv(f"td4cltabs/metadata/{mappings[taxo_idx]}/{dataset_train}.csv",
                        index_col=[0])
test_df = pd.read_csv(f"td4cltabs/metadata/{mappings[taxo_idx]}/{dataset_test}.csv",
                        index_col=[0])
scitsr_df = pd.read_csv(f"td4cltabs/metadata/{mappings[taxo_idx]}/{dataset_extra}.csv",
                        index_col=[0])

print("---------------------")
print("TD4DLTabs train No. of instances: {}".format(len(train_df[mappings[taxo_idx]].values)))
if taxo_name != "FFTT":
    taxo_freqs = train_df[mappings[taxo_idx]].value_counts().rename(index=taxo_id2names)
    for freq_name, freq_value in taxo_freqs.items():
        print("\tNo. of {}: {}".format(freq_name, freq_value/len(train_df)))
print("---------------------")
print("TD4DLTabs test No. of instances: {}".format(len(test_df[mappings[taxo_idx]].values)))
if taxo_name != "FFTT":
    taxo_freqs = test_df[mappings[taxo_idx]].value_counts().rename(index=taxo_id2names)
    for freq_name, freq_value in taxo_freqs.items():
        print("\tNo. of {}: {}".format(freq_name, freq_value))
print("---------------------")
print("Scitsrcomp No. of instances: {}".format(len(scitsr_df[mappings[taxo_idx]].values)))
if taxo_name != "FFTT":
    taxo_freqs = scitsr_df[mappings[taxo_idx]].value_counts().rename(index=taxo_id2names)
    for freq_name, freq_value in taxo_freqs.items():
        print("\tNo. of {}: {}".format(freq_name, freq_value))

In [None]:
import wandb
# key="5c8a29*********************"
# input with your key
wandb.login(key=input_with_your_key)

from sklearn.model_selection import train_test_split
from pathlib import Path
import os
import shutil
from tqdm import tqdm

from sklearn.model_selection import StratifiedKFold
import torch
from transformers import ViTImageProcessor
from torchvision.transforms import v2
from transformers import ViTForImageClassification

from torch.utils.data import Dataset
from PIL import Image

from torchvision.io import read_image
import random
from tqdm.notebook import tqdm
from transformers import TrainingArguments, Trainer

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, hamming_loss, multilabel_confusion_matrix
from datasets import load_dataset as ds_load_dataset, DatasetDict as ds_DatasetDict, Dataset as ds_Dataset, ClassLabel as ds_ClassLabel, Features as ds_Features, Image as ds_Image, Sequence as ds_Sequence

torch.manual_seed(42)

def compute_metrics(eval_pred):
    
    predictions, labels = eval_pred
    if taxo_name != "FFTT":
        predictions = np.argmax(predictions, axis=1)
        return dict(accuracy=accuracy_score(labels, predictions),
                   precision=precision_score(labels, predictions, average="weighted"),
                   recall=recall_score(labels, predictions, average="weighted"),
                   f1_score=f1_score(labels, predictions, average="weighted"),
                   confusion_matrix=confusion_matrix(labels, predictions).tolist())

    else:
        predictions = (predictions > 0.5).astype(int)  # Adjust the threshold as needed
        hamming_loss_value = hamming_loss(labels, predictions)
        multilabel_cm = multilabel_confusion_matrix(labels, predictions)

        precision_micro = precision_score(labels, predictions, average='micro')
        recall_micro = recall_score(labels, predictions, average='micro')
        f1_micro = f1_score(labels, predictions, average='micro')

        precision_macro = precision_score(labels, predictions, average='macro')
        recall_macro = recall_score(labels, predictions, average='macro')
        f1_macro = f1_score(labels, predictions, average='macro')

        return dict(
            accuracy=accuracy_score(labels, predictions),
            precision_micro=precision_micro,
            recall_micro=recall_micro,
            f1_score_micro=f1_micro,
            precision_macro=precision_macro,
            recall_macro=recall_macro,
            f1_score_macro=f1_macro,
            hamming_loss=hamming_loss_value,
            confusion_matrix=multilabel_cm.tolist()
        )

def collate_fn(examples):
    
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    
    if taxo_name != "FFTT":
        labels = torch.tensor([example["label"] for example in examples])
        return {"pixel_values": pixel_values, "labels": labels}
    else:
        labels = torch.FloatTensor([example["multi_label"] for example in examples])
        return {"pixel_values": pixel_values, "labels": labels}

In [None]:
metric_name = "eval_loss"
col = mappings[taxo_idx]
num_classes = len(taxo_id2names.keys())

processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k', do_rescale = False, return_tensors = 'pt')

image_mean, image_std = processor.image_mean, processor.image_std
size = processor.size["height"]

normalize = v2.Normalize(mean=image_mean, std=image_std)

train_transform = v2.Compose([
    v2.Resize((processor.size["height"], processor.size["width"])),
    v2.ToTensor(),
    v2.Pad(padding=0, padding_mode='constant'),
    normalize
 ])

test_transform = v2.Compose([
    v2.Resize((processor.size["height"], processor.size["width"])),
    v2.ToTensor(),
    v2.Pad(padding=0, padding_mode='constant'),
    normalize
])

def train_transforms(examples):
    examples['pixel_values'] = [train_transform(image.convert("RGB")) for image in examples['image']]
    return examples

def test_transforms(examples):
    examples['pixel_values'] = [test_transform(image.convert("RGB")) for image in examples['image']]
    return examples

kfold = 4

val_pct = {}
test_pct = {}

skf = StratifiedKFold(n_splits=kfold, shuffle=True, random_state=1)

images = [Image.open(f"{img_path_train}/{id}") for id in train_df['id'].values]
if taxo_name != "FFTT":
    labels = [str(lb) for lb in train_df[col].values]
else:
    labels = train_df[col].values

test_images = [Image.open(f"{img_path_test}/{id}") for id in test_df['id'].values]
if taxo_name != "FFTT":
    test_labels = [str(lb) for lb in test_df[col].values]
else:
    test_labels = [
        [1 if str(i + 1) in lbs.split(' ') else 0 for i in range(num_classes)]
        for lbs in test_df[col].values
    ]

for fold, (train_indices, val_indices) in enumerate(skf.split(images, labels)):
    print(fold)
    print(len(train_indices))
    
    if taxo_name != "FFTT":
        features = ds_Features(
            {
                "image": ds_Image(decode=True),
                "label": ds_ClassLabel(names=list(set(labels)))
            }
        )
        train_dataset = ds_Dataset.from_dict(
            {
                "image": [images[i] for i in train_indices],
                "label": [labels[i] for i in train_indices]
            },
            features=features,
        )

        val_dataset = ds_Dataset.from_dict(
            {
                "image": [images[i] for i in val_indices],
                "label": [labels[i] for i in val_indices]
            },
            features=features,
        )

        test_dataset = ds_Dataset.from_dict(
            {
                "image": test_images,
                "label": test_labels
            },
            features=features,
        )
        
        idx2label = {idx: label for idx, label in enumerate(labels)}
        label2idx = {label: idx for idx, label in enumerate(labels)}
    else:
        train_labels = [
            [1 if str(i + 1) in lbs.split(' ') else 0 for i in range(num_classes)]
            for lbs in train_df[col].values
        ]
        features = ds_Features(
            {
                "image": ds_Image(decode=True),
                "multi_label": ds_Sequence(ds_ClassLabel(names=[i for i in range(1,num_classes+1)]))
            }
        )
        train_dataset = ds_Dataset.from_dict(
            {
                "image": [images[i] for i in train_indices],
                "multi_label": [train_labels[i] for i in train_indices]
            },
            features=features,
        )

        val_dataset = ds_Dataset.from_dict(
            {
                "image": [images[i] for i in val_indices],
                "multi_label": [train_labels[i] for i in val_indices]
            },
            features=features,
        )

        test_dataset = ds_Dataset.from_dict(
            {
                "image": test_images,
                "multi_label": test_labels
            },
            features=features,
        )
    
    final_dataset = ds_DatasetDict({
        'train': train_dataset,
        'validation': val_dataset,
        'test': test_dataset
    })

    # Set the transforms
    final_dataset['train'].set_transform(train_transforms)
    final_dataset['validation'].set_transform(test_transforms)
    final_dataset['test'].set_transform(test_transforms)
    
    if taxo_name != "FFTT":
        model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k',
                                                      id2label=idx2label,
                                                      label2id=label2idx,
                                                      ignore_mismatched_sizes=True)
    else:
        model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k',
                                                      num_labels=num_classes,
                                                      problem_type="multi_label_classification")
        
    processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k', 
                                                  do_rescale=False, 
                                                  return_tensors='pt')
    
    args = TrainingArguments(
        f"table-classification",
        use_cpu = False,
        save_strategy="epoch",
        evaluation_strategy="epoch",
        learning_rate=2e-5,
        per_device_train_batch_size=16,
        per_device_eval_batch_size=16,
        num_train_epochs=15,
        weight_decay=0.01,
        load_best_model_at_end=True,
        metric_for_best_model=metric_name,
        logging_dir='logs',
        remove_unused_columns=False
    )

    # Train
    trainer = Trainer(
        model,
        args,
        train_dataset=final_dataset['train'],
        eval_dataset=final_dataset['validation'],
        data_collator=collate_fn,
        compute_metrics=compute_metrics,
        tokenizer=processor
    )
    
    trainer.train()
    
    outputs = trainer.predict(final_dataset['test'])
    
    print(outputs.metrics)
    
    try:
        for k,v in outputs.metrics.items():
            if k in val_pct:
                val_pct[k].append(v)
            else:
                val_pct[k] = []
                val_pct[k].append(v)
    except:
        continue

In [None]:
import pickle

with open("val_pct.pkl", "wb") as file:
    pickle.dump(val_pct, file)

In [None]:
for k,v in val_pct.items():
    print(f'{k} Validation: \tmean: {np.mean(v)} \tstd: {np.std(v)}')