Install necessary dependencies

In [None]:
!pip install opencv-python datasets tensorboard transformers torchvision
!pip install torch --index-url https://download.pytorch.org/whl/cu117 # for cuda

Import all needed packages

In [None]:
import torch
import torch.nn as nn
import numpy as np
import cv2 as cv
import matplotlib.pyplot as plt
import pickle
from torch.utils.data import DataLoader
from torch.nn.modules.loss import BCEWithLogitsLoss
from torchvision import models
from torchvision.models import ResNet50_Weights
from datasets import load_dataset, concatenate_datasets
from torchvision.transforms import Normalize, Compose, RandomResizedCrop, RandomHorizontalFlip, RandomVerticalFlip, ToTensor, Resize, CenterCrop, GaussianBlur, RandomApply
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

Download dataset, remove unnecessary columns and add label 0 to real images and 1 to generated images.
Then split the dataset into a train (80%), validation (10%) and test dataset (10%).

In [None]:
def get_dataset():
    # datasets
    fake = load_dataset('poloclub/diffusiondb', '2m_random_10k', split='train', data_dir='./')
    real = load_dataset('frgfm/imagenette', '320px', split='train+validation', data_dir='./')

    # remove unnecessary columns
    fake = fake.remove_columns(
        ['prompt', 'seed', 'step', 'cfg', 'sampler', 'width', 'height', 'user_name', 'timestamp', 'image_nsfw',
         'prompt_nsfw'])
    real = real.remove_columns('label')

    # add label column with 0 for real images and for 1 for generated images
    fake = fake.map(lambda x: {'image': x['image'], 'label': 1})
    real = real.map(lambda x: {'image': x['image'], 'label': 0})

    # split fake dataset into train, validation and test sets
    fake_train_testvalid = fake.train_test_split(test_size=0.2)
    fake_test_valid = fake_train_testvalid['test'].train_test_split(test_size=0.5)

    # split real dataset into train, validation and test sets
    real_train_testvalid = real.train_test_split(test_size=0.2)
    real_test_valid = real_train_testvalid['test'].train_test_split(test_size=0.5)

    # combine fake and real datasets into single dataset for each split
    train_dataset = concatenate_datasets([fake_train_testvalid['train'], real_train_testvalid['train']])
    val_dataset = concatenate_datasets([fake_test_valid['train'], real_test_valid['train']])
    test_dataset = concatenate_datasets([fake_test_valid['test'], real_test_valid['test']])

    return train_dataset, val_dataset, test_dataset

Transform and preprocess data. Convert the images to pixel_values (tensors).

In [None]:
class HighpassFilter:
    def __init__(self):
        pass

    def __call__(self, img):
        kernel = np.array([[0.0, -1.0, 0.0],
                           [-1.0, 5.0, -1.0],
                           [0.0, -1.0, 0.0]])

        kernel = kernel / (np.sum(kernel) if np.sum(kernel) != 0 else 1)

        return cv.filter2D(np.array(img), -1, kernel)


highpass = HighpassFilter()

# transform/preprocess input
train_transforms = Compose([
    RandomResizedCrop(224),
    RandomHorizontalFlip(),
    RandomVerticalFlip(),
    RandomApply([GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5))], p=0.5),
    RandomApply([highpass], p=0.5),
    ToTensor(),
    Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

