In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import pytorch_lightning as pl 

from config import Config
import sklearn

from models.TransformerEncoder import Transformer

In [2]:
class ModelModule(pl.LightningModule):
    def __init__(self, model, criterion, metrics, args):
        super(ModelModule, self).__init__()
        
        self.criterion = criterion
        self.model = model
        self.metrics = metrics
        
        self.lr_scheduler_parameters = args['lr_scheduler']
        self.optimizer_parameters = args['optimizer']
        
    def forward(self, x):
        x = self.model(x)
        
        return x
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            [{'params': self.model.parameters(), 'lr': self.optimizer_parameters['lr']}],
            weight_decay=self.optimizer_parameters['weight_decay']
        )
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, **self.lr_scheduler_parameters, 
        )
        
        return [optimizer], [lr_scheduler]
    
    
    def training_step(self, batch, batch_idx):
        x_batch, y_batch = batch
        predicted = self.forward(x_batch)
        loss = self.criterion(predicted, y_batch)
        
        self.log('train_loss', loss)

        return loss
    
    def validation_step(self, batch, batch_idx):
        x_batch, y_batch = batch
        predicted = self.forward(x_batch)
        loss = self.criterion(predicted, y_batch)
        
        for metric_name, metric in self.metrics.items():
            self.log('val_' + metric_name, metric(predicted, y_batch))

        self.log('val_loss', loss)
        
        return loss
    
    
class DataModule(pl.LightningModule):
    def __init__(self, datasets, samplers, args):
        super(DataModule, self).__init__()
        
        self.datasets = datasets
        self.samplers = samplers
        
        self.train_dataset_parameters = args['train_dataset']
        self.train_sampler_parameters = args['train_sampler']
        self.train_dataloader_parameters = args['train_dataloader']
        
        self.val_dataset_parameters = args['val_dataset']
        self.val_sampler_parameters = args['val_sampler']
        self.val_dataloader_parameters = args['val_dataloader']
        
    def train_dataloader(self):
        train_dataset = self.datasets['train_dataset']
        train_sampler = self.samplers['train_sampler']
        
        if train_sampler is not None:
            train_sampler = self.train_sampler(**self.train_sampler_parameters)
        
        train_dataloader = DataLoader(
            train_dataset, 
            sampler=train_sampler, 
            **self.train_dataloader_parameters
        )
        
        return train_dataloader
    
    def val_dataloader(self):
        val_dataset = self.datasets['val_dataset']
        val_sampler = self.samplers['val_sampler']
        
        if val_sampler is not None:
            val_sampler = val_sampler(**self.val_sampler_parameters)
            
        val_dataloader = DataLoader(
            val_dataset, 
            sampler=val_sampler, 
            **self.val_dataloader_parameters
        )
        
        return val_dataloader

In [7]:
config = Config()

In [8]:
transformer_model = Transformer(**config.get_property('model'))



In [None]:
!git 