In [2]:
import torch
from torch.nn import functional as F
import sys

from nano_medgpt.definitions import *
from nano_medgpt.src.datasets.utils import *
from nano_medgpt.src.models.bigram import BigramLM
from nano_medgpt.src.models.gpt import *
import random
import pickle
import argparse
from azureml.core.run import Run
import glob
import os

torch.manual_seed(456123)

<torch._C.Generator at 0x2679eb14cb0>

In [21]:
# hyperparamters
vocab_size = 37
n_epochs = 100
eval_iters = 10
eval_interval = 10
batch_size = 64
context_length = 256
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# -----------------------------------------------------

In [22]:
def generate_text(m, max_new_tokens=100):
    idx = torch.zeros((1, 1), dtype=torch.long, device=device)
    generated_idx = m.generate(idx, max_new_tokens=max_new_tokens)[0].tolist()
    return decode(generated_idx)

def _get_batch(df):
    ix = torch.randint(len(df)-context_length, (batch_size,))
    x = torch.stack([df[i:i+context_length] for i in ix])
    y = torch.stack([df[i+1:i+context_length+1] for i in ix])
    x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
    return x, y


def get_batch(data_folder=None, split="train"):
    part_selected = 0
    if split == "train":
        part_selected = 0 # random.randint(0, 5)
    else:
        part_selected = 6
    # fpath = data_folder + "part_" + str(part_selected) + "_encoded.pt"
    # df = torch.load(enc)
    fpath = data_folder + "part\\" + str(part_selected) + "\\chunk_0.pkl"
    with open(fpath, 'rb') as f:
        enc = pickle.load(f)
    df = torch.tensor(enc, dtype=torch.long)
    return _get_batch(df)

@torch.no_grad()
def estimate_loss(data_folder, model):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            xb, yb = get_batch(data_folder = data_folder, split=split)
            logits, loss = model(xb, yb)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

def train_model(data_folder = None, make_encodings=False):
    if make_encodings:
        make_encodings()

    # instantiate the model and the optimizer
    model = GPT()
    m = model.to(device)

    optim = torch.optim.Adam(model.parameters(), lr=1e-3)

    best_val_loss = 1e9

    for i in range(n_epochs):
        # First, let's check to see if we are at an eval step
        if i % eval_interval == 0:
            losses = estimate_loss(data_folder, model)
            print("Step number {}: Training Loss: {}, Validation Loss:{}".format(i, losses['train'], losses['val']))
        
        # Check for model saving condition
        if losses['val'] < best_val_loss:
            best_val_loss = losses['val']
            # torch.save(model.state_dict(), MODEL_SAVE_PATH)
            torch.save(model.state_dict(), './outputs/model.pth')

        # Main training loop
        xb, yb = get_batch(data_folder=data_folder, split="train")
        logits, loss = m(xb, yb)
        optim.zero_grad(set_to_none=True)
        loss.backward()
        optim.step()
    
    print("After training, the best validation loss is {}".format(loss.item()))
    return model

In [23]:
# parser = argparse.ArgumentParser()
# parser.add_argument('--data-folder', type=str, help='dataset')
# args = parser.parse_args()

# os.makedirs('./outputs', exist_ok=True)
# data_folder = args.data_folder

# run = Run.get_context() 
data_folder = "C:\\Users\\Ben\\git\\nano_medgpt\\data\\interim\\encodings\\"

model = train_model(data_folder)

Step number 0: Training Loss: 3.8526923656463623, Validation Loss:3.8889126777648926
Step number 10: Training Loss: 3.8736984729766846, Validation Loss:3.862678050994873
Step number 20: Training Loss: 3.8838753700256348, Validation Loss:3.8779869079589844
Step number 30: Training Loss: 3.842189073562622, Validation Loss:3.8719208240509033
Step number 40: Training Loss: 3.86883282661438, Validation Loss:3.844949722290039
Step number 50: Training Loss: 3.8267223834991455, Validation Loss:3.87621808052063
Step number 60: Training Loss: 3.854963779449463, Validation Loss:3.8019726276397705
Step number 70: Training Loss: 3.7992324829101562, Validation Loss:3.834172010421753
Step number 80: Training Loss: 3.7863292694091797, Validation Loss:3.8002943992614746
Step number 90: Training Loss: 3.7799301147460938, Validation Loss:3.794801712036133
After training, the best validation loss is 3.770153760910034


In [24]:
print(generate_text(model, 100))

0s5zrhbw mbriw88xzug 1ww41iz8bxtnx8acjazhxzrpfn4tp6n37gz cmp8hzxz801u1qlrew4fhw5r8xuugw8oer3rj963edr6
