<a href="https://colab.research.google.com/github/GirfanovOV/Transformer_experiments/blob/main/Transformer_stuff.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [10]:
!pip install wandb -qU

In [11]:
from IPython.display import clear_output
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import math
import torchvision
from torchvision.utils import make_grid
import torchvision.transforms as tvf
import numpy as np
import matplotlib.pyplot as plt

from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import v2
from tqdm.autonotebook import tqdm

import gc
import wandb

In [12]:
wandb.login()



True

In [13]:
# RoPE
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)  # type: ignore
    freqs = torch.outer(t, freqs).float()  # type: ignore
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis


def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)


def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor):
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

In [14]:
class MultiheadAttention(nn.Module):
    def __init__(self, model_dim, num_heads, dropout=.1):
        super().__init__()
        assert(model_dim % num_heads == 0)
        self.model_dim = model_dim
        self.num_heads = num_heads
        self.head_dim = model_dim // num_heads

        # self.qkv = nn.Linear(model_dim, model_dim * 3)
        # self.out = nn.Linear(model_dim, model_dim)

        self.qkv = nn.Linear(model_dim, model_dim * 3, bias=False)
        self.out = nn.Linear(model_dim, model_dim, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, freqs_cis, mask):
        q, k ,v = self.qkv(x).split(self.model_dim, dim=-1)
        B, L, D = q.shape

        q = q.view((B, L, self.num_heads, self.head_dim))
        k = k.view((B, L, self.num_heads, self.head_dim))
        v = v.view((B, L, self.num_heads, self.head_dim)).transpose(1,2)

        q, k = apply_rotary_emb(q, k, freqs_cis)

        q = q.transpose(1,2)
        k = k.transpose(1,2)

        w = q @ k.transpose(-2,-1) / math.sqrt(self.head_dim)

        if mask is not None:
            w = w.masked_fill(mask, float("-inf"))

        w_norm = F.softmax(w, dim=-1)
        w_norm = self.dropout(w_norm)

        res = w_norm @ v
        res = res.transpose(1,2).contiguous().view((B, L, -1))
        return res

