In [None]:
from datasets import load_dataset

dataset = load_dataset("imagefolder", data_dir="./inpaintings_dataset")

dataset

In [None]:
from transformers import AutoImageProcessor
import torch
from torch.utils.data import DataLoader
from torchvision import transforms


image_processor = AutoImageProcessor.from_pretrained("facebook/convnext-tiny-224")

data_augmentation_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
])

def transform(batch):
    augmented_images = [data_augmentation_transform(image) for image in batch['image']]
    
    inputs = image_processor(augmented_images, return_tensors='pt')

    inputs['label'] = batch['label']
    return inputs

dataset = dataset.with_transform(transform)

def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['label'] for x in batch])
    }

dataloader = DataLoader(dataset["train"], collate_fn=collate_fn, batch_size=16, shuffle=True)

In [None]:
from transformers import AutoModelForImageClassification

labels = dataset['train'].features['label'].names

model = AutoModelForImageClassification.from_pretrained("facebook/convnext-tiny-224",
                                                        num_labels=len(labels),
                                                        id2label={str(i): c for i, c in enumerate(labels)},
                                                        label2id={c: str(i) for i, c in enumerate(labels)},
                                                        ignore_mismatched_sizes=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model.to(device)

In [4]:
from tqdm.notebook import tqdm

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

model.train()
for epoch in range(4):
  print("Epoch:", epoch)
  correct = 0
  total = 0
  for idx, batch in enumerate(tqdm(dataloader)):
    batch = {k:v.to(device) for k,v in batch.items()}

    optimizer.zero_grad()

    outputs = model(pixel_values=batch["pixel_values"],
                    labels=batch["labels"])
    
    loss, logits = outputs.loss, outputs.logits
    loss.backward()
    optimizer.step()

    total += batch["labels"].shape[0]
    predicted = logits.argmax(-1)
    correct += (predicted == batch["labels"]).sum().item()

    accuracy = correct/total

    if idx % 100 == 0:
      print(f"Loss after {idx} steps:", loss.item())
      print(f"Accuracy after {idx} steps:", accuracy)

Epoch: 0


  0%|          | 0/445 [00:00<?, ?it/s]

Loss after 0 steps: 0.679779052734375
Accuracy after 0 steps: 0.5625
Loss after 100 steps: 0.6351832151412964
Accuracy after 100 steps: 0.5693069306930693
Loss after 200 steps: 0.6596294641494751
Accuracy after 200 steps: 0.6393034825870647
Loss after 300 steps: 0.6410562992095947
Accuracy after 300 steps: 0.6893687707641196
Loss after 400 steps: 0.28500252962112427
Accuracy after 400 steps: 0.7239713216957606
Epoch: 1


  0%|          | 0/445 [00:00<?, ?it/s]

Loss after 0 steps: 0.22560466825962067
Accuracy after 0 steps: 1.0
Loss after 100 steps: 0.11359579116106033
Accuracy after 100 steps: 0.8978960396039604
Loss after 200 steps: 0.07939431071281433
Accuracy after 200 steps: 0.9051616915422885
Loss after 300 steps: 0.43160301446914673
Accuracy after 300 steps: 0.9078073089700996
Loss after 400 steps: 0.08040226250886917
Accuracy after 400 steps: 0.9097568578553616
Epoch: 2


  0%|          | 0/445 [00:00<?, ?it/s]

Loss after 0 steps: 0.03684880957007408
Accuracy after 0 steps: 1.0
Loss after 100 steps: 0.11592613160610199
Accuracy after 100 steps: 0.9616336633663366
Loss after 200 steps: 0.013061465695500374
Accuracy after 200 steps: 0.9601990049751243
Loss after 300 steps: 0.2720889449119568
Accuracy after 300 steps: 0.9622093023255814
Loss after 400 steps: 0.0074321916326880455
Accuracy after 400 steps: 0.9608790523690773
Epoch: 3


  0%|          | 0/445 [00:00<?, ?it/s]

Loss after 0 steps: 0.1239117905497551
Accuracy after 0 steps: 0.9375
Loss after 100 steps: 0.04488851875066757
Accuracy after 100 steps: 0.9727722772277227
Loss after 200 steps: 0.02526453509926796
Accuracy after 200 steps: 0.9763681592039801
Loss after 300 steps: 0.019022051244974136
Accuracy after 300 steps: 0.9746677740863787
Loss after 400 steps: 0.04427985101938248
Accuracy after 400 steps: 0.9744389027431422


Testing

In [5]:
dataloader_test_inpaintings = DataLoader(dataset["test"], collate_fn=collate_fn, batch_size=8, shuffle=True)

In [6]:
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix, roc_curve, auc
import matplotlib.pyplot as plt


total = 0
correct = 0
all_predicted = []
all_labels = []

for idx, batch in enumerate(tqdm(dataloader_test_inpaintings)):
    batch = {k:v.to(device) for k,v in batch.items()}

    with torch.no_grad():
        outputs = model(pixel_values=batch["pixel_values"], labels=batch["labels"])

    loss, logits = outputs.loss, outputs.logits

    predicted = logits.argmax(-1)
    correct += (predicted == batch["labels"]).sum().item()
    total += batch["labels"].shape[0]

    all_predicted.extend(predicted.cpu().numpy())
    all_labels.extend(batch["labels"].cpu().numpy())


accuracy = correct/total

precision = precision_score(all_labels, all_predicted)
recall = recall_score(all_labels, all_predicted)
f1 = f1_score(all_labels, all_predicted)

labels = ['True Negative', 'False Positive', 'False Negative', 'True Positive']
conf_matrix = confusion_matrix(all_labels, all_predicted)

plt.show()
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")
print(f"Confusion Matrix:\n{conf_matrix}")


  0%|          | 0/376 [00:00<?, ?it/s]

Accuracy: 0.9318
Precision: 0.9470
Recall: 0.9148
F1 Score: 0.9306
Confusion Matrix:
[[1426   77]
 [ 128 1375]]
