In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pickle
import matplotlib.pyplot as plt
from tqdm import tqdm

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [3]:
with open("utils/letter2idx.pickle", "rb") as file:
    letter2idx = pickle.load(file)

with open("utils/diacritic2id.pickle", "rb") as file:
    diacritic2id = pickle.load(file)

idx2letter = {value: key for key, value in letter2idx.items()}
idx2diacritic = {value: key for key, value in diacritic2id.items()}

print(letter2idx)
print(idx2letter)
print(diacritic2id)
print(idx2diacritic)

In [4]:
vocab_size = len(letter2idx) 
num_classes = len(diacritic2id)
print("Vocab size:", vocab_size)
print("Num classes:", num_classes)

In [5]:
X_train = np.load("data/X_train.npy", allow_pickle=True)
y_train = np.load("data/y_train.npy", allow_pickle=True)
X_val = np.load("data/X_val.npy", allow_pickle=True)
y_val = np.load("data/y_val.npy", allow_pickle=True)

In [6]:
def train_model_dynamic(model, train_loader, val_loader, epochs=10, learning_rate=0.001):
    criterion = nn.CrossEntropyLoss(ignore_index=diacritic2id['<PAD>'])
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-5)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=2, factor=0.5)
    
    train_losses = []
    train_accuracies = []
    val_losses = []
    val_accuracies = []
    
    best_val_loss = float('inf')
    best_model_path = 'best_dynamic_blstm_model.pth'
    
    for epoch in range(epochs):
        model.train()

        total_train_loss = 0
        total_train_correct = 0
        total_train_tokens = 0

        with tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs} [Train]') as pbar:
            for batch_X, batch_y, _, lengths in pbar:
                batch_X, batch_y = batch_X.to(device), batch_y.to(device)
                
                optimizer.zero_grad()

                outputs = model(batch_X, lengths)

                B, T, C = outputs.shape
                loss = criterion(outputs.view(B*T, C), batch_y.view(B*T))

                preds = outputs.argmax(dim=-1)
                mask = (batch_y != diacritic2id['<PAD>'])

                correct = (preds[mask] == batch_y[mask]).sum().item()
                total = mask.sum().item()

                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()

                total_train_loss += loss.item()
                total_train_correct += correct
                total_train_tokens += total

                acc = correct / total if total > 0 else 0.0

                pbar.set_postfix({
                    'Loss': f'{loss.item():.4f}',
                    'Acc': f'{acc:.4f}'
                })

        avg_train_loss = total_train_loss / len(train_loader)
        avg_train_acc = total_train_correct / total_train_tokens

        model.eval()
        total_val_loss = 0
        total_val_correct = 0
        total_val_tokens = 0

        with torch.no_grad():
            with tqdm(val_loader, desc=f'Epoch {epoch+1}/{epochs} [Val]') as pbar:
                for batch_X, batch_y, _, lengths in pbar:
                    batch_X, batch_y = batch_X.to(device), batch_y.to(device)

                    outputs = model(batch_X, lengths)
                    B, T, C = outputs.shape

                    loss = criterion(outputs.view(B*T, C), batch_y.view(B*T))

                    preds = outputs.argmax(dim=-1)
                    mask = (batch_y != diacritic2id['<PAD>'])

                    correct = (preds[mask] == batch_y[mask]).sum().item()
                    total = mask.sum().item()

                    total_val_loss += loss.item()
                    total_val_correct += correct
                    total_val_tokens += total

                    acc = correct / total if total > 0 else 0.0

                    pbar.set_postfix({
                        'Loss': f'{loss.item():.4f}',
                        'Acc': f'{acc:.4f}'
                    })

        avg_val_loss = total_val_loss / len(val_loader)
        avg_val_acc = total_val_correct / total_val_tokens

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': best_val_loss,
                'val_accuracy': avg_val_acc,
                'train_loss': avg_train_loss,
                'train_accuracy': avg_train_acc
            }, best_model_path)
            print(f"  â†³ Best model saved! (val_loss: {best_val_loss:.4f})")

        scheduler.step(avg_val_loss)

        print(f'Epoch {epoch+1}/{epochs}:')
        print(f'  Train Loss: {avg_train_loss:.4f}, Train Acc: {avg_train_acc:.4f}')
        print(f'  Val Loss: {avg_val_loss:.4f}, Val Acc: {avg_val_acc:.4f}')
        print(f'  LR: {optimizer.param_groups[0]["lr"]:.6f}')

        train_losses.append(avg_train_loss)
        train_accuracies.append(avg_train_acc)
        val_losses.append(avg_val_loss)
        val_accuracies.append(avg_val_acc)

    return {
        'train_loss': train_losses,
        'train_accuracy': train_accuracies,
        'val_loss': val_losses,
        'val_accuracy': val_accuracies
    }

