In [None]:
!pip install -q torch transformers tqdm matplotlib numpy pandas torchmetrics

In [None]:
!apt install tree -y
!rm -rf ./data
!mkdir data
!cp /content/txt2openpose-Data.zip ./file.zip
!unzip ./file.zip -d ./data
!clear

In [None]:
!tree "/content/data/txt2openpose-Data - Copy"

# DataLoader

In [None]:
import os
import json
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import T5Tokenizer
import torch.nn as nn
from transformers import T5Model

# Dataset Class
class MotionDataset(Dataset):
    def __init__(self, root_dir, tokenizer, max_length=128):
        self.root_dir = root_dir
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.file_paths = self._get_file_paths()

    def _get_file_paths(self):
        file_paths = []
        for root, _, files in os.walk(self.root_dir):
            for file in files:
                if file.endswith('.json'):
                    file_paths.append(os.path.join(root, file))
        return file_paths

    def _load_json(self, file_path):
        with open(file_path, 'r') as f:
            data = json.load(f)
        return data

    def _extract_keypoints(self, data):
        keypoints = []
        for person in data['people']:
            keypoints.extend(person['pose_keypoints_2d'])
        return keypoints

    def _extract_path_info(self, file_path):
        relative_path = os.path.relpath(file_path, self.root_dir)
        parts = relative_path.split(os.sep)
        category = parts[0]
        subcategory = parts[1]
        filename = os.path.splitext(parts[2])[0]
        path_info = f"{category}, {subcategory}, {filename}"
        return path_info

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        file_path = self.file_paths[idx]
        data = self._load_json(file_path)
        keypoints = self._extract_keypoints(data)

        # Reshape keypoints to (num_joints, 3) and then to (num_joints, 2) since z is always 1
        keypoints = torch.tensor(keypoints).view(-1, 3)
        keypoints = keypoints[:, :2]

        path_info = self._extract_path_info(file_path)
        encoded_input = self.tokenizer(
            path_info,
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )

        input_ids = encoded_input['input_ids'].squeeze()  # [max_length]
        attention_mask = encoded_input['attention_mask'].squeeze()  # [max_length]
        return input_ids, attention_mask, keypoints

def collate_fn(batch):
    input_ids, attention_masks, labels = zip(*batch)
    input_ids = torch.stack(input_ids)
    attention_masks = torch.stack(attention_masks)
    labels = torch.stack(labels)
    return input_ids, attention_masks, labels


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = T5Tokenizer.from_pretrained('t5-small')
dataset = MotionDataset('/content/data/txt2openpose-Data - Copy', tokenizer)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=collate_fn)

# Example to decode batch of input_ids into text
for batch in dataloader:
    input_ids, attention_masks, keypoints = batch
    decoded_texts = [tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]
    print(decoded_texts)  # List of decoded texts for the batch
    print(keypoints.shape)
    break  # Remove this break to iterate over the entire dataset

In [None]:
import matplotlib.pyplot as plt
import numpy as np
def plot_fromPerson(person, person_idx):
        keypoints = person
        keypoints = np.array(keypoints).reshape(-1, 2)

        # Plot keypoints
        plt.scatter(keypoints[:, 0], keypoints[:, 1], s=10, c='r')

        # Connect keypoints
        for i, j in [(0, 1), (1, 2), (2, 3), (3, 4), (1, 5), (5, 6), (6, 7), (1, 8),
                     (8, 9), (9, 10), (1, 11), (11, 12), (12, 13)]:
            plt.plot([keypoints[i, 0], keypoints[j, 0]],
                     [keypoints[i, 1], keypoints[j, 1]], 'r')

        # Add label for each person
        plt.text(keypoints[0, 0], keypoints[0, 1], f'Person {person_idx}', fontsize=10, color='blue')

def plot_openpose(people):
    plt.figure(figsize=(8, 8))
    plt.imshow(np.zeros((300, 900, 3)))  # Create an empty image to plot keypoints on

    for idx, person in enumerate(people):
      plot_fromPerson(person, idx)

    plt.gca()  # Invert y-axis to match image coordinate system
    plt.show()

def format_keypoints(keypoints):
    return keypoints.flatten().reshape(5, 36)

for batch in dataloader:
    input_ids, attention_masks, keypoints = batch
    decoded_texts = [tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]
    print(decoded_texts)  # List of decoded texts for the batch
    print(keypoints.shape)
    kp = format_keypoints(keypoints)
    plot_openpose(kp)
    break  # Remove this break to iterate over the entire dataset

