In [109]:
import numpy as np
import pandas as pd
from pathlib import Path
import math
import os

from tokenizers import Tokenizer, normalizers, models, pre_tokenizers, decoders, trainers, processors

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, SubsetRandomSampler
from torchvision import transforms
from torchmetrics import AUROC, Accuracy
from torch.optim.lr_scheduler import LambdaLR
from sklearn.model_selection import train_test_split

from pytorch_lightning import Trainer
from pytorch_lightning.tuner.tuning import Tuner 
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint, EarlyStopping, StochasticWeightAveraging
from pytorch_lightning.loggers.neptune import NeptuneLogger
import neptune.new as neptune

from data_processing.utils import *

In [110]:
class Rbp24Dataset(Dataset):

    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):

        df = self.df
        seq = df['seq'][idx].upper()
        label = df['label'][idx]

        sample = {'seq':seq, 'label':label}

        if self.transform:
            sample = self.transform(sample)
        
        return sample

class ToOHE(object):
    "Convert seq to One Hot Encoding, convert both seq and label to Tensors"

    def __call__(self, sample):

        seq, label = sample['seq'], sample['label']

        nucleotid = {'A':[1,0,0,0],'C':[0,1,0,0],'G':[0,0,1,0],'T':[0,0,0,1], '':[0,0,0,0], 'N':[0.25,0.25,0.25,0.25]}
        seq = np.array([nucleotid[x] for x in seq])

        sample = {'seq': torch.from_numpy(seq).float().permute(1,0),
                  'label': torch.tensor(label)}
        
        return sample
        

In [111]:
def scaled_dot_product(q, k, v, mask=None):
    d_k = q.size()[-1]
    attn_logits = torch.matmul(q, k.transpose(-2, -1))
    attn_logits = attn_logits / math.sqrt(d_k)
    if mask is not None:
        attn_logits = attn_logits.masked_fill(mask == 0, -9e15)
    attention = F.softmax(attn_logits, dim=-1)
    values = torch.matmul(attention, v)
    return values, attention

class MultiheadAttention(pl.LightningModule):

    def __init__(self, input_dim, embed_dim, num_heads):
        super().__init__()
        assert embed_dim % num_heads == 0, "Embedding dimension must be 0 modulo number of heads."

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        # Stack all weight matrices 1...h together for efficiency
        # Note that in many implementations you see "bias=False" which is optional
        self.qkv_proj = nn.Linear(input_dim, 3*embed_dim)
        self.o_proj = nn.Linear(embed_dim, embed_dim)

        self._reset_parameters()

    def _reset_parameters(self):
        # Original Transformer initialization, see PyTorch documentation
        nn.init.xavier_uniform_(self.qkv_proj.weight)
        self.qkv_proj.bias.data.fill_(0)
        nn.init.xavier_uniform_(self.o_proj.weight)
        self.o_proj.bias.data.fill_(0)

    def forward(self, x, mask=None, return_attention=False):
        batch_size, seq_length, embed_dim = x.size()
        qkv = self.qkv_proj(x)

        # Separate Q, K, V from linear output
        qkv = qkv.reshape(batch_size, seq_length, self.num_heads, 3*self.head_dim)
        qkv = qkv.permute(0, 2, 1, 3) # [Batch, Head, SeqLen, Dims]
        q, k, v = qkv.chunk(3, dim=-1)

        # Determine value outputs
        values, attention = scaled_dot_product(q, k, v, mask=mask)
        values = values.permute(0, 2, 1, 3) # [Batch, SeqLen, Head, Dims]
        values = values.reshape(batch_size, seq_length, embed_dim)
        o = self.o_proj(values)

        if return_attention:
            return o, attention
        else:
            return o

