In [99]:
# import torch_xla.core.xla_model as xm

# # List all available TPU devices
# devices = xm.get_xla_supported_devices()
# print(f'Available TPU devices: {devices}')


## Modelo

In [117]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
from torch.utils.data import DataLoader, Dataset

class LSTMTextGenerator(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size, num_layers, dropout=0.3, pretrained_embeddings=None, device=None):
        super(LSTMTextGenerator, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        
        # Initialize the embedding layer with pre-trained embeddings if provided
        if pretrained_embeddings is not None:
            self.embedding.weight = nn.Parameter(torch.tensor(pretrained_embeddings))
            self.embedding.weight.requires_grad = False  
        
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True, dropout=dropout, bidirectional=True)
        self.layer_norm = nn.LayerNorm(hidden_size * 2)  
        self.fc = nn.Linear(hidden_size * 2, vocab_size)
        self.dropout = nn.Dropout(dropout)
        self.device = device
    
    def forward(self, x, hidden):
        x = self.embedding(x)
        x = self.dropout(x)
        out, hidden = self.lstm(x, hidden)
        out = self.layer_norm(out)
        out = self.fc(out)
        return out, hidden

    def init_hidden(self, batch_size):
        num_layers = self.lstm.num_layers
        hidden_size = self.lstm.hidden_size
        num_directions = 2  
        if self.device is not None:
            return (torch.zeros(num_layers * num_directions, batch_size, hidden_size).to(self.device),
                    torch.zeros(num_layers * num_directions, batch_size, hidden_size).to(self.device))
        return (torch.zeros(num_layers * num_directions, batch_size, hidden_size),
                torch.zeros(num_layers * num_directions, batch_size, hidden_size))

# Example usage
# device = xm.xla_device()
# model = LSTMTextGenerator(vocab_size=5000, embed_size=300, hidden_size=256, num_layers=2).to(device)


In [None]:
import re
# Load the vocabulary (token -> index) mapping
vocab = {}  # You should populate this with your BPE vocabulary
with open('tokenizadorIskonawa.vocab', 'r', encoding='utf-8') as vocab_file:
    for idx, line in enumerate(vocab_file):
        token, code = re.split(r'\t', line.strip())
        # Save as integer
        vocab[token] = idx

# Load the BPE tokenized dataset
def load_bpe_dataset(file_path, vocab):
    dataset = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            # Tokenize each line into subwords and convert them to indices
            tokens = line.strip().split()  # Assuming tokens are space-separated
            token_ids = [vocab.get(token, vocab['<unk>']) for token in tokens]  # Handle unknown tokens
            dataset.append(token_ids)
    return dataset

bpe_tokenized_dataset = load_bpe_dataset('tokens.txt', vocab)

# Check dataset example
print(bpe_tokenized_dataset[:1])


In [102]:
import numpy as np

def load_embeddings(embedding_file, vocab):
    with open(embedding_file, 'r', encoding='utf-8') as f:
        # Read the first line to get vocab size and embed size
        first_line = f.readline().strip()
        vocab_size, embed_size = map(int, first_line.split())
        
        # Initialize a dictionary to hold the embeddings
        embeddings = np.zeros((len(vocab), embed_size), dtype=np.float32)
        
        # Read the rest of the file
        for line in f:
            values = line.strip().split()
            subword = values[0].strip()
            vector = np.array(values[1:], dtype=np.float32)
            index = vocab.get(subword, -1)
            if index == -1:
                print(f'Found {subword} in vocab')
            else:
                embeddings[index] = vector


    return embeddings, vocab_size, embed_size

embedding_file = 'isk_anchor_final3.txt'
pretrained_embeddings, vocab_size, embed_size = load_embeddings(embedding_file, vocab)

In [103]:
class BPEDataset(Dataset):
    def __init__(self, tokenized_data, pad_token=0):
        self.tokenized_data = tokenized_data
        self.pad_token = pad_token

    def __len__(self):
        return len(self.tokenized_data)

    def __getitem__(self, idx):
        # Get the tokenized sentence
        sentence = self.tokenized_data[idx]
        
        # Convert to tensor and return
        return torch.tensor(sentence, dtype=torch.long)

def collate_fn(batch):
    max_length = max(len(sentence) for sentence in batch)
    padded_batch = [torch.cat([sentence, torch.tensor([0] * (max_length - len(sentence)))]) for sentence in batch]
    return torch.stack(padded_batch)

In [None]:
dataset = BPEDataset(bpe_tokenized_dataset)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, collate_fn=collate_fn)

# Check example
for inputs in dataloader:
    print(inputs.shape)  # Check the shape of the padded input batch
    break

In [105]:
import os
import torch

def save_checkpoint(epoch, model, optimizer, loss, checkpoint_dir='checkpoints'):
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_uni_last.pth')
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }, checkpoint_path)
    print(f'Checkpoint saved at {checkpoint_path}')
    # Create  or append to a file called "checkpoint_log.txt" to save the epoch and loss
    with open('checkpoint_log.txt', 'a') as f:
        f.write(f'Epoch: {epoch}, Loss: {loss}\n')

## Entrenar