In [None]:
def pad_collate_fn(batch):
    x_batch, y_batch, mask_batch_padded = zip(*batch)
    lengths_x = [len(x) for x in x_batch]
    x_padded = torch.nn.utils.rnn.pad_sequence(x_batch, batch_first=True, padding_value=letter2idx['<PAD>'])
    y_padded = torch.nn.utils.rnn.pad_sequence(y_batch, batch_first=True, padding_value=diacritic2id['<PAD>'])
    mask_stacked = torch.stack(mask_batch_padded, dim=0) 
    return x_padded, y_padded, mask_stacked, torch.tensor(lengths_x, dtype=torch.long)

In [8]:
class DiacritizationDataset(Dataset):
    def __init__(self, X, y, letter2idx):
        self.X = X
        self.y = y
        self.pad_idx = letter2idx['<PAD>']
        self.space_idx = letter2idx[' ']
        
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        x_seq = self.X[idx]
        y_seq = self.y[idx]
        
        mask = np.zeros_like(x_seq, dtype=np.int64)
        
        pad_positions = np.where(x_seq == self.pad_idx)[0]
        if len(pad_positions) > 0:
            sentence_end = pad_positions[0]
        else:
            sentence_end = len(x_seq)

        space_positions = np.where(x_seq[:sentence_end] == self.space_idx)[0]
        all_boundaries = list(space_positions) + [sentence_end]
        
        start_idx = 0
        for boundary in all_boundaries:
            if start_idx < boundary:
                last_char_idx = boundary - 1
                if x_seq[last_char_idx] != self.pad_idx:
                    mask[last_char_idx] = 1
            start_idx = boundary + 1
        
        return (
            torch.tensor(x_seq, dtype=torch.long),
            torch.tensor(y_seq, dtype=torch.long),
            torch.tensor(mask, dtype=torch.long)
        )

In [9]:
class BiLSTM(nn.Module):
    def __init__(self, vocab_size, num_classes, embedding_dim=256, hidden_size=256, num_layers=2):
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=13)
        nn.init.xavier_uniform_(self.embedding.weight)
        
        self.bilstm = nn.LSTM(
            embedding_dim,
            hidden_size,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True,
            dropout=0.3 if num_layers > 1 else 0
        )
        
        self.emb_norm = nn.LayerNorm(embedding_dim)
        self.lstm_norm = nn.LayerNorm(hidden_size * 2)
        self.fc = nn.Linear(hidden_size * 2, num_classes)
        
    def forward(self, x, lengths=None):
        x_emb = self.embedding(x)
        x_emb = self.emb_norm(x_emb)
        
        if lengths is not None:
            lengths_sorted, sorted_idx = lengths.sort(descending=True)
            x_sorted = x_emb[sorted_idx]
            
            x_packed = nn.utils.rnn.pack_padded_sequence(
                x_sorted, 
                lengths_sorted.cpu(), 
                batch_first=True, 
                enforce_sorted=True
            )
            
            lstm_packed, _ = self.bilstm(x_packed)
            
            lstm_out, _ = nn.utils.rnn.pad_packed_sequence(
                lstm_packed, 
                batch_first=True,
                padding_value=0.0,
                total_length=x_emb.size(1)
            )
            
            _, unsorted_idx = sorted_idx.sort()
            lstm_out = lstm_out[unsorted_idx]
        else:
            lstm_out, _ = self.bilstm(x_emb)
        
        lstm_out = self.lstm_norm(lstm_out)
        out = self.fc(lstm_out)
        return out

In [10]:
train_text_dataset = DiacritizationDataset(X_train, y_train, letter2idx)
val_text_dataset = DiacritizationDataset(X_val, y_val, letter2idx)

batch_size = 32
train_loader = DataLoader(
    train_text_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2,
    collate_fn=pad_collate_fn
)

val_loader = DataLoader(
    val_text_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=2,
    collate_fn=pad_collate_fn
)

In [11]:
model = BiLSTM(vocab_size=vocab_size, num_classes=num_classes).to(device)

print("Model architecture:")
print(model)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nTotal parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

In [12]:
# Print model summary
print("\\nModel Summary:")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")

# Train model with dynamic sequences
print("\\nStarting training with dynamic sequence lengths...")
history = train_model_dynamic(model, train_loader, val_loader, epochs=10)

# Plot results
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(history['train_accuracy'], label='Training Accuracy')
plt.plot(history['val_accuracy'], label='Validation Accuracy')
plt.title('Model Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history['train_loss'], label='Training Loss')
plt.plot(history['val_loss'], label='Validation Loss')
plt.title('Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.tight_layout()
plt.show()

# Save model
torch.save({
    'model_state_dict': model.state_dict(),
    'vocab_size': vocab_size,
    'num_classes': num_classes,
    'history': history
}, 'dynamic_blstm_model.pth')

print("Model saved successfully as 'dynamic_blstm_model.pth'!")