In [None]:
import torch

print("GPU avaliable: ", torch.cuda.is_available())
print("No. of GPU: ", torch.cuda.device_count())
print("GPU name: ", torch.cuda.get_device_name(0))
print("Device index: ", torch.cuda.current_device())
print(torch.__version__)

In [2]:
from datasets import load_dataset
import os
from datasets import concatenate_datasets
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer
from transformers import AutoImageProcessor
from transformers import TrainingArguments, Trainer
from transformers import EarlyStoppingCallback

import numpy as np
import torch
import evaluate
import torch.nn.functional as F
from scipy.special import softmax

#----img prepro func-----
from torchvision.transforms import (
    RandomCrop,
    CenterCrop,
    Compose,
    Normalize,
    RandomHorizontalFlip,
    RandomVerticalFlip,
    RandomResizedCrop,
    RandomApply,
    RandomChoice,
    RandomRotation,
    Resize,
    RandomErasing,
    ToTensor,
    ColorJitter,
    ToPILImage
)

# Set Model Checkpoint

In [3]:
model_checkpoint = "facebook/convnextv2-atto-1k-224" 

# Compute Metrics

In [5]:
accuracy = evaluate.load("accuracy")
precision = evaluate.load("precision")
recall = evaluate.load("recall")
f1 = evaluate.load("f1")
roc_auc = evaluate.load("roc_auc")

def compute_metrics(eval_pred):
    """Computes accuracy on a batch of predictions"""
    predictions = np.argmax(eval_pred.predictions, axis=1)
    predictions_prob = softmax(eval_pred.predictions, axis=1)[:,1]
    accuracy_score = accuracy.compute(predictions=predictions, references=eval_pred.label_ids)
    precision_score = precision.compute(predictions=predictions, references=eval_pred.label_ids, average=None)
    recall_score = recall.compute(predictions=predictions, references=eval_pred.label_ids, average=None)
    f1_score = f1.compute(predictions=predictions, references=eval_pred.label_ids, average=None)
    roc_auc_score = roc_auc.compute(prediction_scores=predictions_prob, references=eval_pred.label_ids)
    
    return {"accuracy": accuracy_score["accuracy"],
           "precision_0": precision_score["precision"][0],
           "precision_1": precision_score["precision"][1],
           "recall_0": recall_score["recall"][0],
            "recall_1": recall_score["recall"][1],
            "f1_0": f1_score["f1"][0],
           "f1_1": f1_score["f1"][1],
           "roc_auc": roc_auc_score["roc_auc"]}

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

# Load Dataset

In [6]:
#-----setup----
labels = ["REAL", "FAKE"]
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = i
    id2label[i] = label
    
def reformat_dataset(example):
    example["label"] = label2id[example["json"]["label"].upper()]
    return example

In [None]:
#-----load dataset----
WEBDATASET_ROOT = "data/webdataset/"
fnl = os.listdir(WEBDATASET_ROOT)
fnl = [f for f in fnl if 'ipynb_checkpoints' not in f]

train_real_data_l = {}
train_fake_data_l = {}
for dataset in fnl:
    train_data = load_dataset("webdataset", data_dir=WEBDATASET_ROOT+dataset, split="train")
    print(dataset, len(train_data))

    train_data = train_data.map(reformat_dataset, num_proc=os.cpu_count())
    train_real_data = train_data.filter(lambda x: x["label"] == 0)
    train_fake_data = train_data.filter(lambda x: x["label"] == 1)
    if dataset in ['ucf_selfie_dataset','covid_fmd_dataset','FMD_mask_data','mask_detection_data']:
        train_real_data = train_real_data.rename_column("jpg", "image")
        train_fake_data = train_fake_data.rename_column("jpg", "image")
    else:
        train_real_data = train_real_data.rename_column("png", "image")
        train_fake_data = train_fake_data.rename_column("png", "image")
    train_real_data = train_real_data.remove_columns(["json"])
    train_fake_data = train_fake_data.remove_columns(["json"])
    
    train_real_data_l[dataset] = train_real_data
    train_fake_data_l[dataset] = train_fake_data

valid_real_data_l = {}
valid_fake_data_l = {}
for dataset in fnl:
    valid_data = load_dataset("webdataset", data_dir=WEBDATASET_ROOT+dataset, split="validation")
    print(dataset,len(valid_data))
    
    valid_data = valid_data.map(reformat_dataset, num_proc=os.cpu_count())
    valid_real_data = valid_data.filter(lambda x: x["label"] == 0)
    valid_fake_data = valid_data.filter(lambda x: x["label"] == 1)
    if dataset in ['ucf_selfie_dataset','covid_fmd_dataset','FMD_mask_data','mask_detection_data']:
        valid_real_data = valid_real_data.rename_column("jpg", "image")
        valid_fake_data = valid_fake_data.rename_column("jpg", "image")
    else:
        valid_real_data = valid_real_data.rename_column("png", "image")
        valid_fake_data = valid_fake_data.rename_column("png", "image")
    valid_real_data = valid_real_data.remove_columns(["json"])
    valid_fake_data = valid_fake_data.remove_columns(["json"])
    
    valid_real_data_l[dataset] = valid_real_data
    valid_fake_data_l[dataset] = valid_fake_data

In [13]:
def preprocess_train(example_batch):
    """Apply train_transforms across a batch."""
    example_batch["pixel_values"] = [
        train_transforms(image.convert("RGB")) for image in example_batch["image"]
    ]
    return example_batch

def preprocess_val(example_batch):
    """Apply val_transforms across a batch."""
    example_batch["pixel_values"] = [val_transforms(image.convert("RGB")) for image in example_batch["image"]]
    return example_batch

In [None]:
import copy

train_ds = copy.deepcopy(final_train_dataset)
val_ds = copy.deepcopy(final_valid_dataset)

train_ds.set_transform(preprocess_train)
val_ds.set_transform(preprocess_val)
print(f'train_ds={len(train_ds)}, val_ds={len(val_ds)}')

# Model Training

In [None]:
#----training---
model = AutoModelForImageClassification.from_pretrained(
    model_checkpoint,
    label2id=label2id,
    id2label=id2label,
    ignore_mismatched_sizes = True, # provide this in case you're planning to fine-tune an already fine-tuned checkpoint
)

In [20]:
def train(config=None):
    model_name = model_checkpoint.split("/")[-1]
    
    args = TrainingArguments(
        f"{model_output}/{model_name}-{experiment_id}/{config.learning_rate}-{config.batch_size}-{config.epochs}",
        remove_unused_columns=False,
        evaluation_strategy = "steps",
        eval_steps = 1000,
        save_strategy = "steps",
        save_steps = 1000,
        learning_rate=config.learning_rate,
        per_device_train_batch_size=config.batch_size,
        gradient_accumulation_steps=1,
        per_device_eval_batch_size=1024,
        num_train_epochs=config.epochs,
        logging_strategy="steps",
        logging_steps=1000,
        load_best_model_at_end=True,
        metric_for_best_model="roc_auc",
        no_cuda=False,
        dataloader_num_workers=8,
        dataloader_prefetch_factor=2,
        dataloader_pin_memory=True
    )
    
    early_stop = EarlyStoppingCallback(early_stopping_patience=5)

    trainer = Trainer(
        model,
        args=args,
        train_dataset=train_ds,
        eval_dataset=val_ds,
        tokenizer=image_processor,
        compute_metrics=compute_metrics,
        data_collator=collate_fn,
        callbacks=[early_stop],
        # optimizers=(optimizer, None)
    )
    
    train_results = trainer.train()