In [None]:
from datasets import load_dataset

dataset = load_dataset("imagefolder", data_dir="./inpaintings_dataset")

dataset

In [2]:
from transformers import AutoImageProcessor
import torch
from torch.utils.data import DataLoader
from torchvision import transforms


image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")

data_augmentation_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
])

def transform(batch):
    augmented_images = [data_augmentation_transform(image) for image in batch['image']]
    
    inputs = image_processor(augmented_images, return_tensors='pt')

    inputs['label'] = batch['label']
    return inputs

dataset = dataset.with_transform(transform)

def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['label'] for x in batch])
    }

dataloader = DataLoader(dataset["train"], collate_fn=collate_fn, batch_size=8, shuffle=True)

In [None]:
from transformers import AutoModelForImageClassification

labels = dataset['train'].features['label'].names

model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224",
                                                        num_labels=len(labels),
                                                        id2label={str(i): c for i, c in enumerate(labels)},
                                                        label2id={c: str(i) for i, c in enumerate(labels)},
                                                        ignore_mismatched_sizes=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model.to(device)

In [4]:
from tqdm.notebook import tqdm

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

model.train()
for epoch in range(4):
  print("Epoch:", epoch)
  correct = 0
  total = 0
  for idx, batch in enumerate(tqdm(dataloader)):
    batch = {k:v.to(device) for k,v in batch.items()}

    optimizer.zero_grad()

    outputs = model(pixel_values=batch["pixel_values"],
                    labels=batch["labels"])
    
    loss, logits = outputs.loss, outputs.logits
    loss.backward()
    optimizer.step()

    total += batch["labels"].shape[0]
    predicted = logits.argmax(-1)
    correct += (predicted == batch["labels"]).sum().item()

    accuracy = correct/total

    if idx % 100 == 0:
      print(f"Loss after {idx} steps:", loss.item())
      print(f"Accuracy after {idx} steps:", accuracy)

Epoch: 0


  0%|          | 0/890 [00:00<?, ?it/s]

Loss after 0 steps: 0.7425731420516968
Accuracy after 0 steps: 0.625
Loss after 100 steps: 0.5238525867462158
Accuracy after 100 steps: 0.6287128712871287
Loss after 200 steps: 0.4505392014980316
Accuracy after 200 steps: 0.6660447761194029
Loss after 300 steps: 0.623492419719696
Accuracy after 300 steps: 0.6955980066445183
Loss after 400 steps: 0.24400413036346436
Accuracy after 400 steps: 0.7119700748129676
Loss after 500 steps: 0.3421115279197693
Accuracy after 500 steps: 0.7285429141716567
Loss after 600 steps: 0.18491344153881073
Accuracy after 600 steps: 0.7414725457570716
Loss after 700 steps: 0.27132880687713623
Accuracy after 700 steps: 0.7496433666191156
Loss after 800 steps: 0.3528473675251007
Accuracy after 800 steps: 0.7571785268414482
Epoch: 1


  0%|          | 0/890 [00:00<?, ?it/s]

Loss after 0 steps: 0.3223284184932709
Accuracy after 0 steps: 0.875
Loss after 100 steps: 0.15442070364952087
Accuracy after 100 steps: 0.8650990099009901
Loss after 200 steps: 0.34044331312179565
Accuracy after 200 steps: 0.8706467661691543
Loss after 300 steps: 0.05778809264302254
Accuracy after 300 steps: 0.8679401993355482
Loss after 400 steps: 0.22335697710514069
Accuracy after 400 steps: 0.8690773067331671
Loss after 500 steps: 0.4026683270931244
Accuracy after 500 steps: 0.8705089820359282
Loss after 600 steps: 0.23129986226558685
Accuracy after 600 steps: 0.8739600665557404
Loss after 700 steps: 0.07520920038223267
Accuracy after 700 steps: 0.8744650499286734
Loss after 800 steps: 0.05939134210348129
Accuracy after 800 steps: 0.8740636704119851
Epoch: 2


  0%|          | 0/890 [00:00<?, ?it/s]

Loss after 0 steps: 0.5150290131568909
Accuracy after 0 steps: 0.75
Loss after 100 steps: 0.026053674519062042
Accuracy after 100 steps: 0.9393564356435643
Loss after 200 steps: 0.024980856105685234
Accuracy after 200 steps: 0.9458955223880597
Loss after 300 steps: 0.800285816192627
Accuracy after 300 steps: 0.9289867109634552
Loss after 400 steps: 0.075662761926651
Accuracy after 400 steps: 0.9248753117206983
Loss after 500 steps: 0.11949513107538223
Accuracy after 500 steps: 0.9281437125748503
Loss after 600 steps: 0.160184845328331
Accuracy after 600 steps: 0.927828618968386
Loss after 700 steps: 0.709733247756958
Accuracy after 700 steps: 0.925641940085592
Loss after 800 steps: 0.035977721214294434
Accuracy after 800 steps: 0.9257178526841449
Epoch: 3


  0%|          | 0/890 [00:00<?, ?it/s]

Loss after 0 steps: 0.054333530366420746
Accuracy after 0 steps: 1.0
Loss after 100 steps: 0.15633687376976013
Accuracy after 100 steps: 0.9616336633663366
Loss after 200 steps: 0.008315585553646088
Accuracy after 200 steps: 0.9595771144278606
Loss after 300 steps: 0.015616951510310173
Accuracy after 300 steps: 0.9580564784053156
Loss after 400 steps: 0.12005084753036499
Accuracy after 400 steps: 0.956359102244389
Loss after 500 steps: 0.044057004153728485
Accuracy after 500 steps: 0.9568363273453094
Loss after 600 steps: 0.10863806307315826
Accuracy after 600 steps: 0.9534109816971714
Loss after 700 steps: 0.005595153197646141
Accuracy after 700 steps: 0.9518544935805991
Loss after 800 steps: 0.014289742335677147
Accuracy after 800 steps: 0.950374531835206


Testing

In [5]:
dataloader_test_inpaintings = DataLoader(dataset["test"], collate_fn=collate_fn, batch_size=8, shuffle=True)

In [6]:
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix, roc_curve, auc
import matplotlib.pyplot as plt


total = 0
correct = 0
all_predicted = []
all_labels = []

for idx, batch in enumerate(tqdm(dataloader_test_inpaintings)):
    batch = {k:v.to(device) for k,v in batch.items()}

    with torch.no_grad():
        outputs = model(pixel_values=batch["pixel_values"], labels=batch["labels"])

    loss, logits = outputs.loss, outputs.logits

    predicted = logits.argmax(-1)
    correct += (predicted == batch["labels"]).sum().item()
    total += batch["labels"].shape[0]

    all_predicted.extend(predicted.cpu().numpy())
    all_labels.extend(batch["labels"].cpu().numpy())


accuracy = correct/total

precision = precision_score(all_labels, all_predicted)
recall = recall_score(all_labels, all_predicted)
f1 = f1_score(all_labels, all_predicted)

labels = ['True Negative', 'False Positive', 'False Negative', 'True Positive']
conf_matrix = confusion_matrix(all_labels, all_predicted)

plt.show()
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")
print(f"Confusion Matrix:\n{conf_matrix}")


  0%|          | 0/376 [00:00<?, ?it/s]

Accuracy: 0.8543
Precision: 0.8878
Recall: 0.8110
F1 Score: 0.8477
Confusion Matrix:
[[1349  154]
 [ 284 1219]]
