### PIP AND IMPORT

In [None]:
# %pip install torch==2.2.2
# %pip install torchtext==0.4

In [None]:
import math

import cv2
import numpy as np
import torch

In [None]:
# Plot a video given a tensor of joints, a file path, video name and references/sequence ID
def plot_skeletons_video(joints, file_path, video_name, references=None, skip_frames=1, sequence_ID=None, pad_token: int = 0):
    # Create video template
    FPS = 25 // skip_frames
    video_file = file_path + "/{}.mp4".format(video_name.split(".")[0])
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")

    frame_size = (650, 650) if references is None else (1300, 650)
    video = cv2.VideoWriter(video_file, fourcc, float(FPS), frame_size, True)

    for frame_index, frame_joints in enumerate(joints):
        # Reached padding
        if pad_token in frame_joints:
            continue

        # Initialise frame of white
        frame = np.ones((650, 650, 3), np.uint8) * 255

        # Cut off the percent_tok, multiply by 3 to restore joint size
        frame_joints = frame_joints[:] * 3
        frame_joints_2d = np.reshape(frame_joints, (50, 3))[:, :2]

        # Draw the frame given 2D joints and add text
        draw_frame_2D(frame, frame_joints_2d)
        cv2.putText(frame, "Predicted Sign Pose", (180, 600), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)

        # If reference is provided, create and concatenate on the end
        if references is not None:
            # Extract the reference joints
            ref_joints = references[frame_index]
            # Initialise frame of white
            ref_frame = np.ones((650, 650, 3), np.uint8) * 255

            # Cut off the percent_tok and multiply each joint by 3 (as was reduced in training files)
            ref_joints = ref_joints[:] * 3
            ref_joints_2d = np.reshape(ref_joints, (50, 3))[:, :2]

            # Draw these joints on the frame and add text
            draw_frame_2D(ref_frame, ref_joints_2d, offset=(0, -20))
            cv2.putText(ref_frame, "Ground Truth Pose", (190, 600), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2)

            # Concatenate the two frames
            frame = np.concatenate((frame, ref_frame), axis=1)

            # Add the sequence ID to the frame
            sequence_ID_write = "Sequence ID: " + sequence_ID.split("/")[-1]
            cv2.putText(frame, sequence_ID_write, (700, 635), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 2)

        # Write the video frame
        video.write(frame)
    # Release the video
    video.release()

# This is the format of the 3D data, outputted from the Inverse Kinematics model
def getSkeletalModelStructure():
    # Definition of skeleton model structure:
    #   The structure is an n-tuple of:
    #
    #   (index of a start point, index of an end point, index of a bone)
    #
    #   E.g., this simple skeletal model
    #
    #             (0)
    #              |
    #              |
    #              0
    #              |
    #              |
    #     (2)--1--(1)--1--(3)
    #      |               |
    #      |               |
    #      2               2
    #      |               |
    #      |               |
    #     (4)             (5)
    #
    #   has this structure:
    #
    #   (
    #     (0, 1, 0),
    #     (1, 2, 1),
    #     (1, 3, 1),
    #     (2, 4, 2),
    #     (3, 5, 2),
    #   )
    #
    #  Warning 1: The structure has to be a tree.
    #  Warning 2: The order isn't random. The order is from a root to lists.
    #

    return (
        # head
        (0, 1, 0),
        # left shoulder
        (1, 2, 1),
        # left arm
        (2, 3, 2),
        # (3, 4, 3),
        # Changed to avoid wrist, go straight to hands
        (3, 29, 3),
        # right shoulder
        (1, 5, 1),
        # right arm
        (5, 6, 2),
        # (6, 7, 3),
        # Changed to avoid wrist, go straight to hands
        (6, 8, 3),
        # left hand - wrist
        # (7, 8, 4),
        # left hand - palm
        (8, 9, 5),
        (8, 13, 9),
        (8, 17, 13),
        (8, 21, 17),
        (8, 25, 21),
        # left hand - 1st finger
        (9, 10, 6),
        (10, 11, 7),
        (11, 12, 8),
        # left hand - 2nd finger
        (13, 14, 10),
        (14, 15, 11),
        (15, 16, 12),
        # left hand - 3rd finger
        (17, 18, 14),
        (18, 19, 15),
        (19, 20, 16),
        # left hand - 4th finger
        (21, 22, 18),
        (22, 23, 19),
        (23, 24, 20),
        # left hand - 5th finger
        (25, 26, 22),
        (26, 27, 23),
        (27, 28, 24),
        # right hand - wrist
        # (4, 29, 4),
        # right hand - palm
        (29, 30, 5),
        (29, 34, 9),
        (29, 38, 13),
        (29, 42, 17),
        (29, 46, 21),
        # right hand - 1st finger
        (30, 31, 6),
        (31, 32, 7),
        (32, 33, 8),
        # right hand - 2nd finger
        (34, 35, 10),
        (35, 36, 11),
        (36, 37, 12),
        # right hand - 3rd finger
        (38, 39, 14),
        (39, 40, 15),
        (40, 41, 16),
        # right hand - 4th finger
        (42, 43, 18),
        (43, 44, 19),
        (44, 45, 20),
        # right hand - 5th finger
        (46, 47, 22),
        (47, 48, 23),
        (48, 49, 24),
    )

