https://github.com/pytorch/vision/blob/main/torchvision/models/vision_transformer.py

In [None]:
import math
import torch
import torch.nn as nn

from collections import OrderedDict
from functools import partial
from typing import Any, Callable, Dict, List, NamedTuple, Optional
from torchvision.models.vision_transformer import ConvStemConfig, WeightsEnum, MLPBlock, EncoderBlock #, Encoder
from torchvision.models import ViT_B_16_Weights  # , vit_b_16
from torchvision.utils import _log_api_usage_once
from torchvision.models._utils import _ovewrite_named_param
from torchvision.ops.misc import Conv2dNormActivation
from torchvision.datasets import Flowers102
from torch.utils.data import DataLoader, random_split
from torchvision.transforms import ToTensor, Compose, Lambda, Resize
from sklearn.model_selection import ParameterGrid
from collections.abc import Iterable
from sklearn.utils import shuffle

In [None]:
class Encoder(nn.Module):
    """Transformer Model Encoder for sequence to sequence translation."""

    def __init__(
        self,
        seq_length: int,
        num_layers: int,
        num_heads: int,
        hidden_dim: int,
        mlp_dim: int,
        dropout: float,
        attention_dropout: float,
        norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
    ):
        super().__init__()
        # Note that batch_size is on the first dim because
        # we have batch_first=True in nn.MultiAttention() by default
        self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, hidden_dim).normal_(std=0.02))  # from BERT
        self.dropout = nn.Dropout(dropout)
        layers: OrderedDict[str, nn.Module] = OrderedDict()
        class_heads: OrderedDict[str, nn.Module] = OrderedDict()
        for i in range(num_layers):
            layers[f"encoder_layer_{i}"] = EncoderBlock(
                num_heads,
                hidden_dim,
                mlp_dim,
                dropout,
                attention_dropout,
                norm_layer,
            )
            class_heads[f"class_head_{i}"] = nn.Sequential(
                # add some activation?
                nn.Linear(in_features=hidden_dim, out_features=102, bias=True),
                nn.Softmax(dim=1)
            )

        self.layers = nn.Sequential(layers)
        self.ln = norm_layer(hidden_dim)
        self.class_heads = nn.Sequential(class_heads)

    def forward(self, input: torch.Tensor):
        torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
        _input = input + self.pos_embedding
        _input = self.dropout(_input)

        early_classification = []
        for layer, class_head in zip(self.layers, self.class_heads):
            _input = layer(_input)
            early_classification.append(class_head(_input[:, 0]))

        # return self.ln(self.layers(self.dropout(input)))
        return self.ln(_input), early_classification

