In [1]:
import pandas as pd
import numpy as np
import os
import plotly
import plotly.graph_objects as go
import plotly.express as px
import random
import torch
import torch.nn.functional as F

from Bio import SeqIO
from Bio.Seq import Seq
from Bio.SeqUtils import GC
from sklearn.metrics import roc_curve, auc, roc_auc_score, precision_recall_curve, accuracy_score
from torch.utils.data import Dataset, DataLoader
from torch import nn
from random import sample

import pytorch_lightning as pl
from pytorch_lightning import seed_everything
from pytorch_lightning.cli import LightningCLI
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.loggers import CSVLogger

from sklearn.metrics import accuracy_score, roc_auc_score
import torchmetrics

from dinuc import *

from ambrosini_auc import ambrosini_roc_auc_score
import sarus_wrapper

import sys
import glob
import tempfile

from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score

from Bio.SeqRecord import SeqRecord

from plotly.subplots import make_subplots

from typing import Optional, List

from tqdm import tqdm

import re



In [2]:
chr_val = 'chr7'
chr_test = 'chr11'
seed_everything(42, workers=True)
torch.set_float32_matmul_precision('high')
intervals = sorted(glob.glob("/home/jakewayd/New_ChIP-Seq/data/CHS.Schmitges/Train_intervals/*.peaks"))
window = 300
genome = SeqIO.to_dict(SeqIO.parse('data/CHS.Schmitges/chrs.fasta', 'fasta'))

Global seed set to 42


