In [1]:
!pip install jiwer
!pip install matplotlib
!pip install transformers
# !pip install opencv-python
# !pip install flash-attn --no-build-isolation
# !pip install --upgrade numpy scipy
# !pip install peft

[0m

In [2]:
import torch
from torch import nn
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader, random_split
from torch.amp import autocast, GradScaler
import torchvision
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchvision.transforms.functional as F
import transformers
from transformers import AutoModel

# from tqdm.notebook import tqdm
from tqdm import tqdm
from PIL import Image
import json
import os
import numpy as np
import random
from jiwer import wer
# import peft
# from peft import get_peft_model, LoraConfig, TaskType

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

In [3]:
def show_sequence(sequence, NUM_FRAMES):
    columns = 4
    rows = (NUM_FRAMES + 1) // (columns)
    fig = plt.figure(figsize=(32, (16 // columns) * rows))
    gs = gridspec.GridSpec(rows, columns)
    for j in range(rows * columns):
        plt.subplot(gs[j])
        plt.axis("off")
        frames = sequence[j].permute(1,2,0).numpy()
        frames = frames/ frames.max()
        plt.imshow(frames)

    plt.show()

In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
random.seed(42)

In [5]:
train_csv = "/workspace/datasets/phoenix/annotations/manual/train.corpus.csv"
test_csv = "/workspace/datasets/phoenix/annotations/manual/test.corpus.csv"
dev_csv = "/workspace/datasets/phoenix/annotations/manual/dev.corpus.csv"

train_paths = "/workspace/datasets/phoenix/fullFrame-210x260px/train"
test_paths = "/workspace/datasets/phoenix/fullFrame-210x260px/test"
dev_paths =  "/workspace/datasets/phoenix/fullFrame-210x260px/dev"

In [6]:
max_frames = 96
num_workers = 5
batch_size = 2
prefetch_factor = 5

In [7]:
word_to_idx = { '<p>':0}
idx_to_word = ['<p>']

arr_train = np.loadtxt(train_csv, delimiter='|', dtype='str')
arr_train = np.delete(arr_train,0,0)
arr_test = np.loadtxt(test_csv, delimiter='|', dtype='str')
arr_test = np.delete(arr_test,0,0)
arr_dev = np.loadtxt(dev_csv, delimiter='|', dtype='str')
arr_dev = np.delete(arr_dev,0,0)

arr = np.concatenate((arr_train, arr_test, arr_dev), axis=0)

for sentence in arr:
    for word in sentence[3].split(' '):
        if word not in idx_to_word:
            idx_to_word.append(word)
            word_to_idx[word] = len(idx_to_word)-1

In [8]:
len(idx_to_word), len(word_to_idx)

(1297, 1297)

In [9]:
class rwth_phoenix(Dataset):
    def __init__(self, csv, data_path, frame_transform, video_transform, input_fps, output_fps, max_frames, stride, word_dict):

        temp = np.loadtxt(csv, delimiter='|', dtype='str')
        self.csv = np.delete(temp, 0, 0)
        self.word_dict = word_dict
        self.data_path = data_path

        self.frame_transform = frame_transform
        self.video_transform = video_transform

        self.input_fps = input_fps
        self.output_fps = output_fps
        self.max_frames = max_frames
        self.stride = stride


    def __len__(self):
        return len(self.csv)


    def __getitem__(self, idx):
        folder = self.csv[idx][1].split('/')
        label = self.csv[idx][3].split(' ')

        words = []
        for word in label:
            words.append(self.word_dict[word])
        label = torch.tensor(words)
        image_folder_path = os.path.join(self.data_path, folder[0], folder[1])

        images = sorted(os.listdir(image_folder_path))
        end = len(images)

        step = self.input_fps/random.choice(self.output_fps)
        image_list = []
        if int(end//step +1) <= self.max_frames:
            frame_num, num = 0, 0
            while frame_num < end:
                num+=1
                if self.stride and num%self.stride == 0:
                    image_list.append(str(int(frame_num))+'a')

                else:
                    img = Image.open(os.path.join(image_folder_path, images[int(frame_num)]))


                    tensor_img = self.frame_transform(img)

                    image_list.append(tensor_img)
                frame_num += step

            c, h, w = image_list[-1].shape
            while len(image_list) < self.max_frames:
                image_list.append(torch.zeros(c,h,w))

            tensor_video = torch.stack(image_list[:self.max_frames])


            if self.video_transform:
                tensor_video = self.video_transform(tensor_video)


        else:
            frame_positions = np.linspace(0, end, self.max_frames, endpoint=False)
            num = 0
            for n in frame_positions:
                num+=1
                if self.stride and num%self.stride == 0:
                    image_list.append(str(int(n))+'a')

                else:
                    img = Image.open(os.path.join(image_folder_path, images[int(n)]))


                    tensor_img= self.frame_transform(img)
                    image_list.append(tensor_img)

            tensor_video = torch.stack(image_list[:self.max_frames])

            if self.video_transform:
                tensor_video = self.video_transform(tensor_video)
            

        
        return tensor_video, label

In [10]:
class TemporalRescale(object):
    def __init__(self, temp_scaling=0.2):
        self.min_len = 32
        self.max_len = 230
        self.L = 1.0 - temp_scaling
        self.U = 1.0 + temp_scaling

    def __call__(self, clip):
        vid_len = len(clip)
        new_len = int(vid_len * (self.L + (self.U - self.L) * np.random.random()))
        if new_len < self.min_len:
            new_len = self.min_len
        if new_len > self.max_len:
            new_len = self.max_len
        if (new_len - 4) % 4 != 0:
            new_len += 4 - (new_len - 4) % 4
        if new_len <= vid_len:
            index = sorted(random.sample(range(vid_len), new_len))
        else:
            index = sorted(random.choices(range(vid_len), k=new_len))
        return clip[index]

In [11]:
image_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomCrop((224, 224)),
    transforms.ToTensor()
])

image_test_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.CenterCrop((224,224)),
    transforms.ToTensor()
])

video_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    TemporalRescale()
])

