In [28]:
import os
import numpy as np
from torch.utils.data import TensorDataset, DataLoader, Dataset
import torch.nn.functional as F
import torch

from math import floor
import pandas as pd
import scipy.io as sio
import csv
from tqdm import tqdm
import matplotlib.pyplot as plt

from sklearn.metrics import confusion_matrix,f1_score

import torch.nn as nn
from typing import List, Optional, Tuple, Union
import torch.distributed as dist
import torchvision.transforms as T
import pytorch_lightning as pl

In [29]:
def load_hasc_ds(path, window, mode='train'):

    X, lbl = extract_windows_hasc(path, window, 5)

    if mode == "all":
        return X, lbl

    train_size = floor(0.8 * X.shape[0])
    print(train_size)
    if mode =="train":
        trainx = X[0:train_size]
        trainlbl =lbl[0:train_size]
        idx = np.arange(trainx.shape[0])
        np.random.shuffle(idx)
        trainx = trainx[idx,]
        trainlbl = trainlbl[idx]
        print('train samples : ', train_size)
        return trainx, trainlbl

    else:
        testx = X[train_size:]
        testlbl = lbl[train_size:]
        print('test shape {} and number of change points {} '.format(testx.shape, len(np.where(testlbl>0)[0])))

        return testx, testlbl



