In [1]:
import timm
import torch
import torchvision

torch.manual_seed(199)

  from .autonotebook import tqdm as notebook_tqdm


<torch._C.Generator at 0x269599501f0>

In [2]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DTYPE = torch.float16

In [3]:
random_affine = torchvision.transforms.RandomAffine(degrees=20,
                                                    scale=(0.8, 1.2),
                                                    translate=(0.2, 0.2),
                                                    interpolation=torchvision.transforms.InterpolationMode.BICUBIC)

In [4]:
#torchvision.transforms.Resize((32,32))
transform1 = torchvision.transforms.Compose([
                                                torchvision.transforms.RandomCrop(32, padding=4),
                                                torchvision.transforms.Resize(224),
                                                torchvision.transforms.RandomHorizontalFlip(),
                                                torchvision.transforms.ToTensor(),
                                                torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                                                torchvision.transforms.ConvertImageDtype(DTYPE)
                                             ])

transform2 = torchvision.transforms.Compose([
                                                torchvision.transforms.Resize(224),
                                                torchvision.transforms.ToTensor(),
                                                torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                                                torchvision.transforms.ConvertImageDtype(DTYPE)
                                             ])

In [5]:
BATCH_SIZE = 75

epochs = 20

lr = 0.002
momentum = 0.9

In [6]:
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform1)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
                                          shuffle=True, num_workers=8, pin_memory=True)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform2)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE,
                                         shuffle=False, num_workers=8, pin_memory=True)

Files already downloaded and verified
Files already downloaded and verified


In [7]:
net = timm.create_model("vit_base_patch16_224", pretrained=False)
net.head = torch.nn.Linear(net.head.in_features, 10)

net = net.train().to(device=DEVICE, dtype=DTYPE)

optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=momentum)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.9)

loss_func = torch.nn.CrossEntropyLoss()