In [3]:
def get_fasta(x, shift = 0, shuffle = None, dinuc = None):
    
    chrom = x['#CHROM']
    new_peak = x['abs_summit']+np.random.choice([shift, -shift])
    seq = genome[chrom][new_peak-window//2:new_peak+window//2+1].seq.__str__().upper()
    
    if shift != 0 or shuffle or dinuc: 
        y = 0
    else: 
        y = 1
    
    if set(seq) - set('ATGC'):
        return None, None, None
    
    if shuffle:
        seq = ''.join(random.sample(seq, len(seq)))
        
    if dinuc:
        seq = shuffle_string_dinucl(seq)
    
    return chrom, seq, y

def get_df(df, method, shift = 0):
    
    if method == 'pos':
        seqs = df.apply(lambda x: get_fasta(x), axis = 1)
    
    if method == 'shift':
        seqs = df.apply(lambda x: get_fasta(x, shift = shift), axis = 1)
        
    if method == 'shuffle':
        seqs = df.apply(lambda x: get_fasta(x, shuffle = True), axis = 1)
        
    if method == 'dinuc':
        seqs = df.apply(lambda x: get_fasta(x, dinuc = True), axis = 1)
        
    seqs = pd.DataFrame(seqs.to_list(), columns=['chr', 'seq', 'y']).dropna().reset_index(drop = True)
    return seqs

def get_datasets(peaks, method, shift = 0):
    
    df = pd.read_csv(peaks, sep = '\t')
    df = df[['#CHROM', 'abs_summit']]
    for i in [window//4, -window//4]:
        data = df.copy()
        data['abs_summit'] = data['abs_summit']+i
        df = pd.concat([df, data]).reset_index(drop = True)
        
    df_pos = get_df(df, 'pos')
    df_neg = get_df(df, method, shift)
    
    return pd.concat([df_pos, df_neg]).reset_index(drop = True)

def one_hot_encode(seq):
    
    nuc_d = {
         'A':[1.0,0.0,0.0,0.0],
         'C':[0.0,1.0,0.0,0.0],
         'G':[0.0,0.0,1.0,0.0],
         'T':[0.0,0.0,0.0,1.0]
    }
        
    return np.array([nuc_d[x] for x in seq])

class SeqDatasetOHE(Dataset):
    def __init__(self, df, seq_col='seq', target_col='y'):
        
        self.seqs = list(df[seq_col].values)
        self.seq_len = len(self.seqs[0])
        self.ohe_seqs = torch.stack([torch.tensor(one_hot_encode(x)) for x in self.seqs])
        self.labels = torch.tensor(list(df[target_col].values)).unsqueeze(1)
        
    def __len__(self): 
        return len(self.seqs)
    
    def __getitem__(self,idx):
        seq = self.ohe_seqs[idx]
        label = self.labels[idx]
        
        return seq, label
    
class DataModuleSeqs(pl.LightningDataModule):
    def __init__(self, peaks, method, shift = 0, batch_size = 512):
        super().__init__()
        self.peaks = peaks
        self.method = method
        self.shift = shift
        self.batch_size = batch_size
        
        df = get_datasets(self.peaks, self.method, self.shift)
        df = df.sample(frac = 1)
        
        df_val = df[df['chr'] == chr_val]
        if df_val.shape[0] < 1000:
            return None
        
        df_train = df[(df['chr'] != chr_val) & (df['chr'] != chr_test)]
        df_test = df[df['chr'] == chr_test]
        
        print(df_val['y'].value_counts())
        
        self.train = SeqDatasetOHE(df_train)
        self.val = SeqDatasetOHE(df_val)
        self.test = SeqDatasetOHE(df_test)

    def train_dataloader(self):
        return DataLoader(self.train, batch_size=self.batch_size, shuffle = True, num_workers = 5)

    def val_dataloader(self):
        return DataLoader(self.val, batch_size=self.batch_size, shuffle = False, num_workers = 5)

    def test_dataloader(self):
        return DataLoader(self.test, batch_size=self.batch_size, shuffle = False, num_workers = 5)

    def predict_dataloader(self):
        return DataLoader(self.test, batch_size=self.batch_size, shuffle = False, num_workers = 5)

### классы для модели

In [4]:
class Block(nn.Module):
    
    def __init__(self, filters):
        super(Block, self).__init__()
        self.conv1 = nn.Conv1d(filters, filters, kernel_size = 3, padding = 1)
        torch.nn.init.xavier_uniform_(self.conv1.weight)
        self.norm1 = nn.BatchNorm1d(filters)
        self.activ = nn.SELU(inplace=True)
        self.conv2 = nn.Conv1d(filters, filters, kernel_size = 3, padding = 1)
        torch.nn.init.xavier_uniform_(self.conv2.weight)
        self.norm2 = nn.BatchNorm1d(filters)
    
    def forward(self, x):
        out = self.conv1(x)
        out = self.norm1(out)
        out = self.activ(out)
        out = self.conv2(out)
        out = self.norm2(out)
        out = self.activ(x + out)
        return out

In [5]:
class Truck(nn.Module):
    def __init__(self, filters, n_blocks):
        super(Truck, self).__init__()
        truck = []
        for i in range(n_blocks):
            truck += [Block(filters)]
            
        self.truck = nn.Sequential(*truck)
        
    def forward(self, x):
        return self.truck(x)

In [6]:
class Layers(nn.Module):
    def __init__(self, filters, n_blocks, n_layers):
        super(Layers, self).__init__()
        layers = []
        for i in range(n_layers):
            layers += [
                nn.Sequential(
                    Truck(pow(2, i)*filters, n_blocks),
                    nn.Conv1d(pow(2, i)*filters, pow(2, i+1)*filters, kernel_size = 3, padding = 1, stride = 2),
                    nn.SELU(inplace=True)
                )
            ]       

        self.layers = nn.Sequential(*layers)
        
    def forward(self, x):
        return self.layers(x)

### модель

In [7]:
class Resnet_like(nn.Module):
    def __init__(self, filters = 8, num_in_block = 2, num_in_truck = 2, learning_rate = 1e-3):
        super(Resnet_like, self).__init__()
        
        self.conv1 = nn.Conv1d(4, filters, kernel_size = 7)
        self.bn1  = nn.BatchNorm1d(filters)
        self.SELU = nn.SELU(inplace=True)
        
        self.layers = Layers(filters, num_in_block, num_in_truck)
        
        self.avg = nn.AdaptiveMaxPool1d(1)
        
        self.linear = nn.Sequential(
            nn.Linear(pow(2, num_in_truck)*filters, pow(2, num_in_truck+1)*filters),
            nn.BatchNorm1d(pow(2, num_in_truck+1)*filters),
            nn.SELU(inplace=True),
            nn.Linear(pow(2, num_in_truck+1)*filters, pow(2, num_in_truck+2)*filters),
            nn.BatchNorm1d(pow(2, num_in_truck+2)*filters),
            nn.SELU(inplace=True),
            nn.Linear(pow(2, num_in_truck+2)*filters, 1)
        )
        
        self.name = f'{type(self).__name__}_{filters}_{num_in_truck}_{num_in_block}'
        
    def forward(self, x):
        x = x.permute(0, 2, 1)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.SELU(x)
        x = self.layers(x)
        x = self.avg(x)
        x = x.squeeze(2)
        x = self.linear(x)
        
        return x

In [9]:
import torch
import torch.nn.functional as F
import pytorch_lightning as pl
from torchmetrics import AUROC

class MyModel(pl.LightningModule):
    def __init__(self, model, len_train, lr = 1e-3, weight_decay = 1e-4, num_classes = 2):
        super(MyModel, self).__init__()
        self.save_hyperparameters()
        self.len_train = len_train
        self.num_classes = num_classes
        self.lr = lr
        self.weight_decay = weight_decay
        self.model = model
        self.loss_fn = torch.nn.BCEWithLogitsLoss()
        self.train_auc = AUROC(task = 'binary')
        self.val_auc = AUROC(task = 'binary')
        self.name = self.model.name
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x.float())
        loss = self.loss_fn(y_hat, y.float())
        
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log('train_auc', self.train_auc(y_hat, y.float()), on_step=True, on_epoch=True, prog_bar=True, logger=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x.float())
        loss = self.loss_fn(y_hat, y.float())
        
        self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log('val_auc', self.val_auc(y_hat, y.float()), on_step=True, on_epoch=True, prog_bar=True, logger=True)
        
        return loss
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x.float())
        loss = self.loss_fn(y_hat, y.float())
        
        self.log('test_loss', loss)
        self.log('test_auc', self.val_auc(y_hat, y.float()))
        
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        
        scheduler = {
            'scheduler': torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=self.lr, epochs=self.trainer.max_epochs, steps_per_epoch=self.len_train),
            'interval': 'step',
            'frequency': 1
        }
        
        # reduce_on_plateau_scheduler = {
        #     'scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3),
        #     'monitor': 'val_loss',
        #     'interval': 'epoch',
        #     'frequency': 1,
        #     'strict': True
        # }
        
        return [optimizer], [scheduler]


