In [None]:
%load_ext autoreload
%autoreload 2
%load_ext tensorboard

In [None]:
import torch
import torch.nn.functional as F
from torch import nn
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# TODO

> Zbudować model
> podzielić zdjęcia na patche 4x4, albo 7x7

In [None]:
import numpy as np

class Net1(nn.Module):
    def __init__(self, patch_dim, img_dims, n_layers, n_heads, mult_hidden=10):
        super().__init__()
        # sizes
        hidden_dim = img_dims[0] * patch_dim ** 2
        seq_len = np.product(img_dims) // hidden_dim
        self.patch_dim = patch_dim
        print(hidden_dim, seq_len, patch_dim)
        
        #input
        self.init_proj = nn.Linear(hidden_dim, mult_hidden*hidden_dim)
        self.init_layernorm = nn.LayerNorm(mult_hidden*hidden_dim)
        self.init_dropout = nn.Dropout(0.1)
        self.pos_emb = nn.Parameter(torch.zeros(1, seq_len, mult_hidden*hidden_dim))
        
        # transformer encoder + mlp
        self.mhsa_layer = nn.TransformerEncoderLayer(mult_hidden*hidden_dim, n_heads, dim_feedforward=4*mult_hidden*hidden_dim, batch_first=True, activation=nn.GELU())
        self.mhsa = nn.TransformerEncoder(self.mhsa_layer, num_layers=n_layers, norm=nn.LayerNorm(mult_hidden*hidden_dim))
        self.mlp = nn.Sequential(nn.Linear(mult_hidden*hidden_dim, 2*mult_hidden*hidden_dim), nn.ReLU(), nn.Linear(2*mult_hidden*hidden_dim, 10))
        
        # additional parameters
        self.att1 = nn.Parameter(torch.randn(seq_len, 1) / np.sqrt(seq_len))
        self.cls_token = nn.Parameter(torch.randn(1, 1, mult_hidden*hidden_dim))
        
    def forward(self, x):
        batch_size = x.shape[0]
        x = self.init_proj(x)
        img_patches = torch.cat(
                (self.expand_cls_to_batch(batch_size), x), dim=1)
        x = x + self.pos_emb
        x = self.init_layernorm(x)
        x = self.init_dropout(x)
        x = self.mhsa(x)
        x = self.mlp(x[:, 0, :])
        return x
    
    def reduction_att(self, x):
        # perform attention to reduce dimensinality
        att = F.softmax(((x @ x.transpose(-2,-1)) @ self.att1 / x.size(-1)).transpose(-2,-1), dim=-1)
        x = att @ x
        x = x.flatten(-2)
        return x
    
    def expand_cls_to_batch(self, batch):
        """
        Args:
            batch: batch size
        Returns: cls token expanded to the batch size
        """
        return self.cls_token.expand([batch, -1, -1])

In [None]:
from torchvision.datasets import MNIST, CIFAR10
from torchvision.transforms import Compose, ToTensor, Normalize
from torch.utils.data import DataLoader

transform = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
dataset_train = CIFAR10('./data', train=True, transform=transform, download=True)
dataset_test = CIFAR10('./data', train=False, transform=transform, download=True)

loader_train = DataLoader(dataset_train, batch_size=64, shuffle=True, pin_memory=True, num_workers=4)
loader_test = DataLoader(dataset_test, batch_size=64, shuffle=False, pin_memory=True, num_workers=4)

loaders = {
    'train': loader_train,
    'test': loader_test
}

