In [37]:
import torch
import pandas as pd
from torch import nn
from torch import optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import numpy as np
import random
import timeit
from tqdm import tqdm

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Layer
from tensorflow.keras.layers import (GlobalAveragePooling2D, Activation, MaxPooling2D, Add, Conv2D, MaxPool2D, Dense,
                                     Flatten, InputLayer, BatchNormalization, Input, Embedding, Permute,
                                     Dropout, RandomFlip, RandomRotation, LayerNormalization, MultiHeadAttention,
                                     RandomContrast, Rescaling, Resizing, Reshape)
import tensorflow as tf### models
import cv2
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from fastai.vision.all import *

In [None]:
RANDOM_SEED = 42
BATCH_SIZE = 512
EPOCHS = 40
LEARNING_RATE = 1e-4
NUM_CLASSES = 10
PATCH_SIZE = 16
STRIDE_SIZE = 4
IMG_SIZE = 256
IN_CHANNELS = 3
NUM_HEADS = 8
DROPOUT = 0.001
HIDDEN_DIM = 768
ADAM_WEIGHT_DECAY = 0
ADAM_BETAS = (0.9, 0.999)
ACTIVATION="gelu"
NUM_ENCODERS = 4
EMBED_DIM = (PATCH_SIZE ** 2) * IN_CHANNELS # 16
NUM_PATCHES = (IMG_SIZE // PATCH_SIZE) ** 2 # 49

random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed_all(RANDOM_SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self, embed_dim, num_patches, dropout, in_channels):
        super().__init__()
        self.cls_token = nn.Parameter(torch.randn(size=(1, in_channels, embed_dim)), requires_grad=True)
        self.position_embeddings = nn.Parameter(torch.randn(size=(1, num_patches+1, embed_dim)), requires_grad=True)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat([cls_token, x], dim=1)
        x = self.position_embeddings + x 
        x = self.dropout(x)
        return x
    
model = PatchEmbedding(EMBED_DIM, PATCH_SIZE, NUM_PATCHES, DROPOUT, IN_CHANNELS).to(device)
x = torch.randn(512, 1, 28, 28).to(device)
print(model(x).shape)

In [ ]:
class PatchEncoder(Layer):
  def __init__(self, N_PATCHES, HIDDEN_SIZE, PATCH_SIZE):
    super(PatchEncoder, self).__init__(name = 'patch_encoder')
    self.linear_projection = Dense(HIDDEN_SIZE)
    self.N_PATCHES = N_PATCHES

  def call(self, x):
    # test_image = cv2.resize(test_image, (CONFIGURATION["IM_SIZE"] ,CONFIGURATION["IM_SIZE"]))
    patches = tf.image.extract_patches(
        images=x,
        sizes=[1, PATCH_SIZE, PATCH_SIZE, 1],
        strides=[1, PATCH_SIZE, PATCH_SIZE, 1],
        rates=[1, 1, 1, 1],
        padding='VALID')

    patches = tf.reshape(patches, (tf.shape(patches)[0], self.N_PATCHES, patches.shape[-1]))

    output = self.linear_projection(patches) 

    return output

In [None]:
class ViT(nn.Module):
    def __init__(self, num_patches, img_size, num_classes, patch_size, embed_dim, num_encoders, num_heads, hidden_dim, dropout, activation, in_channels):
        super().__init__()
        self.patch_encoder = PatchEncoder(num_patches, hidden_dim, patch_size)
        self.embeddings_block = PatchEmbedding(embed_dim, patch_size, num_patches, dropout, in_channels)        
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dropout=dropout, activation=activation, batch_first=True, norm_first=True)
        self.encoder_blocks = nn.TransformerEncoder(encoder_layer, num_layers=num_encoders)

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(normalized_shape=embed_dim),
            nn.Linear(in_features=embed_dim, out_features=num_classes)
        )

    def forward(self, x):
        x = self.patch_encoder(x)
        x = self.embeddings_block(x)
        x = self.encoder_blocks(x)
        x = self.mlp_head(x[:, 0, :])  # Apply MLP on the CLS token only
        return x