class FeedForward(nn.Module):
    def __init__(self, model_dim, ff_dim, dropout=.1):
        super().__init__()
        self.layer = nn.Sequential(
            nn.Linear(model_dim, ff_dim),
            nn.SiLU(),
            # nn.ReLU(),
            nn.Linear(ff_dim, model_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.layer(x)


class TransformerBlock(nn.Module):
    def __init__(self, model_dim, num_heads, ff_dim, dropout=.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(model_dim)
        self.attn = MultiheadAttention(model_dim, num_heads, dropout)
        self.ln2 = nn.LayerNorm(model_dim)
        self.ff = FeedForward(model_dim, ff_dim, dropout)

    def forward(self, x, freqs_cis, mask):
        x = x + self.attn(self.ln1(x), freqs_cis, mask)
        x = x + self.ff(self.ln2(x))
        return x


class TransformerModel(nn.Module):
    def __init__(self, num_img_toks, num_layers, max_seq_len, model_dim, num_heads, ff_dim, dropout=.1):
        super().__init__()

        self.model_in_special_toks = {
            "<BOS>" : num_img_toks + 0
        }

        self.model_in_vocab_size = num_img_toks + len(self.model_in_special_toks)

        self.tok_embed = nn.Embedding(self.model_in_vocab_size, model_dim)

        self.layers = nn.ModuleList(
            [TransformerBlock(model_dim, num_heads, ff_dim, dropout) for _ in range(num_layers)]
        )

        self.model_out_vocab_size = num_img_toks

        self.ln_f = nn.LayerNorm(model_dim)
        self.gen_head = nn.Linear(model_dim, self.model_out_vocab_size)
        self.max_seq_len = max_seq_len

        self.freqs_cis = precompute_freqs_cis(model_dim // num_heads, self.max_seq_len + 2 * len(self.model_in_special_toks))

    def forward(self, x):
        # add BOS token
        bsz = x.shape[0]
        device = x.device

        x = torch.cat(
            (
                torch.full((bsz, 1), self.model_in_special_toks["<BOS>"], dtype=torch.long, device=device),
                x
            ),
            dim=1
        )

        _, seqlen = x.shape
        causal_mask = torch.ones((seqlen, seqlen), device = device, dtype = torch.bool).triu(1)

        h = self.tok_embed(x)
        self.freqs_cis = self.freqs_cis.to(h.device)
        freqs_cis = self.freqs_cis[:seqlen]

        for layer in self.layers:
            h = layer(h, freqs_cis, causal_mask)
        final = self.gen_head(self.ln_f(h))
        return final

    @torch.no_grad()
    def generate(self, bsz=16, temperature=1.0, use_max=False):
        device = next(self.parameters()).device
        idx = torch.full((bsz, 1), fill_value=self.model_in_special_toks["<BOS>"], device=device, dtype=torch.long)
        max_new_tokens = 28 * 28;
        for _ in tqdm(range(max_new_tokens), leave=False):
            logits = self(idx)
            logits = logits[:, -1, :] / temperature
            # apply softmax to convert logits to (normalized) probabilities
            probs = F.softmax(logits, dim=-1)
            # sample from the distribution
            if use_max:
                idx_next = torch.argmax(probs, dim=-1, keepdim=True)
            else:
                idx_next = torch.multinomial(probs, num_samples=1)
            # append sampled index to the running sequence and continue
            idx = torch.cat((idx, idx_next), dim=1)

        return idx[:, 1:]

def show_gen_img(model_out):
    bsz = model_out.shape[0]
    model_out = model_out.to('cpu').view(bsz, 1, 28, 28)
    img_grid = make_grid(model_out, nrow=4).to(torch.uint8)
    img = tvf.functional.to_pil_image(img_grid)
    plt.imshow(np.asarray(img))

def prep_generated_img(model_out):
    bsz = model_out.shape[0]
    model_out = model_out.to('cpu').view(bsz, 1, 28, 28)
    img_grid = make_grid(model_out, nrow=4).to(torch.uint8)
    img = tvf.functional.to_pil_image(img_grid)
    return np.asarray(img)

In [15]:
device = "cuda" if torch.cuda.is_available() else "cpu"

config = {
    "epochs": 10,
    "lr": 1e-4,
    "batch_size": 32
}

run = wandb.init(
    project="Simple ViT",
    notes="Softmax Attention based vision transformer",
    tags=["GPT", "MNIST", "RoPE"],
    config=config
)

model = TransformerModel(
    num_img_toks=256,
    num_layers=4,
    max_seq_len=28*28,
    model_dim=128,
    num_heads=4,
    ff_dim=512,
    dropout=.1
).to(device)

s = round(sum([p.numel() for p in model.parameters()]) / 10e6, 2)
print(f"Model Size: {s}M params")

VBox(children=(Label(value='0.763 MB of 0.763 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
loss,▁

0,1
loss,2.1231


Model Size: 0.09M params


In [16]:
ds_path = "MNIST/"
transforms = v2.Compose([v2.ToImage(), v2.ToDtype(torch.long)])
mnist_train = torchvision.datasets.MNIST(ds_path, train=True, download=True, transform=transforms)
# train_dataloader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)

class MNIST_single_cls(Dataset):
    def __init__(self, mnist_dataset, cls):
        assert(0 <= cls <= 9)
        #single class
        # self.data = mnist_dataset.data[mnist_dataset.targets == cls].to(torch.long)
        # self.targets = torch.zeros((self.data.shape[0]))
        # single batch
        self.sample = mnist_dataset.data[mnist_dataset.targets == cls].to(torch.long)
        self.data = self.sample[:64]
        #single example
        # self.sample = mnist_dataset.data[mnist_dataset.targets == cls].to(torch.long)
        # self.data = self.sample[0] * torch.ones_like(self.sample)

    def __len__(self):
        return self.sample.shape[0]
        # return self.data.shape[0]

    def __getitem__(self, idx):
        return self.data[idx % 64]
        # return self.data[idx]


mnist_zeros = MNIST_single_cls(mnist_train, 0)
train_dataloader = DataLoader(mnist_zeros, batch_size=wandb.config['batch_size'], shuffle=True)
clear_output()

cls_weights = torch.ones((model.model_out_vocab_size))
cls_weights[0] = .01
optimizer = optim.Adam(model.parameters(), lr=wandb.config['lr'])
loss_fn = nn.CrossEntropyLoss(cls_weights).to(device)

loss_hist = []
curr_epoch = 0

In [18]:
gc.collect()
torch.cuda.empty_cache()
for epoch in range(wandb.config['epochs']):
    dloader = tqdm(train_dataloader, total=len(train_dataloader))
    for n, data in enumerate(dloader):
        optimizer.zero_grad()

        x = data
        x = x.to(device).flatten(start_dim=1)

        model_inp = x[:,:-1]
        model_out = model(model_inp)
        model_out = model_out.reshape((-1,model.model_out_vocab_size))

        target = x
        target = target.flatten()

        loss = loss_fn(model_out, target)
        loss.backward()
        optimizer.step()
        loss_val = loss.to('cpu').item()
        loss_hist.append(loss_val)

        if n % 10 == 0:
            # log to wandb
            model_out = model.generate()
            prepd_img = prep_generated_img(model_out)
            log_img = wandb.Image(prepd_img, capiton="Generated images")
            wandb.log({"loss": loss_val, "generated" : log_img})

        dloader.set_postfix_str(f"Loss: {loss_val:.4}")
wandb.finish()

  0%|          | 0/186 [00:00<?, ?it/s]

  0%|          | 0/784 [00:00<?, ?it/s]

KeyboardInterrupt: 