In [None]:
from functools import partial
from collections import OrderedDict

%config InlineBackend.figure_format = 'retina'

import numpy as np

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

import torchvision as tv

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [None]:
import requests
import io

def get_weights(bit_variant):
  response = requests.get(f'https://storage.googleapis.com/bit_models/{bit_variant}.npz')
  response.raise_for_status()
  return np.load(io.BytesIO(response.content))

weights = get_weights('BiT-M-R50x1')

In [None]:
#weights_cifar100 = get_weights('BiT-M-R50x1-CIFAR100')

In [None]:
def tf2th(conv_weights):
  """Possibly convert HWIO to OIHW"""
  if conv_weights.ndim == 4:
    conv_weights = np.transpose(conv_weights, [3, 2, 0, 1])
  return torch.from_numpy(conv_weights)

class StdConv2d(nn.Conv2d):

  def forward(self, x):
    w = self.weight
    v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)
    w = (w - m) / torch.sqrt(v + 1e-10)
    return F.conv2d(x, w, self.bias, self.stride, self.padding,
                    self.dilation, self.groups)

def conv3x3(cin, cout, stride=1, groups=1, bias=False):
  return StdConv2d(cin, cout, kernel_size=3, stride=stride, padding=1, bias=bias, groups=groups)

def conv1x1(cin, cout, stride=1, bias=False):
  return StdConv2d(cin, cout, kernel_size=1, stride=stride, padding=0, bias=bias)

class PreActConv(nn.Module):

  def __init__(self, cin, cout=None, cmid=None, stride=1):
    super().__init__()
    cout = cout or cin
    cmid = cmid or cout//4

    self.gn1 = nn.GroupNorm(32, cin)
    self.conv1 = conv1x1(cin, cmid)
    self.gn2 = nn.GroupNorm(32, cmid)
    self.conv2 = conv3x3(cmid, cmid, stride)  # Original code has it on conv1!!
    self.gn3 = nn.GroupNorm(32, cmid)
    self.conv3 = conv1x1(cmid, cout)
    self.relu = nn.ReLU(inplace=True)


  def forward(self, x):
    out = self.relu(self.gn1(x))

    # Unit's branch
    out = self.conv1(out)
    out = self.conv2(self.relu(self.gn2(out)))
    out = self.conv3(self.relu(self.gn3(out)))

    return out

  def load_from(self, weights, prefix=''):
    with torch.no_grad():
      self.conv1.weight.copy_(tf2th(weights[prefix + 'a/standardized_conv2d/kernel']))
      self.conv2.weight.copy_(tf2th(weights[prefix + 'b/standardized_conv2d/kernel']))
      self.conv3.weight.copy_(tf2th(weights[prefix + 'c/standardized_conv2d/kernel']))
      self.gn1.weight.copy_(tf2th(weights[prefix + 'a/group_norm/gamma']))
      self.gn2.weight.copy_(tf2th(weights[prefix + 'b/group_norm/gamma']))
      self.gn3.weight.copy_(tf2th(weights[prefix + 'c/group_norm/gamma']))
      self.gn1.bias.copy_(tf2th(weights[prefix + 'a/group_norm/beta']))
      self.gn2.bias.copy_(tf2th(weights[prefix + 'b/group_norm/beta']))
      self.gn3.bias.copy_(tf2th(weights[prefix + 'c/group_norm/beta']))
      self.conv1.weight.requires_grad = False
      self.conv2.weight.requires_grad = False
      self.conv3.weight.requires_grad = False
      self.gn1.weight.requires_grad = False
      self.gn2.weight.requires_grad = False
      self.gn3.weight.requires_grad = False
      self.gn1.bias.requires_grad = False
      self.gn2.bias.requires_grad = False
      self.gn3.bias.requires_grad = False

    return self
  def melt(self):
      self.conv1.weight.requires_grad = True
      self.conv2.weight.requires_grad = True
      self.conv3.weight.requires_grad = True
      self.gn1.weight.requires_grad = True
      self.gn2.weight.requires_grad = True
      self.gn3.weight.requires_grad = True
      self.gn1.bias.requires_grad = True
      self.gn2.bias.requires_grad = True
      self.gn3.bias.requires_grad = True
  def freeze(self):
      self.conv1.weight.requires_grad = False
      self.conv2.weight.requires_grad = False
      self.conv3.weight.requires_grad = False
      self.gn1.weight.requires_grad = False
      self.gn2.weight.requires_grad = False
      self.gn3.weight.requires_grad = False
      self.gn1.bias.requires_grad = False
      self.gn2.bias.requires_grad = False
      self.gn3.bias.requires_grad = False


