In [1]:
import os
image_path = "../dataset/pizza_steak_sushi/data"
# image_path = "../dataset/food101_torch"

train_dir = os.path.join(image_path, "train")
test_dir = os.path.join(image_path, "test")

In [2]:
import torchvision, torch
from torch import nn
# 1. Get pretrained weights for ViT-Base
pretrained_vit_weights = torchvision.models.ViT_B_16_Weights.DEFAULT

In [3]:
from data_setup import createDataloader

train_transform = pretrained_vit_weights.transforms()

train_DL, test_DL, _classes = createDataloader(train_dir, test_dir, 32, train_transform, train_transform)

print("total classes:",len(_classes))

total classes: 3


In [4]:


# 2. Setup a ViT model instance with pretrained weights
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pretrained_vit = torchvision.models.vit_b_16(weights=pretrained_vit_weights).to(device)

# 3. Freeze the base parameters
for parameter in pretrained_vit.parameters():
    parameter.requires_grad = False

# 4. Change the classifier head (set the seeds to ensure same initialization with linear head)
# set_seeds()
pretrained_vit.heads = nn.Linear(in_features=768, out_features=len(_classes)).to(device)
# pretrained_vit # uncomment for model output

In [5]:
from torchinfo import summary

In [None]:
# # Print a summary using torchinfo (uncomment for actual output)
summary(model=pretrained_vit,
        input_size=(32, 3, 224, 224), # (batch_size, color_channels, height, width)
        # col_names=["input_size"], # uncomment for smaller output
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
)

In [None]:
from utils import trainit

loss_fn = nn.CrossEntropyLoss().to(device)
optim = torch.optim.Adam(pretrained_vit.parameters(), lr=1e-3)

trainit(pretrained_vit, train_DL, test_DL, loss_fn, optim, 50, True, device)

In [65]:
summary(model=pretrained_vit,
        input_size=(32, 3, 224, 224), # (batch_size, color_channels, height, width)
		# col_names=["input_size"], # uncomment for smaller output
		col_names=["input_size", "output_size", "num_params", "trainable"],
		col_width=20,
		row_settings=["var_names"]
)

