In [40]:
import torch
import torch.nn as nn
from torchvision.models import ViT_B_16_Weights
from torchvision.datasets import Flowers102
from torch.utils.data import DataLoader, random_split
from torchvision.transforms import ToTensor, Compose, Lambda
from src.early_exits import vit_b_16

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [41]:
weights = ViT_B_16_Weights.IMAGENET1K_V1
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

preprocess = weights.transforms()
transform = Compose([ToTensor(), Lambda(lambda x: preprocess(x))])

# dataset loading - test set switched with training set - done on purpose
train_dataset = Flowers102(root='.', 
                        split='test',
                        download=True,
                        transform=transform)

test_dataset = Flowers102(root='.', 
                       split='train',
                       download=True,
                       transform=transform)

n_train = int(0.8*len(train_dataset))
n_valid = len(train_dataset) - n_train

train_dataset, valid_dataset = random_split(train_dataset, (n_train, n_valid))
print(len(train_dataset), len(valid_dataset), len(test_dataset))

4919 1230 1020


In [42]:
def freeze_params(model):
    for idx, param in enumerate(model.parameters()):
        if idx >= 150:
            param.requires_grad = True
        else:
            param.requires_grad = False
    return model

basic_model = freeze_params(vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1))
torch.save(basic_model, "BASIC_MODEL.pt")
print(basic_model)

CustomViT(
  (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  (encoder): CustomEncoder(
    (dropout): Dropout(p=0.0, inplace=False)
    (layers): Sequential(
      (encoder_layer_0): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=3072, out_features=768, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (encoder_layer_1): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_att

In [45]:
def valid(model, loader) -> float:
    model.eval() # switch dropouts off
    with torch.no_grad():
        # initialize the number of correct predictions
        correct: int = 0 
        N: int = 0

        for i, (x, y) in enumerate(loader):
            x, y = x.to(device), y.to(device)
            N += y.shape[0]

            # pass through the network, remeber about early exits outputs
            output, early_class = model(x)

            # update the number of correctly predicted examples
            correct += sum([torch.argmax(output[k]) == y[k] for k in range(output.shape[0])])

    return correct / N


def run_epoch(model, optimizer, criterion, loader, optimizer2=None):
    model.train()  # switch on dropouts
    N: int = 0
    
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        N += y.shape[0]

        #don't accumulate gradients
        optimizer.zero_grad()
        if optimizer2:
            optimizer2.zero_grad()

        output, early_class = model(x)

        loss: torch.Tensor = criterion(output, target=y)
        #backwards pass through the network
        loss.backward()

        #apply gradients
        optimizer.step()
        if optimizer2:
            optimizer2.step()

    return early_class, y


def train_with_params(params, criterion, datasets):
    train_dataset, valid_dataset = datasets["train"], datasets["valid"]
    train_loader = DataLoader(train_dataset, batch_size=params["batch_size"], shuffle=True)
    valid_loader = DataLoader(valid_dataset, batch_size=params['batch_size'], shuffle=False)

    test_model = torch.load("BASIC_MODEL.pt")
    test_model = test_model.to(device)

    weights = [p for p in test_model.parameters() if p.requires_grad]
    optimizer = torch.optim.Adam(weights, lr=params['lr'])

    for epoch in range(params["epochs_num"]):
        print(f"Training with freezed params, epoch = {epoch}")
        early_classif, y = run_epoch(test_model, optimizer, criterion, train_loader)
        # print(len(early_classif))  # 12 x [m, 102]
        # for each classification head, get probabilities for the last batch containing m <= 32 pictures

        # let's observe the probabilities of the last picture,
        # how its largest probability acts and predicts labels
        max_elem = 0
        max_idx = 0
        for idx, class_head in enumerate(early_classif):
            for cls, elem in enumerate(class_head[-1]):
                if max_elem < elem.detach().cpu().numpy():
                    max_elem = elem.detach().cpu().numpy()
                    max_idx = cls
            print(f"Early exit nr {idx}, last picture - max probability = {max_elem:.5f} for class = {max_idx}, actual class = {y[-1]}")
            max_elem = 0
            max_idx = 0

        torch.cuda.empty_cache()

    model_valid_acc = valid(test_model, valid_loader)
    model_valid_acc = model_valid_acc.detach().cpu().numpy()

    return model_valid_acc, test_model

In [46]:
criterion = nn.CrossEntropyLoss()

params = {
          'lr': 0.001,
          'epochs_num': 4,
          'batch_size': 32,
          }

datasets = {
            "train": train_dataset,
            "valid": valid_dataset,
            "test": test_dataset
            }

acc, trained_model = train_with_params(params=params, criterion=criterion, datasets=datasets)

Training with freezed params, epoch = 0
Early exit nr 0, last picture - max probability = 0.01430 for class = 66, actual class = 46
Early exit nr 1, last picture - max probability = 0.01246 for class = 27, actual class = 46
Early exit nr 2, last picture - max probability = 0.01444 for class = 100, actual class = 46
Early exit nr 3, last picture - max probability = 0.01451 for class = 61, actual class = 46
Early exit nr 4, last picture - max probability = 0.01687 for class = 70, actual class = 46
Early exit nr 5, last picture - max probability = 0.01573 for class = 88, actual class = 46
Early exit nr 6, last picture - max probability = 0.01495 for class = 77, actual class = 46
Early exit nr 7, last picture - max probability = 0.01803 for class = 92, actual class = 46
Early exit nr 8, last picture - max probability = 0.01344 for class = 60, actual class = 46
Early exit nr 9, last picture - max probability = 0.01406 for class = 33, actual class = 46
Early exit nr 10, last picture - max pr

In [47]:
print("Accuracy on the validation dataset:", acc)

Accuracy on the validation dataset: 0.94715446