In [112]:
class AttnCNN(pl.LightningModule):

    def __init__(self, config):
        super(AttnCNN, self).__init__() 

        self.save_hyperparameters(config)

        self.learning_rate = self.hparams.learning_rate
        self.decay_factor = self.hparams.decay_factor
        self.batch_size = self.hparams.batch_size
        
        self.auroc = AUROC(num_classes=1)
        self.acc = Accuracy()

        #Emmbeding
        self.embedd = nn.Embeddgin()

        # conv blocks
        self.conv1 = nn.Sequential(
            nn.Conv1d(in_channels=4, out_channels=8, kernel_size=self.hparams.CONV1_kernelsize, padding="same"),
            nn.ReLU(), 
            nn.Dropout(0.25))
        
        self.conv2 = nn.Sequential(
            nn.Conv1d(in_channels=8, out_channels=self.hparams.num_channels, kernel_size=8, padding="same"),
            nn.ReLU(),
            nn.Dropout(0.25),
            nn.BatchNorm1d(8))
        
        self.lstm = torch.nn.LSTM(
            input_size=128,
            hidden_size=self.hparams.LSTM_kernelsize,
            num_layers=1,
            dropout=0.25,
            bidirectional=True,
            batch_first=True)

        self.multihead_attn = MultiheadAttention(input_dim=self.hparams.num_channels, embed_dim=self.hparams.num_channels, num_heads=4)

        self.conv3 = nn.Conv1d(in_channels=self.hparams.num_channels, out_channels=self.hparams.num_channels // 2, kernel_size=1, padding=0, bias=True)

        self.flatten = nn.Flatten()

        if self.hparams.DIMRED:
            self.hparams.num_channels = self.hparams.num_channels // 2
      
        if self.hparams.LSTM:
            self.linear = nn.Sequential(
                nn.Linear(self.hparams.num_channels*self.hparams.LSTM_kernelsize*2, self.hparams.DENSE_kernelsize),
                nn.ReLU(),
                nn.Dropout(0.25),
                nn.BatchNorm1d(self.hparams.DENSE_kernelsize),
                nn.Linear(self.hparams.DENSE_kernelsize, 2))
        else:
            self.linear = nn.Sequential(
                nn.Linear(self.hparams.num_channels*128, self.hparams.DENSE_kernelsize),
                nn.ReLU(),
                nn.Dropout(0.25),
                nn.BatchNorm1d(self.hparams.DENSE_kernelsize),
                nn.Linear(self.hparams.DENSE_kernelsize, 2))


    def forward(self, x):
        x = self.conv1(x)
        #print(x.shape)
        if self.hparams.CONV2:
            x = self.conv2(x)
            #print(x.shape)
        if self.hparams.LSTM:
            x,_ = self.lstm(x)
            #print(f"LSTM output: {x.shape}")
        if self.hparams.ATTN:
            x = x.permute(0,2,1)
            #print(f"permute output: {x.shape}")
            x = self.multihead_attn(x)
            #print(f"ATTN output: {x.shape}")
            x = x.permute(0,2,1)
            #print(f"permute output: {x.shape}")
        if self.hparams.DIMRED:
            x = self.conv3(x)
            #print(f"DimReduction output: {x.shape}")
        x = self.flatten(x)
        #print(f"Flatten output: {x.shape}")
        x = self.linear(x)
        #print(x.shape)
  
        return F.log_softmax(x, dim=-1)


    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        scheduler = LambdaLR(optimizer, lambda epoch: self.decay_factor ** epoch)
        return [optimizer], [scheduler] 

    def training_step(self, batch, batch_idx):
        inputs, labels = batch['seq'], batch['label']
        outputs = self(inputs)
        preds = torch.max(outputs, 1)[1]

        criterion = nn.NLLLoss()
        loss = criterion(outputs, labels)

        labels = labels.cpu().detach().int()
        preds = preds.cpu().detach().int()

        train_acc = self.acc(preds, labels)
        train_auroc = self.auroc(preds, labels)
        
        return {"loss": loss,
                "train_acc": train_acc,
                "train_auroc": train_auroc}
                

    def training_epoch_end(self, train_step_outputs):
        loss = torch.stack([x["loss"] for x in train_step_outputs]).mean()
        train_acc_epoch = torch.stack([x["train_acc"] for x in train_step_outputs]).mean()
        train_auroc_epoch = torch.stack([x["train_auroc"] for x in train_step_outputs]).mean()
        
        self.log("train/epoch/loss", loss)
        self.log("train/epoch/acc", train_acc_epoch)
        self.log("train/epoch/auroc", train_auroc_epoch) 
    
    
    def validation_step(self, batch, batch_idx):
        inputs, labels = batch['seq'], batch['label']
        outputs = self(inputs)

        criterion = nn.NLLLoss()
        loss = criterion(outputs, labels)

        labels = labels.cpu().detach()
        preds = torch.max(outputs, 1)[1].cpu().detach()

        val_acc = self.acc(preds, labels)
        val_auroc = self.auroc(preds, labels)
        
        return {"loss": loss,
                "val_acc": val_acc,
                "val_auroc": val_auroc}


    def validation_epoch_end(self, val_step_outputs):
        loss = torch.stack([x["loss"] for x in val_step_outputs]).mean()
        val_acc_epoch = torch.stack([x["val_acc"] for x in val_step_outputs]).mean()
        val_auroc_epoch = torch.stack([x["val_auroc"] for x in val_step_outputs]).mean()
        
        self.log("val/epoch/loss", loss)
        self.log("val/epoch/acc", val_acc_epoch)
        self.log("val/epoch/auroc", val_auroc_epoch)

    
    def test_step(self, batch, batch_idx):
        inputs, labels = batch['seq'], batch['label']
        outputs = self(inputs)

        criterion = nn.NLLLoss()
        loss = criterion(outputs, labels)

        labels = labels.cpu().detach()
        preds = torch.max(outputs, 1)[1].cpu().detach()

        test_acc = self.acc(preds, labels)
        test_auroc = self.auroc(preds, labels)
        
        return {"loss": loss,
                "test_acc": test_acc,
                "test_auroc": test_auroc}

    def test_epoch_end(self, test_step_outputs):
        loss = torch.stack([x["loss"] for x in test_step_outputs]).mean()
        test_acc_epoch = torch.stack([x["test_acc"] for x in test_step_outputs]).mean()
        test_auroc_epoch = torch.stack([x["test_auroc"] for x in test_step_outputs]).mean()
        
        self.log("test/loss", loss)
        self.log("test/acc", test_acc_epoch)
        self.log("test/auroc", test_auroc_epoch)

In [113]:
def make_datasets(train_df, test_df, val_size):

    trainset = Rbp24Dataset(train_df, transform=transforms.Compose([ToOHE()]))
    testset = Rbp24Dataset(test_df, transform=transforms.Compose([ToOHE()]))

    train_labels = [int(trainset[i]['label']) for i in range(len(trainset)-1)]
    train_idx, val_idx= train_test_split(np.arange(len(train_labels)), test_size=val_size, shuffle=True, stratify=train_labels)

    train_sampler = SubsetRandomSampler(train_idx)
    val_sampler = SubsetRandomSampler(val_idx)

    return trainset, testset, train_sampler, val_sampler


In [114]:
def make_dataloaders(train_df, test_df, batch_size, num_workers, val_split):

    trainset, testset, train_sampler, val_sampler = make_datasets(train_df, test_df, val_split)

    trainloader = DataLoader(trainset, batch_size=batch_size, num_workers=num_workers, sampler=train_sampler)
    valloader = DataLoader(trainset, batch_size=batch_size, num_workers=num_workers, sampler=val_sampler, shuffle=False)
    testloader = DataLoader(testset, batch_size=batch_size, num_workers=num_workers, shuffle=False)

    return trainloader, valloader, testloader

In [115]:
def concatanate(dataframes):
    train_df = pd.concat([dataframes[0], dataframes[1]], ignore_index=True).sample(frac=1).reset_index(drop=True)
    test_df = pd.concat([dataframes[2], dataframes[3]], ignore_index=True).sample(frac=1).reset_index(drop=True)
    return train_df, test_df

In [122]:
PARAMS = {
        "CONV1_kernelsize": 16,
        "CONV2": True,
        "num_channels": 8,
        "LSTM": False,
        "LSTM_kernelsize": 16,
        "ATTN": True,
        "DIMRED": False,
        "DENSE_kernelsize":256,
        "batch_size": 64,
        "learning_rate": 0.003,
        "decay_factor": 0.95,
        "max_epochs": 50,
        "num_workers": 16,
        "val_split": 0.1
    }

In [123]:
dataset_path = "/home/mrkvrbl/Diplomka/Data/rbp24/processed" #/home/mrkvrbl/Diplomka/Data/rbp31/
protein = "PARCLIP_MOV10_Sievers"
PARAMS['name'] = protein

train_path = dataset_path + "/" + protein + "/train/original.tsv.gz"
test_path = dataset_path + "/" + protein + "/test/original.tsv.gz"

train_df = pd.read_csv(train_path, delimiter="\t", index_col=0, header=0, compression="gzip")
test_df = pd.read_csv(test_path, delimiter="\t", index_col=0, header=0, compression="gzip")

trainloader, valloader, testloader = make_dataloaders(train_df, test_df, PARAMS["batch_size"], PARAMS["num_workers"], PARAMS["val_split"])

In [124]:
it = iter(trainloader)
batch = next(it)
seq, label = batch['seq'], batch['label']

In [126]:
label

tensor([0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1,
        0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
        1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1])

In [10]:
early_stopping = EarlyStopping('val/epoch/loss', patience=10, check_on_train_epoch_end=False, )

checkpoints_path = "checkpoints/" + PARAMS["name"]

model_checkpoint = ModelCheckpoint(
        dirpath=checkpoints_path,
        filename="{epoch:02d}",
        save_weights_only=True,
        save_top_k=-1,
        save_last=True,
        monitor="val/epoch/loss",
        every_n_epochs=1)

In [11]:
# create NeptuneLogger
#neptune_logger = NeptuneLogger(
#    api_key="eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiIxZDI3YzE1Yy0yYzllLTRjM2YtYjk2MS1jNzNiZmI3MzIyNWEifQ==",  # replace with your own
#    project="mrkvrbl/MasterThesis",  # "<WORKSPACE/PROJECT>"
#    name=PARAMS["name"])

In [13]:
trainer = Trainer(#logger=neptune_logger,
                callbacks=[model_checkpoint, early_stopping],
                max_epochs=PARAMS['max_epochs'],
                accumulate_grad_batches=1,
                gradient_clip_val=0.5,
                stochastic_weight_avg=True,
                gpus=1)
model = AttnCNN(PARAMS)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [70]:
#neptune_logger.log_model_summary(model=model, max_depth=-1)
#neptune_logger.log_hyperparams(params=PARAMS)

https://app.neptune.ai/mrkvrbl/MasterThesis/e/MST-251
Remember to stop your run once you’ve finished logging your metadata (https://docs.neptune.ai/api-reference/run#.stop). It will be stopped automatically only when the notebook kernel/interactive console is terminated.


In [14]:
trainer.fit(model, train_dataloaders=trainloader, val_dataloaders=valloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name           | Type               | Params
------------------------------------------------------
0 | auroc          | AUROC              | 0     
1 | acc            | Accuracy           | 0     
2 | conv1          | Sequential         | 520   
3 | conv2          | Sequential         | 536   
4 | lstm           | LSTM               | 18.7 K
5 | multihead_attn | MultiheadAttention | 288   
6 | conv3          | Conv1d             | 36    
7 | flatten        | Flatten            | 0     
8 | linear         | Sequential         | 263 K 
------------------------------------------------------
283 K     Trainable params
0         Non-trainable params
283 K     Total params
1.134     Total estimated model params size (MB)


Validation sanity check:   0%|          | 0/2 [00:00<?, ?it/s]



Validation sanity check:  50%|█████     | 1/2 [00:01<00:01,  1.31s/it]

  self.padding, self.dilation, self.groups)


Epoch 0:  82%|████████▏ | 172/210 [00:07<00:01, 23.89it/s, loss=0.593, v_num=0]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


Epoch 0:  82%|████████▏ | 172/210 [00:18<00:04,  9.27it/s, loss=0.593, v_num=0]

In [72]:
checkpoints = sorted(os.listdir(checkpoints_path))

for checkpoint in checkpoints:
    checkpoint_path = str(checkpoints_path) + "/" + checkpoint
    trainer.test(ckpt_path=checkpoint_path, dataloaders=testloader)

Restoring states from the checkpoint path at checkpoints/PARCLIP_MOV10_Sievers/epoch=00.ckpt
Loaded model weights from checkpoint at checkpoints/PARCLIP_MOV10_Sievers/epoch=00.ckpt


Testing:  88%|████████▊ | 7/8 [00:00<00:00, 11.20it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test/acc': 0.6753305196762085,
 'test/auroc': 0.6775914430618286,
 'test/loss': 0.6066418290138245}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 8/8 [00:00<00:00,  8.10it/s]


Restoring states from the checkpoint path at checkpoints/PARCLIP_MOV10_Sievers/epoch=01.ckpt
Loaded model weights from checkpoint at checkpoints/PARCLIP_MOV10_Sievers/epoch=01.ckpt


Testing: 100%|██████████| 8/8 [00:01<00:00, 10.56it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test/acc': 0.7095853090286255,
 'test/auroc': 0.7105029821395874,
 'test/loss': 0.5574076175689697}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 8/8 [00:01<00:00,  7.63it/s]


Restoring states from the checkpoint path at checkpoints/PARCLIP_MOV10_Sievers/epoch=02.ckpt
Loaded model weights from checkpoint at checkpoints/PARCLIP_MOV10_Sievers/epoch=02.ckpt


Testing: 100%|██████████| 8/8 [00:00<00:00, 10.77it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test/acc': 0.7170973420143127,
 'test/auroc': 0.7181735038757324,
 'test/loss': 0.5509724617004395}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 8/8 [00:00<00:00,  8.01it/s]


Restoring states from the checkpoint path at checkpoints/PARCLIP_MOV10_Sievers/epoch=03.ckpt
Loaded model weights from checkpoint at checkpoints/PARCLIP_MOV10_Sievers/epoch=03.ckpt


Testing: 100%|██████████| 8/8 [00:00<00:00,  9.98it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test/acc': 0.7195762991905212,
 'test/auroc': 0.7217215895652771,
 'test/loss': 0.557011604309082}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 8/8 [00:01<00:00,  7.85it/s]


Restoring states from the checkpoint path at checkpoints/PARCLIP_MOV10_Sievers/epoch=04.ckpt
Loaded model weights from checkpoint at checkpoints/PARCLIP_MOV10_Sievers/epoch=04.ckpt


Testing: 100%|██████████| 8/8 [00:01<00:00, 10.44it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test/acc': 0.7131910920143127,
 'test/auroc': 0.7153535485267639,
 'test/loss': 0.5580605864524841}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 8/8 [00:01<00:00,  7.48it/s]


Restoring states from the checkpoint path at checkpoints/PARCLIP_MOV10_Sievers/epoch=05.ckpt
Loaded model weights from checkpoint at checkpoints/PARCLIP_MOV10_Sievers/epoch=05.ckpt


Testing: 100%|██████████| 8/8 [00:00<00:00, 11.26it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test/acc': 0.7017728090286255,
 'test/auroc': 0.7024372220039368,
 'test/loss': 0.562064528465271}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 8/8 [00:00<00:00,  8.40it/s]


Restoring states from the checkpoint path at checkpoints/PARCLIP_MOV10_Sievers/epoch=06.ckpt
Loaded model weights from checkpoint at checkpoints/PARCLIP_MOV10_Sievers/epoch=06.ckpt


Testing:  88%|████████▊ | 7/8 [00:00<00:00, 11.91it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test/acc': 0.7140174508094788,
 'test/auroc': 0.7157325744628906,
 'test/loss': 0.5470075607299805}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 8/8 [00:00<00:00,  8.97it/s]


Restoring states from the checkpoint path at checkpoints/PARCLIP_MOV10_Sievers/epoch=07.ckpt
Loaded model weights from checkpoint at checkpoints/PARCLIP_MOV10_Sievers/epoch=07.ckpt


Testing: 100%|██████████| 8/8 [00:00<00:00, 11.11it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test/acc': 0.716871976852417,
 'test/auroc': 0.7191404700279236,
 'test/loss': 0.563056230545044}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 8/8 [00:00<00:00,  8.57it/s]


Restoring states from the checkpoint path at checkpoints/PARCLIP_MOV10_Sievers/epoch=08.ckpt
Loaded model weights from checkpoint at checkpoints/PARCLIP_MOV10_Sievers/epoch=08.ckpt


Testing: 100%|██████████| 8/8 [00:00<00:00, 11.55it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test/acc': 0.7200270295143127,
 'test/auroc': 0.7203802466392517,
 'test/loss': 0.54436194896698}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 8/8 [00:00<00:00,  8.39it/s]


Restoring states from the checkpoint path at checkpoints/PARCLIP_MOV10_Sievers/epoch=09.ckpt
Loaded model weights from checkpoint at checkpoints/PARCLIP_MOV10_Sievers/epoch=09.ckpt


Testing:  88%|████████▊ | 7/8 [00:00<00:00, 10.41it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test/acc': 0.73046875,
 'test/auroc': 0.7324057221412659,
 'test/loss': 0.5540257096290588}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 8/8 [00:01<00:00,  7.51it/s]


Restoring states from the checkpoint path at checkpoints/PARCLIP_MOV10_Sievers/epoch=10.ckpt
Loaded model weights from checkpoint at checkpoints/PARCLIP_MOV10_Sievers/epoch=10.ckpt


Testing: 100%|██████████| 8/8 [00:00<00:00, 10.66it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test/acc': 0.7122145295143127,
 'test/auroc': 0.7132553458213806,
 'test/loss': 0.5491184592247009}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 8/8 [00:01<00:00,  7.88it/s]


Restoring states from the checkpoint path at checkpoints/PARCLIP_MOV10_Sievers/epoch=11.ckpt
Loaded model weights from checkpoint at checkpoints/PARCLIP_MOV10_Sievers/epoch=11.ckpt


Testing:  88%|████████▊ | 7/8 [00:00<00:00, 10.54it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test/acc': 0.7195011973381042,
 'test/auroc': 0.7206891179084778,
 'test/loss': 0.5452395081520081}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 8/8 [00:01<00:00,  7.22it/s]


Restoring states from the checkpoint path at checkpoints/PARCLIP_MOV10_Sievers/epoch=12.ckpt
Loaded model weights from checkpoint at checkpoints/PARCLIP_MOV10_Sievers/epoch=12.ckpt


Testing:  88%|████████▊ | 7/8 [00:00<00:00,  9.77it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test/acc': 0.7346754670143127,
 'test/auroc': 0.7350885272026062,
 'test/loss': 0.543409526348114}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 8/8 [00:01<00:00,  6.99it/s]


Restoring states from the checkpoint path at checkpoints/PARCLIP_MOV10_Sievers/epoch=13.ckpt
Loaded model weights from checkpoint at checkpoints/PARCLIP_MOV10_Sievers/epoch=13.ckpt


Testing:  88%|████████▊ | 7/8 [00:00<00:00, 10.47it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test/acc': 0.7331730723381042,
 'test/auroc': 0.7343688607215881,
 'test/loss': 0.5390458106994629}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 8/8 [00:01<00:00,  7.62it/s]


Restoring states from the checkpoint path at checkpoints/PARCLIP_MOV10_Sievers/epoch=14.ckpt
Loaded model weights from checkpoint at checkpoints/PARCLIP_MOV10_Sievers/epoch=14.ckpt


Testing: 100%|██████████| 8/8 [00:00<00:00, 11.27it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test/acc': 0.7368539571762085,
 'test/auroc': 0.7382110357284546,
 'test/loss': 0.5329011678695679}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 8/8 [00:00<00:00,  8.55it/s]


Restoring states from the checkpoint path at checkpoints/PARCLIP_MOV10_Sievers/epoch=15.ckpt
Loaded model weights from checkpoint at checkpoints/PARCLIP_MOV10_Sievers/epoch=15.ckpt


Testing: 100%|██████████| 8/8 [00:01<00:00,  9.61it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test/acc': 0.7231820821762085,
 'test/auroc': 0.7251813411712646,
 'test/loss': 0.5319498777389526}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 8/8 [00:01<00:00,  7.79it/s]


Restoring states from the checkpoint path at checkpoints/PARCLIP_MOV10_Sievers/epoch=16.ckpt
Loaded model weights from checkpoint at checkpoints/PARCLIP_MOV10_Sievers/epoch=16.ckpt


Testing:  88%|████████▊ | 7/8 [00:00<00:00, 10.29it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test/acc': 0.7349008321762085,
 'test/auroc': 0.7366581559181213,
 'test/loss': 0.5384608507156372}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 8/8 [00:01<00:00,  7.40it/s]


Restoring states from the checkpoint path at checkpoints/PARCLIP_MOV10_Sievers/epoch=17.ckpt
Loaded model weights from checkpoint at checkpoints/PARCLIP_MOV10_Sievers/epoch=17.ckpt


Testing:  88%|████████▊ | 7/8 [00:00<00:00, 10.03it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test/acc': 0.7280648946762085,
 'test/auroc': 0.7305477857589722,
 'test/loss': 0.5379988551139832}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 8/8 [00:01<00:00,  7.19it/s]


Restoring states from the checkpoint path at checkpoints/PARCLIP_MOV10_Sievers/epoch=18.ckpt
Loaded model weights from checkpoint at checkpoints/PARCLIP_MOV10_Sievers/epoch=18.ckpt


Testing: 100%|██████████| 8/8 [00:01<00:00, 10.35it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test/acc': 0.7251352071762085,
 'test/auroc': 0.7267079949378967,
 'test/loss': 0.5397195219993591}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 8/8 [00:01<00:00,  7.11it/s]


Restoring states from the checkpoint path at checkpoints/PARCLIP_MOV10_Sievers/epoch=19.ckpt
Loaded model weights from checkpoint at checkpoints/PARCLIP_MOV10_Sievers/epoch=19.ckpt


Testing: 100%|██████████| 8/8 [00:00<00:00, 10.39it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test/acc': 0.7314453125,
 'test/auroc': 0.7334899306297302,
 'test/loss': 0.5381419658660889}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 8/8 [00:01<00:00,  7.96it/s]


Restoring states from the checkpoint path at checkpoints/PARCLIP_MOV10_Sievers/epoch=20.ckpt
Loaded model weights from checkpoint at checkpoints/PARCLIP_MOV10_Sievers/epoch=20.ckpt


Testing: 100%|██████████| 8/8 [00:00<00:00, 10.85it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test/acc': 0.7273136973381042,
 'test/auroc': 0.7295531630516052,
 'test/loss': 0.546108067035675}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 8/8 [00:00<00:00,  8.09it/s]


Restoring states from the checkpoint path at checkpoints/PARCLIP_MOV10_Sievers/epoch=21.ckpt
Loaded model weights from checkpoint at checkpoints/PARCLIP_MOV10_Sievers/epoch=21.ckpt


Testing:  88%|████████▊ | 7/8 [00:00<00:00, 10.40it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test/acc': 0.7190504670143127,
 'test/auroc': 0.7215559482574463,
 'test/loss': 0.5409253239631653}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 8/8 [00:01<00:00,  7.57it/s]


Restoring states from the checkpoint path at checkpoints/PARCLIP_MOV10_Sievers/epoch=22.ckpt
Loaded model weights from checkpoint at checkpoints/PARCLIP_MOV10_Sievers/epoch=22.ckpt


Testing:  88%|████████▊ | 7/8 [00:00<00:00, 11.42it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test/acc': 0.7212289571762085,
 'test/auroc': 0.722809910774231,
 'test/loss': 0.5424394011497498}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 8/8 [00:00<00:00,  8.63it/s]


Restoring states from the checkpoint path at checkpoints/PARCLIP_MOV10_Sievers/epoch=23.ckpt
Loaded model weights from checkpoint at checkpoints/PARCLIP_MOV10_Sievers/epoch=23.ckpt


Testing:  88%|████████▊ | 7/8 [00:00<00:00, 10.80it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test/acc': 0.7333984375,
 'test/auroc': 0.7358294129371643,
 'test/loss': 0.5372716784477234}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 8/8 [00:00<00:00,  8.04it/s]


Restoring states from the checkpoint path at checkpoints/PARCLIP_MOV10_Sievers/epoch=24.ckpt
Loaded model weights from checkpoint at checkpoints/PARCLIP_MOV10_Sievers/epoch=24.ckpt


Testing:  88%|████████▊ | 7/8 [00:00<00:00, 10.89it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test/acc': 0.7158203125,
 'test/auroc': 0.718449056148529,
 'test/loss': 0.5421281456947327}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 8/8 [00:01<00:00,  7.57it/s]


Restoring states from the checkpoint path at checkpoints/PARCLIP_MOV10_Sievers/epoch=25.ckpt
Loaded model weights from checkpoint at checkpoints/PARCLIP_MOV10_Sievers/epoch=25.ckpt


Testing:  88%|████████▊ | 7/8 [00:00<00:00, 10.58it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test/acc': 0.720703125,
 'test/auroc': 0.7235721349716187,
 'test/loss': 0.54537433385849}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 8/8 [00:01<00:00,  7.44it/s]


Restoring states from the checkpoint path at checkpoints/PARCLIP_MOV10_Sievers/epoch=26.ckpt
Loaded model weights from checkpoint at checkpoints/PARCLIP_MOV10_Sievers/epoch=26.ckpt


Testing:  88%|████████▊ | 7/8 [00:00<00:00, 10.73it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test/acc': 0.7087590098381042,
 'test/auroc': 0.711237907409668,
 'test/loss': 0.5501580238342285}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 8/8 [00:01<00:00,  7.70it/s]


Restoring states from the checkpoint path at checkpoints/PARCLIP_MOV10_Sievers/epoch=27.ckpt
Loaded model weights from checkpoint at checkpoints/PARCLIP_MOV10_Sievers/epoch=27.ckpt


Testing:  88%|████████▊ | 7/8 [00:00<00:00, 10.17it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test/acc': 0.7185246348381042,
 'test/auroc': 0.7220681309700012,
 'test/loss': 0.5412994027137756}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 8/8 [00:01<00:00,  7.29it/s]


Restoring states from the checkpoint path at checkpoints/PARCLIP_MOV10_Sievers/epoch=28.ckpt
Loaded model weights from checkpoint at checkpoints/PARCLIP_MOV10_Sievers/epoch=28.ckpt


Testing:  88%|████████▊ | 7/8 [00:00<00:00, 11.97it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test/acc': 0.7109375,
 'test/auroc': 0.7141355872154236,
 'test/loss': 0.549114465713501}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 8/8 [00:00<00:00,  9.00it/s]


Restoring states from the checkpoint path at checkpoints/PARCLIP_MOV10_Sievers/epoch=29.ckpt
Loaded model weights from checkpoint at checkpoints/PARCLIP_MOV10_Sievers/epoch=29.ckpt


Testing: 100%|██████████| 8/8 [00:00<00:00, 11.13it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test/acc': 0.7224308848381042,
 'test/auroc': 0.7250097990036011,
 'test/loss': 0.537405252456665}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 8/8 [00:01<00:00,  7.90it/s]


Restoring states from the checkpoint path at checkpoints/PARCLIP_MOV10_Sievers/last.ckpt
Loaded model weights from checkpoint at checkpoints/PARCLIP_MOV10_Sievers/last.ckpt


Testing:  88%|████████▊ | 7/8 [00:00<00:00, 10.10it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test/acc': 0.7224308848381042,
 'test/auroc': 0.7250097990036011,
 'test/loss': 0.537405252456665}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 8/8 [00:01<00:00,  6.75it/s]


In [73]:
import shutil
shutil.rmtree(checkpoints_path)

In [74]:
from sklearn.linear_model import LogisticRegression

def baseline_model_torch_metrics(X_train, X_test, y_train, y_test, max_iter=200):
    baseline = LogisticRegression(max_iter=max_iter, random_state=42)

    # flatten the data
    X_train_flat = X_train.reshape(X_train.shape[0], -1)
    X_test_flat = X_test.reshape(X_test.shape[0], -1)

    baseline.fit(X_train_flat, y_train)

    y_train = torch.tensor(y_train.values.astype(int))
    y_test = torch.tensor(y_test.values.astype(int))


    baseline_pred_train = torch.from_numpy(baseline.predict(X_train_flat)).int()
    baseline_pred_test = torch.from_numpy(baseline.predict(X_test_flat)).int()

    auroc = AUROC(num_classes=1)
    acc = Accuracy()

    train_acc_score = acc(y_train, baseline_pred_train)
    test_acc_score = auroc(y_test, baseline_pred_test)

    train_auc_score = acc(y_train, baseline_pred_train)
    test_auc_score = auroc(y_test, baseline_pred_test)

    print(f"train_auc_score: {train_auc_score}\ntest_auc_score: {test_auc_score}\ntrain_acc_score: {train_acc_score}\ntest_acc_score: {test_acc_score}")

In [75]:
from utils.utils import get_X_y

X_train, y_train = get_X_y(train_df)
X_test, y_test = get_X_y(test_df)

result = baseline_model_torch_metrics(X_train, X_test, y_train, y_test, max_iter=200)

train_auc_score: 0.6849514842033386
test_auc_score: 0.698250412940979
train_acc_score: 0.6849514842033386
test_acc_score: 0.698250412940979




In [76]:
# ADD EMBEDING AND CONFUSION TABLE

Experiencing connection interruptions. Will try to reestablish communication with Neptune. Internal exception was: ReadTimeout
Error occurred during asynchronous operation processing: Cannot upload file /home/mrkvrbl/Diplomka/src/checkpoints/PARCLIP_MOV10_Sievers/epoch=00.ckpt: Path not found or is a not a file.
Error occurred during asynchronous operation processing: Cannot upload file /home/mrkvrbl/Diplomka/src/checkpoints/PARCLIP_MOV10_Sievers/epoch=01.ckpt: Path not found or is a not a file.
Error occurred during asynchronous operation processing: Cannot upload file /home/mrkvrbl/Diplomka/src/checkpoints/PARCLIP_MOV10_Sievers/epoch=02.ckpt: Path not found or is a not a file.
Error occurred during asynchronous operation processing: Cannot upload file /home/mrkvrbl/Diplomka/src/checkpoints/PARCLIP_MOV10_Sievers/epoch=03.ckpt: Path not found or is a not a file.
Error occurred during asynchronous operation processing: Cannot upload file /home/mrkvrbl/Diplomka/src/checkpoints/PARCLIP_MO