In [6]:
import pandas as pd
import json
import os

fulldata = pd.DataFrame()
datadir = "/home/m/dev/ai/llm/wiki_en"

for filename in os.listdir(datadir):
    if filename.endswith(".json"):
        with open(os.path.join(datadir, filename), "r") as f:
            data = json.load(f)
            df = pd.DataFrame(data)
            
            # Split the data into 8 parts
            chunk_size = len(df) // 8
            for i in range(8):
                start = i * chunk_size
                end = (i + 1) * chunk_size if i < 7 else len(df)
                chunk_df = df.iloc[start:end]
                
                fulldata = pd.concat([fulldata, chunk_df], ignore_index=True)
                print(f"Loaded {filename} (chunk {i+1} of 8):")
                print(chunk_df.head())
            
            break  # This break is kept to maintain the original behavior of processing only one file

fulldata

Loaded 54814a89-cfc6-4429-a44b-ef9a1f256971.json (chunk 1 of 8):
         id                                               text  \
0  30060320  Cédric Gerbehaye (born 1977) is a Belgian jour...   
1  30060327  The West Virginia Capitol Complex is a histori...   
2  30060339  It's Real may refer to: * It's Real (K-Ci & Jo...   
3  30060356  The 2011 Blancpain Endurance Series season was...   
4  30060369  Terra Venture Partners is an Israeli venture c...   

                             title  
0                 Cédric Gerbehaye  
1    West Virginia Capitol Complex  
2                        It's Real  
3  2011 Blancpain Endurance Series  
4           Terra Venture Partners  
Loaded 54814a89-cfc6-4429-a44b-ef9a1f256971.json (chunk 2 of 8):
           id                                               text  \
593  30072986  Jhonatan Longhi (born February 2, 1988) is an ...   
594  30073006  Marko Rudić (born January 17, 1990) is an alpi...   
595  30073009  Shady is an unincorporated commu

Unnamed: 0,id,text,title
0,30060320,Cédric Gerbehaye (born 1977) is a Belgian jour...,Cédric Gerbehaye
1,30060327,The West Virginia Capitol Complex is a histori...,West Virginia Capitol Complex
2,30060339,It's Real may refer to: * It's Real (K-Ci & Jo...,It's Real
3,30060356,The 2011 Blancpain Endurance Series season was...,2011 Blancpain Endurance Series
4,30060369,Terra Venture Partners is an Israeli venture c...,Terra Venture Partners
...,...,...,...
4742,30121786,"Devan Deangelo Downey (born September 28, 1987...",Devan Downey
4743,30121798,"""The Hand That Rocks the Wheelchair"" is the 12...",The Hand That Rocks the Wheelchair
4744,30121817,The 1995 Supercopa Libertadores was the eighth...,1995 Supercopa Libertadores
4745,30121823,"Carl Davis (born November 16, 1973) is an Amer...",Carl Davis (boxer)


In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer
import math

class TransformerModel(nn.Module):
    def __init__(self, vocab_size, d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=2048, dropout=0.1):
        super(TransformerModel, self).__init__()
        
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        
        encoder_layers = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_encoder_layers)
        
        decoder_layers = nn.TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layers, num_decoder_layers)
        
        self.output_layer = nn.Linear(d_model, vocab_size)
        
        self.d_model = d_model
        
    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        src = self.embedding(src) * math.sqrt(self.d_model)
        src = self.pos_encoder(src)
        memory = self.transformer_encoder(src, src_key_padding_mask=src_mask)
        
        tgt = self.embedding(tgt) * math.sqrt(self.d_model)
        tgt = self.pos_encoder(tgt)
        output = self.transformer_decoder(tgt, memory, tgt_key_padding_mask=tgt_mask)
        
        return self.output_layer(output)

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

class WikiDataset(Dataset):
    def __init__(self, dataframe, tokenizer, max_length=512):
        self.data = dataframe
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text = self.data.iloc[idx]['text']
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_length,
            return_token_type_ids=False,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt',
        )
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten()
        }

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
vocab_size = tokenizer.vocab_size
model = TransformerModel(vocab_size)
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
optimizer = optim.Adam(model.parameters(), lr=0.0001)



dataset = WikiDataset(fulldata, tokenizer)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

num_epochs = 1
max_batches = 10


In [None]:


model.train()
total_loss = 0
for epoch in range(num_epochs):
    for batch_idx, batch in enumerate(dataloader):
        if batch_idx >= max_batches:
            break
        
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        
        optimizer.zero_grad()
        
        target = input_ids[:, 1:].contiguous()
        input_ids = input_ids[:, :-1].contiguous()
        
        src_key_padding_mask = (~attention_mask[:, :-1].bool()).transpose(0, 1).to(device)
        tgt_key_padding_mask = (~attention_mask[:, :-1].bool()).transpose(0, 1).to(device)
        
        assert src_key_padding_mask.shape == (input_ids.size(1), input_ids.size(0)), f"src_key_padding_mask shape: {src_key_padding_mask.shape}, expected: {(input_ids.size(1), input_ids.size(0))}"
        assert tgt_key_padding_mask.shape == (input_ids.size(1), input_ids.size(0)), f"tgt_key_padding_mask shape: {tgt_key_padding_mask.shape}, expected: {(input_ids.size(1), input_ids.size(0))}"
        
        outputs = model(input_ids, input_ids, src_mask=src_key_padding_mask, tgt_mask=tgt_key_padding_mask)
        loss = criterion(outputs.view(-1, vocab_size), target.view(-1))
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx+1}/{max_batches}, Loss: {loss.item():.4f}")

    avg_loss = total_loss / max_batches
    print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")

print("Training completed.")


torch.save(model.state_dict(), 'wiki_transformer_model.pth')
print("Model saved as 'wiki_transformer_model.pth'")


In [12]:
# Load the model
print("Step 1: Loading the model")
model = TransformerModel(vocab_size)
model.load_state_dict(torch.load('wiki_transformer_model.pth'))
model.eval()
print("Model loaded successfully")

# Initialize tokenizer
print("\nStep 2: Initializing tokenizer")
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
print("Tokenizer initialized")

# Function to generate text
def generate_text(prompt, max_length=10):
    print(f"\nStep 3: Generating text for prompt: '{prompt}'")
    model.eval()
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
    attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=device)
    
    for i in range(max_length):
        print(f"  Generating token {i+1}/{max_length}")
        src_key_padding_mask = (~attention_mask.bool()).transpose(0, 1)
        tgt_key_padding_mask = (~attention_mask.bool()).transpose(0, 1)
        
        with torch.no_grad():
            output = model(input_ids, input_ids, src_mask=src_key_padding_mask, tgt_mask=tgt_key_padding_mask)
        
        next_token_logits = output[0, -1, :]
        next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(0).unsqueeze(0)
        
        input_ids = torch.cat([input_ids, next_token], dim=1)
        attention_mask = torch.cat([attention_mask, torch.ones((1, 1), dtype=torch.long, device=device)], dim=1)
        
        if next_token.item() == tokenizer.sep_token_id:
            print("  Reached end of sequence token")
            break
    
    generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
    print("Text generation completed")
    return generated_text

# Run a prompt
print("\nStep 4: Running a prompt")
prompt = "The"
print(f"Prompt: {prompt}")
generated_text = generate_text(prompt)
print(f"Generated text: {generated_text}")

print("\nStep 5: Process completed")


Step 1: Loading the model


  model.load_state_dict(torch.load('wiki_transformer_model.pth'))


Model loaded successfully

Step 2: Initializing tokenizer
Tokenizer initialized

Step 4: Running a prompt
Prompt: The

Step 3: Generating text for prompt: 'The'
  Generating token 1/10
  Generating token 2/10
  Generating token 3/10
  Generating token 4/10
  Generating token 5/10
  Generating token 6/10
  Generating token 7/10
  Generating token 8/10
  Generating token 9/10
  Generating token 10/10
Text generation completed
Generated text: the..........

Step 5: Process completed
