In [1]:
import datasets
from torch.utils.data import DataLoader

file_dir = "tmp/"


train = datasets.load_from_disk(f"{file_dir}/imputed_200/train")
test = datasets.load_from_disk(f"{file_dir}/imputed_200/test")

In [2]:
from torch.utils.data import WeightedRandomSampler, DataLoader
from sklearn.utils import class_weight
import numpy as np
import pandas as pd


def createWeightedSampler(labels, class_num=2):

    if isinstance(labels, pd.Series):
        labels = labels.values

    class_weights = dict(
        enumerate(
            class_weight.compute_class_weight(
                "balanced",
                classes=np.arange(class_num),
                y=labels,
            )
        )
    )
    print(class_weights)
    train_class_weights = [class_weights[i] for i in labels]
    sampler = WeightedRandomSampler(
        train_class_weights, len(train_class_weights), replacement=True
    )
    return sampler

In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm import list_models, create_model
from collections import defaultdict
import pytorch_lightning as pl
import torchmetrics


class ProteinTransformer(nn.Module):
    def __init__(
        self,
        input_dim,
        hidden_dim,
        num_heads,
        num_layers,
        num_classes=2,
        dropout=0.1,
        **kwargs,
    ):
        super(ProteinTransformer, self).__init__()
        self.encoder_layers = nn.ModuleList(
            [
                nn.TransformerEncoderLayer(
                    d_model=input_dim,
                    nhead=num_heads,
                    dim_feedforward=hidden_dim,
                    dropout=dropout,
                    **kwargs,
                )
                for _ in range(num_layers)
            ]
        )
        self.fc = nn.Linear(input_dim, num_classes)

    def forward(self, x):
        for layer in self.encoder_layers:
            x = layer(x)
        x = x.mean(dim=1)

        x = self.fc(x)
        return x


class ProteinTransformerPL(pl.LightningModule):
    def __init__(
        self,
        input_dim,
        hidden_dim,
        num_heads,
        num_layers,
        num_classes=2,
        dropout=0.1,
        lr=1e-3,
        weight_decay=1e-2,
        weight=[1, 1],
        **kwargs,
    ):

        super(ProteinTransformerPL, self).__init__()

        self.lr = lr
        self.weight_decay = weight_decay

        self.mertic = {
            "train_auc": torchmetrics.AUROC(num_classes=num_classes, task="multiclass"),
            "val_auc": torchmetrics.AUROC(num_classes=num_classes, task="multiclass"),
        }
        self.history = defaultdict(dict)
        self.loss_fn = nn.CrossEntropyLoss(weight=torch.tensor(weight).float())

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.num_classes = num_classes
        self.dropout = dropout

        self.model = ProteinTransformer(
            input_dim=self.input_dim,
            hidden_dim=hidden_dim,
            num_heads=num_heads,
            num_layers=num_layers,
            num_classes=num_classes,
            dropout=dropout,
        )

    def forward(self, x):

        return self.model(*x) if isinstance(x, (list, tuple)) else self.model(x)

    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        outputs = self.forward(x)

        loss = self.loss_fn(outputs, y.squeeze(-1).float())

        self.mertic["train_auc"].update(
            torch.softmax(outputs, dim=-1), torch.argmax(y, dim=1)
        )

        self.log("ptl/train_loss", loss, on_epoch=True, prog_bar=True, on_step=False)
        return loss

    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        outputs = self.forward(x)
        loss = self.loss_fn(outputs, y.squeeze(-1).float())

        self.mertic["val_auc"].update(
            torch.softmax(outputs, dim=-1), torch.argmax(y, dim=1)
        )

        self.log("ptl/val_loss", loss, on_epoch=True, prog_bar=True)

    def on_train_epoch_end(self):

        auc = self.mertic["train_auc"].compute()
        self.log("ptl/train_auc", auc, prog_bar=True)

    def on_validation_epoch_end(self):
        auc = self.mertic["val_auc"].compute()
        self.log("ptl/val_auc", auc, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.parameters(), lr=self.lr, weight_decay=self.weight_decay
        )
        return optimizer

    def predict_df(self, df, batch_size=256):

        for feature in self.features:
            assert feature in df.columns
        print(f"input df have NA: {df[self.features].isna().sum(axis=1).sum()}")
        df = df.copy().dropna(subset=self.features)

        predict_dataloader = DataLoader(
            torch.tensor(df[self.features].values).float(),
            batch_size=batch_size,
            persistent_workers=True,
            num_workers=4,
        )

        self.eval()
        pred = []
        with torch.no_grad():
            for x in predict_dataloader:
                y_hat = self.forward(x).cpu().detach()
                y_hat = torch.softmax(y_hat, dim=-1)[:, 1]

                pred.append(y_hat)
        pred = torch.cat(pred).numpy()
        df["pred"] = pred
        return df

In [18]:
model = ProteinTransformer(256, 64, 4, 2)
print(model(torch.rand(32, 200, 256)).shape)
model_pl = ProteinTransformerPL(
    input_dim=256,
    hidden_dim=64,
    num_heads=4,
    num_layers=1,
    num_classes=2,
    dropout=0.1,
    lr=1e-3,
    weight_decay=1e-2,
    weight=[1, 1],
)
print(model_pl(torch.rand(32, 200, 256)).shape)