### Приколы с pytorch lightning

In [10]:
def init_weigths(x):
    if 'Conv1d' in str(type(x)) or 'Linear' in str(type(x)):
        torch.nn.init.xavier_uniform_(x.weight)

In [11]:
def plot_all(model, neg, dataset):
    
    df = pd.read_csv(f'logs/{model}/{neg}/{dataset}/metrics.csv')
    
    fig = make_subplots(
        rows=2,
        cols=2,
        subplot_titles = ('Training and validation losses', 'Training and validation AUCs',
                          'Learning rate every batch', 'Result AUCs')
    )

    colors = px.colors.qualitative.Pastel

    fig.add_trace(
        go.Scatter(
            x = df['epoch'].dropna().unique(),
            y = df['train_loss_epoch'].dropna(),
            marker_color = colors[0],
            name = 'train_loss_epoch'
        ), row = 1, col = 1
    )

    fig.add_trace(
        go.Scatter(
            x = df['epoch'].dropna().unique(),
            y = df['val_loss_epoch'].dropna(),
            marker_color = colors[1],
            name = 'val_loss_epoch'
        ), row = 1, col = 1
    )

    fig.add_trace(
        go.Scatter(
            x = df['epoch'].dropna().unique(),
            y = df['train_auc_epoch'].dropna(),
            marker_color = colors[2],
            name = 'train_auc_epoch'
        ), row = 1, col = 2
    )

    fig.add_trace(
        go.Scatter(
            x = df['epoch'].dropna().unique(),
            y = df['val_auc_epoch'].dropna(),
            marker_color = colors[3],
            name = 'val_auc_epoch'
        ), row = 1, col = 2
    )

    fig.add_trace(
        go.Scatter(
            x = [i for i in range(len(df['lr-AdamW'].dropna()))],
            y = df['lr-AdamW'].dropna(),
            marker_color = colors[4],
            name = 'lr-AdamW'
        ), row = 2, col = 1
    )

    fig.add_trace(
        go.Bar(
            x = ['train', 'val', 'test'],
            y = [df['train_auc_epoch'].dropna().to_list()[-1],
                 df['val_auc_epoch'].dropna().to_list()[-1],
                 df['test_auc'].dropna().to_list()[-1]],
            text = [round(df['train_auc_epoch'].dropna().to_list()[-1], 3),
                 round(df['val_auc_epoch'].dropna().to_list()[-1], 3),
                 round(df['test_auc'].dropna().to_list()[-1], 3)],
            marker_color = colors[5],
            name = 'result'
        ), row = 2, col = 2
    )

    fig.update_layout(
        margin = dict(t = 120, b = 5, l = 5, r = 5),
        template = 'plotly_white',
        title = {
            'text': f"<b>Model:</b> {model} <br> <b>Negative:</b> {neg} <br> <b>Dataset</b>: {dataset} <br>",
            'x': 0.01,
            'y': 0.95,
            'yanchor': 'top'
        },
        height = 700,
        width = 1200,
        xaxis1 = dict(title = 'Epoch'),
        yaxis1 = dict(title = 'BCEloss logit'),
        xaxis2 = dict(title = 'Epoch'),
        yaxis2 = dict(title = 'AUC'),
        xaxis3 = dict(title = 'Step'),
        yaxis3 = dict(title = 'Learning rate'),
        xaxis4 = dict(title = 'Dataset'),
        yaxis4 = dict(title = 'AUC'),
    )

    fig.write_image(f'logs/{model}/{neg}/{dataset}/plot.png')

