In [19]:
from src.models.architectures.transformer import Encoder_TRANSFORMER, Decoder_TRANSFORMER
from src.datasets.amass import AMASS
from src.datasets.get_dataset import get_datasets
from pytorch_metric_learning import losses
import torch.nn as nn
import pytorch_lightning as pl
from collections import OrderedDict
from typing import Tuple, Union
import clip
import torch
from transformers import RobertaTokenizer, RobertaModel
from src.utils.tensors import collate
from torch.utils.data import DataLoader
from src.models.tools.losses import get_loss_function
from src.models.rotation2xyz import Rotation2xyz
import wandb
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from src.utils.action_classifier import evaluate

In [20]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
wandb.login(key='93443c480bfbaa0b19be76d24f2efeb6be3319fd')
wandb.init(project='text2motion', name='MOTIONCLIP_DECOMP')





In [21]:
parameters = {
    'expname': 'exps',
    'folder': './exps/clip',
    'cuda': True,
    'device': torch.device(type='cuda', index=0),
    'batch_size': 80,
    'num_epochs': 500,
    'lr': 0.0002,
    'snapshot': 20,
    'dataset': 'babel',
    'datapath': './data/amass/amass_db/babel_30fps_db.pt',
    'num_frames': 60,
    'sampling': 'conseq',
    'sampling_step': 1,
    'pose_rep': 'rot6d',
    'max_len': -1,
    'min_len': -1,
    'num_seq_max': -1,
    'glob': True,
    'glob_rot': [3.141592653589793, 0, 0],
    'translation': True,
    'debug': False,
    'use_action_cat_as_text_labels': True,
    'only_60_classes': True,
    'use_only_15_classes': False,
    'modelname': 'motionclip_transformer_rc_rcxyz_vel',
    'latent_dim': 512,
    'lambda_rc': 95.0,
    'lambda_rcxyz': 95.0,
    'lambda_vel': 95.0,
    'lambda_velxyz': 1.0,
    'jointstype': 'vertices',
    'vertstrans': False,
    'num_layers': 8,
    'activation': 'gelu',
    'clip_image_losses': ['cosine'],
    'clip_text_losses': ['cosine'],
    'clip_lambda_mse': 1.0,
    'clip_lambda_ce': 1.0,
    'clip_lambda_cosine': 1.0,
    'clip_training': '',
    'clip_layers': 12,
    'modeltype': 'motionclip',
    'archiname': 'transformer',
    'losses': ['rc', 'rcxyz', 'vel'],
    'lambdas': {'rc': 95.0, 'rcxyz': 95.0, 'vel': 95.0},
    'clip_lambdas': {'image': {'cosine': 1.0}, 'text': {'cosine': 1.0}},
    'num_classes': 1,
    'nfeats': 6,
    'njoints': 25,
    'outputxyz': True
}

In [4]:
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device, jit=False)  # Must set jit=False for trainin)
train_dataset = get_datasets(parameters=parameters, clip_preprocess=clip_preprocess, split='train')['train']
val_dataset = get_datasets(parameters=parameters, 
                           clip_preprocess=clip_preprocess, split='vald')['test']


datapath used by amass is [./data/amass/amass_db/babel_30fps_train.pt]
BROTHER???
BROTHER???
BROTHER???
BROTHER???
BROTHER???
BROTHER???
BROTHER???
BROTHER???
BROTHER???
BROTHER???
datapath used by amass is [./data/amass/amass_db/babel_30fps_vald.pt]
BROTHER???
BROTHER???


In [22]:
import torch.nn.functional as F
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=1):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, target):
        # Cosine similarity between the two outputs
        cos_sim = F.cosine_similarity(output1, output2)

        # Contrastive loss calculation
        loss_contrastive = torch.mean((1 - target) * torch.pow(1 - cos_sim, 2) +
                                      (target) * torch.pow(torch.clamp(cos_sim - self.margin, min=0.0), 2))

        return loss_contrastive

