# Model Implementation for 3D Cell Tracking


In [1]:
# !pip install torchsummary 
# !pip install gunpowder
# !pip install zarr
# !pip install matplotlib
# pip install tensorboard

In [2]:
from torch.utils.data import DataLoader, random_split
from torch.utils.data.sampler import WeightedRandomSampler
from torchvision.datasets import ImageFolder
from torchvision import transforms
import torch
import torch.nn as nn
import numpy as np
import random 
import matplotlib.pyplot as plt
from torchvision import models
from torchsummary import summary

from torchvision import datasets
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler
import gunpowder as gp
import zarr
import math
%load_ext tensorboard
from torch.utils.tensorboard import SummaryWriter
import skimage
import networkx
import pathlib
from tifffile import imread, imwrite
import tensorboard
import torch.nn.functional as F

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

  from .autonotebook import tqdm as notebook_tqdm


# Data Importing

## Extract Centroids

In [3]:
# function to convert 'TRA' channel into cell and frame-wise centroid positions
## Function to extract trajectories from data

base_path = pathlib.Path("/mnt/shared/celltracking/data/cho/")


# read parent-child links from file
links = np.loadtxt(base_path / "01_GT/TRA" / "man_track.txt", dtype=int)

# read annotated image stack
centroids = np.stack([imread(xi) for xi in sorted((base_path / "01_GT/TRA").glob("*.tif"))])  # images

# extract centroids from annotated image stacks
centers = skimage.measure.regionprops(centroids[0,0,:,:])
tracks = []
for t, frame in enumerate(centroids):
    centers = skimage.measure.regionprops(frame)
    for c in centers:
        tracks.append([c.label, t, int(c.centroid[1]), int(c.centroid[2])])
        
# constructs graph 
tracks = np.array(tracks)
graph = networkx.DiGraph()
for cell_id, t, x, y in tracks:
    graph.add_node((cell_id,t), x=x, y=y, t=t)
    
for cell_id, t in graph.nodes():
    if (cell_id, t+1) in graph.nodes():
        graph.add_edge((cell_id, t), (cell_id,t+1))

for child_id, child_from, _, child_parent_id in links:
    for parent_id, _, parent_to, _ in links:
        if child_parent_id == parent_id:
            graph.add_edge((parent_id, parent_to), (child_id, child_from))
            
# extract trajectories from graph set
tracks = [graph.subgraph(c) for c in networkx.weakly_connected_components(graph) if len(c)>0]

# remove tracks with 0 edges
tracks = [track for track in tracks if len(track.edges)>0]



## Define function to make image pairs

In [4]:
class getPaired(gp.BatchFilter):

    def __init__(self, raw, raw_shift, tracks, paired=True):
        self.raw = raw
        self.raw_shift = raw_shift
        self.tracks = tracks
        self.paired = paired
    
    # _ref channel array is stored in raw_ref, while second volume in pair will be stored raw_new
    def prepare(self, request):
        # obtain volume coordinates from tracks                
        deps = gp.BatchRequest()
        vol1,vol2 = self.sampler(request)
                
        deps[self.raw] = gp.ArraySpec(roi=gp.Roi(vol1,request[self.raw].roi.get_shape()))
        deps[self.raw_shift] = gp.ArraySpec(roi=gp.Roi(vol2,request[self.raw_shift].roi.get_shape()))

        return deps
    
    # required to inform downstream nodes about new array 
    def process(self, batch, request):
        # create a new batch to hold the new array
        out_batch = gp.Batch()

        # create new array and store it in the batch
        out_batch[self.raw_shift] = batch[self.raw_shift]
        out_batch[self.raw] = batch[self.raw]
        
        #print(f'raw: {batch[self.raw].spec.roi}')
        #print(batch[self.raw_shift].spec.roi)
        
        # make sure that coordinates for batch[raw] and batch[raw_shift] are reset to (0,0,0,0,0)
        out_batch[self.raw].spec.roi = request[self.raw].roi
        out_batch[self.raw_shift].spec.roi = request[self.raw_shift].roi

        # return the new batch
        return out_batch
    
    # select pairs of subvolumes from data
    def sampler(self,request):
        tracks = self.tracks
        paired = self.paired
        # choose connected nodes
        # if self.paired:
        if paired:
            t0 = tracks[np.random.randint(0,len(tracks),1).item()]
            e0 = list(t0.edges)[np.random.randint(len(list(t0.edges)))]
            node0 = t0.nodes[e0[0]]
            node1 = t0.nodes[e0[1]]
            
        # choose random unconnected nodes
        else:
            # randomly choose two tracks and make sure they are not identical
            t0,t1 = np.random.randint(0,len(tracks),2)
            while t0==t1:
                t0,t1 = np.random.randint(0,len(tracks),2)

            #print(f'trackids: {t0,t1}')
            t0 = tracks[t0]
            t1 = tracks[t1]

            # choose random edges from each track
            #print(f'number edges per track{len(list(t0.nodes)),len(list(t1.nodes))}')

            r0 = np.random.randint(0,len(list(t0.nodes))) 
            r1 = np.random.randint(0,len(list(t1.nodes)))

            node0 = t0.nodes[list(t0.nodes)[r0]]
            node1 = t1.nodes[list(t1.nodes)[r1]]
            


        node0_xyt = [node0["x"], node0["y"], node0["t"]]
        node1_xyt = [node1["x"], node1["y"], node1["t"]]

        #print(f'input coord: {node0_xyt,node1_xyt}')

        roi_in = request[self.raw_shift].roi.get_shape()
        #t,z,y,x
        coords_vol0 = (node0_xyt[2],0,node0_xyt[0]-(roi_in[2]/2),node0_xyt[1]-(roi_in[3]/2))
        coords_vol1 = (node1_xyt[2],0,node1_xyt[0]-(roi_in[2]/2),node1_xyt[1]-(roi_in[3]/2))
        #print(f'output coords - vol0: {coords_vol0}, vol1:{coords_vol1}')

        return coords_vol0, coords_vol1
            

# Make batches 

In [5]:
#specify subvolume size and volume source
volSize = (1,5,64, 64)
coord = (0,0,0,0)
batch_size = 8

zarrdir = '/mnt/shared/celltracking/data/cho/01.zarr'
raw = gp.ArrayKey('raw')
raw_shift = gp.ArrayKey('raw_shift')

