In [1]:
import matplotlib.pyplot as plt
import torch
import torchvision

from torch import nn
from torchvision import transforms


In [2]:
device = 'mps'
train_dir = '/Users/hiteshgupta/Documents/ML-CV/Vision Transformer/ImageForgery/train'

In [3]:
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 8, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(8, 16, 
                               kernel_size=3, 
                               stride=2, 
                               padding=1, 
                               output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 3, 
                               kernel_size=3, 
                               stride=2, 
                               padding=1, 
                               output_padding=1),
            nn.Sigmoid()
        )
         
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [4]:
# Creating a model with autoencoder and pretrained Vision Transform
model = Autoencoder()


# Get pretrained weights for ViT-Base
pretrained_vit_weights = torchvision.models.ViT_B_16_Weights.DEFAULT 

class ViT(nn.Module):
    def __init__(self, pretrained_vit_weights):
        super(ViT, self).__init__()
        self.pretrained_vit = torchvision.models.vit_b_16(weights=pretrained_vit_weights).to(device)
        self.autoencoder = Autoencoder().to(device)
        for parameter in self.pretrained_vit.parameters():
            parameter.requires_grad = False
        
        self.pretrained_vit.heads = nn.Linear(in_features=768, out_features=2).to(device)
        
    def forward(self, x):
        x = self.autoencoder(x)
        x = self.pretrained_vit(x)
        return x

# 2. Setup a ViT model instance with pretrained weights
# pretrained_vit = torchvision.models.vit_b_16(weights=pretrained_vit_weights).to(device)

# # 3. Freeze the base parameters
# for parameter in pretrained_vit.parameters():
#     parameter.requires_grad = False

pretrained_vit = ViT(pretrained_vit_weights)
    
# 4. Change the classifier head 
class_names =  ['forged', 'real']

# set_seeds()
pretrained_vit.heads = nn.Linear(in_features=768, out_features=len(class_names)).to(device)
pretrained_vit # uncomment for model output 