video_test_transform = transforms.Compose([
    TemporalRescale(temp_scaling=0)
])

In [12]:
def collate_fn(batch):
    vid, labels = zip(*batch)
    
    # Get video lengths before padding
    vid_lengths = torch.tensor([v.shape[0] for v in vid])
    
    # Pad videos using pad_sequence
    vid = torch.nn.utils.rnn.pad_sequence(
        sequences=vid,
        batch_first=True,
        padding_value=0,
    )
    
    # Pad labels
    labels = torch.nn.utils.rnn.pad_sequence(
        sequences=labels,
        batch_first=True,
        padding_value=0,
    ).long()
    
    return vid, labels, vid_lengths

In [13]:
train_dataset = rwth_phoenix(csv=train_csv,
                       data_path=train_paths,
                        frame_transform=image_transform , video_transform=video_transform, input_fps=25, output_fps=list(range(15,22)), max_frames=max_frames, stride=0, word_dict=word_to_idx)

test_dataset = rwth_phoenix(csv=test_csv,
                       data_path=test_paths,
                        frame_transform=image_test_transform , video_transform=video_test_transform, input_fps=25, output_fps=list(range(15,22)), max_frames=max_frames, stride=0, word_dict=word_to_idx)

dev_dataset = rwth_phoenix(csv=dev_csv,
                       data_path=dev_paths,
                        frame_transform=image_test_transform , video_transform=video_test_transform, input_fps=25, output_fps=list(range(15,22)), max_frames=max_frames, stride=0, word_dict=word_to_idx)