class ConvWithLin(nn.Module):
  def __init__(self, lcin, cin, cout = None, cmid = None, lcout = None, stride=1, block_unit = 1):
    super().__init__()
    cout = cout or cin
    cmid = cmid or cout//4
    lcout = lcout or lcin

    #self.unit = PreActConv(cin=cin,cout=cout,cmid=cmid,stride=stride)
    if lcin != cin:
      self.projin = conv1x1(lcin,cin,stride)

    self.unit = nn.Sequential(OrderedDict(
            [('unit01', PreActConv(cin=cin, cout=cout, cmid=cmid, stride=stride))] +
            [(f'unit{i:02d}', PreActConv(cin=cout, cout=cout, cmid=cmid)) for i in range(2, block_unit + 1)],
        ))
    if lcout != cout:
      self.projout = conv1x1(cout, lcout, stride)
    if lcout != lcin:
      self.linproj = conv1x1(lcin, lcout, stride)


  def forward(self, x):
    if hasattr(self, 'projin'):
      out = self.projin(x)
    else:
      out = x

    out = self.unit(out)

    if hasattr(self, 'projout'):
      out = self.projout(out)
    if hasattr(self, 'linproj'):
      x = self.linproj(x)

    out = torch.add(out,x)

#    if hasattr(self, 'outunit'):
#      out = self.outunit(out)

    return out


class LinNet(nn.Module):
  """Implementation of Pre-activation (v2) ResNet mode."""
  BLOCK_UNITS = {
      'r50': [3, 4, 6, 3],
      'r101': [3, 4, 23, 3],
      'r152': [3, 8, 36, 3],
  }

  def __init__(self, block_units, width_factor, head_size=21843, zero_head=False):
    super().__init__()
    wf = width_factor  # shortcut 'cause we'll use it a lot.

    self.update1 = conv3x3(cin=256*wf, cout=256*wf)

    # The following will be unreadable if we split lines.
    # pylint: disable=line-too-long
    self.root = nn.Sequential(OrderedDict([
        ('conv', StdConv2d(3, 32*wf, kernel_size=7, stride=2, padding=3, bias=False)),
        ('pad', nn.ConstantPad2d(1, 0)),
        ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0)),
        # The following is subtly not the same!
        # ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
    ]))

    self.body = nn.Sequential(OrderedDict([
        ('block1', ConvWithLin(lcin = 32*wf, lcout=256*wf, cin=32*wf, cout= 64*wf, cmid=32*wf, block_unit = block_units[0])),

        ('block2', ConvWithLin(lcin = 256*wf, lcout=256*wf, cin=64*wf, cout= 128*wf, cmid=64*wf, block_unit = block_units[1])),

        ('block3', ConvWithLin(lcin = 256*wf, lcout=256*wf, cin=128*wf, cout= 256*wf, cmid=128*wf, block_unit = block_units[2])),

        #('block4', ConvWithLin(lcin = 2048*wf, lcout=2048*wf, cin=1024*wf, cout= 2048*wf, cmid=512*wf, block_unit = block_units[3])),
    ]))
    # pylint: enable=line-too-long

    self.zero_head = zero_head
    self.head = nn.Sequential(OrderedDict([
        ('gn', nn.GroupNorm(32, 256*wf)),
        ('relu', nn.ReLU(inplace=True)),
        ('avg', nn.AdaptiveAvgPool2d(output_size=1)),
        ('conv', nn.Conv2d(256*wf, head_size, kernel_size=1, bias=True)),
    ]))

  def forward(self, x):
    x = self.root(x)#self.head(self.body(self.root(x)))
    for bname, block in self.body.named_children():
      x = self.update1(block(x))
    x = self.head(x)

    assert x.shape[-2:] == (1, 1)  # We should have no spatial shape left.
    return x[...,0,0]

  def load_from(self, weights, prefix='resnet/'):
    with torch.no_grad():
      self.root.conv.weight.copy_(tf2th(weights[f'{prefix}root_block/standardized_conv2d/kernel']))
      self.head.gn.weight.copy_(tf2th(weights[f'{prefix}group_norm/gamma']))
      self.head.gn.bias.copy_(tf2th(weights[f'{prefix}group_norm/beta']))
      if self.zero_head:
        nn.init.zeros_(self.head.conv.weight)
        nn.init.zeros_(self.head.conv.bias)
      else:
        self.head.conv.weight.copy_(tf2th(weights[f'{prefix}head/conv2d/kernel']))
        self.head.conv.bias.copy_(tf2th(weights[f'{prefix}head/conv2d/bias']))

      for bname, block in self.body.named_children():
        for uname, unit in block.unit.named_children():
          unit.load_from(weights, prefix=f'{prefix}{bname}/{uname}/')
    return self