ViT(
  (pretrained_vit): VisionTransformer(
    (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (encoder): Encoder(
      (dropout): Dropout(p=0.0, inplace=False)
      (layers): Sequential(
        (encoder_layer_0): EncoderBlock(
          (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (self_attention): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          )
          (dropout): Dropout(p=0.0, inplace=False)
          (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (mlp): MLPBlock(
            (0): Linear(in_features=768, out_features=3072, bias=True)
            (1): GELU(approximate='none')
            (2): Dropout(p=0.0, inplace=False)
            (3): Linear(in_features=3072, out_features=768, bias=True)
            (4): Dropout(p=0.0, inplace=False)
          )
        )
        (encoder_layer_1): EncoderBlock(
          (ln_1): La

In [5]:
from torchinfo import summary

# Print a summary using torchinfo (uncomment for actual output)
summary(model=pretrained_vit, 
        input_size=(32, 3, 224, 224), # (batch_size, color_channels, height, width)
        # col_names=["input_size"], # uncomment for smaller output
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
)

Layer (type (var_name))                                           Input Shape          Output Shape         Param #              Trainable
ViT (ViT)                                                         [32, 3, 224, 224]    [32, 2]              1,538                Partial
├─Autoencoder (autoencoder)                                       [32, 3, 224, 224]    [32, 3, 224, 224]    --                   True
│    └─Sequential (encoder)                                       [32, 3, 224, 224]    [32, 8, 56, 56]      --                   True
│    │    └─Conv2d (0)                                            [32, 3, 224, 224]    [32, 16, 224, 224]   448                  True
│    │    └─ReLU (1)                                              [32, 16, 224, 224]   [32, 16, 224, 224]   --                   --
│    │    └─MaxPool2d (2)                                         [32, 16, 224, 224]   [32, 16, 112, 112]   --                   --
│    │    └─Conv2d (3)                                    

In [6]:
train_dir = '/Users/hiteshgupta/Documents/ML-CV/Vision Transformer/ImageForgery/train'
val_dir = '/Users/hiteshgupta/Documents/ML-CV/Vision Transformer/ImageForgery/val'

In [7]:
# Get automatic transforms from pretrained ViT weights
pretrained_vit_transforms = pretrained_vit_weights.transforms()
print(pretrained_vit_transforms)

ImageClassification(
    crop_size=[224]
    resize_size=[256]
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    interpolation=InterpolationMode.BILINEAR
)


In [8]:
import os

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

NUM_WORKERS = os.cpu_count()

def create_dataloaders(
    train_dir: str, 
    test_dir: str, 
    transform: transforms.Compose, 
    batch_size: int, 
    num_workers: int=NUM_WORKERS
):

  # Use ImageFolder to create dataset(s)
  train_data = datasets.ImageFolder(train_dir, transform=transform)
  test_data = datasets.ImageFolder(val_dir, transform=transform)

  # Get class names
  class_names = train_data.classes

  # Turn images into data loaders
  train_dataloader = DataLoader(
      train_data,
      batch_size=batch_size,
      shuffle=True,
      num_workers=num_workers,
      pin_memory=True,
  )
  test_dataloader = DataLoader(
      test_data,
      batch_size=batch_size,
      shuffle=False,
      num_workers=num_workers,
      pin_memory=True,
  )

  return train_dataloader, test_dataloader, class_names

In [9]:
# Setup dataloaders
train_dataloader_pretrained, test_dataloader_pretrained, class_names = create_dataloaders(train_dir=train_dir,
                                                                                                     test_dir=val_dir,
                                                                                                     transform=pretrained_vit_transforms,
                                                                                                     batch_size=32) # Could increase if we had more samples, such as here: https://arxiv.org/abs/2205.01580 (there are other improvements there too...)

In [10]:
from going_modular.going_modular import engine

# Create optimizer and loss function
optimizer = torch.optim.Adam(params=pretrained_vit.parameters(), 
                             lr=1e-3)
loss_fn = torch.nn.CrossEntropyLoss()

# Train the classifier head of the pretrained ViT feature extractor model
# set_seeds()
pretrained_vit_results = engine.train(model=pretrained_vit,
                                      train_dataloader=train_dataloader_pretrained,
                                      test_dataloader=test_dataloader_pretrained,
                                      optimizer=optimizer,
                                      loss_fn=loss_fn,
                                      epochs=15,
                                      device=device)

  0%|          | 0/15 [00:00<?, ?it/s]

Epoch: 1 | train_loss: 0.6945 | train_acc: 0.5388 | test_loss: 0.6990 | test_acc: 0.5120
Epoch: 2 | train_loss: 0.6986 | train_acc: 0.5302 | test_loss: 0.7277 | test_acc: 0.4522
Epoch: 3 | train_loss: 0.6791 | train_acc: 0.5739 | test_loss: 0.6875 | test_acc: 0.5538
Epoch: 4 | train_loss: 0.6703 | train_acc: 0.5861 | test_loss: 0.7149 | test_acc: 0.5323
Epoch: 5 | train_loss: 0.6671 | train_acc: 0.6000 | test_loss: 0.6025 | test_acc: 0.6767
Epoch: 6 | train_loss: 0.6560 | train_acc: 0.6097 | test_loss: 0.6647 | test_acc: 0.6080
Epoch: 7 | train_loss: 0.6543 | train_acc: 0.6164 | test_loss: 0.7227 | test_acc: 0.5248
Epoch: 8 | train_loss: 0.6468 | train_acc: 0.6195 | test_loss: 0.6737 | test_acc: 0.6218
Epoch: 9 | train_loss: 0.6513 | train_acc: 0.6225 | test_loss: 0.6052 | test_acc: 0.6794
Epoch: 10 | train_loss: 0.6453 | train_acc: 0.6288 | test_loss: 0.6545 | test_acc: 0.5939
Epoch: 11 | train_loss: 0.6404 | train_acc: 0.6319 | test_loss: 0.6251 | test_acc: 0.6555


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x168c90a40>
Traceback (most recent call last):
  File "/Users/hiteshgupta/anaconda3/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1479, in __del__
    self._shutdown_workers()
  File "/Users/hiteshgupta/anaconda3/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1443, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/Users/hiteshgupta/anaconda3/lib/python3.11/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
          ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/hiteshgupta/anaconda3/lib/python3.11/multiprocessing/popen_fork.py", line 40, in wait
    if not wait([self.sentinel], timeout):
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/hiteshgupta/anaconda3/lib/python3.11/multiprocessing/connection.py", line 947, in wait
    ready = selector.select(timeout)
            ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Use

In [None]:
from helper_functions import plot_loss_curves

plot_loss_curves(pretrained_vit_results) 

In [None]:
# Import function to make predictions on images and plot them 
from going_modular.going_modular.predictions import pred_and_plot_image

# Setup custom image path
custom_image_path = "Testpics/test3.jpg"

# Predict on custom image
pred_and_plot_image(model=pretrained_vit,
                    image_path=custom_image_path,
                    class_names=class_names)

In [None]:
# Save the model
from going_modular.going_modular.utils import save_model
save_model(model=pretrained_vit,
           target_dir='models',
           model_name='ImageForgeryViTEncoderPre_epochs15.pt')