In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

In [2]:
class TransformerMultiTask(nn.Module):
    def __init__(self, input_dim, num_heads:int, num_layers:int, hidden_dim:int, 
                 entity_id_cls:int, entity_state_cls:int):
        super(TransformerMultiTask, self).__init__()
        self.embedding = nn.Linear(input_dim, hidden_dim)
        
        transformer_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=num_heads, batch_first=True)
        self.transformer = nn.TransformerEncoder(transformer_layer, num_layers=num_layers)
        
        # Separate heads for classification tasks
        # Simpler actionable classification head
        self.actionable_fc = nn.Linear(hidden_dim, 2)
        
        # Deeper head for complex entity ID classification
        self.entity_id_fc = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.SiLU(),
            nn.Linear(hidden_dim // 2, hidden_dim // 4),
            nn.SiLU(),
            nn.Linear(hidden_dim // 4, entity_id_cls)
        )
        
        # Simpler head for entity state classification
        self.entity_state_fc = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.SiLU(),
            nn.Linear(hidden_dim // 2, hidden_dim // 4),
            nn.SiLU(),
            nn.Linear(hidden_dim // 4, entity_state_cls)
        )
        
        # Regression head for action time prediction
        self.action_time_regression = nn.Linear(hidden_dim, 1)

    def forward(self, x):
        # x shape: (batch_size, seq_length, input_dim)
        x = self.embedding(x)
        x = self.transformer(x)
        
        # Take the mean across the sequence dimension
        x = x.mean(dim=1)
        # Separate heads for classification and regression
        actionable_output = self.actionable_fc(x)
        entity_id_output = self.entity_id_fc(x)
        entity_state_output = self.entity_state_fc(x)
        time_pred_output = self.action_time_regression(x)
        
        # return actionable_output, entity_id_output, entity_state_output, time_pred_output
        return actionable_output, entity_id_output, entity_state_output, time_pred_output

In [3]:
import torch
from torch.utils.data import Dataset, DataLoader, IterableDataset
import os
import random
import io
import numpy as np
import pandas as pd

# Helper function to convert bytes back to tensor
def bytes_to_tensor(tensor_bytes):
    buffer = io.BytesIO(tensor_bytes)
    loaded_tensor = torch.load(buffer, weights_only=True)
    return loaded_tensor

from torch.utils.data import DataLoader

class MultiPartParquetCASASDataset(IterableDataset):
    def __init__(self, file_list, shuffle_files=True, shuffle_rows=True, transform=None):
        self.file_list = file_list
        self.shuffle_files = shuffle_files
        self.shuffle_rows = shuffle_rows
        self.transform = transform

        if self.shuffle_files:
            random.shuffle(self.file_list)  # Shuffle the order of files

    def _read_parquet_file(self, file_path):
        df = pd.read_parquet(file_path)

        if self.shuffle_rows:
            df = df.sample(frac=1).reset_index(drop=True)  # Shuffle rows within the dataframe

        return df

    def __len__(self):
        total_len = 0
        for file in self.file_list:
            df = self._read_parquet_file(file)
            total_len += len(df)
        return total_len
    
    def __iter__(self):
        for file in self.file_list:
            df = self._read_parquet_file(file)
            for _, row in df.iterrows():
                # shape(batch_size, sequence_length(var))
                input_tensor = bytes_to_tensor(row['sequence']) 
                time_prediction_tensor = torch.tensor(row['secs_from_last'], dtype=torch.float32)
                actionable_tensor = torch.tensor(not row['sensor_change'], dtype=torch.long)
                action_entity_id_tensor = torch.tensor(row['changed_entity_id'], dtype=torch.long)
                action_entity_state_tensor = torch.tensor(float(row['changed_entity_value']), dtype=torch.long)

                if self.transform:
                    input_tensor, time_prediction_tensor, actionable_tensor, action_entity_id_tensor, action_entity_state_tensor = self.transform(input_tensor, time_prediction_tensor, actionable_tensor, action_entity_id_tensor, action_entity_state_tensor)
                yield input_tensor, time_prediction_tensor, actionable_tensor, action_entity_id_tensor, action_entity_state_tensor


def split_file_list(file_list, train_ratio=0.8, test_ratio=0.1, eval_ratio=0.1):
    """
    Split the list of files into train/test/eval based on the given ratios.
    """
    assert train_ratio + test_ratio + eval_ratio == 1.0, "Ratios must sum to 1"
    
    # Shuffle the file list to ensure randomness
    random.shuffle(file_list)
    
    # Split based on the ratio
    total_files = len(file_list)
    train_end = int(train_ratio * total_files)
    test_end = train_end + int(test_ratio * total_files)
    
    train_files = file_list[:train_end]
    test_files = file_list[train_end:test_end]
    eval_files = file_list[test_end:]
    
    return train_files, test_files, eval_files

def _collate_fn(batch):
    # Separate inputs and targets
    input_tensors, time_prediction_tensors, actionable_tensors, action_entity_id_tensors, action_entity_state_tensors = zip(*batch)
    
    # Pad sequences for inputs (batch_first=True makes it [batch_size, seq_len, features])
    input_tensors_padded = pad_sequence(input_tensors, batch_first=True)
    
    # Convert targets to tensors (they should all have the same length as they're scalar values)
    time_prediction_tensors = torch.stack(time_prediction_tensors)
    actionable_tensors = torch.stack(actionable_tensors)
    action_entity_id_tensors = torch.stack(action_entity_id_tensors)
    action_entity_state_tensors = torch.stack(action_entity_state_tensors)
    
    return input_tensors_padded, time_prediction_tensors, actionable_tensors, action_entity_id_tensors, action_entity_state_tensors


In [14]:
import torch
from tqdm import tqdm
from torch.nn.utils.rnn import pad_sequence

def train_act_model(model, dataloader, num_epochs, 
                actionable_classification_criterion, 
                entity_id_classification_criterion, 
                entity_state_classification_criterion, 
                time_pred_regression_criterion, 
                optimizer, device):
    
    model.to(device)
    model.train()
    
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        correct_entity_state = 0
        total_entity_state = 0
        correct_entity_id = 0
        total_entity_id = 0
        epoch_time_pred_loss = 0
        
        for batch_idx, (input_tensors_padded, time_prediction_tensors, actionable_tensors, action_entity_id_tensors, action_entity_state_tensors) in tqdm(enumerate(dataloader), desc=f"Epoch {epoch+1}/{num_epochs}"):
            # Assuming the dataloader returns a dictionary-like batch with relevant keys
            inputs = input_tensors_padded.to(device)  # shape: (batch_size, max_seq_length, input_dim)
            actionable_labels = actionable_tensors.to(device)  # shape: (batch_size,)
            entity_id_labels = action_entity_id_tensors.to(device)  # shape: (batch_size,)
            entity_state_labels = action_entity_state_tensors.to(device)  # shape: (batch_size,)
            time_pred_labels = time_prediction_tensors.to(device)  # shape: (batch_size,)

            optimizer.zero_grad()

            # Forward pass
            actionable_output, entity_id_output, entity_state_output, time_pred_output = model(inputs)
            
            # Compute losses
            entity_id_loss = entity_id_classification_criterion(entity_id_output, entity_id_labels)
            entity_state_loss = entity_state_classification_criterion(entity_state_output, entity_state_labels)
            time_pred_loss = time_pred_regression_criterion(time_pred_output.squeeze(), time_pred_labels)

            # Total loss
            total_loss = entity_id_loss * 5 + entity_state_loss * 4 + time_pred_loss * 1
            
            # Backward pass and optimization
            total_loss.backward()
            optimizer.step()
            
            epoch_loss += total_loss.item()
            # accumulate epoch loss
            epoch_time_pred_loss += time_pred_loss.item()
            
            # Calculate entity ID classification accuracy
            _, predicted_entity_id = torch.max(entity_id_output, 1)
            correct_entity_id += (predicted_entity_id == entity_id_labels).sum().item()
            total_entity_id += entity_id_labels.size(0)

            # Calculate entity state classification accuracy
            _, predicted_entity_state = torch.max(entity_state_output, 1)
            correct_entity_state += (predicted_entity_state == entity_state_labels).sum().item()
            total_entity_state += entity_state_labels.size(0)
            
        # Print epoch stats
        avg_loss = epoch_loss / len(dataloader)
        avg_time_pred_loss = epoch_time_pred_loss / len(dataloader)
        entity_id_accuracy = 100 * correct_entity_id / total_entity_id
        entity_state_accuracy = 100 * correct_entity_state / total_entity_state
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}, Entity ID Acc: {entity_id_accuracy:.4f}%, Entity state Acc: {entity_state_accuracy:.2f}%, Time Pred Loss: {avg_time_pred_loss:.4f}")

    print("Act Training Complete!")

In [15]:
from os import listdir
from os.path import isfile, join
import re
data_root_path = './data/training_act'
data_pattern = r"^training_data_chunk_\d+\.parquet"
file_list = [join(data_root_path, f) for f in listdir(data_root_path) if isfile(join(data_root_path, f)) and re.match(data_pattern, f)][:100]
# Split the file list
train_files, test_files, eval_files = split_file_list(file_list, train_ratio=0.8, test_ratio=0.1, eval_ratio=0.1)

# Create separate datasets
train_dataset = MultiPartParquetCASASDataset(train_files, shuffle_files=True, shuffle_rows=True)
test_dataset = MultiPartParquetCASASDataset(test_files, shuffle_files=False, shuffle_rows=False)
eval_dataset = MultiPartParquetCASASDataset(eval_files, shuffle_files=False, shuffle_rows=False)

# Wrap them in DataLoaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=False, collate_fn=_collate_fn)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, collate_fn=_collate_fn)
eval_loader = DataLoader(eval_dataset, batch_size=64, shuffle=False, collate_fn=_collate_fn)


# Actionable Classification Criterion (binary classification)
actionable_classification_criterion = nn.CrossEntropyLoss()

# Entity ID Classification Criterion (multi-class classification)
entity_id_classification_criterion = nn.CrossEntropyLoss()

# Entity State Classification Criterion (multi-class classification)
entity_state_classification_criterion = nn.CrossEntropyLoss()

# Time Prediction Criterion (regression)
time_pred_regression_criterion = nn.MSELoss()



In [16]:
# define the model
model = TransformerMultiTask(input_dim=52, hidden_dim=1024, num_heads=2, num_layers=1, entity_id_cls=45, entity_state_cls=4)
training_epoch = 100
# Using Adam optimizer
import torch.optim as optim
learning_rate = 1e-4  # You can adjust the learning rate as needed
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# find the device for training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# start the training
train_act_model(model=model, dataloader=train_loader, num_epochs=training_epoch, 
            actionable_classification_criterion=actionable_classification_criterion, 
            entity_id_classification_criterion=entity_id_classification_criterion, 
            entity_state_classification_criterion=entity_state_classification_criterion, 
            time_pred_regression_criterion=time_pred_regression_criterion, 
            optimizer=optimizer, device=device)

Epoch 1/100: 19it [00:35,  1.89s/it]


Epoch [1/100], Loss: 20.8717, Entity ID Acc: 15.0833%, Entity state Acc: 55.33%, Time Pred Loss: 0.3135


Epoch 2/100: 19it [00:30,  1.60s/it]


Epoch [2/100], Loss: 15.4035, Entity ID Acc: 19.0833%, Entity state Acc: 58.33%, Time Pred Loss: 0.0409


Epoch 3/100: 19it [00:28,  1.49s/it]


Epoch [3/100], Loss: 13.8631, Entity ID Acc: 22.4167%, Entity state Acc: 57.75%, Time Pred Loss: 0.0087


Epoch 4/100: 19it [00:24,  1.29s/it]


Epoch [4/100], Loss: 12.7702, Entity ID Acc: 27.3333%, Entity state Acc: 58.33%, Time Pred Loss: 0.0058


Epoch 5/100: 19it [00:33,  1.75s/it]


Epoch [5/100], Loss: 12.3510, Entity ID Acc: 27.6667%, Entity state Acc: 57.67%, Time Pred Loss: 0.0027


Epoch 6/100: 19it [00:35,  1.85s/it]


Epoch [6/100], Loss: 11.7399, Entity ID Acc: 29.5833%, Entity state Acc: 57.92%, Time Pred Loss: 0.0022


Epoch 7/100: 19it [00:35,  1.87s/it]


Epoch [7/100], Loss: 11.2276, Entity ID Acc: 33.3333%, Entity state Acc: 57.42%, Time Pred Loss: 0.0012


Epoch 8/100: 19it [00:34,  1.83s/it]


Epoch [8/100], Loss: 10.9263, Entity ID Acc: 37.3333%, Entity state Acc: 60.67%, Time Pred Loss: 0.0007


Epoch 9/100: 19it [00:32,  1.74s/it]


Epoch [9/100], Loss: 10.3826, Entity ID Acc: 39.0000%, Entity state Acc: 60.33%, Time Pred Loss: 0.0007


Epoch 10/100: 19it [00:29,  1.56s/it]


Epoch [10/100], Loss: 10.3861, Entity ID Acc: 38.7500%, Entity state Acc: 60.25%, Time Pred Loss: 0.0012


Epoch 11/100: 19it [00:28,  1.48s/it]


Epoch [11/100], Loss: 10.1044, Entity ID Acc: 41.0833%, Entity state Acc: 59.75%, Time Pred Loss: 0.0009


Epoch 12/100: 19it [00:32,  1.71s/it]


Epoch [12/100], Loss: 9.6746, Entity ID Acc: 42.8333%, Entity state Acc: 60.67%, Time Pred Loss: 0.0009


Epoch 13/100: 19it [00:33,  1.78s/it]


Epoch [13/100], Loss: 9.4800, Entity ID Acc: 43.4167%, Entity state Acc: 60.58%, Time Pred Loss: 0.0007


Epoch 14/100: 19it [00:27,  1.47s/it]


Epoch [14/100], Loss: 9.4298, Entity ID Acc: 46.8333%, Entity state Acc: 61.17%, Time Pred Loss: 0.0007


Epoch 15/100: 19it [00:33,  1.79s/it]


Epoch [15/100], Loss: 9.3392, Entity ID Acc: 45.4167%, Entity state Acc: 60.75%, Time Pred Loss: 0.0007


Epoch 16/100: 19it [00:32,  1.71s/it]


Epoch [16/100], Loss: 9.0331, Entity ID Acc: 47.7500%, Entity state Acc: 60.75%, Time Pred Loss: 0.0004


Epoch 17/100: 19it [00:29,  1.55s/it]


Epoch [17/100], Loss: 8.9741, Entity ID Acc: 48.1667%, Entity state Acc: 61.42%, Time Pred Loss: 0.0005


Epoch 18/100: 19it [00:36,  1.90s/it]


Epoch [18/100], Loss: 8.8941, Entity ID Acc: 47.5833%, Entity state Acc: 59.50%, Time Pred Loss: 0.0005


Epoch 19/100: 19it [00:33,  1.79s/it]


Epoch [19/100], Loss: 9.1716, Entity ID Acc: 46.0000%, Entity state Acc: 62.42%, Time Pred Loss: 0.0004


Epoch 20/100: 19it [00:25,  1.33s/it]


Epoch [20/100], Loss: 8.9017, Entity ID Acc: 48.7500%, Entity state Acc: 61.58%, Time Pred Loss: 0.0006


Epoch 21/100: 19it [00:27,  1.47s/it]


Epoch [21/100], Loss: 8.7201, Entity ID Acc: 50.2500%, Entity state Acc: 61.75%, Time Pred Loss: 0.0004


Epoch 22/100: 19it [00:33,  1.77s/it]


Epoch [22/100], Loss: 8.5707, Entity ID Acc: 50.3333%, Entity state Acc: 62.33%, Time Pred Loss: 0.0004


Epoch 23/100: 19it [00:33,  1.79s/it]


Epoch [23/100], Loss: 9.2194, Entity ID Acc: 46.2500%, Entity state Acc: 62.67%, Time Pred Loss: 0.0005


Epoch 24/100: 19it [00:32,  1.68s/it]


Epoch [24/100], Loss: 8.8638, Entity ID Acc: 48.1667%, Entity state Acc: 62.17%, Time Pred Loss: 0.0010


Epoch 25/100: 19it [00:33,  1.76s/it]


Epoch [25/100], Loss: 8.7401, Entity ID Acc: 49.3333%, Entity state Acc: 62.08%, Time Pred Loss: 0.0009


Epoch 26/100: 19it [00:30,  1.59s/it]


Epoch [26/100], Loss: 8.8160, Entity ID Acc: 48.5833%, Entity state Acc: 62.42%, Time Pred Loss: 0.0006


Epoch 27/100: 19it [00:34,  1.82s/it]


Epoch [27/100], Loss: 9.2374, Entity ID Acc: 45.4167%, Entity state Acc: 62.00%, Time Pred Loss: 0.0008


Epoch 28/100: 19it [00:26,  1.37s/it]


Epoch [28/100], Loss: 9.1628, Entity ID Acc: 47.0833%, Entity state Acc: 62.00%, Time Pred Loss: 0.0008


Epoch 29/100: 19it [00:32,  1.68s/it]


Epoch [29/100], Loss: 8.5740, Entity ID Acc: 49.2500%, Entity state Acc: 62.75%, Time Pred Loss: 0.0005


Epoch 30/100: 19it [00:34,  1.80s/it]


Epoch [30/100], Loss: 8.2312, Entity ID Acc: 54.0000%, Entity state Acc: 63.42%, Time Pred Loss: 0.0004


Epoch 31/100: 19it [00:27,  1.44s/it]


Epoch [31/100], Loss: 8.1717, Entity ID Acc: 53.7500%, Entity state Acc: 63.92%, Time Pred Loss: 0.0006


Epoch 32/100: 19it [00:28,  1.47s/it]


Epoch [32/100], Loss: 8.2730, Entity ID Acc: 52.5833%, Entity state Acc: 62.00%, Time Pred Loss: 0.0005


Epoch 33/100: 19it [00:29,  1.57s/it]


Epoch [33/100], Loss: 8.1619, Entity ID Acc: 52.8333%, Entity state Acc: 63.08%, Time Pred Loss: 0.0004


Epoch 34/100: 19it [00:33,  1.77s/it]


Epoch [34/100], Loss: 8.2287, Entity ID Acc: 52.2500%, Entity state Acc: 63.08%, Time Pred Loss: 0.0007


Epoch 35/100: 19it [00:30,  1.59s/it]


Epoch [35/100], Loss: 8.0361, Entity ID Acc: 53.5000%, Entity state Acc: 63.92%, Time Pred Loss: 0.0006


Epoch 36/100: 19it [00:27,  1.47s/it]


Epoch [36/100], Loss: 8.1483, Entity ID Acc: 52.6667%, Entity state Acc: 64.58%, Time Pred Loss: 0.0006


Epoch 37/100: 19it [00:28,  1.49s/it]


Epoch [37/100], Loss: 8.1023, Entity ID Acc: 53.4167%, Entity state Acc: 64.50%, Time Pred Loss: 0.0006


Epoch 38/100: 19it [00:29,  1.57s/it]


Epoch [38/100], Loss: 7.8925, Entity ID Acc: 55.0000%, Entity state Acc: 64.25%, Time Pred Loss: 0.0005


Epoch 39/100: 19it [00:29,  1.57s/it]


Epoch [39/100], Loss: 7.7806, Entity ID Acc: 56.8333%, Entity state Acc: 65.50%, Time Pred Loss: 0.0010


Epoch 40/100: 19it [00:30,  1.59s/it]


Epoch [40/100], Loss: 7.6084, Entity ID Acc: 59.1667%, Entity state Acc: 66.50%, Time Pred Loss: 0.0006


Epoch 41/100: 19it [00:27,  1.44s/it]


Epoch [41/100], Loss: 7.7151, Entity ID Acc: 55.7500%, Entity state Acc: 67.17%, Time Pred Loss: 0.0008


Epoch 42/100: 19it [00:29,  1.56s/it]


Epoch [42/100], Loss: 7.7627, Entity ID Acc: 56.1667%, Entity state Acc: 65.75%, Time Pred Loss: 0.0007


Epoch 43/100: 19it [00:34,  1.81s/it]


Epoch [43/100], Loss: 7.7868, Entity ID Acc: 56.0833%, Entity state Acc: 64.17%, Time Pred Loss: 0.0009


Epoch 44/100: 19it [00:29,  1.58s/it]


Epoch [44/100], Loss: 7.7062, Entity ID Acc: 55.1667%, Entity state Acc: 64.17%, Time Pred Loss: 0.0007


Epoch 45/100: 19it [00:29,  1.54s/it]


Epoch [45/100], Loss: 7.9718, Entity ID Acc: 54.0000%, Entity state Acc: 66.83%, Time Pred Loss: 0.0014


Epoch 46/100: 19it [00:31,  1.68s/it]


Epoch [46/100], Loss: 7.6049, Entity ID Acc: 57.2500%, Entity state Acc: 67.25%, Time Pred Loss: 0.0012


Epoch 47/100: 19it [00:27,  1.47s/it]


Epoch [47/100], Loss: 7.5590, Entity ID Acc: 57.7500%, Entity state Acc: 65.25%, Time Pred Loss: 0.0007


Epoch 48/100: 19it [00:29,  1.54s/it]


Epoch [48/100], Loss: 7.4781, Entity ID Acc: 59.0833%, Entity state Acc: 67.00%, Time Pred Loss: 0.0009


Epoch 49/100: 19it [00:32,  1.74s/it]


Epoch [49/100], Loss: 7.8667, Entity ID Acc: 54.8333%, Entity state Acc: 66.08%, Time Pred Loss: 0.0013


Epoch 50/100: 19it [00:31,  1.64s/it]


Epoch [50/100], Loss: 8.0600, Entity ID Acc: 52.7500%, Entity state Acc: 65.42%, Time Pred Loss: 0.0017


Epoch 51/100: 19it [00:31,  1.66s/it]


Epoch [51/100], Loss: 7.5476, Entity ID Acc: 58.5000%, Entity state Acc: 65.67%, Time Pred Loss: 0.0015


Epoch 52/100: 19it [00:33,  1.77s/it]


Epoch [52/100], Loss: 7.6639, Entity ID Acc: 57.0000%, Entity state Acc: 66.42%, Time Pred Loss: 0.0021


Epoch 53/100: 19it [00:33,  1.77s/it]


Epoch [53/100], Loss: 7.8492, Entity ID Acc: 55.2500%, Entity state Acc: 63.58%, Time Pred Loss: 0.0010


Epoch 54/100: 19it [00:28,  1.51s/it]


Epoch [54/100], Loss: 7.6768, Entity ID Acc: 54.9167%, Entity state Acc: 65.33%, Time Pred Loss: 0.0028


Epoch 55/100: 19it [00:30,  1.61s/it]


Epoch [55/100], Loss: 7.3808, Entity ID Acc: 59.8333%, Entity state Acc: 64.67%, Time Pred Loss: 0.0022


Epoch 56/100: 19it [00:25,  1.34s/it]


Epoch [56/100], Loss: 7.4319, Entity ID Acc: 58.1667%, Entity state Acc: 66.17%, Time Pred Loss: 0.0028


Epoch 57/100: 19it [00:25,  1.36s/it]


Epoch [57/100], Loss: 7.1971, Entity ID Acc: 59.3333%, Entity state Acc: 67.83%, Time Pred Loss: 0.0018


Epoch 58/100: 19it [00:31,  1.66s/it]


Epoch [58/100], Loss: 7.2039, Entity ID Acc: 60.1667%, Entity state Acc: 65.92%, Time Pred Loss: 0.0021


Epoch 59/100: 19it [00:33,  1.75s/it]


Epoch [59/100], Loss: 7.1111, Entity ID Acc: 61.1667%, Entity state Acc: 66.92%, Time Pred Loss: 0.0010


Epoch 60/100: 19it [00:32,  1.69s/it]


Epoch [60/100], Loss: 7.2154, Entity ID Acc: 60.1667%, Entity state Acc: 65.17%, Time Pred Loss: 0.0021


Epoch 61/100: 19it [00:29,  1.55s/it]


Epoch [61/100], Loss: 7.0469, Entity ID Acc: 60.8333%, Entity state Acc: 65.25%, Time Pred Loss: 0.0017


Epoch 62/100: 19it [00:31,  1.67s/it]


Epoch [62/100], Loss: 6.9976, Entity ID Acc: 62.8333%, Entity state Acc: 67.83%, Time Pred Loss: 0.0011


Epoch 63/100: 19it [00:30,  1.63s/it]


Epoch [63/100], Loss: 6.8860, Entity ID Acc: 61.5833%, Entity state Acc: 66.50%, Time Pred Loss: 0.0010


Epoch 64/100: 19it [00:27,  1.45s/it]


Epoch [64/100], Loss: 7.1457, Entity ID Acc: 58.9167%, Entity state Acc: 69.58%, Time Pred Loss: 0.0015


Epoch 65/100: 19it [00:28,  1.52s/it]


Epoch [65/100], Loss: 7.2405, Entity ID Acc: 56.7500%, Entity state Acc: 68.83%, Time Pred Loss: 0.0028


Epoch 66/100: 19it [00:29,  1.57s/it]


Epoch [66/100], Loss: 7.1131, Entity ID Acc: 59.5000%, Entity state Acc: 69.25%, Time Pred Loss: 0.0037


Epoch 67/100: 19it [00:29,  1.54s/it]


Epoch [67/100], Loss: 7.0900, Entity ID Acc: 59.8333%, Entity state Acc: 68.75%, Time Pred Loss: 0.0058


Epoch 68/100: 19it [00:31,  1.68s/it]


Epoch [68/100], Loss: 6.7890, Entity ID Acc: 61.8333%, Entity state Acc: 69.08%, Time Pred Loss: 0.0033


Epoch 69/100: 19it [00:32,  1.73s/it]


Epoch [69/100], Loss: 6.9038, Entity ID Acc: 61.4167%, Entity state Acc: 67.83%, Time Pred Loss: 0.0043


Epoch 70/100: 18it [00:33,  1.86s/it]


KeyboardInterrupt: 