In [106]:
def load_checkpoint(checkpoint_path, model, optimizer):
    if os.path.isfile(checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        epoch = checkpoint['epoch']
        loss = checkpoint['loss']
        print(f'Checkpoint loaded from {checkpoint_path}, epoch: {epoch}, loss: {loss}')
        return epoch, loss
    else:
        print(f'No checkpoint found at {checkpoint_path}')
        return None, None

In [107]:
import torch
from torch.utils.data import DataLoader
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.utils.utils as xu
# Parameters
vocab_size = len(vocab)
embed_size = 300
hidden_size = 128
num_layers = 2
num_epochs = 20
learning_rate = 0.0001

def train_loop_fn(loader, model, optimizer, criterion, epoch, device):
    total_loss = 0
    model.train()  # Set the model to training mode
    step = 0
    total_batches = len(loader)

    for batch_idx, inputs in enumerate(loader):
        optimizer.zero_grad()
        
        # Prepare inputs and targets for text generation
        inputs_seq = inputs[:, :-1].to(device).long()
        targets_seq = inputs[:, 1:].to(device).long()
        
        # Initialize hidden state
        hidden = model.init_hidden(inputs_seq.size(0))  # Initialize hidden state based on batch size

        # Forward pass
        outputs, hidden = model(inputs_seq, hidden)  # Pass inputs_seq and hidden state
        
        # Calculate loss
        loss = criterion(outputs.view(-1, vocab_size), targets_seq.contiguous().view(-1))
        
        # Backward pass and optimization
        loss.backward()
        xm.optimizer_step(optimizer)
        xm.mark_step()
        
        total_loss += loss.item()  # Accumulate loss
        progress = (batch_idx + 1) / total_batches * 100  # Calculate progress percentage
        print(f"Successfully completed step {step} on device: {device}, Progress: {progress:.2f}%")
        step += 1

    return total_loss / len(loader)  # Average loss over the epoch

def _runa(rank, flags, device):
    # Set up device
    print(f'Training on: {device}\n')

    # Initialize the model, loss function, and optimizer
    model = LSTMTextGenerator(vocab_size, embed_size, hidden_size, num_layers, pretrained_embeddings=pretrained_embeddings).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Load checkpoint if available
    start_epoch, _ = load_checkpoint('checkpoints/checkpoint_last.pth', model, optimizer)
    if start_epoch is None:
        start_epoch = 0

    # Training Loop
    for epoch in range(num_epochs):
        para_loader = pl.MpDeviceLoader(dataloader, device)
        loss = train_loop_fn(para_loader, model, optimizer, criterion, epoch, device)
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss:.4f}, Device: {device}')



In [None]:
def train_model():
    # Initialize the model, loss function, and optimizer
    device = xm.xla_device()
    model = LSTMTextGenerator(vocab_size, embed_size, hidden_size, num_layers, pretrained_embeddings=pretrained_embeddings, device=device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    # Create the DataLoader
    dataloader = DataLoader(dataset, batch_size=256, shuffle=True, collate_fn=collate_fn)
    
    model.to(device)

    # Load checkpoint if available
    start_epoch, _ = load_checkpoint('checkpoints/checkpoint_uni_last.pth', model, optimizer)
    if start_epoch is None:
        start_epoch = 0


    # Training Loop
    for epoch in range(num_epochs):
        para_loader = pl.ParallelLoader(dataloader, [device]).per_device_loader(device)
        loss = train_loop_fn(para_loader, model, optimizer, criterion, epoch, device)
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss:.4f}, Device: {device}')
        save_checkpoint(epoch, model, optimizer, loss)

train_model()

In [None]:
from torch.utils.data import DataLoader, Dataset
from transformers import AdamW, get_linear_schedule_with_warmup
import gc
import numpy as np
# from sklearn import metrics

