In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import json
import random
from tqdm import tqdm
from torchtune.modules import RotaryPositionalEmbeddings

  from .autonotebook import tqdm as notebook_tqdm


In [17]:
batch_size = 16
#block_size = 128
max_iters = 3000
eval_interval = 100
learning_rate = 6e-4
eval_iters = 200
n_embd = 64
n_head = 4
n_layer = 4
dropout = 0.1
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

Using device: cuda


In [3]:
train_data = torch.load("train.pt")
val_data = torch.load("val.pt")

def find_max_len(dataset, name="Dataset"):
    max_len = 0
    total = 0
    count = 0
    for seq in dataset:
        seq_len = len(seq)
        if seq_len > max_len:
            max_len = seq_len
        total += seq_len
        count += 1
    avg_len = total / count if count > 0 else 0
    print(f"{name}: {count} timelines. Max length: {max_len}, Average length: {avg_len:.1f}")

find_max_len(train_data, "Train set")
find_max_len(val_data, "Val set")

with open("D:/CourseworkFolder/DPSynthData/Data Manipulation/token_map.json", "r", encoding="utf-8") as f:
    token_map = json.load(f)
vocab_size = len(token_map)
print(f"Vocab size: {vocab_size}")
 

Train set: 96341 timelines. Max length: 29917, Average length: 200.0
Val set: 10705 timelines. Max length: 25961, Average length: 188.3
Vocab size: 12666


In [None]:
def get_batch(split):
    data = train_data if split == "train" else val_data
    batch = random.sample(data, batch_size)
    x_batch = [torch.tensor(seq[:-1], dtype=torch.long, device=device) for seq in batch]
    y_batch = [torch.tensor(seq[1:], dtype=torch.long, device=device) for seq in batch]
    return x_batch, y_batch

@torch.no_grad()
'''
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            x_batch, y_batch = get_batch(split)
            
            batch_losses = []
            for x, y in zip(x_batch, y_batch):
                logits, loss = model([x], [y])  
                batch_losses.append(loss.item())
            
            avg_loss = sum(batch_losses) / len(batch_losses)
            losses[k] = avg_loss
        
        out[split] = losses.mean()
    model.train()
    return out
    '''
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            _, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [33]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, n_embd):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = n_embd // num_heads
        assert n_embd % num_heads == 0, "Embedding size must be divisible by num_heads"

        self.query = nn.Linear(n_embd, n_embd, bias=False)
        self.key = nn.Linear(n_embd, n_embd, bias=False)
        self.value = nn.Linear(n_embd, n_embd, bias=False)
        
        
        self.rope = RotaryPositionalEmbeddings(dim=self.head_dim, max_seq_len=50000)
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape  # (batch_size, sequence_length, embedding_dim)
        
        # Linear projections and reshape to (B, T, n_head, head_dim)
        q = self.query(x).view(B, T, self.num_heads, self.head_dim)
        k = self.key(x).view(B, T, self.num_heads, self.head_dim)
        v = self.value(x).view(B, T, self.num_heads, self.head_dim)
        
        #print(f"q.shape: {q.shape}")
        
        # Apply RoPE to q and k (automatic positional embeddings)
        q = self.rope(q)
        k = self.rope(k)

        # Scaled Dot-Product Attention with flash attention
        out = F.scaled_dot_product_attention(
            q, k, v,
            attn_mask=None,
            dropout_p=dropout if self.training else 0.0,
            is_causal=True
        )  # (B, T, n_head, head_dim)

        # Merge heads: (B, T, n_head, head_dim) -> (B, T, C)
        out = out.transpose(1, 2).reshape(B, T, C)
        out = self.dropout(self.proj(out))
        return out