# Draw a line between two points, if they are positive points
def draw_line(im, joint1, joint2, c=(0, 0, 255), t=1, width=3):
    thresh = -100
    if joint1[0] > thresh and joint1[1] > thresh and joint2[0] > thresh and joint2[1] > thresh:

        center = (int((joint1[0] + joint2[0]) / 2), int((joint1[1] + joint2[1]) / 2))

        length = int(math.sqrt(((joint1[0] - joint2[0]) ** 2) + ((joint1[1] - joint2[1]) ** 2)) / 2)

        angle = math.degrees(math.atan2((joint1[0] - joint2[0]), (joint1[1] - joint2[1])))

        cv2.ellipse(im, center, (width, length), -angle, 0.0, 360.0, c, -1)

# Draw the frame given 2D joints that are in the Inverse Kinematics format
def draw_frame_2D(frame, joints, offset: tuple[int, int] = (0, 0)):
    # Line to be between the stacked
    draw_line(frame, [1, 650], [1, 1], c=(0, 0, 0), t=1, width=1)
    # Give an offset to center the skeleton around

    # Get the skeleton structure details of each bone, and size
    skeleton = getSkeletalModelStructure()
    skeleton = np.array(skeleton)

    number = skeleton.shape[0]

    # Increase the size and position of the joints
    joints = joints * 10 * 12 * 2
    joints = joints + np.ones((50, 2)) * (0,0)

    # Loop through each of the bone structures, and plot the bone
    for j in range(number):

        c = get_bone_colour(skeleton, j)

        draw_line(
            frame,
            [joints[skeleton[j, 0]][0], joints[skeleton[j, 0]][1]],
            [joints[skeleton[j, 1]][0], joints[skeleton[j, 1]][1]],
            c=c,
            t=1,
            width=1,
        )

# get bone colour given index
def get_bone_colour(skeleton, j):
    bone = skeleton[j, 2]

    if bone == 0:  # head
        c = (0, 153, 0)
    elif bone == 1:  # Shoulder
        c = (0, 0, 255)

    elif bone == 2 and skeleton[j, 1] == 3:  # left arm
        c = (0, 102, 204)
    elif bone == 3 and skeleton[j, 0] == 3:  # left lower arm
        c = (0, 204, 204)

    elif bone == 2 and skeleton[j, 1] == 6:  # right arm
        c = (0, 153, 0)
    elif bone == 3 and skeleton[j, 0] == 6:  # right lower arm
        c = (0, 204, 0)

    # Hands
    elif bone in [5, 6, 7, 8]:
        c = (0, 0, 255)
    elif bone in [9, 10, 11, 12]:
        c = (51, 255, 51)
    elif bone in [13, 14, 15, 16]:
        c = (255, 0, 0)
    elif bone in [17, 18, 19, 20]:
        c = (204, 153, 255)
    elif bone in [21, 22, 23, 24]:
        c = (51, 255, 255)

    return c

