In [1]:
import pickle
import numpy as np
import os, re
import torch
import pytorch_lightning as pl

In [2]:
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [2]:
torch.cuda.is_available()

True

In [2]:
with open("train_data.pkl", "rb") as f:
    train_data = pickle.load(f)
with open("test_data.pkl", "rb") as f:
    test_data = pickle.load(f)

In [3]:
class DictDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.embeddings = data["embeddings"]
        self.labels = data["labels"]
    
    def __len__(self):
        return len(self.embeddings)
    
    def __getitem__(self, idx):
        return np.array(self.embeddings[idx]), self.labels[idx]

def create_dataloader(d, batch_size=1):
    dataset = DictDataset(d)
    return torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=4)

In [4]:
train_dataset = create_dataloader(train_data, batch_size=64)
test_dataset = create_dataloader(test_data, batch_size=64)

In [5]:
from sklearn.metrics import roc_auc_score
def AUC(pred, labels):
    return roc_auc_score(labels, pred)

In [40]:
class Model(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.l1 = torch.nn.Linear(768,2)
        self.train_step_scores = []
        self.test_step_scores = []
        self.train_labels = []
        self.test_labels = []
        self.predictions = []
    def forward(self, x):
        output = self.l1(x)
        return output
    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0),-1)
        output = self(x)
        logits = torch.nn.functional.softmax(output, dim=1).detach().cpu().numpy()[:,1]
        self.train_step_scores.extend(logits.tolist())
        self.train_labels.extend(y.cpu().numpy().tolist())
        f = torch.nn.CrossEntropyLoss()
        loss = f(output, y)
        self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        return loss
    def on_train_epoch_end(self):
        train_proba = self.train_step_scores
        train_labels = self.train_labels
        auc_epoch = AUC(train_proba, train_labels)
        self.log("train_auc", auc_epoch, on_step=False, on_epoch=True, prog_bar=True)
        self.train_step_scores.clear()
        self.train_labels.clear() 
    def validation_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1)
        output = self(x)
        f = torch.nn.CrossEntropyLoss()
        loss = f(output, y)
        logits = torch.nn.functional.softmax(output, dim=1)
        pred_labels = logits.detach().cpu().numpy()[:,1]
        self.test_step_scores.extend(pred_labels.tolist())
        self.test_labels.extend(y.cpu().numpy().tolist())
        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
    def on_validation_epoch_end(self):
        test_proba = self.test_step_scores
        test_labels = self.test_labels
        auc_epoch = AUC(test_proba, test_labels)
        self.log("val_auc", auc_epoch, on_step=False, on_epoch=True, prog_bar=True)
        self.test_step_scores.clear()
        self.test_labels.clear() 
    def configure_optimizers(self):
        opt = torch.optim.Adam(self.parameters(), lr=1e-3)
        return opt
    def predict_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1)
        output = self(x)
        f = torch.nn.CrossEntropyLoss()
        loss = f(output, y)
        logits = torch.nn.functional.softmax(output, dim=1)
        pred_labels = logits.detach().cpu().numpy()[:,1]
        self.predictions.extend(pred_labels.tolist())
        return pred_labels
    def on_predict_end(self):
        return self.predictions

In [12]:
#dir(pl.LightningModule)

In [44]:
m = Model()

In [45]:
logger = pl.loggers.TensorBoardLogger("tb_logs", name="my_model", version="try_1")

In [46]:
trainer = pl.Trainer(max_epochs=10,
                    logger=logger,
                    log_every_n_steps=49
                    )

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [47]:
trainer.fit(m, train_dataset, test_dataset)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name | Type   | Params
--------------------------------
0 | l1   | Linear | 1.5 K 
--------------------------------
1.5 K     Trainable params
0         Non-trainable params
1.5 K     Total params
0.006     Total estimated model params size (MB)


Sanity Checking: |                                        | 0/? [00:00<?, ?it/s]

Training: |                                               | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.


In [50]:
m.predictions.clear()

In [51]:
predictions = trainer.predict(m, test_dataset)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: |                                             | 0/? [00:00<?, ?it/s]

In [52]:
roc_auc_score(test_data['labels'], m.predictions)

0.8858471760797342

In [48]:
predictions

[array([0.6385173 , 0.6419538 , 0.8837637 , 0.6395603 , 0.9998103 ,
        0.6777087 , 0.4865732 , 0.70311815, 0.9429919 , 0.92614895,
        0.9863599 , 0.76180375, 0.8366164 , 0.993877  , 0.7034826 ,
        0.9768009 , 0.88359547, 0.5777206 , 0.9633439 , 0.9941732 ,
        0.9831661 , 0.9620391 , 0.98803073, 0.99756193, 0.99572396,
        0.83691704, 0.7402928 , 0.8663593 , 0.83974844, 0.9540559 ,
        0.9600333 , 0.39396793, 0.97981447, 0.9753126 , 0.99955946,
        0.5304502 , 0.9966241 , 0.4617092 , 0.9943467 , 0.89638025,
        0.9366684 , 0.7640085 , 0.9824011 , 0.86409354, 0.84360945,
        0.915665  , 0.9695105 , 0.98459226, 0.9495296 , 0.97711825,
        0.97817993, 0.99969995, 0.96337754, 0.9946907 , 0.98789936,
        0.95125383, 0.9542039 , 0.8927416 , 0.8179715 , 0.9848925 ,
        0.99347866, 0.9898822 , 0.7379187 , 0.93041885], dtype=float32),
 array([0.72794676, 0.70939463, 0.82192254, 0.89366937, 0.9953761 ,
        0.78560627, 0.99831533, 0.98770565,