Layer (type (var_name))                                      Input Shape          Output Shape         Param #              Trainable
VisionTransformer (VisionTransformer)                        [32, 3, 224, 224]    [32, 3]              768                  Partial
├─Conv2d (conv_proj)                                         [32, 3, 224, 224]    [32, 768, 14, 14]    (590,592)            False
├─Encoder (encoder)                                          [32, 197, 768]       [32, 197, 768]       151,296              False
│    └─Dropout (dropout)                                     [32, 197, 768]       [32, 197, 768]       --                   --
│    └─Sequential (layers)                                   [32, 197, 768]       [32, 197, 768]       --                   False
│    │    └─EncoderBlock (encoder_layer_0)                   [32, 197, 768]       [32, 197, 768]       (7,087,872)          False
│    │    └─EncoderBlock (encoder_layer_1)                   [32, 197, 768]       [32, 

# infering best save model

In [70]:

def get_last_model(path):
	files = os.listdir(path)
	files.sort()
	return os.path.join(path, files[-1],"best_model.pth")
bestmodel_path = "{}".format(get_last_model("./model"))
print("best model path:", bestmodel_path)

best model path: ./model/1736257866/best_model.pth


In [71]:
torch.save(pretrained_vit.state_dict(), "best_vits.pth")

In [72]:
new_vit_model = torchvision.models.vit_b_16().to(device)

In [73]:
new_vit_model = torchvision.models.vit_b_16(weights=pretrained_vit_weights).to(device)
new_vit_model.heads = nn.Linear(in_features=768, out_features=len(_classes))

In [74]:
summary(model=new_vit_model,
        input_size=(32, 3, 224, 224), # (batch_size, color_channels, height, width)
		# col_names=["input_size"], # uncomment for smaller output
		col_names=["input_size", "output_size", "num_params", "trainable"],
		col_width=20,
		row_settings=["var_names"]
)

Layer (type (var_name))                                      Input Shape          Output Shape         Param #              Trainable
VisionTransformer (VisionTransformer)                        [32, 3, 224, 224]    [32, 3]              768                  True
├─Conv2d (conv_proj)                                         [32, 3, 224, 224]    [32, 768, 14, 14]    590,592              True
├─Encoder (encoder)                                          [32, 197, 768]       [32, 197, 768]       151,296              True
│    └─Dropout (dropout)                                     [32, 197, 768]       [32, 197, 768]       --                   --
│    └─Sequential (layers)                                   [32, 197, 768]       [32, 197, 768]       --                   True
│    │    └─EncoderBlock (encoder_layer_0)                   [32, 197, 768]       [32, 197, 768]       7,087,872            True
│    │    └─EncoderBlock (encoder_layer_1)                   [32, 197, 768]       [32, 197, 76

In [75]:
new_vit_model.load_state_dict(torch.load("best_vits.pth"))

  new_vit_model.load_state_dict(torch.load("best_vits.pth"))


<All keys matched successfully>

In [76]:
new_vit_model

VisionTransformer(
  (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  (encoder): Encoder(
    (dropout): Dropout(p=0.0, inplace=False)
    (layers): Sequential(
      (encoder_layer_0): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=3072, out_features=768, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (encoder_layer_1): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_a

In [77]:
summary(model=new_vit_model,
        input_size=(32, 3, 224, 224), # (batch_size, color_channels, height, width)
		# col_names=["input_size"], # uncomment for smaller output
		col_names=["input_size", "output_size", "num_params", "trainable"],
		col_width=20,
		row_settings=["var_names"]
)

Layer (type (var_name))                                      Input Shape          Output Shape         Param #              Trainable
VisionTransformer (VisionTransformer)                        [32, 3, 224, 224]    [32, 3]              768                  True
├─Conv2d (conv_proj)                                         [32, 3, 224, 224]    [32, 768, 14, 14]    590,592              True
├─Encoder (encoder)                                          [32, 197, 768]       [32, 197, 768]       151,296              True
│    └─Dropout (dropout)                                     [32, 197, 768]       [32, 197, 768]       --                   --
│    └─Sequential (layers)                                   [32, 197, 768]       [32, 197, 768]       --                   True
│    │    └─EncoderBlock (encoder_layer_0)                   [32, 197, 768]       [32, 197, 768]       7,087,872            True
│    │    └─EncoderBlock (encoder_layer_1)                   [32, 197, 768]       [32, 197, 76

In [78]:
test_DL.dataset[0][0].unsqueeze(0).shape

torch.Size([1, 3, 224, 224])

In [98]:
len(test_DL.dataset)

75

In [102]:
import requests
with open("pizza.jpg", "wb") as f:
    img = requests.get("https://veenaazmanov.com/wp-content/uploads/2020/07/Mushroom-Pizza-Recipe4.jpg")
    f.write(img.content)

img_path = "pizza.jpg"
# img_path = "./steak.jpg"
img = Image.open(img_path)
img = train_transform(img).unsqueeze(0)
print(img.shape)

torch.Size([1, 3, 224, 224])


In [104]:
import matplotlib.pyplot as plt
import random
with torch.inference_mode():
    new_vit_model.eval()
    # rand_image = random.randint(0, len(test_DL))
    # plt.imshow(test_DL.dataset[rand_image][0].permute(1, 2, 0))
    # plt.title(f"Actual: {_classes[test_DL.dataset[rand_image][1]]}")
    # y_preds = new_vit_model(test_DL.dataset[rand_image][0].unsqueeze(0).to(device))
    # print(_classes[torch.argmax(y_preds)])
    print(_classes[torch.argmax(new_vit_model(img.to(device)))])


pizza