In [14]:
train_dataloader = DataLoader(dataset=train_dataset, shuffle=True, batch_size=batch_size, collate_fn=collate_fn, prefetch_factor=prefetch_factor, num_workers=num_workers, pin_memory=True)
test_dataloader = DataLoader(dataset=test_dataset, shuffle=False, batch_size=batch_size, collate_fn=collate_fn, prefetch_factor=prefetch_factor, num_workers=num_workers, pin_memory=True)
dev_dataloader = DataLoader(dataset=dev_dataset, shuffle=False, batch_size=batch_size, collate_fn=collate_fn, prefetch_factor=prefetch_factor, num_workers=num_workers, pin_memory=True)

In [15]:
# vid, kp_vid, y = next(iter(train_dataloader))

In [16]:
# vid.shape

In [17]:
# show_sequence(vid[0], max_frames)

In [18]:
# show_sequence(kp_vid[0], max_frames)

In [19]:
scaler = GradScaler('cuda')

def update_length(lengths):
    # return ((lengths - 4) // 2 - 4) // 2
    return (lengths - 4) // 2

def decode_pred(tensor, lengths, idx_to_word=idx_to_word):
    n = tensor.shape[0]
    text = []
    for i in range(n):
        st = []
        prev_token = None
        # Only iterate up to the actual length for this sequence
        for j in range(lengths[i]):
            token = tensor[i, j]
            if token != 0 and token != prev_token:
                st.append(idx_to_word[token.item()])
            prev_token = token
        text.append(' '.join(st))
    return text

def decode_target(tensor, idx_to_word=idx_to_word):
    text = []
    for seq in tensor:
        words = [idx_to_word[token.item()] for token in seq if token.item() != 0]
        text.append(' '.join(words))
    return text

def train_step(model, optimizer, dataloader, loss_fn, epoch, device="cuda"):
    model.train()
    train_loss, total_correct_wer = 0, 0

    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch}", dynamic_ncols=True)
    
    for batch_idx, (X, y, lengths) in enumerate(progress_bar):
        X, y = X.to(device), y.to(device)
        lengths = update_length(lengths)

        with autocast('cuda'):
            y_logit, auxiliary_logit = model(X, lengths)
            y_logit = y_logit.permute(1, 0, 2)
            auxiliary_logit = auxiliary_logit.permute(1, 0, 2)

            if y_logit.isnan().any():
                print("⚠️ ABORT! NaN found in y_logit!")
                break
            y_length = torch.count_nonzero(y, axis=1)
            
            # Main CTC loss
            ctc_probs = y_logit.log_softmax(-1)
            main_loss = loss_fn(ctc_probs, y, lengths, y_length)
            
            # Auxiliary CTC loss
            aux_ctc_probs = auxiliary_logit.log_softmax(-1)
            aux_loss = loss_fn(aux_ctc_probs, y, lengths, y_length)
            
            # Combined loss
            loss = main_loss + aux_loss
            batch_loss = loss.item()
            train_loss += batch_loss

        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

        y_pred = torch.argmax(y_logit.permute(1, 0, 2), dim=2)
        total_correct_wer += wer(decode_target(y), decode_pred(y_pred, lengths))

        avg_loss = train_loss / (batch_idx + 1)
        progress_bar.set_postfix(batch_loss=f"{batch_loss:.4f}", avg_loss=f"{avg_loss:.4f}")

        del X, y, y_logit, auxiliary_logit, loss, y_pred

    acc_wer = (total_correct_wer / len(dataloader)) * 100
    avg_loss = train_loss / len(dataloader)

    print(f"Epoch {epoch} | Train Loss: {avg_loss:.4f} | Train WER: {acc_wer:.2f}%")
    return avg_loss, acc_wer