In [None]:
def get_patches_naive(x, step):
    patches = []
    for i in range(x.shape[-1]//step):
        for j in range(x.shape[-1]//step):
            patches.append(x[:,:,i*step:(i+1)*step,j*step:(j+1)*step].flatten(start_dim=1))
    return torch.stack(patches, axis=0).transpose(0,1)

# patches = get_patches_naive(x, 4)
# patches.shape

In [None]:
import datetime
from tqdm.auto import tqdm
from tensorboard_pytorch import TensorboardPyTorch

def simple_trainer(model, loaders, criterion, optim, writer, epoch_start, epoch_end, phases=['train', 'test']):
    for epoch in tqdm(range(epoch_start, epoch_end)):
        for phase in phases:
            running_acc = 0.0
            running_loss = 0.0
            model.train() if 'train' in phase else model.eval()
            for x_true, y_true in loaders[phase]:
                x_true, y_true = x_true.to(device), y_true.to(device)
                x_true = get_patches_naive(x_true, model.patch_dim)
                y_pred = model(x_true)
                loss = criterion(y_pred, y_true)
                if phase == phases[0]:
                    optim.zero_grad()
                    loss.backward()
                    optim.step()
                running_acc += (torch.argmax(y_pred.detach().data, dim=1) == y_true).sum().item()
                running_loss += loss.item() * x_true.size(0)

            epoch_acc = running_acc / len(loaders[phase].dataset)
            epoch_loss = running_loss / len(loaders[phase].dataset)
            writer.log_scalar(f'Acc/{phase}', round(epoch_acc, 4), epoch + 1)
            writer.log_scalar(f'Loss/{phase}', round(epoch_loss, 4), epoch + 1)

In [None]:
%tensorboard --logdir=tensorboard

In [None]:
EPOCHS = 150
import madgrad

model = Net1(patch_dim=8, img_dims=(3, 32, 32), n_layers=8, n_heads=8, mult_hidden=4).to(device)
criterion = nn.CrossEntropyLoss().to(device)
# optim = madgrad.MADGRAD(model.parameters(), lr=1e-2, momentum=0.9)
optim = torch.optim.SGD(model.parameters(), lr=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=150)

date = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
writer = TensorboardPyTorch(f'tensorboard/ViT/cifar10/sgd/cls_token/pos_emb_8x8x8x4_extended/{date}', device)

In [None]:
simple_trainer(model, loaders, criterion, optim, writer, epoch_start=0, epoch_end=EPOCHS, phases=['train', 'test'])

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
count_parameters(model)

# Setting

# Learned Representation

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.figure(figsize=(10,10))
plt.imshow(model.pos_emb.detach().cpu().squeeze(0))
plt.xlabel('Emb Dim')
plt.ylabel('Position')

In [None]:
model.pos_emb.detach().cpu().squeeze(0).shape

In [None]:
pos_emb.shape

In [None]:
768 // 64

In [None]:
mult = 4

pos_emb = model.pos_emb.cpu().detach().squeeze(0)
fig, axes = plt.subplots(4, 4, figsize=(10,10))
for i in range(16):
    axes[i//4][i%4].imshow(pos_emb[i].reshape(8*4,8*3))

# Check Trainer

In [None]:
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


net = Net()

In [None]:
import datetime
from tqdm.auto import tqdm
from tensorboard_pytorch import TensorboardPyTorch

def simple_trainer(model, loaders, criterion, optim, writer, epoch_start, epoch_end, phases=['train', 'test']):
    for epoch in tqdm(range(epoch_start, epoch_end)):
        for phase in phases:
            running_acc = 0.0
            running_loss = 0.0
            model.train() if 'train' in phase else model.eval()
            for x_true, y_true in loaders[phase]:
                x_true, y_true = x_true.to(device), y_true.to(device)
                y_pred = model(x_true)
                loss = criterion(y_pred, y_true)
                if phase == phases[0]:
                    optim.zero_grad()
                    loss.backward()
                    optim.step()
                running_acc += (torch.argmax(y_pred.detach().data, dim=1) == y_true).sum().item()
                running_loss += loss.item() * x_true.size(0)

            epoch_acc = running_acc / len(loaders[phase].dataset)
            epoch_loss = running_loss / len(loaders[phase].dataset)
            writer.log_scalar(f'Acc/{phase}', round(epoch_acc, 4), epoch + 1)
            writer.log_scalar(f'Loss/{phase}', round(epoch_loss, 4), epoch + 1)

In [None]:
criterion = nn.CrossEntropyLoss().to(device)
optim = torch.optim.Adam(model.parameters(), lr=0.01)
date = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
writer = TensorboardPyTorch(f'tensorboard/check_trainer/{date}', device)

# Pytorch lightning

In [None]:
model = MyLightningModule()

trainer = Trainer()
trainer.fit(model, train_dataloader, val_dataloader)

In [None]:

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl


import numpy as np

class Net1(pl.LightningModule):
    def __init__(self, patch_dim, img_dims, n_layers, n_heads, mult_hidden=10):
        super().__init__()
        # sizes
        hidden_dim = img_dims[0] * patch_dim ** 2
        seq_len = np.product(img_dims) // hidden_dim
        self.patch_dim = patch_dim
        print(hidden_dim, seq_len, patch_dim)
        
        #input
        self.init_proj = nn.Linear(hidden_dim, mult_hidden*hidden_dim)
        self.init_layernorm = nn.LayerNorm(mult_hidden*hidden_dim)
        self.init_dropout = nn.Dropout(0.1)
        self.pos_emb = nn.Parameter(torch.zeros(1, seq_len, mult_hidden*hidden_dim))
        
        # transformer encoder + mlp
        self.mhsa_layer = nn.TransformerEncoderLayer(mult_hidden*hidden_dim, n_heads, dim_feedforward=4*mult_hidden*hidden_dim, batch_first=True, activation=nn.GELU())
        self.mhsa = nn.TransformerEncoder(self.mhsa_layer, num_layers=n_layers, norm=nn.LayerNorm(mult_hidden*hidden_dim))
        self.mlp = nn.Sequential(nn.Linear(mult_hidden*hidden_dim, 2*mult_hidden*hidden_dim), nn.ReLU(), nn.Linear(2*mult_hidden*hidden_dim, 10))
        
        # additional parameters
        self.att1 = nn.Parameter(torch.randn(seq_len, 1) / np.sqrt(seq_len))
        self.cls_token = nn.Parameter(torch.randn(1, 1, mult_hidden*hidden_dim))
        self.criterion = nn.CrossEntropyLoss()
        
    def forward(self, x):
        batch_size = x.shape[0]
        x = self.init_proj(x)
        img_patches = torch.cat(
                (self.expand_cls_to_batch(batch_size), x), dim=1)
        x = x + self.pos_emb
        x = self.init_layernorm(x)
        x = self.init_dropout(x)
        x = self.mhsa(x)
        x = self.mlp(x[:, 0, :])
        return x
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        x = get_patches_naive(x, 7)
        z = self.forward(x)    
        loss = self.criterion(z, y)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        x = get_patches_naive(x, 7)
        z = self.forward(x)
        loss = self.criterion(z, y)
        self.log('val_loss', loss)
    
    def reduction_att(self, x):
        # perform attention to reduce dimensinality
        att = F.softmax(((x @ x.transpose(-2,-1)) @ self.att1 / x.size(-1)).transpose(-2,-1), dim=-1)
        x = att @ x
        x = x.flatten(-2)
        return x
    
    def expand_cls_to_batch(self, batch):
        """
        Args:
            batch: batch size
        Returns: cls token expanded to the batch size
        """
        return self.cls_token.expand([batch, -1, -1])
    
    def get_patches_naive(self, x, step):
        patches = []
        for i in range(x.shape[-1]//step):
            for j in range(x.shape[-1]//step):
                patches.append(x[:,:,i*step:(i+1)*step,j*step:(j+1)*step].flatten(start_dim=1))
        return torch.stack(patches, axis=0).transpose(0,1)

# patches = get_patches_naive(x, 4)
# patches.shape



# data
dataset = MNIST('', train=True, download=True, transform=transforms.ToTensor())
mnist_train, mnist_val = random_split(dataset, [55000, 5000])

train_loader = DataLoader(mnist_train, batch_size=64)
val_loader = DataLoader(mnist_val, batch_size=64)

# model
model = Net1(patch_dim=7, img_dims=(1, 28, 28), n_layers=7, n_heads=7, mult_hidden=4).to(device)

# training
trainer = pl.Trainer(gpus=1, num_nodes=1, precision=16, limit_train_batches=0.5)
trainer.fit(model, train_loader, val_loader)
    


In [None]:
%reload_ext tensorboard
%tensorboard --logdir=lightning_logs/