In [51]:
class SkelCLIP(pl.LightningModule):
    def __init__(self, encoder, decoder, outputxyz=True, train_dataset=None, val_dataset=None, pose_rep='rot6d',
                 lambdas=None, text_cosine_lambda=1.0, latent_dim=512, glob_rot=[3.141592653589793, 0, 0], glob=True, translation=True,
                 jointstype='vertices', vertstrans=False,
                 image_cosine_lambda=1.0, batch_size=80, lr=1e-3, **kwargs):
        super(SkelCLIP, self).__init__()
        model_name = 'roberta-base'
        self.tokenizer = RobertaTokenizer.from_pretrained(model_name)
        self.text_encoder = RobertaModel.from_pretrained(model_name)
        for param in self.text_encoder.parameters():
            param.requires_grad = False
        self.econder_to_latent = nn.Linear(768, latent_dim)
        self.encoder = encoder 
        self.decoder = decoder
        self.pose_rep = pose_rep
        self.batch_size = batch_size
        self.lr = lr
        self.clip_model, self.clip_preprocess = clip.load("ViT-B/32", device=self.device, jit=False)
        self.clip_model.eval()
        for p in self.clip_model.parameters():
            p.requires_grad = False
        self.outputxyz = outputxyz
        self.clip_lambdas = {
            'text': {'cosine': text_cosine_lambda},
            #'image': {'cosine': text_cosine_lambda},
        }
        
        self.lambdas=lambdas
        self.glob_rot = glob_rot
        self.glob = glob
        self.translation = translation
        self.ae_lambdas = lambdas
        self.cosine_sim = nn.CosineSimilarity(dim=1, eps=1e-6)
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.vertstrans = vertstrans
        self.jointstype = jointstype
        self.rotation2xyz = Rotation2xyz(device=torch.device(type='cpu', index=0))
        self.param2xyz = {"pose_rep": self.pose_rep,
                          "glob_rot": self.glob_rot,
                          "glob": self.glob,
                          "jointstype": self.jointstype,
                          "translation": self.translation,
                          "vertstrans": self.vertstrans}
        self.contrastive = ContrastiveLoss()
    
    def rot2xyz(self, x, mask, get_rotations_back=False, **kwargs):
        kargs = self.param2xyz.copy()
        kargs.update(kwargs)
        return self.rotation2xyz(x, mask, get_rotations_back=get_rotations_back, **kargs)    
    
    def forward(self, batch):
        if self.outputxyz:
            batch["x_xyz"] = self.rot2xyz(batch["x"], batch["mask"])
            
        batch.update(self.encoder(batch))
        batch["z"] = batch["mu"]
        # decode
        batch.update(self.decoder(batch))
        # if we want to output xyz
        if self.outputxyz:
            batch["output_xyz"] = self.rot2xyz(batch["output"], batch["mask"])
        return batch
    
    def compute_ae_loss(self, batch, loop='train'):
        mixed_loss = 0.
        for ltype, lam in self.lambdas.items():
            loss_function = get_loss_function(ltype) 
            loss = loss_function(batch) * lam
            self.log(f'{ltype}_{loop}', loss, sync_dist=True, batch_size=self.batch_size)
            mixed_loss += loss
        return mixed_loss
    
    def compute_clip_loss(self, batch):
        mixed_clip_loss = 0.
        clip_losses = {}
        #target = torch.eye(batch['z'].shape[0]).to(self.device)
        for d in self.clip_lambdas.keys():
            if len(self.clip_lambdas[d].keys()) == 0:
                continue
            with torch.no_grad():
                if d == 'image':
                    features = self.clip_model.encode_image(
                        batch['clip_images']).float()  # preprocess is done in dataloader
                elif d == 'text':
                    print(batch['clip_text'])
                    tokens = self.tokenizer(batch['clip_text'], return_tensors="pt", padding=True, truncation=True)
                    features = self.econder_to_latent(
                        torch.mean(self.text_encoder(**tokens).last_hidden_state, dim=1))
                    print(features.shape)
                else:
                    raise ValueError(f'Invalid clip domain [{d}]')

            # normalized features
            features_norm = features / features.norm(dim=-1, keepdim=True)
            seq_motion_features_norm = batch["z"] / batch["z"].norm(dim=-1, keepdim=True)
            # mixed_clip_loss += self.contrastive(features_norm, seq_motion_features_norm, target)
            cos = self.cosine_sim(features_norm, seq_motion_features_norm)
            cosine_loss = (1 - cos).mean()
            clip_losses[f'{d}_cosine'] = cosine_loss.item()
            mixed_clip_loss += cosine_loss * self.clip_lambdas[d]['cosine']
            
        return mixed_clip_loss, clip_losses
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)
    
    def training_step(self, batch, batch_idx):
        output = self(batch)
        clip_loss, clip_losses = self.compute_clip_loss(output)
        self.log('clip_loss_train', clip_loss, sync_dist=True, on_step=True, batch_size=self.batch_size)
        # ae_loss = self.compute_ae_loss(output)
        return clip_loss

    def validation_step(self, batch, batch_idx):
        output = self(batch)
        clip_loss, clip_losses = self.compute_clip_loss(output)
        # ae_loss = self.compute_ae_loss(output, loop='val')
        self.log('clip_loss_val', clip_loss, sync_dist=True, batch_size=self.batch_size)
        return clip_loss
    
    def on_validation_epoch_end(self):
        self.eval()
        top_1_acc, top_5_acc = evaluate(self, self.val_dataset, self.val_dataloader(), {})
        self.train()
        self.log('top_1_acc', top_1_acc, sync_dist=True, batch_size=self.batch_size)
        self.log('top_5_acc', top_1_acc, sync_dist=True, batch_size=self.batch_size)
    
    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size,
                                shuffle=True, num_workers=32, collate_fn=collate)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size,
                                shuffle=False, num_workers=32, collate_fn=collate)
    
    def to(self, device):
        self.rotation2xyz.smpl_model = self.rotation2xyz.smpl_model.to(device)
        self = super(SkelCLIP, self).to(device)
        return self
    
        
        