def test_step(model, loss_fn, epoch, dataloader, scheduler, device="cuda"):
    model.eval()
    test_loss, total_correct_wer = 0, 0

    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch} (Eval)", dynamic_ncols=True)

    with torch.no_grad():
        for batch_idx, (X, y, lengths) in enumerate(progress_bar):
            X, y = X.to(device), y.to(device)
            lengths = update_length(lengths)

            with autocast('cuda'):
                y_logit, auxiliary_logit = model(X, lengths)
                y_logit = y_logit.permute(1, 0, 2)
                auxiliary_logit = auxiliary_logit.permute(1, 0, 2)

                if y_logit.isnan().any():
                    print("⚠️ ABORT! NaN found in y_logit!")
                    break

                y_length = torch.count_nonzero(y, axis=1)
                
                # Main CTC loss
                ctc_probs = y_logit.log_softmax(-1)
                main_loss = loss_fn(ctc_probs, y, lengths, y_length)
                
                # Auxiliary CTC loss
                aux_ctc_probs = auxiliary_logit.log_softmax(-1)
                aux_loss = loss_fn(aux_ctc_probs, y, lengths, y_length)
                
                # Combined loss
                loss = main_loss + aux_loss
                batch_loss = loss.item()
                test_loss += batch_loss

            y_pred = torch.argmax(y_logit.permute(1, 0, 2), dim=2)
            total_correct_wer += wer(decode_target(y), decode_pred(y_pred, lengths))

            avg_loss = test_loss / (batch_idx + 1)
            progress_bar.set_postfix(batch_loss=f"{batch_loss:.4f}", avg_loss=f"{avg_loss:.4f}")

            del X, y, y_logit, auxiliary_logit, loss, y_pred

    acc_wer = (total_correct_wer / len(dataloader)) * 100
    avg_loss = test_loss / len(dataloader)

    if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
        scheduler.step(avg_loss)

    print(f"Epoch {epoch} | Test Loss: {avg_loss:.4f} | Test WER: {acc_wer:.2f}%")
    return avg_loss, acc_wer

In [20]:
class ViViT_SLR(nn.Module):
    def __init__(self,
                 vit_model,
                 vocab_size=len(idx_to_word),
                 batch_first=True,
                 input_size=640,
                 hidden_size=1024,
                 num_layers=2,
                 layer_norm_eps=1e-06,
                 dropout=0.1,
                 frame=max_frames,
                 pad_token=0
                 ):

        super(ViViT_SLR, self).__init__()

        self.vit = vit_model

        self.norm_encoding = nn.LayerNorm(input_size, eps=layer_norm_eps)

        self.temporal_conv = nn.Sequential(
            nn.Conv1d(640, hidden_size, kernel_size=5),
            nn.BatchNorm1d(hidden_size),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.MaxPool1d(kernel_size=2),

            # nn.Conv1d(hidden_size, hidden_size, kernel_size=5),
            # nn.BatchNorm1d(hidden_size),
            # nn.ReLU(),
            # nn.Dropout(dropout),
            # nn.MaxPool1d(kernel_size=2)
        )

        self.auxiliary_fc =  nn.Linear(hidden_size, vocab_size)

        self.vocab_size = vocab_size
        self.pad_token = pad_token
        self.frame = frame

        self.lstm = nn.LSTM(
            input_size=hidden_size, 
            hidden_size=hidden_size, 
            num_layers=num_layers, 
            batch_first=batch_first, 
            dropout=dropout, 
            bidirectional=True
        )

        self.dropout = nn.Dropout(dropout)

        self.fc_out = nn.Linear(hidden_size*2, vocab_size)
    #     self._init_weights()

    
    # def _init_weights(self):
    #     """Initialize weights using best practices for different layer types"""
        
    #     # Initialize Conv1d layers
    #     for module in self.temporal_conv.modules():
    #         if isinstance(module, nn.Conv1d):
    #             # Kaiming initialization for ReLU activations
    #             nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
    #             if module.bias is not None:
    #                 nn.init.constant_(module.bias, 0)
    #         elif isinstance(module, nn.BatchNorm1d):
    #             nn.init.constant_(module.weight, 1)
    #             nn.init.constant_(module.bias, 0)
        
    #     # Initialize LSTM layers
    #     for name, param in self.lstm.named_parameters():
    #         if 'weight_ih' in name:
    #             # Input-hidden weights: Xavier uniform
    #             nn.init.xavier_uniform_(param.data)
    #         elif 'weight_hh' in name:
    #             # Hidden-hidden weights: Orthogonal
    #             nn.init.orthogonal_(param.data)
    #         elif 'bias' in name:
    #             # Initialize biases to zero, except forget gate bias
    #             nn.init.constant_(param.data, 0)
    #             # Set forget gate bias to 1 (helps with gradient flow)
    #             n = param.size(0)
    #             param.data[n//4:n//2].fill_(1.0)
        
    #     # Initialize LayerNorm
    #     nn.init.constant_(self.norm_encoding.weight, 1)
    #     nn.init.constant_(self.norm_encoding.bias, 0)
        
    #     # Initialize final linear layer
    #     nn.init.xavier_uniform_(self.fc_out.weight)
    #     nn.init.constant_(self.fc_out.bias, 0)

    #     nn.init.xavier_uniform_(self.auxiliary_fc.weight)
    #     nn.init.constant_(self.auxiliary_fc.bias, 0)



    def forward(self, vid, lengths):
        
        B, T, C, H, W = vid.shape
        encoded_output = self.vit(vid.view(B*T, C, H, W)).pooler_output  # (B*T, 768)
        encoded_output = encoded_output.view(B, T, 640)  # (B, T, 768)
        encoded_output = self.norm_encoding(encoded_output)
        encoded_output = self.temporal_conv(encoded_output.permute(0,2,1)).permute(0,2,1)

        auxiliary = self.auxiliary_fc(encoded_output)

        new_T = encoded_output.size(1)

        packed_input = nn.utils.rnn.pack_padded_sequence(
            encoded_output,
            lengths.cpu(), 
            batch_first=True,
            enforce_sorted=False 
        )

        packed_output, _ = self.lstm(packed_input)  # (B, T, hidden_size * 2)

        decoded_output, _ = nn.utils.rnn.pad_packed_sequence(
            packed_output,
            batch_first=True,
            total_length=new_T 
        )

        logits = self.fc_out(self.dropout(decoded_output))  # (B, T, vocab_size)
        return logits, auxiliary

