In [1]:
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm_notebook
from sklearn.metrics import f1_score, precision_score, recall_score
import datasets
import torch
import torchvision
import torchvision.transforms as T

In [2]:
# Set device
device = "cpu"
if torch.backends.mps.is_available():
    device = "mps"
elif torch.cuda_is_available():
    device = "cuda"

device = torch.device(device)

#Set Parameters for creating the Dataset
num_workers = 0
crop_size = 625
resize = 360
batch_size = 32

In [3]:
def to_3_channels(image):
    if image.shape[0] == 3:
        return image
    elif image.shape[0] == 1:
        return image.repeat(3, 1, 1)
    else:
        # Select the first 3 channels if the input has more than 3 channels
        return image[:3, :, :]

def collate_fn(examples):
    images, labels = [], []

    image_transform = T.Compose([
            T.RandomCrop(crop_size, pad_if_needed=True),
            T.Resize((resize,resize)),
            T.ToTensor()
        ])

    # Iterate through the examples, apply the image transformation, and append the results
    for example in examples:
        image = image_transform(example['image'])
        label = example['label']
        images.append(image)
        labels.append(label)

        pixel_values = torch.stack(images)
    labels = torch.tensor(labels)

    return {"pixel_values": pixel_values, "label": labels}

test_dataset = datasets.load_dataset("Hanneseh/MPDL_Project_1_custom_data", split="test")
test_dataloader = DataLoader(test_dataset, collate_fn=collate_fn, batch_size=batch_size, num_workers=num_workers, shuffle=True)

Found cached dataset imagefolder (/Users/pauladler/.cache/huggingface/datasets/Hanneseh___imagefolder/Hanneseh--MPDL_Project_1_custom_data-b0234636f7e76ba6/0.0.0/37fbb85cc714a338bea574ac6c7d0b5be5aff46c1862c1989b20e0771199e93f)


In [4]:
PATH = './../models/resnet50_dataset_3_lr_0_0001_final.pth'

model = torchvision.models.resnet50(pretrained=True)
model.fc = torch.nn.Linear(model.fc.in_features, 2)
model.load_state_dict(torch.load(PATH, map_location=device))
model.eval()
model = model.to(device)



In [5]:
correct, total, test_loss = 0, 0, 0
all_labels, all_predicted = [], []

with torch.no_grad():
    for element in tqdm_notebook(test_dataloader):
        inputs = element["pixel_values"].to(device)
        labels = element["label"].to(device)

        _, predicted = model(inputs).max(1)

        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
        all_labels.extend(labels.cpu().numpy())
        all_predicted.extend(predicted.cpu().numpy())

accuracy = 100. * correct / total

# Calculate F1-score, precision, and recall
f1 = f1_score(all_labels, all_predicted, average='binary')
precision = precision_score(all_labels, all_predicted, average='binary')
recall = recall_score(all_labels, all_predicted, average='binary')

print('Accuracy: %.3f | F1-score: %.3f | Precision: %.3f | Recall: %.3f' % (accuracy, f1, precision, recall))

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy: 88.681 | F1-score: 0.895 | Precision: 0.836 | Recall: 0.962


In [6]:
class customThreshold(torch.nn.Module):
    def __init__(self, fake_label):
        super(customThreshold, self).__init__()
        self.fake_label = fake_label

    def forward(self, x):
        prob = torch.sigmoid(x[:, self.fake_label]) # extract probability of fake
        pred = (prob >= 0.5).float() # threshold probability
        return (prob.view(-1, 1), pred.view(-1, 1))
    

model = torchvision.models.resnet50(pretrained=True)
model.fc = torch.nn.Linear(model.fc.in_features, 2)
model.load_state_dict(torch.load(PATH, map_location=device))
model.eval()
model = model.to(device)
model.fc = torch.nn.Sequential(model.fc, customThreshold(1))



In [7]:
correct, total, test_loss = 0, 0, 0
all_labels, all_predicted = [], []

with torch.no_grad():
    for element in tqdm_notebook(test_dataloader):
        inputs = element["pixel_values"].to(device)
        labels = element["label"].to(device)

        predicted = torch.squeeze(model(inputs)[1])
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
        all_labels.extend(labels.cpu().numpy())
        all_predicted.extend(predicted.cpu().numpy())

accuracy = 100. * correct / total

f1 = f1_score(all_labels, all_predicted, average='binary')
precision = precision_score(all_labels, all_predicted, average='binary')
recall = recall_score(all_labels, all_predicted, average='binary')

print('Accuracy: %.3f | F1-score: %.3f | Precision: %.3f | Recall: %.3f' % (accuracy, f1, precision, recall))

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy: 88.625 | F1-score: 0.894 | Precision: 0.835 | Recall: 0.963
