# Custom Medical LLM Test

### Goal: 
Train a LLM using data from a selected medical textbook.

### More to Explore:
Could potentially use RAG as an alternative instead, like SRI researchers may pursue.

### TODOs:
- evaluate on Brent & Helms Radiology Textbook for more text and less images
- train LLM and evaluate training loss using SotA, not Bigram Model

### Process:

##### Data Intake - Train text data page by page, including images
1. clean text data
2. image extractor
3. Need to train classifier that associates text with respective image - if we can get coordinates, that'd be huge
5. For now, probably just storing images and text in dict

##### Joint Heads
4. Create a LLM - this will be used as a BERT-esque LLM in the text encoder
5. (Hard) CNN for images

##### Encoder-Decoder Model
6. Train Encoder
7. Train Decoder

# Reading Data

**Reference Max's textbook scraping notebook first!** Courtesy of Max Vogel :)

In [2]:
import os
import json
import regex as re

In [3]:
# IMPORTANT: Run scrape notebook first
PDF_URL = "General - Mandell - Core Radiology (1e).pdf"

assert os.path.exists(f"book-scrape/scrape_out/{PDF_URL.split('.pdf')[0]}")
TEXT_DATA_FOLDER_URL = f"book-scrape/scrape_out/{PDF_URL.split('.pdf')[0]}"

In [4]:
text_data = {}

for fjson in os.listdir(TEXT_DATA_FOLDER_URL):
    ch, ftype = fjson.split('.')[0], fjson.split('.')[-1]
    if ftype != 'json':
        continue
    with open(TEXT_DATA_FOLDER_URL + f'/{fjson}') as f:
        text_data[int(re.search(r'\d+', ch)[0]) - 1] = json.load(f)

text_data[0]

