In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2Tokenizer

class SimpleTransformer(nn.Module):
    def __init__(self, vocab_size, d_model, n_heads, n_layers, dropout=0.1):
        super(SimpleTransformer, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.transformer = nn.Transformer(
            d_model=d_model, nhead=n_heads, num_encoder_layers=n_layers, 
            num_decoder_layers=n_layers, dropout=dropout
        )
        self.fc_out = nn.Linear(d_model, vocab_size)
    
    def forward(self, src, tgt):
        src_emb = self.embedding(src).permute(1, 0, 2)
        tgt_emb = self.embedding(tgt).permute(1, 0, 2)
        output = self.transformer(src_emb, tgt_emb)
        return self.fc_out(output)

# Preparar os dados
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# Definir hiperparâmetros
vocab_size = tokenizer.vocab_size
d_model = 512  # Dimensão dos embeddings
n_heads = 8  # Número de cabeças de atenção
n_layers = 6  # Número de camadas do Transformer

# Criar o modelo
model = SimpleTransformer(vocab_size, d_model, n_heads, n_layers)



In [3]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4)

In [4]:
def load_text_file(file_path):
    with open(file_path, "r", encoding="utf-8") as f:
        text = f.read()
    return text

# Carregar e dividir o texto em parágrafos
file_path = "dom_casmurro.txt"
full_text = load_text_file(file_path)
texts = full_text.split("\n\n")  # Divide em parágrafos
texts = [t.strip() for t in texts if t.strip()]  # Remove espaços extras

print(f"Total de trechos: {len(texts)}")
print(f"Exemplo de trecho: {texts[0][:200]}")  # Exibir um trecho do texto

Total de trechos: 1648
Exemplo de trecho: Dom Casmurro


In [None]:
from torch.utils.data import DataLoader, Dataset
import torch
import time

class TextDataset(Dataset):
    def __init__(self, texts, tokenizer, max_length=128):
        self.tokenizer = tokenizer
        self.max_length = max_length
        # Pré-tokenizar todos os textos
        self.tokens = [
            tokenizer(
                text, 
                max_length=max_length, 
                padding="max_length", 
                truncation=True, 
                return_tensors="pt"
            )["input_ids"].squeeze(0) 
            for text in texts
        ]

    def __len__(self):
        return len(self.tokens)

    def __getitem__(self, idx):
        input_ids = self.tokens[idx]
        return input_ids, input_ids

# Criar dataset com texto corrigido
dataset = TextDataset(texts, tokenizer)

# Definir função de collation para lidar com tamanhos diferentes
def collate_fn(batch):
    inputs, targets = zip(*batch)
    inputs = torch.stack(inputs)  # Empilhar os tensores para batch
    targets = torch.stack(targets)
    return inputs, targets

dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=collate_fn, num_workers=0)

# Loop de treinamento
num_epochs = 5

for epoch in range(num_epochs):
    start_time = time.time()
    for i, (src, tgt) in enumerate(dataloader):
        optimizer.zero_grad()
        output = model(src, tgt)
        loss = criterion(output.view(-1, vocab_size), tgt.view(-1))
        loss.backward()
        optimizer.step()
        if i % 10 == 0:  # Print a cada 10 batches
            print(f"Batch {i}, Time: {time.time() - start_time:.2f}s")
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item():.4f}, Time: {time.time() - start_time:.2f}s")


In [None]:
def generate_text(prompt, model, tokenizer, max_length=50):
    model.eval()
    with torch.no_grad():
        tokens = tokenizer(prompt, return_tensors="pt")["input_ids"]
        for _ in range(max_length):
            output = model(tokens, tokens)
            next_token = output.argmax(-1)[:, -1].unsqueeze(0)
            tokens = torch.cat([tokens, next_token], dim=1)
    
    return tokenizer.decode(tokens.squeeze(), skip_special_tokens=True)

# Gerar um texto
generated_text = generate_text("Olá, como você está?", model, tokenizer)
print(generated_text)