In [21]:
torch.manual_seed(42)
torch.cuda.manual_seed(42)
device = "cuda"

vit_model = AutoModel.from_pretrained("apple/mobilevit-small", use_safetensors=True)

# peft_config = LoraConfig(
#     r=16,
#     lora_alpha=32,
#     target_modules=[
#         "attention.attention.query",
#         "attention.attention.key", 
#         "attention.attention.value",
#         "attention.output.dense"
#     ],
#     lora_dropout=0.1,
#     bias="none"
# )

# vit_model = get_peft_model(vit_model, peft_config)
# vit_model.print_trainable_parameters()

model = ViViT_SLR(vit_model=vit_model).to(device)

# model = nn.DataParallel(model)

# checkpoint = torch.load(f="/home/jovyan/A_folder/Mobile_ViT_LSTM_15_epochs_40.83_wer.pth")
# new_state_dict = checkpoint["model_state_dict"]
# model.load_state_dict(new_state_dict, strict=True)
# print("Loaded checkpoint weights!")

loss_fn = nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True)
# param_groups = get_param_groups(model)
# optimizer = torch.optim.AdamW(param_groups)
optimizer = torch.optim.AdamW(params=model.parameters(),
                            lr=1e-4,
                            weight_decay=1e-3)

scheduler = ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.5,
    patience=2,
    min_lr=1e-6,
)

# optimizer.load_state_dict(checkpoint["optimizer"])
# scheduler.load_state_dict(checkpoint["scheduler"])

# del checkpoint, new_state_dict



In [None]:
current = 0
num_epochs = 40

start = 1 + current
epochs = num_epochs + current + 1
best = 36

print(f"Running from {start} to {epochs-1}")

torch.manual_seed(42)
torch.cuda.manual_seed(42)