val_transforms = Compose([
    Resize(256),
    CenterCrop(224),
    ToTensor(),
    Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

test_transforms = Compose([
    Resize((224, 224)),
    ToTensor(),
    Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])


def transform_train(batch):
    batch['pixel_values'] = [train_transforms(img.convert("RGB")) for img in batch['image']]
    del batch['image']
    return batch


def transform_val(batch):
    batch['pixel_values'] = [val_transforms(img.convert("RGB")) for img in batch['image']]
    del batch['image']
    return batch


def transform_test(batch):
    batch['pixel_values'] = [test_transforms(img.convert("RGB")) for img in batch['image']]
    del batch['image']
    return batch


Help Functions

In [None]:
def iterate_dataloader(dataloader, device):
    """
    Iterate through given dataloader and return image and label data.
    """
    data = next(iter(dataloader))
    images = data['pixel_values'].to(device)
    labels = data['label'].to(device)

    return images, labels

def inv_trans(image):
    """
    Inverse transformations so that colors are in normal range.
    """
    inverse_trans = Compose([
        Normalize(mean=[0., 0., 0.], std=[1 / 0.229, 1 / 0.224, 1 / 0.225]),
        Normalize(mean=[-0.485, -0.456, -0.406], std=[1., 1., 1.])])

    return inverse_trans(image)


def imshow(images, labels):
    """
    Display images (batch) in one figure with labels.
    """
    fig = plt.figure(figsize=(8, 5))
    rows, columns = 2, 4

    for i in range(8):
        unnormalize_img = inv_trans(images[i]).cpu().numpy()
        if int(labels[i]) == 0: label = 'real'
        else: label = 'generated'

        fig.add_subplot(rows, columns, i + 1)
        plt.imshow(np.transpose(unnormalize_img, (1, 2, 0)))
        plt.axis('off')
        plt.title(label)

    fig.tight_layout()
    plt.show()


def get_best_loss():
    """
    Get the best loss from previous runs if available.
    """
    try:
        with open('./model_data/best_loss.pkl', 'rb') as file:
            best_loss = pickle.load(file)
        print(f'\nbest loss: {best_loss}')
    except FileNotFoundError:
        best_loss = 42

    return best_loss


def get_image_label(data, device):
    """
    Get image and label data. Convert label data to the same nn output shape.
    """
    image = data['pixel_values'].to(device)
    label = data['label'].unsqueeze(1).float().to(device)

    return image, label


def save_model_with_best_loss(model, val_loss, device):
    """
    Saves best model and best loss. Returns the best loss.
    """
    with torch.no_grad():
        traced = torch.jit.trace(model, torch.rand(1, 3, 224, 224).to(device))
    torch.jit.save(traced, './model_data/resnet50.pth')

    # save the best loss for next run
    with open('model_data/best_loss_ds1.pkl', 'wb') as file: pickle.dump(val_loss, file)
    print('model updated')

    return val_loss


def plot_loss(epoch_train_losses, epoch_test_losses, epochs):
    epochs = list(range(1, epochs+1))
    plt.plot(epochs, epoch_train_losses, 'r', label='Train Loss')
    plt.plot(epochs, epoch_test_losses, 'g', label='Validation Loss')
    plt.title('Train and Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()


def plot_acc(total_train_acc, total_val_acc, epochs):
    epochs = list(range(1, epochs+1))
    plt.plot(epochs, total_train_acc, 'r', label='Train Accuracy')
    plt.plot(epochs, total_val_acc, 'g', label='Validation Accuracy')
    plt.title('Train and Validation Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.show()

Training and Validation of model. Model is saved when validation loss decreases.
If model has not improved in a while, stop training early.

In [None]:
def train_model(n_epochs, model, train_loader, device, optimizer, loss_fn, val_loader, best_loss):
    total_train_losses = []
    total_val_losses = []
    total_train_acc = []
    total_val_acc = []
    early_stopping_counter = 0
    early_stopping_tolerance = 5

    writer = SummaryWriter("model_data/logs")

    for epoch in range(n_epochs):
        # enter training mode
        model.train()
        print('Starting training...')

        train_loss, total, correct = 0, 0, 0

        for _, data in enumerate(tqdm(train_loader)):  # iterate over batches
            image, label = get_image_label(data, device)

            # zero gradient
            optimizer.zero_grad()

            with torch.set_grad_enabled(True):
                # make prediction by giving model image
                output = model(image)
                probabilities = torch.sigmoid(output)

                # compute loss
                loss = loss_fn(output, label)

            train_loss += loss / len(train_loader)

            # calculate accuracy for training
            predictions = probabilities > 0.5
            correct += (predictions == label).sum().item()
            total += predictions.size(0)

            # perform backpropagation, update weights on model
            loss.backward()
            optimizer.step()

        total_train_losses.append(train_loss)
        print('\nEpoch : {}, train loss : {}'.format(epoch + 1, train_loss))
        writer.add_scalar('loss/training', train_loss, epoch)

        # calculate training accuracy
        train_acc = float(correct) / float(total) * 100
        total_train_acc.append(train_acc)
        print("Got {} / {} with training accuracy {}".format(correct, total, train_acc))
        writer.add_scalar('accuracy/training', train_acc, epoch)

        # validation doesnt requires gradient
        with torch.no_grad():
            # model to eval mode
            model.eval()

            val_loss, total, correct = 0, 0, 0

            for data in val_loader:
                image, label = get_image_label(data, device)

                # give model image to receive prediction
                output = model(image)
                probabilities = torch.sigmoid(output)

                # calculate validation loss
                v_loss = loss_fn(output, label)
                val_loss += v_loss / len(val_loader)

                # calculate accuracy for validation
                predictions = probabilities > 0.5
                correct += (predictions == label).sum().item()
                total += predictions.size(0)

            total_val_losses.append(val_loss)
            print('Epoch : {}, val loss : {}'.format(epoch + 1, val_loss))
            writer.add_scalar('loss/validation', val_loss, epoch)

            # calculate validation accuracy
            val_acc = float(correct) / float(total) * 100
            total_val_acc.append(val_acc)
            print("Got {} / {} with validation accuracy {}".format(correct, total, val_acc))
            writer.add_scalar('accuracy/validation', val_acc, epoch)

            print('val_loss: ', val_loss)
            print('best_loss ', best_loss)

            # save best model
            if val_loss <= best_loss: best_loss = save_model_with_best_loss(model, val_loss, device)

            # early stopping
            if val_loss > best_loss:
                early_stopping_counter += 1
                print(f'early stopping counter {early_stopping_counter} / {early_stopping_tolerance}')

            if early_stopping_counter == early_stopping_tolerance:
                print(f'early stopping counter {early_stopping_counter} / {early_stopping_tolerance}')
                print("Terminating: early stopping")
                break  # terminate training

    return total_train_losses, total_val_losses, total_train_acc, total_val_acc

Execute Training.

In [None]:
# used device for computing
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device, 'will be used.')

# get the dataset
train_dataset, val_dataset, test_dataset = get_dataset()

# apply transforms and preprocessing to dataset
train_dataset.set_transform(transform_train)
val_dataset.set_transform(transform_val)
test_dataset.set_transform(transform_test)

# load data
batch_size = 16
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=0)

# showcase some training images
images, labels = iterate_dataloader(train_loader, device)
imshow(images, labels)

# define model
model = models.resnet50(weights=ResNet50_Weights.DEFAULT)

# freeze all parameters
for params in model.parameters(): params.requires_grad_ = False

# add new final layer to customize model to become binary classifier
nr_filters = model.fc.in_features  # number of input features of last layer
model.fc = nn.Linear(nr_filters, 1)

# load trained model if there is one
try:
    model = torch.jit.load('./model_data/resnet50.pth')
    with torch.no_grad():
        print(model)
except ValueError:
    print("There was no model to load.")

model = model.to(device)

# loss; binary cross entropy with sigmoid, i.e. no need to use sigmoid in model
loss_fn = BCEWithLogitsLoss()

# optimizer
optimizer = torch.optim.Adam(model.fc.parameters(), lr=1e-4)

# get best loss
best_loss = get_best_loss()

# start training of model
n_epochs = 0
total_train_losses, total_val_losses, total_train_acc, total_val_acc =\
    train_model(n_epochs, model, train_loader, device, optimizer, loss_fn, val_loader, best_loss)

print('evaluating model...')

print('train loss: ', torch.Tensor(total_train_losses).cpu())
print('val loss: ', torch.Tensor(total_val_losses).cpu())

# plotting losses, accuracies
duration = len(total_val_losses)
plot_loss(torch.tensor(total_train_losses).cpu(), torch.tensor(total_val_losses).cpu(), duration)
plot_acc(total_train_acc, total_val_acc, duration)

Display tensorboard logs.

In [None]:
%load_ext tensorboard
%tensorboard --logdir model_data/logs