In [12]:
def train_valid_model(num_filters, blocks, divs, intervals, method, shift):
    d = {'name': [], 'neg': [], 'model': [], 'auc': []}
    for name in intervals:
        print(name.split('/')[-1].split('.')[0])
        datasets = DataModuleSeqs(name, method, shift)
        
        try:
            datasets.val
        except:
            continue
        
        model = Resnet_like(num_filters, blocks, divs).apply(init_weigths)
        shell = MyModel(model, len(datasets.train_dataloader()))
        if not os.path.exists(f'logs/{shell.name}/{method}_{shift}/{name.split("/")[-1]}'):
            os.makedirs(f'logs/{shell.name}/{method}_{shift}/{name.split("/")[-1]}')
            
        trainer = pl.Trainer(
                             max_epochs = 20,
                             accelerator='gpu',
                             devices=[1],
                             auto_lr_find = True,
                             callbacks = [EarlyStopping(monitor = "val_loss", mode = "min"), 
                                          LearningRateMonitor(logging_interval='step')],
                             log_every_n_steps = 1,
                             logger = CSVLogger(save_dir = f"logs/{shell.name}", name = f'{method}_{shift}', version = name.split('/')[-1])
                            )

        trainer.fit(model = shell, datamodule = datasets)
        result = trainer.validate(model = shell, datamodule = datasets)
        test = trainer.test(model = shell, datamodule = datasets)
        
        plot_all(shell.name, f'{method}_{shift}', name.split('/')[-1])
        d['auc'].append(test[0]['test_auc'])
        d['name'].append(name)
        d['neg'].append(method)
        d['model'].append(shell.name)
    
    return pd.DataFrame(d)