# chose a random source (i.e., sample) from the above
# pipeline_paired = (gp.ZarrSource(
#     zarrdir,  # the zarr container
#     {raw_shift: 'raw', raw: 'raw'},  # which dataset to associate to the array key
#     {raw_shift: gp.ArraySpec(voxel_size=(1,1,1,1), interpolatable=True), raw:gp.ArraySpec(voxel_size=(1,1,1,1), interpolatable=True)}  # meta-information
#     )+ gp.Normalize(raw)+gp.Normalize(raw_shift)+ 
#     gp.Pad(raw_shift, None) + 
#     gp.Pad(raw, None) + gp.IntensityAugment(
#     raw,
#     scale_min=0.9,
#     scale_max=1.1,
#     shift_min=-0.1,
#     shift_max=0.1,
#     ) + gp.NoiseAugment(raw, mode="gaussian")) + gp.IntensityAugment(
#     raw_shift,
#     scale_min=0.9,
#     scale_max=1.1,
#     shift_min=-0.1,
#     shift_max=0.1,
#     ) + gp.NoiseAugment(raw_shift, mode="gaussian") + getPaired(raw,raw_shift,tracks,paired=True) + gp.ElasticAugment(
#     [2,10,10],
#     [0,2,2],
#     [0,0*math.pi/2.0],
#     prob_slip=0.05,
#     prob_shift=0.05,
#     max_misalign=25) + gp.SimpleAugment(transpose_only=[2, 3], mirror_only=[])

# pipeline_unpaired = (gp.ZarrSource(
#     zarrdir,  # the zarr container
#     {raw_shift: 'raw', raw: 'raw'},  # which dataset to associate to the array key
#     {raw_shift: gp.ArraySpec(voxel_size=(1,1,1,1), interpolatable=True), raw:gp.ArraySpec(voxel_size=(1,1,1,1), interpolatable=True)}  # meta-information
#     )+ gp.Normalize(raw)+gp.Normalize(raw_shift)+ 
#     gp.Pad(raw_shift, None) + 
#     gp.Pad(raw, None) + gp.IntensityAugment(
#     raw,
#     scale_min=0.9,
#     scale_max=1.1,
#     shift_min=-0.1,
#     shift_max=0.1,
#     ) + gp.NoiseAugment(raw, mode="gaussian") + gp.IntensityAugment(
#     raw_shift,
#     scale_min=0.9,
#     scale_max=1.1,
#     shift_min=-0.1,
#     shift_max=0.1,
#     ) + gp.NoiseAugment(raw_shift, mode="gaussian")) + getPaired(raw,raw_shift,tracks,paired=False)  + gp.ElasticAugment(
#     [2,10,10],
#     [0,2,2],
#     [0,0*math.pi/2.0],
#     prob_slip=0.05,
#     prob_shift=0.05,
#     max_misalign=25) + gp.SimpleAugment(transpose_only=[2, 3], mirror_only=[]) 

pipeline_paired = (gp.ZarrSource(
    zarrdir,  # the zarr container
    {raw_shift: 'raw', raw: 'raw'},  # which dataset to associate to the array key
    {raw_shift: gp.ArraySpec(voxel_size=(1,1,1,1), interpolatable=True), raw:gp.ArraySpec(voxel_size=(1,1,1,1), interpolatable=True)}  # meta-information
    ) + gp.Pad(raw_shift, None) + gp.Pad(raw, None) + getPaired(raw,raw_shift,tracks,paired=True))

pipeline_unpaired = (gp.ZarrSource(
    zarrdir,  # the zarr container
    {raw_shift: 'raw', raw: 'raw'},  # which dataset to associate to the array key
    {raw_shift: gp.ArraySpec(voxel_size=(1,1,1,1), interpolatable=True), raw:gp.ArraySpec(voxel_size=(1,1,1,1), interpolatable=True)}  # meta-information
    ) + gp.Pad(raw_shift, None) + gp.Pad(raw, None) + getPaired(raw,raw_shift,tracks,paired=False))


pipeline_paired += gp.PreCache(num_workers=6) 
pipeline_unpaired += gp.PreCache(num_workers=6)
pipeline_paired += gp.Stack(batch_size)
pipeline_unpaired += gp.Stack(batch_size)

# specify request
request = gp.BatchRequest()
request[raw] = gp.Roi(coord, volSize)
request[raw_shift] = gp.Roi(coord, volSize)

gp.ArraySpec()
# #build the pipeline...
# with gp.build(pipeline_paired):

#   #...and request a batch
#   batch = pipeline_paired.request_batch(request)
  
# #show the content of the batch
# print(f"batch returned: {batch}")

# #plot first slice of volume
# fig, axs = plt.subplots(1,2, figsize=(10,5),dpi=300)
# print(batch[raw].data.shape)
# axs[0].imshow(np.flipud(batch[raw].data[0,0,0,:,:]))
# axs[1].imshow(np.flipud(batch[raw_shift].data[0,0,0,:,:]))
# axs[0].set_xticks([])
# axs[1].set_xticks([])
# axs[0].set_yticks([])
# axs[1].set_yticks([])



ROI: None, voxel size: None, interpolatable: None, non-spatial: False, dtype: None, placeholder: False

In [6]:
# provide path to zarr directory
zarrdir = '/mnt/shared/celltracking/data/cho/01.zarr'

data = zarr.open(zarrdir)
loader = []

# Define the Model

In [7]:
class Vgg3D(torch.nn.Module):

    def __init__(self, input_size, output_classes, downsample_factors, fmaps=12):

        super(Vgg3D, self).__init__()

        self.input_size = input_size
        self.downsample_factors = downsample_factors
        self.output_classes = 2

        current_fmaps, h, w, d = tuple(input_size)
        current_size = (h, w,d)

        features = []
        for i in range(len(downsample_factors)):

            features += [
                torch.nn.Conv3d(current_fmaps,fmaps,kernel_size=3,padding=1),
                torch.nn.BatchNorm3d(fmaps),
                torch.nn.ReLU(inplace=True),
                torch.nn.Conv3d(fmaps,fmaps,kernel_size=3,padding=1),
                torch.nn.BatchNorm3d(fmaps),
                torch.nn.ReLU(inplace=True),
                torch.nn.MaxPool3d(downsample_factors[i])
            ]

            current_fmaps = fmaps
            fmaps *= 2

            size = tuple(
                int(c/d)
                for c, d in zip(current_size, downsample_factors[i]))
            check = (
                s*d == c
                for s, d, c in zip(size, downsample_factors[i], current_size))
            assert all(check), \
                "Can not downsample %s by chosen downsample factor" % \
                (current_size,)
            current_size = size

        self.features = torch.nn.Sequential(*features)

        classifier = [
            torch.nn.Linear(current_size[0] *current_size[1]*current_size[2] *current_fmaps,4096),
            torch.nn.ReLU(inplace=True),
            torch.nn.Dropout(),
            torch.nn.Linear(4096,4096),
            torch.nn.ReLU(inplace=True),
            torch.nn.Dropout(),
            torch.nn.Linear(4096,output_classes)
        ]

        self.classifier = torch.nn.Sequential(*classifier)
    
    def forward(self, raw):

        # add a channel dimension to raw
        # shape = tuple(raw.shape)
        # raw = raw.reshape(shape[0], 1, shape[1], shape[2])
        
        # compute features
        f = self.features(raw)
        f = f.view(f.size(0), -1)
        
        # classify
        y = self.classifier(f)

        return y