embed_size = 300
hidden_size = 128
num_layers = 2
num_epochs = 5
learning_rate = 0.0001
def _runa(rank, flags):
    # Define training params 
    MAX_LEN = 192 # maximum text length in the batch (cannot have too high due to memory constraints)
    BATCH_SIZE = 16 # batch size (cannot have too high due to memory constraints)
    EPOCHS = 2 # number of epochs

    # defining data samplers and loaders 
    train_sampler = torch.utils.data.distributed.DistributedSampler(
          dataset,
          num_replicas=xm.xrt_world_size(), # tell PyTorch how many devices (TPU cores) we are using for training
          rank=xm.get_ordinal(), # tell PyTorch which device (core) we are on currently
          shuffle=True)

    train_data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        sampler=train_sampler,
        drop_last=True,
        num_workers=0,
    )
        
    # valid_sampler = torch.utils.data.distributed.DistributedSampler(
    #       valid_dataset,
    #       num_replicas=xm.xrt_world_size(),
    #       rank=xm.get_ordinal(),
    #       shuffle=False)

    # valid_data_loader = torch.utils.data.DataLoader(
    #     valid_dataset,
    #     batch_size=BATCH_SIZE,
    #     sampler=valid_sampler,
    #     drop_last=False,
    #     num_workers=0
    # )
    

    device = xm.xla_device() # our device (single TPU core)
    model = LSTMTextGenerator(vocab_size, embed_size, hidden_size, num_layers, pretrained_embeddings=pretrained_embeddings).to(device)
    xm.master_print(f'Training on: {device}\n')
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    param_optimizer = list(model.named_parameters()) # model parameters to optimize
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    # apply to weight decay
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.001},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]

    xm.master_print('training on train dataset')
    
    lr = 0.5e-5 * xm.xrt_world_size() # scale the learning rate
    # calculate the total number of training steps
    num_train_steps = int(len(dataset) / BATCH_SIZE / xm.xrt_world_size() * EPOCHS) 
    
    
    # a scheduler can be used if desired
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=0,
        num_training_steps=num_train_steps
    )
    xm.master_print(f'num_training_steps = {num_train_steps}, world_size={xm.xrt_world_size()}')

    # Let's start training on the train set!
    for epoch in range(EPOCHS):
        gc.collect() # I use a lot of gc.collect() statement to hopefully prevent OOM problems
        # We use ParallelLoader (provided by PyTorch XLA) for TPU-core-specific dataloading:
        para_loader = pl.ParallelLoader(train_data_loader, [device]) 
        xm.master_print('parallel loader created... training now')
        gc.collect()
        # call training loop:
        train_loop_fn(para_loader.per_device_loader(device), model, optimizer, criterion, epoch, device)
        del para_loader
        # para_loader = pl.ParallelLoader(valid_data_loader, [device])
        gc.collect()
        # call evaluation loop
        # o, t = eval_loop_fn(para_loader.per_device_loader(device), model, device)
        # del para_loader
        gc.collect()
        # report AUC at the end
        # auc = metrics.roc_auc_score(np.array(t) >= 0.5, o)
        # auc_reduced = xm.mesh_reduce('auc_reduce',auc,reduce_fn)
        # xm.master_print(f'AUC = {auc_reduced}')
        # gc.collect()
    # save checkpoint
    save_checkpoint(EPOCHS, model, optimizer, "loss", checkpoint_dir='checkpoints')
    
    # We can also repeat the process on the validation set as demonstrated by @xhlulu
    
    # xm.master_print('training on validation set')
    
    # lr = 1.5e-5 * xm.xrt_world_size()
    
    # num_train_steps = int(len(valid_dataset) / BATCH_SIZE / xm.xrt_world_size() * EPOCHS)
    
    # optimizer = AdamW(optimizer_grouped_parameters, lr=lr)
    # scheduler = get_linear_schedule_with_warmup(
    #     optimizer,
    #     num_warmup_steps=0.1*num_train_steps,
    #     num_training_steps=num_train_steps
    # )
    # xm.master_print(f'num_training_steps = {num_train_steps}, world_size={xm.xrt_world_size()}')



In [None]:
def main():
    flags = {}
    xmp.spawn(_runa, args=(flags,), nprocs=8, start_method='fork')

if __name__ == '__main__':
    main()

## Generar

In [84]:
def generate_text(model, start_sequence, generation_length, seq_len, device):
    model.eval()
    generated_sequence = start_sequence

    # Initialize the hidden state
    hidden = model.init_hidden(1)
    hidden = (hidden[0].to(device), hidden[1].to(device))  # Move each element to device


    with torch.no_grad():
        for _ in range(generation_length):
            input_seq = torch.tensor(generated_sequence[-(seq_len-1):], dtype=torch.long).unsqueeze(0).to(device)  # Move to device
            output, hidden = model(input_seq, hidden)
            # print(output.shape) 
            next_token = output.argmax(dim=2)[:,-1].item()
            generated_sequence.append(next_token)
    
    return generated_sequence


In [118]:
device = xm.xla_device()
model = LSTMTextGenerator(vocab_size, embed_size, hidden_size, num_layers, pretrained_embeddings=pretrained_embeddings).to(device)
start_sequence = [vocab['▁ma']]  # Start with the token for 'ma'
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
epoch, loss=load_checkpoint('checkpoints/checkpoint_last.pth', model, optimizer)
print("Loaded model from epoch {} with loss {}".format(epoch, loss))

import sentencepiece as spm
tokenizerIsk = spm.SentencePieceProcessor(model_file='tokenizadorIskonawa.model')

Checkpoint loaded from checkpoints/checkpoint_last.pth, epoch: 19, loss: 8.059777471754286
Loaded model from epoch 19 with loss 8.059777471754286


  checkpoint = torch.load(checkpoint_path)


In [126]:
sequence=generate_text(model, start_sequence=[vocab['ke']], generation_length=20, seq_len=10, device=device)


In [127]:
sequenceA = [tokenizerIsk.id_to_piece(idx) for idx in sequence]
sequenceA

['ke',
 'pó',
 '▁pae',
 '▁pae',
 '▁pae',
 '▁pae',
 '▁pae',
 '▁pae',
 '▁pae',
 '▁pae',
 '▁pae',
 'én',
 'én',
 '▁hanawe',
 'én',
 '▁hanawe',
 'én',
 'én',
 '▁hanawe',
 'én',
 '▁hanawe']