In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
import numpy as np
from generateTrees import generate_random_tree, serialize, deserialize
from torch.utils.data import Dataset, DataLoader
import os

# Define the Transformer model

In [None]:

class DecoderOnlyTransformer(nn.Module):
    def __init__(self, vocab_size, d_model, nhead, num_layers):
        super(DecoderOnlyTransformer, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.decoder_layers = nn.TransformerDecoderLayer(d_model, nhead)
        self.transformer_decoder = nn.TransformerDecoder(self.decoder_layers, num_layers)
        self.fc = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        #print("shape x", x.shape)
        x = self.embedding(x)
        #print("embedding", x.shape)
        memory = torch.zeros_like(x)
        output = self.transformer_decoder(x, memory)
        output = self.fc(output)
        return output

# Function for sequence generation

In [None]:
def generate_sequence(model, start_token, stop_token, max_length=10):
    model.eval()  # Set the model to evaluation mode
    with torch.no_grad():
        current_token = torch.tensor([start_token])

        generated_sequence = [start_token]

        # Generate sequences until the stop token is encountered or reach max length
        for _ in range(max_length):
            logits = model(current_token.unsqueeze(0))  # Add batch dimension
            

            # Sample the next token using argmax
            next_token = torch.argmax(logits[:, -1, :]).item()
            # Append the next token to the generated sequence
            generated_sequence.append(next_token)

            # If the stop token is encountered, break the loop
            if next_token == stop_token:
                break

            # Update the current token for the next iteration
            current_token = torch.tensor([next_token])

        return generated_sequence

# Data Loader

In [None]:
def read_tree(filename, dir):
    with open(dir +'/' +filename, "r") as f:
        byte = f.read() 
        return byte

In [None]:
def my_collate(batch):
    return batch


class tDataset(Dataset):
    def __init__(self, l, dir, transform=None):
        self.names = l
        self.transform = transform
        self.data = [] #lista con las strings de todos los arboles
        for file in self.names:
            self.data.append(read_tree(file, dir))

    def __len__(self):
        return len(self.names)

    def __getitem__(self, idx):
        tree = self.data[idx]
        return tree

batch_size = 1

# Training

In [None]:
vocab_size = 100
model = DecoderOnlyTransformer(vocab_size=vocab_size, d_model=512, nhead=8, num_layers=4)


# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop with batches
num_epochs = 100
losses = []

folder = "trees"
t_list = os.listdir( folder)
dataset = tDataset(t_list, folder )
data_loader = DataLoader(dataset, batch_size = batch_size, shuffle=True, collate_fn=my_collate)

for epoch in range(num_epochs):
    loss_batch = []
    for batch_idx, batch in enumerate(data_loader):
        print("batch", batch)
        optimizer.zero_grad()
        input_sequence = batch[:, :-1]
        target_sequence = batch[:, 1:]

        # Forward pass
        outputs = model(input_sequence)
        #print("input", input_sequence)
        #print("target", target_sequence)
        #print("output", outputs)
        # Calculate the loss using the shifted target sequence
        #breakpoint()
        loss = criterion(outputs.view(-1, vocab_size), target_sequence.reshape(-1))
        loss.backward()
        optimizer.step()
        loss_batch.append(loss.item())
    losses.append(np.average(loss_batch))
    if (epoch) % 5 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {np.average(loss_batch)}')

# Plotting the loss curve
plt.plot(losses)
plt.xlabel('Epochs')
plt.ylabel('Cross-Entropy Loss')
plt.title('Training Loss Curve')
plt.show()

# Choose a starting token and stop token for generation
start_token = torch.randint(0, vocab_size, (1,))
stop_token = torch.randint(0, vocab_size, (1,))

# Generate a sequence using autoregressive sampling
generated_sequence = generate_sequence(model, start_token.item(), stop_token.item(), max_length=10)


# Print the generated sequence
print("Generated Sequence:")
print(generated_sequence)