In [None]:
"""
Example template for defining a system
"""
import logging as log
import os
from argparse import ArgumentParser
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch import optim
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler

import pytorch_lightning as pl
from pytorch_lightning.core.lightning import LightningModule#, seed_everything

from pytorch_metric_learning import losses, samplers

# seed_everything(42)

In [None]:
import sys
sys.path.append('../')

from model import agcn, msg3d
from graph import ntu_rgb_d
from feeders import feeder

In [None]:
class GCNDMLModel(LightningModule):
    """
    Sample model to show how to define a template
    """

    def __init__(self, hparams):
        """
        Pass in parsed HyperOptArgumentParser to the model
        :param hparams:
        """
        # init superclass
        super(GCNDMLModel, self).__init__()
#         self.hparams = hparams

        self.batch_size = 4
    
        self.learning_rate = 0.0001

        self.model = agcn.Model(graph="graph.ntu_rgb_d.Graph")
        
        self.metric_loss = losses.TripletMarginLoss()

#     # ---------------------
#     # TRAINING
#     # ---------------------
    def forward(self, x):
        embeddings = self.model(x)
        return embeddings

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        metric_loss = self.metric_loss(y_hat, y)
        output = {"loss": metric_loss}
        return output

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        metric_loss = self.metric_loss(y_hat, y)
        output = {"loss": metric_loss}
        output = {"val_loss", torch.tensor(0)}
        return output

    def validation_end(self, outputs):
        result = {"val_loss": 0}
        return result

    # ---------------------
    # TRAINING SETUP
    # ---------------------
    def configure_optimizers(self):
        """
        return whatever optimizers we want here
        :return: list of optimizers
        """
        optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
        return [optimizer], [scheduler]

    def __dataloader(self, dataset_type="train"):
        data_path = "/home/raphael/git/graph_metric_learning/"
        
        if dataset_type == "train":
            dataset = feeder.Feeder(data_path=data_path+"data/ntu/one_shot/train_data_joint.npy",
                                          label_path=data_path+"data/ntu/one_shot/train_label.pkl",
                                          debug=False)
        
        elif dataset_type == "test":
            dataset = feeder.Feeder(data_path=data_path+"data/ntu/one_shot/val_data_joint.npy",
                                       label_path=data_path+"data/ntu/one_shot/val_label.pkl",
                                       debug=False)
        elif dataset_type == "samples":
            sample_dataset = feeder.Feeder(data_path=data_path+"data/ntu/one_shot/sample_data_joint.npy",
                                       label_path=data_path+"data/ntu/one_shot/sample_label.pkl",
                                       debug=False)
        sampler = samplers.MPerClassSampler(dataset.label, m=4, 
                                            length_before_new_iter=len(dataset))
    
        dataloader = DataLoader(dataset, batch_size=self.batch_size, sampler=sampler,
#                batch_sampler=None, num_workers=0, collate_fn=None,
#                pin_memory=False, drop_last=False, timeout=0,
#                worker_init_fn=None, *, prefetch_factor=2,
#                persistent_workers=False
                               )

        return dataloader

    @pl.data_loader
    def train_dataloader(self):
        log.info('Training data loader called.')
        return self.__dataloader(dataset_type="train")

    @pl.data_loader
    def val_dataloader(self):
        log.info('Validation data loader called.')
        return self.__dataloader(dataset_type="test")
    
    @pl.data_loader
    def sample_dataloader(self):
        log.info('Sample data loader called.')
        return self.__dataloader(dataset_type="samples")

In [None]:
model = GCNDMLModel(None)

# ------------------------
# 2 INIT TRAINER
# ------------------------
trainer = pl.Trainer(
    gpus=-1,
#     distributed_backend=hparams.distributed_backend,
    max_epochs=1,
    use_amp=False
)

# ------------------------
# 3 START TRAINING
# ------------------------
trainer.fit(model)

In [None]:
modelx = GCNDMLModel(None)
a = modelx.train_dataloader()
for batch_ndx in enumerate(a):
    print(batch_ndx)
#     print(x,y,z)