In [1]:
import pickle
import numpy as np
import os, re
import torch
import pytorch_lightning as pl

In [2]:
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [2]:
torch.cuda.is_available()

True

In [3]:
with open("train_data.pkl", "rb") as f:
    train_data = pickle.load(f)
with open("test_data.pkl", "rb") as f:
    test_data = pickle.load(f)

In [93]:
class DictDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.embeddings = data["embeddings"]
        self.labels = data["labels"]
    
    def __len__(self):
        return len(self.embeddings)
    
    def __getitem__(self, idx):
        return np.array(self.embeddings[idx]), self.labels[idx]

def create_dataloader(d, batch_size=1):
    dataset = DictDataset(d)
    return torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=4)

In [94]:
train_dataset = create_dataloader(train_data, batch_size=64)
test_dataset = create_dataloader(test_data, batch_size=64)

In [51]:
from sklearn.metrics import roc_auc_score
def AUC(pred, labels):
    return roc_auc_score(labels, pred)

In [82]:
class Model(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.l1 = torch.nn.Linear(768,2)
    def forward(self, x):
        output = self.l1(x)
        return output
    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0),-1)
        output = self(x)
        f = torch.nn.CrossEntropyLoss()
        loss = f(output, y)
        self.log("train_loss", loss)
        return loss
    def validation_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1)
        output = self(x)
        f = torch.nn.CrossEntropyLoss()
        loss = f(output, y)
        logits = torch.nn.functional.softmax(output, dim=1)
        #print(logits.detach().cpu().numpy())
        pred_labels = logits.detach().cpu().numpy()[:,1]
        auc_roc = AUC(pred_labels, y.detach().cpu().numpy())
        self.log_dict({"val_loss":loss, "auc":auc_roc}, on_epoch=True)
    def configure_optimizers(self):
        opt = torch.optim.Adam(self.parameters(), lr=1e-3)
        return opt

In [12]:
#dir(pl.LightningModule)

In [100]:
m = Model()

In [101]:
logger = pl.loggers.TensorBoardLogger("tb_logs", name="my_model")

In [102]:
trainer = pl.Trainer(max_epochs=20,
                    logger=logger,
                    log_every_n_steps=20)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [103]:
trainer.fit(m, train_dataset, test_dataset)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name | Type   | Params
--------------------------------
0 | l1   | Linear | 1.5 K 
--------------------------------
1.5 K     Trainable params
0         Non-trainable params
1.5 K     Total params
0.006     Total estimated model params size (MB)


Sanity Checking: |                                                                                            …

Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

`Trainer.fit` stopped: `max_epochs=20` reached.