# Apply DTW to the produced sequence, so it can be visually compared to the reference sequence
def alter_DTW_timing(pred_seq, ref_seq):

    # Define a cost function
    euclidean_norm = lambda x, y: np.sum(np.abs(x - y))

    # Cut the reference down to the max count value
    _, ref_max_idx = torch.max(ref_seq[:, -1], 0)
    if ref_max_idx == 0:
        ref_max_idx += 1
    # Cut down frames by counter
    ref_seq = ref_seq[:ref_max_idx, :].cpu().numpy()

    # Cut the hypothesis down to the max count value
    _, hyp_max_idx = torch.max(pred_seq[:, -1], 0)
    if hyp_max_idx == 0:
        hyp_max_idx += 1
    # Cut down frames by counter
    pred_seq = pred_seq[:hyp_max_idx, :].cpu().numpy()

    # Run DTW on the reference and predicted sequence
    d, cost_matrix, acc_cost_matrix, path = dtw(ref_seq[:, :-1], pred_seq[:, :-1], dist=euclidean_norm)

    # Normalise the dtw cost by sequence length
    d = d / acc_cost_matrix.shape[0]

    # Initialise new sequence
    new_pred_seq = np.zeros_like(ref_seq)
    # j tracks the position in the reference sequence
    j = 0
    skips = 0
    squeeze_frames = []
    for i, pred_num in enumerate(path[0]):

        if i == len(path[0]) - 1:
            break

        if path[1][i] == path[1][i + 1]:
            skips += 1

        # If a double coming up
        if path[0][i] == path[0][i + 1]:
            squeeze_frames.append(pred_seq[i - skips])
            j += 1
        # Just finished a double
        elif path[0][i] == path[0][i - 1]:
            new_pred_seq[pred_num] = avg_frames(squeeze_frames)
            squeeze_frames = []
        else:
            new_pred_seq[pred_num] = pred_seq[i - skips]

    return new_pred_seq, ref_seq, d

# Find the average of the given frames
def avg_frames(frames):
    frames_sum = np.zeros_like(frames[0])
    for frame in frames:
        frames_sum += frame

    avg_frame = frames_sum / len(frames)
    return avg_frame

### final

