In [None]:
import random
class LSTMHead(nn.Module):
    def __init__(self, in_features, hidden_dim, n_layers):
        super().__init__()
        self.lstm = nn.LSTM(in_features,
                            hidden_dim,
                            n_layers,
                            batch_first=True,
                            bidirectional=True,
                            dropout=0.)
        self.out_features = hidden_dim

    def forward(self, x):
        self.lstm.flatten_parameters()
        hidden, (_, _) = self.lstm(x)
        out = hidden
        return out


class PIIModel(pl.LightningModule):
    def __init__(self,config):
        super().__init__()
        self.cfg = config
        self.model_config = AutoConfig.from_pretrained(
            config.model_name,
        )
        hidden_dropout_prob: float = 0.1
        layer_norm_eps: float = 1e-7
        self.model_config.update(
            {
                "output_hidden_states": True,
                "hidden_dropout_prob": hidden_dropout_prob,
                "layer_norm_eps": layer_norm_eps,
                "add_pooling_layer": False,
            }
        )

        self.transformers_model = AutoModel.from_pretrained(config.model_name,config=self.model_config)
        self.head = LSTMHead(in_features=self.model_config.hidden_size, hidden_dim=self.model_config.hidden_size//2, n_layers=1)

        self.output = nn.Linear(self.model_config.hidden_size, len(self.cfg.target_cols))

        self.loss_function = nn.CrossEntropyLoss(reduction='mean',ignore_index=-100) 
        self.validation_step_outputs = []
        


    def forward(self, input_ids, attention_mask,train):
        
        transformer_out = self.transformers_model(input_ids,attention_mask = attention_mask)
        sequence_output = transformer_out.last_hidden_state
        sequence_output = self.head(sequence_output)
        logits = self.output(sequence_output)
        return (logits, _)

    def training_step(self,batch,batch_idx):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        target = batch['labels'] 

        outputs = self(input_ids,attention_mask,train=True)
        output = outputs[0]

        # loss = self.loss_function(output.view(-1, len(self.cfg.target_cols)),target.view(-1))
        loss = self.loss_function(output.view(-1,len(self.cfg.target_cols)), target.view(-1))

        self.log('train_loss', loss , prog_bar=True)
        return {'loss': loss}
    
    def train_epoch_end(self,outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        print(f'epoch {trainer.current_epoch} training loss {avg_loss}')
        return {'train_loss': avg_loss} 
    
    def train_dataloader(self):
        return self._train_dataloader 
    
    def validation_dataloader(self):
        return self._validation_dataloader

    def get_optimizer_params(self, encoder_lr, decoder_lr, weight_decay=0.0):
        param_optimizer = list(model.named_parameters())
        no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
        optimizer_parameters = [
            {'params': [p for n, p in self.transformers_model.named_parameters() if not any(nd in n for nd in no_decay)],
             'lr': encoder_lr, 'weight_decay': weight_decay},
            {'params': [p for n, p in self.transformers_model.named_parameters() if any(nd in n for nd in no_decay)],
             'lr': encoder_lr, 'weight_decay': 0.0},
            {'params': [p for n, p in self.named_parameters() if "transformers_model" not in n],
             'lr': decoder_lr, 'weight_decay': 0.0}
        ]
        return optimizer_parameters

    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr = config.learning_rate)

        epoch_steps = self.cfg.data_length
        batch_size = self.cfg.batch_size

        warmup_steps = 0.1 * epoch_steps // batch_size
        training_steps = self.cfg.epochs * epoch_steps // batch_size
        # scheduler = get_linear_schedule_with_warmup(optimizer,warmup_steps,training_steps,-1)
        scheduler = get_polynomial_decay_schedule_with_warmup(optimizer, warmup_steps, training_steps, lr_end=1e-6, power=3.0)

        lr_scheduler_config = {
                'scheduler': scheduler,
                'interval': 'step',
                'frequency': 1,
            }

        return {'optimizer': optimizer, 'lr_scheduler': lr_scheduler_config}