In [None]:
#run imports
import torch, time, Data, random, sys, json
import matplotlib.pyplot as plt
import torch.nn as nn
from tokenizer import Tokenizer
torch.random.manual_seed(1)

#choose device
if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")
print(f'Device set to {device}')

In [None]:
# Model hyper parameters
chunk_size = 128
embedding_dim = 256 #must be some factor of heads in attention layers

# Transforer Hyper parameters
num_attention_blocks = 4
num_heads = 8
dropout_rate = 0.0

# Training parameters
batch_size = 32
learning_rate = 0.001
train_time = 60 # minutes
num_step=15000

vocab_file = 'vocab_chars.json'

In [None]:
#get data
train_file='tbbt_train.txt'
test_file='tbbt_test.txt'

file = open(vocab_file, 'r')
vocab = json.loads(file.read())
file.close()

tokenizer = Tokenizer(vocab)

data = Data.Data(train_file, test_file, tokenizer, chunk_size, sample_data=True)

#sample data
x, y = data.get_random_train_sample(num_samples = 1)
print(x)
print(y)
print(f'X: {x.shape}, Y: {y.shape}')
print('X: (batch size, chunk size), Y: (batch size, chunk size, num char)')

In [None]:
from model import Transformer

#load saved checkpoint
#model = Transformer.create_from_checkpoint('Transformer.pth.tar').to(device)

# Create new model from hyper parameters
model = Transformer(
    len(tokenizer.vocab_list), 
    embedding_dim, chunk_size, 
    num_heads, 
    num_attention_blocks, 
    learning_rate, 
    dropout_rate
).to(device)

print(model.get_parameter_count(), 'Million Parameters')

cost_data, cost_it=[],[]
step=0

In [None]:
# # load saved checkpoint
# model:Transformer = Transformer.create_from_checkpoint('Transformer.pth.tar')
# cost_data = model.training_state['cost_data']

In [None]:
# Train Model
model.train()
start=time.time()

flatten_const=100

while(time.time()-start<60*train_time):
    step+=1
    x_train, y_train= data.get_random_train_sample(batch_size)
    x_train = x_train.to(device)
    y_train = y_train.to(device)

    # Forward pass
    outputs, loss = model(x_train, y_train)

    # Backward and optimize
    model.optimizer.zero_grad()
    loss.backward()
    model.optimizer.step()

    cost_it.append(loss.item())


    #flatten graph data
    if (step + 1) % flatten_const == 0:
        cost_data.append(torch.tensor(cost_it).mean().item())
        cost_it=[]

    #print update
    if (step + 1) % 100 == 0:
        print(f"Step [{step+1}], Loss: {cost_data[-1]:.4f}")

    if (step + 1) % 5000 == 0:
        model.save_checkpoint('Checkpoint.pth.tar', training_state={'cost_data':cost_data})
        print('Saved Checkpoint')

print(step)

torch.cuda.empty_cache()

In [None]:
# Run test set
model.eval()
test_loss_list=[]
test_batch_size=64
start = time.time()
with torch.no_grad():
    for i in range((len(data.test_text)-len(tokenizer.vocab_list)-1)//test_batch_size):
        x, y = data.get_test_sample(i*test_batch_size, test_batch_size)
        x=x.to(device)
        y=y.to(device)
        outputs, loss = model(x, y)
        test_loss_list.append(loss)

        torch.cuda.empty_cache()
        print(f'Step [{i}/{(len(data.test_text)-len(tokenizer.vocab_list))//test_batch_size}, Loss: {loss.item()}]')
val_loss=torch.tensor(test_loss_list).mean().item()
print(f'Validation Loss: {val_loss}, Test Time: {time.time() - start}')

torch.cuda.empty_cache()

In [None]:
# Plot training data
plt.plot(torch.arange(0, len(cost_data))*flatten_const, cost_data)
plt.xlabel('Number of Steps')
plt.ylabel('Average Loss')
plt.title('Average Loss vs Step')
plt.show()

In [None]:
# LLM Inference

def sample_with_temp(probs, temperature=1.0, top_k=None):
    # Apply temperature scaling
    if temperature != 1.0:
        probs = probs ** (1.0 / temperature)
    probs = probs / probs.sum()  # Re-normalize

    if top_k is not None:
        # Get top-k probabilities and their indices
        top_probs, top_indices = torch.topk(probs, top_k)
        top_probs = top_probs / top_probs.sum()
        sampled_index = torch.multinomial(top_probs, num_samples=1)
        return top_indices[sampled_index].item()
    else:
        return torch.multinomial(probs, num_samples=1).item()

# Begin generation
model.eval()
tokens = data._encoded_train_data[-data.ctx_size:]
text = data.train_text[0:data.ctx_size]
print('"', end='', sep='')

total_log_prob = 0.0
n_tokens = 0

# Sampling config
TEMPERATURE = 0.8
TOP_K = 40

with torch.no_grad():
    for i in range(1000):
        # Generate input
        x = torch.tensor(tokens[-data.ctx_size:]).to(device)
        x = x.reshape((1, data.ctx_size))

        # Model forward
        output = model(x)  # shape: (1, T, vocab_size)
        output = output.view((chunk_size, len(tokenizer.vocab_list)))  # Flattening, okay here
        logits = output[-1]  # Last token's prediction

        # Convert to probabilities
        probs = nn.functional.softmax(logits, dim=-1)

        # Sample a token
        #sampled_token = torch.multinomial(probs, num_samples=1).item()
        sampled_token = sample_with_temp(probs, temperature=TEMPERATURE, top_k=TOP_K)
        tokens.append(sampled_token)

        # Track log probability
        log_prob = torch.log(probs[sampled_token])
        total_log_prob += log_prob.item()
        n_tokens += 1

        # Decode and print
        char = tokenizer.decode([sampled_token])[0]
        text += char
        print(char, end='', sep='')
        sys.stdout.flush()

print('"\n', sep='')

# Compute and print perplexity
import math
perplexity = math.exp(-total_log_prob / n_tokens)
print(f"[Perplexity of generated text: {perplexity:.3f}]")


In [None]:
# # Convert old model types to new models...

# # Save model with hyper parameters for new class definition:
# checkpoint = {
#     'model_state': model.state_dict(),
#     'optimizer_state': model.optimizer.state_dict(),
#     'hyperparams': {
#         'vocab_size': len(data.tokenizer.vocab_list),
#         'ctx_window_length': chunk_size,
#         'embedding_dim': embedding_dim,
#         'num_attention_blocks': num_attention_blocks,
#         'num_attention_heads': num_heads,
#         'learning_rate': learning_rate,
#         'dropout_rate': dropout_rate
#     },
#     'training_state':{
#         'loss_data':cost_data,
#         'val_loss':1.1509664058685303,
#         'train_step':78700
#     }
# }
        
# torch.save(checkpoint, 'Transformer.pth.tar')