In [None]:
class VisionTransformer(nn.Module):
    """Vision Transformer as per https://arxiv.org/abs/2010.11929."""

    def __init__(
        self,
        image_size: int,
        patch_size: int,
        num_layers: int,
        num_heads: int,
        hidden_dim: int,
        mlp_dim: int,
        dropout: float = 0.0,
        attention_dropout: float = 0.0,
        num_classes: int = 1000,
        representation_size: Optional[int] = None,
        norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
        conv_stem_configs: Optional[List[ConvStemConfig]] = None,
    ):
        super().__init__()
        _log_api_usage_once(self)
        torch._assert(image_size % patch_size == 0, "Input shape indivisible by patch size!")
        self.image_size = image_size
        self.patch_size = patch_size
        self.hidden_dim = hidden_dim
        self.mlp_dim = mlp_dim
        self.attention_dropout = attention_dropout
        self.dropout = dropout
        self.num_classes = num_classes
        self.representation_size = representation_size
        self.norm_layer = norm_layer

        if conv_stem_configs is not None:
            # As per https://arxiv.org/abs/2106.14881
            seq_proj = nn.Sequential()
            prev_channels = 3
            for i, conv_stem_layer_config in enumerate(conv_stem_configs):
                seq_proj.add_module(
                    f"conv_bn_relu_{i}",
                    Conv2dNormActivation(
                        in_channels=prev_channels,
                        out_channels=conv_stem_layer_config.out_channels,
                        kernel_size=conv_stem_layer_config.kernel_size,
                        stride=conv_stem_layer_config.stride,
                        norm_layer=conv_stem_layer_config.norm_layer,
                        activation_layer=conv_stem_layer_config.activation_layer,
                    ),
                )
                prev_channels = conv_stem_layer_config.out_channels
            seq_proj.add_module(
                "conv_last", nn.Conv2d(in_channels=prev_channels, out_channels=hidden_dim, kernel_size=1)
            )
            self.conv_proj: nn.Module = seq_proj
        else:
            self.conv_proj = nn.Conv2d(
                in_channels=3, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size
            )

        seq_length = (image_size // patch_size) ** 2

        # Add a class token
        self.class_token = nn.Parameter(torch.zeros(1, 1, hidden_dim))
        seq_length += 1

        self.encoder = Encoder(
            seq_length,
            num_layers,
            num_heads,
            hidden_dim,
            mlp_dim,
            dropout,
            attention_dropout,
            norm_layer,
        )
        self.seq_length = seq_length

        heads_layers: OrderedDict[str, nn.Module] = OrderedDict()
        if representation_size is None:
            heads_layers["head"] = nn.Linear(hidden_dim, num_classes)
        else:
            heads_layers["pre_logits"] = nn.Linear(hidden_dim, representation_size)
            heads_layers["act"] = nn.Tanh()
            heads_layers["head"] = nn.Linear(representation_size, num_classes)

        self.heads = nn.Sequential(heads_layers)

        if isinstance(self.conv_proj, nn.Conv2d):
            # Init the patchify stem
            fan_in = self.conv_proj.in_channels * self.conv_proj.kernel_size[0] * self.conv_proj.kernel_size[1]
            nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in))
            if self.conv_proj.bias is not None:
                nn.init.zeros_(self.conv_proj.bias)
        elif self.conv_proj.conv_last is not None and isinstance(self.conv_proj.conv_last, nn.Conv2d):
            # Init the last 1x1 conv of the conv stem
            nn.init.normal_(
                self.conv_proj.conv_last.weight, mean=0.0, std=math.sqrt(2.0 / self.conv_proj.conv_last.out_channels)
            )
            if self.conv_proj.conv_last.bias is not None:
                nn.init.zeros_(self.conv_proj.conv_last.bias)

        if hasattr(self.heads, "pre_logits") and isinstance(self.heads.pre_logits, nn.Linear):
            fan_in = self.heads.pre_logits.in_features
            nn.init.trunc_normal_(self.heads.pre_logits.weight, std=math.sqrt(1 / fan_in))
            nn.init.zeros_(self.heads.pre_logits.bias)

        if isinstance(self.heads.head, nn.Linear):
            nn.init.zeros_(self.heads.head.weight)
            nn.init.zeros_(self.heads.head.bias)

    def _process_input(self, x: torch.Tensor) -> torch.Tensor:
        n, c, h, w = x.shape
        p = self.patch_size
        torch._assert(h == self.image_size, f"Wrong image height! Expected {self.image_size} but got {h}!")
        torch._assert(w == self.image_size, f"Wrong image width! Expected {self.image_size} but got {w}!")
        n_h = h // p
        n_w = w // p

        # (n, c, h, w) -> (n, hidden_dim, n_h, n_w)
        x = self.conv_proj(x)
        # (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w))
        x = x.reshape(n, self.hidden_dim, n_h * n_w)

        # (n, hidden_dim, (n_h * n_w)) -> (n, (n_h * n_w), hidden_dim)
        # The self attention layer expects inputs in the format (N, S, E)
        # where S is the source sequence length, N is the batch size, E is the
        # embedding dimension
        x = x.permute(0, 2, 1)

        return x

    def forward(self, x: torch.Tensor):
        # Reshape and permute the input tensor
        x = self._process_input(x)
        n = x.shape[0]

        # Expand the class token to the full batch
        batch_class_token = self.class_token.expand(n, -1, -1)
        x = torch.cat([batch_class_token, x], dim=1)

        x, early_classif = self.encoder(x)

        # Classifier "token" as used by standard language architectures
        x = x[:, 0]

        x = self.heads(x)

        return x, early_classif
        


In [None]:
def _vision_transformer(
    patch_size: int,
    num_layers: int,
    num_heads: int,
    hidden_dim: int,
    mlp_dim: int,
    weights: Optional[WeightsEnum],
    progress: bool,
    **kwargs: Any,
) -> VisionTransformer:
    if weights is not None:
        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
        assert weights.meta["min_size"][0] == weights.meta["min_size"][1]
        _ovewrite_named_param(kwargs, "image_size", weights.meta["min_size"][0])
    image_size = kwargs.pop("image_size", 224)

    model = VisionTransformer(
        image_size=image_size,
        patch_size=patch_size,
        num_layers=num_layers,
        num_heads=num_heads,
        hidden_dim=hidden_dim,
        mlp_dim=mlp_dim,
        **kwargs,
    )

    if weights:
        model.load_state_dict(weights.get_state_dict(progress=progress), strict=False)

    return model