class FeedForward(nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    def __init__(self, n_embd, n_head):
        super().__init__()
        self.sa = MultiHeadAttention(n_head, n_embd)
        self.ffwd = FeedForward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x


class TimelineLanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        #self.position_embedding_table = nn.Embedding(12000, n_embd)  
        self.blocks = nn.Sequential(*[Block(n_embd, n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, x_batch, y_batch=None):
        logits_batch = []
        losses = []

        for x in x_batch:
            #T = x.size(0)
            tok_emb = self.token_embedding_table(x)
            #pos_emb = self.position_embedding_table(torch.arange(T, device=x.device))
            x_in = tok_emb 
            x_in = x_in.unsqueeze(0)

            x_out = self.blocks(x_in)
            x_out = self.ln_f(x_out)
            logits = self.lm_head(x_out).squeeze(0)  
            logits_batch.append(logits)

        if y_batch is not None:
            for logits, y in zip(logits_batch, y_batch):
                loss = F.cross_entropy(logits, y)
                losses.append(loss)
            return logits_batch, torch.stack(losses).mean()

        return logits_batch, None

    def generate(self, idx, max_new_tokens):
        min_tokens = 1
        eos_token = 12665

        for i in range(max_new_tokens):
            T = idx.size(1)
            tok_emb = self.token_embedding_table(idx)
            #pos_emb = self.position_embedding_table(torch.arange(T, device=idx.device))
            x = tok_emb
            x = self.blocks(x)
            x = self.ln_f(x)
            logits = self.lm_head(x)

            probs = F.softmax(logits[:, -1, :], dim=-1)

            if i < min_tokens:
                probs[:, eos_token] = 0
                probs = probs / probs.sum(dim=-1, keepdim=True)

            next_token = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, next_token), dim=1)
         
            if next_token.item() == eos_token and i >= min_tokens:
                break

        return idx


In [18]:
print(" Starting training")
model = TimelineLanguageModel().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
print(f" Model has {sum(p.numel() for p in model.parameters())/1e6:.2f}M parameters")

for iter in tqdm(range(max_iters), desc="Training Model", unit="steps"):
    if iter % eval_interval == 0:
        losses = estimate_loss()
        print(f" Step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    xb, yb = get_batch("train")
    logits, loss = model(xb, yb)

    loss.backward()
    optimizer.step()
    optimizer.zero_grad(set_to_none=True)




 Starting training
 Model has 1.83M parameters


Training Model:   0%|          | 0/3000 [00:00<?, ?steps/s]

 Step 0: train loss 9.6205, val loss 9.6189


Training Model:   3%|▎         | 100/3000 [01:27<22:08,  2.18steps/s] 

 Step 100: train loss 4.1366, val loss 4.1614


Training Model:   7%|▋         | 200/3000 [02:48<14:04,  3.32steps/s]   

 Step 200: train loss 3.0425, val loss 3.0348


Training Model:  10%|█         | 300/3000 [04:08<11:25,  3.94steps/s]   

 Step 300: train loss 2.7319, val loss 2.7297


Training Model:  13%|█▎        | 400/3000 [05:22<09:58,  4.35steps/s]   

 Step 400: train loss 2.5987, val loss 2.5974


Training Model:  17%|█▋        | 500/3000 [06:39<11:25,  3.65steps/s]   

 Step 500: train loss 2.5355, val loss 2.5086


Training Model:  20%|██        | 600/3000 [07:53<12:50,  3.11steps/s]   

 Step 600: train loss 2.4747, val loss 2.4821


Training Model:  23%|██▎       | 700/3000 [09:11<12:38,  3.03steps/s]   

 Step 700: train loss 2.4511, val loss 2.4382


Training Model:  27%|██▋       | 800/3000 [10:31<08:08,  4.51steps/s]   

 Step 800: train loss 2.4102, val loss 2.3974


Training Model:  30%|███       | 900/3000 [11:50<17:14,  2.03steps/s]  

 Step 900: train loss 2.3800, val loss 2.3944


Training Model:  33%|███▎      | 1000/3000 [13:05<07:46,  4.29steps/s] 

 Step 1000: train loss 2.3790, val loss 2.3610


Training Model:  37%|███▋      | 1100/3000 [14:21<07:58,  3.97steps/s]  

 Step 1100: train loss 2.3196, val loss 2.3260


Training Model:  40%|████      | 1200/3000 [15:41<06:39,  4.50steps/s]  

 Step 1200: train loss 2.3245, val loss 2.3068


Training Model:  43%|████▎     | 1300/3000 [17:00<07:40,  3.69steps/s]  

 Step 1300: train loss 2.2987, val loss 2.2894


Training Model:  47%|████▋     | 1400/3000 [18:19<06:01,  4.42steps/s]  

 Step 1400: train loss 2.2666, val loss 2.2653


Training Model:  50%|█████     | 1500/3000 [19:39<05:54,  4.23steps/s]  

 Step 1500: train loss 2.2455, val loss 2.2484


Training Model:  53%|█████▎    | 1600/3000 [20:55<05:14,  4.45steps/s]  

 Step 1600: train loss 2.2353, val loss 2.2150


Training Model:  57%|█████▋    | 1700/3000 [22:14<04:58,  4.36steps/s]  

 Step 1700: train loss 2.2020, val loss 2.1891


Training Model:  60%|██████    | 1800/3000 [23:26<04:54,  4.07steps/s]  

 Step 1800: train loss 2.1809, val loss 2.1728


Training Model:  63%|██████▎   | 1900/3000 [24:44<05:27,  3.36steps/s]  

 Step 1900: train loss 2.1722, val loss 2.1625


Training Model:  67%|██████▋   | 2000/3000 [26:02<04:12,  3.96steps/s]  

 Step 2000: train loss 2.1536, val loss 2.1532


Training Model:  70%|███████   | 2100/3000 [27:26<03:54,  3.84steps/s]  

 Step 2100: train loss 2.1296, val loss 2.1407


Training Model:  73%|███████▎  | 2200/3000 [28:44<03:10,  4.20steps/s]  

 Step 2200: train loss 2.1203, val loss 2.1174


Training Model:  77%|███████▋  | 2300/3000 [30:01<02:37,  4.44steps/s]  

 Step 2300: train loss 2.1255, val loss 2.1142


Training Model:  80%|████████  | 2400/3000 [31:20<02:09,  4.63steps/s]  

 Step 2400: train loss 2.0831, val loss 2.1215


Training Model:  83%|████████▎ | 2500/3000 [32:33<01:47,  4.67steps/s]  

 Step 2500: train loss 2.0933, val loss 2.0859


Training Model:  87%|████████▋ | 2600/3000 [33:49<01:32,  4.34steps/s]  

 Step 2600: train loss 2.0725, val loss 2.0479


Training Model:  90%|█████████ | 2700/3000 [35:04<01:05,  4.55steps/s]  

 Step 2700: train loss 2.0645, val loss 2.0559


Training Model:  93%|█████████▎| 2800/3000 [36:23<00:49,  4.03steps/s]  

 Step 2800: train loss 2.0430, val loss 2.0384


Training Model:  97%|█████████▋| 2900/3000 [37:39<00:24,  4.10steps/s]

 Step 2900: train loss 2.0442, val loss 2.0338


Training Model: 100%|██████████| 3000/3000 [38:55<00:00,  1.28steps/s]


In [34]:
with open("D:/CourseworkFolder/DPSynthData/Data Manipulation/token_map.json", "r", encoding="utf-8") as f:
    token_map = json.load(f)
itos = {v: k for k, v in token_map.items()}

def decode(token_ids):
    return [itos.get(i, f"<unk:{i}>") for i in token_ids]


torch.save(model.state_dict(), "timeline_model.pt")
print(" Model saved.")

print(" Generating synthetic timeline...")
context = torch.zeros((1, 1), dtype=torch.long, device=device)
output = model.generate(context, max_new_tokens=300)
decoded = decode(output[0].tolist())

print(" Decoded Tokens:")
for token in decoded:
    print(token)

 Model saved.
 Generating synthetic timeline...
 Decoded Tokens:
('Hospital Admission', 'admittime')
admission_location=PHYSICIAN REFERRAL_discharge_location=HOME HEALTH CARE
Diagnosis
icd_code=470
Diagnosis
icd_code=600
Diagnosis
icd_code=781
Diagnosis
icd_code=V12
icd_code=446
Diagnosis
icd_code=389
__EOS__
