## Imports

In [1]:
import sys
import os
import argparse
import time
import random
import math
import numpy as np
from scipy.special import logsumexp

import torch
import torchvision
from PIL import Image
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as tfs

from tensorboardX import SummaryWriter


from utils import kNN, AverageMeter, py_softmax

## Training parameters

In [2]:
#data
datadir = "/root/data/Multivariate_arff"

# optimization
lamb = 10      # SK lambda-parameter
nopts = 400    # number of SK-optimizations
epochs = 100   # numbers of epochs
momentum = 0.9 # sgd momentum
exp = './resnet1d_exp' # experiments results dir


# other
devc='0'  # cuda device
batch_size = 100
lr=0.03     #learning rate
alr=0.03    #starting learning rate

knn_dim = 10
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch


In [3]:
device = torch.device('cuda:' + devc) if torch.cuda.is_available() else torch.device('cpu')
print(f"GPU device: {torch.cuda.current_device()}")

GPU device: 0


## Model parameters (AlexNet in that case)

In [4]:
hc=10       # number of heads
ncl=6       # number of clusters

numc = [ncl] * hc
# # (number of filters, kernel size, stride, pad) for AlexNet, two vesions
# CFG = {
#     'big': [(96, 11, 4, 2), 'M', (256, 5, 1, 2), 'M', (384, 3, 1, 1), (384, 3, 1, 1), (256, 3, 1, 1), 'M'],
#     'small': [(64, 11, 4, 2), 'M', (192, 5, 1, 2), 'M', (384, 3, 1, 1), (256, 3, 1, 1), (256, 3, 1, 1), 'M']
# }

## Data Preparation

In [5]:
import pandas as pd
import numpy as np
from scipy.io import arff
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn import preprocessing

from tqdm import tqdm
import os
from datetime import datetime

In [6]:
def load_file(filepath):
    data = arff.loadarff(filepath)
    data = pd.DataFrame(data[0])
    X = data.iloc[:, :-1]
    y = data.iloc[:, -1]
    return X.values, y.values

def load_group(prefix, filenames): 
    loaded = []
    for name in filenames: 
        X, y = load_file(prefix + "/" + name) 
        loaded.append(X)
    # stack group so that features are the 3rd dimension 
    loaded = np.dstack(loaded)
    return loaded, y

def load_dataset_group(folder_path, ds_path, dims_num, is_train=True, label_enc=False): 
    filenames = []
    if is_train:
        postfix = "_TRAIN.arff"
    else:
        postfix = "_TEST.arff"
    for dim_num in range(1, dims_num + 1):
        filenames.append(ds_path + str(dim_num) + postfix)

    X, y = load_group(folder_path, filenames)
    X = torch.from_numpy(np.array(X, dtype=np.float64))
    if label_enc:
        le = preprocessing.LabelEncoder()
        y = le.fit_transform(y)
        y = torch.from_numpy(np.array(y, dtype=np.int32))
    else:
        y = torch.from_numpy(np.array(y, dtype=np.int32)) - 1
    X = X.transpose(1, 2)
    return X, y

def load_dataset(folder_path, ds_path, dims_num, label_enc=False): 
    X_train, y_train = load_dataset_group(folder_path, ds_path, dims_num, 
                                          is_train=True, label_enc=label_enc) 
    X_test, y_test = load_dataset_group(folder_path, ds_path, dims_num, 
                                        is_train=False, label_enc=label_enc)
    X_train = F.normalize(X_train, dim=1)
    X_test = F.normalize(X_test, dim=1)
    return X_train, y_train, X_test, y_test

# from tqdm import trange
def iterate_minibatches(inputs, targets, batchsize, shuffle=False):
    assert len(inputs) == len(targets)
    if shuffle:
        indices = np.random.permutation(len(inputs))
    for start_idx in range(0, len(inputs) - batchsize + 1, batchsize):
        if shuffle:
            excerpt = indices[start_idx:start_idx + batchsize]
        else:
            excerpt = slice(start_idx, start_idx + batchsize)
        yield inputs[excerpt], targets[excerpt], excerpt

In [7]:
#ds_path = "ERing/ERingDimension"
#ds_path = "SpokenArabicDigits/SpokenArabicDigitsDimension"
# dims_num = 13
# num_classes = 10
# magic_dim = 4500

ds_path = "LSST/LSSTDimension"
dims_num = 6
num_classes = 14
magic_dim = 2304

X_train, y_train, X_test, y_test = load_dataset(datadir, ds_path, dims_num)
# X_train[0], y_train
print("X_train.shape:", X_train.shape, "\ny_train.shape:", y_train.shape)
print("X_test.shape:", X_test.shape, "\ny_test.shape:", y_test.shape)

X_train.shape: torch.Size([2459, 6, 36]) 
y_train.shape: torch.Size([2459])
X_test.shape: torch.Size([2466, 6, 36]) 
y_test.shape: torch.Size([2466])


In [8]:
N = X_train.shape[0]
N

2459

In [9]:
X_train.shape

torch.Size([2459, 6, 36])

## Model, ResNet

In [10]:
import torch.nn as nn
import math

__all__ = ['resnetv1','resnetv1_18']

def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv1d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)

class Normalize(nn.Module):
    def __init__(self, power=2):
        super(Normalize, self).__init__()
        self.power = power

    def forward(self, x):
        norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power)
        out = x.div(norm)
        return out

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm1d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm1d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv1d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm1d(planes)
        self.conv2 = nn.Conv1d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm1d(planes)
        self.conv3 = nn.Conv1d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm1d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

class ResNet(nn.Module):
    def __init__(self, block, layers, in_channel=3, width=1, num_classes=[1000]):
        self.inplanes = 16
        super(ResNet, self).__init__()
        self.headcount = len(num_classes)
        self.base = int(16 * width)
        self.features = nn.Sequential(*[                                                     # [100, 8, 18]
                            nn.Conv1d(in_channel, 16, kernel_size=3, padding=1, bias=False), # [100, 16, 36]
                            nn.BatchNorm1d(16),
                            nn.ReLU(inplace=True),
                            self._make_layer(block, self.base, layers[0]),                   # [100, 16, 36]
                            self._make_layer(block, self.base * 2, layers[1]),               # [100, 32, 36]
                            self._make_layer(block, self.base * 4, layers[2]),               # [100, 64, 36]
                            self._make_layer(block, self.base * 8, layers[3]),               # [100, 128, 36]
                            nn.AvgPool1d(2),                                                 # [100, 128, 18]
        ])
    
        if len(num_classes) == 1:
            self.top_layer = nn.Sequential(nn.Linear(magic_dim, num_classes[0]))
        else:
            for a, i in enumerate(num_classes):
                setattr(self, "top_layer%d" % a, nn.Linear(magic_dim, i))
            self.top_layer = None
        for m in self.features.modules():
            if isinstance(m, nn.Conv1d):
                n = m.kernel_size[0] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm1d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv1d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm1d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.features(x.float())
        out = out.view(out.size(0), -1)
        if self.headcount == 1:
            if self.top_layer:
                out = self.top_layer(out)
            return out
        else:
            outp = []
            for i in range(self.headcount):
                outp.append(getattr(self, "top_layer%d" % i)(out))
            return outp

