In [1]:
import numpy as np
import pandas as pd
from os.path import join

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from pytorch_lightning import loggers as pl_loggers

import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint

from utils.data import IsRelevantDataset

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"
    
from transformers import BertModel
bert_id = "google/bert_uncased_L-2_H-128_A-2"

In [2]:
class RelevantModule(pl.LightningModule):
    def __init__(self, bert: BertModel, input_size: int, output_size: int, start_lr=1e-4):
        super().__init__()        
        self.bert = bert
        self.linear_after_bert = nn.Linear(bert.config.hidden_size, 256)
        self.feed_forward = nn.Sequential(
            #nn.BatchNorm1d(bert.config.hidden_size + input_size),#just a feeling this might be nice
            nn.Linear(256 + input_size, 1024),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(1024, 256),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(64, output_size)
        )
        
        for 
        
        self.start_lr = start_lr

    def forward(self, x):
        x_bert = x[0]
        x_other = x[1]
        y_bert = self.bert(x[0])["last_hidden_state"][:,0] #all batches but only clf output
        y_bert = self.linear_after_bert(y_bert)
        x = torch.cat((y_bert, x_other), dim=1)#dim=1 is feature dimensions (0 is batch dim)
        
        return self.feed_forward(x)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.start_lr)
        #return optimizer
        return {
           'optimizer': optimizer,
           'lr_scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, factor =.2, patience =1, cooldown =2, min_lr =1e-6),
           'monitor': 'val_loss'
       }
    def training_step(self, train_batch, batch_idx):
        x, y = train_batch

        y_hat = self.forward(x)
        loss = F.cross_entropy(y_hat, y)

        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss
    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        y_hat = self.forward(x)
        loss = F.cross_entropy(y_hat, y)
        acc = torch.sum(torch.argmax(y_hat, dim=-1) == y) / len(y)

        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log("val_acc", acc, on_step=False, on_epoch=True, prog_bar=True, logger=True)

In [3]:
train_joint = pd.read_hdf(join("preprocessed_data","train_joint.h5"), key="s")
validation_join = pd.read_hdf(join("preprocessed_data","validation_joint.h5"), key="s")

In [4]:
train_ds = IsRelevantDataset(train_joint, device = device)
validation_ds = IsRelevantDataset(validation_join, device = device, dimensions = train_ds.dimensions)

In [5]:
lr = 1e-4
batch_size=16
epochs = 10
bert_hidden = 768

In [6]:
train_dl = DataLoader(train_ds,batch_size  = batch_size, shuffle=True)
validation_dl = DataLoader(train_ds, batch_size  = 64, shuffle=False)

In [7]:
# model
model = RelevantModule(BertModel.from_pretrained(bert_id).to(device), sum(train_ds.dimensions[0][1]), train_ds.dimensions[1])
#logging
tb_logger = pl_loggers.TensorBoardLogger('logs', name="No_Normalization")
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    dirpath='checkpoints/',
    filename='-{epoch:02d}-{val_loss:.2f}'
)
# training
trainer = Trainer(gpus=1, precision=16, logger=tb_logger, callbacks=[checkpoint_callback], num_sanity_val_steps = 5)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Using native 16bit precision.


In [None]:
trainer.fit(model, train_dl, validation_dl)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type       | Params
-------------------------------------------------
0 | bert              | BertModel  | 4.4 M 
1 | linear_after_bert | Linear     | 33.0 K
2 | feed_forward      | Sequential | 713 K 
-------------------------------------------------
5.1 M     Trainable params
0         Non-trainable params
5.1 M     Total params
20.528    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]