print("Ready to train!!")

log = open('Mobile_ViT_LSTM_readings_aux-loss.txt', 'w')
log.write('Epoch | Train loss | Train WER | Test loss | Test WER\n')

for epoch in range(start, epochs):
    train_loss, train_wer = train_step(
        model=model,
        optimizer=optimizer,
        loss_fn=loss_fn,
        epoch=epoch,
        dataloader=train_dataloader
    )

    test_loss, test_wer = test_step(
        model=model,
        loss_fn=loss_fn,
        epoch=epoch,
        dataloader=test_dataloader,
        scheduler=scheduler
    )
    log.write(f'{epoch} | {train_loss:.4f} | {train_wer:.2f}% | {test_loss:.4f} | {test_wer:.2f}%\n')
    if epoch%5==0 or test_wer<best:
        best = test_wer
        checkpoint = {
            "model_state_dict" : model.state_dict(),
            "train_wer" : train_wer,
            "test_wer" : test_wer,
            "train_loss" : train_loss,
            "test_loss" : test_loss,
            "epoch" : epoch,
            "optimizer" : optimizer.state_dict(),
            "scheduler" : scheduler.state_dict()
        }
        torch.save(obj=checkpoint, f=f"Mobile_ViT_LSTM_{epoch}_epochs_{test_wer:.2f}_wer.pth")
        del checkpoint

log.close()

Running from 1 to 40
Ready to train!!


Epoch 1: 100%|██████████| 2836/2836 [10:42<00:00,  4.41it/s, avg_loss=9.1674, batch_loss=7.5902]  


Epoch 1 | Train Loss: 9.1674 | Train WER: 91.30%


Epoch 1 (Eval): 100%|██████████| 315/315 [00:42<00:00,  7.41it/s, avg_loss=7.1640, batch_loss=10.2118]


Epoch 1 | Test Loss: 7.1640 | Test WER: 80.68%


Epoch 2: 100%|██████████| 2836/2836 [10:33<00:00,  4.48it/s, avg_loss=6.0467, batch_loss=4.9517] 


Epoch 2 | Train Loss: 6.0467 | Train WER: 69.83%


Epoch 2 (Eval): 100%|██████████| 315/315 [00:43<00:00,  7.22it/s, avg_loss=5.9859, batch_loss=10.2641]


Epoch 2 | Test Loss: 5.9859 | Test WER: 54.30%


Epoch 3: 100%|██████████| 2836/2836 [10:31<00:00,  4.49it/s, avg_loss=5.0919, batch_loss=3.7540] 


Epoch 3 | Train Loss: 5.0919 | Train WER: 57.09%


Epoch 3 (Eval): 100%|██████████| 315/315 [00:43<00:00,  7.26it/s, avg_loss=5.0182, batch_loss=8.4547] 


Epoch 3 | Test Loss: 5.0182 | Test WER: 46.89%


Epoch 4: 100%|██████████| 2836/2836 [10:31<00:00,  4.49it/s, avg_loss=4.1468, batch_loss=3.3789] 


Epoch 4 | Train Loss: 4.1468 | Train WER: 46.79%


Epoch 4 (Eval): 100%|██████████| 315/315 [00:44<00:00,  7.15it/s, avg_loss=4.6934, batch_loss=9.1200] 


Epoch 4 | Test Loss: 4.6934 | Test WER: 42.48%


Epoch 5: 100%|██████████| 2836/2836 [10:31<00:00,  4.49it/s, avg_loss=3.8657, batch_loss=2.5452] 


Epoch 5 | Train Loss: 3.8657 | Train WER: 43.71%


Epoch 5 (Eval): 100%|██████████| 315/315 [00:43<00:00,  7.21it/s, avg_loss=4.4961, batch_loss=9.0725] 


Epoch 5 | Test Loss: 4.4961 | Test WER: 40.85%


Epoch 6: 100%|██████████| 2836/2836 [10:32<00:00,  4.48it/s, avg_loss=3.3770, batch_loss=2.3104]