In [14]:
# train_valid_model(16, 1, 1, intervals[5:6], 'dinuc', 0)

In [15]:
class PWMModel:
    def __init__(self, matrix_path, transpose=False):
        self.matrix_path = matrix_path
        self.transpose_mode = transpose
        self.dir = tempfile.TemporaryDirectory()
    
    def fit(self, sequences: List[str], y: List[int]):
        raise NotImplementedError("fitting PWMModel is not implemented")
    
    def predict(self, sequences: List[str]):
        records = [SeqRecord(Seq(seq), id=f"{i}", description="") for i, seq in enumerate(sequences, 1)]
        file_path = os.path.join(self.dir.name, "predict.fasta")
        with open(file_path, "wt") as outfile:
            SeqIO.write(records, outfile, "fasta")
        pred = self.predict_file(file_path)
        return pred
    
    def predict_file(self, file_path: str):
        sarus_result = sarus_wrapper.launch_sarus_single(file_path, self.matrix_path, transpose=self.transpose_mode)
        scores = sarus_result["val"].values
        return scores

In [16]:
def shuffle_string(s):
    chars = list(s)
    np.random.shuffle(chars)
    return "".join(chars)

In [17]:
def get_motif_records(type_dir, type_, subtype="-"):
    motifs_dir = "/home/arsen_l/imtf/wp04_datachecker/bestpwms/greco-motifs/best-mx"
    all_type = "ALL"
    data_type = "CHS"
    has_subdirs = False
    records = list()
    motif_mask = os.path.join(type_dir, "*.p?m")
    motif_paths = glob.glob(motif_mask)
    motif_paths.sort()
    for motif_path in motif_paths:
        record = dict()
        record["type"] = type_
        record["subtype"] = subtype
        motif_id = os.path.basename(motif_path)
        record["factor"] = motif_id.split("@", maxsplit=1)[0].split(".", maxsplit=1)[0]
        record["motif_id"] = motif_id
        record["extension"] = os.path.splitext(record["motif_id"])[-1]
        record["path"] = motif_path
        records.append(record)
    return records

In [18]:
def get_motif_df():
    motifs_dir = "/home/arsen_l/imtf/wp04_datachecker/bestpwms/greco-motifs/best-mx"
    all_type = "ALL"
    data_type = "CHS"
    has_subdirs = False
    motif_records = list()
    if has_subdirs:
        subtype_mask = os.path.join(motifs_dir, data_type, "*")
        subtype_paths = sorted(glob.glob(subtype_mask))
        for subtype_path in subtype_paths:
            subtype = os.path.basename(subtype_path)
            records = get_motif_records(subtype_path, data_type, subtype)
            motif_records.extend(records)
    else:
        data_type_dir = os.path.join(motifs_dir, data_type)
        records = get_motif_records(data_type_dir, data_type)
        motif_records.extend(records)

    records = get_motif_records(data_type_dir, all_type)
    motif_records.extend(records)

    motif_df = pd.DataFrame(motif_records).sort_values(by=["factor"]).reset_index(drop=True)
    return motif_df

