In [1]:
from datetime import datetime
import numpy
from torch.utils.tensorboard import SummaryWriter
from torch.optim import Adam, SGD
from torch.nn import CrossEntropyLoss
from sklearn.metrics import precision_score, recall_score, accuracy_score, f1_score, auc, roc_curve, confusion_matrix
import torch
from torchinfo import summary
import os
import PIL
from PIL import Image
import import_ipynb 

In [2]:
from vit import ViT
from vit import training_dataloader, validation_dataloader

importing Jupyter notebook from vit.ipynb


In [3]:
# Batch size
BATCH_SIZE = 32
writer = SummaryWriter("logs")
model = ViT()
if os.path.exists("predict.pth"):
    print("Loading saved model")
    model.load_state_dict(torch.load("predict.pth"))
else:
    print("Saved model weights not found")
optimizer = SGD(model.parameters(), lr=1e-2, momentum=0.9, weight_decay=1e-4)
loss_fn = CrossEntropyLoss()
summary(
    model=model,
    input_size=(
        BATCH_SIZE,
        3,
        224,
        224,
    ),  # (batch_size, num_patches, embedding_dimension)
    col_names=["input_size", "output_size", "num_params", "trainable"],
    col_width=20,
    row_settings=["var_names"],
)

Layer (type (var_name))                                      Input Shape          Output Shape         Param #              Trainable
ViT (ViT)                                                    [32, 3, 224, 224]    [32, 2]              --                   True
├─PatchEmbeddingLayer (patch_embedding_layer)                [32, 3, 224, 224]    [32, 197, 768]       175,872              True
│    └─Conv2d (conv_layer)                                   [32, 3, 224, 224]    [32, 768, 14, 14]    590,592              True
│    └─Flatten (flatten_layer)                               [32, 14, 14, 768]    [32, 196, 768]       --                   --
├─Sequential (transformer_encoder)                           [32, 197, 768]       [32, 197, 768]       --                   True
│    └─TransformerBlock (0)                                  [32, 197, 768]       [32, 197, 768]       --                   True
│    │    └─MultiHeadSelfAttentionBlock (msa_block)          [32, 197, 768]       [32, 197, 76

In [4]:
def train_one_epoch(epoch_index, tb_writer):
    running_loss = 0.0
    last_loss = 0.0

    for i, data in enumerate(training_dataloader):
        inputs, labels = data

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()

        optimizer.step()

        running_loss += loss.item()
        predicted_labels = torch.argmax(outputs, dim=1)
        accuracy = accuracy_score(labels, predicted_labels)
        # report every 100 batches
        if i % 100 == 99:
            print(
                f"Batch {i+1}, Loss: {running_loss / 100:.3f}, Training Accuracy: {accuracy}%"
            )
            last_loss = running_loss / 100
            tb_x = epoch_index * len(training_dataloader) + i + 1
            tb_writer.add_scalar("Training loss", last_loss, tb_x)
            tb_writer.add_scalar("Training Accuracy", accuracy, tb_x)
            tb_writer.flush()
            running_loss = 0.0

    return last_loss

In [5]:
NUM_EPOCHS = 5
best_vloss = 1_000_000
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")


## Training and validation loop

In [6]:
for epoch in range(NUM_EPOCHS):
    print("EPOCH {}:".format(epoch + 1))
    model.train(True)
    avg_loss = train_one_epoch(epoch, writer)

    running_vloss = 0.0
    model.eval()

    with torch.no_grad():
        for i, vdata in enumerate(validation_dataloader):
            vinputs, vlabels = vdata
            voutputs = model(vinputs)
            vloss = loss_fn(voutputs, vlabels)
            running_vloss += vloss
            predicted_labels = torch.argmax(voutputs, dim=1)

            precision = precision_score(vlabels, predicted_labels, average="macro")
            recall = recall_score(vlabels, predicted_labels, average="macro")
            f1score = f1_score(vlabels, predicted_labels, average="macro")

            accuracy = accuracy_score(vlabels, predicted_labels)

            if i % 100 == 99:
              print("Predicted labels:", predicted_labels)
              print("Actual labels:", vlabels)
              print("Validation Accuracy", accuracy)
              print("Validation Precision", precision)
              print("Validation Recall", recall)
              print("Validation F1 score", f1score)
            writer.add_scalar("Validation Precision", precision, epoch)
            writer.add_scalar("Validation Recall", recall, epoch)
            writer.add_scalar("Validation F1-score", f1score, epoch)
            writer.add_scalar("Validation Accuracy", accuracy, epoch)
    avg_vloss = running_vloss / (i + 1)
    print("LOSS train {} valid {}".format(avg_loss, avg_vloss))

    writer.add_scalars(
        "Training vs. Validation Loss",
        {"Training": avg_loss, "Validaton": avg_vloss},
        epoch + 1,
    )
    writer.flush()

    # save best model
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_path = "model_{}_{}".format(timestamp, epoch)
        torch.save(model.state_dict(), model_path)

EPOCH 1:


RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor

## Prediction

In [None]:
import matplotlib.pyplot as plt


def showImage(path):
    # Load the image
    img = Image.open(path)

    # Create a figure with subplots
    fig, axs = plt.subplots(1, 1, figsize=(6, 6))

    # Display the image in the subplot
    axs.imshow(img)

    # Disable the axis for better visualization
    axs.axis(False)

    # Set the background color of the figure as black
    fig.set_facecolor(color="black")

    # Display the plot
    plt.show()

In [None]:
from torchvision.transforms import Resize, Compose, ToTensor
from vit import ViT
import torch.nn.functional as F

# Initialize
model = ViT()
# Load the entire model (architecture and weights)
model.load_state_dict(torch.load("predict.pth"))

model.eval()  # Set the model to evaluation mode
# Define the test_transform using Compose
test_transform = Compose([Resize((224, 224)), ToTensor()])

# Load and preprocess the test image
path = "../generated-images/animated-char2.jpg"
test_image = Image.open(path)
test_image = test_transform(test_image).unsqueeze(0)  # Add batch dimension

# Replicate the input tensor along the batch dimension to match the expected size (32)
batch_size = 32
test_image_batched = test_image.repeat(batch_size, 1, 1, 1)

# Make predictions
with torch.no_grad():
    predictions = model(test_image_batched)

# Process predictions as needed for each image
probabilities = F.softmax(predictions, dim=-1)
result = "real" if torch.argmax(probabilities) == 0 else "fake"
print(f"Prediction: {result}")
showImage(path)