In [7]:
from genie import LatentAction 

In [8]:
from typing import Tuple
from torch import Tensor
import torch.nn as nn
import torch

from math import prod
from torch.nn.functional import mse_loss

from einops.layers.torch import Rearrange

from genie.module import parse_blueprint
from genie.module.quantization import LookupFreeQuantization
from genie.module.video import CausalConv3d, Downsample, Upsample
from genie.utils import Blueprint

REPR_ACT_ENC = (
    ('space-time_attn', {
        'n_repr' : 8,
        'n_heads': 8,
        'd_head': 64,
    }),
)

REPR_ACT_DEC = (
    ('space-time_attn', {
        'n_repr' : 8,
        'n_heads': 8,
        'd_head': 64,
    }),
)

class LatentAction(nn.Module):
    '''Latent Action Model (LAM) used to distill latent actions
    from history of past video frames. The LAM model employs a
    VQ-VAE model to encode video frames into discrete latents.
    Both the encoder and decoder are based on spatial-temporal
    transformers.
    '''
    
    def __init__(
        self,
        enc_desc: Blueprint,
        dec_desc: Blueprint,
        d_codebook: int,
        inp_channels: int = 3,
        inp_shape : int | Tuple[int, int] = (64, 64),
        ker_size : int | Tuple[int, int] = 3,
        n_embd: int = 64,
        n_codebook: int = 2,
        lfq_bias : bool = True,
        lfq_frac_sample : float = 1.,
        lfq_commit_weight : float = 0.25,
        lfq_entropy_weight : float = 0.1,
        lfq_diversity_weight : float = 1.,
        quant_loss_weight : float = 1.,
    ) -> None:
        super().__init__()
        
        if isinstance(inp_shape, int): inp_shape = (inp_shape, inp_shape)
        
        self.proj_in = CausalConv3d(
            inp_channels,
            out_channels=n_embd,
            kernel_size=ker_size
        )
        
        self.proj_out = CausalConv3d(
            n_embd,
            out_channels=inp_channels,
            kernel_size=ker_size
        )
        
        # Build the encoder and decoder based on the blueprint
        self.enc_layers, self.enc_ext = parse_blueprint(enc_desc)
        self.dec_layers, self.dec_ext = parse_blueprint(dec_desc)
        
        # Keep track of space-time up/down factors
        enc_fact = prod(enc.factor for enc in self.enc_layers if isinstance(enc, (Downsample, Upsample)))
        dec_fact = prod(dec.factor for dec in self.dec_layers if isinstance(dec, (Downsample, Upsample)))
        
        assert enc_fact * dec_fact == 1, 'The product of the space-time up/down factors must be 1.'
        
        # Add the projections to the action space
        self.to_act = nn.Sequential(
                Rearrange('b c t ... -> b t (c ...)'),
                nn.Linear(
                    int(n_embd * enc_fact * prod(inp_shape)),
                    d_codebook,
                    bias=False,
                )
        )

        # Build the quantization module
        self.quant = LookupFreeQuantization(
            codebook_dim       = d_codebook,
            num_codebook       = n_codebook,
            input_dim=d_codebook * n_codebook,
            use_bias         = lfq_bias,
            frac_sample      = lfq_frac_sample,
            commit_weight    = lfq_commit_weight,
            entropy_weight   = lfq_entropy_weight,
            diversity_weight = lfq_diversity_weight,
        )
        
        self.d_codebook = d_codebook
        self.n_codebook = n_codebook
        self.quant_loss_weight = quant_loss_weight
        
    def sample(self, idxs : Tensor) -> Tensor:
        '''Sample the action codebook values based on the indices.'''
        return self.quant.codebook[idxs]
        
    def encode(
        self,
        video: Tensor,
        mask : Tensor | None = None,
        transpose : bool = False,
    ) -> Tuple[Tuple[Tensor, Tensor], Tensor]:
        video = self.proj_in(video)
        
        # Encode the video frames into latent actions
        for enc in self.enc_layers:
            video = enc(video, mask=mask)
        
        # Project to latent action space
        act : Tensor = self.to_act(video)

        # Quantize the latent actions
        (act, idxs), q_loss = self.quant(act, transpose=transpose)
        
        return (act, idxs, video), q_loss
    
    def decode(
        self,
        video : Tensor,
        q_act : Tensor,
    ) -> Tensor:        
        # Decode the video frames based on past history and
        # the quantized latent actions
        for dec, has_ext in zip(self.dec_layers, self.dec_ext):
            video = dec(
                video,
                cond=(
                    None, # No space condition
                    q_act if has_ext else None,
                )
            )
            
        recon = self.proj_out(video)
        
        return recon
        
    def forward(
        self,
        video: Tensor,
        mask : Tensor | None = None,
    ) -> Tuple[Tensor, Tensor]:
        
        # Encode the video frames into latent actions
        (act, idxs, enc_video), q_loss = self.encode(video, mask=mask)
        
        # Decode the last video frame based on all the previous
        # frames and the quantized latent actions
        recon = self.decode(enc_video, act)
        
        # Compute the reconstruction loss
        # Reconstruction loss
        rec_loss = mse_loss(recon, video)
        
        # Compute the total loss by combining the individual
        # losses, weighted by the corresponding loss weights
        loss = rec_loss\
            + q_loss * self.quant_loss_weight
        
        return idxs, loss, (
            rec_loss,
            q_loss,
        )

