In [25]:
import datasets
import numpy as np
from tqdm.notebook import tqdm_notebook
import torch
import torchvision
from torch.utils.data import DataLoader
from torchvision.models.resnet import ResNet50_Weights
from torchvision.transforms import Compose, Resize, ToTensor, RandomCrop, Lambda
import matplotlib.pyplot as plt
from sklearn.metrics import f1_score, precision_score, recall_score

class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = np.inf
        self.min_model = None

    def early_stop(self, validation_loss, model):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.min_model = model
            self.counter = 0
        elif validation_loss >= (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

In [2]:
################################# download dataset #################################
test_dataset = datasets.load_dataset("Hanneseh/MPDL_Project_1_custom_data", split="test")

Found cached dataset imagefolder (/home/hakim/.cache/huggingface/datasets/Hanneseh___imagefolder/Hanneseh--MPDL_Project_1_custom_data-b0234636f7e76ba6/0.0.0/37fbb85cc714a338bea574ac6c7d0b5be5aff46c1862c1989b20e0771199e93f)


In [26]:
################################# define the model and parameters #################################
# Set hyperparameters for training
num_epochs = 24
batch_size = 8
learning_rate = 0.001
num_workers = 2
patience = 5
crop_size = int(625 / 2)
weight_decay = 0.001

def to_3_channels(image):
    if image.shape[0] == 3:
        return image
    elif image.shape[0] == 1:
        return image.repeat(3, 1, 1)
    else:
        # Select the first 3 channels if the input has more than 3 channels
        return image[:3, :, :]

def collate_fn(examples):
    images, labels = [], []

    image_transform = Compose([
        # apply random square crop with max possible size of the current image, then resize to crop_sizexcrop_size
        RandomCrop(crop_size, pad_if_needed=True),
        Resize((crop_size,crop_size)),
        ToTensor(),
        Lambda(to_3_channels),
        ])

    # Iterate through the examples, apply the image transformation, and append the results
    for example in examples:
        image = image_transform(example['image'])
        label = example['label']
        images.append(image)
        labels.append(label)

        pixel_values = torch.stack(images)
    labels = torch.tensor(labels)

    return {"pixel_values": pixel_values, "label": labels}



test_dataloader = DataLoader(test_dataset, collate_fn=collate_fn, batch_size=batch_size, num_workers=num_workers, shuffle=True)

## load and modify the model
device = "cuda"
model = torchvision.models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
model.fc = torch.nn.Linear(model.fc.in_features, 2)
model_path = "resnet50_dataset_3_lr_0_0001_final.pth"
model.load_state_dict(torch.load(model_path, map_location=torch.device('cuda')))
model = model.to(device)

# Define the loss function and optimizer
criterion = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
stopper = EarlyStopper(patience=patience, min_delta=0)



In [27]:
class customThreshold(torch.nn.Module):
    def __init__(self, fake_label,th):
        super(customThreshold, self).__init__()
        self.fake_label = fake_label
        self.th = th

    def forward(self, x):
        prob = torch.sigmoid(x[:, self.fake_label]) # extract probability of fake
        pred = (prob >= self.th).float() # threshold probability
        return (prob.view(-1, 1), pred.view(-1, 1))

In [28]:
thresholds = np.arange(0.05,1,0.05)

In [30]:
best_acc = 0
best_th = 0
for threshold in thresholds:
    
    model = torchvision.models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
    model.fc = torch.nn.Linear(model.fc.in_features, 2)
    model_path = "resnet50_dataset_3_lr_0_0001_final.pth"
    model.load_state_dict(torch.load(model_path, map_location=torch.device('cuda')))
    model = model.to(device)
    model_th = model 
    model_th.fc = torch.nn.Sequential(model.fc, customThreshold(fake_label=1,th=threshold))
    model_th = model_th.to(device)
    model_th.eval()
    correct = 0
    total = 0
    test_loss = 0
    model_th.eval()  # Set the model to evaluation mode

    all_labels = []
    all_predicted = []

    with torch.no_grad():
        for element in tqdm_notebook(test_dataloader):
            # Move input and label tensors to the device
            inputs = element["pixel_values"].to(device)
            labels = element["label"].to(device)

            # Forward pass
            outputs = model_th(inputs)
            predicted = outputs[1]
            
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            # Store labels and predictions for metrics calculation
            all_labels.extend(labels.cpu().numpy())
            all_predicted.extend(predicted.cpu().numpy())

    accuracy = 100. * correct / total

    # Calculate F1-score, precision, and recall
    f1 = f1_score(all_labels, all_predicted, average='binary')
    precision = precision_score(all_labels, all_predicted, average='binary')
    recall = recall_score(all_labels, all_predicted, average='binary')

    print("threshold at ", threshold)
    print('Accuracy: %.3f' % ( accuracy))
    print('F1-score: %.3f | Precision: %.3f | Recall: %.3f' % (f1, precision, recall))

    if accuracy > best_acc:
        best_acc = accuracy
        best_th = threshold

  0%|          | 0/2000 [00:00<?, ?it/s]

threshold at  0.05
Accuracy: 406.988
F1-score: 0.702 | Precision: 0.542 | Recall: 0.995


  0%|          | 0/2000 [00:00<?, ?it/s]

threshold at  0.1
Accuracy: 411.012
F1-score: 0.719 | Precision: 0.566 | Recall: 0.986


  0%|          | 0/2000 [00:00<?, ?it/s]

threshold at  0.15000000000000002
Accuracy: 415.238
F1-score: 0.737 | Precision: 0.593 | Recall: 0.975


  0%|          | 0/2000 [00:00<?, ?it/s]

threshold at  0.2
Accuracy: 418.750
F1-score: 0.750 | Precision: 0.615 | Recall: 0.960


  0%|          | 0/2000 [00:00<?, ?it/s]

threshold at  0.25
Accuracy: 419.788
F1-score: 0.764 | Precision: 0.642 | Recall: 0.942


  0%|          | 0/2000 [00:00<?, ?it/s]

threshold at  0.3
Accuracy: 423.413
F1-score: 0.772 | Precision: 0.667 | Recall: 0.918


  0%|          | 0/2000 [00:00<?, ?it/s]

threshold at  0.35000000000000003
Accuracy: 425.962
F1-score: 0.779 | Precision: 0.695 | Recall: 0.887


  0%|          | 0/2000 [00:00<?, ?it/s]

threshold at  0.4
Accuracy: 428.225
F1-score: 0.782 | Precision: 0.726 | Recall: 0.848


  0%|          | 0/2000 [00:00<?, ?it/s]

threshold at  0.45
Accuracy: 427.137
F1-score: 0.782 | Precision: 0.760 | Recall: 0.806


  0%|          | 0/2000 [00:00<?, ?it/s]

threshold at  0.5
Accuracy: 428.788
F1-score: 0.778 | Precision: 0.795 | Recall: 0.761


  0%|          | 0/2000 [00:00<?, ?it/s]

KeyboardInterrupt: 