In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import json
import random
from tqdm import tqdm
from torchtune.modules import RotaryPositionalEmbeddings

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
batch_size = 16
#block_size = 128
max_iters = 6000
eval_interval = 100
learning_rate = 6e-4
eval_iters = 200
n_embd = 64
n_head = 4
n_layer = 4
dropout = 0.1
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

Using device: cuda


In [3]:
train_data = torch.load("train.pt")
val_data = torch.load("val.pt")

def find_max_len(dataset, name="Dataset"):
    max_len = 0
    total = 0
    count = 0
    for seq in dataset:
        seq_len = len(seq)
        if seq_len > max_len:
            max_len = seq_len
        total += seq_len
        count += 1
    avg_len = total / count if count > 0 else 0
    print(f"{name}: {count} timelines. Max length: {max_len}, Average length: {avg_len:.1f}")

find_max_len(train_data, "Train set")
find_max_len(val_data, "Val set")

with open("D:/CourseworkFolder/DPSynthData/Data Manipulation/token_map.json", "r", encoding="utf-8") as f:
    token_map = json.load(f)
vocab_size = len(token_map)
print(f"Vocab size: {vocab_size}")
 

Train set: 96377 timelines. Max length: 29941, Average length: 107.5
Val set: 10709 timelines. Max length: 27058, Average length: 109.7
Vocab size: 15567


In [4]:
def get_batch(split):
    data = train_data if split == "train" else val_data
    batch = random.sample(data, batch_size)
    x_batch = [torch.tensor(seq[:-1], dtype=torch.long, device=device) for seq in batch]
    y_batch = [torch.tensor(seq[1:], dtype=torch.long, device=device) for seq in batch]
    return x_batch, y_batch

@torch.no_grad()

# def estimate_loss():
#     out = {}
#     model.eval()
#     for split in ['train', 'val']:
#         losses = torch.zeros(eval_iters)
#         for k in range(eval_iters):
#             x_batch, y_batch = get_batch(split)
            
#             batch_losses = []
#             for x, y in zip(x_batch, y_batch):
#                 logits, loss = model([x], [y])  
#                 batch_losses.append(loss.item())
            
#             avg_loss = sum(batch_losses) / len(batch_losses)
#             losses[k] = avg_loss
        
#         out[split] = losses.mean()
#     model.train()
#     return out

def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            _, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [11]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, n_embd):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = n_embd // num_heads
        assert n_embd % num_heads == 0, "Embedding size must be divisible by num_heads"

        self.query = nn.Linear(n_embd, n_embd, bias=False)
        self.key = nn.Linear(n_embd, n_embd, bias=False)
        self.value = nn.Linear(n_embd, n_embd, bias=False)
        
        
        self.rope = RotaryPositionalEmbeddings(dim=self.head_dim, max_seq_len=50000)
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape  # (batch_size, sequence_length, embedding_dim)
        
        # Linear projections and reshape to (B, T, n_head, head_dim)
        q = self.query(x).view(B, T, self.num_heads, self.head_dim)
        k = self.key(x).view(B, T, self.num_heads, self.head_dim)
        v = self.value(x).view(B, T, self.num_heads, self.head_dim)
        
        #print(f"q.shape: {q.shape}")
        
        # Apply RoPE to q and k (automatic positional embeddings)
        q = self.rope(q)
        k = self.rope(k)

        # Scaled Dot-Product Attention with flash attention
        out = F.scaled_dot_product_attention(
            q, k, v,
            attn_mask=None,
            dropout_p=dropout if self.training else 0.0,
            is_causal=True
        )  # (B, T, n_head, head_dim)

        # Merge heads: (B, T, n_head, head_dim) -> (B, T, C)
        out = out.transpose(1, 2).reshape(B, T, C)
        out = self.dropout(self.proj(out))
        return out


