In [None]:
import torch
from torch.utils.data import DataLoader
from transformers import ViTForImageClassification, ViTImageProcessor, Trainer, TrainingArguments
from sklearn.metrics import accuracy_score
import os
from torch.utils.data import Dataset
from PIL import Image


model_name = "dima806/ai_vs_real_image_detection"

processor = ViTImageProcessor.from_pretrained(model_name)

model = ViTForImageClassification.from_pretrained(model_name,num_labels = 2)



# freezing the backbone

for param in model.vit.parameters():
    param.requires_grad = False



class imagedataset(Dataset):

    def __init__(self,root_dir,processor):

        self.image_path = []
        self.labels = []
        self.processor = processor

        for label,folder in enumerate(['REAL','FAKE']):
            folder_path = os.path.join(root_dir,folder)

            for img_file in (os.listdir(folder_path)):
                self.image_path.append(os.path.join(folder_path,img_file))
                self.labels.append(label)

    def __len__(self):
        return len(self.image_path)

    def __getitem__(self,idx):

        image = Image.open(self.image_path[idx]).convert("RGB")

        inputs = self.processor(image,return_tensors = "pt")

        new_inputs = {}

        for key in inputs:

            new_inputs[key] = inputs[key].squeeze(0)   #remove the dimension

        new_inputs["labels"] = torch.tensor(self.labels[idx])

        return new_inputs


train_dataset = imagedataset(root_dir=r"/content/dataset/train",processor=processor)
test_dataset = imagedataset(root_dir="/content/dataset/test",processor=processor)



training_args = TrainingArguments(
    output_dir="./vit_finetune",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=2,
    learning_rate=1e-4,
    eval_strategy = "steps",
    eval_steps= 500,
    logging_dir="./logs",
    load_best_model_at_end=True,
    report_to="none"   # ðŸš€ disables wandb
)




def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = logits.argmax(axis=1)
    return {"accuracy": accuracy_score(labels, preds)}

trainer = Trainer(
    model =model,
    args = training_args,
    train_dataset = train_dataset,
    eval_dataset =test_dataset,
    tokenizer = processor,
    compute_metrics = compute_metrics
)

trainer.train()