In [19]:
def get_quals_pwm(names, neg, shift = 0, motif_df = None):
    motifs_dir = "/home/arsen_l/imtf/wp04_datachecker/bestpwms/greco-motifs/best-mx"
    all_type = "ALL"
    data_type = "CHS"
    has_subdirs = False
    metrics = {
                "AUROC": (roc_auc_score, dict()),
              }

    results = list()
    for name in tqdm(names, total=len(names), ascii=True):

        df = get_datasets(name, neg, shift)
        df = df.sample(frac = 1)    
        if df[df['chr'] == chr_val].shape[0] < 1000:
            continue

        dataset_id_pos = os.path.basename(name)
        print(dataset_id_pos)
        factor = dataset_id_pos.split('.')[0]
        if factor not in motif_df["factor"].values:
            results.append({"factor": factor, "dataset_positive": dataset_id_pos, "dataset_negative_type": neg})
            print(f"skipped {factor}")
            continue

        positive = df[df['y'] == 1]['seq'].to_list()
        negative = df[df['y'] == 0]['seq'].to_list()
        eval_set = positive + negative
        y_true = np.concatenate([np.ones(shape=len(positive)), np.zeros(shape=len(negative))])

        subdf = motif_df[motif_df["factor"] == factor]
        for _, record in subdf.iterrows():
            pwm_path = record["path"]
            pwm_id = record["motif_id"]

            pwm_model = PWMModel(pwm_path)
            scores = pwm_model.predict(eval_set)

            eval_result = record.to_dict()
            eval_result["dataset_positive"] = dataset_id_pos
            eval_result["dataset_negative"] = neg
            for metric_name, (metric, params) in metrics.items():
                m_v = metric(y_true=y_true, y_score=scores, **params)
            eval_result[metric_name] = m_v
        results.append(eval_result)
        
    return pd.DataFrame(results)

In [20]:
def get_quals(intervals, method, shift):
    motifs_dir = "/home/arsen_l/imtf/wp04_datachecker/bestpwms/greco-motifs/best-mx"
    all_type = "ALL"
    data_type = "CHS"
    has_subdirs = False
    motif_df = get_motif_df()
    return get_quals_pwm(intervals, method, motif_df = motif_df, shift = shift)

In [24]:
def all_in(num_filters, blocks, divs, shifts, intervals, overwrite = False):
    for i in num_filters:
        for j in blocks:
            for k in divs:
                for method in ['shuffle', 'dinuc', 'shift']:
                    for shift in shifts:
                        if not os.path.exists(f'./results/{method}_{shift}_{i}_{j}_{k}.tsv') or overwrite:
                            result = train_valid_model(i, j, k, intervals, method, shift)
                            result.to_csv(f'./results/{method}_{shift}_{i}_{j}_{k}.tsv', sep = '\t')

                        if not os.path.exists(f'./quals/{method}_{shift}.tsv') or overwrite:
                            quals = get_quals(intervals, method, shift)
                            quals.to_csv(f'./quals/{method}_{shift}.tsv', sep = '\t')

In [None]:
all_in([8, 16, 24], [1, 2, 3], [1, 2, 3], [0, 100, 500], intervals, True)

GLI4
SNAI1
SNAI1
SNAI1
ZFP28
ZFP28
0.0    512
1.0    512
Name: y, dtype: int64


  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name      | Type              | Params
------------------------------------------------
0 | model     | Resnet_like       | 14.6 K
1 | loss_func | BCEWithLogitsLoss | 0     
2 | acc       | BinaryAccuracy    | 0     
------------------------------------------------
14.6 K    Trainable params
0         Non-trainable params
14.6 K    Total params
0.059     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Validation: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
────────────────────────────────────────────────────────────────────────────────
          step                      5.0
      val_acc_epoch            0.9208984375
        val_auroc             0.989990234375
     val_loss_epoch         0.17394708096981049
────────────────────────────────────────────────────────────────────────────────
ZFP3
ZIM3
1.0    4888
0.0    4888
Name: y, dtype: int64


  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name      | Type              | Params
------------------------------------------------
0 | model     | Resnet_like       | 14.6 K
1 | loss_func | BCEWithLogitsLoss | 0     
2 | acc       | BinaryAccuracy    | 0     
------------------------------------------------
14.6 K    Trainable params
0         Non-trainable params
14.6 K    Total params
0.059     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Validation: 0it [00:00, ?it/s]