In [8]:
def tran(epoch):
    net.train()
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs
        inputs, labels = data
        inputs = inputs.to(DEVICE)
        labels = labels.to(DEVICE)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = loss_func(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()

    scheduler.step()
    print('[%d] loss: %.3f' % (epoch + 1, running_loss))

In [9]:
def test():
    net.eval()
    all_counter=0
    correct_counter=0
    for i, data in enumerate(testloader, 0):
        inputs, labels = data
        inputs = inputs.to(DEVICE)
        labels = labels.to(DEVICE)
        out = net(inputs)
        out = out.detach().cpu().argmax(1)
        t = labels.cpu()
        for m in range(len(t)):
            all_counter += 1
            if t[m] == out[m]:
                correct_counter += 1

    print(correct_counter, all_counter, correct_counter / all_counter)
    return (correct_counter / all_counter)

In [10]:
optimizer.state_dict()['param_groups'][0]['lr']

0.002

In [11]:
correctRate = 0
for i in range(epochs):
    tran(i)
    r = test()
    if(r > correctRate):
        correctRate = r
        print("best: ", r , " in NO: ", i)
        torch.save(net.cpu(),"checkpoint/vit.pth")
        net = net.to(DEVICE)

[1] loss: 1223.480
3815 10000 0.3815
best:  0.3815  in NO:  0
[2] loss: 1058.729
4300 10000 0.43
best:  0.43  in NO:  1
[3] loss: 977.214
4692 10000 0.4692
best:  0.4692  in NO:  2
[4] loss: 933.788
4768 10000 0.4768
best:  0.4768  in NO:  3
[5] loss: 882.909
4629 10000 0.4629
[6] loss: 860.270
5286 10000 0.5286
best:  0.5286  in NO:  5
[7] loss: 828.134
5467 10000 0.5467
best:  0.5467  in NO:  6
[8] loss: 808.576
5256 10000 0.5256
[9] loss: 778.805
5656 10000 0.5656
best:  0.5656  in NO:  8
[10] loss: 767.952
5525 10000 0.5525
[11] loss: 740.688
5695 10000 0.5695
best:  0.5695  in NO:  10
[12] loss: 730.064
5757 10000 0.5757
best:  0.5757  in NO:  11
[13] loss: 715.108
5926 10000 0.5926
best:  0.5926  in NO:  12
[14] loss: 703.548
6103 10000 0.6103
best:  0.6103  in NO:  13
[15] loss: 682.301
6021 10000 0.6021
[16] loss: 674.952
6075 10000 0.6075
[17] loss: 654.547
6123 10000 0.6123
best:  0.6123  in NO:  16
[18] loss: 650.505
6215 10000 0.6215
best:  0.6215  in NO:  17
[19] loss: 636

In [12]:
'''

[1] loss: 115.806
9717 10000 0.9717
best:  0.9717  in NO:  0
[2] loss: 36.799
9803 10000 0.9803
best:  0.9803  in NO:  1
[3] loss: 20.298
9802 10000 0.9802
[4] loss: 13.308
9797 10000 0.9797
[5] loss: 7.838
9827 10000 0.9827
best:  0.9827  in NO:  4
[6] loss: 8.158
9834 10000 0.9834
best:  0.9834  in NO:  5
[7] loss: 5.108
9851 10000 0.9851
best:  0.9851  in NO:  6
[8] loss: 3.506
9844 10000 0.9844
[9] loss: 3.396
9867 10000 0.9867
best:  0.9867  in NO:  8
[10] loss: 2.826
9867 10000 0.9867
[11] loss: 2.288
9860 10000 0.986
[12] loss: 1.877
9881 10000 0.9881
best:  0.9881  in NO:  11
[13] loss: 0.738
9889 10000 0.9889
best:  0.9889  in NO:  12
[14] loss: 0.774
9890 10000 0.989
best:  0.989  in NO:  13
[15] loss: 0.704
9886 10000 0.9886
[16] loss: 0.604
9882 10000 0.9882
[17] loss: 0.461
9884 10000 0.9884
[18] loss: 0.577
9875 10000 0.9875
[19] loss: 0.517
9890 10000 0.989
[20] loss: 0.407
9887 10000 0.9887

'''

'\n\n[1] loss: 115.806\n9717 10000 0.9717\nbest:  0.9717  in NO:  0\n[2] loss: 36.799\n9803 10000 0.9803\nbest:  0.9803  in NO:  1\n[3] loss: 20.298\n9802 10000 0.9802\n[4] loss: 13.308\n9797 10000 0.9797\n[5] loss: 7.838\n9827 10000 0.9827\nbest:  0.9827  in NO:  4\n[6] loss: 8.158\n9834 10000 0.9834\nbest:  0.9834  in NO:  5\n[7] loss: 5.108\n9851 10000 0.9851\nbest:  0.9851  in NO:  6\n[8] loss: 3.506\n9844 10000 0.9844\n[9] loss: 3.396\n9867 10000 0.9867\nbest:  0.9867  in NO:  8\n[10] loss: 2.826\n9867 10000 0.9867\n[11] loss: 2.288\n9860 10000 0.986\n[12] loss: 1.877\n9881 10000 0.9881\nbest:  0.9881  in NO:  11\n[13] loss: 0.738\n9889 10000 0.9889\nbest:  0.9889  in NO:  12\n[14] loss: 0.774\n9890 10000 0.989\nbest:  0.989  in NO:  13\n[15] loss: 0.704\n9886 10000 0.9886\n[16] loss: 0.604\n9882 10000 0.9882\n[17] loss: 0.461\n9884 10000 0.9884\n[18] loss: 0.577\n9875 10000 0.9875\n[19] loss: 0.517\n9890 10000 0.989\n[20] loss: 0.407\n9887 10000 0.9887\n\n'

In [13]:
correctRate = 0.9887
for i in range(epochs):
    tran(i)
    r = test()
    if(r > correctRate):
        correctRate = r
        print("best: ", r , " in NO: ", i)
        torch.save(net.cpu(),"checkpoint/vit.pth")
        net = net.to(DEVICE)

[1] loss: 610.944
6134 10000 0.6134
[2] loss: 603.417
6163 10000 0.6163
[3] loss: 586.596
6356 10000 0.6356
[4] loss: 582.059
6273 10000 0.6273
[5] loss: 569.641
6442 10000 0.6442
[6] loss: 563.017
6404 10000 0.6404
[7] loss: 549.224
6416 10000 0.6416
[8] loss: 544.415
6355 10000 0.6355
[9] loss: 529.503
6422 10000 0.6422
[10] loss: 525.560
6437 10000 0.6437
[11] loss: 510.370
6380 10000 0.638
[12] loss: 508.068
6311 10000 0.6311
[13] loss: 497.575
6562 10000 0.6562
[14] loss: 489.541
6464 10000 0.6464
[15] loss: 474.975
6524 10000 0.6524
[16] loss: 474.075
6491 10000 0.6491
[17] loss: 459.201
6528 10000 0.6528
[18] loss: 457.192
6597 10000 0.6597
[19] loss: 447.351
6523 10000 0.6523
[20] loss: 441.142
6504 10000 0.6504


In [14]:
'''

[1] loss: 0.287
9892 10000 0.9892
best:  0.9892  in NO:  0
[2] loss: 0.350
9889 10000 0.9889
[3] loss: 0.209
9884 10000 0.9884
[4] loss: 0.173
9886 10000 0.9886
[5] loss: 0.237
9888 10000 0.9888
[6] loss: 0.228
9892 10000 0.9892
[7] loss: 0.237
9896 10000 0.9896
best:  0.9896  in NO:  6
[8] loss: 0.193
9890 10000 0.989
[9] loss: 0.188
9892 10000 0.9892
[10] loss: 0.276
9895 10000 0.9895
[11] loss: 0.191
9897 10000 0.9897
best:  0.9897  in NO:  10
[12] loss: 0.311
9894 10000 0.9894
[13] loss: 0.165
9896 10000 0.9896
[14] loss: 0.178
9892 10000 0.9892
[15] loss: 0.142
9893 10000 0.9893
[16] loss: 0.119
9892 10000 0.9892
[17] loss: 0.182
9895 10000 0.9895
[18] loss: 0.154
9895 10000 0.9895
[19] loss: 0.198
9893 10000 0.9893
[20] loss: 0.162
9895 10000 0.9895

'''

'\n\n[1] loss: 0.287\n9892 10000 0.9892\nbest:  0.9892  in NO:  0\n[2] loss: 0.350\n9889 10000 0.9889\n[3] loss: 0.209\n9884 10000 0.9884\n[4] loss: 0.173\n9886 10000 0.9886\n[5] loss: 0.237\n9888 10000 0.9888\n[6] loss: 0.228\n9892 10000 0.9892\n[7] loss: 0.237\n9896 10000 0.9896\nbest:  0.9896  in NO:  6\n[8] loss: 0.193\n9890 10000 0.989\n[9] loss: 0.188\n9892 10000 0.9892\n[10] loss: 0.276\n9895 10000 0.9895\n[11] loss: 0.191\n9897 10000 0.9897\nbest:  0.9897  in NO:  10\n[12] loss: 0.311\n9894 10000 0.9894\n[13] loss: 0.165\n9896 10000 0.9896\n[14] loss: 0.178\n9892 10000 0.9892\n[15] loss: 0.142\n9893 10000 0.9893\n[16] loss: 0.119\n9892 10000 0.9892\n[17] loss: 0.182\n9895 10000 0.9895\n[18] loss: 0.154\n9895 10000 0.9895\n[19] loss: 0.198\n9893 10000 0.9893\n[20] loss: 0.162\n9895 10000 0.9895\n\n'

In [15]:
print(net.cpu())

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop2): Dropout(p=0.0, inplace=F

In [16]:
print(net.state_dict())

OrderedDict([('cls_token', tensor([[[-1.5221e-02,  3.7964e-02,  5.8022e-03,  3.6774e-02,  1.6602e-02,
           2.8564e-02, -1.5335e-02, -6.6986e-03,  1.1978e-02, -2.6001e-02,
           7.3051e-03, -2.5253e-02,  1.2543e-02, -1.7166e-02, -6.7024e-03,
          -2.2614e-02, -3.1952e-02, -2.5139e-03,  4.0466e-02,  1.4107e-02,
          -1.4709e-02, -3.7323e-02, -2.3087e-02, -1.3893e-02, -2.0538e-02,
           8.8043e-03, -3.0502e-02,  1.0376e-02, -3.2867e-02,  1.9058e-02,
          -2.6108e-02,  3.0228e-02,  2.0279e-02,  3.2387e-03,  1.2047e-02,
           3.7323e-02, -3.2104e-02, -7.0686e-03, -1.3588e-02, -3.2501e-02,
           1.3435e-02,  5.5618e-03,  2.9678e-02, -1.6953e-02, -3.8185e-03,
           2.1362e-02,  5.0323e-02,  3.0502e-02, -2.9800e-02,  3.5919e-02,
          -1.1147e-02,  2.3209e-02,  2.1362e-02, -1.9836e-02, -1.1536e-02,
          -2.7786e-02,  2.5345e-02,  3.6438e-02, -1.5686e-02, -2.9068e-02,
          -2.0676e-02, -2.6367e-02,  1.2512e-02, -1.1154e-02,  3.0182e-02

In [17]:
#torch.cuda.empty_cache()

In [1]:
import torch
import torchvision

torch.manual_seed(199)

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DTYPE = torch.float32

BATCH_SIZE = 256

#torchvision.transforms.Resize((32,32))
transform1 = torchvision.transforms.Compose([
                                                torchvision.transforms.RandomCrop(32, padding=4),
                                                torchvision.transforms.RandomHorizontalFlip(),
                                                torchvision.transforms.ToTensor(),
                                                torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                                                torchvision.transforms.ConvertImageDtype(DTYPE)
                                             ])

transform2 = torchvision.transforms.Compose([
                                                torchvision.transforms.ToTensor(),
                                                torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                                                torchvision.transforms.ConvertImageDtype(DTYPE)
                                             ])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform1)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
                                          shuffle=True, num_workers=8, pin_memory=True)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform2)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE,
                                         shuffle=False, num_workers=8, pin_memory=True)

  from .autonotebook import tqdm as notebook_tqdm


Files already downloaded and verified
Files already downloaded and verified


In [2]:
from torch import nn

from einops import rearrange
from einops.layers.torch import Rearrange

# helpers

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32):
    _, h, w, dim, device, dtype = *(patches.shape), patches.device, patches.dtype

    y, x = torch.meshgrid(torch.arange(h, device = device), torch.arange(w, device = device), indexing = 'ij')
    assert (dim % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb'
    omega = torch.arange(dim // 4, device = device) / (dim // 4 - 1)
    omega = 1. / (temperature ** omega)

    y = y.flatten()[:, None] * omega[None, :]
    x = x.flatten()[:, None] * omega[None, :]
    pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
    return pe.type(dtype)

# classes

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim),
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64):
        super().__init__()
        inner_dim = dim_head *  heads
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.norm = nn.LayerNorm(dim)

        self.attend = nn.Softmax(dim = -1)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

    def forward(self, x):
        x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head),
                FeedForward(dim, mlp_dim)
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b h w (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.Linear(patch_dim, dim),
        )

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)

        self.to_latent = nn.Identity()
        self.linear_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        *_, h, w, dtype = *(img.shape), img.dtype

        x = self.to_patch_embedding(img)
        pe = posemb_sincos_2d(x)
        x = rearrange(x, 'b ... d -> b (...) d') + pe

        x = self.transformer(x)
        x = x.mean(dim = 1)

        x = self.to_latent(x)
        return self.linear_head(x)

In [3]:


net2 = ViT(
        image_size=32,
        patch_size=8,
        num_classes=10,
        dim=256,
        depth=8,
        heads=6,
        mlp_dim=128
)

'''

net2 = ViT(
        image_size=32,
        patch_size=8,
        num_classes=10,
        dim=128,
        depth=4,
        heads=4,
        mlp_dim=64
)

'''

'\n\nnet2 = ViT(\n        image_size=32,\n        patch_size=8,\n        num_classes=10,\n        dim=128,\n        depth=4,\n        heads=4,\n        mlp_dim=64\n)\n\n'

In [9]:
epochs = 50

lr = 0.1
momentum = 0.9

In [10]:
net2 = net2.train().to(device=DEVICE, dtype=DTYPE)

optimizer = torch.optim.SGD(net2.parameters(), lr=lr, momentum=momentum)
#optimizer = torch.optim.Adam(net2.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.9)

loss_func = torch.nn.CrossEntropyLoss()

In [11]:
def tran2(epoch):
    net2.train()
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs
        inputs, labels = data
        inputs = inputs.to(DEVICE)
        labels = labels.to(DEVICE)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net2(inputs)
        loss = loss_func(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()

    scheduler.step()
    print('[%d] loss: %.3f' % (epoch + 1, running_loss))

In [12]:
def test2():
    net2.eval()
    all_counter=0
    correct_counter=0
    for i, data in enumerate(testloader, 0):
        inputs, labels = data
        inputs = inputs.to(DEVICE)
        labels = labels.to(DEVICE)
        out = net2(inputs)
        out = out.detach().cpu().argmax(1)
        t = labels.cpu()
        for m in range(len(t)):
            all_counter += 1
            if t[m] == out[m]:
                correct_counter += 1

    print(correct_counter, all_counter, correct_counter / all_counter)
    return (correct_counter / all_counter)

In [13]:
correctRate = 0
for i in range(epochs):
    tran2(i)
    r = test2()
    if(r > correctRate):
        correctRate = r
        print("best: ", r , " in NO: ", i)
        torch.save(net2.cpu(),"checkpoint/vits.pth")
        net2 = net2.to(DEVICE)


[1] loss: 245.179
5699 10000 0.5699
best:  0.5699  in NO:  0
[2] loss: 223.761
5794 10000 0.5794
best:  0.5794  in NO:  1
[3] loss: 209.334
5956 10000 0.5956
best:  0.5956  in NO:  2
[4] loss: 203.727
6125 10000 0.6125
best:  0.6125  in NO:  3
[5] loss: 195.133
6323 10000 0.6323
best:  0.6323  in NO:  4
[6] loss: 189.139
6490 10000 0.649
best:  0.649  in NO:  5
[7] loss: 181.537
6469 10000 0.6469
[8] loss: 179.389
6654 10000 0.6654
best:  0.6654  in NO:  7
[9] loss: 172.475
6684 10000 0.6684
best:  0.6684  in NO:  8
[10] loss: 170.755
6735 10000 0.6735
best:  0.6735  in NO:  9
[11] loss: 165.952
6731 10000 0.6731
[12] loss: 163.411
6867 10000 0.6867
best:  0.6867  in NO:  11
[13] loss: 157.752
6869 10000 0.6869
best:  0.6869  in NO:  12
[14] loss: 156.063
6861 10000 0.6861
[15] loss: 151.729
6943 10000 0.6943
best:  0.6943  in NO:  14
[16] loss: 149.477
6958 10000 0.6958
best:  0.6958  in NO:  15
[17] loss: 146.276
7014 10000 0.7014
best:  0.7014  in NO:  16
[18] loss: 143.296
7051 100

In [None]:
'''

lr = 0.05

"[1] loss: 404.931\n",
      "3198 10000 0.3198\n",
      "best:  0.3198  in NO:  0\n",
      "[2] loss: 341.681\n",
      "3834 10000 0.3834\n",
      "best:  0.3834  in NO:  1\n",
      "[3] loss: 317.256\n",
      "3991 10000 0.3991\n",
      "best:  0.3991  in NO:  2\n",
      "[4] loss: 302.169\n",
      "4393 10000 0.4393\n",
      "best:  0.4393  in NO:  3\n",
      "[5] loss: 289.559\n",
      "4458 10000 0.4458\n",
      "best:  0.4458  in NO:  4\n",
      "[6] loss: 277.654\n",
      "4842 10000 0.4842\n",
      "best:  0.4842  in NO:  5\n",
      "[7] loss: 268.890\n",
      "5020 10000 0.502\n",
      "best:  0.502  in NO:  6\n",
      "[8] loss: 261.052\n",
      "5168 10000 0.5168\n",
      "best:  0.5168  in NO:  7\n",
      "[9] loss: 251.862\n",
      "5387 10000 0.5387\n",
      "best:  0.5387  in NO:  8\n",
      "[10] loss: 248.690\n",
      "5337 10000 0.5337\n",
      "[11] loss: 242.203\n",
      "5463 10000 0.5463\n",
      "best:  0.5463  in NO:  10\n",
      "[12] loss: 234.319\n",
      "5621 10000 0.5621\n",
      "best:  0.5621  in NO:  11\n",
      "[13] loss: 228.763\n",
      "5666 10000 0.5666\n",
      "best:  0.5666  in NO:  12\n",
      "[14] loss: 225.170\n",
      "5672 10000 0.5672\n",
      "best:  0.5672  in NO:  13\n",
      "[15] loss: 220.040\n",
      "5959 10000 0.5959\n",
      "best:  0.5959  in NO:  14\n",
      "[16] loss: 215.442\n",
      "5821 10000 0.5821\n",
      "[17] loss: 210.248\n",
      "5984 10000 0.5984\n",
      "best:  0.5984  in NO:  16\n",
      "[18] loss: 207.026\n",
      "6146 10000 0.6146\n",
      "best:  0.6146  in NO:  17\n",
      "[19] loss: 202.818\n",
      "6047 10000 0.6047\n",
      "[20] loss: 199.656\n",
      "6152 10000 0.6152\n",
      "best:  0.6152  in NO:  19\n",
      "[21] loss: 195.116\n",
      "6346 10000 0.6346\n",
      "best:  0.6346  in NO:  20\n",
      "[22] loss: 194.315\n",
      "6331 10000 0.6331\n",
      "[23] loss: 189.150\n",
      "6385 10000 0.6385\n",
      "best:  0.6385  in NO:  22\n",
      "[24] loss: 186.852\n",
      "6491 10000 0.6491\n",
      "best:  0.6491  in NO:  23\n",
      "[25] loss: 184.331\n",
      "6535 10000 0.6535\n",
      "best:  0.6535  in NO:  24\n",
      "[26] loss: 181.108\n",
      "6446 10000 0.6446\n",
      "[27] loss: 178.737\n",
      "6531 10000 0.6531\n",
      "[28] loss: 176.147\n",
      "6604 10000 0.6604\n",
      "best:  0.6604  in NO:  27\n",
      "[29] loss: 173.203\n",
      "6573 10000 0.6573\n",
      "[30] loss: 172.319\n",
      "6643 10000 0.6643\n",
      "best:  0.6643  in NO:  29\n",
      "[31] loss: 169.811\n",
      "6682 10000 0.6682\n",
      "best:  0.6682  in NO:  30\n",
      "[32] loss: 167.813\n",
      "6588 10000 0.6588\n",
      "[33] loss: 164.900\n",
      "6742 10000 0.6742\n",
      "best:  0.6742  in NO:  32\n",
      "[34] loss: 164.467\n",
      "6671 10000 0.6671\n",
      "[35] loss: 161.650\n",
      "6856 10000 0.6856\n",
      "best:  0.6856  in NO:  34\n",
      "[36] loss: 160.288\n",
      "6820 10000 0.682\n",
      "[37] loss: 158.733\n",
      "6856 10000 0.6856\n",
      "[38] loss: 157.677\n",
      "6917 10000 0.6917\n",
      "best:  0.6917  in NO:  37\n",
      "[39] loss: 154.414\n",
      "6916 10000 0.6916\n",
      "[40] loss: 153.393\n",
      "6876 10000 0.6876\n",
      "[41] loss: 150.992\n",
      "6953 10000 0.6953\n",
      "best:  0.6953  in NO:  40\n",
      "[42] loss: 150.110\n",
      "7029 10000 0.7029\n",
      "best:  0.7029  in NO:  41\n",
      "[43] loss: 149.221\n",
      "6999 10000 0.6999\n",
      "[44] loss: 148.187\n",
      "6978 10000 0.6978\n",
      "[45] loss: 146.554\n",
      "6980 10000 0.698\n",
      "[46] loss: 145.776\n",
      "6992 10000 0.6992\n",
      "[47] loss: 143.981\n",
      "6994 10000 0.6994\n",
      "[48] loss: 143.742\n",
      "7079 10000 0.7079\n",
      "best:  0.7079  in NO:  47\n",
      "[49] loss: 141.786\n",
      "7045 10000 0.7045\n",
      "[50] loss: 140.800\n",
      "7110 10000 0.711\n",
      "best:  0.711  in NO:  49\n"

'''

In [None]:
'''

lr = 0.05
batch = 256

[1] loss: 784.486
3543 10000 0.3543
best:  0.3543  in NO:  0
[2] loss: 664.410
4061 10000 0.4061
best:  0.4061  in NO:  1
[3] loss: 634.330
4104 10000 0.4104
best:  0.4104  in NO:  2
[4] loss: 606.153
4409 10000 0.4409
best:  0.4409  in NO:  3
[5] loss: 586.263
4583 10000 0.4583
best:  0.4583  in NO:  4
[6] loss: 567.869
4751 10000 0.4751
best:  0.4751  in NO:  5
[7] loss: 547.394
4883 10000 0.4883
best:  0.4883  in NO:  6
[8] loss: 535.890
5025 10000 0.5025
best:  0.5025  in NO:  7
[9] loss: 519.164
5076 10000 0.5076
best:  0.5076  in NO:  8
[10] loss: 508.081
5308 10000 0.5308
best:  0.5308  in NO:  9
[11] loss: 494.039
5411 10000 0.5411
best:  0.5411  in NO:  10
[12] loss: 482.274
5507 10000 0.5507
best:  0.5507  in NO:  11
[13] loss: 472.957
5523 10000 0.5523
best:  0.5523  in NO:  12
[14] loss: 467.231
5621 10000 0.5621
best:  0.5621  in NO:  13
[15] loss: 454.950
5792 10000 0.5792
best:  0.5792  in NO:  14
[16] loss: 442.935
5681 10000 0.5681
[17] loss: 433.986
5836 10000 0.5836
best:  0.5836  in NO:  16
[18] loss: 429.796
5871 10000 0.5871
best:  0.5871  in NO:  17
[19] loss: 423.177
5890 10000 0.589
best:  0.589  in NO:  18
[20] loss: 417.535
6057 10000 0.6057
best:  0.6057  in NO:  19
[21] loss: 404.164
6106 10000 0.6106
best:  0.6106  in NO:  20
[22] loss: 400.901
6274 10000 0.6274
best:  0.6274  in NO:  21
[23] loss: 393.700
6148 10000 0.6148
[24] loss: 390.899
6099 10000 0.6099
[25] loss: 383.347
6385 10000 0.6385
best:  0.6385  in NO:  24
[26] loss: 373.090
6273 10000 0.6273
[27] loss: 366.752
6559 10000 0.6559
best:  0.6559  in NO:  26
[28] loss: 363.315
6557 10000 0.6557
[29] loss: 358.751
6490 10000 0.649
[30] loss: 354.392
6457 10000 0.6457
[30] loss: 354.392
6457 10000 0.6457
[31] loss: 346.615
6450 10000 0.645
[32] loss: 340.893
6573 10000 0.6573
best:  0.6573  in NO:  31
[33] loss: 337.954
6638 10000 0.6638
best:  0.6638  in NO:  32
[34] loss: 334.542
6740 10000 0.674
best:  0.674  in NO:  33
[35] loss: 332.343
6743 10000 0.6743
best:  0.6743  in NO:  34
[36] loss: 324.215
6651 10000 0.6651
[37] loss: 319.147
6813 10000 0.6813
best:  0.6813  in NO:  36
[38] loss: 317.386
6851 10000 0.6851
best:  0.6851  in NO:  37
[39] loss: 314.624
6901 10000 0.6901
best:  0.6901  in NO:  38
[40] loss: 314.070
6798 10000 0.6798
[41] loss: 305.120
6921 10000 0.6921
best:  0.6921  in NO:  40
[42] loss: 302.018
6948 10000 0.6948
best:  0.6948  in NO:  41
[43] loss: 299.509
7006 10000 0.7006
best:  0.7006  in NO:  42
[44] loss: 297.726
6943 10000 0.6943
[45] loss: 293.764
6985 10000 0.6985
[46] loss: 288.799
6940 10000 0.694
[47] loss: 283.599
7005 10000 0.7005
[48] loss: 283.725
7040 10000 0.704
best:  0.704  in NO:  47
[49] loss: 280.424
6975 10000 0.6975
[50] loss: 280.512
6976 10000 0.6976
[51] loss: 272.883
7085 10000 0.7085
best:  0.7085  in NO:  50
[52] loss: 269.482
7051 10000 0.7051


'''

In [26]:
torch.cuda.empty_cache()