def extract_windows_hasc(path, window_size, step):

    dataset = sio.loadmat(path+"hasc.mat")
    window_size
    windows = []
    lbl = []
    first = True
    num_cp = 0
    x = np.array(dataset['Y'])
    cp = np.array(dataset['L'])

    ts = np.sqrt(np.power(x[:, 0], 2) + np.power(x[:, 1], 2) + np.power(x[:, 2], 2))
    for i in range(0, ts.shape[0] - window_size, windows_size // 5):
        windows.append(np.array(ts[i:i + window_size]))
        is_cp = np.where(cp[i:i + window_size] == 1)[0]
        if is_cp.size == 0:
            is_cp = [0]
        else:
            num_cp += 1
        lbl.append(is_cp[0])

    print("number of samples : {} /  number of samples with change point : {}".format(len(windows), num_cp))
    windows = np.array(windows)

    return windows, np.array(lbl)

def load_usc_ds(path, window, mode='train'):
    X, lbl = extract_windows_usc(path, window, mode)

    if mode == "all":
        return X, lbl
    train_size = int(floor(0.8 * X.shape[0]))
    if mode == "train":
        trainx = X[0:train_size]
        trainlbl = lbl[0:train_size]
        idx = np.arange(trainx.shape[0])
        np.random.shuffle(idx)
        trainx = trainx[idx,]
        trainlbl = trainlbl[idx]
        print('train samples : ', train_size)
        return trainx, trainlbl

    else:
        testx = X[train_size:]
        testlbl = lbl[train_size:]
        print('test shape {} and number of change points {} '.format(testx.shape, len(np.where(testlbl > 0)[0])))

        return testx, testlbl


def extract_windows_usc(path, window_size, mode="train"):
    windows = []
    lbl = []
    dataset = sio.loadmat(path+"usc.mat")

    ts = np.array(dataset['Y'])
    ts = ts[:,0]
    cp = np.array(dataset['L'])
    cp = cp[:,0]

    num_cp = 0

    for i in range(0, ts.shape[0] - window_size, window_size // 5):
        windows.append(np.array(ts[i:i + window_size]))
        is_cp = np.where(cp[i:i + window_size] == 1)[0]
        if is_cp.size == 0:
            is_cp = [0]
        else:
            num_cp += 1
        lbl.append(is_cp[0])

    print("number of samples : {} /  number of samples with change point : {}".format(len(windows), num_cp))
    windows = np.array(windows)

    return windows, np.array(lbl)

def load_dataset(path, ds_name, win, bs, mode="train"):
    if ds_name == 'HASC':
        trainx, trainlbl = load_hasc_ds(path, window = win, mode=mode)
    elif ds_name == "USC":
        trainx, trainlbl = load_usc_ds(path, window= win, mode=mode)
    else:
        raise ValueError("Undefined Dataset.")

    trainlbl = trainlbl.reshape((trainlbl.shape[0], 1))
    print(trainx.shape, trainlbl.shape)
    dataset = np.concatenate((trainlbl, trainx), 1)

    print("dataset shape : ", dataset.shape)
    if mode == "test":
        return dataset

    train_ds = TensorDataset(torch.from_numpy(dataset))
    return train_ds

def load_future(path, ds_name, win, bs, mode="train"):
    if ds_name == 'HASC':
        trainx, trainlbl = load_hasc_ds(path, window = 2 * win, mode=mode)
    elif ds_name == "USC":
        trainx, trainlbl = load_usc_ds(path, window=2 * win, mode=mode)
    else:
        raise ValueError("Undefined Dataset.")

    trainlbl = trainlbl.reshape((trainlbl.shape[0], 1))
    print(trainx.shape, trainlbl.shape)
    dataset = np.concatenate((trainlbl, trainx), 1)

    print("dataset shape : ", dataset.shape)
    if mode == "test":
        return dataset

    train_ds = TensorDataset(torch.from_numpy(dataset))
    return train_ds

In [4]:
def estimate_CPs(sim, gt, name, train_name, metric='cosine', threshold=0.5):

    est_cp = np.zeros(sim.shape[0])
    est_cp[np.where(sim < threshold)[0]] = 1
    tn, fp, fn, tp = confusion_matrix(gt, est_cp).ravel()
    f1 = f1_score(gt, est_cp)

    gt_id = np.where(gt == 1)[0]
    print("tn {}, fp {}, fn {}, tp {} ----- f1-score {}".format(tn, fp, fn, tp, f1))

    i = 1
    pos, seq_tp, seq_fn, seq_fp = 0, 0, 0, 0

    while i < gt.shape[0]:
        if gt[i] == 1:
            pos += 1
            j = i
            while gt[i] == 1:
                i += 1

            if np.sum(est_cp[j:i]) > 0:
                seq_tp += 1
                est_cp[j:i] = 0
            else:
                seq_fn += 1

        i += 1

    seq_fp = np.where(np.diff(est_cp) == 1)[0].shape[0]
    seq_f1 = (2 * seq_tp) / (2 * seq_tp + seq_fn + seq_fp)

    print("SEQ : Pos {}, fp {}, fn {}, tp {} ----- f1-score {}".format(pos, seq_fp, seq_fn, seq_tp, seq_f1))
    result = "tn, {}, fp, {}, fn, {}, tp, {}, f1-score, {}, Pos, {}, seqfp, {}, seqfn, {}, seqtp, {}, seqf1, {}\n".format(tn, fp, fn, tp, f1, pos, seq_fp, seq_fn, seq_tp, seq_f1)
    return result

In [5]:
DS_NAME = 'USC'
DATA_PATH = './data/'
OUTPUT_PATH = os.path.join('./output/', DS_NAME)
MODEL_PATH = os.path.join('./output/', "model")
LOSS = 'nce'
SIM = 'cosine'
GPU = 0

WIN = 100
CODE_SIZE = 10
BATCH_SIZE = 32
EPOCHS = 40
LR = 1e-4
TEMP = 0.5
TAU = 0.1
BETA = 1
EVALFREQ = 25
decay_steps = 1000


train_name = "CP2_model_" + DS_NAME + "_T" + str(TEMP) + "_WIN" + str(WIN) + \
             "_BS" + str(BATCH_SIZE) + "_CS" + str(CODE_SIZE) + "_lr" + str(LR) + \
             "_LOSS" + LOSS +  "_SIM" + SIM + "_TAU" + str(TAU) + "_BETA" + str(BETA)
print("------------------------------------>>> " + train_name)

train_ds = load_dataset(DATA_PATH, DS_NAME, WIN, BATCH_SIZE, mode = "train")
test_ds = load_dataset(DATA_PATH, DS_NAME, WIN, BATCH_SIZE, mode = "test")


------------------------------------>>> CP2_model_USC_T0.5_WIN100_BS32_CS10_lr0.0001_LOSSnce_SIMcosine_TAU0.1_BETA1
number of samples : 4677 /  number of samples with change point : 175
train samples :  3741
(3741, 100) (3741, 1)
dataset shape :  (3741, 101)
number of samples : 4677 /  number of samples with change point : 175
test shape (936, 100) and number of change points 35 
(936, 100) (936, 1)
dataset shape :  (936, 101)


In [6]:
class ProjectionHead(nn.Module):
    def __init__(
        self, 
        blocks: List[Tuple[int, int, Optional[nn.Module], Optional[nn.Module]]]
    ):
        super(ProjectionHead, self).__init__()

        layers = []
        for input_dim, output_dim, batch_norm, non_linearity in blocks:
            use_bias = not bool(batch_norm)
            layers.append(nn.Linear(input_dim, output_dim, bias=use_bias))
            if batch_norm:
                layers.append(batch_norm)
            if non_linearity:
                layers.append(non_linearity)
        self.layers = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor):
        return self.layers(x)
        
class BarlowTwinsProjectionHead(ProjectionHead):
    def __init__(self,
                 input_dim: int = 2048,
                 hidden_dim: int = 8192,
                 output_dim: int = 8192):
        super(BarlowTwinsProjectionHead, self).__init__([
            (input_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.ReLU()),
            (hidden_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.ReLU()),
            (hidden_dim, output_dim, None, None),
        ])

In [7]:
class BarlowTwins(nn.Module):
    def __init__(self,
                 num_ftrs: int = 10,
                 proj_hidden_dim: int = 20,
                 out_dim: int = 40):

        super(BarlowTwins, self).__init__()

        self.num_ftrs = num_ftrs
        self.proj_hidden_dim = proj_hidden_dim
        self.out_dim = out_dim

        self.projection_mlp = BarlowTwinsProjectionHead(
            num_ftrs,
            proj_hidden_dim,
            out_dim
        )


    def forward(self,
                x0: torch.Tensor,
                x1: torch.Tensor = None,
                return_features: bool = False):
        
        f0 = x0.flatten(start_dim=1)
        out0 = self.projection_mlp(f0)

        if return_features:
            out0 = (out0, f0)

        if x1 is None:
            return out0

        f1 = x1.flatten(start_dim=1)
        out1 = self.projection_mlp(f1)

        if return_features:
            out1 = (out1, f1)

        return out0, out1

In [8]:
class BarlowTwinsLoss(torch.nn.Module):
    def __init__(
        self, 
        lambda_param: float = 5e-3, 
        gather_distributed : bool = False
    ):

        super(BarlowTwinsLoss, self).__init__()
        self.lambda_param = lambda_param
        self.gather_distributed = gather_distributed

    def forward(self, z_a: torch.Tensor, z_b: torch.Tensor) -> torch.Tensor:

        device = z_a.device

        z_a_norm = (z_a - z_a.mean(0)) / z_a.std(0)
        z_b_norm = (z_b - z_b.mean(0)) / z_b.std(0)

        N = z_a.size(0)
        D = z_a.size(1)

        c = torch.mm(z_a_norm.T, z_b_norm) / N 

        if self.gather_distributed and dist.is_initialized():
            world_size = dist.get_world_size()
            if world_size > 1:
                c = c / world_size
                dist.all_reduce(c)

        c_diff = (c - torch.eye(D, device=device)).pow(2)
        c_diff[~torch.eye(D, dtype=bool)] *= self.lambda_param
        loss = c_diff.sum()

        return loss

In [22]:
class ResidualBlock(nn.Module):
    
    def __init__(self,
                 in_channels,
                 dilation_rate: int,
                 nb_filters: int,
                 kernel_size: int,
                 dropout_rate: float = 0, 
                 use_batch_norm: bool = False,
                 use_layer_norm: bool = False,
                 use_weight_norm: bool = False, 
                 training: bool = True):
        """Defines the residual block for the WaveNet TCN
        Args:
            dilation_rate: The dilation power of 2 we are using for this residual block
            nb_filters: The number of convolutional filters to use in this block
            kernel_size: The size of the convolutional kernel
            activation: The final activation used in o = Activation(x + F(x))
            dropout_rate: Float between 0 and 1. Fraction of the input units to drop.
        """

        self.dilation_rate = dilation_rate
        self.nb_filters = nb_filters
        self.kernel_size = kernel_size
        self.dropout_rate = dropout_rate
        # for causal padding
        self.padding = (self.kernel_size - 1) * self.dilation_rate
        
        self.use_batch_norm = use_batch_norm
        self.use_layer_norm = use_layer_norm
        self.use_weight_norm = use_weight_norm
        
        self.training = training

        super(ResidualBlock, self).__init__()
        
        self.conv_1 = nn.Conv1d(in_channels, self.nb_filters, self.kernel_size, 
                                padding=0, dilation=self.dilation_rate)        
        if self.use_weight_norm:
            weight_norm(self.conv_1) 
        self.bn_1 = nn.BatchNorm1d(self.nb_filters)
        self.ln_1 = nn.LayerNorm(self.nb_filters)              
        self.relu_1 = nn.ReLU()

        self.conv_2 = nn.Conv1d(self.nb_filters, self.nb_filters, self.kernel_size, 
                                padding=0, dilation=self.dilation_rate)        
        if self.use_weight_norm:
            weight_norm(self.conv_1)    
        self.bn_2 = nn.BatchNorm1d(self.nb_filters)
        self.ln_2 = nn.LayerNorm(self.nb_filters)              
        self.relu_2 = nn.ReLU()        
        
        self.conv_block = nn.Sequential()
        self.downsample = nn.Conv1d(in_channels, self.nb_filters, kernel_size=1) if in_channels != self.nb_filters else nn.Identity()
        
        self.relu = nn.ReLU()  
                
        self.init_weights()
        
        
    def init_weights(self):
        # in the realization, they use random normal initialization
        torch.nn.init.normal_(self.conv_1.weight, mean=0, std=0.05)
        torch.nn.init.zeros_(self.conv_1.bias)            
        
        torch.nn.init.normal_(self.conv_2.weight, mean=0, std=0.05)
        torch.nn.init.zeros_(self.conv_2.bias)            
        
        if isinstance(self.downsample, nn.Conv1d):         
            torch.nn.init.normal_(self.downsample.weight, mean=0, std=0.05)
            torch.nn.init.zeros_(self.downsample.bias)                    
            
    def forward(self, inp):
        # inp batch, channels, time
        ######################
        # do causal padding        
        out = F.pad(inp, (self.padding, 0))
        out = self.conv_1(out)
        
        if self.use_batch_norm:
            out = self.bn_1(out)
        elif self.use_layer_norm:
            out = self.ln_1(out)        
        out = self.relu_1(out)
        
        # spatial dropout
        out = out.permute(0, 2, 1)   # convert to [batch, time, channels]
        out = F.dropout2d(out, self.dropout_rate, training=self.training)        
        out = out.permute(0, 2, 1)   # back to [batch, channels, time]    
        
        #######################
        # do causal padding
        out = F.pad(out, (self.padding, 0))
        out = self.conv_2(out)
        if self.use_batch_norm:
            out = self.bn_2(out)
        elif self.use_layer_norm:
            out = self.ln_2(out)
        out = self.relu_2(out)            
        out = self.relu_2(out)    
        # spatial dropout
        # out batch, channels, time 
        
        out = out.permute(0, 2, 1)   # convert to [batch, time, channels]
        out = F.dropout2d(out, self.dropout_rate, training=self.training)
        out = out.permute(0, 2, 1)   # back to [batch, channels, time]            
        
        #######################        
        skip_out = self.downsample(inp)
        #######################
        res = self.relu(out + skip_out)
        return res, skip_out
    
# only causal padding
# only return sequence = True
    
class TCN(nn.Module):        
    def __init__(self,
                 in_channels=1,
                 nb_filters=64,
                 kernel_size=3,
                 nb_stacks=1,
                 dilations=(1, 2, 4, 8, 16, 32),
                 use_skip_connections=True,
                 dropout_rate=0.0, 
                 use_batch_norm: bool = False,
                 use_layer_norm: bool = False, 
                 use_weight_norm: bool = False):

        super(TCN, self).__init__()
        
        self.dropout_rate = dropout_rate
        self.use_skip_connections = use_skip_connections
        self.dilations = dilations
        self.nb_stacks = nb_stacks
        self.kernel_size = kernel_size
        self.nb_filters = nb_filters
        
        self.use_batch_norm = use_batch_norm
        self.use_layer_norm = use_layer_norm
        self.use_weight_norm = use_weight_norm
        self.in_channels = in_channels
        if self.use_batch_norm + self.use_layer_norm + self.use_weight_norm > 1:
            raise ValueError('Only one normalization can be specified at once.')        
        
        self.residual_blocks = []        
        res_block_filters = 0
        for s in range(self.nb_stacks):
            for i, d in enumerate(self.dilations):
                in_channels = self.in_channels if i + s == 0 else res_block_filters                
                res_block_filters = self.nb_filters[i] if isinstance(self.nb_filters, list) else self.nb_filters
                self.residual_blocks.append(ResidualBlock(in_channels=in_channels, 
                                                          dilation_rate=d,
                                                          nb_filters=res_block_filters,
                                                          kernel_size=self.kernel_size,
                                                          dropout_rate=self.dropout_rate, 
                                                          use_batch_norm=self.use_batch_norm,
                                                          use_layer_norm=self.use_layer_norm,
                                                          use_weight_norm=self.use_weight_norm))

        
        self.residual_blocks = nn.ModuleList(self.residual_blocks)
                                            
    def forward(self, inp):
        out = inp
        for layer in self.residual_blocks:
            out, skip_out = layer(out)
        if self.use_skip_connections:
            out = out + skip_out
        return out

########################### model #########################################
class Encoder(nn.Module):
    def __init__(self, c_in=1, nb_filters=64, kernel_size=4, 
                 dilations=[1,2,4,8], nb_stacks=2, n_steps=50, code_size=10, seq_len=100):       
        super(Encoder, self).__init__()        
        
        self.tcn_layer = TCN(in_channels=c_in, nb_filters=nb_filters, 
                             nb_stacks=nb_stacks, dilations=dilations, use_skip_connections=True, dropout_rate=0)
        
        self.fc1 = nn.Linear(nb_filters * seq_len, 2 * n_steps)  
        self.fc2 = nn.Linear(2 * n_steps, n_steps)    
        self.output_layer = nn.Linear(n_steps, code_size)           
        self.relu = nn.ReLU()
        
    def forward(self, x):
        out = x
        if len(out.shape) == 2:
            out = out.unsqueeze(1)
        out = self.tcn_layer(out)     
        out = out.flatten(1, 2)         
        out = self.relu(self.fc1(out)) 
        out = self.relu(self.fc2(out)) 
        out = self.output_layer(out)
        return out
    
########################### loss #########################################
def _cosine_simililarity_dim2(x, y):
    cos = nn.CosineSimilarity(dim=2, eps=1e-6)
    v = cos(x.unsqueeze(1), y.unsqueeze(0))
    return v    


################## PL wrapper ###############################################
class BTSCP_model(pl.LightningModule):
    def __init__(
        self,
        encoder: nn.Module,
        model: nn.Module,     
        train_dataset: Dataset, 
        test_dataset: Dataset, 
        batch_size: int = 64,        
        num_workers: int = 2,        
        temperature: float = 0.1, 
        lr: float = 1e-4,
        decay_steps: int = 1000, 
        window_1: int = 100,
        window_2: int = 100
    ) -> None:
        super().__init__()
        self.encoder = encoder
        self.model = model
        self.train_dataset = train_dataset
        self.test_dataset = test_dataset
        
        self.batch_size = batch_size
        self.num_workers = num_workers        
        
        self.temperature = temperature
        
        self.lr = lr
        self.decay_steps = decay_steps   
        
        self.window = window_1
        self.window_1 = window_1
        self.window_2 = window_2
        self.loss = BarlowTwinsLoss()

    def forward(self, inputs1: torch.Tensor, inputs2=None) -> torch.Tensor:
        inp1 = self.encoder(inputs1)
        inp2 = self.encoder(inputs2)
        return self.model(inp1, inp2)
    
    def pred(self, inputs):
        return self.encoder(inputs)
    
    def training_step(self, batch1, batch2):
       
        emb1, emb2 = self.forward(batch1.float(), batch2.float())

        loss = self.loss.forward(emb1, emb2)

        return loss
        
    def validation_step(self, batch1, batch2):

        emb1, emb2 = self.forward(batch1.float(), batch2.float())

        loss = self.loss.forward(emb1, emb2)

        return loss


    def configure_optimizers(self) -> torch.optim.Optimizer:
        opt = torch.optim.Adam(self.model.parameters(), lr=self.lr)
        return opt

    def train_dataloader(self) -> DataLoader:
        return DataLoader(
            self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers
        )

    def val_dataloader(self) -> DataLoader:
        return DataLoader(
            self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers
        )
    def get_enc(self):
        return self.encoder
    
def _cosine_simililarity_dim1(x, y):
    cos = nn.CosineSimilarity(dim=1, eps=1e-6)
    v = cos(x, y)
    return v   

In [13]:
def random_crop(batch, WIN):
    tmp = batch[0][:,1:]
    l = torch.randint(0, WIN+1, (1,))
    l = max(l, int(0.8 * tmp.shape[1]))
    pos = torch.randint(0, WIN // 10, (1,))
    tmp = tmp[:, pos:l+1]
    tmp = F.pad(input=tmp, pad=(WIN - tmp.shape[1], 0), mode='constant', value=0)
    return tmp

In [14]:
def random_noise(batch, sigma=1):
    tmp = batch[0][:,1:]
    noise = torch.normal(0, sigma, size=tmp.shape)
    tmp += noise
    return tmp

In [15]:
enc = Encoder(code_size = CODE_SIZE, seq_len = WIN)
model = BarlowTwins(CODE_SIZE, 2*CODE_SIZE, 4*CODE_SIZE)

model = BTSCP_model(enc, model, train_ds, test_ds, batch_size=BATCH_SIZE, temperature=TEMP, lr=LR, decay_steps=decay_steps, window_1=WIN)
optimizer = model.configure_optimizers()

In [13]:
train_loader = model.train_dataloader()
val_loader = model.val_dataloader()

train_losses = []
val_losses = []

for epoch in tqdm(range(EPOCHS)):
    
    iteration = 0
    train_losses_iters = []

    
    for index, batch in enumerate(train_loader):
        aug1, aug2 = random_crop(batch, WIN), random_crop(batch, WIN)
        loss = model.training_step(aug1, aug2)
        train_losses_iters.append(float(loss))

        print("train_losses_iters", train_losses_iters[-1])

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        iteration += 1

        if not iteration % 50:
            model.eval()
            vall = []
            with torch.no_grad():
                for c, b in enumerate(val_loader):
                    aug1, aug2 = random_crop([b], WIN), random_crop([b], WIN)
                    val_loss = model.validation_step(aug1, aug2)
                    vall.append(float(val_loss.detach()))
                    if c>10:
                        break
                print("val_loss", np.mean(vall))
                val_losses.append(np.mean(vall))
          
            
            model.train()
        
    print("epoch_train_loss", np.mean(train_losses_iters))  
    train_losses.append(np.mean(train_losses_iters))
    if len(val_losses):
        print("epoch_val_loss", np.mean(val_losses))




train_losses_iters 24.186826705932617
train_losses_iters 2.035172462463379
train_losses_iters 1.9536380767822266
train_losses_iters 25.047222137451172
train_losses_iters 2.090066909790039
train_losses_iters 2.0098931789398193
train_losses_iters 2.02724552154541
train_losses_iters 1.8167611360549927
train_losses_iters 2.1229002475738525
train_losses_iters 34.95899200439453
train_losses_iters 2.5307252407073975
train_losses_iters 27.999502182006836
train_losses_iters 12.878658294677734
train_losses_iters 2.2299301624298096
train_losses_iters 2.751626968383789
train_losses_iters 26.711463928222656
train_losses_iters 2.1358089447021484
train_losses_iters 2.435905933380127
train_losses_iters 1.854870319366455
train_losses_iters 3.178393602371216
train_losses_iters 3.9356865882873535
train_losses_iters 12.22683048248291
train_losses_iters 23.78879165649414
train_losses_iters 2.4218716621398926
train_losses_iters 1.478054165840149
train_losses_iters 3.116790294647217
train_losses_iters 1.7452

  2%|█                                           | 1/40 [00:42<27:51, 42.86s/it]

epoch_train_loss 8.545247483457256
epoch_val_loss 19.280677288770676
train_losses_iters 2.06050968170166
train_losses_iters 2.8866240978240967
train_losses_iters 30.311492919921875
train_losses_iters 2.037013530731201
train_losses_iters 2.2524216175079346
train_losses_iters 1.470069169998169
train_losses_iters 1.9341477155685425
train_losses_iters 2.325690269470215
train_losses_iters 2.1072514057159424
train_losses_iters 27.172317504882812
train_losses_iters 24.776836395263672
train_losses_iters 1.9360636472702026
train_losses_iters 21.571063995361328
train_losses_iters 2.3894832134246826
train_losses_iters 2.887561559677124
train_losses_iters 22.764022827148438
train_losses_iters 2.347937822341919
train_losses_iters 2.4910688400268555
train_losses_iters 21.184417724609375
train_losses_iters 1.5194916725158691
train_losses_iters 5.592536449432373
train_losses_iters 2.3087263107299805
train_losses_iters 4.802385330200195
train_losses_iters 3.462223529815674
train_losses_iters 18.9609928

  5%|██▏                                         | 2/40 [01:24<26:50, 42.37s/it]

train_losses_iters 3.4211182594299316
epoch_train_loss 8.440152110197605
epoch_val_loss 16.065648277600605
train_losses_iters 3.2537178993225098
train_losses_iters 3.3900046348571777
train_losses_iters 2.4049487113952637
train_losses_iters 1.87126886844635
train_losses_iters 3.011547088623047
train_losses_iters 2.5354323387145996
train_losses_iters 2.920548915863037
train_losses_iters 3.7344322204589844
train_losses_iters 2.4095873832702637
train_losses_iters 17.168779373168945
train_losses_iters 2.311999559402466
train_losses_iters 15.579673767089844
train_losses_iters 2.4015450477600098
train_losses_iters 24.7304744720459
train_losses_iters 15.92808723449707
train_losses_iters 9.90612506866455
train_losses_iters 2.7399702072143555
train_losses_iters 2.0201637744903564
train_losses_iters 13.13371753692627
train_losses_iters 2.258981704711914
train_losses_iters 11.269511222839355
train_losses_iters 2.2363929748535156
train_losses_iters 27.55014991760254
train_losses_iters 2.99994754791

  8%|███▎                                        | 3/40 [01:56<23:10, 37.59s/it]

epoch_train_loss 7.215987517283513
epoch_val_loss 13.978289299541048
train_losses_iters 9.061591148376465
train_losses_iters 11.86021900177002
train_losses_iters 11.599929809570312
train_losses_iters 2.9801719188690186
train_losses_iters 6.760810375213623
train_losses_iters 3.110388994216919
train_losses_iters 2.745880365371704
train_losses_iters 9.831634521484375
train_losses_iters 2.571290969848633
train_losses_iters 2.8855299949645996
train_losses_iters 2.49623966217041
train_losses_iters 2.6220664978027344
train_losses_iters 2.694453001022339
train_losses_iters 2.626347064971924
train_losses_iters 23.39451789855957
train_losses_iters 3.255992889404297
train_losses_iters 13.220993995666504
train_losses_iters 3.3510489463806152
train_losses_iters 2.7772321701049805
train_losses_iters 22.020488739013672
train_losses_iters 24.94925880432129
train_losses_iters 1.8689969778060913
train_losses_iters 23.54478645324707
train_losses_iters 3.4191553592681885
train_losses_iters 13.576861381530

 10%|████▍                                       | 4/40 [02:24<20:18, 33.86s/it]

epoch_train_loss 7.611511857081682
epoch_val_loss 13.351441577076912
train_losses_iters 1.8598392009735107
train_losses_iters 12.880752563476562
train_losses_iters 3.3306102752685547
train_losses_iters 2.933767318725586
train_losses_iters 2.997030019760132
train_losses_iters 15.032461166381836
train_losses_iters 3.1313891410827637
train_losses_iters 3.107300043106079
train_losses_iters 2.857573986053467
train_losses_iters 14.963922500610352
train_losses_iters 8.608190536499023
train_losses_iters 13.012918472290039
train_losses_iters 8.841434478759766
train_losses_iters 2.978405237197876
train_losses_iters 2.6690351963043213
train_losses_iters 3.6412394046783447
train_losses_iters 2.7409627437591553
train_losses_iters 6.533721446990967
train_losses_iters 10.990957260131836
train_losses_iters 21.09947967529297
train_losses_iters 3.3784584999084473
train_losses_iters 18.930801391601562
train_losses_iters 3.845104217529297
train_losses_iters 9.559860229492188
train_losses_iters 2.492373943

 12%|█████▌                                      | 5/40 [02:52<18:22, 31.51s/it]

train_losses_iters 6.469751358032227
epoch_train_loss 6.733407957941039
epoch_val_loss 14.168441670139632
train_losses_iters 8.429672241210938
train_losses_iters 2.9671802520751953
train_losses_iters 3.352919101715088
train_losses_iters 13.53722095489502
train_losses_iters 3.438016653060913
train_losses_iters 3.642367362976074
train_losses_iters 2.909149408340454
train_losses_iters 2.9081618785858154
train_losses_iters 2.8847551345825195
train_losses_iters 6.095270156860352
train_losses_iters 3.9322056770324707
train_losses_iters 9.002985954284668
train_losses_iters 3.2032065391540527
train_losses_iters 17.493497848510742
train_losses_iters 3.5741846561431885
train_losses_iters 3.4485397338867188
train_losses_iters 16.167034149169922
train_losses_iters 3.4428510665893555
train_losses_iters 2.9095613956451416
train_losses_iters 13.113504409790039
train_losses_iters 15.9176025390625
train_losses_iters 3.2698757648468018
train_losses_iters 2.975961208343506
train_losses_iters 3.0385041236

 15%|██████▌                                     | 6/40 [03:19<16:57, 29.91s/it]

epoch_train_loss 6.347778703412439
epoch_val_loss 14.556501852141487
train_losses_iters 6.501483917236328
train_losses_iters 3.6117634773254395
train_losses_iters 3.70906925201416
train_losses_iters 3.2263917922973633
train_losses_iters 3.2231650352478027
train_losses_iters 3.9014205932617188
train_losses_iters 2.5736773014068604
train_losses_iters 4.83477258682251
train_losses_iters 14.32259750366211
train_losses_iters 2.688373565673828
train_losses_iters 3.358596086502075
train_losses_iters 3.1694579124450684
train_losses_iters 4.1912713050842285
train_losses_iters 3.204782247543335
train_losses_iters 3.438539505004883
train_losses_iters 3.967707633972168
train_losses_iters 4.819936752319336
train_losses_iters 6.339662551879883
train_losses_iters 3.8797621726989746
train_losses_iters 2.5868301391601562
train_losses_iters 2.3716766834259033
train_losses_iters 2.694139242172241
train_losses_iters 3.425144672393799
train_losses_iters 2.6733689308166504
train_losses_iters 3.4630715847015

 18%|███████▋                                    | 7/40 [03:45<15:51, 28.84s/it]

epoch_train_loss 6.0211995585351925
epoch_val_loss 13.509552231147177
train_losses_iters 14.666553497314453
train_losses_iters 3.6088874340057373
train_losses_iters 2.7574973106384277
train_losses_iters 13.378677368164062
train_losses_iters 3.919072151184082
train_losses_iters 14.381632804870605
train_losses_iters 7.42237663269043
train_losses_iters 4.005427360534668
train_losses_iters 3.3271849155426025
train_losses_iters 11.40428352355957
train_losses_iters 12.8233060836792
train_losses_iters 2.807730197906494
train_losses_iters 13.60810661315918
train_losses_iters 3.1838443279266357
train_losses_iters 2.716235637664795
train_losses_iters 5.512539386749268
train_losses_iters 11.671655654907227
train_losses_iters 3.366244316101074
train_losses_iters 3.6097500324249268
train_losses_iters 11.776006698608398
train_losses_iters 3.457758665084839
train_losses_iters 2.776244878768921
train_losses_iters 3.026569366455078
train_losses_iters 12.75052547454834
train_losses_iters 12.594146728515

 20%|████████▊                                   | 8/40 [04:12<15:01, 28.18s/it]

train_losses_iters 7.110174655914307
epoch_train_loss 5.577170017438057
epoch_val_loss 13.930887820820013
train_losses_iters 3.235430955886841
train_losses_iters 3.006453514099121
train_losses_iters 3.573061227798462
train_losses_iters 3.941930055618286
train_losses_iters 3.224532127380371
train_losses_iters 3.3818485736846924
train_losses_iters 3.5840296745300293
train_losses_iters 3.0982398986816406
train_losses_iters 15.543049812316895
train_losses_iters 3.1931774616241455
train_losses_iters 4.234035491943359
train_losses_iters 5.0291571617126465
train_losses_iters 2.9674158096313477
train_losses_iters 11.380228996276855
train_losses_iters 2.9633309841156006
train_losses_iters 2.1205172538757324
train_losses_iters 3.928370475769043
train_losses_iters 3.697472333908081
train_losses_iters 15.162208557128906
train_losses_iters 10.473469734191895
train_losses_iters 2.1466636657714844
train_losses_iters 3.727750778198242
train_losses_iters 3.4001762866973877
train_losses_iters 3.69872212

 22%|█████████▉                                  | 9/40 [04:39<14:24, 27.89s/it]

train_losses_iters 3.8405346870422363
epoch_train_loss 5.805326936591385
epoch_val_loss 14.387142429197276
train_losses_iters 9.518489837646484
train_losses_iters 6.4614691734313965
train_losses_iters 11.276224136352539
train_losses_iters 4.090083122253418
train_losses_iters 3.6678686141967773
train_losses_iters 8.264015197753906
train_losses_iters 2.6657373905181885
train_losses_iters 3.1334025859832764
train_losses_iters 7.812343597412109
train_losses_iters 3.7052602767944336
train_losses_iters 2.9072346687316895
train_losses_iters 2.854555130004883
train_losses_iters 3.0736632347106934
train_losses_iters 3.4134793281555176
train_losses_iters 3.735468864440918
train_losses_iters 11.308523178100586
train_losses_iters 2.961355209350586
train_losses_iters 11.198222160339355
train_losses_iters 3.234215497970581
train_losses_iters 3.1119744777679443
train_losses_iters 2.8371148109436035
train_losses_iters 10.996879577636719
train_losses_iters 6.663771629333496
train_losses_iters 3.7348580

 25%|██████████▊                                | 10/40 [05:09<14:13, 28.46s/it]

train_losses_iters 2.6425411701202393
epoch_train_loss 5.124611414395845
epoch_val_loss 14.833744179705779
train_losses_iters 3.086839437484741
train_losses_iters 2.530989646911621
train_losses_iters 3.2347352504730225
train_losses_iters 2.894251823425293
train_losses_iters 3.9525907039642334
train_losses_iters 2.8634374141693115
train_losses_iters 4.6087541580200195
train_losses_iters 6.989754676818848
train_losses_iters 11.039678573608398
train_losses_iters 2.828420877456665
train_losses_iters 2.997631072998047
train_losses_iters 2.9480361938476562
train_losses_iters 8.414693832397461
train_losses_iters 3.660475730895996
train_losses_iters 2.3239920139312744
train_losses_iters 13.307886123657227
train_losses_iters 2.76639986038208
train_losses_iters 6.757307529449463
train_losses_iters 8.194595336914062
train_losses_iters 2.781430244445801
train_losses_iters 2.6804184913635254
train_losses_iters 3.219823122024536
train_losses_iters 6.392397403717041
train_losses_iters 7.9224567413330

 28%|███████████▊                               | 11/40 [05:49<15:28, 32.01s/it]

epoch_train_loss 5.689258832197923
epoch_val_loss 14.460073981321218
train_losses_iters 3.5526490211486816
train_losses_iters 2.8878421783447266
train_losses_iters 10.897428512573242
train_losses_iters 3.2080445289611816
train_losses_iters 8.015741348266602
train_losses_iters 2.9282426834106445
train_losses_iters 2.895465612411499
train_losses_iters 10.948216438293457
train_losses_iters 3.3735625743865967
train_losses_iters 3.206697463989258
train_losses_iters 3.8886067867279053
train_losses_iters 3.1266441345214844
train_losses_iters 2.920764923095703
train_losses_iters 3.403663158416748
train_losses_iters 2.2762327194213867
train_losses_iters 4.3883585929870605
train_losses_iters 2.7381467819213867
train_losses_iters 3.4479622840881348
train_losses_iters 3.4521329402923584
train_losses_iters 3.3203508853912354
train_losses_iters 3.2901268005371094
train_losses_iters 9.38088607788086
train_losses_iters 2.9562699794769287
train_losses_iters 7.0723958015441895
train_losses_iters 3.86692

 30%|████████████▉                              | 12/40 [06:24<15:22, 32.96s/it]

train_losses_iters 11.559876441955566
epoch_train_loss 4.592583888616318
epoch_val_loss 14.41822835513287
train_losses_iters 2.6356070041656494
train_losses_iters 3.3412773609161377
train_losses_iters 2.7992801666259766
train_losses_iters 3.0856544971466064
train_losses_iters 8.184636116027832
train_losses_iters 2.7009711265563965
train_losses_iters 2.680377721786499
train_losses_iters 3.0798726081848145
train_losses_iters 2.743802547454834
train_losses_iters 2.7986667156219482
train_losses_iters 3.893549680709839
train_losses_iters 3.246898651123047
train_losses_iters 2.991562843322754
train_losses_iters 6.077506065368652
train_losses_iters 2.3268089294433594
train_losses_iters 2.5675950050354004
train_losses_iters 2.8743631839752197
train_losses_iters 7.819921970367432
train_losses_iters 3.0225086212158203
train_losses_iters 8.297307968139648
train_losses_iters 2.7082982063293457
train_losses_iters 2.9332635402679443
train_losses_iters 9.904050827026367
train_losses_iters 2.487014293

 32%|█████████████▉                             | 13/40 [07:05<15:51, 35.25s/it]

train_losses_iters 12.83243465423584
epoch_train_loss 4.505765692800538
epoch_val_loss 14.979897163235222
train_losses_iters 3.2469167709350586
train_losses_iters 8.775976181030273
train_losses_iters 2.6883316040039062
train_losses_iters 2.8697805404663086
train_losses_iters 3.2334423065185547
train_losses_iters 2.7475507259368896
train_losses_iters 8.073580741882324
train_losses_iters 2.360044002532959
train_losses_iters 7.999749660491943
train_losses_iters 2.698025703430176
train_losses_iters 3.295138120651245
train_losses_iters 8.57929801940918
train_losses_iters 2.840116024017334
train_losses_iters 3.0878396034240723
train_losses_iters 2.8383829593658447
train_losses_iters 8.204757690429688
train_losses_iters 3.146122694015503
train_losses_iters 2.6632330417633057
train_losses_iters 2.2313437461853027
train_losses_iters 2.5152926445007324
train_losses_iters 2.386666774749756
train_losses_iters 2.5464720726013184
train_losses_iters 2.514326572418213
train_losses_iters 2.705716371536

 35%|███████████████                            | 14/40 [07:30<14:01, 32.38s/it]

epoch_train_loss 4.810968661919619
epoch_val_loss 15.252328665838352
train_losses_iters 2.6951708793640137
train_losses_iters 2.526123523712158
train_losses_iters 2.845341920852661
train_losses_iters 2.8335280418395996
train_losses_iters 8.318367958068848
train_losses_iters 8.564881324768066
train_losses_iters 2.6527457237243652
train_losses_iters 5.487118244171143
train_losses_iters 5.818456649780273
train_losses_iters 7.102527141571045
train_losses_iters 2.023380994796753
train_losses_iters 7.00362491607666
train_losses_iters 4.573633670806885
train_losses_iters 2.6374990940093994
train_losses_iters 5.228030204772949
train_losses_iters 8.3782377243042
train_losses_iters 3.3224103450775146
train_losses_iters 2.6195101737976074
train_losses_iters 7.230118751525879
train_losses_iters 2.303802967071533
train_losses_iters 3.0095274448394775
train_losses_iters 2.634101390838623
train_losses_iters 2.787052631378174
train_losses_iters 2.6003472805023193
train_losses_iters 2.340956211090088
t

 38%|████████████████▏                          | 15/40 [07:55<12:30, 30.03s/it]

train_losses_iters 2.6929545402526855
epoch_train_loss 4.168673721134153
epoch_val_loss 15.152690418561296
train_losses_iters 2.7932722568511963
train_losses_iters 8.264752388000488
train_losses_iters 2.635615587234497
train_losses_iters 3.069411039352417
train_losses_iters 4.872674465179443
train_losses_iters 6.301648139953613
train_losses_iters 2.302922248840332
train_losses_iters 2.63236403465271
train_losses_iters 4.1691131591796875
train_losses_iters 6.945826530456543
train_losses_iters 2.2742788791656494
train_losses_iters 2.8997573852539062
train_losses_iters 6.215397357940674
train_losses_iters 2.526493787765503
train_losses_iters 2.427689552307129
train_losses_iters 6.1763811111450195
train_losses_iters 8.870132446289062
train_losses_iters 2.681248426437378
train_losses_iters 2.1607465744018555
train_losses_iters 2.188816547393799
train_losses_iters 3.1032519340515137
train_losses_iters 8.54753589630127
train_losses_iters 4.434193134307861
train_losses_iters 2.4778778553009033

 40%|█████████████████▏                         | 16/40 [08:20<11:20, 28.36s/it]

train_losses_iters 2.080919027328491
epoch_train_loss 4.185289998339791
epoch_val_loss 14.956428522244096
train_losses_iters 2.9539172649383545
train_losses_iters 2.4499218463897705
train_losses_iters 2.4666996002197266
train_losses_iters 10.276077270507812
train_losses_iters 2.47053861618042
train_losses_iters 5.877458572387695
train_losses_iters 2.704470634460449
train_losses_iters 2.0519394874572754
train_losses_iters 2.4290733337402344
train_losses_iters 3.940955638885498
train_losses_iters 2.3108620643615723
train_losses_iters 10.958697319030762
train_losses_iters 2.491349697113037
train_losses_iters 2.368586540222168
train_losses_iters 2.851199150085449
train_losses_iters 2.2342231273651123
train_losses_iters 5.948258399963379
train_losses_iters 2.5539162158966064
train_losses_iters 2.6506476402282715
train_losses_iters 2.8340468406677246
train_losses_iters 7.15643310546875
train_losses_iters 2.713263988494873
train_losses_iters 2.3581063747406006
train_losses_iters 2.44870853424

 42%|██████████████████▎                        | 17/40 [08:53<11:29, 29.98s/it]

epoch_train_loss 3.8433782677365165
epoch_val_loss 14.877022398160953
train_losses_iters 6.0901994705200195
train_losses_iters 2.385366201400757
train_losses_iters 2.4274191856384277
train_losses_iters 2.326820135116577
train_losses_iters 2.8264880180358887
train_losses_iters 2.466373920440674
train_losses_iters 2.3993797302246094
train_losses_iters 2.42287015914917
train_losses_iters 2.4916598796844482
train_losses_iters 2.482736110687256
train_losses_iters 6.63259744644165
train_losses_iters 7.251237392425537
train_losses_iters 2.390683650970459
train_losses_iters 5.992015361785889
train_losses_iters 7.782968997955322
train_losses_iters 2.141842842102051
train_losses_iters 2.193706750869751
train_losses_iters 2.0639710426330566
train_losses_iters 2.353806257247925
train_losses_iters 2.590393543243408
train_losses_iters 2.1607933044433594
train_losses_iters 2.2216053009033203
train_losses_iters 2.455505609512329
train_losses_iters 2.605435371398926
train_losses_iters 2.811428070068359

 45%|███████████████████▎                       | 18/40 [09:24<11:03, 30.14s/it]

epoch_train_loss 3.932146658245315
epoch_val_loss 15.096925527409272
train_losses_iters 2.505516767501831
train_losses_iters 3.4542810916900635
train_losses_iters 3.5174312591552734
train_losses_iters 8.421348571777344
train_losses_iters 2.470478057861328
train_losses_iters 7.185315132141113
train_losses_iters 6.878355979919434
train_losses_iters 2.1027324199676514
train_losses_iters 2.511348247528076
train_losses_iters 2.7142956256866455
train_losses_iters 2.8516604900360107
train_losses_iters 2.4931652545928955
train_losses_iters 2.7263388633728027
train_losses_iters 5.316010475158691
train_losses_iters 2.6199259757995605
train_losses_iters 2.543646812438965
train_losses_iters 2.432028293609619
train_losses_iters 2.603299140930176
train_losses_iters 6.426411151885986
train_losses_iters 2.3030612468719482
train_losses_iters 5.5455002784729
train_losses_iters 8.238365173339844
train_losses_iters 2.322678327560425
train_losses_iters 2.2601866722106934
train_losses_iters 3.03468775749206

 48%|████████████████████▍                      | 19/40 [09:55<10:42, 30.61s/it]

train_losses_iters 2.4389877319335938
epoch_train_loss 3.723088060688769
epoch_val_loss 15.025703373708225
train_losses_iters 5.194794654846191
train_losses_iters 6.430496692657471
train_losses_iters 7.326627731323242
train_losses_iters 7.218045234680176
train_losses_iters 2.252312660217285
train_losses_iters 5.896475791931152
train_losses_iters 2.407160758972168
train_losses_iters 6.835750102996826
train_losses_iters 2.392721652984619
train_losses_iters 2.4665651321411133
train_losses_iters 2.3476457595825195
train_losses_iters 2.40236496925354
train_losses_iters 2.8381526470184326
train_losses_iters 2.5762157440185547
train_losses_iters 2.3813564777374268
train_losses_iters 2.0184669494628906
train_losses_iters 2.36643385887146
train_losses_iters 2.5561728477478027
train_losses_iters 2.821063280105591
train_losses_iters 3.9260365962982178
train_losses_iters 2.553637742996216
train_losses_iters 2.3679654598236084
train_losses_iters 2.418030023574829
train_losses_iters 4.08373880386352

 50%|█████████████████████▌                     | 20/40 [10:24<09:57, 29.88s/it]

epoch_train_loss 3.5178093808329005
epoch_val_loss 15.043652017166215
train_losses_iters 2.576042413711548
train_losses_iters 3.341693162918091
train_losses_iters 6.3540544509887695
train_losses_iters 5.9864115715026855
train_losses_iters 2.5227575302124023
train_losses_iters 2.3044052124023438
train_losses_iters 4.730466842651367
train_losses_iters 2.503706932067871
train_losses_iters 6.8132500648498535
train_losses_iters 2.4293577671051025
train_losses_iters 2.657790184020996
train_losses_iters 2.6047213077545166
train_losses_iters 2.5965285301208496
train_losses_iters 3.060734510421753
train_losses_iters 2.856127977371216
train_losses_iters 5.795217990875244
train_losses_iters 7.6050848960876465
train_losses_iters 3.5899085998535156
train_losses_iters 2.727701425552368
train_losses_iters 2.6532328128814697
train_losses_iters 2.610048294067383
train_losses_iters 2.568288803100586
train_losses_iters 4.924057483673096
train_losses_iters 4.804989814758301
train_losses_iters 2.6097040176

 52%|██████████████████████▌                    | 21/40 [10:53<09:22, 29.62s/it]

epoch_train_loss 3.5240565120664415
epoch_val_loss 14.990132501910601
train_losses_iters 3.0365946292877197
train_losses_iters 3.0049242973327637
train_losses_iters 2.3320865631103516
train_losses_iters 2.218050479888916
train_losses_iters 6.6699347496032715
train_losses_iters 7.046797752380371
train_losses_iters 7.797035217285156
train_losses_iters 2.272368907928467
train_losses_iters 2.146350145339966
train_losses_iters 2.316391706466675
train_losses_iters 2.4005565643310547
train_losses_iters 2.5255260467529297
train_losses_iters 6.260786056518555
train_losses_iters 2.532428503036499
train_losses_iters 2.452885389328003
train_losses_iters 2.685178756713867
train_losses_iters 2.2597827911376953
train_losses_iters 2.5543923377990723
train_losses_iters 4.5950469970703125
train_losses_iters 2.2316813468933105
train_losses_iters 2.4025557041168213
train_losses_iters 4.484844207763672
train_losses_iters 2.1470160484313965
train_losses_iters 2.7152953147888184
train_losses_iters 2.37601089

 55%|███████████████████████▋                   | 22/40 [11:43<10:43, 35.74s/it]

train_losses_iters 2.2248823642730713
epoch_train_loss 3.372174613496177
epoch_val_loss 14.92405855700825
train_losses_iters 2.1261889934539795
train_losses_iters 4.203967571258545
train_losses_iters 2.6599786281585693
train_losses_iters 6.434804916381836
train_losses_iters 6.187195301055908
train_losses_iters 2.280076503753662
train_losses_iters 2.6759865283966064
train_losses_iters 2.2430198192596436
train_losses_iters 5.953684329986572
train_losses_iters 2.334686040878296
train_losses_iters 2.795579433441162
train_losses_iters 2.726619243621826
train_losses_iters 3.2766902446746826
train_losses_iters 2.358184337615967
train_losses_iters 4.343007564544678
train_losses_iters 2.7030720710754395
train_losses_iters 2.3626794815063477
train_losses_iters 2.4660959243774414
train_losses_iters 3.8617265224456787
train_losses_iters 4.8765740394592285
train_losses_iters 2.717785358428955
train_losses_iters 2.33042049407959
train_losses_iters 4.721782684326172
train_losses_iters 5.0534925460815

 57%|████████████████████████▋                  | 23/40 [12:23<10:29, 37.04s/it]

epoch_train_loss 3.373297351038354
epoch_val_loss 14.871656402729558
train_losses_iters 2.6659622192382812
train_losses_iters 2.396495819091797
train_losses_iters 5.747098922729492
train_losses_iters 2.379978656768799
train_losses_iters 2.277566432952881
train_losses_iters 2.5122852325439453
train_losses_iters 2.1625359058380127
train_losses_iters 2.4541876316070557
train_losses_iters 2.1780643463134766
train_losses_iters 2.5069706439971924
train_losses_iters 5.500837802886963
train_losses_iters 2.5210604667663574
train_losses_iters 2.348114013671875
train_losses_iters 2.3451149463653564
train_losses_iters 6.618076324462891
train_losses_iters 2.734133720397949
train_losses_iters 3.4532084465026855
train_losses_iters 2.3562982082366943
train_losses_iters 2.1486105918884277
train_losses_iters 2.6175949573516846
train_losses_iters 2.5602216720581055
train_losses_iters 2.5008561611175537
train_losses_iters 2.3115153312683105
train_losses_iters 2.593313694000244
train_losses_iters 2.2387096

 60%|█████████████████████████▊                 | 24/40 [12:52<09:14, 34.65s/it]

epoch_train_loss 3.457403458081759
epoch_val_loss 14.75070361358424
train_losses_iters 6.348213195800781
train_losses_iters 2.219032049179077
train_losses_iters 2.181053400039673
train_losses_iters 7.252335548400879
train_losses_iters 2.329531669616699
train_losses_iters 5.713079452514648
train_losses_iters 2.5715999603271484
train_losses_iters 2.9573042392730713
train_losses_iters 5.107275485992432
train_losses_iters 2.9562489986419678
train_losses_iters 2.4415957927703857
train_losses_iters 4.055820941925049
train_losses_iters 2.44999361038208
train_losses_iters 2.7191379070281982
train_losses_iters 2.4563817977905273
train_losses_iters 8.308878898620605
train_losses_iters 2.56382417678833
train_losses_iters 4.2585368156433105
train_losses_iters 2.2762062549591064
train_losses_iters 2.394926071166992
train_losses_iters 2.3857884407043457
train_losses_iters 2.280158042907715
train_losses_iters 4.84909725189209
train_losses_iters 3.447472095489502
train_losses_iters 10.681707382202148


 62%|██████████████████████████▉                | 25/40 [13:22<08:20, 33.37s/it]

train_losses_iters 5.449333190917969
epoch_train_loss 3.5251844173822646
epoch_val_loss 14.670598098635674
train_losses_iters 2.351339340209961
train_losses_iters 2.3756704330444336
train_losses_iters 3.3257436752319336
train_losses_iters 2.4351253509521484
train_losses_iters 2.8728628158569336
train_losses_iters 2.5257043838500977
train_losses_iters 2.316972494125366
train_losses_iters 2.2041592597961426
train_losses_iters 2.462313175201416
train_losses_iters 2.4361367225646973
train_losses_iters 2.280486822128296
train_losses_iters 5.084585189819336
train_losses_iters 2.336362361907959
train_losses_iters 2.2408957481384277
train_losses_iters 2.7534642219543457
train_losses_iters 2.1729836463928223
train_losses_iters 6.083093643188477
train_losses_iters 5.668883800506592
train_losses_iters 2.2607645988464355
train_losses_iters 2.46268892288208
train_losses_iters 2.2780537605285645
train_losses_iters 2.463080883026123
train_losses_iters 5.555419921875
train_losses_iters 2.3684849739074

 65%|███████████████████████████▉               | 26/40 [13:49<07:19, 31.38s/it]

train_losses_iters 2.4065356254577637
epoch_train_loss 3.3288809397281747
epoch_val_loss 14.68855664630731
train_losses_iters 3.807359457015991
train_losses_iters 2.683380603790283
train_losses_iters 2.3695127964019775
train_losses_iters 2.455735206604004
train_losses_iters 2.5952296257019043
train_losses_iters 2.699660301208496
train_losses_iters 3.209726333618164
train_losses_iters 2.547854423522949
train_losses_iters 3.871863603591919
train_losses_iters 5.995433807373047
train_losses_iters 2.5814929008483887
train_losses_iters 4.37375020980835
train_losses_iters 6.807931423187256
train_losses_iters 2.3513689041137695
train_losses_iters 6.601428508758545
train_losses_iters 4.778072357177734
train_losses_iters 2.2842893600463867
train_losses_iters 2.417651414871216
train_losses_iters 3.9724135398864746
train_losses_iters 6.150218486785889
train_losses_iters 2.263011932373047
train_losses_iters 2.414243459701538
train_losses_iters 9.820579528808594
train_losses_iters 2.6697347164154053

 68%|█████████████████████████████              | 27/40 [14:16<06:31, 30.15s/it]

epoch_train_loss 3.482138919015216
epoch_val_loss 14.71619689685327
train_losses_iters 2.4853439331054688
train_losses_iters 2.1124608516693115
train_losses_iters 2.4600412845611572
train_losses_iters 2.4635603427886963
train_losses_iters 2.3701562881469727
train_losses_iters 2.437040328979492
train_losses_iters 2.6425256729125977
train_losses_iters 2.2476348876953125
train_losses_iters 2.3527493476867676
train_losses_iters 2.1888482570648193
train_losses_iters 2.190396785736084
train_losses_iters 2.318781852722168
train_losses_iters 4.593301296234131
train_losses_iters 2.422867774963379
train_losses_iters 2.4715471267700195
train_losses_iters 2.4433937072753906
train_losses_iters 2.4280929565429688
train_losses_iters 2.630087375640869
train_losses_iters 3.9786484241485596
train_losses_iters 4.217275619506836
train_losses_iters 2.884669303894043
train_losses_iters 2.4287149906158447
train_losses_iters 2.3446593284606934
train_losses_iters 2.30592679977417
train_losses_iters 2.626323938

 70%|██████████████████████████████             | 28/40 [14:43<05:51, 29.25s/it]

epoch_train_loss 2.9846808808481593
epoch_val_loss 14.764757675074394
train_losses_iters 2.623624086380005
train_losses_iters 2.438152551651001
train_losses_iters 6.112155437469482
train_losses_iters 2.282066822052002
train_losses_iters 6.218655586242676
train_losses_iters 2.3177919387817383
train_losses_iters 2.356327772140503
train_losses_iters 5.412981033325195
train_losses_iters 2.3213417530059814
train_losses_iters 2.412240505218506
train_losses_iters 2.361130714416504
train_losses_iters 5.92134952545166
train_losses_iters 4.325628280639648
train_losses_iters 2.4627184867858887
train_losses_iters 4.132284641265869
train_losses_iters 2.235874891281128
train_losses_iters 2.485546350479126
train_losses_iters 2.3245387077331543
train_losses_iters 2.944833755493164
train_losses_iters 4.166208267211914
train_losses_iters 2.300800323486328
train_losses_iters 2.2507026195526123
train_losses_iters 5.169197082519531
train_losses_iters 3.9878125190734863
train_losses_iters 4.534685134887695


 72%|███████████████████████████████▏           | 29/40 [15:10<05:13, 28.50s/it]

epoch_train_loss 3.369787870309292
epoch_val_loss 14.826960898850155
train_losses_iters 2.2341504096984863
train_losses_iters 3.2860562801361084
train_losses_iters 2.3788225650787354
train_losses_iters 2.49377179145813
train_losses_iters 2.135003089904785
train_losses_iters 2.388779640197754
train_losses_iters 2.6403510570526123
train_losses_iters 2.610806941986084
train_losses_iters 2.3259084224700928
train_losses_iters 2.140636682510376
train_losses_iters 2.955446720123291
train_losses_iters 2.3961403369903564
train_losses_iters 2.3546836376190186
train_losses_iters 2.497525930404663
train_losses_iters 3.4434595108032227
train_losses_iters 2.6834447383880615
train_losses_iters 4.511010646820068
train_losses_iters 2.6091113090515137
train_losses_iters 2.6701793670654297
train_losses_iters 2.2055439949035645
train_losses_iters 4.632664680480957
train_losses_iters 6.048986911773682
train_losses_iters 2.2453556060791016
train_losses_iters 6.235884666442871
train_losses_iters 4.1581883430

 75%|████████████████████████████████▎          | 30/40 [15:40<04:48, 28.85s/it]

train_losses_iters 4.256711006164551
epoch_train_loss 3.4549277623494468
epoch_val_loss 14.819535469346574
train_losses_iters 2.3527069091796875
train_losses_iters 2.7304458618164062
train_losses_iters 4.429073810577393
train_losses_iters 2.4861392974853516
train_losses_iters 2.3123154640197754
train_losses_iters 2.4212515354156494
train_losses_iters 2.543231964111328
train_losses_iters 2.5536580085754395
train_losses_iters 5.059260368347168
train_losses_iters 2.421128511428833
train_losses_iters 2.171799898147583
train_losses_iters 2.4345877170562744
train_losses_iters 2.4609248638153076
train_losses_iters 5.089172840118408
train_losses_iters 4.270697593688965
train_losses_iters 2.2905099391937256
train_losses_iters 2.2199535369873047
train_losses_iters 5.073982238769531
train_losses_iters 2.2624826431274414
train_losses_iters 2.4644908905029297
train_losses_iters 5.512456893920898
train_losses_iters 2.757563829421997
train_losses_iters 2.376610040664673
train_losses_iters 2.155503511

 78%|█████████████████████████████████▎         | 31/40 [16:11<04:27, 29.68s/it]

epoch_train_loss 3.250545654541407
epoch_val_loss 14.842896929831912
train_losses_iters 2.6268539428710938
train_losses_iters 5.284271717071533
train_losses_iters 2.8283867835998535
train_losses_iters 4.211141109466553
train_losses_iters 2.307586669921875
train_losses_iters 2.7294118404388428
train_losses_iters 5.364631652832031
train_losses_iters 2.2207372188568115
train_losses_iters 2.26335072517395
train_losses_iters 2.8168535232543945
train_losses_iters 2.39166259765625
train_losses_iters 2.3074228763580322
train_losses_iters 2.567295551300049
train_losses_iters 2.6935036182403564
train_losses_iters 5.265667915344238
train_losses_iters 2.498171091079712
train_losses_iters 2.4909114837646484
train_losses_iters 4.742138862609863
train_losses_iters 2.32794189453125
train_losses_iters 2.3084566593170166
train_losses_iters 2.6535215377807617
train_losses_iters 2.287236452102661
train_losses_iters 4.320128440856934
train_losses_iters 2.7731266021728516
train_losses_iters 2.49497389793396

 80%|██████████████████████████████████▍        | 32/40 [16:39<03:51, 28.93s/it]

epoch_train_loss 3.192756787324563
epoch_val_loss 14.754593407269567
train_losses_iters 2.214766502380371
train_losses_iters 2.100445508956909
train_losses_iters 4.574941635131836
train_losses_iters 2.3538565635681152
train_losses_iters 4.872715950012207
train_losses_iters 2.2103371620178223
train_losses_iters 3.936768054962158
train_losses_iters 2.6295886039733887
train_losses_iters 7.8031816482543945
train_losses_iters 4.678316593170166
train_losses_iters 2.8992607593536377
train_losses_iters 2.4692375659942627
train_losses_iters 7.670259475708008
train_losses_iters 2.211310386657715
train_losses_iters 2.3558356761932373
train_losses_iters 2.2804958820343018
train_losses_iters 2.3254218101501465
train_losses_iters 2.615743398666382
train_losses_iters 2.4894068241119385
train_losses_iters 3.732788324356079
train_losses_iters 2.364293098449707
train_losses_iters 2.298905849456787
train_losses_iters 2.606809139251709
train_losses_iters 2.318082332611084
train_losses_iters 2.220781087875

 82%|███████████████████████████████████▍       | 33/40 [17:06<03:18, 28.38s/it]

train_losses_iters 2.2303783893585205
epoch_train_loss 3.3059295690976658
epoch_val_loss 14.797737123237717
train_losses_iters 2.902459144592285
train_losses_iters 2.4261739253997803
train_losses_iters 6.877543926239014
train_losses_iters 2.6957650184631348
train_losses_iters 2.328899383544922
train_losses_iters 2.2112009525299072
train_losses_iters 2.2332398891448975
train_losses_iters 2.3689517974853516
train_losses_iters 4.213407516479492
train_losses_iters 4.261765480041504
train_losses_iters 2.9391226768493652
train_losses_iters 2.344627618789673
train_losses_iters 5.2580180168151855
train_losses_iters 2.1882472038269043
train_losses_iters 2.787721872329712
train_losses_iters 2.635698080062866
train_losses_iters 2.375462055206299
train_losses_iters 2.4877405166625977
train_losses_iters 4.365407466888428
train_losses_iters 2.7391014099121094
train_losses_iters 5.6516313552856445
train_losses_iters 2.4669480323791504
train_losses_iters 2.4472358226776123
train_losses_iters 2.5405094

 85%|████████████████████████████████████▌      | 34/40 [17:33<02:48, 28.04s/it]

train_losses_iters 2.528709650039673
epoch_train_loss 3.2504405221368513
epoch_val_loss 14.78538486230023
train_losses_iters 2.170586347579956
train_losses_iters 3.4926130771636963
train_losses_iters 2.4619319438934326
train_losses_iters 5.9806647300720215
train_losses_iters 2.9945647716522217
train_losses_iters 2.5028727054595947
train_losses_iters 2.4357285499572754
train_losses_iters 2.391077995300293
train_losses_iters 2.309858560562134
train_losses_iters 2.3638195991516113
train_losses_iters 2.5541930198669434
train_losses_iters 3.247664451599121
train_losses_iters 4.963870525360107
train_losses_iters 2.7067480087280273
train_losses_iters 2.360745668411255
train_losses_iters 2.395530939102173
train_losses_iters 2.4164938926696777
train_losses_iters 3.856013536453247
train_losses_iters 4.067462921142578
train_losses_iters 2.8691229820251465
train_losses_iters 4.481918811798096
train_losses_iters 2.306020498275757
train_losses_iters 2.387672185897827
train_losses_iters 4.81194496154

 88%|█████████████████████████████████████▋     | 35/40 [18:00<02:18, 27.66s/it]

epoch_train_loss 3.2317850528619227
epoch_val_loss 14.868288421772776
train_losses_iters 3.6495397090911865
train_losses_iters 2.7110025882720947
train_losses_iters 2.008718967437744
train_losses_iters 3.4225821495056152
train_losses_iters 2.4047253131866455
train_losses_iters 4.2798848152160645
train_losses_iters 2.621048927307129
train_losses_iters 2.7794189453125
train_losses_iters 5.82081937789917
train_losses_iters 2.417811870574951
train_losses_iters 2.469686269760132
train_losses_iters 2.600654125213623
train_losses_iters 2.308804750442505
train_losses_iters 2.4985387325286865
train_losses_iters 5.635186195373535
train_losses_iters 6.31654691696167
train_losses_iters 2.097379207611084
train_losses_iters 2.2229421138763428
train_losses_iters 2.465902090072632
train_losses_iters 2.1136341094970703
train_losses_iters 2.189725399017334
train_losses_iters 2.1862192153930664
train_losses_iters 2.7413344383239746
train_losses_iters 8.175682067871094
train_losses_iters 4.262322425842285

 90%|██████████████████████████████████████▋    | 36/40 [18:28<01:50, 27.73s/it]

train_losses_iters 2.4550015926361084
epoch_train_loss 3.1648287385956855
epoch_val_loss 14.855842439526763
train_losses_iters 2.6733198165893555
train_losses_iters 6.091665744781494
train_losses_iters 5.383777141571045
train_losses_iters 8.442910194396973
train_losses_iters 2.4215636253356934
train_losses_iters 2.3692474365234375
train_losses_iters 2.3737082481384277
train_losses_iters 2.370816469192505
train_losses_iters 2.1189732551574707
train_losses_iters 6.164483070373535
train_losses_iters 2.7609457969665527
train_losses_iters 7.129321575164795
train_losses_iters 2.449666738510132
train_losses_iters 2.4074244499206543
train_losses_iters 2.2962779998779297
train_losses_iters 2.27933931350708
train_losses_iters 5.067102432250977
train_losses_iters 2.532600164413452
train_losses_iters 6.934568881988525
train_losses_iters 2.4454033374786377
train_losses_iters 3.3057777881622314
train_losses_iters 7.560096263885498
train_losses_iters 2.716290235519409
train_losses_iters 4.60030174255

 92%|███████████████████████████████████████▊   | 37/40 [18:56<01:24, 28.00s/it]

epoch_train_loss 3.260547715374547
epoch_val_loss 14.853297479920561
train_losses_iters 2.2583565711975098
train_losses_iters 3.243543863296509
train_losses_iters 2.1477596759796143
train_losses_iters 3.829738140106201
train_losses_iters 4.826292991638184
train_losses_iters 2.2516732215881348
train_losses_iters 2.356381893157959
train_losses_iters 2.4041504859924316
train_losses_iters 2.658405065536499
train_losses_iters 2.8215317726135254
train_losses_iters 2.1661622524261475
train_losses_iters 5.726412296295166
train_losses_iters 2.105173349380493
train_losses_iters 2.287358522415161
train_losses_iters 2.557464838027954
train_losses_iters 5.062591075897217
train_losses_iters 2.248250722885132
train_losses_iters 2.299365520477295
train_losses_iters 2.490086078643799
train_losses_iters 2.3397812843322754
train_losses_iters 2.610499143600464
train_losses_iters 2.02543306350708
train_losses_iters 3.463836669921875
train_losses_iters 4.381181716918945
train_losses_iters 2.4246039390563965

 95%|████████████████████████████████████████▊  | 38/40 [19:24<00:55, 27.98s/it]

train_losses_iters 2.5502350330352783
epoch_train_loss 2.9370141905597134
epoch_val_loss 14.955759647776162
train_losses_iters 2.9482674598693848
train_losses_iters 4.0173563957214355
train_losses_iters 6.155426979064941
train_losses_iters 2.469275951385498
train_losses_iters 2.3316149711608887
train_losses_iters 1.9488478899002075
train_losses_iters 2.4041671752929688
train_losses_iters 2.2452688217163086
train_losses_iters 2.518098831176758
train_losses_iters 4.132260799407959
train_losses_iters 2.2657222747802734
train_losses_iters 2.208374261856079
train_losses_iters 9.115715026855469
train_losses_iters 2.128159761428833
train_losses_iters 3.2734365463256836
train_losses_iters 4.161949157714844
train_losses_iters 3.942803382873535
train_losses_iters 2.192927837371826
train_losses_iters 2.598845958709717
train_losses_iters 2.3755955696105957
train_losses_iters 3.8021652698516846
train_losses_iters 2.5035274028778076
train_losses_iters 2.358328342437744
train_losses_iters 2.044609785

 98%|█████████████████████████████████████████▉ | 39/40 [19:53<00:28, 28.15s/it]

epoch_train_loss 3.1527431867061515
epoch_val_loss 14.971449659548256
train_losses_iters 2.2827587127685547
train_losses_iters 2.6508185863494873
train_losses_iters 2.4359335899353027
train_losses_iters 2.13018798828125
train_losses_iters 2.2107620239257812
train_losses_iters 2.2299208641052246
train_losses_iters 4.161649703979492
train_losses_iters 2.02751088142395
train_losses_iters 2.198331832885742
train_losses_iters 2.886848211288452
train_losses_iters 4.163015365600586
train_losses_iters 5.631760597229004
train_losses_iters 2.3748879432678223
train_losses_iters 2.444596290588379
train_losses_iters 2.3881640434265137
train_losses_iters 2.491826057434082
train_losses_iters 2.633986473083496
train_losses_iters 2.412400960922241
train_losses_iters 2.2940735816955566
train_losses_iters 2.105245590209961
train_losses_iters 3.462643623352051
train_losses_iters 2.0677101612091064
train_losses_iters 7.6075029373168945
train_losses_iters 2.0583858489990234
train_losses_iters 2.248976945877

100%|███████████████████████████████████████████| 40/40 [20:20<00:00, 30.52s/it]

train_losses_iters 5.358888149261475
epoch_train_loss 2.9911909867555666
epoch_val_loss 14.87304739529888





In [15]:
torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, './barlowtscp'+DS_NAME+str(WIN)+str(CODE_SIZE)+str(BATCH_SIZE)+'_opt.pt')

In [16]:
cpt = torch.load('barlowtscpUSC1001032_opt.pt')
model.load_state_dict(cpt['model_state_dict'])

<All keys matched successfully>

In [17]:
enc = model.get_enc()

In [19]:
test_ds = load_future(DATA_PATH, DS_NAME, WIN, BATCH_SIZE, mode = "test")

number of samples : 2336 /  number of samples with change point : 175
test shape (468, 200) and number of change points 35 
(468, 200) (468, 1)
dataset shape :  (468, 201)


In [30]:
x_test, lbl_test = test_ds[:,1:], test_ds[:,0]

num = x_test.shape[0]
lbl_test = np.array(lbl_test).reshape((lbl_test.shape[0], 1))
history = enc(torch.from_numpy(x_test[:, 0:WIN].reshape((num, 1, WIN))).float())
future = enc(torch.from_numpy(x_test[:, WIN:].reshape((num, 1, WIN))).float())
pred_out = np.concatenate((lbl_test, history.detach().numpy(), future.detach().numpy()), 1)
rep_sim = _cosine_simililarity_dim1(history, future)

print('Average similarity for test set : Reps : {}'.format(np.mean(rep_sim.detach().numpy())))
gt = np.zeros(lbl_test.shape[0])
gt[np.where((lbl_test > int(2 * WIN * 0.15)) & (lbl_test < int(2 * WIN * 0.85)))[0]] = 1
result = estimate_CPs(rep_sim.detach().numpy(), gt, os.path.join(OUTPUT_PATH, train_name),
                    os.path.join(OUTPUT_PATH, "Evaluation.txt"),
                    metric='cosine', threshold=0.5)



Average similarity for test set : Reps : 0.999922513961792
tn 442, fp 0, fn 26, tp 0 ----- f1-score 0.0
SEQ : Pos 7, fp 0, fn 7, tp 0 ----- f1-score 0.0


In [36]:
result = estimate_CPs(rep_sim.detach().numpy(), gt, os.path.join(OUTPUT_PATH, train_name),
                    os.path.join(OUTPUT_PATH, "Evaluation.txt"),
                    metric='cosine', threshold=0.9)

tn 442, fp 0, fn 26, tp 0 ----- f1-score 0.0
SEQ : Pos 7, fp 0, fn 7, tp 0 ----- f1-score 0.0
