In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import string
from tqdm import tqdm

# Define the model
class GRUTextGenerator(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size):
        super(GRUTextGenerator, self).__init__()
        
        self.vocab_size = vocab_size
        self.embed_size = embed_size

        self.embedding_ff = nn.Linear(vocab_size, embed_size, bias=False)
        self.gru1 = nn.GRU(embed_size, hidden_size, 1, batch_first=True)
        self.gru2 = nn.GRU(hidden_size, hidden_size, 1, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, x, hidden=None):
        # x: (batch_size, seq_length)
        x = torch.nn.functional.one_hot(x, num_classes=self.vocab_size).float()
        embedded = self.embedding_ff(x)
        x, hidden = self.gru1(embedded, None)  # (batch_size, seq_length, hidden_size)
        x, hidden = self.gru2(x, None)
        logits = self.fc(x)  # (batch_size, seq_length, vocab_size)
        return logits, hidden

token_dict = {letter:i for letter,i in zip(string.ascii_lowercase, range(1,27))}
token_dict[' '] = 0

class TextDataset(Dataset):
    def __init__(self, sentences, idx):
        self.data = []

        self.max_len = max([len(x) for x in sentences])

        for sentence in sentences:
            if len(sentence) > idx:
                x = torch.tensor([token_dict[x] for x in sentence[:idx]], dtype=torch.int64, device=device)
                y = torch.tensor(token_dict[sentence[idx]], dtype=torch.int64, device=device)
                self.data.append((x, y))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, i):
        return self.data[i]

# Training loop
def train(model, dataloaders, vocab_size, device, epochs=10, lr=0.001):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    model.to(device)
    model.train()
    
    for epoch in range(epochs):
        for data_loader in dataloaders:
            total_loss = 0
            for _, (inputs, targets) in enumerate(data_loader, 0):
                inputs, targets = inputs.to(device), targets.to(device)
                optimizer.zero_grad()
                
                logits, _ = model(inputs)  # (batch_size, seq_length, vocab_size)
                logits = logits[:, -1, :]
                logits = logits.view(-1, vocab_size)  # Reshape for loss calculation
                targets = targets.view(-1)  # Reshape to match logits
                
                loss = criterion(logits, targets)
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
        
        print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss / len(data_loader)}")

        break

# Example hyperparameters
vocab_size = 27  # a-z and a space character
embed_size = 128
hidden_size = 256
batch_size = 64
seq_length = 100
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Create the model
model = GRUTextGenerator(vocab_size, embed_size, hidden_size)



In [4]:
print(model.embedding_ff.weight[:, 1])

tensor([-1.7557e-01,  6.7922e-02,  9.8388e-02, -1.2501e-01, -6.0013e-02,
        -1.6312e-01, -6.9352e-02,  1.0065e-01, -3.9197e-02,  3.2594e-02,
        -1.5702e-01,  6.6227e-02,  1.4566e-01, -1.1029e-01,  5.9280e-03,
         6.4825e-02, -1.6047e-01, -1.6856e-01, -4.2442e-02, -7.7414e-02,
        -3.4500e-02, -1.2600e-01, -1.1436e-01, -4.8913e-02,  5.6049e-03,
         1.1267e-01,  8.6909e-02, -1.3511e-01, -1.4664e-01,  1.4858e-01,
         1.5899e-01, -1.7651e-01, -1.0319e-02, -1.0715e-01, -7.3224e-02,
        -1.5392e-01,  1.8479e-01,  7.3894e-02,  1.8040e-02,  1.4655e-01,
         8.3047e-02,  1.7492e-01, -7.9131e-02, -1.2953e-01,  1.4146e-01,
         1.1696e-04,  6.7882e-02,  1.7049e-01, -5.6124e-02, -4.6413e-02,
        -1.0923e-01, -1.8666e-02,  9.0332e-02, -7.3868e-04, -3.0353e-02,
        -3.5101e-02, -6.4153e-02, -1.4838e-02,  6.8584e-02, -6.7836e-02,
         1.8344e-01, -1.2265e-01,  9.8370e-02, -1.5883e-01, -5.2605e-02,
        -1.8863e-01,  1.7119e-02, -5.5684e-03, -1.7

In [5]:
# Placeholder: Replace with your own DataLoader
sentences = []
with open("simple_sentences.txt", 'r') as f:
    sentences = f.readlines()
sentences = [x.strip() for x in sentences]

dataloaders = []
for i in range(1, max([len(x) for x in sentences])):
    dataset = TextDataset(sentences, i)
    dataloaders.append(DataLoader(dataset, batch_size))

# Example training call
train(model, dataloaders, vocab_size, device)

Epoch 1/10, Loss: 0.000115703387564281


In [16]:
def vectorise(sentence):
    return torch.tensor([token_dict[x] for x in sentence], dtype=torch.int64, device=device)

def run_inference(model, sentence):
    input = vectorise(sentence)

    output = model(input)

    logits = output[0][-1,:]

    print(logits)

    max_idx = logits.argmax().item()
    for a, i in token_dict.items():
        if max_idx == i:
            sentence += a
            return sentence
    
    print(max_idx)



sentence = 'a dog run'
# sentence = 'a'


In [17]:
for i in range(12):
    sentence = run_inference(model, sentence)
    print(sentence)
    break

tensor([ 2.5538, -7.0962, -2.3684, -2.7878,  4.2614,  0.1217, -3.6715, -3.4830,
        -2.9746, -4.4778,  1.7500, -2.6130, -4.1250, -3.8780, -5.2185, -6.5963,
         6.9670, -3.8502,  4.0441,  6.9086, -3.4119, -7.4640, -5.0035,  1.3210,
        -3.5759, -1.6996, -4.0673], grad_fn=<SliceBackward0>)
a dog runp


In [18]:
import numpy as np
import json
from json import JSONEncoder


In [19]:
class EncodeTensor(JSONEncoder,Dataset):
    def default(self, obj):
        if isinstance(obj, torch.Tensor):
            return obj.cpu().detach().numpy().tolist()
        return super(json.NpEncoder, self).default(obj)


In [None]:
x = vectorise(sentence)
y = model(x)[0]

np.savetxt('test_data/generative_gru_x.csv', x.detach(), delimiter=',')
np.savetxt('test_data/generative_gru_y.csv', y.detach(), delimiter=',')

with open('models/generative_gru.json', 'w') as json_file:
    json.dump(model.state_dict(), json_file,cls=EncodeTensor)
