In [1]:
import os,sys,json

sys.path.append("../")

In [2]:
from src.hooks.core import *

In [3]:
from src.utils.checkpoints import *

In [108]:
from src import utils
from src.utils.train_logger import FileTrainLogger
from src.dataio.dataset import DynamicItemDataset
from src.dataio.batch import PaddedBatch #padded_keys
from functools import partial

logger = FileTrainLogger(
    './log.json'
)

In [127]:
from torch.optim import SGD
import torch

class SimpleBrain(Brain):
    
    def compute_forward(self, batch, stage):
#         print(batch.__dict__,) 
        batch = batch.to(self.device)
#         print(batch.x_train.data.size(), batch.y_train.data.size())
        return self.modules.model(batch.x_train.data)
    def compute_objectives(self, predictions, batch, stage):
#         print(batch.id, batch.x_train, batch.y_train)
#         print(batch[1].size(),predictions.size())
        
        self.loss_metric.append(
             batch.id, predictions, batch.y_train.data
        )

        # Compute classification error at test time
        if stage != src.hooks.core.Stage.TRAIN:
            self.error_metrics.append(batch.id, predictions, batch.y_train.data,)
            
        return torch.nn.functional.l1_loss(predictions, batch.y_train.data)
    
    def on_stage_start(self, stage, epoch=None):
        """Gets called at the beginning of each epoch.
        Arguments
        ---------
        stage : sb.Stage
            One of sb.Stage.TRAIN, sb.Stage.VALID, or sb.Stage.TEST.
        epoch : int
            The currently-starting epoch. This is passed
            `None` during the test stage.
        """

        # Set up statistics trackers for this stage
        self.loss_metric = utils.metric_stats.Metric(
            metric= torch.nn.functional.l1_loss
        )

        # Set up evaluation-only statistics trackers
        if stage != src.hooks.core.Stage.TRAIN:
            self.error_metrics = utils.metric_stats.Metric(
            metric= torch.nn.functional.l1_loss
        )
    def on_stage_end(self, stage, stage_loss, epoch=None):
        """Gets called at the end of an epoch.
        Arguments
        ---------
        stage : sb.Stage
            One of sb.Stage.TRAIN, sb.Stage.VALID, sb.Stage.TEST
        stage_loss : float
            The average loss for all of the data processed in this stage.
        epoch : int
            The currently-starting epoch. This is passed
            `None` during the test stage.
        """

        # Store the train loss until the validation stage.
        if stage ==  src.hooks.core.Stage.TRAIN:
            self.train_loss = stage_loss

        # Summarize the statistics from the stage for record-keeping.
        else:
            stats = {
                "loss": stage_loss,
                "error": self.error_metrics.summarize("average"),
            }

        # At the end of validation...
        if stage ==  src.hooks.core.Stage.VALID:

            # The train_logger writes a summary to stdout and to the logfile.
            logger.log_stats(
                {"Epoch": epoch,},
                train_stats={"loss": self.train_loss},
                valid_stats=stats,
            )

            # Save the current checkpoint and delete previous checkpoints,
            self.checkpointer.save_and_keep_only(meta=stats, min_keys=["error"])

        
            
model = torch.nn.Linear(in_features=10, out_features=10)
checkpoint = Checkpointer("./")
brain = SimpleBrain({"model": model}, opt_class=lambda x: SGD(x, 0.1),checkpointer=checkpoint)
# brain.fit(range(1), ([torch.rand(10, 10), torch.rand(10, 10)],))

In [128]:
X_train = torch.rand(1000, 10)
y_train  = X_train*2 + 3
X_test = torch.rand(100, 10)
y_test  = X_test*2 + 3

data_set = {
    i:{'x_train':k[0],'y_train':k[1]} for i,k in enumerate(zip(X_train,y_train))
}
val_set = {
    i:{'x_val':k[0],'y_val':k[1]} for i,k in enumerate(zip(X_test,y_test))
}

@utils.data_pipeline.takes("x_train","y_train")
@utils.data_pipeline.provides("x_train1","y_train1","double")
def audio_pipeline(x_train, y_train):
    """Load the signal, and pass it and its length to the corruption class.
    This is done on the CPU in the `collate_fn`."""
#     sig = sb.dataio.dataio.read_audio(wav)
    return x_train+1, y_train+2, y_train *2

# # Define label pipeline:
# @sb.utils.data_pipeline.takes("spk_id")
# @sb.utils.data_pipeline.provides("spk_id", "spk_id_encoded")
# def label_pipeline(spk_id):
#     yield spk_id
#     spk_id_encoded = label_encoder.encode_label_torch(spk_id)
#     yield spk_id_encoded

# Define datasets. We also connect the dataset with the data processing
# functions defined above.
dataset = DynamicItemDataset(data_set, dynamic_items=[audio_pipeline])

In [129]:
dataset.set_output_keys(['id','x_train','y_train','double'])


In [130]:
dataset[0]

{'id': 0,
 'x_train': tensor([0.5618, 0.4003, 0.6037, 0.6277, 0.2039, 0.3526, 0.0206, 0.8732, 0.9231,
         0.0463]),
 'y_train': tensor([4.1237, 3.8005, 4.2074, 4.2553, 3.4078, 3.7052, 3.0412, 4.7464, 4.8461,
         3.0925]),
 'double': tensor([8.2474, 7.6011, 8.4148, 8.5106, 6.8156, 7.4104, 6.0825, 9.4928, 9.6923,
         6.1850])}

In [131]:
brain.fit(range(4), dataset,
          valid_set= dataset,
         train_loader_kwargs={'batch_size':32},
        valid_loader_kwargs={'batch_size':32},)

100%|██████████| 32/32 [00:00<00:00, 215.96it/s, train_loss=3.37]
100%|██████████| 32/32 [00:00<00:00, 513.18it/s]
100%|██████████| 32/32 [00:00<00:00, 183.66it/s, train_loss=2.27]
100%|██████████| 32/32 [00:00<00:00, 390.67it/s]
100%|██████████| 32/32 [00:00<00:00, 217.48it/s, train_loss=1.22]
100%|██████████| 32/32 [00:00<00:00, 459.99it/s]
100%|██████████| 32/32 [00:00<00:00, 236.47it/s, train_loss=0.626]
100%|██████████| 32/32 [00:00<00:00, 462.24it/s]