def resnet18(pretrained=False, **kwargs):
    """Constructs a ResNet-18 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
    return model

def resnetv1_18(num_classes=[1000]):
    """Encoder for instance discrimination and MoCo"""
    return resnet18(num_classes=num_classes)

## Sinkhorn-Knopp optimization

In [11]:
def optimize_L_sk(PS):
    N, K = PS.shape
    tt = time.time()
    PS = PS.T  # now it is K x N
    r = np.ones((K, 1)) / K
    c = np.ones((N, 1)) / N
    PS **= lamb  # K x N
    inv_K = 1. / K
    inv_N = 1. / N
    err = 1e3
    _counter = 0
    while err > 1e-2:
        r = inv_K / (PS @ c)  # (KxN)@(N,1) = K x 1
        c_new = inv_N / (r.T @ PS).T  # ((1,K)@(KxN)).t() = N x 1
        if _counter % 10 == 0:
            err = np.nansum(np.abs(c / c_new - 1))
        c = c_new
        _counter += 1
        
    print("error: ", err, 'step ', _counter, flush=True)  # " nonneg: ", sum(I), flush=True)
    # inplace calculations.
    PS *= np.squeeze(c)
    PS = PS.T
    PS *= np.squeeze(r)
    PS = PS.T
    argmaxes = np.nanargmax(PS, 0)  # size N
    newL = torch.LongTensor(argmaxes)
    selflabels = newL.to(device)
    PS = PS.T
    PS /= np.squeeze(r)
    PS = PS.T
    PS /= np.squeeze(c)
    sol = PS[argmaxes, np.arange(N)]
    np.log(sol, sol)
    cost = -(1. / lamb) * np.nansum(sol) / N
    print('cost: ', cost, flush=True)
    print('opt took {0:.2f}min, {1:4d}iters'.format(((time.time() - tt) / 60.), _counter), flush=True)
    return cost, selflabels

def opt_sk(model, selflabels_in, epoch):
    if hc == 1:
        PS = np.zeros((N, ncl))
    else:
        PS_pre = np.zeros((N, magic_dim)) # knn_dim
    
    for batch_idx, (data, _, _selected) in enumerate(iterate_minibatches(X_train, y_train, batch_size, shuffle=True)):
        data = data.to(device)#cuda()
        if hc == 1:
            p = nn.functional.softmax(model(data), 1)
            PS[_selected, :] = p.detach().cpu().numpy()
        else:
            p = model(data.float())
            PS_pre[_selected, :] = p.detach().cpu().numpy()
    if hc == 1:
        cost, selflabels = optimize_L_sk(PS)
        _costs = [cost]
    else:
        _nmis = np.zeros(hc)
        _costs = np.zeros(hc)
        nh = epoch % hc  # np.random.randint(args.hc)
        print("computing head %s " % nh, end="\r", flush=True)
        tl = getattr(model, "top_layer%d" % nh)
        # do the forward pass:
        PS = (PS_pre @ tl.weight.cpu().numpy().T
                   + tl.bias.cpu().numpy())
        PS = py_softmax(PS, 1)
        c, selflabels_ = optimize_L_sk(PS)
        _costs[nh] = c
        selflabels_in[nh] = selflabels_
        selflabels = selflabels_in
    return selflabels

## Training utils

In [12]:
def adjust_learning_rate(optimizer, epoch):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    lr = alr
    if epochs == 200:
        if epoch >= 80:
            lr = alr * (0.1 ** ((epoch - 80) // 40))  # i.e. 120, 160
            print(lr)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
    elif epochs == 400:
        if epoch >= 160:
            lr = alr * (0.1 ** ((epoch - 160) // 80))  # i.e. 240,320
            print(lr)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
    elif epochs == 800:
        if epoch >= 320:
            lr = alr * (0.1 ** ((epoch - 320) // 160))  # i.e. 480, 640
            print(lr)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
    elif epochs == 1600:
        if epoch >= 640:
            lr = alr * (0.1 ** ((epoch - 640) // 320))
            print(lr)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

In [13]:
def feature_return_switch(model, bool=True):
    """
    switch between network output or conv5features
        if True: changes switch s.t. forward pass returns post-conv5 features
        if False: changes switch s.t. forward will give full network output
    """
    if bool:
        model.headcount = 1
    else:
        model.headcount = hc
    model.return_feature = bool

In [14]:
def train(epoch, selflabels):
    print('\nEpoch: %d' % epoch)
    print(name)
    adjust_learning_rate(optimizer, epoch)
    train_loss = AverageMeter()
    data_time = AverageMeter()
    batch_time = AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()
    
    for batch_idx, (inputs, targets, indexes) in enumerate(iterate_minibatches(X_train, y_train, batch_size, shuffle=True)):
        inputs = inputs.float().to(device)
        niter = epoch * N + batch_idx
        if len(optimize_times) > 0 and niter * batch_size >= optimize_times[-1]:
            with torch.no_grad():
                _ = optimize_times.pop()
                if hc >1:
                    feature_return_switch(model, True)
                selflabels = opt_sk(model, selflabels, epoch)
                if hc >1:
                    feature_return_switch(model, False)
        data_time.update(time.time() - end)
        inputs, targets = inputs.to(device), targets.to(device)#, indexes.to(device)
        optimizer.zero_grad()

        outputs = model(inputs)
        if hc == 1:
            loss = criterion(outputs, selflabels[indexes])
        else:
            loss = torch.mean(torch.stack([criterion(outputs[h],
                                                     selflabels[h, indexes]) for h in range(hc)]))

        loss.backward()
        optimizer.step()

        train_loss.update(loss.item(), inputs.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if True:
#         if batch_idx % 10 == 0:
            print('Epoch: [{}][{}/{}]'
                  'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
                  'Data: {data_time.val:.3f} ({data_time.avg:.3f}) '
                  'Loss: {train_loss.val:.4f} ({train_loss.avg:.4f})'.format(
                epoch, batch_idx, N, batch_time=batch_time, data_time=data_time, train_loss=train_loss))
#             writer.add_scalar("loss", loss.item(), batch_idx*512 +epoch*len(trainloader.dataset))
    return selflabels

## Model initialization

In [15]:
model = resnet18(num_classes=numc, in_channel=dims_num)

In [16]:
optimize_times = ((epochs + 1.0001)*N*(np.linspace(0, 1, nopts))[::-1]).tolist()
optimize_times = [(epochs +10)*N] + optimize_times
print('We will optimize L at epochs:', [np.round(1.0*t/N, 2) for t in optimize_times], flush=True)

We will optimize L at epochs: [110.0, 101.0, 100.75, 100.49, 100.24, 99.99, 99.73, 99.48, 99.23, 98.98, 98.72, 98.47, 98.22, 97.96, 97.71, 97.46, 97.2, 96.95, 96.7, 96.44, 96.19, 95.94, 95.68, 95.43, 95.18, 94.92, 94.67, 94.42, 94.17, 93.91, 93.66, 93.41, 93.15, 92.9, 92.65, 92.39, 92.14, 91.89, 91.63, 91.38, 91.13, 90.87, 90.62, 90.37, 90.12, 89.86, 89.61, 89.36, 89.1, 88.85, 88.6, 88.34, 88.09, 87.84, 87.58, 87.33, 87.08, 86.82, 86.57, 86.32, 86.07, 85.81, 85.56, 85.31, 85.05, 84.8, 84.55, 84.29, 84.04, 83.79, 83.53, 83.28, 83.03, 82.77, 82.52, 82.27, 82.02, 81.76, 81.51, 81.26, 81.0, 80.75, 80.5, 80.24, 79.99, 79.74, 79.48, 79.23, 78.98, 78.72, 78.47, 78.22, 77.96, 77.71, 77.46, 77.21, 76.95, 76.7, 76.45, 76.19, 75.94, 75.69, 75.43, 75.18, 74.93, 74.67, 74.42, 74.17, 73.91, 73.66, 73.41, 73.16, 72.9, 72.65, 72.4, 72.14, 71.89, 71.64, 71.38, 71.13, 70.88, 70.62, 70.37, 70.12, 69.86, 69.61, 69.36, 69.11, 68.85, 68.6, 68.35, 68.09, 67.84, 67.59, 67.33, 67.08, 66.83, 66.57, 66.32, 66.07

In [17]:
# init selflabels randomly
if hc == 1:
    selflabels = np.zeros(N, dtype=np.int32)
    for qq in range(N):
        selflabels[qq] = qq % ncl
    selflabels = np.random.permutation(selflabels)
    selflabels = torch.LongTensor(selflabels).to(device)
else:
    selflabels = np.zeros((hc, N), dtype=np.int32)
    for nh in range(hc):
        for _i in range(N):
            selflabels[nh, _i] = _i % numc[nh]
        selflabels[nh] = np.random.permutation(selflabels[nh])
    selflabels = torch.LongTensor(selflabels).to(device)

In [18]:
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=5e-4)
model = model.to(device)
criterion = nn.CrossEntropyLoss()

In [19]:
name = "ResNet1D"
writer = SummaryWriter(f'./runs/ERing/{name}')

## Training! 
Takes a couple of minutes per epoch

In [20]:
def my_kNN(net, K, sigma=0.1, dim=128, use_pca=False):
    net.eval()
    # this part is ugly but made to be backwards-compatible. there was a change in cifar dataset's structure.
    trainLabels = y_train
    LEN = N
    C = trainLabels.max() + 1

    trainFeatures = torch.zeros((magic_dim, LEN))  # , device='cuda:0') # dim
    normalize = Normalize()
    for batch_idx, (inputs, targets, _) in enumerate(iterate_minibatches(X_train, y_train, batch_size, shuffle=False)):
        batchSize = batch_size
        inputs = inputs.cuda()
        features = net(inputs.float())
        if not use_pca:
            features = normalize(features)
        tmp = trainFeatures[:, batch_idx * batchSize:batch_idx * batchSize + batchSize]
        trainFeatures[:, batch_idx * batchSize:batch_idx * batchSize + batchSize] = features.data.t().cpu()
        
    if use_pca:
        comps = 128
        print('doing PCA with %s components'%comps, end=' ')
        from sklearn.decomposition import PCA
        pca = PCA(n_components=comps, whiten=False)
        trainFeatures = pca.fit_transform(trainFeatures.numpy().T)
        trainFeatures = torch.Tensor(trainFeatures)
        trainFeatures = normalize(trainFeatures).t()
        print('..done')
    def eval_k_s(K_,sigma_):
        total = 0
        top1 = 0.
        top5 = 0.

        with torch.no_grad():
            retrieval_one_hot = torch.zeros(K_, C)# .cuda()
            for batch_idx, (inputs, targets, _) in enumerate(iterate_minibatches(X_test, y_test, batch_size, shuffle=True)):
                targets = targets # .cuda(async=True) # or without async for py3.7
                inputs = inputs.cuda()
                batchSize = batch_size
                features = net(inputs)
                if use_pca:
                    features = pca.transform(features.cpu().numpy())
                    features = torch.Tensor(features).cuda()
                features = normalize(features).cpu()

                dist = torch.mm(features, trainFeatures)

                yd, yi = dist.topk(K_, dim=1, largest=True, sorted=True)
                candidates = trainLabels.view(1, -1).expand(batchSize, -1)
                retrieval = torch.gather(candidates, 1, yi).long()

                retrieval_one_hot.resize_(batchSize * K_, C).zero_()
                retrieval_one_hot.scatter_(1, retrieval.view(-1, 1), 1.)
                
                yd_transform = yd.clone().div_(sigma_).exp_()
                probs = torch.sum(torch.mul(retrieval_one_hot.view(batchSize, -1, C),
                                            yd_transform.view(batchSize, -1, 1)),
                                  1)
                _, predictions = probs.sort(1, True)

                # Find which predictions match the target
                correct = predictions.eq(targets.data.view(-1, 1))

                top1 = top1 + correct.narrow(1, 0, 1).sum().item()
                top5 = top5 + correct.narrow(1, 0, 5).sum().item()

                total += targets.size(0)

        print(f"{K_}-NN,s={sigma_}: TOP1: ", top1 * 100. / total)
        return top1 / total

    if isinstance(K, list):
        res = []
        for K_ in K:
            for sigma_ in sigma:
                res.append(eval_k_s(K_, sigma_))
        return res
    else:
        res = eval_k_s(K, sigma)
        return res

In [None]:
import time

start = time.time()
for epoch in range(start_epoch, start_epoch + epochs):
    selflabels = train(epoch, selflabels)
    feature_return_switch(model, True)
    
    acc = my_kNN(model, K=10, sigma=0.1, dim=knn_dim)
    feature_return_switch(model, False)
#     writer.add_scalar("accuracy kNN", acc, epoch)
    if acc > best_acc:
        print('Saving..')
        state = {
            'net': model.state_dict(),
            'acc': acc,
            'epoch': epoch,
            'opt': optimizer.state_dict(),
            'L': selflabels,
        }
        if not os.path.isdir(exp):
            os.mkdir(exp)
        torch.save(state, '%s/best_ckpt.t7' % (exp))
        best_acc = acc
    if epoch % 400 == 0:
        print('Saving..')
        state = {
            'net': model.state_dict(),
            'opt': optimizer.state_dict(),
            'acc': acc,
            'epoch': epoch,
            'L': selflabels,
        }
        if not os.path.isdir(exp):
            os.mkdir(exp)
        torch.save(state, '%s/ep%s.t7' % (exp, epoch))
    if epoch % 50 == 0:
        feature_return_switch(model, True)
        acc = my_kNN(model, K=[50, 10], sigma=[0.1, 0.5], dim=knn_dim, use_pca=True)
        i = 0
#         for num_nn in [50, 10]:
#             for sig in [0.1, 0.5]:
#                 writer.add_scalar('knn%s-%s' % (num_nn, sig), acc[i], epoch)
#                 i += 1
        feature_return_switch(model, False)
    print('best accuracy: {:.2f}'.format(best_acc * 100))
end = time.time()

checkpoint = torch.load('%s'%exp+'/best_ckpt.t7' )
model.load_state_dict(checkpoint['net'])
feature_return_switch(model, True)
acc = my_kNN(model, K=10, sigma=0.1, dim=knn_dim, use_pca=True)


Epoch: 0
ResNet1D
error:  0.009233038617996403 step  31
cost:  1.3887578024021057
opt took 0.00min,   31iters
Epoch: [0][0/2459]Time: 0.332 (0.332) Data: 0.259 (0.259) Loss: 1.9297 (1.9297)
Epoch: [0][1/2459]Time: 0.045 (0.189) Data: 0.005 (0.132) Loss: 1.8533 (1.8915)
Epoch: [0][2/2459]Time: 0.027 (0.135) Data: 0.001 (0.088) Loss: 1.9103 (1.8977)
Epoch: [0][3/2459]Time: 0.026 (0.108) Data: 0.001 (0.066) Loss: 1.9746 (1.9169)
Epoch: [0][4/2459]Time: 0.030 (0.092) Data: 0.001 (0.053) Loss: 1.8865 (1.9109)
Epoch: [0][5/2459]Time: 0.026 (0.081) Data: 0.001 (0.045) Loss: 1.8449 (1.8999)
Epoch: [0][6/2459]Time: 0.026 (0.073) Data: 0.001 (0.038) Loss: 1.9727 (1.9103)
error:  0.002349356032679273 step  61
cost:  1.2761258430396636
opt took 0.00min,   61iters
Epoch: [0][7/2459]Time: 0.264 (0.097) Data: 0.232 (0.063) Loss: 1.8981 (1.9088)
Epoch: [0][8/2459]Time: 0.033 (0.090) Data: 0.002 (0.056) Loss: 1.9163 (1.9096)
Epoch: [0][9/2459]Time: 0.036 (0.085) Data: 0.001 (0.050) Loss: 1.8840 (1.907

error:  0.005936685988586099 step  101
cost:  1.0970638845147211
opt took 0.00min,  101iters
Epoch: [2][8/2459]Time: 0.258 (0.296) Data: 0.226 (0.256) Loss: 1.4088 (1.4266)
error:  0.004356519482974108 step  71
cost:  0.9586815654145
opt took 0.00min,   71iters
Epoch: [2][9/2459]Time: 0.258 (0.292) Data: 0.219 (0.252) Loss: 1.4303 (1.4269)
error:  0.008313518381654283 step  111
cost:  0.9684723529938857
opt took 0.00min,  111iters
Epoch: [2][10/2459]Time: 0.277 (0.291) Data: 0.249 (0.252) Loss: 1.4249 (1.4267)
error:  0.004966955497098358 step  131
cost:  1.0221863550058954
opt took 0.00min,  131iters
Epoch: [2][11/2459]Time: 0.254 (0.288) Data: 0.226 (0.250) Loss: 1.3596 (1.4212)
error:  0.006326925245197623 step  151
cost:  0.9075002856469107
opt took 0.00min,  151iters
Epoch: [2][12/2459]Time: 0.344 (0.292) Data: 0.300 (0.254) Loss: 1.4489 (1.4233)
error:  0.009802542560975569 step  151
cost:  0.8150262143605131
opt took 0.00min,  151iters
Epoch: [2][13/2459]Time: 0.421 (0.301) Data

opt took 0.00min,   81iters
Epoch: [4][6/2459]Time: 0.245 (0.261) Data: 0.213 (0.215) Loss: 1.0231 (1.0500)
error:  0.008997757713394838 step  71
cost:  0.99930465774916
opt took 0.00min,   71iters
Epoch: [4][7/2459]Time: 0.273 (0.262) Data: 0.237 (0.218) Loss: 1.0029 (1.0441)
error:  0.004455771013267218 step  91
cost:  1.0385325461155621
opt took 0.00min,   91iters
Epoch: [4][8/2459]Time: 0.233 (0.259) Data: 0.200 (0.216) Loss: 1.0239 (1.0419)
error:  0.006858789826424294 step  101
cost:  1.0506866050301535
opt took 0.00min,  101iters
Epoch: [4][9/2459]Time: 0.222 (0.255) Data: 0.194 (0.214) Loss: 1.0458 (1.0423)
error:  0.005596943693829792 step  101
cost:  0.9419463060532046
opt took 0.00min,  101iters
Epoch: [4][10/2459]Time: 0.336 (0.263) Data: 0.309 (0.222) Loss: 1.0465 (1.0426)
error:  0.0032939098794758648 step  91
cost:  0.8605009498081678
opt took 0.00min,   91iters
Epoch: [4][11/2459]Time: 0.238 (0.261) Data: 0.203 (0.221) Loss: 1.0647 (1.0445)
error:  0.004176937422322613 

error:  0.004755004547699282 step  91
cost:  0.9834487317187092
opt took 0.00min,   91iters
Epoch: [6][5/2459]Time: 0.249 (0.256) Data: 0.222 (0.226) Loss: 0.7701 (0.7547)
error:  0.0035226190621376885 step  81
cost:  0.9447458100014119
opt took 0.00min,   81iters
Epoch: [6][6/2459]Time: 0.243 (0.254) Data: 0.188 (0.220) Loss: 0.7581 (0.7552)
error:  0.005213686955913066 step  81
cost:  1.0165045437873264
opt took 0.00min,   81iters
Epoch: [6][7/2459]Time: 0.243 (0.253) Data: 0.207 (0.219) Loss: 0.7291 (0.7519)
error:  0.007349195795170793 step  81
cost:  1.0204574451763175
opt took 0.00min,   81iters
Epoch: [6][8/2459]Time: 0.279 (0.256) Data: 0.226 (0.219) Loss: 0.7637 (0.7532)
error:  0.005329195152325217 step  91
cost:  0.9136222752699101
opt took 0.00min,   91iters
Epoch: [6][9/2459]Time: 0.293 (0.259) Data: 0.257 (0.223) Loss: 0.7684 (0.7548)
error:  0.008256718692560616 step  91
cost:  0.8994790733731991
opt took 0.00min,   91iters
Epoch: [6][10/2459]Time: 0.346 (0.267) Data: 0.

Epoch: [8][3/2459]Time: 0.271 (0.268) Data: 0.220 (0.230) Loss: 0.5023 (0.4825)
error:  0.005201425937310211 step  81
cost:  0.7946014116208222
opt took 0.00min,   81iters
Epoch: [8][4/2459]Time: 0.257 (0.266) Data: 0.233 (0.231) Loss: 0.5724 (0.5005)
error:  0.008034367313097035 step  111
cost:  0.8100195269284365
opt took 0.00min,  111iters
Epoch: [8][5/2459]Time: 0.227 (0.259) Data: 0.199 (0.226) Loss: 0.5122 (0.5024)
error:  0.007153300921307948 step  121
cost:  0.8418743265353698
opt took 0.00min,  121iters
Epoch: [8][6/2459]Time: 0.322 (0.268) Data: 0.281 (0.233) Loss: 0.5110 (0.5036)
error:  0.006050751035810542 step  121
cost:  0.8537294611242168
opt took 0.00min,  121iters
Epoch: [8][7/2459]Time: 0.282 (0.270) Data: 0.244 (0.235) Loss: 0.5513 (0.5096)
error:  0.003943946528377706 step  121
cost:  0.7971371605609985
opt took 0.00min,  121iters
Epoch: [8][8/2459]Time: 0.287 (0.272) Data: 0.245 (0.236) Loss: 0.5357 (0.5125)
error:  0.007113244186013334 step  121
cost:  0.77168343

error:  0.009059753338954812 step  421
cost:  0.2538927867403954
opt took 0.00min,  421iters
Epoch: [10][2/2459]Time: 0.460 (0.625) Data: 0.421 (0.582) Loss: 0.4272 (0.4022)
error:  0.00834516760358861 step  351
cost:  0.26049631117048794
opt took 0.00min,  351iters
Epoch: [10][3/2459]Time: 0.502 (0.594) Data: 0.460 (0.552) Loss: 0.4380 (0.4112)
error:  0.008786111784765516 step  561
cost:  0.261746784950808
opt took 0.00min,  561iters
Epoch: [10][4/2459]Time: 0.472 (0.570) Data: 0.409 (0.523) Loss: 0.4058 (0.4101)
error:  0.009408713400304802 step  461
cost:  0.2614113445244886
opt took 0.00min,  461iters
Epoch: [10][5/2459]Time: 0.354 (0.534) Data: 0.310 (0.488) Loss: 0.3763 (0.4045)
error:  0.00929643125775137 step  441
cost:  0.26004961461618237
opt took 0.01min,  441iters
Epoch: [10][6/2459]Time: 0.838 (0.577) Data: 0.784 (0.530) Loss: 0.4295 (0.4080)
error:  0.009007748644344282 step  521
cost:  0.24868546818041054
opt took 0.00min,  521iters
Epoch: [10][7/2459]Time: 0.355 (0.549

error:  0.009902684696318387 step  701
cost:  0.18564568937164855
opt took 0.01min,  701iters
Epoch: [12][0/2459]Time: 0.727 (0.727) Data: 0.688 (0.688) Loss: 0.3261 (0.3261)
error:  0.00982441627565156 step  861
cost:  0.18109637010317794
opt took 0.01min,  861iters
Epoch: [12][1/2459]Time: 0.612 (0.669) Data: 0.555 (0.621) Loss: 0.2736 (0.2998)
error:  0.009385973709764128 step  931
cost:  0.17441269459864206
opt took 0.00min,  931iters
Epoch: [12][2/2459]Time: 0.585 (0.641) Data: 0.522 (0.588) Loss: 0.2808 (0.2935)
error:  0.009153853861750605 step  761
cost:  0.17824531988697045
opt took 0.00min,  761iters
Epoch: [12][3/2459]Time: 0.428 (0.588) Data: 0.400 (0.541) Loss: 0.3489 (0.3073)
error:  0.009942316549257924 step  611
cost:  0.17545675058796756
opt took 0.00min,  611iters
Epoch: [12][4/2459]Time: 0.542 (0.579) Data: 0.514 (0.536) Loss: 0.2904 (0.3040)
error:  0.009484231277990252 step  501
cost:  0.17574662691213017
opt took 0.00min,  501iters
Epoch: [12][5/2459]Time: 0.488 (

opt took 0.00min,  471iters
Epoch: [13][22/2459]Time: 0.434 (0.600) Data: 0.395 (0.559) Loss: 0.3472 (0.2805)
error:  0.009364897863807564 step  691
cost:  0.1683071986255244
opt took 0.00min,  691iters
Epoch: [13][23/2459]Time: 0.350 (0.590) Data: 0.273 (0.547) Loss: 0.3073 (0.2816)
10-NN,s=0.1: TOP1:  36.458333333333336
Saving..
best accuracy: 36.46

Epoch: 14
ResNet1D
error:  0.009632543458180431 step  841
cost:  0.15159144868361296
opt took 0.00min,  841iters
Epoch: [14][0/2459]Time: 0.399 (0.399) Data: 0.324 (0.324) Loss: 0.2312 (0.2312)
error:  0.009751416057215856 step  651
cost:  0.16009593488768775
opt took 0.02min,  651iters
Epoch: [14][1/2459]Time: 1.211 (0.805) Data: 1.184 (0.754) Loss: 0.2370 (0.2341)
error:  0.009777427605144329 step  641
cost:  0.16388368115226265
opt took 0.00min,  641iters
Epoch: [14][2/2459]Time: 0.439 (0.683) Data: 0.397 (0.635) Loss: 0.2656 (0.2446)
error:  0.009769990355174119 step  761
cost:  0.16597898112600548
opt took 0.00min,  761iters
Epoch: 

cost:  0.17948380758799562
opt took 0.00min,  461iters
Epoch: [15][20/2459]Time: 0.361 (0.425) Data: 0.316 (0.383) Loss: 0.2548 (0.2290)
error:  0.009028435432200443 step  301
cost:  0.1809573282544537
opt took 0.00min,  301iters
Epoch: [15][21/2459]Time: 0.279 (0.419) Data: 0.237 (0.377) Loss: 0.2123 (0.2282)
error:  0.00955421127465983 step  441
cost:  0.17458664054305248
opt took 0.00min,  441iters
Epoch: [15][22/2459]Time: 0.521 (0.423) Data: 0.494 (0.382) Loss: 0.2041 (0.2271)
error:  0.009580976012403042 step  881
cost:  0.17701471658735826
opt took 0.01min,  881iters
Epoch: [15][23/2459]Time: 1.112 (0.452) Data: 1.064 (0.410) Loss: 0.2770 (0.2292)
10-NN,s=0.1: TOP1:  35.708333333333336
best accuracy: 36.46

Epoch: 16
ResNet1D
error:  0.009843221756787757 step  641
cost:  0.17910370238236087
opt took 0.00min,  641iters
Epoch: [16][0/2459]Time: 0.334 (0.334) Data: 0.294 (0.294) Loss: 0.1803 (0.1803)
error:  0.009832465743447205 step  621
cost:  0.18150319887782249
opt took 0.00min

10-NN,s=0.1: TOP1:  35.833333333333336
best accuracy: 36.46

Epoch: 18
ResNet1D
Epoch: [18][0/2459]Time: 0.026 (0.026) Data: 0.001 (0.001) Loss: 0.1749 (0.1749)
Epoch: [18][1/2459]Time: 0.044 (0.035) Data: 0.001 (0.001) Loss: 0.1864 (0.1807)
Epoch: [18][2/2459]Time: 0.023 (0.031) Data: 0.001 (0.001) Loss: 0.1873 (0.1829)
Epoch: [18][3/2459]Time: 0.068 (0.040) Data: 0.001 (0.001) Loss: 0.2352 (0.1960)
Epoch: [18][4/2459]Time: 0.075 (0.047) Data: 0.001 (0.001) Loss: 0.1815 (0.1931)
Epoch: [18][5/2459]Time: 0.026 (0.044) Data: 0.001 (0.001) Loss: 0.1815 (0.1911)
Epoch: [18][6/2459]Time: 0.025 (0.041) Data: 0.001 (0.001) Loss: 0.1632 (0.1871)
Epoch: [18][7/2459]Time: 0.025 (0.039) Data: 0.001 (0.001) Loss: 0.1493 (0.1824)
Epoch: [18][8/2459]Time: 0.024 (0.037) Data: 0.001 (0.001) Loss: 0.1799 (0.1821)
Epoch: [18][9/2459]Time: 0.023 (0.036) Data: 0.001 (0.001) Loss: 0.1897 (0.1829)
Epoch: [18][10/2459]Time: 0.023 (0.035) Data: 0.001 (0.001) Loss: 0.1812 (0.1827)
Epoch: [18][11/2459]Time: 0.

Epoch: [22][8/2459]Time: 0.023 (0.025) Data: 0.001 (0.001) Loss: 0.0883 (0.0710)
Epoch: [22][9/2459]Time: 0.023 (0.024) Data: 0.001 (0.001) Loss: 0.1212 (0.0760)
Epoch: [22][10/2459]Time: 0.023 (0.024) Data: 0.001 (0.001) Loss: 0.0756 (0.0760)
Epoch: [22][11/2459]Time: 0.046 (0.026) Data: 0.001 (0.001) Loss: 0.0887 (0.0771)
Epoch: [22][12/2459]Time: 0.025 (0.026) Data: 0.001 (0.001) Loss: 0.0823 (0.0775)
Epoch: [22][13/2459]Time: 0.065 (0.029) Data: 0.001 (0.001) Loss: 0.0833 (0.0779)
Epoch: [22][14/2459]Time: 0.076 (0.032) Data: 0.001 (0.001) Loss: 0.0810 (0.0781)
Epoch: [22][15/2459]Time: 0.024 (0.031) Data: 0.001 (0.001) Loss: 0.0742 (0.0778)
Epoch: [22][16/2459]Time: 0.023 (0.031) Data: 0.001 (0.001) Loss: 0.0739 (0.0776)
Epoch: [22][17/2459]Time: 0.022 (0.030) Data: 0.001 (0.001) Loss: 0.0764 (0.0775)
Epoch: [22][18/2459]Time: 0.023 (0.030) Data: 0.001 (0.001) Loss: 0.0857 (0.0780)
Epoch: [22][19/2459]Time: 0.023 (0.030) Data: 0.001 (0.001) Loss: 0.0971 (0.0789)
Epoch: [22][20/245

Epoch: [26][13/2459]Time: 0.028 (0.034) Data: 0.001 (0.001) Loss: 0.0448 (0.0551)
Epoch: [26][14/2459]Time: 0.027 (0.034) Data: 0.001 (0.001) Loss: 0.0454 (0.0545)
Epoch: [26][15/2459]Time: 0.048 (0.035) Data: 0.001 (0.001) Loss: 0.0410 (0.0536)
Epoch: [26][16/2459]Time: 0.028 (0.034) Data: 0.001 (0.001) Loss: 0.0590 (0.0539)
Epoch: [26][17/2459]Time: 0.063 (0.036) Data: 0.001 (0.001) Loss: 0.0449 (0.0534)
Epoch: [26][18/2459]Time: 0.075 (0.038) Data: 0.001 (0.001) Loss: 0.0484 (0.0532)
Epoch: [26][19/2459]Time: 0.033 (0.038) Data: 0.001 (0.001) Loss: 0.0437 (0.0527)
Epoch: [26][20/2459]Time: 0.027 (0.037) Data: 0.001 (0.001) Loss: 0.0430 (0.0522)
Epoch: [26][21/2459]Time: 0.027 (0.037) Data: 0.001 (0.001) Loss: 0.0496 (0.0521)
Epoch: [26][22/2459]Time: 0.026 (0.036) Data: 0.001 (0.001) Loss: 0.0547 (0.0522)
Epoch: [26][23/2459]Time: 0.024 (0.036) Data: 0.001 (0.001) Loss: 0.0522 (0.0522)
10-NN,s=0.1: TOP1:  35.416666666666664
best accuracy: 37.12

Epoch: 27
ResNet1D
Epoch: [27][0/2459

Epoch: [30][20/2459]Time: 0.045 (0.031) Data: 0.001 (0.001) Loss: 0.0454 (0.0378)
Epoch: [30][21/2459]Time: 0.026 (0.031) Data: 0.001 (0.001) Loss: 0.0397 (0.0379)
Epoch: [30][22/2459]Time: 0.064 (0.033) Data: 0.001 (0.001) Loss: 0.0330 (0.0377)
Epoch: [30][23/2459]Time: 0.074 (0.034) Data: 0.001 (0.001) Loss: 0.0445 (0.0380)
10-NN,s=0.1: TOP1:  36.833333333333336
best accuracy: 37.17

Epoch: 31
ResNet1D
Epoch: [31][0/2459]Time: 0.042 (0.042) Data: 0.001 (0.001) Loss: 0.0358 (0.0358)
Epoch: [31][1/2459]Time: 0.065 (0.054) Data: 0.001 (0.001) Loss: 0.0294 (0.0326)
Epoch: [31][2/2459]Time: 0.026 (0.044) Data: 0.001 (0.001) Loss: 0.0435 (0.0362)
Epoch: [31][3/2459]Time: 0.026 (0.040) Data: 0.001 (0.001) Loss: 0.0265 (0.0338)
Epoch: [31][4/2459]Time: 0.026 (0.037) Data: 0.001 (0.001) Loss: 0.0344 (0.0339)
Epoch: [31][5/2459]Time: 0.025 (0.035) Data: 0.001 (0.001) Loss: 0.0293 (0.0332)
Epoch: [31][6/2459]Time: 0.025 (0.034) Data: 0.001 (0.001) Loss: 0.0508 (0.0357)
Epoch: [31][7/2459]Time: 

Epoch: [34][21/2459]Time: 0.034 (0.029) Data: 0.001 (0.001) Loss: 0.0286 (0.0255)
Epoch: [34][22/2459]Time: 0.043 (0.030) Data: 0.001 (0.001) Loss: 0.0262 (0.0256)
Epoch: [34][23/2459]Time: 0.034 (0.030) Data: 0.002 (0.001) Loss: 0.0243 (0.0255)
10-NN,s=0.1: TOP1:  37.0
best accuracy: 37.21

Epoch: 35
ResNet1D
Epoch: [35][0/2459]Time: 0.023 (0.023) Data: 0.002 (0.002) Loss: 0.0210 (0.0210)
Epoch: [35][1/2459]Time: 0.023 (0.023) Data: 0.001 (0.001) Loss: 0.0376 (0.0293)
Epoch: [35][2/2459]Time: 0.023 (0.023) Data: 0.001 (0.001) Loss: 0.0231 (0.0272)
Epoch: [35][3/2459]Time: 0.032 (0.025) Data: 0.001 (0.001) Loss: 0.0273 (0.0272)
Epoch: [35][4/2459]Time: 0.023 (0.025) Data: 0.001 (0.001) Loss: 0.0300 (0.0278)
Epoch: [35][5/2459]Time: 0.023 (0.024) Data: 0.001 (0.001) Loss: 0.0247 (0.0273)
Epoch: [35][6/2459]Time: 0.023 (0.024) Data: 0.001 (0.001) Loss: 0.0495 (0.0305)
Epoch: [35][7/2459]Time: 0.023 (0.024) Data: 0.001 (0.001) Loss: 0.0247 (0.0297)
Epoch: [35][8/2459]Time: 0.023 (0.024) D

10-NN,s=0.1: TOP1:  36.708333333333336
best accuracy: 37.21

Epoch: 39
ResNet1D
Epoch: [39][0/2459]Time: 0.023 (0.023) Data: 0.002 (0.002) Loss: 0.0243 (0.0243)
Epoch: [39][1/2459]Time: 0.022 (0.023) Data: 0.001 (0.001) Loss: 0.0215 (0.0229)
Epoch: [39][2/2459]Time: 0.030 (0.025) Data: 0.001 (0.001) Loss: 0.0209 (0.0222)
Epoch: [39][3/2459]Time: 0.022 (0.025) Data: 0.001 (0.001) Loss: 0.0243 (0.0228)
Epoch: [39][4/2459]Time: 0.023 (0.024) Data: 0.001 (0.001) Loss: 0.0218 (0.0226)
Epoch: [39][5/2459]Time: 0.024 (0.024) Data: 0.001 (0.001) Loss: 0.0229 (0.0226)
Epoch: [39][6/2459]Time: 0.023 (0.024) Data: 0.001 (0.001) Loss: 0.0266 (0.0232)
Epoch: [39][7/2459]Time: 0.023 (0.024) Data: 0.001 (0.001) Loss: 0.0239 (0.0233)
Epoch: [39][8/2459]Time: 0.034 (0.025) Data: 0.001 (0.001) Loss: 0.0228 (0.0232)
Epoch: [39][9/2459]Time: 0.023 (0.025) Data: 0.001 (0.001) Loss: 0.0207 (0.0230)
Epoch: [39][10/2459]Time: 0.022 (0.024) Data: 0.001 (0.001) Loss: 0.0306 (0.0237)
Epoch: [39][11/2459]Time: 0.

Epoch: [43][8/2459]Time: 0.024 (0.025) Data: 0.001 (0.001) Loss: 0.0172 (0.0197)
Epoch: [43][9/2459]Time: 0.024 (0.025) Data: 0.001 (0.001) Loss: 0.0247 (0.0202)
Epoch: [43][10/2459]Time: 0.024 (0.025) Data: 0.001 (0.001) Loss: 0.0171 (0.0199)
Epoch: [43][11/2459]Time: 0.023 (0.024) Data: 0.001 (0.001) Loss: 0.0173 (0.0197)
Epoch: [43][12/2459]Time: 0.023 (0.024) Data: 0.001 (0.001) Loss: 0.0187 (0.0196)
Epoch: [43][13/2459]Time: 0.023 (0.024) Data: 0.001 (0.001) Loss: 0.0191 (0.0196)
Epoch: [43][14/2459]Time: 0.036 (0.025) Data: 0.001 (0.001) Loss: 0.0212 (0.0197)
Epoch: [43][15/2459]Time: 0.066 (0.028) Data: 0.001 (0.001) Loss: 0.0193 (0.0197)
Epoch: [43][16/2459]Time: 0.083 (0.031) Data: 0.026 (0.003) Loss: 0.0358 (0.0206)
Epoch: [43][17/2459]Time: 0.035 (0.031) Data: 0.001 (0.003) Loss: 0.0193 (0.0205)
Epoch: [43][18/2459]Time: 0.053 (0.032) Data: 0.008 (0.003) Loss: 0.0233 (0.0207)
Epoch: [43][19/2459]Time: 0.040 (0.033) Data: 0.009 (0.003) Loss: 0.0281 (0.0211)
Epoch: [43][20/245

Epoch: [47][17/2459]Time: 0.024 (0.024) Data: 0.001 (0.001) Loss: 0.0179 (0.0183)
Epoch: [47][18/2459]Time: 0.023 (0.024) Data: 0.001 (0.001) Loss: 0.0198 (0.0184)
Epoch: [47][19/2459]Time: 0.023 (0.024) Data: 0.001 (0.001) Loss: 0.0304 (0.0190)
Epoch: [47][20/2459]Time: 0.031 (0.024) Data: 0.001 (0.001) Loss: 0.0201 (0.0191)
Epoch: [47][21/2459]Time: 0.024 (0.024) Data: 0.001 (0.001) Loss: 0.0178 (0.0190)
Epoch: [47][22/2459]Time: 0.022 (0.024) Data: 0.001 (0.001) Loss: 0.0189 (0.0190)
Epoch: [47][23/2459]Time: 0.022 (0.024) Data: 0.001 (0.001) Loss: 0.0196 (0.0190)
10-NN,s=0.1: TOP1:  36.916666666666664
best accuracy: 37.21

Epoch: 48
ResNet1D
Epoch: [48][0/2459]Time: 0.023 (0.023) Data: 0.002 (0.002) Loss: 0.0334 (0.0334)
Epoch: [48][1/2459]Time: 0.022 (0.023) Data: 0.001 (0.001) Loss: 0.0155 (0.0244)
Epoch: [48][2/2459]Time: 0.022 (0.022) Data: 0.001 (0.001) Loss: 0.0168 (0.0219)
Epoch: [48][3/2459]Time: 0.022 (0.022) Data: 0.001 (0.001) Loss: 0.0199 (0.0214)
Epoch: [48][4/2459]Tim

Epoch: [51][19/2459]Time: 0.025 (0.032) Data: 0.001 (0.001) Loss: 0.0139 (0.0158)
Epoch: [51][20/2459]Time: 0.025 (0.031) Data: 0.001 (0.001) Loss: 0.0136 (0.0157)
Epoch: [51][21/2459]Time: 0.024 (0.031) Data: 0.001 (0.001) Loss: 0.0171 (0.0157)
Epoch: [51][22/2459]Time: 0.030 (0.031) Data: 0.001 (0.001) Loss: 0.0153 (0.0157)
Epoch: [51][23/2459]Time: 0.025 (0.031) Data: 0.001 (0.001) Loss: 0.0212 (0.0159)
10-NN,s=0.1: TOP1:  37.125
best accuracy: 37.21

Epoch: 52
ResNet1D
Epoch: [52][0/2459]Time: 0.023 (0.023) Data: 0.001 (0.001) Loss: 0.0110 (0.0110)
Epoch: [52][1/2459]Time: 0.022 (0.022) Data: 0.001 (0.001) Loss: 0.0142 (0.0126)
Epoch: [52][2/2459]Time: 0.022 (0.022) Data: 0.001 (0.001) Loss: 0.0125 (0.0126)
Epoch: [52][3/2459]Time: 0.022 (0.022) Data: 0.001 (0.001) Loss: 0.0145 (0.0131)
Epoch: [52][4/2459]Time: 0.022 (0.022) Data: 0.001 (0.001) Loss: 0.0149 (0.0134)
Epoch: [52][5/2459]Time: 0.022 (0.022) Data: 0.001 (0.001) Loss: 0.0140 (0.0135)
Epoch: [52][6/2459]Time: 0.036 (0.02

Epoch: [55][20/2459]Time: 0.048 (0.032) Data: 0.001 (0.001) Loss: 0.0143 (0.0147)
Epoch: [55][21/2459]Time: 0.028 (0.032) Data: 0.002 (0.001) Loss: 0.0170 (0.0148)
Epoch: [55][22/2459]Time: 0.065 (0.034) Data: 0.001 (0.001) Loss: 0.0165 (0.0149)
Epoch: [55][23/2459]Time: 0.074 (0.035) Data: 0.001 (0.001) Loss: 0.0142 (0.0149)
10-NN,s=0.1: TOP1:  35.791666666666664
best accuracy: 37.21

Epoch: 56
ResNet1D
Epoch: [56][0/2459]Time: 0.025 (0.025) Data: 0.001 (0.001) Loss: 0.0177 (0.0177)
Epoch: [56][1/2459]Time: 0.023 (0.024) Data: 0.001 (0.001) Loss: 0.0158 (0.0167)
Epoch: [56][2/2459]Time: 0.023 (0.023) Data: 0.001 (0.001) Loss: 0.0150 (0.0162)
Epoch: [56][3/2459]Time: 0.023 (0.023) Data: 0.001 (0.001) Loss: 0.0124 (0.0152)
Epoch: [56][4/2459]Time: 0.023 (0.023) Data: 0.001 (0.001) Loss: 0.0179 (0.0158)
Epoch: [56][5/2459]Time: 0.023 (0.023) Data: 0.001 (0.001) Loss: 0.0149 (0.0156)
Epoch: [56][6/2459]Time: 0.023 (0.023) Data: 0.001 (0.001) Loss: 0.0169 (0.0158)
Epoch: [56][7/2459]Time: 

10-NN,s=0.1: TOP1:  37.125
best accuracy: 37.62

Epoch: 60
ResNet1D
Epoch: [60][0/2459]Time: 0.030 (0.030) Data: 0.002 (0.002) Loss: 0.0188 (0.0188)
Epoch: [60][1/2459]Time: 0.025 (0.027) Data: 0.001 (0.002) Loss: 0.0130 (0.0159)
Epoch: [60][2/2459]Time: 0.025 (0.026) Data: 0.001 (0.001) Loss: 0.0110 (0.0143)
Epoch: [60][3/2459]Time: 0.025 (0.026) Data: 0.001 (0.001) Loss: 0.0125 (0.0138)
Epoch: [60][4/2459]Time: 0.025 (0.026) Data: 0.001 (0.001) Loss: 0.0145 (0.0140)
Epoch: [60][5/2459]Time: 0.024 (0.026) Data: 0.001 (0.001) Loss: 0.0109 (0.0135)
Epoch: [60][6/2459]Time: 0.024 (0.025) Data: 0.001 (0.001) Loss: 0.0112 (0.0131)
Epoch: [60][7/2459]Time: 0.024 (0.025) Data: 0.001 (0.001) Loss: 0.0137 (0.0132)
Epoch: [60][8/2459]Time: 0.024 (0.025) Data: 0.001 (0.001) Loss: 0.0147 (0.0134)
Epoch: [60][9/2459]Time: 0.048 (0.027) Data: 0.001 (0.001) Loss: 0.0111 (0.0131)
Epoch: [60][10/2459]Time: 0.023 (0.027) Data: 0.001 (0.001) Loss: 0.0180 (0.0136)
Epoch: [60][11/2459]Time: 0.070 (0.030) 

In [None]:
print (end-start)