In [1]:
import os
from datetime import datetime
from sklearn.metrics import accuracy_score
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
from pathlib import Path
import random
from PIL import ImageFilter
from transformers import ViTImageProcessor, ViTForImageClassification, TrainingArguments, Trainer

  from .autonotebook import tqdm as notebook_tqdm


### Load the model

In [2]:
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224', device_map="auto")
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224', device_map="auto")
model

ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=7

### Dataset

In [3]:
def get_idx(fname):
    if "(" not in fname:
        return 0
    return int(fname.split("(")[1].split(")")[0]) - 1

def get_image(fname):
    return Image.open(fname)

In [4]:
# Load training images and labels
DATA_SOURCE = "REAL"
train_path = Path(f"../data/cifake/train/{DATA_SOURCE}")
test_path = Path(f"../data/cifake/test/{DATA_SOURCE}")

train_images = [get_image(train_path / fname) for fname in os.listdir(train_path)]
train_labels = [get_idx(fname) for fname in os.listdir(train_path)]

train_images[0], train_labels[0]

(<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=32x32>, 5)

In [5]:
# Load testing images and labels
test_images = [get_image(test_path / fname) for fname in os.listdir(test_path)]
test_labels = [get_idx(fname) for fname in os.listdir(test_path)]

test_images[0], test_labels[0]

(<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=32x32>, 0)

In [6]:
# Define dataset class
class ImageDataset(Dataset):
    def __init__(self, images, labels, transform):
        super().__init__()
        self.images = images
        self.labels = labels
        self.transform = transform
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index):
        img = self.images[index]
        label = self.labels[index]
        if self.transform:
            img = self.transform(img)
        return {"pixel_values": img, "label": label}

# Self-define Gaussian blur function for data augmentation    
class GaussianBlur(object):
    def __init__(self, sigma=[.1, 2.]):
        self.sigma = sigma

    def __call__(self, x):
        sigma = random.uniform(self.sigma[0], self.sigma[1])
        x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
        return x

In [7]:
# Data augmentation
image_mean, image_std = processor.image_mean, processor.image_std
size = processor.size["height"]

train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(size),
    transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
    transforms.RandomGrayscale(p=0.2),
    transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=image_mean, std=image_std)
])

test_transforms = transforms.Compose([
    transforms.RandomResizedCrop(size),
    transforms.CenterCrop(size),
    transforms.ToTensor(),
    transforms.Normalize(mean=image_mean, std=image_std),
])

In [8]:
# Prepare training and testing data
train_dataset = ImageDataset(train_images, train_labels, train_transforms)
test_dataset = ImageDataset(test_images, test_labels, test_transforms)

In [9]:
# Define training arguments
timestamp = datetime.now().strftime("%y%m%d_%H%M")
args = TrainingArguments(
    output_dir=f"../results/result_{DATA_SOURCE}_{timestamp}",
    save_strategy="epoch",
    evaluation_strategy="epoch",
    optim="paged_adamw_32bit",
    learning_rate=2e-5,
    per_device_train_batch_size=128,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    lr_scheduler_type="constant",
    logging_dir='logs',
    remove_unused_columns=False,
)

In [10]:
# Use the Trainer object from Huggingface for training
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return dict(accuracy=accuracy_score(predictions, labels))

trainer = Trainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    compute_metrics=compute_metrics,
    tokenizer=processor,
)

In [11]:
# Train
trainer.train()

  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


Epoch,Training Loss,Validation Loss,Accuracy
1,No log,0.616484,0.7946
2,1.323400,0.5202,0.8224
3,0.574100,0.470092,0.8402


TrainOutput(global_step=1173, training_loss=0.8870121987579423, metrics={'train_runtime': 3176.0455, 'train_samples_per_second': 47.229, 'train_steps_per_second': 0.369, 'total_flos': 1.17277705101312e+19, 'train_loss': 0.8870121987579423, 'epoch': 3.0})