Epoch 6 | Train Loss: 3.3770 | Train WER: 38.40%


Epoch 6 (Eval): 100%|██████████| 315/315 [00:42<00:00,  7.50it/s, avg_loss=4.4099, batch_loss=9.0800] 


Epoch 6 | Test Loss: 4.4099 | Test WER: 40.11%


Epoch 7: 100%|██████████| 2836/2836 [10:32<00:00,  4.48it/s, avg_loss=3.2880, batch_loss=3.7880] 


Epoch 7 | Train Loss: 3.2880 | Train WER: 37.59%


Epoch 7 (Eval): 100%|██████████| 315/315 [00:41<00:00,  7.55it/s, avg_loss=4.1758, batch_loss=9.0245] 


Epoch 7 | Test Loss: 4.1758 | Test WER: 37.74%


Epoch 8: 100%|██████████| 2836/2836 [10:31<00:00,  4.49it/s, avg_loss=2.9644, batch_loss=3.4560] 


Epoch 8 | Train Loss: 2.9644 | Train WER: 34.28%


Epoch 8 (Eval): 100%|██████████| 315/315 [00:41<00:00,  7.60it/s, avg_loss=4.0418, batch_loss=7.4363] 


Epoch 8 | Test Loss: 4.0418 | Test WER: 37.48%


Epoch 9: 100%|██████████| 2836/2836 [10:33<00:00,  4.48it/s, avg_loss=2.7409, batch_loss=2.2930]


Epoch 9 | Train Loss: 2.7409 | Train WER: 31.45%


Epoch 9 (Eval): 100%|██████████| 315/315 [00:41<00:00,  7.50it/s, avg_loss=3.9435, batch_loss=7.2266] 


Epoch 9 | Test Loss: 3.9435 | Test WER: 36.09%


Epoch 10: 100%|██████████| 2836/2836 [10:34<00:00,  4.47it/s, avg_loss=2.4340, batch_loss=1.9898] 


Epoch 10 | Train Loss: 2.4340 | Train WER: 28.14%


Epoch 10 (Eval): 100%|██████████| 315/315 [00:42<00:00,  7.45it/s, avg_loss=4.1142, batch_loss=8.6746] 


Epoch 10 | Test Loss: 4.1142 | Test WER: 38.07%


Epoch 11: 100%|██████████| 2836/2836 [10:35<00:00,  4.46it/s, avg_loss=2.4500, batch_loss=1.8065] 


Epoch 11 | Train Loss: 2.4500 | Train WER: 28.61%


Epoch 11 (Eval): 100%|██████████| 315/315 [00:40<00:00,  7.74it/s, avg_loss=4.1026, batch_loss=8.3597] 


Epoch 11 | Test Loss: 4.1026 | Test WER: 36.34%


Epoch 12: 100%|██████████| 2836/2836 [10:30<00:00,  4.50it/s, avg_loss=2.1570, batch_loss=2.8706] 


Epoch 12 | Train Loss: 2.1570 | Train WER: 24.93%


Epoch 12 (Eval): 100%|██████████| 315/315 [00:37<00:00,  8.42it/s, avg_loss=4.0430, batch_loss=8.5130] 


Epoch 12 | Test Loss: 4.0430 | Test WER: 35.83%


Epoch 13: 100%|██████████| 2836/2836 [10:33<00:00,  4.48it/s, avg_loss=1.7962, batch_loss=2.1441]


Epoch 13 | Train Loss: 1.7962 | Train WER: 19.95%


Epoch 13 (Eval): 100%|██████████| 315/315 [00:37<00:00,  8.42it/s, avg_loss=4.0109, batch_loss=7.9862] 


Epoch 13 | Test Loss: 4.0109 | Test WER: 35.54%


Epoch 14: 100%|██████████| 2836/2836 [10:30<00:00,  4.50it/s, avg_loss=1.7685, batch_loss=2.1068] 


