In [21]:
from torchvision import models
import torch

In [3]:
# The inference transforms are available at ViT_B_16_Weights.IMAGENET1K_V1.transforms and perform the following preprocessing operations: Accepts PIL.Image, batched (B, C, H, W) and single (C, H, W) image torch.Tensor objects. 
# The images are resized to resize_size=[256] using interpolation=InterpolationMode.BILINEAR, followed by a central crop of crop_size=[224]. 
# Finally the values are first rescaled to [0.0, 1.0] and then normalized using mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225].

vit = models.vit_b_16(models.ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1)

Downloading: "https://download.pytorch.org/models/vit_b_16_swag-9ac1b537.pth" to C:\Users\josem/.cache\torch\hub\checkpoints\vit_b_16_swag-9ac1b537.pth
100%|██████████| 331M/331M [00:25<00:00, 13.6MB/s] 


In [9]:
for param in vit.parameters():
    param.requires_grad = False

In [25]:
print(*list(vit.children())[:-1])

Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16)) Encoder(
  (dropout): Dropout(p=0.0, inplace=False)
  (layers): Sequential(
    (encoder_layer_0): EncoderBlock(
      (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (self_attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
      )
      (dropout): Dropout(p=0.0, inplace=False)
      (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): MLPBlock(
        (0): Linear(in_features=768, out_features=3072, bias=True)
        (1): GELU(approximate='none')
        (2): Dropout(p=0.0, inplace=False)
        (3): Linear(in_features=3072, out_features=768, bias=True)
        (4): Dropout(p=0.0, inplace=False)
      )
    )
    (encoder_layer_1): EncoderBlock(
      (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (self_attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_f

In [29]:
class ViT(torch.nn.Module):
    def __init__(self, visionTransformer: models.VisionTransformer):
        super(ViT, self).__init__()

        ViTLayers = torch.nn.Sequential(*list(visionTransformer.children())[:-1])
        for param in ViTLayers.parameters():
            param.requires_grad = False

        self.ViT = ViTLayers
        self.linear = torch.nn.Linear(768, 10)
        self.softmax = torch.nn.Softmax(dim = 1)

    def forward(self, x):
        extractedFeature = self.ViT(x)
        probabilities = self.linear(extractedFeature)
        softmax = self.softmax(probabilities)

        return softmax

In [30]:
visionTransformer = ViT(vit)

In [31]:
# Loss and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(visionTransformer.parameters(), lr=0.001)


In [32]:
import CitiesData

In [33]:
trainDataLoader, testDataLoader = CitiesData.getCitiesDataLoader("./Data/")

In [35]:
for i in trainDataLoader:
    print(len(i))
    break

4


In [None]:
num_epochs = 10
for epoch in range(num_epochs):
    for data in trainDataLoader:
        images, cities, _, _ = data

        optimizer.zero_grad()
        outputs = visionTransformer(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')