In [None]:
%load_ext autoreload
%autoreload 2
%load_ext tensorboard

In [None]:
import torch
import torch.nn.functional as F
from torch import nn
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# TODO

> Zbudować model
> podzielić zdjęcia na patche 4x4, albo 7x7

In [None]:
import torch
import torch.nn as nn
from einops import rearrange

class ViT(nn.Module):
    def __init__(self,
                 img_dim,
                 in_channels=3,
                 patch_dim=16,
                 num_classes=10,
                 dim=512,
                 blocks=6,
                 heads=4,
                 dim_linear_block=1024,
                 dim_head=None,
                 dropout=0, transformer=None, classification=True):
        """
        Args:
            img_dim: the spatial image size
            in_channels: number of img channels
            patch_dim: desired patch dim
            num_classes: classification task classes
            dim: the linear layer's dim to project the patches for MHSA
            blocks: number of transformer blocks
            heads: number of heads
            dim_linear_block: inner dim of the transformer linear block
            dim_head: dim head in case you want to define it. defaults to dim/heads
            dropout: for pos emb and transformer
            transformer: in case you want to provide another transformer implementation
            classification: creates an extra CLS token
        """
        super().__init__()
        assert img_dim % patch_dim == 0, f'patch size {patch_dim} not divisible'
        self.p = patch_dim
        self.classification = classification
        tokens = (img_dim // patch_dim) ** 2
        self.token_dim = in_channels * (patch_dim ** 2)
        self.dim = dim
        self.dim_head = (int(dim / heads)) if dim_head is None else dim_head
        self.project_patches = nn.Linear(self.token_dim, dim)

        self.emb_dropout = nn.Dropout(dropout)
        if self.classification:
            self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
            self.pos_emb1D = nn.Parameter(torch.randn(tokens + 1, dim))
            self.mlp_head = nn.Linear(dim, num_classes)
        else:
            self.pos_emb1D = nn.Parameter(torch.randn(tokens, dim))

        if transformer is None:
            self.mhsa_layer = nn.TransformerEncoderLayer(dim, self.dim_head, dim_feedforward=dim_linear_block, batch_first=True, activation=nn.GELU())
            self.mhsa = nn.TransformerEncoder(self.mhsa_layer, num_layers=blocks, norm=nn.LayerNorm(dim))
        else:
            self.transformer = transformer

    def expand_cls_to_batch(self, batch):
        """
        Args:
            batch: batch size
        Returns: cls token expanded to the batch size
        """
        return self.cls_token.expand([batch, -1, -1])

    def forward(self, img, mask=None):
        batch_size = img.shape[0]
        img_patches = rearrange(
            img, 'b c (patch_x x) (patch_y y) -> b (x y) (patch_x patch_y c)',
                                patch_x=self.p, patch_y=self.p)
        # project patches with linear layer + add pos emb
        img_patches = self.project_patches(img_patches)

        if self.classification:
            img_patches = torch.cat(
                (self.expand_cls_to_batch(batch_size), img_patches), dim=1)

        patch_embeddings = self.emb_dropout(img_patches + self.pos_emb1D)

        # feed patch_embeddings and output of transformer. shape: [batch, tokens, dim]
        y = self.mhsa(patch_embeddings, mask)

        if self.classification:
            # we index only the cls token for classification. nlp tricks :P
            return self.mlp_head(y[:, 0, :])
        else:
            return y

In [None]:
from torchvision.datasets import MNIST, CIFAR10
from torchvision.transforms import Compose, ToTensor, Normalize
from torch.utils.data import DataLoader

transform = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
dataset_train = CIFAR10('./data', train=True, transform=transform, download=True)
dataset_test = CIFAR10('./data', train=False, transform=transform, download=True)

loader_train = DataLoader(dataset_train, batch_size=64, shuffle=True, pin_memory=True, num_workers=4)
loader_test = DataLoader(dataset_test, batch_size=64, shuffle=False, pin_memory=True, num_workers=4)

loaders = {
    'train': loader_train,
    'test': loader_test
}

In [None]:
def get_patches_naive(x, step):
    patches = []
    for i in range(x.shape[-1]//step):
        for j in range(x.shape[-1]//step):
            patches.append(x[:,:,i*step:(i+1)*step,j*step:(j+1)*step].flatten(start_dim=1))
    return torch.stack(patches, axis=0).transpose(0,1)

# patches = get_patches_naive(x, 4)
# patches.shape

In [None]:
import datetime
from tqdm.auto import tqdm
from tensorboard_pytorch import TensorboardPyTorch

def simple_trainer(model, loaders, criterion, optim, writer, epoch_start, epoch_end, phases=['train', 'test']):
    for epoch in tqdm(range(epoch_start, epoch_end)):
        for phase in phases:
            running_acc = 0.0
            running_loss = 0.0
            model.train() if 'train' in phase else model.eval()
            for x_true, y_true in loaders[phase]:
                x_true, y_true = x_true.to(device), y_true.to(device)
                # x_true = get_patches_naive(x_true, 16)
                y_pred = model(x_true)
                loss = criterion(y_pred, y_true)
                if phase == phases[0]:
                    optim.zero_grad()
                    loss.backward()
                    optim.step()
                running_acc += (torch.argmax(y_pred.detach().data, dim=1) == y_true).sum().item()
                running_loss += loss.item() * x_true.size(0)

            epoch_acc = running_acc / len(loaders[phase].dataset)
            epoch_loss = running_loss / len(loaders[phase].dataset)
            writer.log_scalar(f'Acc/{phase}', round(epoch_acc, 4), epoch + 1)
            writer.log_scalar(f'Loss/{phase}', round(epoch_loss, 4), epoch + 1)

In [None]:
%tensorboard --logdir=tensorboard

In [None]:
EPOCHS = 100
import madgrad
from ImageTransformer import ViT

model = ViT(
        patch_height = 16,
        patch_width = 16,
        embedding_dims = 768,
        dropout = 0.1,
        heads = 4,
        num_layers = 4,
        forward_expansion = 4,
        max_len = int((32*32)/(16*16)),
        layer_norm_eps = 1e-5,
        num_classes = 10,
    ).to(device)
criterion = nn.CrossEntropyLoss().to(device)
# optim = madgrad.MADGRAD(model.parameters(), lr=1e-2, momentum=0.9)
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
date = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
writer = TensorboardPyTorch(f'tensorboard/ViT/cifar10/sgd/ShivamRajSharma_model/{date}', device)

In [None]:
simple_trainer(model, loaders, criterion, optim, writer, epoch_start=0, epoch_end=EPOCHS, phases=['train', 'test'])

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
count_parameters(model)

# Setting

# Learned Representation

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.figure(figsize=(10,10))
plt.imshow(model.pos_emb.detach().cpu().squeeze(0))
plt.xlabel('Emb Dim')
plt.ylabel('Position')

In [None]:
model.pos_emb.detach().cpu().squeeze(0).shape

In [None]:
mult = 8

pos_emb = model.pos_emb.cpu().detach().squeeze(0)
fig, axes = plt.subplots(4, 4, figsize=(10,10))
for i in range(16):
    axes[i//4][i%4].imshow(pos_emb[i].reshape(4*7,2*7))

# Check Trainer

In [None]:
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


model = Net().to(device)

In [None]:
import datetime
from tqdm.auto import tqdm
from tensorboard_pytorch import TensorboardPyTorch

def simple_trainer(model, loaders, criterion, optim, writer, epoch_start, epoch_end, phases=['train', 'test']):
    for epoch in tqdm(range(epoch_start, epoch_end)):
        for phase in phases:
            running_acc = 0.0
            running_loss = 0.0
            model.train() if 'train' in phase else model.eval()
            for x_true, y_true in loaders[phase]:
                x_true, y_true = x_true.to(device), y_true.to(device)
                y_pred = model(x_true)
                loss = criterion(y_pred, y_true)
                if phase == phases[0]:
                    optim.zero_grad()
                    loss.backward()
                    optim.step()
                running_acc += (torch.argmax(y_pred.detach().data, dim=1) == y_true).sum().item()
                running_loss += loss.item() * x_true.size(0)

            epoch_acc = running_acc / len(loaders[phase].dataset)
            epoch_loss = running_loss / len(loaders[phase].dataset)
            writer.log_scalar(f'Acc/{phase}', round(epoch_acc, 4), epoch + 1)
            writer.log_scalar(f'Loss/{phase}', round(epoch_loss, 4), epoch + 1)

In [None]:
criterion = nn.CrossEntropyLoss().to(device)
optim = torch.optim.Adam(model.parameters(), lr=0.01)
date = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
writer = TensorboardPyTorch(f'tensorboard/check_trainer/{date}', device)

In [None]:
EPOCHS = 100
simple_trainer(model, loaders, criterion, optim, writer, epoch_start=0, epoch_end=EPOCHS, phases=['train', 'test'])