[{'label_range': [0, 0],
  'pg_range': [14, 16],
  'header_font': [['Calibri', 14.0]],
  'header': 'Anatomy',
  'body_font': [['TwCenMT', 11.0]],
  'body': 'Interlobar fi ssures Mechanisms of atelectasis Each of the five lobes tends to collapse in a predictable direction, as shown above.',
  'images': ['scrape_out/General - Mandell - Core Radiology (1e)/images/0.png',
   'scrape_out/General - Mandell - Core Radiology (1e)/images/1.png',
   'scrape_out/General - Mandell - Core Radiology (1e)/images/2.png']},
 {'label_range': [0, 0],
  'pg_range': [17, 18],
  'header_font': [['TwCenMT', 11.0]],
  'header': 'Right upper lobe atelectasis',
  'body_font': [['Calibri', 10.0]],
  'body': 'Case courtesy Ritu R. Gill, MD, MPH, Brigham and Women’s Hospital.',
  'images': ['scrape_out/General - Mandell - Core Radiology (1e)/images/3.png',
   'scrape_out/General - Mandell - Core Radiology (1e)/images/4.png']},
 {'label_range': [1, 0],
  'pg_range': [18, 19],
  'header_font': [['TwCenMT', 11.0]],
 

In [11]:
# merging all text data; lowk make dictionaries instead
def get_text_recursively(obj):
    text = ['']
    
    def helper(obj):
        if isinstance(obj[0], str):
            for s in obj:
                text[0] += s + " [END]"
        else:
            for ls in obj:
                helper(ls)

    helper(obj)
    return text[0]

text = get_text_recursively(text_data[0])
text

'Anatomy [END]Interlobar fi ssuresMechanisms of atelectasisEach of the five lobes tends to collapse in a predictable direction, as shown above.\nCase courtesy Ritu R. Gill, MD, MPH, Brigham and Women’s Hospital.\n [END]Right lower lobe atelectasis [END]Left lower lobe collapse: Frontal and lateral radiographs demonstrate a triangular retrocardiac opacity\nrepresenting the collapsed left lower lobe (red arrows). There is loss of concavity of the left heart\nborder (the flat waist sign; yellow arrow).Case courtesy Ritu R. Gill, MD, MPH, Brigham and Women’s Hospital.\n [END]round atelectasis [END]Right middle lobe atelectasis: Frontal chest radiograph shows an indistinct opacity in the right lung with\nfocal silhouetting of the right heart border (arrow). There is elevation of the right hemidiaphragm due\nto volume loss. The lateral radiograph shows a wedge-shaped opacity (arrow) projecting over the midheart representing the collapsed right middle lobe. Round atelectasis: Noncontrast CT s

In [12]:
chars = sorted(list(set(text)))
vocab_size = len(set(text))
print(vocab_size)
print(chars)

85
['\n', ' ', '"', '&', "'", '(', ')', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '>', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'Y', 'Z', '[', ']', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', 'á', 'ç', 'ö', '–', '’', '“', '”', '•']


# Building an LLM / text transformer model

##### BERT-esque may be more effective due to bidirectionality; currently using a GPT-esque model
##### Additionally, worth exploring RAG in conjunction with an LLM:
https://paperswithcode.com/method/rag

Here's how RAG and LLM can work together in a VQA model:

- When a question is posed to the VQA system, the retriever component (dense retrieval, sparse retrieval, etc.) retrieves relevant textual information or passages from a knowledge base - the textbook data in this case.
- The retrieved passages are then passed to the LLM, which generates an answer based on the context provided by the passages and the question itself.
- The generated answer is returned as the output of the VQA model.

### Encoding / Decoding

In [13]:
!pip install tiktoken


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.2.1[0m[39;49m -> [0m[32;49m23.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [14]:
import tiktoken
import torch
enc = tiktoken.get_encoding('gpt2')

In [15]:
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

In [16]:
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
print(data[:100])

torch.Size([30388]) torch.int64
tensor([24, 64, 51, 70, 65, 63, 75,  1, 49, 28, 37, 27, 50, 32, 64, 70, 55, 68,
        62, 65, 52, 51, 68,  1, 56, 59,  1, 69, 69, 71, 68, 55, 69, 36, 55, 53,
        58, 51, 64, 59, 69, 63, 69,  1, 65, 56,  1, 51, 70, 55, 62, 55, 53, 70,
        51, 69, 59, 69, 28, 51, 53, 58,  1, 65, 56,  1, 70, 58, 55,  1, 56, 59,
        72, 55,  1, 62, 65, 52, 55, 69,  1, 70, 55, 64, 54, 69,  1, 70, 65,  1,
        53, 65, 62, 62, 51, 66, 69, 55,  1, 59])


### Loading / Batching Data

In [17]:
n = int(0.9 * len(data))
train_data = data[:n]
test_data = data[n:]

In [18]:
batch_size = 32
block_size = 8
n_embd = 32

def get_batch(split):
    data = train_data if split == 'train' else test_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

In [19]:
xb, yb = get_batch("train")

### Simple Bigram Language Model

In [20]:
device = torch.device( cuda if torch.cuda.is_available() else 'cpu' )

In [490]:
import torch
import torch.nn as nn
from torch.nn import functional as F

class BigramLanguageModel(nn.Module):

    def __init__(self, n_embd):
        super().__init__()
        # each token directly reads logits for next token off a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        
        # idx and targets are both (batch_size, block_size) tensor of integers or (B, T)
        token_emb = self.token_embedding_table(idx) # (B, T, C) = batch, time, channel
        x = token_emb
        logits = self.lm_head(x) # (B, T, vocab_size)
        
        if targets == None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    def generate(self, idx, max_tokens):
        # idx if a (batch_size, block_size) tensor of integers or (B, T)
        for _ in range(max_tokens):
            logits, loss = self.forward(idx) # get predictions
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)
            # attend sample to running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx


m = BigramLanguageModel(n_embd)
logits, loss = m.forward(xb, yb)
print(logits.shape)
print(loss)
print(decode(m.generate(torch.zeros((1, 1), dtype=torch.long), max_tokens = 100)[0].tolist()))

torch.Size([256, 85])
tensor(4.6054, grad_fn=<NllLossBackward0>)

Zn–3:RzBg4-Y1 6t1tEO0'H,•” Ji8c3(rr
[wuCdDw“:]NxP&r
Wo•.SUbdL.SK0 [’:EsGwdbd"kcmsn2oW 0Dny;L11uf[n08


Training

In [491]:
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [492]:
for steps in range(10000):
    
    # sample batch of data
    xb, yb = get_batch('train')

    # evaluate loss
    logits, loss = m.forward(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(loss.item())

2.448796272277832


In [495]:
print(decode(m.generate(torch.zeros((1, 1), dtype=torch.long), max_tokens = 100)[0].tolist()))


pe ppa d) c [END, orasthits
PILAla nesurowe ll oungi malisugeretorrr ofim thalaluc s al mmomas.
car 


# Self Attention Model

### Demo:

In [22]:
import torch
import torch.nn as nn
from torch.nn import functional as F

B, T, C = 4, 8, 32
x = torch.randn(B, T, C)

# Self-Attention Head
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
k = key(x) # (B, T, head_size) 
q = query(x) # (B, T, head_size)
wei = q @ k.transpose(-2, -1) # (B, T, head_size) @ (B, head_size, T) --> (B, T, T)

tril = torch.tril(torch.ones(T, T))
# wei = torch.zeros((T,T)) - no longer all uniform because of attention!
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
out = wei @ x

wei

tensor([[[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00],
         [3.9789e-01, 6.0211e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00],
         [1.9113e-01, 7.0959e-01, 9.9280e-02, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00],
         [2.8954e-02, 5.6960e-01, 3.0922e-01, 9.2228e-02, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00],
         [1.1165e-01, 1.3680e-01, 9.3979e-02, 1.6007e-02, 6.4156e-01,
          0.0000e+00, 0.0000e+00, 0.0000e+00],
         [2.1489e-01, 9.1245e-02, 1.3590e-01, 2.5661e-01, 1.6505e-02,
          2.8484e-01, 0.0000e+00, 0.0000e+00],
         [4.0136e-02, 1.6778e-01, 1.6040e-02, 6.1050e-02, 6.6853e-03,
          6.3306e-01, 7.5253e-02, 0.0000e+00],
         [1.5845e-01, 1.6098e-01, 2.3580e-02, 5.5861e-02, 6.5940e-02,
          5.2553e-02, 1.6164e-01, 3.2100e-01]],

        [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.00

### Implementation:

In [23]:
import torch
import torch.nn as nn
from torch.nn import functional as F

# hyperparameters
batch_size = 16 # how many independent sequences will we process in parallel?
block_size = 32 # what is the maximum context length for predictions?
max_iters = 5000
eval_interval = 100
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 64
n_head = 4
n_layer = 4
dropout = 0.0
# ------------

torch.manual_seed(1337)


@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)   # (B,T,C)
        q = self.query(x) # (B,T,C)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,C)
        out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
        return out

class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, n_embd, n_head):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedFoward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

# super simple bigram model
class BigramLanguageModel(nn.Module):

    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd) # final layer norm
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
        x = tok_emb + pos_emb # (B,T,C)
        x = self.blocks(x) # (B,T,C)
        x = self.ln_f(x) # (B,T,C)
        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

In [24]:
model = BigramLanguageModel()
m = model.to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):

    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=2000)[0].tolist()))

0.212309 M parameters
step 0: train loss 4.6285, val loss 4.6339
step 100: train loss 2.7345, val loss 3.1479
step 200: train loss 2.5565, val loss 2.9297
step 300: train loss 2.4331, val loss 2.8158
step 400: train loss 2.3015, val loss 2.7195
step 500: train loss 2.1841, val loss 2.6871
step 600: train loss 2.0319, val loss 2.6275
step 700: train loss 1.8879, val loss 2.5053
step 800: train loss 1.7786, val loss 2.4561
step 900: train loss 1.6521, val loss 2.4167
step 1000: train loss 1.5712, val loss 2.3612
step 1100: train loss 1.4920, val loss 2.2799
step 1200: train loss 1.4228, val loss 2.3346
step 1300: train loss 1.3618, val loss 2.2691
step 1400: train loss 1.2983, val loss 2.2332
step 1500: train loss 1.2414, val loss 2.2970
step 1600: train loss 1.2079, val loss 2.2139
step 1700: train loss 1.1548, val loss 2.2378
step 1800: train loss 1.1066, val loss 2.2181
step 1900: train loss 1.0783, val loss 2.2533
step 2000: train loss 1.0534, val loss 2.3094
step 2100: train loss 1.

In [26]:
def generate_text_from_context(text, max_new_tokens=1000):
    context = torch.tensor([encode(text)], dtype=torch.long, device=device)
    return decode(m.generate(context, max_new_tokens=max_new_tokens)[0].tolist())

In [27]:
test_text = 'Focal nodular hyperplasia (FNH) is disorganized liver tissue with no malignant potential. It is primarily seen in asymptomatic women and is not associated with oral contraceptives.'
print(generate_text_from_context(test_text))

Focal nodular hyperplasia (FNH) is disorganized liver tissue with no malignant potential. It is primarily seen in asymptomatic women and is not associated with oral contraceptives. [END]Amyomesochs pretection of a midddle mediastinal mass Detected to the esoplaciety B Signs, sho31, evei an of the colung disease.. Rigial. Case courtesy Ritu R. Gill, MD, MPH, Brigham and Women’s Hospital.
 [END]Pneumatocele, right paired, is an aspergillosis: Axial CT demonstrates for neck shows numeralbascent malignantly peripheral nodules (arrows). This pattermant pulmonary hypertension T3 squamous in th characterized by uperf and sternophology 2, 76-9 (2, [END]IId 1 sighted may be lower lobe atelectasis [END]Pulmonary
hypertension C: Axial CT demonjeft spirgial mucinous BAC HSP, sign of ABNoncontrast [END]Pulmonary hypertension stestablitic demonsterstitial diffuse ICC curves an teithe pulmonary hypertension [END]Overview of pulmonary gangrenedly of Virginial adenopathy
(yellow arrow) pattern, represe

### Saving Our Model

In [500]:
from transformers import PretrainedConfig
import os

In [501]:
# Create a custom configuration

MODEL_FOLDER_PATH = 'config'
MODEL_FILE_PATH = "config/bigram_language_model"

model_config = PretrainedConfig(
    vocab_size=vocab_size,
    hidden_size=n_embd,
    num_hidden_layers=n_layer,
    num_attention_heads=n_head,
    intermediate_size=4 * n_embd,
    hidden_dropout_prob=dropout,
    attention_probs_dropout_prob=dropout,
)

# Save the configuration to a file
if not os.path.exists(MODEL_FOLDER_PATH):
    os.makedirs(MODEL_FOLDER_PATH)
model_config.save_pretrained(MODEL_FILE_PATH)

In [502]:
import json

file_path = f'{MODEL_FILE_PATH}/config.json'

with open(file_path, 'r') as json_file:
    existing_data = json.load(json_file)

new_data = {
    "model_type": "gpt2"
}

# Update the existing data with the new data
existing_data.update(new_data)

with open(file_path, 'w') as json_file:
    json.dump(existing_data, json_file, indent=4)

# Using LLM as a a Text Transformer in a QA model

### Reconciling Custom LLM with Transformers based LLM

In [1]:
import transformers

In [2]:
import torch
import torch.nn as nn
from transformers import PreTrainedModel

class CustomLLMWrapper(PreTrainedModel):
    def __init__(self, config, model):
        super(CustomLLMWrapper, self).__init__(config)
        self.token_embedding_table = model.token_embedding_table
        self.position_embedding_table = model.position_embedding_table
        self.blocks = model.blocks
        self.ln_f = model.ln_f
        self.lm_head = model.lm_head

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=self.device)) # (T,C)
        x = tok_emb + pos_emb # (B,T,C)
        x = self.blocks(x) # (B,T,C)
        x = self.ln_f(x) # (B,T,C)
        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -block_size:]
            # get the predictions
            logits, loss = self.forward(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

In [3]:
# config load
from transformers import DistilBertTokenizerFast, AutoConfig
config_path = 'config/bigram_language_model'
loaded_config = AutoConfig.from_pretrained(config_path)

custom_llm = CustomLLMWrapper(loaded_config, m)

NameError: name 'm' is not defined

In [506]:
def generate_text_from_context(text, max_new_tokens=1000):
    context = torch.tensor([encode(text)], dtype=torch.long, device=device)
    return decode(custom_llm.generate(context, max_new_tokens=max_new_tokens)[0].tolist())

In [507]:
generate_text_from_context('What is Focal nodular hyperplasia?')

KeyError: '?'

### Using Custom LLM as part of QA Model

In [395]:
from torch.utils.data import DataLoader
from transformers import AdamW # AdamW optimizer with weighted decay reduces chance of overfitting
from tqdm import tqdm # progress bar

In [396]:
model = custom_llm

model.to(device)
model.train()
optim = AdamW(model.parameters(), lr=5e-5)



##### TODO: Form a Dataset
Format: input id, start position, end position, attention mask (?)

In [398]:
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

NameError: name 'train_dataset' is not defined

In [402]:
# NEEDS MODIFICATION - loss and maybe logits can be mapped from custom architecture

for epoch in range(3):
    # set model to train mode
    model.train()
    # setup loop (we use tqdm for the progress bar)
    loop = tqdm(train_loader, leave=True)
    for batch in loop:
        # we calculate gradients at the end of loop; previous iterations shouldn't effect
        optim.zero_grad() 
        # pull all the tensor batches required for training
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        start_positions = batch['start_positions'].to(device)
        end_positions = batch['end_positions'].to(device)
        # train model on batch and return outputs (incl. loss)
        outputs = model(input_ids, attention_mask=attention_mask,
                        start_positions=start_positions,
                        end_positions=end_positions)
        # extract loss
        loss = outputs[0]
        # calculate loss for every parameter that needs grad update
        loss.backward()
        # update parameters
        optim.step()
        # print relevant info to progress bar
        loop.set_description(f'Epoch {epoch}')
        loop.set_postfix(loss=loss.item())

NameError: name 'train_loader' is not defined

In [400]:
model.eval()

CustomLLMWrapper(
  (token_embedding_table): Embedding(152, 64)
  (position_embedding_table): Embedding(32, 64)
  (blocks): Sequential(
    (0): Block(
      (sa): MultiHeadAttention(
        (heads): ModuleList(
          (0-3): 4 x Head(
            (key): Linear(in_features=64, out_features=16, bias=False)
            (query): Linear(in_features=64, out_features=16, bias=False)
            (value): Linear(in_features=64, out_features=16, bias=False)
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (proj): Linear(in_features=64, out_features=64, bias=True)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (ffwd): FeedFoward(
        (net): Sequential(
          (0): Linear(in_features=64, out_features=256, bias=True)
          (1): ReLU()
          (2): Linear(in_features=256, out_features=64, bias=True)
          (3): Dropout(p=0.0, inplace=False)
        )
      )
      (ln1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      

# Next Steps

### Priority:
- Explore this: https://github.com/RSNA/AI-Deep-Learning-Lab-2022 - data processing
- Maybe train on more text-heavy textbook (see above)
- Data Cleaning / PDF text ingestion - Check with Max Vogel
- Assigning Photos to relevant text for VQA model training
- Upgrade Bigram Language Model to State of the Art

### Additional Research:
- RAG for increased accuracy of text data on large amounts of text knowledge