In [1]:
from data.dataset import Synth90kDataset, synth90k_collate_fn
import torch.multiprocessing as mp
import torch
from torch.utils.data import DataLoader
from models.crnn import CRNN, count_parameters
import torch.optim as optim
import torch.nn as nn
from tqdm.notebook import tqdm 
from models.ctc_decoder import ctc_decoder

# Set multiprocessing start method to 'spawn'
mp.set_start_method('spawn', force=True)


dataset_path = './data/mnt/ramdisk/max/90kDICT32px/'
modes = ['train', 'val', 'test']
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
def getting_all_batches(batch, device):

    images, targets, target_lengths = batch['images'], \
                                      batch['targets'], \
                                    batch['target_lengths']
    images, targets, target_lengths = images.to(device), \
                                       targets.to(device), \
                                      target_lengths.to(device)
    return images, targets, target_lengths

In [3]:
def calculate_loss(preds, preds_length, targets, target_lengths, optimizer, criterion):
    optimizer.zero_grad()
    batch_size = images.size(0)
    
    loss = criterion(preds, targets, preds_length, target_lengths)

    
    loss.backward()

    torch.nn.utils.clip_grad_norm_(crnn.parameters(), 5) # gradient clipping with 5
    optimizer.step()
    return loss.item()
    

In [4]:
def calculate_accuracy(output, output_lengths, targets, target_lengths, 
                      decode_method = 'beam_search', beam_size = 10):
    output_detach = output.detach()
    preds = ctc_decoder(output_detach, method = decode_method, beam_size = beam_size)
    
    reals = targets.cpu().numpy().tolist()

    
    target_lengths = target_lengths.cpu().numpy().tolist()
    
    num_correct = 0
    target_length_counter = 0
    for pred, target_length in zip(preds, target_lengths):
        real = reals[target_length_counter: target_length_counter + target_length]
        target_length_counter += target_length

        # print(pred, real)
        if pred == real:
            num_correct += 1

    return num_correct

In [8]:
len(train_dataset)

7224612

In [9]:
len(valid_dataset)

802734

In [12]:
from argparse import Namespace 

torch.manual_seed(42)
device = "cuda" if torch.cuda.is_available() else "cpu"
train_args = Namespace(
    data_dir = './data/mnt/ramdisk/max/90kDICT32px/',
    train_batch_size = 32,
    eval_batch_size = 512,
    epochs = 1000,
    save_interval = 20,
    cpu_workers = 8,
    learning_rate = 0.05,
    reload_checkpoint = None,
    decode_method = 'beam_search',
    beam_size = 10,
    checkpoints_dir = 'checkpoints/',
    img_width = 100,
    img_height= 32,
    map_to_seq = 64,
    lstm_hidden = 256,
    leaky_relu = False
)


train_dataset = Synth90kDataset(dataset_path, mode = 'train', 
                                img_height = train_args.img_height,
                                img_width = train_args.img_width)
valid_dataset = Synth90kDataset(dataset_path, mode = 'val', 
                                img_height = train_args.img_height,
                                img_width = train_args.img_width)
# test_dataset = Synth90kDataset(dataset_path, mode = 'test', 
#                                 img_height = train_args.img_height,
#                                 img_width = train_args.img_width)

reduced_train = len(train_dataset) // 100
reduced_indices = torch.randperm(len(train_dataset))[:reduced_train]
train_dataset_reduced = torch.utils.data.Subset(train_dataset, reduced_indices)

reduced_val = len(valid_dataset) // 100
reduced_indices = torch.randperm(len(train_dataset))[:reduced_val]
val_dataset_reduced = torch.utils.data.Subset(valid_dataset, reduced_indices)


train_loader = DataLoader(train_dataset_reduced, batch_size = train_args.train_batch_size,
                         shuffle = True, num_workers = train_args.cpu_workers,
                        collate_fn = synth90k_collate_fn)
valid_loader = DataLoader(val_dataset_reduced, batch_size = train_args.eval_batch_size,
                         shuffle = True, num_workers = train_args.cpu_workers,
                        collate_fn = synth90k_collate_fn)
# test_loader = DataLoader(test_dataset, batch_size = train_args.eval_batch_size,
#                          shuffle = True, num_workers = train_args.cpu_workers,
#                         collate_fn = synth90k_collate_fn)