In [None]:
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from torchtext.vocab import build_vocab_from_iterator

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class SignLanguageDataset(Dataset):
    def __init__(self, skel_file='../T2S-GPT/data/alldata/train.skels', text_file='../T2S-GPT/data/alldata/train.txt', vocab=None, seq_length=200, text_length=30, sign_language_dim=150, window_size=200):
        self.seq_length = seq_length
        self.text_length = text_length
        self.sign_language_dim = sign_language_dim
        self.window_size = window_size

        self.min_value = float("inf")
        self.max_value = -float("inf")

        # Read and process skeleton data
        with open(skel_file, 'r') as f:
            self.sign_language_data = f.readlines()
            self.sign_language_data = [line.strip().split(" ") for line in self.sign_language_data]
            self.sign_language_data = [[float(val) for val in line] for line in self.sign_language_data]
            self.sign_language_data = [torch.tensor(line).reshape(-1, self.sign_language_dim + 1)[:, :-1] for line in self.sign_language_data]

        # Read and process text data
        with open(text_file, 'r') as f:
            self.text_data = [line.strip().split() for line in f]

        # Filter sequences shorter than window size
        self.sign_language_data, self.text_data = zip(*[
            (line, self.text_data[idx])
            for idx, line in enumerate(self.sign_language_data)
            if line.shape[0] > self.window_size
        ])

        # Normalize the skeleton data
        for line in self.sign_language_data:
            self.min_value = min(self.min_value, line.min())
            self.max_value = max(self.max_value, line.max())
        self.sign_language_data = [(line - self.min_value) / (self.max_value - self.min_value) for line in self.sign_language_data]

        # Build or use existing vocabulary for text data
        self.vocab = vocab if vocab else self.build_vocab(self.text_data)

    def build_vocab(self, text_data):
        def yield_tokens(data):
            for text in data:
                yield text
        return build_vocab_from_iterator(yield_tokens(text_data))

    def preprocess_data(self, X_T):
        padded_X_T = F.pad(X_T, (0, 512 - X_T.size(-1)), "constant", 0)
        return padded_X_T

    def __len__(self):
        return len(self.sign_language_data)

    def __getitem__(self, idx):
        seq_length = self.sign_language_data[idx].shape[0]
        if seq_length < self.window_size:
            raise ValueError(f"Sequence at index {idx} is shorter than window size {self.window_size}")

        start_index = torch.randint(0, seq_length - self.window_size, (1,)).item()
        end_index = start_index + self.window_size
        sign_language_sequence = self.sign_language_data[idx][start_index:end_index]

        spoken_language_text = self.text_data[idx]
        spoken_language_tensor = torch.tensor([self.vocab[word] for word in spoken_language_text[:self.text_length]], dtype=torch.long)

        # Generate placeholders for additional variables
        skel_indices = torch.randint(0, 512, (self.window_size,))  # Example placeholder
        skel_reconstructed = sign_language_sequence.clone().unsqueeze(0)  # Adding batch dimension for 3D
        indices_prediction = skel_indices.clone().unsqueeze(0)            # Adding batch dimension for 2D
        sign_prediction = spoken_language_tensor.clone().unsqueeze(0).unsqueeze(-1)  # Making it 3D: (1, text_length, 1)
        converted_skel = sign_language_sequence.clone()
        converted_skel_reconstructed = skel_reconstructed.clone()

        return {
            'sign_language_sequence': sign_language_sequence.to(device),
            'spoken_language_text': spoken_language_tensor.to(device),
            'skel_indices': skel_indices.to(device),
            'skel_reconstructed': skel_reconstructed.to(device),
            'indices_prediction': indices_prediction.to(device),
            'sign_prediction': sign_prediction.to(device),
            'converted_skel': converted_skel.to(device),
            'converted_skel_reconstructed': converted_skel_reconstructed.to(device)
        }

# Instantiate your dataset to check it works
dataset = SignLanguageDataset()

# Fetch an item to verify
data_index = 8
item = dataset[data_index]

text = item["spoken_language_text"]
print(f'Text: "{item["spoken_language_text"]}"')
print("-" * 50)
skel = item["sign_language_sequence"]
print(f'Skel: "{item["sign_language_sequence"]}"')
print(f'Skel: "{item["sign_language_sequence"].shape}"')
print("-" * 50)
skel_indices = item['skel_indices']
print(f"Skeleton indices: {item['skel_indices']}")
print(f"Skeleton indices: {item['skel_indices'].shape}")
print("-" * 50)
skel_reconstructed = item['skel_reconstructed']
print(f"Reconstructed skeleton: {item['skel_reconstructed']}")
print(f"Reconstructed skeleton: {item['skel_reconstructed'].shape}")
print("-" * 50)
indices_prediction = item['indices_prediction']
print(f"Indices prediction: {item['indices_prediction']}")
print(f"Indices prediction: {item['indices_prediction'].shape}")
print("-" * 50)
sign_prediction = item['sign_prediction']
print(f"Sign prediction: {item['sign_prediction']}")
print(f"Sign prediction: {item['sign_prediction'].shape}")
print("-" * 50)
converted_skel = item['converted_skel']
print(f"Converted skeleton: {item['converted_skel']}")
print("-" * 50)
converted_skel_reconstructed = item['converted_skel_reconstructed']
print(f"Converted skeleton reconstructed: {item['converted_skel_reconstructed']}")

# from capstone_utils.plot_skeletons import plot_skeletons_video
plot_skeletons_video(converted_skel, '.', 'rose.mp4')