Epoch 14 | Train Loss: 1.7685 | Train WER: 19.58%


Epoch 14 (Eval): 100%|██████████| 315/315 [00:37<00:00,  8.39it/s, avg_loss=3.9885, batch_loss=7.5240] 


Epoch 14 | Test Loss: 3.9885 | Test WER: 34.90%


Epoch 15: 100%|██████████| 2836/2836 [10:31<00:00,  4.49it/s, avg_loss=1.6041, batch_loss=0.4994] 


Epoch 15 | Train Loss: 1.6041 | Train WER: 17.40%


Epoch 15 (Eval): 100%|██████████| 315/315 [00:38<00:00,  8.25it/s, avg_loss=4.1244, batch_loss=8.4539] 


Epoch 15 | Test Loss: 4.1244 | Test WER: 35.52%


Epoch 16: 100%|██████████| 2836/2836 [10:35<00:00,  4.46it/s, avg_loss=1.5025, batch_loss=1.7602] 


Epoch 16 | Train Loss: 1.5025 | Train WER: 15.92%


Epoch 16 (Eval): 100%|██████████| 315/315 [00:37<00:00,  8.39it/s, avg_loss=4.0013, batch_loss=7.4720] 


Epoch 16 | Test Loss: 4.0013 | Test WER: 34.93%


Epoch 17: 100%|██████████| 2836/2836 [10:32<00:00,  4.49it/s, avg_loss=1.4130, batch_loss=1.8096] 


Epoch 17 | Train Loss: 1.4130 | Train WER: 14.44%


Epoch 17 (Eval): 100%|██████████| 315/315 [00:38<00:00,  8.26it/s, avg_loss=4.1060, batch_loss=9.2641] 


Epoch 17 | Test Loss: 4.1060 | Test WER: 35.40%


Epoch 18: 100%|██████████| 2836/2836 [10:31<00:00,  4.49it/s, avg_loss=1.3652, batch_loss=1.1492] 


Epoch 18 | Train Loss: 1.3652 | Train WER: 13.87%


Epoch 18 (Eval): 100%|██████████| 315/315 [00:37<00:00,  8.36it/s, avg_loss=4.0941, batch_loss=8.2360] 


Epoch 18 | Test Loss: 4.0941 | Test WER: 35.11%


Epoch 19: 100%|██████████| 2836/2836 [10:31<00:00,  4.49it/s, avg_loss=1.2652, batch_loss=0.8660] 


Epoch 19 | Train Loss: 1.2652 | Train WER: 12.47%


Epoch 19 (Eval): 100%|██████████| 315/315 [00:37<00:00,  8.37it/s, avg_loss=4.0748, batch_loss=9.1271] 


Epoch 19 | Test Loss: 4.0748 | Test WER: 34.84%


Epoch 20: 100%|██████████| 2836/2836 [10:31<00:00,  4.49it/s, avg_loss=1.2457, batch_loss=3.0610] 


Epoch 20 | Train Loss: 1.2457 | Train WER: 12.04%


Epoch 20 (Eval): 100%|██████████| 315/315 [00:38<00:00,  8.13it/s, avg_loss=4.0792, batch_loss=8.2187] 


Epoch 20 | Test Loss: 4.0792 | Test WER: 35.14%


Epoch 21:   9%|▉         | 260/2836 [00:58<09:01,  4.75it/s, avg_loss=1.2275, batch_loss=0.7149]

In [None]:
checkpoint = {
    "model_state_dict" : model.state_dict(),
    "train_wer" : train_wer,
    "test_wer" : test_wer,
    "train_loss" : train_loss,
    "test_loss" : test_loss,
    "epoch" : epochs-1,
    "optimizer" : optimizer.state_dict(),
    "scheduler" : scheduler.state_dict()
}

In [None]:
torch.save(obj=checkpoint, f=f"Mobile_ViT_LSTM_{epochs-1}_epochs.pth")