Loading Dataset with mode: train
Loading Dataset with mode: val


In [13]:
len(train_dataset_reduced)

72246

In [16]:
num_classes = len(Synth90kDataset.LABEL2CHAR) + 1
crnn = CRNN(1, train_args.img_height, train_args.img_width, 
            num_classes = num_classes,
            leaky_relu = train_args.leaky_relu, 
            map_to_seq = train_args.map_to_seq,
            lstm_hidden = train_args.lstm_hidden).to(device)
print(f"The number of parameters in this model are: {count_parameters(crnn)}")
if train_args.reload_checkpoint:
    crnn.load_state_dict(torch.load(reload_checkpoint, map_location=device))

optimizer = optim.Adadelta(crnn.parameters(), lr = train_args.learning_rate, rho = 0.9)
criterion = nn.CTCLoss(reduction = 'sum',  zero_infinity = True).to(device)

The number of parameters in this model are: 7839077


In [18]:
num_epochs = train_args.epochs

train_loss, val_loss = [], []
train_acc, val_acc = [], []

epoch_bar = tqdm(desc = 'Epoch',
                 total = num_epochs, position = 1)
train_bar = tqdm(desc = 'Training', total = len(train_loader),
                 position = 1, leave = True)
# val_bar = tqdm(desc = 'Validation', total = len(test_loader),
#                position = 1, leave = True)


for epoch in range(num_epochs):
    epoch_bar.set_description(f'Epoch {epoch + 1}/{num_epochs}')

    crnn.train()
    running_loss = 0.0
    running_acc = 0.0
    total_loss = 0.0
    total_acc = 0.0
    total = 0

    for i, batch in enumerate(train_loader):
        
        images, targets, target_lengths = getting_all_batches(batch, device)
        batch_size = batch['images'].size(0)

        # print(images)
        
        preds = crnn(images)
        preds = preds.permute(1, 0, 2) #(seq_len, batch, num_classes)
        seq_length = preds.size(0)

        # print(preds)
        preds_lengths = torch.full(size = (batch_size, ), 
                                   fill_value = seq_length, 
                                   dtype = torch.long).to(device)
        # print(preds_lengths)
        # print(preds.shape)
        # print(preds_lengths.shape)
        # print(targets.shape)
        # print(target_lengths.shape)
        # print(torch.sum(target_lengths))
        loss_t = calculate_loss(preds, preds_lengths, targets, target_lengths,
                               optimizer , criterion)
        

        running_loss += (loss_t - running_loss) / (i + 1)
        total_loss += loss_t 
        total += batch_size 

        
        num_correct = calculate_accuracy(preds, preds_lengths, 
                                          targets, target_lengths, 
                                         decode_method = train_args.decode_method,
                                         beam_size = train_args.beam_size)
        acc_t = num_correct / batch_size * 100
        running_acc += (acc_t - running_acc) / (i + 1)
        total_acc += num_correct 
        
        train_bar.set_postfix(loss = running_loss,
                              acc = f"{running_acc:.2f}%",
                              epoch = epoch + 1)
        train_bar.update()
    
    current_loss = total_loss / len(train_loader)
    current_acc = total_acc / total * 100
    train_loss.append(current_loss)
    train_acc.append(current_acc)

    print("========================================")
    print("\033[1;34m" + f"Epoch {epoch + 1}/{num_epochs}" + "\033[0m")
    print(f"Train Loss: {current_loss:.2f}\nTrain Acc: {current_acc:.2f}%")

    train_bar.n = 0
    epoch_bar.update()

Epoch:   0%|          | 0/1000 [00:00<?, ?it/s]

Training:   0%|          | 0/2258 [00:00<?, ?it/s]

[1;34mEpoch 1/1000[0m
Train Loss: 119.27
Train Acc: 0.00%
[1;34mEpoch 2/1000[0m
Train Loss: 89.27
Train Acc: 0.00%
[1;34mEpoch 3/1000[0m
Train Loss: 62.16
Train Acc: 0.00%
[1;34mEpoch 4/1000[0m
Train Loss: 43.84
Train Acc: 0.00%
[1;34mEpoch 5/1000[0m
Train Loss: 32.98
Train Acc: 0.00%


IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/opt/conda/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/opt/conda/lib/python3.10/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
KeyboardInterrupt


KeyboardInterrupt: 