In [1]:
from torchvision import models
from torchvision.transforms import transforms
import torch
import CitiesData
import numpy as np
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import time

In [2]:
# The inference transforms are available at ViT_B_16_Weights.IMAGENET1K_V1.transforms and perform the following preprocessing operations: Accepts PIL.Image, batched (B, C, H, W) and single (C, H, W) image torch.Tensor objects. 
# The images are resized to resize_size=[256] using interpolation=InterpolationMode.BILINEAR, followed by a central crop of crop_size=[224]. 
# Finally the values are first rescaled to [0.0, 1.0] and then normalized using mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225].

#models.ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1
vit = models.vit_b_16(pretrained=True)



In [3]:
class ViT(torch.nn.Module):
    def __init__(self, visionTransformer: models.VisionTransformer):
        super(ViT, self).__init__()
        
        self.reference_vit = visionTransformer.to(device)
#         ViTLayers = torch.nn.Sequential(*list(visionTransformer.children())[:-1])
#         for param in ViTLayers.parameters():
#             param.requires_grad = False
            
        for param in self.reference_vit.parameters():
            param.requires_grad = False

        #self.ViT = ViTLayers.to(device)
        self.linear = torch.nn.Linear(1000, 10).to(device)
        self.softmax = torch.nn.Softmax(dim = 1).to(device)

    def forward(self, x):
        # Reshape and permute the input tensor
        x = self._process_input(x)
        n = x.shape[0]

        # Expand the class token to the full batch
        batch_class_token = self.reference_vit.class_token.expand(n, -1, -1)
        x = torch.cat([batch_class_token, x], dim=1)

        x = self.reference_vit.encoder(x)

        # Classifier "token" as used by standard language architectures
        x = x[:, 0]

        x =  self.reference_vit.heads(x)
        
        #extractedFeature = self.ViT(x)
        probabilities = self.linear(x)
        softmax = self.softmax(probabilities)

        return softmax
    
    def _process_input(self, x: torch.Tensor) -> torch.Tensor:
        n, c, h, w = x.shape
        p = self.reference_vit.patch_size
        torch._assert(h == self.reference_vit.image_size, f"Wrong image height! Expected {self.reference_vit.image_size} but got {h}!")
        torch._assert(w == self.reference_vit.image_size, f"Wrong image width! Expected {self.reference_vit.image_size} but got {w}!")
        n_h = h // p
        n_w = w // p
        
        # (n, c, h, w) -> (n, hidden_dim, n_h, n_w)
        x =  self.reference_vit.conv_proj(x)
        # (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w))
        x = x.reshape(n, self.reference_vit.hidden_dim, n_h * n_w)

        # (n, hidden_dim, (n_h * n_w)) -> (n, (n_h * n_w), hidden_dim)
        # The self attention layer expects inputs in the format (N, S, E)
        # where S is the source sequence length, N is the batch size, E is the
        # embedding dimension
        x = x.permute(0, 2, 1)

        return x

In [4]:
visionTransformer = ViT(vit)
print(*list(visionTransformer.children())[:-1])

VisionTransformer(
  (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  (encoder): Encoder(
    (dropout): Dropout(p=0.0, inplace=False)
    (layers): Sequential(
      (encoder_layer_0): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=3072, out_features=768, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (encoder_layer_1): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_a

In [5]:
# Loss and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(visionTransformer.parameters(), lr=0.001)


In [6]:
batch_size = 1024
transform = transforms.Compose([transforms.ToTensor(), transforms.RandomResizedCrop(size=(224, 224), antialias=True), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainDataLoader, testDataLoader = CitiesData.getCitiesDataLoader("./Data/", transforms = transform, batchSize=batch_size)

In [7]:
print(len(trainDataLoader))
print(len(testDataLoader))
for i in trainDataLoader:
    image, cities, _, _ = i
    print(image.shape)
    break

196
22
torch.Size([1024, 3, 224, 224])


In [8]:
def city_to_vector(city):
    output = np.zeros(shape=(len(city), 10))
    for i in range(len(city)):
        if city[i] == 'Atlanta':
            output[i][0] = 1
        elif city[i] == 'Austin':
            output[i][1] = 1
        elif city[i] == 'Boston':
            output[i][2] = 1
        elif city[i] == 'Chicago':
            output[i][3] = 1
        elif city[i] == 'LosAngeles':
            output[i][4] = 1
        elif city[i] == 'Miami':
            output[i][5] = 1
        elif city[i] == 'NewYork':
            output[i][6] = 1
        elif city[i] == 'Phoenix':
            output[i][7] = 1
        elif city[i] == 'SanFrancisco':
            output[i][8] = 1
        elif city[i] == 'Seattle':
            output[i][9] = 1
    
    return torch.tensor(output).to(device)
        

In [9]:
num_epochs = 10
count = 0
for epoch in range(num_epochs):
    start = time.time()
    for data in trainDataLoader:
        

        image, city, _, _ = data
        city = city_to_vector(city)
        image = image.to(device)

        optimizer.zero_grad()
        outputs = visionTransformer(image)
        loss = criterion(outputs, city)
        loss.backward()
        optimizer.step()
        
        end = time.time()
        print(end - start)
        count += 1
        print(count * batch_size)

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

6.215665102005005
1024
10.973674058914185
2048
15.169410943984985
3072
19.854658126831055
4096
24.07573103904724
5120
28.804178714752197
6144
33.05536651611328
7168
37.92238664627075
8192
42.367615699768066
9216
46.769951820373535
10240
51.10030436515808
11264
55.54031777381897
12288
59.89974808692932
13312
64.29403424263
14336
68.59121966362
15360


KeyboardInterrupt: 