In [None]:
model = LinNet(LinNet.BLOCK_UNITS['r50'], width_factor=1, head_size=100,zero_head = True)  # NOTE: No new head.
#model.load_from(weights)
model.to(device)

LinNet(
  (update1): StdConv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (root): Sequential(
    (conv): StdConv2d(3, 32, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (pad): ConstantPad2d(padding=(1, 1, 1, 1), value=0)
    (pool): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (body): Sequential(
    (block1): ConvWithLin(
      (unit): Sequential(
        (unit01): PreActConv(
          (gn1): GroupNorm(32, 32, eps=1e-05, affine=True)
          (conv1): StdConv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (gn2): GroupNorm(32, 32, eps=1e-05, affine=True)
          (conv2): StdConv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (gn3): GroupNorm(32, 32, eps=1e-05, affine=True)
          (conv3): StdConv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (relu): ReLU(inplace=True)
        )
        (unit02): PreActConv(
        

Train

In [None]:
import torchvision
import torchvision.transforms as T

In [None]:

transform_train = T.Compose( [T.RandomCrop(32, padding=4), T.RandomHorizontalFlip(), T.ToTensor(), T.Normalize( (0.5, 0.5, 0.5), (0.5, 0.5, 0.5) )] )
transform_test = T.Compose( [T.ToTensor(), T.Normalize( (0.5, 0.5, 0.5), (0.5, 0.5, 0.5) )] )

train_set = torchvision.datasets.CIFAR100('./data', train=True, download=True, transform=transform_train )
test_set = torchvision.datasets.CIFAR100('./data', train=False, download=True, transform=transform_test )

classes = train_set.classes


Files already downloaded and verified
Files already downloaded and verified


In [None]:
#learning_rate = 1e-3
batch_size = 32
num_epochs = 10

In [None]:
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False)



In [None]:
from IPython.display import HTML, display


# Custom IPython progress bar for training
class ProgressMonitor(object):

    tmpl = """
        <table style="width: 100%;">
            <tbody>
                <tr>
                    <td style="width: 30%;">
                     <b>Loss: {loss:0.4f}</b> &nbsp&nbsp&nbsp {value} / {length}
                    </td>
                    <td style="width: 70%;">
                        <progress value='{value}' max='{length}', style='width: 100%'>{value}</progress>
                    </td>
                </tr>
            </tbody>
        </table>
        """

    def __init__(self, length):
        self.length = length
        self.count = 0
        self.display = display(self.html(0, 0), display_id=True)

    def html(self, count, loss):
        return HTML(self.tmpl.format(length=self.length, value=count, loss=loss))

    def update(self, count, loss):
        self.count += count
        self.display.update(self.html(self.count, loss))

In [None]:
def stairs(s, v, *svs):
    """ Implements a typical "stairs" schedule for learning-rates.
    Best explained by example:
    stairs(s, 0.1, 10, 0.01, 20, 0.001)
    will return 0.1 if s<10, 0.01 if 10<=s<20, and 0.001 if 20<=s
    """
    for s0, v0 in zip(svs[::2], svs[1::2]):
        if s < s0:
            break
        v = v0
    return v

def rampup(s, peak_s, peak_lr):
  if s < peak_s:  # Warmup
    return s/peak_s * peak_lr
  else:
    return peak_lr

In [None]:
from statistics import mean
import os

optim = torch.optim.SGD(model.parameters(), lr=0.003, momentum=0.9, weight_decay=1e-6)

loss_fn = nn.CrossEntropyLoss()

S = 100
def schedule(s):
  step_lr = stairs(s, 3e-3, 35, 3e-4, 70, 3e-5, 90, 3e-6, S, None)
  return rampup(s, 2, step_lr)

def train(optimizer, model, num_epochs = 10, first_epoch = 1):
    os.mkdir("./train1")
    train_losses = []
    test_losses = []

    best_test_acc = 0

    steps_per_batch = 50000 // train_loader.batch_size

    for epoch in range(first_epoch, first_epoch + num_epochs):
        print('Epoch', epoch)

        model.train()
        progress = ProgressMonitor(length=len(train_set))

        correct_train = 0
        batch_losses = []

        lr = schedule(epoch)
        for param_group in optimizer.param_groups:
          param_group['lr'] = lr

        for batch, targets in train_loader:

            batch = batch.to(device)
            targets = targets.to(device)

            optimizer.zero_grad()

            outputs = model(batch)

            loss = loss_fn(outputs, targets)
            lossdiv = loss#*5/steps_per_batch
            lossdiv.backward()

            #torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
            optimizer.step()


            batch_losses.append(loss.item())

            _, preds = torch.max(outputs, 1)
            correct_train += torch.sum(preds == targets.data)
            progress.update(batch.shape[0], mean(batch_losses) )

        #scheduler.step()

        train_losses.append( mean(batch_losses))

        model.eval()

        y_pred = []

        correct_test = 0

        with torch.no_grad():

            for batch, targets in test_loader:

                # Move the training batch to the GPU
                batch = batch.to(device)
                targets = targets.to(device)

                # forward propagation
                outputs = model(batch)

                # calculate the loss
                loss = loss_fn(outputs, targets)

                # save predictions
                y_pred.extend( outputs.argmax(dim=1).cpu().numpy() )

                # accumulate correct count
                _, preds = torch.max(outputs, 1)
                correct_test += torch.sum(preds == targets.data)


        # Calculate accuracy
        train_acc = correct_train.item() / train_set.data.shape[0]
        test_acc = correct_test.item() / test_set.data.shape[0]


        print('Training accuracy: {:.2f}%'.format(float(train_acc) * 100))
        print('Test accuracy: {:.2f}%'.format(float(test_acc) * 100))
        print('Learning Rate: {:.7f}'.format(float(lr)), end='')
        torch.save(model.state_dict(), f"./train1/{epoch}.pt")

        # Save the best model
        #if test_acc > best_test_acc:
         #   best_test_acc = test_acc
          #  torch.save( model.state_dict(), 'best_model.pt' )
           # print( ' --> Best model saved!\n' )


    return train_losses, test_losses, y_pred

In [None]:
a, b ,c = train(optim, model=model, num_epochs = 100,first_epoch=1)#, train_loader=train_loader,valid_loader=test_loader)

Epoch 1


0,1
Loss: 3.6649 50000 / 50000,50000


Training accuracy: 14.37%
Test accuracy: 16.71%
Learning Rate: 0.0015000Epoch 2


0,1
Loss: 3.4280 50000 / 50000,50000


Training accuracy: 17.68%
Test accuracy: 19.47%
Learning Rate: 0.0030000Epoch 3


0,1
Loss: 3.2087 50000 / 50000,50000


Training accuracy: 21.63%
Test accuracy: 23.45%
Learning Rate: 0.0030000Epoch 4


0,1
Loss: 3.0709 50000 / 50000,50000


Training accuracy: 23.94%
Test accuracy: 23.98%
Learning Rate: 0.0030000Epoch 5


0,1
Loss: 2.9719 50000 / 50000,50000


Training accuracy: 25.91%
Test accuracy: 26.19%
Learning Rate: 0.0030000Epoch 6


0,1
Loss: 2.8771 50000 / 50000,50000


Training accuracy: 27.77%
Test accuracy: 28.15%
Learning Rate: 0.0030000Epoch 7


0,1
Loss: 2.8056 50000 / 50000,50000


Training accuracy: 29.08%
Test accuracy: 30.30%
Learning Rate: 0.0030000Epoch 8


0,1
Loss: 2.7447 50000 / 50000,50000


Training accuracy: 30.31%
Test accuracy: 31.58%
Learning Rate: 0.0030000Epoch 9


0,1
Loss: 2.6775 50000 / 50000,50000


Training accuracy: 31.71%
Test accuracy: 31.87%
Learning Rate: 0.0030000Epoch 10


0,1
Loss: 2.6237 50000 / 50000,50000


Training accuracy: 32.69%
Test accuracy: 33.04%
Learning Rate: 0.0030000Epoch 11


0,1
Loss: 2.5708 50000 / 50000,50000


Training accuracy: 33.85%
Test accuracy: 34.83%
Learning Rate: 0.0030000Epoch 12


0,1
Loss: 2.5180 50000 / 50000,50000


Training accuracy: 35.12%
Test accuracy: 35.18%
Learning Rate: 0.0030000Epoch 13


0,1
Loss: 2.4734 50000 / 50000,50000


Training accuracy: 36.00%
Test accuracy: 35.95%
Learning Rate: 0.0030000Epoch 14


0,1
Loss: 2.4348 50000 / 50000,50000


Training accuracy: 36.66%
Test accuracy: 36.57%
Learning Rate: 0.0030000Epoch 15


0,1
Loss: 2.3948 50000 / 50000,50000


Training accuracy: 37.62%
Test accuracy: 37.19%
Learning Rate: 0.0030000Epoch 16


0,1
Loss: 2.3574 50000 / 50000,50000


Training accuracy: 38.50%
Test accuracy: 38.48%
Learning Rate: 0.0030000Epoch 17


0,1
Loss: 2.3304 50000 / 50000,50000


Training accuracy: 38.96%
Test accuracy: 38.88%
Learning Rate: 0.0030000Epoch 18


0,1
Loss: 2.2886 50000 / 50000,50000


Training accuracy: 39.92%
Test accuracy: 39.50%
Learning Rate: 0.0030000Epoch 19


0,1
Loss: 2.2617 50000 / 50000,50000


Training accuracy: 40.41%
Test accuracy: 39.62%
Learning Rate: 0.0030000Epoch 20


0,1
Loss: 2.2350 50000 / 50000,50000


Training accuracy: 41.27%
Test accuracy: 40.16%
Learning Rate: 0.0030000Epoch 21


0,1
Loss: 2.2080 50000 / 50000,50000


Training accuracy: 41.81%
Test accuracy: 41.03%
Learning Rate: 0.0030000Epoch 22


0,1
Loss: 2.1731 50000 / 50000,50000


Training accuracy: 42.59%
Test accuracy: 41.75%
Learning Rate: 0.0030000Epoch 23


0,1
Loss: 2.1523 50000 / 50000,50000


Training accuracy: 42.94%
Test accuracy: 41.87%
Learning Rate: 0.0030000Epoch 24


0,1
Loss: 2.1285 50000 / 50000,50000


Training accuracy: 43.52%
Test accuracy: 41.26%
Learning Rate: 0.0030000Epoch 25


0,1
Loss: 2.1083 50000 / 50000,50000


Training accuracy: 43.74%
Test accuracy: 42.77%
Learning Rate: 0.0030000Epoch 26


0,1
Loss: 2.0838 50000 / 50000,50000


Training accuracy: 44.73%
Test accuracy: 42.85%
Learning Rate: 0.0030000Epoch 27


0,1
Loss: 2.0642 50000 / 50000,50000


Training accuracy: 45.07%
Test accuracy: 44.07%
Learning Rate: 0.0030000Epoch 28


0,1
Loss: 2.0485 50000 / 50000,50000


Training accuracy: 45.23%
Test accuracy: 43.77%
Learning Rate: 0.0030000Epoch 29


0,1
Loss: 2.0266 50000 / 50000,50000


Training accuracy: 46.04%
Test accuracy: 44.86%
Learning Rate: 0.0030000Epoch 30


0,1
Loss: 2.0051 50000 / 50000,50000


Training accuracy: 46.19%
Test accuracy: 44.81%
Learning Rate: 0.0030000Epoch 31


0,1
Loss: 1.9932 50000 / 50000,50000


Training accuracy: 46.57%
Test accuracy: 44.41%
Learning Rate: 0.0030000Epoch 32


0,1
Loss: 1.9715 50000 / 50000,50000


Training accuracy: 47.26%
Test accuracy: 45.46%
Learning Rate: 0.0030000Epoch 33


0,1
Loss: 1.9570 50000 / 50000,50000


Training accuracy: 47.15%
Test accuracy: 45.44%
Learning Rate: 0.0030000Epoch 34


0,1
Loss: 1.9375 50000 / 50000,50000


Training accuracy: 47.50%
Test accuracy: 45.50%
Learning Rate: 0.0030000Epoch 35


0,1
Loss: 1.8068 50000 / 50000,50000


Training accuracy: 51.24%
Test accuracy: 48.58%
Learning Rate: 0.0003000Epoch 36


0,1
Loss: 1.7774 50000 / 50000,50000


Training accuracy: 51.98%
Test accuracy: 48.23%
Learning Rate: 0.0003000Epoch 37


0,1
Loss: 1.7711 50000 / 50000,50000


Training accuracy: 52.26%
Test accuracy: 48.70%
Learning Rate: 0.0003000Epoch 38


0,1
Loss: 1.7611 50000 / 50000,50000


Training accuracy: 52.32%
Test accuracy: 48.61%
Learning Rate: 0.0003000Epoch 39


0,1
Loss: 1.7565 50000 / 50000,50000


Training accuracy: 52.42%
Test accuracy: 48.58%
Learning Rate: 0.0003000Epoch 40


0,1
Loss: 1.7497 50000 / 50000,50000


Training accuracy: 52.81%
Test accuracy: 48.55%
Learning Rate: 0.0003000Epoch 41


0,1
Loss: 1.7418 50000 / 50000,50000


Training accuracy: 52.64%
Test accuracy: 48.91%
Learning Rate: 0.0003000Epoch 42


0,1
Loss: 1.7367 22752 / 50000,22752


0,1
Loss: 1.7435 50000 / 50000,50000


Training accuracy: 52.63%
Test accuracy: 48.71%
Learning Rate: 0.0003000Epoch 43


0,1
Loss: 1.7395 50000 / 50000,50000


Training accuracy: 52.71%
Test accuracy: 48.70%
Learning Rate: 0.0003000Epoch 44


0,1
Loss: 1.7375 50000 / 50000,50000


Training accuracy: 52.79%
Test accuracy: 49.18%
Learning Rate: 0.0003000Epoch 45


0,1
Loss: 1.7349 50000 / 50000,50000


Training accuracy: 52.96%
Test accuracy: 48.75%
Learning Rate: 0.0003000Epoch 46


0,1
Loss: 1.7338 50000 / 50000,50000


Training accuracy: 52.95%
Test accuracy: 49.06%
Learning Rate: 0.0003000Epoch 47


0,1
Loss: 1.7293 50000 / 50000,50000


Training accuracy: 53.02%
Test accuracy: 49.23%
Learning Rate: 0.0003000Epoch 48


0,1
Loss: 1.7276 50000 / 50000,50000


Training accuracy: 53.02%
Test accuracy: 49.24%
Learning Rate: 0.0003000Epoch 49


0,1
Loss: 1.7244 50000 / 50000,50000


Training accuracy: 53.31%
Test accuracy: 48.79%
Learning Rate: 0.0003000Epoch 50


0,1
Loss: 1.7221 50000 / 50000,50000


Training accuracy: 53.38%
Test accuracy: 49.28%
Learning Rate: 0.0003000Epoch 51


0,1
Loss: 1.7261 11200 / 50000,11200


KeyboardInterrupt: 

In [None]:
torch.save(model.state_dict(), "1.pt")

In [None]:
model.load_state_dict(torch.load("1.pt"))
model.eval()

LinNet(
  (root): Sequential(
    (conv): StdConv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (pad): ConstantPad2d(padding=(1, 1, 1, 1), value=0)
    (pool): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (body): Sequential(
    (block1): ConvWithLin(
      (unit): Sequential(
        (unit01): PreActConv(
          (gn1): GroupNorm(32, 64, eps=1e-05, affine=True)
          (conv1): StdConv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (gn2): GroupNorm(32, 64, eps=1e-05, affine=True)
          (conv2): StdConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (gn3): GroupNorm(32, 64, eps=1e-05, affine=True)
          (conv3): StdConv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (relu): ReLU(inplace=True)
        )
        (unit02): PreActConv(
          (gn1): GroupNorm(32, 128, eps=1e-05, affine=True)
          (conv1): StdConv2d(128, 64, kerne

In [None]:
for bname, block in model.body.named_children():
  for uname, unit in block.unit.named_children():
    unit.melt()

In [None]:
for bname, block in model.body.named_children():
  for uname, unit in block.unit.named_children():
    unit.freeze()

In [None]:
for bname, block in model.body.named_children():
  for uname, unit in block.unit.named_children():
    #unit.load_from(weights, prefix=f'resnet/{bname}/{uname}/')
    t1 = tf2th(weights[f'resnet/{bname}/{uname}/' + 'a/standardized_conv2d/kernel']).to(device)
    print(torch.equal(unit.conv1.weight,t1))
    print(t1.shape)
    print(unit.conv1.weight.shape)

False
torch.Size([64, 64, 1, 1])
torch.Size([64, 64, 1, 1])
False
torch.Size([64, 256, 1, 1])
torch.Size([64, 128, 1, 1])
False
torch.Size([64, 256, 1, 1])
torch.Size([64, 128, 1, 1])
False
torch.Size([128, 256, 1, 1])
torch.Size([128, 256, 1, 1])
False
torch.Size([128, 512, 1, 1])
torch.Size([128, 512, 1, 1])
False
torch.Size([128, 512, 1, 1])
torch.Size([128, 512, 1, 1])
False
torch.Size([128, 512, 1, 1])
torch.Size([128, 512, 1, 1])
False
torch.Size([256, 512, 1, 1])
torch.Size([256, 512, 1, 1])
False
torch.Size([256, 1024, 1, 1])
torch.Size([256, 512, 1, 1])
False
torch.Size([256, 1024, 1, 1])
torch.Size([256, 512, 1, 1])
False
torch.Size([256, 1024, 1, 1])
torch.Size([256, 512, 1, 1])
False
torch.Size([256, 1024, 1, 1])
torch.Size([256, 512, 1, 1])
False
torch.Size([256, 1024, 1, 1])
torch.Size([256, 512, 1, 1])


In [None]:
t1 = []
for bname, block in model.body.named_children():
  for uname, unit in block.unit.named_children():
    t1 += [torch.clone(unit.conv1.weight)]


In [None]:
#t1 = []
for bname, block in model.body.named_children():
  for uname, unit in block.unit.named_children():
    #unit.load_from(weights, prefix=f'resnet/{bname}/{uname}/')

    print(torch.equal(unit.conv1.weight,t1[0]))
    t1 += [torch.clone(unit.conv1.weight)]
    t1.pop(0)

True
True
True
True
True
True
True
True
True
True
True
True
True


In [None]:
sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
model = LinNet(LinNet.BLOCK_UNITS['r50'], width_factor=1, head_size=100,zero_head = True)  # NOTE: No new head.
#model.load_from(weights)
model.to(device);