In [None]:
!pip install jiwer
!pip install matplotlib
!pip install transformers
!pip install opencv-python
# !pip install flash-attn --no-build-isolation

In [None]:
import torch
from torch import nn
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader, random_split
from torch.amp import autocast, GradScaler
import torchvision
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchvision.transforms.functional as F

from transformers import ViTModel

# from tqdm.notebook import tqdm
from tqdm import tqdm
from PIL import Image
import json
import os
import numpy as np
import cv2
import random
from jiwer import wer

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

In [None]:
def show_sequence(sequence, NUM_FRAMES):
    columns = 4
    rows = (NUM_FRAMES + 1) // (columns)
    fig = plt.figure(figsize=(32, (16 // columns) * rows))
    gs = gridspec.GridSpec(rows, columns)
    for j in range(rows * columns):
        plt.subplot(gs[j])
        plt.axis("off")
        frames = sequence[j].permute(1,2,0).numpy()
        frames = frames/ frames.max()
        plt.imshow(frames)

    plt.show()

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
random.seed(42)

In [None]:
train_csv = "/kaggle/input/hmmmmmm/phoenix/annotations/manual/train.corpus.csv"
test_csv = "/kaggle/input/hmmmmmm/phoenix/annotations/manual/test.corpus.csv"
dev_csv = "/kaggle/input/hmmmmmm/phoenix/annotations/manual/dev.corpus.csv"

train_paths = "/kaggle/input/hmmmmmm/phoenix/fullFrame-210x260px/train"
test_paths = "/kaggle/input/hmmmmmm/phoenix/fullFrame-210x260px/test"
dev_paths =  "/kaggle/input/hmmmmmm/phoenix/fullFrame-210x260px/dev"

In [None]:
max_frames = 96
num_workers = 4
batch_size = 2
prefetch_factor = 4

In [None]:
word_to_idx = { '<p>':0}
idx_to_word = ['<p>']

arr_train = np.loadtxt(train_csv, delimiter='|', dtype='str')
arr_train = np.delete(arr_train,0,0)
arr_test = np.loadtxt(test_csv, delimiter='|', dtype='str')
arr_test = np.delete(arr_test,0,0)
arr_dev = np.loadtxt(dev_csv, delimiter='|', dtype='str')
arr_dev = np.delete(arr_dev,0,0)

arr = np.concatenate((arr_train, arr_test, arr_dev), axis=0)

for sentence in arr:
    for word in sentence[3].split(' '):
        if word not in idx_to_word:
            idx_to_word.append(word)
            word_to_idx[word] = len(idx_to_word)-1

In [None]:
len(idx_to_word), len(word_to_idx)

In [None]:
class rwth_phoenix(Dataset):
    def __init__(self, csv, data_path, kp_path, frame_transform, video_transform, input_fps, output_fps, max_frames, stride, word_dict):

        temp = np.loadtxt(csv, delimiter='|', dtype='str')
        self.csv = np.delete(temp, 0, 0)
        self.word_dict = word_dict
        self.data_path = data_path
        self.kp_path = kp_path

        self.frame_transform = frame_transform
        self.video_transform = video_transform

        self.input_fps = input_fps
        self.output_fps = output_fps
        self.max_frames = max_frames
        self.stride = stride


    def __len__(self):
        return len(self.csv)


    def __getitem__(self, idx):
        folder = self.csv[idx][1].split('/')
        label = self.csv[idx][3].split(' ')

        words = []
        for word in label:
            words.append(self.word_dict[word])
        label = torch.tensor(words)
        image_folder_path = os.path.join(self.data_path, folder[0], folder[1])
        kp_folder_path = os.path.join(self.kp_path, folder[0], folder[1])

        images = sorted(os.listdir(image_folder_path))
        end = len(images)

        step = self.input_fps/random.choice(self.output_fps)
        image_list = []
        kp_image_list =[]
        if int(end//step +1) <= self.max_frames:
            frame_num, num = 0, 0
            while frame_num < end:
                num+=1
                if self.stride and num%self.stride == 0:
                    image_list.append(str(int(frame_num))+'a')

                else:
                    img = Image.open(os.path.join(image_folder_path, images[int(frame_num)]))
                    kp = Image.open(os.path.join(kp_folder_path, images[int(frame_num)]))


                    tensor_img, tensor_kp = self.frame_transform(img, kp)
                    tensor_img = F.to_tensor(tensor_img)
                    tensor_kp = F.to_tensor(tensor_kp)

                    kp_image_list.append(tensor_kp)
                    image_list.append(tensor_img)
                frame_num += step

            c, h, w = image_list[-1].shape
            while len(image_list) < self.max_frames:
                image_list.append(torch.zeros(c,h,w))
                kp_image_list.append(torch.zeros(c,h,w))

            tensor_kp_vid = torch.stack(kp_image_list[:self.max_frames])
            tensor_video = torch.stack(image_list[:self.max_frames])


            tensor_video, tensor_kp_vid = self.video_transform(tensor_video, tensor_kp_vid)


        else:
            frame_positions = np.linspace(0, end, self.max_frames, endpoint=False)
            num = 0
            for n in frame_positions:
                num+=1
                if self.stride and num%self.stride == 0:
                    image_list.append(str(int(n))+'a')

                else:
                    img = Image.open(os.path.join(image_folder_path, images[int(n)]))
                    kp = Image.open(os.path.join(kp_folder_path, images[int(n)]))


                    tensor_img, tensor_kp = self.frame_transform(img, kp)
                    tensor_img = F.to_tensor(tensor_img)
                    tensor_kp = F.to_tensor(tensor_kp)

                    kp_image_list.append(tensor_kp)
                    image_list.append(tensor_img)

            tensor_kp_vid = torch.stack(kp_image_list[:self.max_frames])
            tensor_video = torch.stack(image_list[:self.max_frames])

            tensor_video, tensor_kp_vid = self.video_transform(tensor_video, tensor_kp_vid)


        return tensor_video, tensor_kp_vid, label

In [None]:
def image_transform(img, kp_img):
    # Random crop
    i, j, h, w = transforms.RandomCrop.get_params(img, output_size=(248, 200))
    img = F.crop(img, i, j, h, w)
    kp_img = F.crop(kp_img, i, j, h, w)

    # Resize
    img = F.resize(img, (224, 224))
    kp_img = F.resize(kp_img, (224, 224))

    # Random rotation
    angle = transforms.RandomRotation.get_params([-5, 5])
    img = F.rotate(img, angle)
    kp_img = F.rotate(kp_img, angle)

    return img, kp_img

def video_transform(video, kp_video, p=0.5):
    if random.random() < p:
        video = torch.flip(video, dims=[3])  # Flip width
        kp_video = torch.flip(kp_video, dims=[3])
    return video, kp_video

In [None]:
def image_test_transform(img, kp_img):
    # Resize
    img = F.resize(img, (224, 224))
    kp_img = F.resize(kp_img, (224, 224))

    return img, kp_img

def video_test_transform(video, kp_video, p=0):
    if random.random() < p:
        video = torch.flip(video, dims=[3])  # Flip width
        kp_video = torch.flip(kp_video, dims=[3])
    return video, kp_video

In [None]:
def collate_fn(batch):
    vid, kp_vid, labels = zip(*batch)
    vid = torch.stack(vid)
    kp_vid = torch.stack(kp_vid)

    labels = torch.nn.utils.rnn.pad_sequence(
        sequences=labels,
        batch_first=True,
        padding_value=0,
    ).long()

    return vid, kp_vid, labels

In [None]:
train_dataset = rwth_phoenix(csv=train_csv,
                       data_path=train_paths,
                        kp_path='/kaggle/input/hmmmmmm-keypoint-vid/keypoints_vid/phoenix/fullFrame-210x260px/train',
                        frame_transform=image_transform , video_transform=video_transform, input_fps=25, output_fps=list(range(15,22)), max_frames=max_frames, stride=0, word_dict=word_to_idx)

test_dataset = rwth_phoenix(csv=test_csv,
                       data_path=test_paths,
                        kp_path='/kaggle/input/hmmmmmm-keypoint-vid/keypoints_vid/phoenix/fullFrame-210x260px/test',
                        frame_transform=image_test_transform , video_transform=video_test_transform, input_fps=25, output_fps=list(range(15,22)), max_frames=max_frames, stride=0, word_dict=word_to_idx)

dev_dataset = rwth_phoenix(csv=dev_csv,
                       data_path=dev_paths,
                        kp_path='/kaggle/input/hmmmmmm-keypoint-vid/keypoints_vid/phoenix/fullFrame-210x260px/dev',
                        frame_transform=image_test_transform , video_transform=video_test_transform, input_fps=25, output_fps=list(range(15,22)), max_frames=max_frames, stride=0, word_dict=word_to_idx)

In [None]:
train_dataloader = DataLoader(dataset=train_dataset, shuffle=True, batch_size=batch_size, collate_fn=collate_fn, prefetch_factor=prefetch_factor, num_workers=num_workers, pin_memory=True)
test_dataloader = DataLoader(dataset=test_dataset, shuffle=False, batch_size=batch_size, collate_fn=collate_fn, prefetch_factor=prefetch_factor, num_workers=num_workers, pin_memory=True)
dev_dataloader = DataLoader(dataset=dev_dataset, shuffle=False, batch_size=batch_size, collate_fn=collate_fn, prefetch_factor=prefetch_factor, num_workers=num_workers, pin_memory=True)

In [None]:
# vid, kp_vid, y = next(iter(train_dataloader))

In [None]:
# vid.shape

In [None]:
# show_sequence(vid[0], max_frames)

In [None]:
# show_sequence(kp_vid[0], max_frames)

In [None]:
scaler = GradScaler('cuda')

def decode_pred(tensor, idx_to_word=idx_to_word):
    n = tensor.shape[0]
    text = []
    for i in range(n):
        st = []
        prev_token = None
        for token in tensor[i]:
            if token != 0 and token != prev_token:
                st.append(idx_to_word[token.item()])
            prev_token = token
        text.append(' '.join(st))
    return text

def decode_target(tensor, idx_to_word=idx_to_word):
    text = []
    for seq in tensor:
        words = [idx_to_word[token.item()] for token in seq if token.item() != 0]
        text.append(' '.join(words))
    return text

def train_step(model, optimizer, dataloader, loss_fn, epoch, device="cuda"):
    model.train()
    train_loss, total_correct_wer = 0, 0

    # progress_bar = tqdm(dataloader, desc=f"Epoch {epoch}", dynamic_ncols=True)
    
    for batch_idx, (X, X_kp, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        with autocast('cuda'):
            y_logit = model(X).permute(1,0,2)

            if y_logit.isnan().any():
                print("⚠️ ABORT! NaN found in y_logit!")
                break
            y_length = torch.count_nonzero(y, axis=1)
            ctc_probs = y_logit.log_softmax(-1)
            ctc_probs_length = torch.full(
                size=(y.shape[0],),         
                fill_value=y_logit.size(0),
                dtype=torch.long
            )
            
            loss = loss_fn(ctc_probs, y, ctc_probs_length, y_length)
            batch_loss = loss.item()
            train_loss += batch_loss

        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

        y_pred = torch.argmax(y_logit.permute(1, 0, 2), dim=2)
        total_correct_wer += wer(decode_target(y), decode_pred(y_pred))

        # avg_loss = train_loss / (batch_idx + 1)
        # progress_bar.set_postfix(batch_loss=f"{batch_loss:.4f}", avg_loss=f"{avg_loss:.4f}")

        del X, X_kp, y, y_logit, loss, y_pred

    acc_wer = (total_correct_wer / len(dataloader)) * 100
    avg_loss = train_loss / len(dataloader)

    print(f"Epoch {epoch} | Train Loss: {avg_loss:.4f} | Train WER: {acc_wer:.2f}%")
    return avg_loss, acc_wer



def test_step(model, loss_fn, epoch, dataloader, scheduler, device="cuda"):
    model.eval()
    test_loss, total_correct_wer = 0, 0

    # progress_bar = tqdm(dataloader, desc=f"Epoch {epoch} (Eval)", dynamic_ncols=True)

    with torch.no_grad():
        for batch_idx, (X, X_kp, y) in enumerate(dataloader):
            X, y = X.to(device), y.to(device)

            with autocast('cuda'):
                y_logit = model(X).permute(1,0,2)

                if y_logit.isnan().any():
                    print("⚠️ ABORT! NaN found in y_logit!")
                    break

                y_length = torch.count_nonzero(y, axis=1)
                ctc_probs = y_logit.log_softmax(-1)
                ctc_probs_length = torch.full(
                    size=(y.shape[0],),         
                    fill_value=y_logit.size(0),
                    dtype=torch.long
                )
                
                loss = loss_fn(ctc_probs, y, ctc_probs_length, y_length)
                batch_loss = loss.item()
                test_loss += batch_loss

            y_pred = torch.argmax(y_logit.permute(1, 0, 2), dim=2)
            total_correct_wer += wer(decode_target(y), decode_pred(y_pred))

            # avg_loss = test_loss / (batch_idx + 1)
            # progress_bar.set_postfix(batch_loss=f"{batch_loss:.4f}", avg_loss=f"{avg_loss:.4f}")

            del X, X_kp, y, y_logit, loss, y_pred

    acc_wer = (total_correct_wer / len(dataloader)) * 100
    avg_loss = test_loss / len(dataloader)

    if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
        scheduler.step(avg_loss)

    print(f"Epoch {epoch} | Test Loss: {avg_loss:.4f} | Test WER: {acc_wer:.2f}%")
    return avg_loss, acc_wer

In [None]:
class ViViT_SLR(nn.Module):
    def __init__(self,
                 vit_weights = 'google/vit-base-patch16-224',
                 vocab_size = len(idx_to_word),
                 batch_first=True,
                 input_size=768,
                 hidden_size=1024,
                 num_layers=2,
                 layer_norm_eps=1e-06,
                 dropout=0.1,
                 frame=max_frames,  # 96 frames (input vid always have 96 frames)
                 pad_token=0,
                 max_pred=64,
                 ):

        super(ViViT_SLR, self).__init__()

        self.vit_weights = vit_weights
        self.vit = ViTModel.from_pretrained(
            self.vit_weights, 
            attn_implementation="sdpa", 
            torch_dtype=torch.float32, 
            use_safetensors=True
        )

        self.pos_embedding = nn.Parameter(torch.rand([1, frame, input_size]))
        nn.init.normal_(self.pos_embedding, mean=0.0, std=0.02)

        self.norm_encoding = nn.LayerNorm(input_size, eps=layer_norm_eps)

        self.vocab_size = vocab_size
        self.pad_token = pad_token
        self.frame = frame

        self.lstm = nn.LSTM(
            input_size=input_size, 
            hidden_size=hidden_size, 
            num_layers=num_layers, 
            batch_first=batch_first, 
            dropout=dropout, 
            bidirectional=True
        )

        self.dropout = nn.Dropout(dropout)

        self.fc_out = nn.Linear(hidden_size*2, vocab_size)

        self.max_pred = max_pred


    def forward(self, vid):
        
        B, T, C, H, W = vid.shape
        encoded_output = self.vit(vid.view(B*T, C, H, W)).last_hidden_state[:, 0:1, :]  # (B*T, 768)
        encoded_output = encoded_output.view(B, T, 768)  # (B, T, 768)

        encoded_output = encoded_output + self.pos_embedding
        encoded_output = self.norm_encoding(encoded_output)

        decoded_output, _ = self.lstm(encoded_output)  # (B, T, hidden_size * 2)

        logits = self.fc_out(self.dropout(decoded_output))  # (B, T, vocab_size)
        return logits

In [None]:
lr = 5e-5
dropout = 0.1
weight_decay = 0.01

In [None]:
def get_param_groups(model: ViViT_SLR, base_lr=5e-5, new_lr=1e-4, weight_decay=0.01):
    pretrained_params = []
    new_params = []

    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue

        # Identify newly added components based on naming
        if any(
            key in name for key in [
                "cross_attn", "norm_", "fc_out", "tf", "embedding"
                "early_", "mid_", "copy_pretrained", "freeze_unused", "adapter"
            ]
        ):
            new_params.append(param)
        else:
            pretrained_params.append(param)

    print(f"Pretrained params: {len(pretrained_params)}")
    print(f"Newly initialized params: {len(new_params)}")


    return [
        {"params": pretrained_params, "lr": base_lr, "weight_decay": weight_decay},
        {"params": new_params, "lr": new_lr, "weight_decay": weight_decay},
    ]


In [None]:
torch.manual_seed(42)
torch.cuda.manual_seed(42)
device = "cuda"

model = ViViT_SLR().to(device)

model = nn.DataParallel(model)

checkpoint = torch.load(f="/kaggle/input/vit-lstm-epoch5/pytorch/default/1/ViT_LSTM_5_epochs.pth")
new_state_dict = checkpoint["model_state_dict"]
model.load_state_dict(new_state_dict, strict=True)
print("Loaded checkpoint weights!")

loss_fn = nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True)
param_groups = get_param_groups(model)
optimizer = torch.optim.AdamW(param_groups)


# optimizer = torch.optim.Adam(model.parameters(), lr=checkpoint["learning_rate"])
# print("Loaded learning rate successfully")
# print(f"Current learning rate: {checkpoint['learning_rate']}")

scheduler = ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.5,
    patience=2,
    min_lr=1e-6,
)

optimizer.load_state_dict(checkpoint["optimizer"])
scheduler.load_state_dict(checkpoint["scheduler"])

del checkpoint, new_state_dict

In [None]:
current = 5
num_epochs = 5

start = 1 + current
epochs = num_epochs + current + 1

print(f"Running from {start} to {epochs-1}")

torch.manual_seed(42)
torch.cuda.manual_seed(42)

print("Ready to train!!")

log = open('/kaggle/working/ViT_LSTM_readings.txt', 'w')
log.write('Epoch | Train loss | Train WER | Test loss | Test WER\n')

for epoch in range(start, epochs):
    train_loss, train_wer = train_step(
        model=model,
        optimizer=optimizer,
        loss_fn=loss_fn,
        epoch=epoch,
        dataloader=train_dataloader
    )

    test_loss, test_wer = test_step(
        model=model,
        loss_fn=loss_fn,
        epoch=epoch,
        dataloader=test_dataloader,
        scheduler=scheduler
    )
    log.write(f'{epoch} | {train_loss:.4f} | {train_wer:.2f}% | {test_loss:.4f} | {test_wer:.2f}%\n')
    if epoch%5==0:
        checkpoint = {
            "model_state_dict" : model.state_dict(),
            "train_wer" : train_wer,
            "test_wer" : test_wer,
            "train_loss" : train_loss,
            "test_loss" : test_loss,
            "epoch" : epoch,
        }
        torch.save(obj=checkpoint, f=f"ViT_LSTM_{epoch}_epochs_{test_wer:.2f}_wer.pth")
        del checkpoint

log.close()

In [None]:
checkpoint = {
    "model_state_dict" : model.state_dict(),
    "train_wer" : train_wer,
    "test_wer" : test_wer,
    "train_loss" : train_loss,
    "test_loss" : test_loss,
    "epoch" : epochs-1,
    "optimizer" : optimizer.state_dict(),
    "scheduler" : scheduler.state_dict()
}

In [None]:
torch.save(obj=checkpoint, f=f"ViT_LSTM_{epochs-1}_epochs.pth")