In [None]:
# @register_model()
# @handle_legacy_interface(weights=("pretrained", ViT_B_16_Weights.IMAGENET1K_V1))
def vit_b_16(*, weights: Optional[ViT_B_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
    """
    Constructs a vit_b_16 architecture from
    `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_.
    Args:
        weights (:class:`~torchvision.models.ViT_B_16_Weights`, optional): The pretrained
            weights to use. See :class:`~torchvision.models.ViT_B_16_Weights`
            below for more details and possible values. By default, no pre-trained weights are used.
        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.vision_transformer.VisionTransformer``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/vision_transformer.py>`_
            for more details about this class.
    .. autoclass:: torchvision.models.ViT_B_16_Weights
        :members:
    """
    weights = ViT_B_16_Weights.verify(weights)

    return _vision_transformer(
        patch_size=16,
        num_layers=12,
        num_heads=12,
        hidden_dim=768,
        mlp_dim=3072,
        weights=weights,
        progress=progress,
        **kwargs,
    )

In [None]:
model = vit_b_16(weights= ViT_B_16_Weights.IMAGENET1K_V1)
print(model)

Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to /root/.cache/torch/hub/checkpoints/vit_b_16-c867db91.pth


  0%|          | 0.00/330M [00:00<?, ?B/s]

VisionTransformer(
  (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  (encoder): Encoder(
    (dropout): Dropout(p=0.0, inplace=False)
    (layers): Sequential(
      (encoder_layer_0): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate=none)
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=3072, out_features=768, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (encoder_layer_1): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_att

In [None]:
torch.cuda.empty_cache()

In [None]:
weights = ViT_B_16_Weights.IMAGENET1K_V1
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = vit_b_16(progress=True, weights=weights).to(device)

preprocess = weights.transforms()
transform = Compose([ToTensor(), Lambda(lambda x: preprocess(x))])

train_dataset = Flowers102(root='.', 
                        split='test',
                        download=True,
                        transform=transform)

test_dataset = Flowers102(root='.', 
                       split='train',
                       download=True,
                       transform=transform)

n_train = int(0.8*len(train_dataset))
n_valid = len(train_dataset) - n_train

train_dataset, valid_dataset = random_split(train_dataset, (n_train, n_valid))
print(len(train_dataset), len(valid_dataset), len(test_dataset))

4919 1230 1020


In [None]:
def with_freezed_params(model):
    for param in model.parameters():
      param.requires_grad = False
    return model


class BasicViT(nn.Module):

    def __init__(self, *args, **kwargs):
      super().__init__()
      # vit = vit_b_16(weights =  ViT_B_16_Weights.IMAGENET1K_V1)
      # print(len([p for p in vit.parameters() if p.requires_grad]))
      self.vit = with_freezed_params( vit_b_16(weights =  ViT_B_16_Weights.IMAGENET1K_V1))
      self.vit.heads = nn.Linear(in_features=768, out_features=kwargs.get('out_heads', 102), bias=True)

    def forward(self, x):
        return self.vit(x)

In [None]:
basic_model = BasicViT()
print(basic_model)
torch.save(basic_model, "BASIC_MODEL.pt")

BasicViT(
  (vit): VisionTransformer(
    (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (encoder): Encoder(
      (dropout): Dropout(p=0.0, inplace=False)
      (layers): Sequential(
        (encoder_layer_0): EncoderBlock(
          (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (self_attention): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          )
          (dropout): Dropout(p=0.0, inplace=False)
          (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (mlp): MLPBlock(
            (0): Linear(in_features=768, out_features=3072, bias=True)
            (1): GELU(approximate=none)
            (2): Dropout(p=0.0, inplace=False)
            (3): Linear(in_features=3072, out_features=768, bias=True)
            (4): Dropout(p=0.0, inplace=False)
          )
        )
        (encoder_layer_1): EncoderBlock(
          (ln_1): LayerNorm(

In [None]:
valid_loader = DataLoader(valid_dataset, batch_size=2, shuffle=False)
for x,y in valid_loader:
  # print(x[0],y[0])
  print(x.shape)
  break

output, early_class = basic_model(x)
# print(output)

torch.Size([2, 3, 224, 224])


In [None]:
print(len(early_class))
print(early_class[10].shape)
print(early_class[10][0])
print(early_class[10][0].sum())

12
torch.Size([2, 102])
tensor([0.0078, 0.0114, 0.0118, 0.0114, 0.0130, 0.0077, 0.0114, 0.0098, 0.0085,
        0.0108, 0.0154, 0.0088, 0.0109, 0.0100, 0.0105, 0.0109, 0.0091, 0.0082,
        0.0099, 0.0090, 0.0092, 0.0096, 0.0085, 0.0091, 0.0100, 0.0074, 0.0091,
        0.0158, 0.0097, 0.0099, 0.0103, 0.0099, 0.0082, 0.0090, 0.0097, 0.0104,
        0.0110, 0.0069, 0.0091, 0.0113, 0.0070, 0.0082, 0.0092, 0.0120, 0.0098,
        0.0104, 0.0086, 0.0067, 0.0082, 0.0101, 0.0088, 0.0108, 0.0094, 0.0114,
        0.0091, 0.0096, 0.0077, 0.0085, 0.0078, 0.0133, 0.0121, 0.0090, 0.0137,
        0.0094, 0.0112, 0.0098, 0.0109, 0.0085, 0.0084, 0.0111, 0.0075, 0.0072,
        0.0075, 0.0085, 0.0124, 0.0110, 0.0106, 0.0083, 0.0121, 0.0095, 0.0126,
        0.0099, 0.0093, 0.0105, 0.0128, 0.0118, 0.0069, 0.0076, 0.0094, 0.0096,
        0.0073, 0.0089, 0.0103, 0.0091, 0.0099, 0.0072, 0.0129, 0.0098, 0.0104,
        0.0105, 0.0108, 0.0065])
tensor(1.0000)


In [None]:
def valid(model, loader):
    model.eval()
    with torch.no_grad():
        # initialize the number of correct predictions
        correct: int = 0 
        N: int = 0

        for i, (x, y) in enumerate(loader):
            x, y = x.to(device), y.to(device)
            N += y.shape[0]

            # pass through the network
            output, early_class = model(x)

            # update the number of correctly predicted examples
            correct += sum([torch.argmax(output[k]) == y[k] for k in range(output.shape[0])])

    return correct / N  # , early_class


def run_epoch(model, optimizer, criterion, loader, optimizer2=None):
    """param :unfreezed: train the model during the last epoch with unfreezed all wieghts."""
    model.train()
    N: int = 0
    
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        N += y.shape[0]

        #don't accumulate gradients
        optimizer.zero_grad()
        if optimizer2:
            optimizer2.zero_grad()
        output, early_class = model(x)

        loss: torch.Tensor = criterion(output, target=y)
        #backwards pass through the network
        loss.backward()

        #apply gradients
        optimizer.step()
        if optimizer2:
            optimizer2.step()

    return early_class, y

def unfreeze_params(x,y,z):
    pass

def train_with_params(params, criterion, datasets, unfreezed = False, at_beginning=False):
    train_dataset, valid_dataset, test_dataset = datasets["train"], datasets["valid"], datasets["test"]
    train_loader = DataLoader(train_dataset, batch_size=params["batch_size"], shuffle=True)
    valid_loader = DataLoader(valid_dataset, batch_size=params['batch_size'], shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=params['batch_size'], shuffle=False)

    # test_model = BasicViT().to(device)  # for random weight initialization
    test_model = torch.load("BASIC_MODEL.pt")
    test_model = test_model.to(device)
    # test_model = torch.load("BASIC_MODEL.pt").to(device)

    optimizer = torch.optim.Adam([p for p in test_model.parameters() if p.requires_grad], lr=params['lr'])
    optimizer2 = None

    if unfreezed:
        unfreeze_params(test_model, unfreeze_params=True, all=False)
        print("After switching grads ON: ",len([p for p in test_model.parameters() if p.requires_grad]))
        optimizer2 = torch.optim.Adam([p for p in test_model.parameters() if p.requires_grad], lr=params['lr'])
        unfreeze_params(test_model, unfreeze_params=False, all=False)

    for epoch in range(params["epochs_num"]):
        if at_beginning and epoch == 0 and unfreezed:
            print("Training with unfreezed params, first epoch")
            unfreeze_params(test_model, unfreeze_params=True, all=True)
            epoch_train_loss = run_epoch(test_model, optimizer, criterion, train_loader, optimizer2=optimizer2)
            unfreeze_params(test_model, unfreeze_params=False)

        elif not at_beginning and epoch == params["epochs_num"]-1 and unfreezed:
            print("Training with unfreezed params, last epoch")
            unfreeze_params(test_model, unfreeze_params=True, all=True)
            epoch_train_loss = run_epoch(test_model, optimizer, criterion, train_loader, optimizer2=optimizer2)
            # unfreeze_params(test_model, unfreeze_params=False)
        else:
            print(f"Training with freezed params, epoch = {epoch}")
            early_classif, y = run_epoch(test_model, optimizer, criterion, train_loader)

            # let's observe the probabilities of the last picture, how its largest probability acts and changes label
            max_elem = 0
            max_idx = 0
            for idx, elem in enumerate(early_classif[-1][-1]):
                if max_elem < elem.cpu().numpy():
                    max_elem = elem.cpu().numpy()
                    max_idx = idx
            print(f"Last picture - max probability = {max_elem:.5f} with idx = {max_idx}, actual class = {y[-1]}")

    model_valid_acc = valid(test_model, valid_loader)
    # model_test_acc = valid(test_model, test_loader)

    return model_valid_acc, test_model #, model_test_acc


def make_params_grid(param_grid, max_num_sets=None, randomize=True):
    to_list = lambda x: [x] if not isinstance(x, Iterable) else x

    params = {k: to_list(v) for k, v in param_grid.items()}
    if randomize:
        grid = shuffle(ParameterGrid(params))
        return grid[:max_num_sets]

    return ParameterGrid(params)


def find_best_params(param_grid, max_num_sets, criterion, datasets, unfreezed=False, at_beginning=False):
    best_params = {}
    best_valid_acc = 0.0

    param_grid = make_params_grid(param_grid, max_num_sets, randomize=True)

    for i, params in enumerate(param_grid):
        # model_valid_acc, model_test_acc = train_with_params(params, optimizer, criterion, datasets)
        model_valid_acc, trained_model = train_with_params(params, criterion, datasets, unfreezed, at_beginning)
        print(f'Model: {i} trained, valid accuracy: {model_valid_acc:.4f}')

        if model_valid_acc > best_valid_acc:
            best_valid_acc = model_valid_acc
            best_params = params
            torch.save(trained_model, "BEST_PARAMS_MODEL.pt")

    print(f'Best params: {best_params}, best validation accuracy: {best_valid_acc}')
    test_loader = DataLoader(datasets["test"], batch_size=best_params['batch_size'], shuffle=False)
    best_model = torch.load("BEST_PARAMS_MODEL.pt")
    print(f'Test accuracy: {valid(best_model, test_loader)}' )

    return best_params

In [None]:
criterion = nn.CrossEntropyLoss()

param_grid = {
                'lr': [0.001],
                'epochs_num': [4],
                'batch_size': [32],
             }

max_num_sets = 1

datasets = {
            "train": train_dataset,
            "valid": valid_dataset,
            "test": test_dataset
            }

best_params = find_best_params(param_grid, max_num_sets, criterion, datasets, unfreezed=False)

Training with freezed params, epoch = 0
Last picture - max probability = 0.01650 with idx = 60, actual class = 16
Training with freezed params, epoch = 1
Last picture - max probability = 0.01675 with idx = 55, actual class = 48
Training with freezed params, epoch = 2
Last picture - max probability = 0.01567 with idx = 96, actual class = 50
Training with freezed params, epoch = 3
Last picture - max probability = 0.01628 with idx = 19, actual class = 89
Model: 0 trained, valid accuracy: 0.9301
Best params: {'lr': 0.001, 'epochs_num': 4, 'batch_size': 32}, best validation accuracy: 0.930081307888031
Test accuracy: 0.9117647409439087


In [None]:
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((464,464)),
        transforms.RandomRotation(15,),
        transforms.RandomCrop(448),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276])
    ]),
'val': transforms.Compose([
        transforms.Resize((448,448)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276])
    ]),

dodać augmentacje do zbioru, 