In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
from tqdm import tqdm
import random
import os
import torch.nn.utils.rnn as rnn_utils



In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
def load_all_txt_data(device='cpu', eos_token='\x03'):  # EOS token (␃)
    all_text = ""
    file_list = [f for f in os.listdir() if f.endswith(".txt")]

    # Read all files and add EOS tokens, removing non-ASCII except EOS
    for fname in file_list:
        with open(fname, 'r', encoding='utf-8') as f:
            lines = f.read().lower().splitlines()
            for line in lines:
                clean_line = ''.join([ch for ch in line if ch.isascii()])
                all_text += clean_line + eos_token

    # Build vocab including EOS token
    chars = sorted(list(set(all_text)))
    data_size, vocab_size = len(all_text), len(chars)
    print("----------------------------------------")
    print("Data has {} characters, {} unique".format(data_size, vocab_size))
    print("EOS token:", repr(eos_token))
    print("----------------------------------------")

    char_to_ix = {ch: i for i, ch in enumerate(chars)}
    ix_to_char = {i: ch for ch, i in char_to_ix.items()}

    # Reload with EOS and convert to tensors
    data = []
    for fname in file_list:
        with open(fname, 'r', encoding='utf-8') as f:
            lines = f.read().lower().splitlines()
            for line in lines:
                clean_line = ''.join([ch for ch in line if ch.isascii()])
                clean_line += eos_token  # Append EOS
                indices = [char_to_ix[ch] for ch in clean_line if ch in char_to_ix]
                tensor = torch.tensor(indices).to(device)
                if tensor.numel() > 1:
                    data.append(tensor)

    return data, char_to_ix, ix_to_char, vocab_size
data, char_to_ix, ix_to_char, vocab_size = load_all_txt_data(device)



----------------------------------------
Data has 6817585 characters, 66 unique
EOS token: '\x03'
----------------------------------------


In [4]:
print(len(data))

62228


In [None]:
class RNN(nn.Module):
    def __init__(self, input_size, embedding_size, output_size, hidden_size, num_layers=1):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.embedding = nn.Embedding(input_size, embedding_size)
        self.rnn = nn.LSTM(
            input_size=embedding_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=False 
        )
        self.decoder = nn.Linear(hidden_size, output_size)

    def forward(self, input_seq, hidden_state):
        embedded = self.embedding(input_seq)  
        output, hidden_state = self.rnn(embedded, hidden_state)
        output = self.decoder(output) 
        return output, hidden_state

    def init_hidden(self, batch_size, device):
        h0 = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device)
        c0 = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device)
        return (h0, c0)


In [6]:
model = RNN(input_size=vocab_size,
            embedding_size=vocab_size,
            output_size=vocab_size,
            hidden_size=512,
            num_layers=3).to(device)
loss_fn = nn.CrossEntropyLoss()


In [None]:
total_params = sum(p.numel() for p in model.parameters())
print(f"The model has {total_params} parameters.")

The model has 5428550 parameters.


In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
batch_size = 32
epochs = 10
model.train()

for i_epoch in range(1, epochs + 1):
    random.shuffle(data)

    n = 0
    running_loss = 0

    for i in tqdm(range(0, len(data), batch_size), desc=f"Epoch {i_epoch}"):
        batch = data[i:i + batch_size]

        input_seqs = [seq[:-1] for seq in batch]
        target_seqs = [seq[1:] for seq in batch]

        input_padded = rnn_utils.pad_sequence(input_seqs, batch_first=False)  
        target_padded = rnn_utils.pad_sequence(target_seqs, batch_first=False)

        input_padded = input_padded.to(device)
        target_padded = target_padded.to(device)

        hidden = model.init_hidden(input_padded.size(1), device)  # current batch size

        output, hidden = model(input_padded, hidden)

        loss = loss_fn(output.view(-1, output.size(-1)), target_padded.view(-1))
        running_loss += loss.item()
        n += 1

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print("Epoch: {0} \t Loss: {1:.8f}".format(i_epoch, running_loss / n))

    if i_epoch % 5 == 0:
        torch.save(model.state_dict(), "model.pth")


In [None]:
# Save model parameters
# torch.save(model.state_dict(), "model_3layer512.pth")


In [8]:
model = RNN(input_size=vocab_size,
        embedding_size=vocab_size,
        output_size=vocab_size,
        hidden_size=512,
        num_layers=3).to(device)
model.load_state_dict(torch.load("model_3layer512.pth")) 
model.to(device)

RNN(
  (embedding): Embedding(66, 66)
  (rnn): LSTM(66, 512, num_layers=3)
  (decoder): Linear(in_features=512, out_features=66, bias=True)
)

In [9]:
def format_named_entities(entities):
    if not entities:
        return "() ||| "
    entity_str = ", ".join(entities)
    return f"({entity_str}) ||| "


In [290]:
prompt = ['ai']

prompt = format_named_entities(prompt).lower()
print(prompt)
prompt = [char_to_ix[ch] for ch in prompt]
model.eval()  # Use this for inference mode

with torch.no_grad():
    prompt_tensor = torch.tensor(prompt).to(device).long()
    hidden = model.init_hidden(1, device)  # batch size

    # Prime the model
    output, hidden = model(prompt_tensor.unsqueeze(1), hidden)  # [seq_len, batch]
    temperature = 0.5

    while True:
        last_logits = output[-1].squeeze()  # Shape: [vocab_size]
        # Sample from softmax
        prediction = torch.multinomial(torch.softmax(last_logits/temperature, dim=0), num_samples=1)
        # Print character
        if ix_to_char[int(prediction.item())] == '\x03':
            break
        print(ix_to_char[int(prediction.item())], end="")


        # Prepare input for next step
        inp = prediction.view(1, 1)  # [seq_len=1, batch=1]
        output, hidden = model(inp, hidden)
        temperature = 0.5




(ai) ||| 
ai could keep us dependent on natural gas for defying human rights