# Loss Functions

We'll probably need to test some different loss functions. List some here:
Contrastive loss
cosine similarity
triplet loss



In [8]:
class ContrastiveLoss(nn.Module):
    "Contrastive loss function"

    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2)
        loss_contrastive = torch.mean(
            (1 - label) * torch.pow(euclidean_distance, 2)
            + (label)
            * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2)
        )

        return loss_contrastive

In [9]:
input_size = (1, 64, 64, 5)
downsample_factors =[(2, 2, 1), (2, 2, 1), (2, 2, 1), (2, 2, 1)];
output_classes = 32

# create the model to train
model = Vgg3D(input_size, output_classes,  downsample_factors = downsample_factors)
model = model.to(device)

summary(model, input_size)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1        [-1, 12, 64, 64, 5]             336
       BatchNorm3d-2        [-1, 12, 64, 64, 5]              24
              ReLU-3        [-1, 12, 64, 64, 5]               0
            Conv3d-4        [-1, 12, 64, 64, 5]           3,900
       BatchNorm3d-5        [-1, 12, 64, 64, 5]              24
              ReLU-6        [-1, 12, 64, 64, 5]               0
         MaxPool3d-7        [-1, 12, 32, 32, 5]               0
            Conv3d-8        [-1, 24, 32, 32, 5]           7,800
       BatchNorm3d-9        [-1, 24, 32, 32, 5]              48
             ReLU-10        [-1, 24, 32, 32, 5]               0
           Conv3d-11        [-1, 24, 32, 32, 5]          15,576
      BatchNorm3d-12        [-1, 24, 32, 32, 5]              48
             ReLU-13        [-1, 24, 32, 32, 5]               0
        MaxPool3d-14        [-1, 24, 16

In [10]:
#Training length
epochs = 100

#loss_function = torch.nn.BCELoss()
loss_function = torch.nn.CosineEmbeddingLoss()
#loss_function = ContrastiveLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=0.0005)

# Training Test

# Implementing the Siamese Network

The above training is just to test if the VGG model works for 3D data. Here, the training will take two pairs of images and calculate the loss from both pairs of images.

In [11]:
#%tensorboard --logdir models
logger = SummaryWriter()
#%tensorboard --logdir runs

In [None]:
from tqdm import tqdm

def train(tb_logger = None, log_image_interval = 10):
    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
        
    loss=[] 
    counter=[]
    with gp.build(pipeline_paired), gp.build(pipeline_unpaired):
        for epoch in tqdm(range(epochs)):
            epoch_loss = 0
            epoch_loss_pos = 0
            epoch_loss_neg = 0
            for x in range(500):
                unpaired = pipeline_unpaired.request_batch(request)
                yu = -1 #zero if using contrastive loss, -1 if using cosine similarity
                
                paired = pipeline_paired.request_batch(request)
                yp = 1
                
                unpaired1 = unpaired[raw].data
                unpaired2 = unpaired[raw_shift].data
                paired1 = paired[raw].data
                paired2 = paired[raw_shift].data
                
                unpaired1 = np.reshape(unpaired1, (batch_size,64, 64, 5))
                unpaired2 = np.reshape(unpaired2, (batch_size,64, 64, 5))
                paired1 = np.reshape(paired1, (batch_size,64, 64, 5))
                paired2 = np.reshape(paired2, (batch_size,64, 64, 5))
                
                # unpaired1 = np.reshape(unpaired1, (batch_size,16, 16, 5))
                # unpaired2 = np.reshape(unpaired2, (batch_size,16, 16, 5))
                # paired1 = np.reshape(paired1, (batch_size,16, 16, 5))
                # paired2 = np.reshape(paired2, (batch_size,16, 16, 5))
                
                unpaired1 = np.expand_dims(unpaired1, axis =1)
                unpaired2 = np.expand_dims(unpaired2, axis=1)
                paired1 = np.expand_dims(paired1, axis =1)
                paired2 = np.expand_dims(paired2, axis=1)

                unpaired1 = torch.from_numpy(unpaired1).to(device).float()
                unpaired2 = torch.from_numpy(unpaired2).to(device).float()
                yu = torch.from_numpy(np.array([yu])).to(device).float()
                
                paired1 = torch.from_numpy(paired1).to(device).float()
                paired2 = torch.from_numpy(paired2).to(device).float() 
                yp = torch.from_numpy(np.array([yp])).to(device).float()

                optimizer.zero_grad()
                
                predp1 = model(paired1)
                predp2 = model(paired2)
                predu1 = model(unpaired1)
                predu2 = model(unpaired2)
                #print(model(unpaired1).shape)
                #print(predp1.shape)

                #loss = loss_function(pred, y)
                
                #print(predp1.shape)
                if x%10==0:
                    print(f'step:{x}')

                loss_contrastivep = loss_function(predp1,predp2,yp)
                loss_contrastiveu = loss_function(predu1,predu2,yu)

                loss_contrastivep.backward()
                loss_contrastiveu.backward()
                optimizer.step()    
                epoch_loss_pos += loss_contrastivep
                epoch_loss_neg += loss_contrastiveu
                epoch_loss += loss_contrastivep + loss_contrastiveu
                
                
                
                if tb_logger is not None:
                    step = epoch * 10 + x
                    tb_logger.add_scalar(
                        tag="positive_loss", scalar_value=loss_contrastivep.item(), global_step=step
                    )
                    tb_logger.add_scalar(
                        tag="negative_loss", scalar_value=loss_contrastiveu.item(), global_step=step
                    )
                    tb_logger.add_scalar(
                        tag="total_loss", scalar_value=(loss_contrastivep.item()+loss_contrastiveu.item()), global_step = step
                    )
                    # check if we log images in this iteration
                    # if step % log_image_interval == 0:
                    #     tb_logger.add_images(
                    #         tag="in_unpaired1", img_tensor=unpaired1.to("cpu"), global_step=step
                    #     )
                    #     tb_logger.add_images(
                    #         tag="in_unpaired2", img_tensor=unpaired2.to("cpu"), global_step=step
                    #     )
                    #     tb_logger.add_images(
                    #         tag="in_paired1", img_tensor=paired1.to("cpu"), global_step=step
                    #     )
                    #     tb_logger.add_images(
                    #         tag="in_paired2", img_tensor=paired2.to("cpu"), global_step=step
                    #     )


            print(f"epoch {epoch}, total_loss = {epoch_loss}, positive_loss={epoch_loss_pos}, negative_loss={epoch_loss_neg}")
            
            if(epoch % 2 == 0):
                baseP = '/mnt/shared/celltracking/modelstates/'
                machine = 'rerun_32output_noaugment2'
                e = epoch 
                saveP = (baseP+'/'+machine+'/'+str(f'epoch_{e}'))
                torch.save(model.state_dict(), saveP)

    return model

model = train(tb_logger = logger)

  0%|          | 0/100 [00:00<?, ?it/s]

step:0
step:10
step:20
step:30
step:40
step:50
step:60
step:70
step:80
step:90
step:100
step:110
step:120
step:130
step:140
step:150
step:160
step:170
step:180
step:190
step:200
step:210
step:220
step:230
step:240
step:250
step:260
step:270
step:280
step:290
step:300
step:310
step:320
step:330
step:340
step:350
step:360
step:370
step:380
step:390
step:400
step:410
step:420
step:430
step:440
step:450
step:460
step:470
step:480
step:490
epoch 0, total_loss = 267.6469421386719, positive_loss=64.4326171875, negative_loss=203.21424865722656


  1%|          | 1/100 [03:27<5:43:02, 207.90s/it]

step:0
step:10
step:20
step:30
step:40
step:50
step:60
step:70
step:80
step:90
step:100
step:110
step:120
step:130
step:140
step:150
step:160
step:170
step:180
step:190
step:200
step:210
step:220
step:230
step:240
step:250
step:260
step:270
step:280
step:290
step:300
step:310
step:320
step:330
step:340
step:350
step:360
step:370
step:380
step:390
step:400
step:410
step:420
step:430
step:440
step:450
step:460
step:470
step:480
step:490


  2%|▏         | 2/100 [06:50<5:34:35, 204.86s/it]

epoch 1, total_loss = 206.3771514892578, positive_loss=49.03589630126953, negative_loss=157.34132385253906
step:0
step:10
step:20
step:30
step:40
step:50
step:60
step:70
step:80
step:90
step:100
step:110
step:120
step:130
step:140
step:150
step:160
step:170
step:180
step:190
step:200
step:210
step:220
step:230
step:240
step:250
step:260
step:270
step:280
step:290
step:300
step:310
step:320
step:330
step:340
step:350
step:360
step:370
step:380
step:390
step:400
step:410
step:420
step:430
step:440
step:450
step:460
step:470
step:480
step:490
epoch 2, total_loss = 182.1725616455078, positive_loss=34.91307830810547, negative_loss=147.2594757080078


  3%|▎         | 3/100 [10:16<5:32:18, 205.55s/it]

step:0
step:10
step:20
step:30
step:40
step:50
step:60
step:70
step:80
step:90
step:100
step:110
step:120
step:130
step:140
step:150
step:160
step:170
step:180
step:190
step:200
step:210
step:220
step:230
step:240
step:250
step:260
step:270
step:280
step:290
step:300
step:310
step:320
step:330
step:340
step:350
step:360
step:370
step:380
step:390
step:400
step:410
step:420
step:430
step:440
step:450
step:460
step:470
step:480
step:490


  4%|▍         | 4/100 [13:39<5:27:06, 204.44s/it]

epoch 3, total_loss = 178.13720703125, positive_loss=32.917327880859375, negative_loss=145.21990966796875
step:0
step:10
step:20
step:30
step:40
step:50
step:60
step:70
step:80
step:90
step:100
step:110
step:120
step:130
step:140
step:150
step:160
step:170
step:180
step:190
step:200
step:210
step:220
step:230
step:240
step:250
step:260
step:270
step:280
step:290
step:300
step:310
step:320
step:330
step:340
step:350
step:360
step:370
step:380
step:390
step:400
step:410
step:420
step:430
step:440
step:450
step:460
step:470
step:480
step:490
epoch 4, total_loss = 155.87277221679688, positive_loss=36.99220657348633, negative_loss=118.88065338134766


  5%|▌         | 5/100 [17:05<5:24:32, 204.98s/it]

step:0
step:10
step:20
step:30
step:40
step:50
step:60
step:70
step:80
step:90
step:100
step:110
step:120
step:130
step:140
step:150
step:160
step:170
step:180
step:190
step:200
step:210
step:220
step:230
step:240
step:250
step:260
step:270
step:280
step:290
step:300
step:310
step:320
step:330
step:340
step:350
step:360
step:370
step:380
step:390
step:400
step:410
step:420
step:430
step:440
step:450
step:460
step:470
step:480
step:490


  6%|▌         | 6/100 [20:30<5:20:50, 204.79s/it]

epoch 5, total_loss = 159.43455505371094, positive_loss=40.7009162902832, negative_loss=118.73365020751953
step:0
step:10
step:20
step:30
step:40
step:50
step:60
step:70
step:80
step:90
step:100
step:110
step:120
step:130
step:140
step:150
step:160
step:170
step:180
step:190
step:200
step:210
step:220
step:230
step:240
step:250
step:260
step:270
step:280
step:290
step:300
step:310
step:320
step:330
step:340
step:350
step:360
step:370
step:380
step:390
step:400
step:410
step:420
step:430
step:440
step:450
step:460
step:470
step:480
step:490
epoch 6, total_loss = 146.82620239257812, positive_loss=29.079565048217773, negative_loss=117.74665832519531


  7%|▋         | 7/100 [23:54<5:17:10, 204.63s/it]

step:0
step:10
step:20
step:30
step:40
step:50
step:60
step:70
step:80
step:90
step:100
step:110
step:120
step:130
step:140
step:150
step:160
step:170
step:180
step:190
step:200
step:210
step:220
step:230
step:240
step:250
step:260
step:270
step:280
step:290
step:300
step:310
step:320
step:330
step:340
step:350
step:360
step:370
step:380
step:390
step:400
step:410
step:420
step:430
step:440
step:450
step:460
step:470
step:480
step:490


  8%|▊         | 8/100 [27:14<5:11:27, 203.13s/it]

epoch 7, total_loss = 143.37930297851562, positive_loss=30.719377517700195, negative_loss=112.66000366210938
step:0
step:10
step:20
step:30
step:40
step:50
step:60
step:70
step:80
step:90
step:100
step:110
step:120
step:130
step:140
step:150
step:160
step:170
step:180
step:190
step:200
step:210
step:220
step:230
step:240
step:250
step:260
step:270
step:280
step:290
step:300
step:310
step:320
step:330
step:340
step:350
step:360
step:370
step:380
step:390
step:400
step:410
step:420
step:430
step:440
step:450
step:460
step:470
step:480
step:490
epoch 8, total_loss = 138.49710083007812, positive_loss=27.233919143676758, negative_loss=111.26311492919922


  9%|▉         | 9/100 [30:36<5:07:36, 202.82s/it]

step:0
step:10
step:20
step:30
step:40
step:50
step:60
step:70
step:80
step:90
step:100
step:110
step:120
step:130
step:140
step:150
step:160
step:170
step:180
step:190
step:200
step:210
step:220
step:230
step:240
step:250
step:260
step:270
step:280
step:290
step:300
step:310
step:320
step:330
step:340
step:350
step:360
step:370
step:380
step:390
step:400
step:410
step:420
step:430
step:440
step:450
step:460
step:470
step:480
step:490


 10%|█         | 10/100 [33:57<5:03:11, 202.13s/it]

epoch 9, total_loss = 127.3759994506836, positive_loss=28.70534324645996, negative_loss=98.6706314086914
step:0
step:10
step:20
step:30
step:40
step:50
step:60
step:70
step:80
step:90
step:100
step:110
step:120
step:130
step:140
step:150
step:160
step:170
step:180
step:190
step:200
step:210
step:220
step:230
step:240
step:250
step:260
step:270
step:280
step:290
step:300
step:310
step:320
step:330
step:340
step:350
step:360
step:370
step:380
step:390
step:400
step:410
step:420
step:430
step:440
step:450
step:460
step:470
step:480
step:490
epoch 10, total_loss = 115.88424682617188, positive_loss=24.79194450378418, negative_loss=91.09235382080078


 11%|█         | 11/100 [37:20<5:00:37, 202.66s/it]

step:0
step:10
step:20
step:30
step:40
step:50
step:60
step:70
step:80
step:90
step:100
step:110
step:120
step:130
step:140
step:150
step:160
step:170
step:180
step:190
step:200
step:210
step:220
step:230
step:240
step:250
step:260
step:270
step:280
step:290
step:300
step:310
step:320
step:330
step:340
step:350
step:360
step:370
step:380
step:390
step:400
step:410
step:420
step:430
step:440
step:450
step:460
step:470
step:480
step:490


 12%|█▏        | 12/100 [40:42<4:56:42, 202.30s/it]

epoch 11, total_loss = 117.07051086425781, positive_loss=27.531469345092773, negative_loss=89.53900146484375
step:0
step:10
step:20
step:30
step:40
step:50
step:60
step:70
step:80
step:90
step:100
step:110
step:120
step:130
step:140
step:150
step:160
step:170
step:180
step:190
step:200
step:210
step:220
step:230
step:240
step:250
step:260
step:270
step:280
step:290
step:300
step:310
step:320
step:330
step:340
step:350
step:360
step:370
step:380
step:390
step:400
step:410
step:420
step:430
step:440
step:450
step:460
step:470
step:480
step:490
epoch 12, total_loss = 118.53730773925781, positive_loss=23.237375259399414, negative_loss=95.2999267578125


 13%|█▎        | 13/100 [44:04<4:53:10, 202.19s/it]

step:0
step:10
step:20
step:30
step:40
step:50
step:60
step:70
step:80
step:90
step:100
step:110
step:120
step:130
step:140
step:150
step:160
step:170
step:180
step:190
step:200
step:210
step:220
step:230
step:240
step:250
step:260
step:270
step:280
step:290
step:300
step:310
step:320
step:330
step:340
step:350
step:360
step:370
step:380
step:390
step:400
step:410
step:420
step:430
step:440
step:450
step:460
step:470
step:480
step:490


 14%|█▍        | 14/100 [47:22<4:47:56, 200.89s/it]

epoch 13, total_loss = 112.61827850341797, positive_loss=25.57949447631836, negative_loss=87.03882598876953
step:0
step:10
step:20
step:30
step:40
step:50
step:60
step:70
step:80
step:90
step:100
step:110
step:120
step:130
step:140
step:150
step:160
step:170
step:180
step:190
step:200
step:210
step:220
step:230
step:240
step:250
step:260
step:270
step:280
step:290
step:300
step:310
step:320
step:330
step:340
step:350
step:360
step:370
step:380
step:390
step:400
step:410
step:420
step:430
step:440
step:450
step:460
step:470
step:480
step:490
epoch 14, total_loss = 116.9760513305664, positive_loss=23.34086799621582, negative_loss=93.63517761230469


 15%|█▌        | 15/100 [50:46<4:46:14, 202.05s/it]

step:0
step:10
step:20
step:30
step:40
step:50
step:60
step:70
step:80
step:90
step:100
step:110
step:120
step:130
step:140
step:150
step:160
step:170
step:180
step:190
step:200
step:210
step:220
step:230
step:240
step:250
step:260
step:270
step:280
step:290
step:300
step:310
step:320
step:330
step:340
step:350
step:360
step:370
step:380
step:390
step:400
step:410
step:420
step:430
step:440
step:450
step:460
step:470
step:480
step:490


 16%|█▌        | 16/100 [54:05<4:41:34, 201.13s/it]

epoch 15, total_loss = 118.84037780761719, positive_loss=24.382349014282227, negative_loss=94.45803833007812
step:0
step:10
step:20
step:30
step:40
step:50
step:60
step:70
step:80
step:90
step:100
step:110
step:120
step:130
step:140
step:150
step:160
step:170
step:180
step:190
step:200
step:210
step:220
step:230
step:240
step:250
step:260
step:270
step:280
step:290
step:300
step:310
step:320
step:330
step:340
step:350
step:360
step:370
step:380
step:390
step:400
step:410
step:420
step:430
step:440
step:450
step:460
step:470
step:480
step:490
epoch 16, total_loss = 117.77302551269531, positive_loss=22.131269454956055, negative_loss=95.64173889160156


 17%|█▋        | 17/100 [57:23<4:36:54, 200.18s/it]

step:0
step:10
step:20
step:30
step:40
step:50
step:60
step:70
step:80
step:90
step:100
step:110
step:120
step:130
step:140
step:150
step:160
step:170
step:180
step:190
step:200
step:210
step:220
step:230
step:240
step:250
step:260
step:270
step:280
step:290
step:300
step:310
step:320
step:330
step:340
step:350
step:360
step:370
step:380
step:390
step:400
step:410
step:420
step:430
step:440
step:450
step:460
step:470
step:480
step:490


 18%|█▊        | 18/100 [1:00:39<4:31:40, 198.79s/it]

epoch 17, total_loss = 120.33572387695312, positive_loss=24.288602828979492, negative_loss=96.04716491699219
step:0
step:10
step:20
step:30
step:40
step:50
step:60
step:70
step:80
step:90
step:100
step:110
step:120
step:130
step:140
step:150
step:160
step:170
step:180
step:190
step:200
step:210
step:220
step:230
step:240
step:250
step:260
step:270
step:280
step:290
step:300
step:310
step:320
step:330
step:340
step:350
step:360
step:370
step:380
step:390
step:400
step:410
step:420
step:430
step:440
step:450
step:460
step:470
step:480
step:490
epoch 18, total_loss = 109.9270248413086, positive_loss=22.008520126342773, negative_loss=87.91846466064453


 19%|█▉        | 19/100 [1:04:08<4:32:32, 201.88s/it]

step:0
step:10
step:20
step:30
step:40
step:50
step:60
step:70
step:80
step:90
step:100
step:110
step:120
step:130
step:140
step:150
step:160
step:170
step:180
step:190
step:200
step:210
step:220
step:230
step:240
step:250
step:260
step:270
step:280
step:290
step:300
step:310
step:320
step:330
step:340
step:350
step:360
step:370
step:380
step:390
step:400
step:410
step:420
step:430
step:440
step:450
step:460
step:470
step:480
step:490


 20%|██        | 20/100 [1:07:35<4:31:22, 203.54s/it]

epoch 19, total_loss = 111.98565673828125, positive_loss=18.914508819580078, negative_loss=93.0711669921875
step:0
step:10
step:20
step:30
step:40
step:50
step:60
step:70
step:80
step:90
step:100
step:110
step:120
step:130
step:140
step:150
step:160
step:170
step:180
step:190
step:200
step:210
step:220
step:230
step:240
step:250
step:260
step:270
step:280
step:290
step:300
step:310
step:320
step:330
step:340
step:350
step:360
step:370
step:380
step:390
step:400
step:410
step:420
step:430
step:440
step:450
step:460
step:470
step:480
step:490
epoch 20, total_loss = 112.0408706665039, positive_loss=22.230113983154297, negative_loss=89.8106918334961


 21%|██        | 21/100 [1:11:00<4:28:15, 203.74s/it]

step:0
step:10
step:20
step:30
step:40
step:50
step:60
step:70
step:80
step:90
step:100
step:110
step:120
step:130
step:140
step:150
step:160
step:170
step:180
step:190
step:200
step:210
step:220
step:230
step:240
step:250
step:260
step:270
step:280
step:290
step:300
step:310
step:320
step:330
step:340
step:350
step:360
step:370
step:380
step:390
step:400
step:410
step:420
step:430
step:440
step:450
step:460


Exception in getPaired while processing request
	raw: ROI: [0:1, 0:5, 0:64, 0:64] (1, 5, 64, 64), voxel size: None, interpolatable: None, non-spatial: False, dtype: None, placeholder: False
	raw_shift: ROI: [0:1, 0:5, 0:64, 0:64] (1, 5, 64, 64), voxel size: None, interpolatable: None, non-spatial: False, dtype: None, placeholder: False
 
Batch returned so far:
None
Traceback (most recent call last):
  File "/home/bourquea/miniconda3/envs/celltracking/lib/python3.9/site-packages/gunpowder/nodes/batch_provider.py", line 187, in request_batch
    batch = self.provide(upstream_request)
  File "/home/bourquea/miniconda3/envs/celltracking/lib/python3.9/site-packages/gunpowder/nodes/hdf5like_source_base.py", line 108, in provide
    self.__read(data_file, self.datasets[array_key], dataset_roi),
  File "/home/bourquea/miniconda3/envs/celltracking/lib/python3.9/site-packages/gunpowder/nodes/hdf5like_source_base.py", line 189, in __read
    c = len(data_file[ds_name].shape) - self.ndims
  File "

step:470


Traceback (most recent call last):
  File "/home/bourquea/miniconda3/envs/celltracking/lib/python3.9/multiprocessing/queues.py", line 244, in _feed
    obj = _ForkingPickler.dumps(obj)
  File "/home/bourquea/miniconda3/envs/celltracking/lib/python3.9/multiprocessing/reduction.py", line 51, in dumps
    cls(buf, protocol).dump(obj)
AttributeError: Can't pickle local object 'subgraph_view.<locals>.reverse_edge'
Traceback (most recent call last):
  File "/home/bourquea/miniconda3/envs/celltracking/lib/python3.9/multiprocessing/queues.py", line 244, in _feed
    obj = _ForkingPickler.dumps(obj)
  File "/home/bourquea/miniconda3/envs/celltracking/lib/python3.9/multiprocessing/reduction.py", line 51, in dumps
    cls(buf, protocol).dump(obj)
AttributeError: Can't pickle local object 'subgraph_view.<locals>.reverse_edge'


step:480
step:490


 22%|██▏       | 22/100 [1:14:30<4:27:30, 205.77s/it]

epoch 21, total_loss = 109.78729248046875, positive_loss=18.99485969543457, negative_loss=90.79242706298828
step:0
step:10
step:20
step:30
step:40
step:50
step:60
step:70
step:80
step:90
step:100
step:110
step:120
step:130
step:140
step:150
step:160
step:170
step:180
step:190
step:200
step:210
step:220
step:230
step:240
step:250
step:260
step:270
step:280
step:290
step:300
step:310
step:320
step:330
step:340
step:350
step:360
step:370
step:380
step:390
step:400
step:410
step:420
step:430
step:440
step:450
step:460
step:470
step:480
step:490
epoch 22, total_loss = 102.9512710571289, positive_loss=17.873737335205078, negative_loss=85.07759094238281


 23%|██▎       | 23/100 [1:18:20<4:33:22, 213.02s/it]

step:0
step:10
step:20
step:30
step:40
step:50
step:60
step:70
step:80
step:90
step:100
step:110
step:120
step:130
step:140
step:150
step:160
step:170
step:180
step:190
step:200
step:210
step:220
step:230
step:240
step:250
step:260
step:270
step:280
step:290
step:300
step:310
step:320
step:330
step:340
step:350
step:360
step:370
step:380
step:390
step:400
step:410
step:420
step:430
step:440
step:450
step:460
step:470
step:480
step:490


 24%|██▍       | 24/100 [1:21:43<4:26:06, 210.09s/it]

epoch 23, total_loss = 106.94430541992188, positive_loss=19.14613151550293, negative_loss=87.79814910888672
step:0
step:10
step:20
step:30
step:40
step:50
step:60
step:70
step:80
step:90
step:100
step:110
step:120
step:130
step:140
step:150
step:160
step:170
step:180
step:190
step:200
step:210
step:220
step:230
step:240
step:250
step:260
step:270
step:280
step:290
step:300
step:310
step:320
step:330
step:340
step:350
step:360
step:370
step:380
step:390
step:400
step:410
step:420
step:430
step:440
step:450
step:460
step:470
step:480
step:490
epoch 24, total_loss = 102.54338073730469, positive_loss=17.048311233520508, negative_loss=85.49504852294922


 25%|██▌       | 25/100 [1:25:14<4:22:50, 210.27s/it]

step:0
step:10
step:20
step:30
step:40
step:50
step:60
step:70
step:80
step:90
step:100
step:110
step:120
step:130
step:140
step:150
step:160
step:170
step:180
step:190
step:200
step:210
step:220
step:230
step:240
step:250
step:260
step:270
step:280
step:290
step:300
step:310
step:320
step:330
step:340
step:350
step:360
step:370
step:380
step:390
step:400
step:410
step:420
step:430
step:440
step:450
step:460
step:470
step:480
step:490


 26%|██▌       | 26/100 [1:28:43<4:18:53, 209.91s/it]

epoch 25, total_loss = 93.27140808105469, positive_loss=17.118253707885742, negative_loss=76.15310668945312
step:0
step:10
step:20
step:30
step:40
step:50
step:60
step:70
step:80
step:90
step:100
step:110
step:120
step:130
step:140
step:150
step:160
step:170
step:180
step:190
step:200
step:210
step:220
step:230
step:240
step:250
step:260
step:270
step:280
step:290
step:300
step:310
step:320
step:330
step:340
step:350
step:360
step:370
step:380
step:390
step:400
step:410
step:420
step:430
step:440
step:450
step:460
step:470
step:480
step:490
epoch 26, total_loss = 104.63494873046875, positive_loss=23.88242530822754, negative_loss=80.75249481201172


 27%|██▋       | 27/100 [1:32:11<4:14:32, 209.21s/it]

step:0
step:10
step:20
step:30
step:40
step:50
step:60
step:70
step:80
step:90
step:100
step:110
step:120
step:130
step:140
step:150
step:160
step:170
step:180
step:190
step:200
step:210
step:220
step:230
step:240
step:250
step:260
step:270
step:280
step:290
step:300
step:310
step:320
step:330
step:340
step:350
step:360
step:370
step:380
step:390
step:400
step:410
step:420
step:430
step:440
step:450
step:460
step:470
step:480
step:490


 28%|██▊       | 28/100 [1:35:36<4:09:33, 207.97s/it]

epoch 27, total_loss = 97.82340240478516, positive_loss=20.963653564453125, negative_loss=76.85966491699219
step:0
step:10
step:20
step:30
step:40
step:50
step:60
step:70
step:80
step:90
step:100
step:110
step:120
step:130
step:140
step:150
step:160
step:170
step:180
step:190
step:200
step:210
step:220
step:230
step:240
step:250
step:260
step:270
step:280
step:290
step:300
step:310
step:320
step:330
step:340
step:350
step:360
step:370
step:380
step:390
step:400
step:410
step:420
step:430
step:440
step:450
step:460
step:470
step:480
step:490
epoch 28, total_loss = 96.49504089355469, positive_loss=22.41478729248047, negative_loss=74.08024597167969


 29%|██▉       | 29/100 [1:39:08<4:07:28, 209.14s/it]

step:0
step:10
step:20
step:30
step:40
step:50
step:60
step:70
step:80
step:90
step:100
step:110
step:120
step:130
step:140
step:150
step:160
step:170
step:180
step:190
step:200
step:210
step:220
step:230
step:240
step:250
step:260
step:270
step:280
step:290
step:300
step:310
step:320
step:330
step:340
step:350
step:360
step:370
step:380
step:390
step:400
step:410
step:420
step:430
step:440
step:450
step:460
step:470
step:480
step:490


 30%|███       | 30/100 [1:42:22<3:58:49, 204.71s/it]

epoch 29, total_loss = 97.07852172851562, positive_loss=19.40599250793457, negative_loss=77.67253875732422
step:0
step:10
step:20
step:30
step:40
step:50
step:60
step:70
step:80
step:90
step:100
step:110
step:120
step:130
step:140
step:150
step:160
step:170
step:180
step:190
step:200
step:210
step:220
step:230
step:240
step:250
step:260
step:270
step:280
step:290
step:300
step:310
step:320
step:330
step:340
step:350
step:360
step:370
step:380
step:390
step:400
step:410
step:420
step:430
step:440
step:450
step:460
step:470
step:480
step:490
epoch 30, total_loss = 90.72106170654297, positive_loss=21.35387420654297, negative_loss=69.36717224121094


 31%|███       | 31/100 [1:45:50<3:56:26, 205.60s/it]

step:0
step:10
step:20
step:30
step:40
step:50
step:60
step:70
step:80
step:90
step:100
step:110
step:120
step:130
step:140
step:150
step:160
step:170
step:180
step:190
step:200
step:210
step:220
step:230
step:240
step:250
step:260
step:270
step:280
step:290
step:300
step:310
step:320
step:330
step:340
step:350
step:360
step:370
step:380
step:390
step:400
step:410
step:420
step:430
step:440
step:450
step:460
step:470
step:480
step:490


 32%|███▏      | 32/100 [1:49:05<3:49:40, 202.66s/it]

epoch 31, total_loss = 92.47034454345703, positive_loss=19.54529571533203, negative_loss=72.92510986328125
step:0
step:10
step:20
step:30
step:40
step:50
step:60
step:70
step:80
step:90
step:100
step:110
step:120
step:130
step:140
step:150
step:160
step:170
step:180
step:190
step:200
step:210
step:220
step:230
step:240
step:250
step:260
step:270
step:280
step:290
step:300
step:310
step:320
step:330
step:340
step:350
step:360
step:370
step:380
step:390
step:400
step:410
step:420
step:430
step:440
step:450
step:460
step:470
step:480
step:490
epoch 32, total_loss = 94.82070922851562, positive_loss=19.00498390197754, negative_loss=75.8156509399414


 33%|███▎      | 33/100 [1:52:25<3:45:15, 201.72s/it]

step:0
step:10
step:20
step:30
step:40
step:50
step:60
step:70
step:80
step:90
step:100
step:110
step:120
step:130
step:140
step:150
step:160
step:170
step:180
step:190
step:200
step:210
step:220
step:230
step:240
step:250
step:260
step:270
step:280
step:290
step:300
step:310
step:320
step:330
step:340
step:350
step:360
step:370
step:380
step:390
step:400
step:410
step:420
step:430
step:440
step:450
step:460
step:470
step:480
step:490


 34%|███▍      | 34/100 [1:55:39<3:39:26, 199.50s/it]

epoch 33, total_loss = 92.19413757324219, positive_loss=18.900753021240234, negative_loss=73.2933349609375
step:0
step:10
step:20
step:30
step:40
step:50
step:60
step:70
step:80
step:90
step:100
step:110
step:120
step:130
step:140
step:150
step:160
step:170
step:180
step:190
step:200
step:210
step:220
step:230
step:240
step:250
step:260
step:270
step:280
step:290
step:300
step:310
step:320
step:330
step:340
step:350
step:360
step:370
step:380
step:390
step:400
step:410
step:420
step:430
step:440
step:450
step:460
step:470
step:480
step:490
epoch 34, total_loss = 94.67361450195312, positive_loss=20.069244384765625, negative_loss=74.60440063476562


 35%|███▌      | 35/100 [1:59:01<3:36:53, 200.20s/it]

step:0
step:10
step:20
step:30
step:40
step:50
step:60
step:70
step:80
step:90
step:100
step:110
step:120
step:130
step:140
step:150
step:160
step:170
step:180
step:190
step:200
step:210
step:220
step:230
step:240
step:250
step:260
step:270
step:280
step:290
step:300
step:310
step:320
step:330
step:340
step:350
step:360
step:370
step:380
step:390
step:400
step:410
step:420
step:430
step:440
step:450
step:460
step:470
step:480
step:490


 36%|███▌      | 36/100 [2:02:14<3:31:20, 198.14s/it]

epoch 35, total_loss = 85.47857666015625, positive_loss=17.80783462524414, negative_loss=67.67076873779297
step:0
step:10
step:20
step:30
step:40
step:50
step:60
step:70
step:80
step:90
step:100
step:110
step:120
step:130
step:140
step:150
step:160
step:170
step:180
step:190
step:200
step:210
step:220
step:230
step:240
step:250
step:260
step:270
step:280
step:290
step:300
step:310
step:320
step:330
step:340
step:350
step:360
step:370
step:380
step:390
step:400
step:410
step:420
step:430
step:440
step:450
step:460
step:470
step:480
step:490
epoch 36, total_loss = 96.94127655029297, positive_loss=19.501110076904297, negative_loss=77.44011688232422


 37%|███▋      | 37/100 [2:05:33<3:28:09, 198.25s/it]

step:0
step:10
step:20
step:30
step:40
step:50
step:60
step:70
step:80
step:90
step:100
step:110
step:120
step:130
step:140
step:150
step:160
step:170
step:180
step:190
step:200
step:210
step:220
step:230
step:240
step:250
step:260
step:270
step:280
step:290
step:300
step:310
step:320
step:330
step:340
step:350
step:360
step:370
step:380
step:390
step:400
step:410
step:420
step:430
step:440
step:450
step:460
step:470
step:480
step:490


 38%|███▊      | 38/100 [2:09:02<3:28:18, 201.58s/it]

epoch 37, total_loss = 86.97100830078125, positive_loss=17.269248962402344, negative_loss=69.7017593383789
step:0
step:10
step:20
step:30
step:40
step:50
step:60
step:70
step:80
step:90
step:100
step:110
step:120
step:130
step:140
step:150
step:160
step:170
step:180
step:190
step:200
step:210
step:220
step:230
step:240
step:250
step:260
step:270
step:280
step:290
step:300
step:310
step:320
step:330
step:340
step:350
step:360
step:370
step:380
step:390
step:400
step:410
step:420
step:430
step:440
step:450
step:460
step:470
step:480
step:490
epoch 38, total_loss = 92.40898132324219, positive_loss=18.381587982177734, negative_loss=74.02738189697266


 39%|███▉      | 39/100 [2:12:27<3:25:58, 202.61s/it]

step:0
step:10
step:20
step:30
step:40
step:50
step:60
step:70
step:80
step:90
step:100
step:110
step:120
step:130
step:140
step:150
step:160
step:170
step:180
step:190
step:200
step:210
step:220
step:230
step:240
step:250
step:260
step:270
step:280
step:290
step:300
step:310
step:320
step:330
step:340
step:350
step:360
step:370
step:380
step:390
step:400
step:410
step:420
step:430
step:440
step:450
step:460
step:470
step:480
step:490


 40%|████      | 40/100 [2:15:39<3:19:17, 199.30s/it]

epoch 39, total_loss = 89.88488006591797, positive_loss=16.4658203125, negative_loss=73.41902923583984
step:0
step:10
step:20
step:30
step:40
step:50
