In [1]:
# -------------------------------------------------------------------------------------------------------------
# Imports
# -------------------------------------------------------------------------------------------------------------
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as functional
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms

In [2]:
# -------------------------------------------------------------------------------------------------------------
# System path for imports
# -------------------------------------------------------------------------------------------------------------
PROJECT_ROOT='./'
import sys
sys.path.append(PROJECT_ROOT)

### Branch 1

In [14]:
# -------------------------------------------------------------------------------------------------------------
# Driver code
# -------------------------------------------------------------------------------------------------------------

from architecture.cnn_architecture import CNN_ARCHITECTURE
from configs.cnn_branch_config import cnn_experiment_1

model = CNN_ARCHITECTURE(cnn_experiment_1['model_args']['input_size'], cnn_experiment_1['model_args']['hidden_layers'], cnn_experiment_1['model_args']['activation'], cnn_experiment_1['model_args']['norm_layer'], cnn_experiment_1['model_args']['drop_prob'])
print(model)
model.eval()

# input_2 = torch.randn(1,3,224,224)
image = Image.open("./Blur.png").convert('RGB')
transform = transforms.Compose([transforms.ToTensor()])
input = transform(image)
print(f"Input shape: {input.shape}")
# print(input_2.shape)

final_features = model(input)

print(f"Final features: {final_features}")
print(f"Final features shape: {final_features.shape}")

# features_in_shape = final_features.view(512, 28, 28)
# print(features_in_shape.shape)

CNN_ARCHITECTURE(
  (features): Sequential(
    (0): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (2): ReLU()
    (3): Dropout(p=0.4, inplace=False)
    (4): Conv2d(128, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): ReLU()
    (7): Dropout(p=0.4, inplace=False)
    (8): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): ReLU()
    (11): Dropout(p=0.4, inplace=False)
  )
)
Input shape: torch.Size([3, 720, 1280])
Final features: tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0367, 0.0228, 0.0283,  ..., 0.0252, 0.0236, 0.0197],
        [0.0042, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.0764, 0.0538, 0.0593,  ..., 0.

### Branch 2

In [15]:
from architecture.vit_architectire import VIT_ARCHITECTURE
from configs.cnn_branch_config import cnn_experiment_1

vit_model = VIT_ARCHITECTURE(cnn_experiment_1['model_args']['model_name'])


vit_features = vit_model.extract_features("./Blur.png")
print(vit_features)
print("CLS token shape:", vit_features['cls'].shape)
print("All token embeddings shape:", vit_features['all'].shape)

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


{'cls': tensor([[ 4.7615e-01, -9.9602e-02,  1.5029e+00, -5.1744e-02, -3.8278e-02,
          2.4302e+00,  4.4923e-01, -8.1642e-01, -5.7224e-01, -8.1522e-01,
         -1.8054e+00, -1.4601e-01,  3.8230e-01,  6.0427e-01, -1.1726e-01,
         -5.0949e-01, -3.7467e-01,  5.5540e-01,  1.0882e-01,  1.2580e+00,
         -6.7239e-01, -7.0905e-01, -6.4983e-01, -1.4001e+00, -7.9707e-02,
         -1.1794e+00, -5.2500e-01, -1.1217e+00,  8.0294e-01,  3.9371e-01,
         -1.1232e+00,  9.5453e-01,  6.8784e-01, -2.0632e+00,  2.1795e-01,
          3.2798e-01, -1.2843e+00,  1.1984e+00,  1.1486e+00, -4.5268e-01,
         -5.7421e-01, -1.4203e+00,  1.3407e+00,  3.2471e-01, -1.5128e+00,
         -3.7624e-01, -1.2511e+00, -1.4204e-01, -1.9368e+00,  9.1642e-01,
          5.8074e-02,  1.0683e+00, -1.2409e+00,  5.7934e-01, -9.9048e-01,
         -1.0262e+00,  5.2281e-01, -2.2733e-01,  1.8387e-01,  8.7814e-01,
          1.9188e+00, -1.4181e+00,  8.1813e-02,  6.3443e-01,  4.7809e-01,
         -1.6845e+00,  1.4446e