In [1]:
import torch
from torch import tensor, nn, optim
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR

In [2]:
with open("data/input.txt") as f:
    text = f.read()

In [3]:
from transformers import AutoTokenizer, AutoModelForCausalLM

model_name = "openai-community/gpt2"

tokenizer = AutoTokenizer.from_pretrained(model_name)

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

In [4]:
from datasets import load_dataset

ds = load_dataset("myothiha/jokes")

README.md:   0%|          | 0.00/21.0 [00:00<?, ?B/s]

train.csv:   0%|          | 0.00/19.7M [00:00<?, ?B/s]

validation.csv:   0%|          | 0.00/2.18M [00:00<?, ?B/s]

test.csv:   0%|          | 0.00/2.43M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/187641 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/20850 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/23166 [00:00<?, ? examples/s]

In [5]:
len(set(ds['train']['text']))

187641

In [6]:
text = '<|endoftext|>'.join(ds['train']['text'])

In [7]:
vocab_size = tokenizer.vocab_size
context_length = 32
n_embs = 128
n_heads = 16
n_blocks = 8
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [8]:
len(text)

20160500

In [9]:
class TextDataLoader():
    def __init__(self, text, context_length, tokenizer, batch_size=1, device='cpu', mask = False):
        v = tokenizer(text, return_tensors='pt')
        self.tokens = v.input_ids # The attention mask will be handled manually later
        self.masks =  torch.ones(batch_size, context_length) # for finetuning huggingface models
        self.batch_size = batch_size
        self.context_length = context_length
        self.device = device
        
        self.position = 0
        
    def __iter__(self):
        self.reset()
        return self
        
    def __next__(self):
        B, T = self.batch_size, self.context_length
        if self.position + B * T + 1 < len(self.tokens[0]):
            tokens = self.tokens[0][self.position: self.position + B * T + 1]
            self.position += B * T + 1
            x = tokens[:-1].view(B, T)
            y = tokens[1:].view(B, T)
            if self.masks is not None:
                return x.to(self.device), self.masks.to(self.device), y.to(self.device)
            return x.to(self.device), y.to(self.device)
        else:
            raise StopIteration
            
    def __len__(self):
        return len(self.tokens[0]) // (self.context_length + 1) // self.batch_size
    
    def reset(self):
        self.position = 0
        

The `head_size` matches that of the embeddings if it is a single head.

If multi-headed attention is used, then the `head_size` would equal number of embeddings // number of heads.

In [10]:
class SelfAttentionHead(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.k = nn.Linear(n_embs, head_size, bias=False)
        self.q = nn.Linear(n_embs, head_size, bias=False)
        self.v = nn.Linear(n_embs, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(context_length, context_length)))
        
    def forward(self, x):
        B, T, C = x.shape
        k = self.k(x) # (B, T, head_size)
        q = self.q(x) # (B, T, head_size)
        v = self.v(x) # (B, T, head_size)
        
#         attn = q @ k.transpose(-2, -1) * C ** -0.5 # (B, T, T)
        
#         wei = attn.masked_fill(self.tril[:T, :T] == 0, float("-inf")) # Causal masking, blocks future tokens from being seen
#         wei = wei.softmax(dim=-1) # (B, T, T)
        
#         out = wei @ v # (B, T, head_size)
        out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
        return out

In [11]:
class MultiHeadAttention(nn.Module):
    def __init__(self, n_heads):
        super().__init__()
        self.heads = nn.ModuleList([SelfAttentionHead(n_embs // n_heads) for i in range(n_heads)])
        self.proj = nn.Linear(n_embs, n_embs)
        
    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.proj(out)
        return out

In [12]:
class FeedForward(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embs, n_embs * 4),
            nn.GELU(),
            nn.Linear(n_embs * 4, n_embs)
        )
        
    def forward(self, x):
        out = self.net(x)
        return out

In [13]:
class TransformerBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.ln1 = nn.LayerNorm(n_embs)
        self.mha = MultiHeadAttention(n_heads)
        self.ln2 = nn.LayerNorm(n_embs)
        self.ffwd = FeedForward()
        
    def forward(self, x):
        x = self.mha(self.ln1(x)) + x
        x = self.ffwd(self.ln2(x)) + x
        
        return x
        

In [14]:
class GPT(nn.Module):
    def __init__(self):
        super().__init__()
        self.tk_emb = nn.Embedding(vocab_size, n_embs)
        self.pos_emb = nn.Embedding(context_length, n_embs)
        self.blocks = nn.Sequential(*[TransformerBlock() for i in range(n_blocks)])
        self.ln_f = nn.LayerNorm(n_embs)
        self.fc = nn.Linear(n_embs, vocab_size)
        
        self.tk_emb.weight = self.fc.weight
        
    def forward(self, x, targets=None):
        B, T = x.shape
        tk_emb = self.tk_emb(x) # (B, T, C)
        pos_tns = torch.arange(T, device=device) # T
        pos_emb = self.pos_emb(pos_tns) # (T, C)
        x = pos_emb + tk_emb # (B, T, C) + (T, C)
        
        x = self.blocks(x)
        
        x = self.ln_f(x)
        logits = self.fc(x) # (B, T, vocab_size)
        if targets is None:
            return logits
        else:
            loss = F.cross_entropy(logits.view(B * T, -1), targets.view(B*T))
            return logits, loss
        
    def generate(self):
        pass

In [15]:
def initialize(mod):
    if isinstance(mod, nn.Linear):
        torch.nn.init.kaiming_normal_(mod.weight)
        if mod.bias is not None:
            torch.nn.init.zeros_(mod.bias)

In [16]:
torch.set_float32_matmul_precision('high')


In [17]:
dl = TextDataLoader(text, context_length, tokenizer, batch_size=16, device=device, mask=True)

Token indices sequence length is longer than the specified maximum sequence length for this model (4673544 > 1024). Running this sequence through the model will result in indexing errors


In [18]:
len(dl)

8851

In [19]:
# model = GPT().to(device)
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
model.lm_head = nn.Linear(768, 50257).to(device)
# model.apply(initialize)
lr = 7e-4
opt = optim.AdamW(model.parameters(), lr)

sum(p.numel() for p in model.parameters())

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

163087441

In [20]:
xb, masks, yb = next(iter(dl))
model(input_ids=xb, attention_mask=masks, labels=yb).loss

tensor(18.7118, device='cuda:0', grad_fn=<NllLossBackward0>)

In [21]:
model

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2SdpaAttention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=True)
)