torch.Size([32, 2])
torch.Size([32, 2])


In [19]:
model = create_model("resnet18", num_classes=2)
model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act1): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (drop_block): Identity()
      (act1): ReLU(inplace=True)
      (aa): Identity()
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, m

In [20]:
from functools import partial


def hf_collate_fn(batch, x_col, y_col, one_hot=True):
    x = torch.stack([x[x_col] for x in batch])
    y = torch.stack([x[y_col] for x in batch])
    if one_hot:
        y = F.one_hot(y, num_classes=2)
    return x, y


train_sampler = createWeightedSampler(train["incident_cad"])

train_loader = DataLoader(
    train.select_columns(column_names=["embeddings", "incident_cad"]).with_format("pt"),
    batch_size=64,
    sampler=train_sampler,
    collate_fn=partial(hf_collate_fn, x_col="embeddings", y_col="incident_cad"),
)
test_loader = DataLoader(
    test.select_columns(column_names=["embeddings", "incident_cad"]).with_format("pt"),
    batch_size=64,
    collate_fn=partial(hf_collate_fn, x_col="embeddings", y_col="incident_cad"),
)

{0: 0.5293901434956481, 1: 9.006253126563282}


In [21]:
from pytorch_lightning import Trainer

model_pl = ProteinTransformerPL(
    input_dim=256,
    hidden_dim=128,
    num_heads=4,
    num_layers=4,
    num_classes=2,
    dropout=0.1,
    lr=1e-2,
    weight_decay=1e-2,
    weight=[1, 1],
)

trainer = Trainer(
    max_epochs=30,
    gradient_clip_val=1,
)
trainer.fit(model_pl, train_dataloaders=train_loader, val_dataloaders=test_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type               | Params
-----------------------------------------------
0 | loss_fn | CrossEntropyLoss   | 0     
1 | model   | ProteinTransformer | 1.3 M 
-----------------------------------------------
1.3 M     Trainable params
0         Non-trainable params
1.3 M     Total params
5.284     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [None]:
import torchmetrics
import os
from torch import optim, nn, utils, Tensor
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

import torch
import pytorch_lightning as pl

import torch.nn as nn
import torch.optim as optim
from collections import defaultdict
from timm import create_model


class LinearTransformerPL(pl.LightningModule):
    def __init__(
        self,
        features_dict,
        covariates_dict=None,
        d_ff=512,
        num_classes=2,
        num_layers=2,
        dropout=0.1,
        lr=1e-3,
        weight_decay=1e-2,
        weight=[1, 1],
        **kwargs,
    ):

        super(LinearTransformerPL, self).__init__()

        self.lr = lr
        self.weight_decay = weight_decay

        self.mertic = {
            "train_auc": torchmetrics.AUROC(num_classes=2, task="multiclass"),
            "val_auc": torchmetrics.AUROC(num_classes=2, task="multiclass"),
        }
        self.history = defaultdict(dict)
        self.loss_fn = nn.CrossEntropyLoss(weight=torch.tensor(weight).float())
        self.model = LinearTransformer(
            features_dict=features_dict,
            covariates_dict=covariates_dict,
            d_ff=d_ff,
            num_classes=num_classes,
            num_layers=num_layers,
            dropout=dropout,
        )

    def forward(self, x):

        return self.model(*x) if isinstance(x, (list, tuple)) else self.model(x)

    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        outputs = self.forward(x)
        loss = self.loss_fn(outputs, y.squeeze(-1).float())

        self.mertic["train_auc"].update(
            torch.softmax(outputs, dim=-1), torch.argmax(y, dim=1)
        )

        self.log("ptl/train_loss", loss, on_epoch=True, prog_bar=True, on_step=False)
        return loss

    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        outputs = self.forward(x)
        loss = self.loss_fn(outputs, y.squeeze(-1).float())

        self.mertic["val_auc"].update(
            torch.softmax(outputs, dim=-1), torch.argmax(y, dim=1)
        )

        self.log("ptl/val_loss", loss, on_epoch=True, prog_bar=True)

    def on_train_epoch_end(self):

        auc = self.mertic["train_auc"].compute()
        self.log("ptl/train_auc", auc, prog_bar=True)

    def on_validation_epoch_end(self):
        auc = self.mertic["val_auc"].compute()
        self.log("ptl/val_auc", auc, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.parameters(), lr=self.lr, weight_decay=self.weight_decay
        )
        return optimizer

    def predict_df(self, df, batch_size=256):

        for feature in self.features:
            assert feature in df.columns
        print(f"input df have NA: {df[self.features].isna().sum(axis=1).sum()}")
        df = df.copy().dropna(subset=self.features)

        predict_dataloader = DataLoader(
            torch.tensor(df[self.features].values).float(),
            batch_size=batch_size,
            persistent_workers=True,
            num_workers=4,
        )

        self.eval()
        pred = []
        with torch.no_grad():
            for x in predict_dataloader:
                y_hat = self.forward(x).cpu().detach()
                y_hat = torch.softmax(y_hat, dim=-1)[:, 1]

                pred.append(y_hat)
        pred = torch.cat(pred).numpy()
        df["pred"] = pred
        return df