In [9]:
import torch
import torch.nn as nn

class LatentActionClassifier(nn.Module):
    def __init__(self, latent_size: Tuple[int, int], hidden_dim: int = 128, num_classes: int = 25):
        """
        MLP for processing latent representations.
        
        Args:
            latent_size (Tuple[int, int]): Size of the latent tensor (T, num_codebook).
            hidden_dim (int): Dimension of the hidden layers.
            num_classes (int): Number of output classes.
        """
        super(LatentActionClassifier, self).__init__()
        
        T, num_codebook = latent_size  # Extract latent dimensions
        
        input_dim = T * num_codebook  # Flattened input size
        
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),  # Input to hidden
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),  # Hidden to hidden
            nn.ReLU(),
            nn.Linear(hidden_dim, num_classes)  # Hidden to output
        )
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through the MLP.
        
        Args:
            x (torch.Tensor): Latent input tensor of shape (B, T, num_codebook).
        
        Returns:
            torch.Tensor: Logits of shape (B, num_classes).
        """
        # Flatten the latent tensor: (B, T, num_codebook) -> (B, T * num_codebook)
        x = x.view(x.size(0), -1)
        
        # Pass through the MLP
        return self.mlp(x)


In [10]:
LATENT_ACT_ENC = (
    ('space-time_attn', {
        'n_head': 8,
        'd_head': 8,
        'd_inp': 64,
        'd_out': 64,
    }),
    ('spacetime_downsample', {
        'in_channels': 64,
        'out_channels': 64,
        'kernel_size': 3,
        'space_factor': 2,
        'time_factor': 1,
    }),
)

LATENT_ACT_DEC = (
    ('spacetime_upsample', {
        'in_dim': 64,
        'out_dim': 64,
        # 'kernel_size': 3,
        'space_factor': 2,
        'time_factor': 1,
    }),
    ('space-time_attn', {
        'n_head': 8,
        'd_head': 8,
        'd_inp': 64,
        'd_out': 64,
    }),
)

In [11]:
from torch.utils.data import DataLoader

# Example Dataset and DataLoader (Replace with actual dataset)
class ExampleDataset(torch.utils.data.Dataset):
    def __init__(self, num_samples=100):
        self.num_samples = num_samples

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        video = torch.randn(3, 16, 64, 64)  # Example video (C, T, H, W)
        label = torch.randn(25) # Example action label
        return video, label

In [12]:
# Video (B, T, C, H, W)
# Action (B, A)

import torch
from torch.optim import Adam
import torch.nn.functional as F
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}, GPUs available: {torch.cuda.device_count()}")

# Instantiate the LatentAction model
latent_action_model = LatentAction(
    enc_desc=LATENT_ACT_ENC,
    dec_desc=LATENT_ACT_DEC,
    d_codebook=10,
    n_codebook=1,
    inp_channels=3,  
    inp_shape=(64, 64),
    n_embd=64,
)

# if torch.cuda.device_count() > 1:
#     latent_action_model = torch.nn.DataParallel(latent_action_model)

latent_action_model = latent_action_model.to(device)

# Instantiate the integrated classifier
classifier = LatentActionClassifier(latent_size=[16, 10]).to(device)

# Combine both models
latent_action_model.classifier = classifier

dataset = ExampleDataset()
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

# Optimizer
optimizer = Adam(
    list(latent_action_model.parameters()) + list(classifier.parameters()),
    lr=1e-4
)

# Training loop
latent_action_model.train()
for epoch in range(10):  # Number of epochs
    epoch_loss = 0
    epoch_recon_loss = 0
    epoch_mlp_loss = 0

    progress_bar = tqdm(dataloader, desc=f"Training Epoch {epoch + 1}", unit="batch")
    
    for videos, labels in progress_bar:
        # Videos: (B, C, T, H, W)
        # Labels: (B,)
        videos = videos.to(device)
        labels = labels.to(device)

        # Forward pass through the LatentAction model
        idxs, latent_loss, (recon_loss, quant_loss) = latent_action_model(videos)

        # Compute refined embeddings from the quantized indices
        refined_embeddings = latent_action_model.sample(idxs)

        # Forward pass through the classifier
        logits = classifier(refined_embeddings)

        # Classification loss
        mlp_loss = F.mse_loss(logits, labels)

        # Total loss
        total_loss = latent_loss + mlp_loss

        # Backward pass and optimization
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        # Update losses for progress bar and epoch tracking
        epoch_loss += total_loss.item()
        epoch_recon_loss += recon_loss.item()
        epoch_mlp_loss += mlp_loss.item()

        # Update progress bar with individual losses
        progress_bar.set_postfix(
            total_loss=total_loss.item(),
            recon_loss=recon_loss.item(),
            mlp_loss=mlp_loss.item(),
        )

    # Print summary of losses at the end of the epoch
    # print(
    #     f"Epoch {epoch + 1} completed. "
    #     f"Total Loss: {epoch_loss:.4f}, "
    #     f"Reconstruction Loss: {epoch_recon_loss:.4f}, "
    #     f"MLP Loss: {epoch_mlp_loss:.4f}"
    # )


  return torch._dynamo.disable(fn, recursive)(*args, **kwargs)


Using device: cuda, GPUs available: 4


  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Training Epoch 1:  31%|███       | 4/13 [00:06<00:14,  1.64s/batch, mlp_loss=1.1, recon_loss=2.46, total_loss=4.15]  


KeyboardInterrupt: 

In [26]:
import numpy as np

# Load the .npy file
data = np.load('/workspace/wm-group11-1x-challenge3/combined_action_data/frame_045376.npy')

# Print or process the data
print(data)

[ 0.02116361 -0.00173624 -0.13813464  0.19783622  0.02508382 -0.04306014
  0.29541606  0.00993608 -0.02094108 -1.2919362   0.02760345  0.20123062
 -0.00949717  0.29516238 -0.00792499  0.02227568 -1.2910248  -0.02258378
  0.19646356 -0.01153214  0.01455284  0.          0.          1.
  0.83336455]


In [24]:
10832584/45377

238.7241113339357

In [38]:
import numpy as np
import os

class CustomDataset(nn.Module):
    def __init__(self, action_dir, video_dir):
        super(CustomDataset, self).__init__()

        action_files = [f for f in os.listdir(action_dir) if f.endswith('.npy')]
        video_files = [f for f in os.listdir(video_dir) if f.endswith('.npy')]
        
        self.actions = [
            torch.tensor(np.load(os.path.join(action_dir, action_file)), dtype=torch.float32)
            for action_file in action_files
        ]
        
        self.videos = [
            torch.tensor(np.load(os.path.join(video_dir, video_file)), dtype=torch.float32)
            for video_file in video_files
        ]

        print(len(self.actions))
        print(len(self.videos))
        # assert len(self.actions) == len(self.videos), "Mismatch between actions and videos!"

    def __len__(self):
        return len(self.actions)

    def __getitem__(self, idx):
        return self.videos[idx], self.actions[idx]
        
        

In [39]:
dataset = CustomDataset(
    action_dir="/workspace/wm-group11-1x-challenge3/combined_action_data", 
    video_dir="/workspace/wm-group11-1x-challenge3/combined_action_data"
)

53137
53137


In [32]:
dataset.actions.shape

AttributeError: 'list' object has no attribute 'shape'

In [40]:
action_dir="/workspace/wm-group11-1x-challenge3/combined_action_data"
video_dir="/workspace/wm-group11-1x-challenge3/combined_action_data"

action_files = [f for f in os.listdir(action_dir) if f.endswith('.npy')]
video_files = [f for f in os.listdir(video_dir) if f.endswith('.npy')]

print("Action Count:", len(action_files))
print("Video Count:", len(video_files))


Action Count: 53190
Video Count: 53190