In [22]:
# Training loop for fine-tuning a huggingface model
epochs = 1
sched = CosineAnnealingLR(opt, epochs * len(dl), lr * 0.01)
model.train()
for i in range(epochs):
    for step, (xb, mask, yb) in enumerate(dl):
        opt.zero_grad()
        with torch.autocast(device_type=device, dtype=torch.bfloat16):
            output = model(input_ids=xb, attention_mask=mask, labels=yb)
            loss = output.loss
        loss.backward()
        opt.step()
        sched.step()
        if step % (len(dl)//20) == 0 or step == len(dl):
            print(f"Epoch: {i}, Step: {step}, Loss: {loss}")

Epoch: 0, Step: 0, Loss: 18.450794219970703
Epoch: 0, Step: 442, Loss: 6.331844806671143
Epoch: 0, Step: 884, Loss: 5.9788737297058105
Epoch: 0, Step: 1326, Loss: 5.828938007354736
Epoch: 0, Step: 1768, Loss: 5.4010844230651855
Epoch: 0, Step: 2210, Loss: 5.53565788269043
Epoch: 0, Step: 2652, Loss: 5.770430088043213
Epoch: 0, Step: 3094, Loss: 5.153104782104492
Epoch: 0, Step: 3536, Loss: 5.821743488311768
Epoch: 0, Step: 3978, Loss: 5.305891036987305
Epoch: 0, Step: 4420, Loss: 5.067082405090332
Epoch: 0, Step: 4862, Loss: 5.330435752868652
Epoch: 0, Step: 5304, Loss: 5.594406604766846
Epoch: 0, Step: 5746, Loss: 5.049088954925537
Epoch: 0, Step: 6188, Loss: 5.299722671508789
Epoch: 0, Step: 6630, Loss: 5.385701656341553
Epoch: 0, Step: 7072, Loss: 5.208510875701904
Epoch: 0, Step: 7514, Loss: 4.577992916107178
Epoch: 0, Step: 7956, Loss: 4.767080783843994
Epoch: 0, Step: 8398, Loss: 5.327550411224365
Epoch: 0, Step: 8840, Loss: 5.458425998687744
Epoch: 0, Step: 8851, Loss: 5.0539374

In [23]:
def generate(idx, max_tokens):
    model.eval()
    tokens = idx
    for i in range(max_tokens):
        logits = model(tokens[:, -context_length:])
        topk_values, topk_indices = torch.topk(logits[:, -1, :], 50)
        probs = topk_values.softmax(dim=-1)
        sample = torch.multinomial(probs, 1)
        token = torch.gather(topk_indices, 1, sample)
        tokens = torch.cat((tokens, token), dim=-1)
    
    return tokens

In [24]:
def generate_finetune(input_ids, attention_mask, max_tokens=50):
    model.eval()
    tokens = input_ids
    for i in range(max_tokens):
        logits = model(input_ids, attention_mask=attention_mask).logits
        topk_values, topk_indices = torch.topk(logits[:, -1, :], 10)
        probs = topk_values.softmax(dim=-1)
        sample = torch.multinomial(probs, 1)
        token = torch.gather(topk_indices, 1, sample)
        tokens = torch.cat((tokens, token), dim=-1)
        
    return tokens

In [25]:
start_tokens = tokenizer("What do you call", return_tensors='pt').to(device)
print(tokenizer.decode(generate_finetune(**start_tokens, max_tokens=100)[0]))

What do you call black gay black when black guy guy Mexican dog gay when Mexican who Mexican black when when dog group gay when when black black black who when when guy when Mexican gay cow when gay who Mexican when black when black group cow black when Mexican black when cow when group when when gay black when Mexican who black when gay when dog group group who black who black group when who when guy when when Mexican when cow black when when black when when black black gay black who who Mexican guy Mexican when when dog group when when