In [52]:
print(len(train_dataset))
print(train_dataset.__getitem__(100))

20112
BROTHER???
{'inp': tensor([[[-4.7670e-01, -5.1837e-01, -5.8202e-01,  ..., -3.9532e-01,
          -4.6517e-01, -5.4573e-01],
         [ 1.0069e-01,  9.9092e-02,  7.9911e-02,  ..., -1.1173e-01,
          -1.1524e-01, -1.1933e-01],
         [-8.7328e-01, -8.4939e-01, -8.0924e-01,  ...,  9.1172e-01,
           8.7769e-01,  8.2942e-01],
         [-8.7826e-01, -8.5454e-01, -8.1267e-01,  ...,  9.1853e-01,
           8.8518e-01,  8.3753e-01],
         [-9.6918e-02, -9.7796e-02, -9.2281e-02,  ..., -5.3032e-02,
          -6.9794e-02, -1.0936e-01],
         [ 4.6825e-01,  5.1010e-01,  5.7538e-01,  ...,  3.9178e-01,
           4.5998e-01,  5.3533e-01]],

        [[ 9.6482e-01,  9.6570e-01,  9.6878e-01,  ...,  8.9487e-01,
           9.2027e-01,  9.4676e-01],
         [-2.5414e-01, -2.4341e-01, -2.2263e-01,  ..., -3.9268e-01,
          -3.5102e-01, -2.8709e-01],
         [ 6.7417e-02,  9.0409e-02,  1.0912e-01,  ...,  2.1217e-01,
           1.7288e-01,  1.4570e-01],
         [ 2.4121e-01,  2.28

In [53]:
encoder = Encoder_TRANSFORMER(**parameters)
decoder = Decoder_TRANSFORMER(**parameters)
model = SkelCLIP(encoder, decoder, train_dataset=train_dataset, val_dataset=val_dataset, **parameters)

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [54]:
example = train_dataset.__getitem__(100)
output = model.compute_clip_loss(collate([train_dataset.__getitem__(100)]))

BROTHER???
BROTHER???
['martial art']
torch.Size([1, 512])


KeyError: 'z'

In [13]:
output.keys()

dict_keys(['x', 'y', 'mask', 'lengths', 'clip_images', 'clip_text', 'clip_path', 'all_categories', 'x_xyz', 'mu', 'z', 'output', 'output_xyz'])

In [11]:
output['x'].shape

torch.Size([1, 25, 6, 60])

In [12]:
output['y'].shape

torch.Size([1])

In [13]:
output['mask'].shape

torch.Size([1, 60])

In [14]:
output['mu'].shape

torch.Size([1, 512])

In [15]:
output['z'].shape

torch.Size([1, 512])

In [16]:
# output['output'].shape

In [17]:
wandb_logger = WandbLogger(project="CSNER", name='MOTIONCLIP_DECOMP')

checkpoint_callback = ModelCheckpoint(
    monitor='top_1_acc',
    mode='max',
    save_top_k=2,
    dirpath='./skel_clip_exp'
)

torch.set_float32_matmul_precision('medium')
trainer = pl.Trainer(accelerator="gpu", max_epochs=parameters['num_epochs'], callbacks=[checkpoint_callback], logger=wandb_logger, fast_dev_run=False, accumulate_grad_batches=100)  #,  val_check_interval=100, strategy='ddp_notebook', devices=4,
trainer.fit(model)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/opt/conda/envs/motionclip/lib/python3.8/site-packages/pytorch_lightning/loggers/wandb.py:389: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
/opt/conda/envs/motionclip/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:639: Checkpoint directory ./skel_clip_exp exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name        | Type                | Params
----------------------------------------------------
0 | encoder     | Encoder_TRANSFORMER | 16.9 M
1 | decoder     | Decoder_TRANSFORMER | 25.3 M
2 | clip_model  | CLIP                | 151 M 
3 | cosine_sim  | CosineSimilarity    | 0     
4 | contrastive | ContrastiveLoss     | 0     
----------------

Epoch 0:   0%|          | 0/252 [00:00<?, ?it/s]                           

OutOfMemoryError: CUDA out of memory. Tried to allocate 1.73 GiB. GPU 0 has a total capacty of 22.02 GiB of which 121.25 MiB is free. Process 7872 has 14.55 GiB memory in use. Including non-PyTorch memory, this process has 7.35 GiB memory in use. Of the allocated memory 4.74 GiB is allocated by PyTorch, and 1.43 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF