In [7]:
import torch
import torch.nn as nn
import torchvision.models as models

import string
import dataset
import config
import data_train
import config
import model
import os

In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.nn import CTCLoss
from tqdm import tqdm
import config

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Mapping index â†’ char (ubah sesuai datasetmu)
idx2char = {i: chr(96+i) for i in range(1, 27)}  # 1='a' ... 26='z'
idx2char[0] = "-"  # blank untuk CTC


def decode_prediction(logits):
    """Greedy decode dari output CTC"""
    out = torch.argmax(logits, dim=2)  # [T, B]
    out = out.permute(1, 0)  # [B, T]
    results = []
    for seq in out:
        prev = -1
        s = ""
        for idx in seq.cpu().numpy():
            if idx != prev and idx != 0:  # remove repeat & blank
                s += idx2char.get(idx, "?")
            prev = idx
        results.append(s)
    return results


def train_one_epoch(model, train_loader, optimizer, criterion, epoch, print_every=100):
    model.train()
    train_loss = 0.0
    pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch}")

    for batch_idx, batch in pbar:
        images = batch["images"].to(device)           # [B, C, H, W]
        targets = batch["targets"].to(device)         # [B, max_len]
        targets_lengths = batch["targets_lengths"]    # [B]

        batch_size = images.size(0)

        # Forward
        logits = model(images)   # [T, B, num_classes]
        log_probs = F.log_softmax(logits, dim=2)

        input_lengths = torch.full(size=(batch_size,), fill_value=logits.size(0), dtype=torch.long).to(device)

        # Hitung CTC Loss
        loss = criterion(log_probs, targets, input_lengths, targets_lengths)

        # Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

        # Print sample prediction tiap 10 batch
        if batch_idx % print_every == 0:
            decoded = decode_prediction(log_probs.detach())
            target_strs = []
            for t in targets:
                s = "".join([idx2char[i.item()] for i in t if i.item() != 0])
                target_strs.append(s)

            print(f"\nBatch {batch_idx} - Pred: {decoded[0]} | GT: {target_strs[0]}")

        pbar.set_postfix({"loss": f"{loss.item():.4f}"})

    return train_loss / len(train_loader)


def validate(model, val_loader, criterion):
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for batch in val_loader:
            images = batch["images"].to(device)
            targets = batch["targets"].to(device)
            targets_lengths = batch["targets_lengths"]

            batch_size = images.size(0)
            logits = model(images)
            log_probs = F.log_softmax(logits, dim=2)
            input_lengths = torch.full(size=(batch_size,), fill_value=logits.size(0), dtype=torch.long).to(device)

            loss = criterion(log_probs, targets, input_lengths, targets_lengths)
            val_loss += loss.item()

    return val_loss / len(val_loader)


def train_model(
    train_loader, 
    val_loader, 
    num_classes=27, 
    hidden_state=128, 
    epochs=10, 
    lr=1e-3, 
    checkpoint_path="checkpoint.pth"
):
    # Init model
    net = model.model_wordRec(num_classes=num_classes, hidden_state=hidden_state).to(device)
    criterion = CTCLoss(blank=0, reduction="mean", zero_infinity=True)
    optimizer = optim.Adam(net.parameters(), lr=lr)

    start_epoch = 1

    # Cek kalau ada checkpoint, resume
    if os.path.exists(checkpoint_path):
        print(f"ðŸ”„ Loading checkpoint from {checkpoint_path} ...")
        checkpoint = torch.load(checkpoint_path, map_location=device)
        net.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        print(f"Resumed from epoch {checkpoint['epoch']}")

    for epoch in range(start_epoch, epochs+1):
        train_loss = train_one_epoch(net, train_loader, optimizer, criterion, epoch)
        val_loss = validate(net, val_loader, criterion)

        print(f"Epoch [{epoch}/{epochs}] - Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

        # âœ… Simpan checkpoint tiap epoch
        torch.save({
            'epoch': epoch,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss,
        }, checkpoint_path)
        print(f"ðŸ’¾ Checkpoint saved at epoch {epoch}")

    return net




Using device: cuda


In [9]:
print(torch.cuda.is_available())  
print(torch.cuda.get_device_name(0)) 


True
NVIDIA GeForce RTX 3050 6GB Laptop GPU


In [10]:
model_trained = train_model(
    data_train.train_loader, 
    data_train.val_loader, 
    num_classes=27,   # 26 huruf + blank
    epochs=config.EPOCHS, 
    lr=0.001
)


Epoch 1:   0%|          | 1/2000 [00:01<27:09,  1.23it/s, loss=22.2492]


Batch 0 - Pred: wvxit | GT: dktegqe


Epoch 1:   5%|â–Œ         | 101/2000 [00:22<06:54,  4.58it/s, loss=3.5814]


Batch 100 - Pred:  | GT: pcikhdn


Epoch 1:  10%|â–ˆ         | 201/2000 [00:44<06:32,  4.59it/s, loss=3.6157]


Batch 200 - Pred:  | GT: vugraiiwbn


Epoch 1:  15%|â–ˆâ–Œ        | 301/2000 [01:05<06:13,  4.55it/s, loss=3.6046]


Batch 300 - Pred:  | GT: fipugwrz


Epoch 1:  20%|â–ˆâ–ˆ        | 401/2000 [01:27<05:49,  4.57it/s, loss=3.6218]


Batch 400 - Pred:  | GT: wcrcgamjue


Epoch 1:  25%|â–ˆâ–ˆâ–Œ       | 501/2000 [01:48<05:29,  4.55it/s, loss=3.5842]


Batch 500 - Pred:  | GT: kgfdjjyypw


Epoch 1:  30%|â–ˆâ–ˆâ–ˆ       | 601/2000 [02:10<05:07,  4.54it/s, loss=3.6204]


Batch 600 - Pred:  | GT: dpvrhk


Epoch 1:  35%|â–ˆâ–ˆâ–ˆâ–Œ      | 701/2000 [02:32<04:45,  4.54it/s, loss=3.6053]


Batch 700 - Pred:  | GT: oczjocuq


Epoch 1:  40%|â–ˆâ–ˆâ–ˆâ–ˆ      | 801/2000 [02:54<04:25,  4.52it/s, loss=3.5766]


Batch 800 - Pred:  | GT: lozrerjn


Epoch 1:  45%|â–ˆâ–ˆâ–ˆâ–ˆâ–Œ     | 901/2000 [03:16<04:03,  4.51it/s, loss=3.5816]


Batch 900 - Pred:  | GT: jxxxt


Epoch 1:  50%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆ     | 1001/2000 [03:38<03:41,  4.52it/s, loss=3.5552]


Batch 1000 - Pred:  | GT: fezmiw


Epoch 1:  55%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–Œ    | 1101/2000 [04:00<03:20,  4.48it/s, loss=3.5490]


Batch 1100 - Pred:  | GT: nzlsyy


Epoch 1:  60%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ    | 1201/2000 [04:22<02:58,  4.47it/s, loss=3.5692]


Batch 1200 - Pred:  | GT: xmcrhdg


Epoch 1:  65%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–Œ   | 1301/2000 [04:44<02:36,  4.46it/s, loss=3.5396]


Batch 1300 - Pred:  | GT: qggnybhq


Epoch 1:  70%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ   | 1401/2000 [05:06<02:14,  4.46it/s, loss=3.5310]


Batch 1400 - Pred:  | GT: asdpnnkv


Epoch 1:  75%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–Œ  | 1501/2000 [05:28<01:51,  4.48it/s, loss=3.5424]


Batch 1500 - Pred:  | GT: vfgsfmdsr


Epoch 1:  80%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ  | 1601/2000 [05:50<01:29,  4.45it/s, loss=3.5363]


Batch 1600 - Pred:  | GT: hyhr


Epoch 1:  85%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–Œ | 1701/2000 [06:12<01:07,  4.46it/s, loss=3.5251]


Batch 1700 - Pred:  | GT: wwbfsj


Epoch 1:  90%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ | 1801/2000 [06:34<00:44,  4.47it/s, loss=3.5516]


Batch 1800 - Pred:  | GT: upbztfbm


Epoch 1:  95%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–Œ| 1901/2000 [06:57<00:22,  4.44it/s, loss=3.5423]


Batch 1900 - Pred:  | GT: vskpkczk


Epoch 1: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 2000/2000 [07:19<00:00,  4.55it/s, loss=3.4616]


Epoch [1/5] - Train Loss: 3.6091 | Val Loss: 3.6958
ðŸ’¾ Checkpoint saved at epoch 1


Epoch 2:   0%|          | 1/2000 [00:00<17:01,  1.96it/s, loss=3.4910]


Batch 0 - Pred:  | GT: vjdqt


Epoch 2:   5%|â–Œ         | 101/2000 [00:22<07:00,  4.52it/s, loss=3.2504]


Batch 100 - Pred:  | GT: hvvjpdynl


Epoch 2:  10%|â–ˆ         | 201/2000 [00:43<06:38,  4.51it/s, loss=3.1805]


Batch 200 - Pred:  | GT: kehinb


Epoch 2:  15%|â–ˆâ–Œ        | 301/2000 [01:05<06:21,  4.46it/s, loss=3.1111]


Batch 300 - Pred:  | GT: gapvjmby


Epoch 2:  20%|â–ˆâ–ˆ        | 401/2000 [01:28<05:57,  4.47it/s, loss=2.4895]


Batch 400 - Pred: ejeegwj | GT: qhnucsrvwh


Epoch 2:  25%|â–ˆâ–ˆâ–Œ       | 501/2000 [01:50<05:35,  4.47it/s, loss=1.3595]


Batch 500 - Pred: xkakak | GT: xhghsi


Epoch 2:  30%|â–ˆâ–ˆâ–ˆ       | 601/2000 [02:12<05:13,  4.46it/s, loss=0.5468]


Batch 600 - Pred: wjzrasidg | GT: wjzrasidyg


Epoch 2:  35%|â–ˆâ–ˆâ–ˆâ–Œ      | 701/2000 [02:34<04:50,  4.47it/s, loss=0.1804]


Batch 700 - Pred: lfzyywgg | GT: ifzyywgg


Epoch 2:  40%|â–ˆâ–ˆâ–ˆâ–ˆ      | 801/2000 [02:56<04:31,  4.41it/s, loss=0.0548]


Batch 800 - Pred: uincji | GT: uincji


Epoch 2:  45%|â–ˆâ–ˆâ–ˆâ–ˆâ–Œ     | 901/2000 [03:19<04:08,  4.43it/s, loss=0.0171]


Batch 900 - Pred: kpgulrz | GT: kpgulrz


Epoch 2:  50%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆ     | 1001/2000 [03:41<03:47,  4.40it/s, loss=0.0658]


Batch 1000 - Pred: tuvfphlobg | GT: tuvfphlobg


Epoch 2:  55%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–Œ    | 1101/2000 [04:03<03:20,  4.48it/s, loss=0.0301]


Batch 1100 - Pred: dogkbc | GT: dogkbc


Epoch 2:  60%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ    | 1201/2000 [04:25<03:00,  4.43it/s, loss=0.0058]


Batch 1200 - Pred: igngqo | GT: igngqo


Epoch 2:  65%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–Œ   | 1301/2000 [04:48<02:37,  4.44it/s, loss=0.0299]


Batch 1300 - Pred: bjvwonk | GT: bjvwonk


Epoch 2:  70%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ   | 1401/2000 [05:10<02:14,  4.46it/s, loss=0.0465]


Batch 1400 - Pred: zubgqaug | GT: zubgqaug


Epoch 2:  75%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–Œ  | 1501/2000 [05:32<01:51,  4.46it/s, loss=0.0089]


Batch 1500 - Pred: usfv | GT: usfv


Epoch 2:  80%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ  | 1601/2000 [05:54<01:29,  4.45it/s, loss=0.0181]


Batch 1600 - Pred: jufmvz | GT: jufmvz


Epoch 2:  85%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–Œ | 1701/2000 [06:17<01:07,  4.43it/s, loss=0.0029]


Batch 1700 - Pred: sscbr | GT: sscbr


Epoch 2:  90%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ | 1801/2000 [06:39<00:44,  4.45it/s, loss=0.0121]


Batch 1800 - Pred: tnrcskytad | GT: tnrcskytad


Epoch 2:  95%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–Œ| 1901/2000 [07:01<00:22,  4.45it/s, loss=0.0123]


Batch 1900 - Pred: csanen | GT: csanen


Epoch 2: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 2000/2000 [07:24<00:00,  4.50it/s, loss=0.0616]


Epoch [2/5] - Train Loss: 0.8116 | Val Loss: 0.0189
ðŸ’¾ Checkpoint saved at epoch 2


Epoch 3:   0%|          | 1/2000 [00:00<17:12,  1.94it/s, loss=0.0044]


Batch 0 - Pred: aguammvjk | GT: aguammvjk


Epoch 3:   5%|â–Œ         | 101/2000 [00:22<07:03,  4.48it/s, loss=0.0023]


Batch 100 - Pred: dvyfnx | GT: dvyfnx


Epoch 3:  10%|â–ˆ         | 201/2000 [00:44<06:40,  4.49it/s, loss=0.0081]


Batch 200 - Pred: lhckusojy | GT: lhckusojy


Epoch 3:  15%|â–ˆâ–Œ        | 301/2000 [01:06<06:21,  4.45it/s, loss=0.0155]


Batch 300 - Pred: axvmgf | GT: axvmgf


Epoch 3:  20%|â–ˆâ–ˆ        | 401/2000 [01:28<06:05,  4.38it/s, loss=0.0957]


Batch 400 - Pred: lbihyfpemk | GT: lbihyfpemk


Epoch 3:  23%|â–ˆâ–ˆâ–Ž       | 452/2000 [01:40<05:43,  4.51it/s, loss=0.0051]


KeyboardInterrupt: 