In [ ]:
def compute_patch_coordinates(image_shape, patch_size, stride, as_tuple=False):
    """
    Compute the top-left coordinates of the patches extracted from an image.

    Parameters:
    - image_shape: A tuple (height, width) specifying the size of the image.
    - patch_size: The size of the patches.
    - stride: The stride used when extracting patches.

    Returns:
    - A list of tuples where each tuple is the (y, x) coordinate of a patch.
    """
    coordinates = []
    for y in range(0, image_shape[0] - patch_size[0] + 1, stride[0]):
        for x in range(0, image_shape[1] - patch_size[1] + 1, stride[1]):
            coordinates.append((y, x) if as_tuple else [y, x])
    return coordinates

PATCH_COORDS = compute_patch_coordinates(IMG_SIZE, PATCH_SIZE, STRIDE_SIZE)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), betas=ADAM_BETAS, lr=LEARNING_RATE, weight_decay=ADAM_WEIGHT_DECAY)


In [27]:

test_image = cv2.imread("téléchargement.jpeg")
test_image = cv2.resize(test_image, (256 , 256))
patches = tf.image.extract_patches(
    images=np.expand_dims(test_image, axis=0),
    sizes=[1, 64, 64, 1],
    strides=[1, 10, 10, 1],
    rates=[1, 1, 1, 1],
    padding='VALID')

In [16]:
patches[0].shape

TensorShape([20, 20, 12288])

In [ ]:
print(patches.shape)
print(patches.shape)

In [None]:
plt.figure(figsize = (16,16))

for i in range(patches.shape[1]):

    ax = plt.subplot(20,20, i+1)
    plt.imshow(tf.reshape(patches[0,i,:], (64,64,3)))
    plt.axis("off")

In [33]:
patches = tf.reshape(patches, (patches.shape[0], -1, EMBED_DIM))

# Example usage:
image_shape = (256, 256)
patch_size = (64, 64)
stride = (10, 10)
c = 
print(c, len(c))

[[0, 0], [0, 10], [0, 20], [0, 30], [0, 40], [0, 50], [0, 60], [0, 70], [0, 80], [0, 90], [0, 100], [0, 110], [0, 120], [0, 130], [0, 140], [0, 150], [0, 160], [0, 170], [0, 180], [0, 190], [10, 0], [10, 10], [10, 20], [10, 30], [10, 40], [10, 50], [10, 60], [10, 70], [10, 80], [10, 90], [10, 100], [10, 110], [10, 120], [10, 130], [10, 140], [10, 150], [10, 160], [10, 170], [10, 180], [10, 190], [20, 0], [20, 10], [20, 20], [20, 30], [20, 40], [20, 50], [20, 60], [20, 70], [20, 80], [20, 90], [20, 100], [20, 110], [20, 120], [20, 130], [20, 140], [20, 150], [20, 160], [20, 170], [20, 180], [20, 190], [30, 0], [30, 10], [30, 20], [30, 30], [30, 40], [30, 50], [30, 60], [30, 70], [30, 80], [30, 90], [30, 100], [30, 110], [30, 120], [30, 130], [30, 140], [30, 150], [30, 160], [30, 170], [30, 180], [30, 190], [40, 0], [40, 10], [40, 20], [40, 30], [40, 40], [40, 50], [40, 60], [40, 70], [40, 80], [40, 90], [40, 100], [40, 110], [40, 120], [40, 130], [40, 140], [40, 150], [40, 160], [40, 17

In [38]:
source = untar_data(URLs.PETS)/"images"
items = get_image_files(source)

KeyboardInterrupt: 

In [ ]:
pets = DataBlock(
    blocks=(ImageBlock, BBoxBlock), 
     get_items=get_image_files, 
     splitter=RandomSplitter(),
     get_y=[lambda o: compute_patch_coordinates(IMG_SIZE, PATCH_SIZE, STRIDE_SIZE)],
     item_tfms=[Resize(IMG_SIZE)],
     batch_tfms=aug_transforms(),
     n_inp=1
 )

In [ ]:
dls = pets.dataloaders(source)
dls.show_batch(max_n=9, show=False)

In [ ]:

model = ViT(NUM_PATCHES, IMG_SIZE, NUM_CLASSES, PATCH_SIZE, EMBED_DIM, NUM_ENCODERS, NUM_HEADS, HIDDEN_DIM, DROPOUT, ACTIVATION, IN_CHANNELS).to(device)
x = torch.randn(512, 1, 28, 28).to(device)
print(model(x).shape) # BATCH_SIZE X NUM_CLASSES