class FeedForward(nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    def __init__(self, n_embd, n_head):
        super().__init__()
        self.sa = MultiHeadAttention(n_head, n_embd)
        self.ffwd = FeedForward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x


class TimelineLanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        #self.position_embedding_table = nn.Embedding(12000, n_embd)  
        self.blocks = nn.Sequential(*[Block(n_embd, n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, x_batch, y_batch=None):
        logits_batch = []
        losses = []

        for x in x_batch:
            #T = x.size(0)
            tok_emb = self.token_embedding_table(x)
            #pos_emb = self.position_embedding_table(torch.arange(T, device=x.device))
            x_in = tok_emb 
            x_in = x_in.unsqueeze(0)

            x_out = self.blocks(x_in)
            x_out = self.ln_f(x_out)
            logits = self.lm_head(x_out).squeeze(0)  
            logits_batch.append(logits)

        if y_batch is not None:
            for logits, y in zip(logits_batch, y_batch):
                loss = F.cross_entropy(logits, y)
                losses.append(loss)
            return logits_batch, torch.stack(losses).mean()

        return logits_batch, None

    def generate(self, idx, max_new_tokens):
        min_tokens = 1
        eos_token = 15566

        for i in range(max_new_tokens):
            T = idx.size(1)
            tok_emb = self.token_embedding_table(idx)
            #pos_emb = self.position_embedding_table(torch.arange(T, device=idx.device))
            x = tok_emb
            x = self.blocks(x)
            x = self.ln_f(x)
            logits = self.lm_head(x)

            probs = F.softmax(logits[:, -1, :], dim=-1)

            if i < min_tokens:
                probs[:, eos_token] = 0
                probs = probs / probs.sum(dim=-1, keepdim=True)

            next_token = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, next_token), dim=1)
         
            if next_token.item() == eos_token and i >= min_tokens:
                break

        return idx


In [6]:
print(" Starting training")
model = TimelineLanguageModel().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
print(f" Model has {sum(p.numel() for p in model.parameters())/1e6:.2f}M parameters")

for iter in tqdm(range(max_iters), desc="Training Model", unit="steps"):
    if iter % eval_interval == 0:
        losses = estimate_loss()
        print(f" Step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    xb, yb = get_batch("train")
    logits, loss = model(xb, yb)

    loss.backward()
    optimizer.step()
    optimizer.zero_grad(set_to_none=True)




 Starting training
 Model has 2.21M parameters


Training Model:   0%|          | 0/6000 [00:00<?, ?steps/s]

 Step 0: train loss 9.7831, val loss 9.7877


Training Model:   2%|▏         | 100/6000 [00:57<23:43,  4.14steps/s] 

 Step 100: train loss 6.1807, val loss 6.1662


Training Model:   3%|▎         | 200/6000 [01:51<22:43,  4.25steps/s]   

 Step 200: train loss 5.6639, val loss 5.6941


Training Model:   5%|▌         | 300/6000 [02:48<20:45,  4.58steps/s]   

 Step 300: train loss 5.2546, val loss 5.2610


Training Model:   7%|▋         | 400/6000 [03:44<20:54,  4.46steps/s]   

 Step 400: train loss 4.8760, val loss 4.8621


Training Model:   8%|▊         | 500/6000 [04:37<20:09,  4.55steps/s]   

 Step 500: train loss 4.6118, val loss 4.6113


Training Model:  10%|█         | 600/6000 [05:30<22:10,  4.06steps/s]   

 Step 600: train loss 4.3786, val loss 4.3714


Training Model:  12%|█▏        | 700/6000 [06:24<21:05,  4.19steps/s]   

 Step 700: train loss 4.2538, val loss 4.2612


Training Model:  13%|█▎        | 800/6000 [07:18<20:05,  4.31steps/s]   

 Step 800: train loss 4.0991, val loss 4.1245


Training Model:  15%|█▌        | 900/6000 [08:15<18:33,  4.58steps/s]   

 Step 900: train loss 4.0215, val loss 4.0306


Training Model:  17%|█▋        | 1000/6000 [09:09<19:07,  4.36steps/s]  

 Step 1000: train loss 3.9716, val loss 3.9454


Training Model:  18%|█▊        | 1100/6000 [10:02<18:14,  4.48steps/s]   

 Step 1100: train loss 3.8695, val loss 3.8910


Training Model:  20%|██        | 1200/6000 [10:56<17:59,  4.45steps/s]   

 Step 1200: train loss 3.7983, val loss 3.8543


Training Model:  22%|██▏       | 1300/6000 [11:50<16:42,  4.69steps/s]   

 Step 1300: train loss 3.7716, val loss 3.7840


Training Model:  23%|██▎       | 1400/6000 [12:49<15:43,  4.87steps/s]   

 Step 1400: train loss 3.7643, val loss 3.7778


Training Model:  25%|██▌       | 1500/6000 [13:43<16:26,  4.56steps/s]   

 Step 1500: train loss 3.7236, val loss 3.7215


Training Model:  27%|██▋       | 1600/6000 [14:36<15:42,  4.67steps/s]   

 Step 1600: train loss 3.6715, val loss 3.6914


Training Model:  28%|██▊       | 1700/6000 [15:31<14:57,  4.79steps/s]   

 Step 1700: train loss 3.6682, val loss 3.6674


Training Model:  30%|███       | 1800/6000 [16:25<14:33,  4.81steps/s]   

 Step 1800: train loss 3.6174, val loss 3.6194


Training Model:  32%|███▏      | 1900/6000 [17:18<15:06,  4.52steps/s]   

 Step 1900: train loss 3.5641, val loss 3.5849


Training Model:  33%|███▎      | 2000/6000 [18:11<14:39,  4.55steps/s]   

 Step 2000: train loss 3.5716, val loss 3.5828


Training Model:  35%|███▌      | 2100/6000 [19:05<14:23,  4.52steps/s]   

 Step 2100: train loss 3.5355, val loss 3.5551


Training Model:  37%|███▋      | 2200/6000 [20:02<14:55,  4.24steps/s]   

 Step 2200: train loss 3.5213, val loss 3.5087


Training Model:  38%|███▊      | 2300/6000 [21:01<13:11,  4.67steps/s]   

 Step 2300: train loss 3.5296, val loss 3.5702


Training Model:  40%|████      | 2400/6000 [21:49<12:41,  4.73steps/s]  

 Step 2400: train loss 3.4998, val loss 3.4853


Training Model:  42%|████▏     | 2500/6000 [22:38<13:57,  4.18steps/s]  

 Step 2500: train loss 3.4920, val loss 3.4754


Training Model:  43%|████▎     | 2600/6000 [23:27<12:02,  4.71steps/s]  

 Step 2600: train loss 3.4568, val loss 3.4684


Training Model:  45%|████▌     | 2700/6000 [24:15<11:34,  4.75steps/s]  

 Step 2700: train loss 3.4446, val loss 3.4655


Training Model:  47%|████▋     | 2800/6000 [25:03<11:14,  4.74steps/s]  

 Step 2800: train loss 3.4475, val loss 3.4394


Training Model:  48%|████▊     | 2900/6000 [25:52<11:59,  4.31steps/s]  

 Step 2900: train loss 3.4192, val loss 3.3922


Training Model:  50%|█████     | 3000/6000 [26:54<11:50,  4.22steps/s]  

 Step 3000: train loss 3.4030, val loss 3.4538


Training Model:  52%|█████▏    | 3100/6000 [27:48<11:11,  4.32steps/s]  

 Step 3100: train loss 3.3486, val loss 3.3753


Training Model:  53%|█████▎    | 3200/6000 [28:43<12:21,  3.78steps/s]  

 Step 3200: train loss 3.3169, val loss 3.3603


Training Model:  55%|█████▌    | 3300/6000 [29:38<10:04,  4.46steps/s]  

 Step 3300: train loss 3.3736, val loss 3.3530


Training Model:  57%|█████▋    | 3400/6000 [30:27<09:11,  4.71steps/s]  

 Step 3400: train loss 3.3190, val loss 3.3731


Training Model:  58%|█████▊    | 3500/6000 [31:15<08:58,  4.64steps/s]  

 Step 3500: train loss 3.3391, val loss 3.3767


Training Model:  60%|██████    | 3600/6000 [32:04<08:37,  4.64steps/s]  

 Step 3600: train loss 3.3220, val loss 3.3280


Training Model:  62%|██████▏   | 3700/6000 [32:52<08:04,  4.74steps/s]  

 Step 3700: train loss 3.2968, val loss 3.3169


Training Model:  63%|██████▎   | 3800/6000 [33:39<07:25,  4.94steps/s]  

 Step 3800: train loss 3.3000, val loss 3.2702


Training Model:  65%|██████▌   | 3900/6000 [34:27<07:08,  4.90steps/s]  

 Step 3900: train loss 3.3038, val loss 3.3028


Training Model:  67%|██████▋   | 4000/6000 [35:15<07:08,  4.67steps/s]  

 Step 4000: train loss 3.2731, val loss 3.3019


Training Model:  68%|██████▊   | 4100/6000 [36:04<06:54,  4.59steps/s]  

 Step 4100: train loss 3.2961, val loss 3.3082


Training Model:  70%|███████   | 4200/6000 [36:51<06:15,  4.80steps/s]  

 Step 4200: train loss 3.2294, val loss 3.2462


Training Model:  72%|███████▏  | 4300/6000 [37:41<06:19,  4.48steps/s]  

 Step 4300: train loss 3.2556, val loss 3.2975


Training Model:  73%|███████▎  | 4400/6000 [38:30<06:30,  4.10steps/s]  

 Step 4400: train loss 3.2624, val loss 3.2517


Training Model:  75%|███████▌  | 4500/6000 [39:20<05:27,  4.58steps/s]  

 Step 4500: train loss 3.2515, val loss 3.2246


Training Model:  77%|███████▋  | 4600/6000 [40:10<05:30,  4.24steps/s]  

 Step 4600: train loss 3.2300, val loss 3.2527


Training Model:  78%|███████▊  | 4700/6000 [41:00<04:44,  4.57steps/s]  

 Step 4700: train loss 3.2072, val loss 3.2352


Training Model:  80%|████████  | 4800/6000 [41:51<04:35,  4.35steps/s]  

 Step 4800: train loss 3.2042, val loss 3.1969


Training Model:  82%|████████▏ | 4900/6000 [42:41<03:52,  4.72steps/s]  

 Step 4900: train loss 3.1774, val loss 3.2103


Training Model:  83%|████████▎ | 5000/6000 [43:30<03:37,  4.59steps/s]  

 Step 5000: train loss 3.1881, val loss 3.2082


Training Model:  85%|████████▌ | 5100/6000 [44:29<03:31,  4.25steps/s]  

 Step 5100: train loss 3.2098, val loss 3.2132


Training Model:  87%|████████▋ | 5200/6000 [45:25<02:55,  4.55steps/s]  

 Step 5200: train loss 3.1894, val loss 3.2479


Training Model:  88%|████████▊ | 5300/6000 [46:26<02:35,  4.49steps/s]  

 Step 5300: train loss 3.1290, val loss 3.1703


Training Model:  90%|█████████ | 5400/6000 [47:29<02:32,  3.93steps/s]  

 Step 5400: train loss 3.1667, val loss 3.1817


Training Model:  92%|█████████▏| 5500/6000 [48:31<01:54,  4.38steps/s]  

 Step 5500: train loss 3.1482, val loss 3.1815


Training Model:  93%|█████████▎| 5600/6000 [49:34<01:23,  4.78steps/s]  

 Step 5600: train loss 3.1480, val loss 3.1643


Training Model:  95%|█████████▌| 5700/6000 [50:34<01:10,  4.24steps/s]  

 Step 5700: train loss 3.1270, val loss 3.1501


Training Model:  97%|█████████▋| 5800/6000 [51:38<01:20,  2.48steps/s]

 Step 5800: train loss 3.1444, val loss 3.1744


Training Model:  98%|█████████▊| 5900/6000 [52:38<00:21,  4.71steps/s]

 Step 5900: train loss 3.1365, val loss 3.1423


Training Model: 100%|██████████| 6000/6000 [53:45<00:00,  1.86steps/s]


In [10]:
with open("D:/CourseworkFolder/DPSynthData/Data Manipulation/token_map.json", "r", encoding="utf-8") as f:
    token_map = json.load(f)
itos = {v: k for k, v in token_map.items()}

def decode(token_ids):
    return [itos.get(i, f"<unk:{i}>") for i in token_ids]


torch.save(model.state_dict(), "timeline_model.pt")
print(" Model saved.")

print(" Generating synthetic timeline...")
context = torch.zeros((1, 1), dtype=torch.long, device=device)
output = model.generate(context, max_new_tokens=300)
decoded = decode(output[0].tolist())

print(" Decoded Tokens:")
for token in decoded:
    print(token)

 Model saved.
 Generating synthetic timeline...
 Decoded Tokens:
('Hospital Admission', 'admittime')_admission_location=AMBULATORY SURGERY TRANSFER_discharge_location=
Diagnosis_icd_code=Z20
Transfer to Location_careunit=Psychiatry
Transfer to Location_careunit=Psychiatry
Transfer to Location_careunit=Psychiatry
Transfer to Location_careunit=UNKNOWN
('Hospital Admission', 'admittime')_admission_location=PHYSICIAN REFERRAL_discharge_location=HOME
__EOS__
__EOS__
__EOS__
Transfer to Location_careunit=UNKNOWN
('Hospital Admission', 'admittime')_admission_location=PHYSICIAN REFERRAL_discharge_location=HOME
__EOS__
__EOS__
Transfer to Location_careunit=UNKNOWN
('Hospital Admission', 'admittime')_admission_location=EMERGENCY ROOM_discharge_location=HOME
__EOS__
__EOS__
Transfer to Location_careunit=UNKNOWN
('Hospital Admission', 'admittime')_admission_location=EMERGENCY ROOM_discharge_location=HOME
__EOS__
__EOS__
Transfer to Location_careunit=UNKNOWN
('Hospital Admission', 'admittime')_admi