In [None]:
import torch
import torch.nn as nn
from transformers import T5Model, T5Tokenizer
from tqdm import tqdm
from torchmetrics.regression import MeanAbsoluteError
import math

def eval(preds, target):
  mean_absolute_error = MeanAbsoluteError()
  error = mean_absolute_error(preds, target)
  return eval

class Text2Motion(nn.Module):
    def __init__(self, t5_model_name='t5-small', output_points=90): # total output points (1 dim)
        super(Text2Motion, self).__init__()
        self.output_points = output_points

        # Load T5 model's encoder
        self.t5_encoder = T5Model.from_pretrained(t5_model_name).encoder

        # Define custom output layer
        self.output_layer = nn.Linear(self.t5_encoder.config.hidden_size, self.output_points * 2) # treadted as 2D point

    def forward(self, input_ids, attention_mask=None):
        # Get T5 encoder outputs
        encoder_outputs = self.t5_encoder(input_ids=input_ids, attention_mask=attention_mask)
        hidden_state = encoder_outputs.last_hidden_state  # [batch_size, seq_len, d_model]

        # Take the first token's output as the representation (similar to using [CLS] token in BERT)
        cls_token_state = hidden_state[:, 0, :]  # [batch_size, d_model]

        # Apply custom output layer
        motion_output = self.output_layer(cls_token_state)  # [batch_size, output_points * 2]

        # Reshape to [batch_size, output_points, 2]
        motion_output = motion_output.view(-1, self.output_points, 2)

        return motion_output



def display_batch_keypoints(list_o_input_ids, list_o_attention_mask, list_o_keypoint, tokenizer):
    decoded_texts = [tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]
    for idx, text in enumerate(decoded_texts):
      print(text)
      plot_openpose(format_keypoints(list_o_keypoint[idx]))

loss_logs = []

current_epoch = 0
# Define a simple training loop
def train(model, dataloader, optimizer, criterion, device, tokenizer):
    model.train()
    total_loss = 0.0


    for idx, batch in enumerate(tqdm(dataloader, desc="Steps")):
        input_ids, attention_masks, targets = batch
        input_ids = input_ids.to(device)
        attention_masks = attention_masks.to(device)
        targets = targets.to(device)

        optimizer.zero_grad()
        outputs = model(input_ids, attention_mask=attention_masks)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        loss = loss.item()
        total_loss += loss
        loss_logs.append(loss)

        if (current_epoch % 50 == 0) and (idx == len(batch) - 1):
          display_batch_keypoints(
              list_o_input_ids=input_ids,
              list_o_attention_mask=attention_masks,
              list_o_keypoint = outputs.cpu().detach().numpy(),
              tokenizer=tokenizer
          )

    return total_loss / len(dataloader)

# Example usage:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = T5Tokenizer.from_pretrained('t5-small')
dataset = MotionDataset('/content/data/txt2openpose-Data - Copy', tokenizer)

train_size = math.floor(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

train_dataloader = DataLoader(train_dataset, batch_size=256, shuffle=True, collate_fn=collate_fn)
test_dataloader  = DataLoader(test_dataset, batch_size=256, shuffle=True, collate_fn=collate_fn)

model = Text2Motion().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

# Training loop
num_epochs = 2000
for epoch in tqdm(range(num_epochs), desc="- Epocs"):
    current_epoch = epoch
    loss = train(model, train_dataloader, optimizer, criterion, device, tokenizer)
    if (epoch % 50 == 0):
        plt.plot(loss_logs)
    print(f"Epoch {epoch + 1}, Loss: {loss}")

In [None]:
torch.save(model.state_dict(), './2000_cpkt.pt')

In [None]:
plt.plot(loss_logs)

In [None]:
import numpy as np

model.eval()
mae = []
MAELoss = nn.L1Loss()
with torch.no_grad():
  for i, batch in enumerate(test_dataloader):
    input_ids, attention_masks, targets = batch
    input_ids = input_ids.to(device)
    attention_masks = attention_masks.to(device)
    targets = targets.to(device)
    outputs = model(input_ids, attention_mask=attention_masks)
    test_loss = MAELoss(outputs, targets)
    mae.append(test_loss.cpu().detach())

print("MAE Mean: ", np.array(mae).mean())