In [None]:
from functools import partial
from collections import OrderedDict

%config InlineBackend.figure_format = 'retina'

import numpy as np

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

import torchvision as tv

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [None]:
import requests
import io

def get_weights(bit_variant):
  response = requests.get(f'https://storage.googleapis.com/bit_models/{bit_variant}.npz')
  response.raise_for_status()
  return np.load(io.BytesIO(response.content))

weights = get_weights('BiT-M-R50x1')

In [None]:
weights_cifar10 = get_weights('BiT-M-R50x1-CIFAR10')

In [None]:
def tf2th(conv_weights):
  """Possibly convert HWIO to OIHW"""
  if conv_weights.ndim == 4:
    conv_weights = np.transpose(conv_weights, [3, 2, 0, 1])
  return torch.from_numpy(conv_weights)

class StdConv2d(nn.Conv2d):

  def forward(self, x):
    w = self.weight
    v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)
    w = (w - m) / torch.sqrt(v + 1e-10)
    return F.conv2d(x, w, self.bias, self.stride, self.padding,
                    self.dilation, self.groups)

def conv3x3(cin, cout, stride=1, groups=1, bias=False):
  return StdConv2d(cin, cout, kernel_size=3, stride=stride, padding=1, bias=bias, groups=groups)

def conv1x1(cin, cout, stride=1, bias=False):
  return StdConv2d(cin, cout, kernel_size=1, stride=stride, padding=0, bias=bias)

class PreActConv(nn.Module):

  def __init__(self, cin, cout=None, cmid=None, stride=1):
    super().__init__()
    cout = cout or cin
    cmid = cmid or cout//4

    self.gn1 = nn.GroupNorm(32, cin)
    self.conv1 = conv1x1(cin, cmid)
    self.gn2 = nn.GroupNorm(32, cmid)
    self.conv2 = conv3x3(cmid, cmid, stride)  # Original code has it on conv1!!
    self.gn3 = nn.GroupNorm(32, cmid)
    self.conv3 = conv1x1(cmid, cout)
    self.relu = nn.ReLU(inplace=True)


  def forward(self, x):
    out = self.relu(self.gn1(x))

    # Unit's branch
    out = self.conv1(out)
    out = self.conv2(self.relu(self.gn2(out)))
    out = self.conv3(self.relu(self.gn3(out)))

    return out

  def load_from(self, weights, prefix=''):
    with torch.no_grad():
      self.conv1.weight.copy_(tf2th(weights[prefix + 'a/standardized_conv2d/kernel']))
      self.conv2.weight.copy_(tf2th(weights[prefix + 'b/standardized_conv2d/kernel']))
      self.conv3.weight.copy_(tf2th(weights[prefix + 'c/standardized_conv2d/kernel']))
      self.gn1.weight.copy_(tf2th(weights[prefix + 'a/group_norm/gamma']))
      self.gn2.weight.copy_(tf2th(weights[prefix + 'b/group_norm/gamma']))
      self.gn3.weight.copy_(tf2th(weights[prefix + 'c/group_norm/gamma']))
      self.gn1.bias.copy_(tf2th(weights[prefix + 'a/group_norm/beta']))
      self.gn2.bias.copy_(tf2th(weights[prefix + 'b/group_norm/beta']))
      self.gn3.bias.copy_(tf2th(weights[prefix + 'c/group_norm/beta']))
      self.conv1.weight.requires_grad = False
      self.conv2.weight.requires_grad = False
      self.conv3.weight.requires_grad = False
      self.gn1.weight.requires_grad = False
      self.gn2.weight.requires_grad = False
      self.gn3.weight.requires_grad = False
      self.gn1.bias.requires_grad = False
      self.gn2.bias.requires_grad = False
      self.gn3.bias.requires_grad = False

    return self
  def melt(self):
      self.conv1.weight.requires_grad = True
      self.conv2.weight.requires_grad = True
      self.conv3.weight.requires_grad = True
      self.gn1.weight.requires_grad = True
      self.gn2.weight.requires_grad = True
      self.gn3.weight.requires_grad = True
      self.gn1.bias.requires_grad = True
      self.gn2.bias.requires_grad = True
      self.gn3.bias.requires_grad = True
  def freeze(self):
      self.conv1.weight.requires_grad = False
      self.conv2.weight.requires_grad = False
      self.conv3.weight.requires_grad = False
      self.gn1.weight.requires_grad = False
      self.gn2.weight.requires_grad = False
      self.gn3.weight.requires_grad = False
      self.gn1.bias.requires_grad = False
      self.gn2.bias.requires_grad = False
      self.gn3.bias.requires_grad = False


class ConvWithLin(nn.Module):
  def __init__(self, lcin, cin, cout = None, cmid = None, lcout = None, stride=1, block_unit = 1):
    super().__init__()
    cout = cout or cin
    cmid = cmid or cout//4
    lcout = lcout or lcin

    #self.unit = PreActConv(cin=cin,cout=cout,cmid=cmid,stride=stride)
    if lcin != cin:
      self.projin = conv1x1(lcin,cin,stride)

    self.unit = nn.Sequential(OrderedDict(
            [('unit01', PreActConv(cin=cin, cout=cout, cmid=cmid, stride=stride))] +
            [(f'unit{i:02d}', PreActConv(cin=cout, cout=cout, cmid=cmid)) for i in range(2, block_unit + 1)],
        ))
    if lcout != cout:
      self.projout = conv1x1(cout, lcout, stride)
    if lcout != lcin:
      self.linproj = conv1x1(lcin, lcout, stride)


  def forward(self, x):
    if hasattr(self, 'projin'):
      out = self.projin(x)
    else:
      out = x

    out = self.unit(out)

    if hasattr(self, 'projout'):
      out = self.projout(out)
    if hasattr(self, 'linproj'):
      x = self.linproj(x)

    out = torch.add(out,x)

#    if hasattr(self, 'outunit'):
#      out = self.outunit(out)

    return out


class LinNet(nn.Module):
  """Implementation of Pre-activation (v2) ResNet mode."""
  BLOCK_UNITS = {
      'r50': [3, 4, 6, 3],
      'r101': [3, 4, 23, 3],
      'r152': [3, 8, 36, 3],
  }

  def __init__(self, block_units, width_factor, head_size=21843, zero_head=False):
    super().__init__()
    wf = width_factor  # shortcut 'cause we'll use it a lot.

    # The following will be unreadable if we split lines.
    # pylint: disable=line-too-long
    self.root = nn.Sequential(OrderedDict([
        ('conv', StdConv2d(3, 64*wf, kernel_size=7, stride=2, padding=3, bias=False)),
        ('pad', nn.ConstantPad2d(1, 0)),
        ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0)),
        # The following is subtly not the same!
        # ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
    ]))

    self.body = nn.Sequential(OrderedDict([
        ('block1', ConvWithLin(lcin = 64*wf, lcout=512*wf, cin=64*wf, cout= 256*wf, cmid=64*wf, block_unit = block_units[0])),
        ('block2', ConvWithLin(lcin = 512*wf, lcout=512*wf, cin=256*wf, cout= 512*wf, cmid=128*wf, block_unit = block_units[1])),
        ('block3', ConvWithLin(lcin = 512*wf, lcout=2048*wf, cin=512*wf, cout= 1024*wf, cmid=256*wf, block_unit = block_units[2])),
        ('block4', ConvWithLin(lcin = 2048*wf, lcout=2048*wf, cin=1024*wf, cout= 2048*wf, cmid=512*wf, block_unit = block_units[3])),
    ]))
    # pylint: enable=line-too-long

    self.zero_head = zero_head
    self.head = nn.Sequential(OrderedDict([
        ('gn', nn.GroupNorm(32, 2048*wf)),
        ('relu', nn.ReLU(inplace=True)),
        ('avg', nn.AdaptiveAvgPool2d(output_size=1)),
        ('conv', nn.Conv2d(2048*wf, head_size, kernel_size=1, bias=True)),
    ]))

  def forward(self, x):
    x = self.head(self.body(self.root(x)))
    assert x.shape[-2:] == (1, 1)  # We should have no spatial shape left.
    return x[...,0,0]

  def load_from(self, weights, prefix='resnet/'):
    with torch.no_grad():
      self.root.conv.weight.copy_(tf2th(weights[f'{prefix}root_block/standardized_conv2d/kernel']))
      self.head.gn.weight.copy_(tf2th(weights[f'{prefix}group_norm/gamma']))
      self.head.gn.bias.copy_(tf2th(weights[f'{prefix}group_norm/beta']))
      if self.zero_head:
        nn.init.zeros_(self.head.conv.weight)
        nn.init.zeros_(self.head.conv.bias)
      else:
        self.head.conv.weight.copy_(tf2th(weights[f'{prefix}head/conv2d/kernel']))
        self.head.conv.bias.copy_(tf2th(weights[f'{prefix}head/conv2d/bias']))

      for bname, block in self.body.named_children():
        for uname, unit in block.unit.named_children():
          unit.load_from(weights, prefix=f'{prefix}{bname}/{uname}/')
    return self


In [None]:
model = LinNet(LinNet.BLOCK_UNITS['r50'], width_factor=1, head_size=10,zero_head = True)  # NOTE: No new head.
model.load_from(weights)
model.to(device);

Train

In [None]:
import torchvision
import torchvision.transforms as T

In [None]:

transform_train = T.Compose( [T.RandomCrop(32, padding=4), T.RandomHorizontalFlip(), T.ToTensor(), T.Normalize( (0.5, 0.5, 0.5), (0.5, 0.5, 0.5) )] )
transform_test = T.Compose( [T.ToTensor(), T.Normalize( (0.5, 0.5, 0.5), (0.5, 0.5, 0.5) )] )

train_set = torchvision.datasets.CIFAR10('./data', train=True, download=True, transform=transform_train )
test_set = torchvision.datasets.CIFAR10('./data', train=False, download=True, transform=transform_test )

classes = train_set.classes


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:12<00:00, 13482998.79it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [None]:
#learning_rate = 1e-3
batch_size = 32
num_epochs = 10

In [None]:
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False)



In [None]:
from IPython.display import HTML, display


# Custom IPython progress bar for training
class ProgressMonitor(object):

    tmpl = """
        <table style="width: 100%;">
            <tbody>
                <tr>
                    <td style="width: 30%;">
                     <b>Loss: {loss:0.4f}</b> &nbsp&nbsp&nbsp {value} / {length}
                    </td>
                    <td style="width: 70%;">
                        <progress value='{value}' max='{length}', style='width: 100%'>{value}</progress>
                    </td>
                </tr>
            </tbody>
        </table>
        """

    def __init__(self, length):
        self.length = length
        self.count = 0
        self.display = display(self.html(0, 0), display_id=True)

    def html(self, count, loss):
        return HTML(self.tmpl.format(length=self.length, value=count, loss=loss))

    def update(self, count, loss):
        self.count += count
        self.display.update(self.html(self.count, loss))

In [None]:
def stairs(s, v, *svs):
    """ Implements a typical "stairs" schedule for learning-rates.
    Best explained by example:
    stairs(s, 0.1, 10, 0.01, 20, 0.001)
    will return 0.1 if s<10, 0.01 if 10<=s<20, and 0.001 if 20<=s
    """
    for s0, v0 in zip(svs[::2], svs[1::2]):
        if s < s0:
            break
        v = v0
    return v

def rampup(s, peak_s, peak_lr):
  if s < peak_s:  # Warmup
    return s/peak_s * peak_lr
  else:
    return peak_lr

In [None]:
from statistics import mean

optim = torch.optim.SGD(model.parameters(), lr=0.003, momentum=0.9, weight_decay=1e-6)

loss_fn = nn.CrossEntropyLoss()

S = 30
def schedule(s):
  step_lr = stairs(s, 3e-3, 10, 3e-4, 20, 3e-5, 25, 3e-6, S, None)
  return rampup(s, 2, step_lr)

def train(optimizer, model, num_epochs = 10, first_epoch = 1):

    train_losses = []
    test_losses = []

    best_test_acc = 0

    steps_per_batch = 50000 // train_loader.batch_size

    for epoch in range(first_epoch, first_epoch + num_epochs):
        print('Epoch', epoch)

        model.train()
        progress = ProgressMonitor(length=len(train_set))

        correct_train = 0
        batch_losses = []

        lr = schedule(epoch)
        for param_group in optimizer.param_groups:
          param_group['lr'] = lr

        for batch, targets in train_loader:

            batch = batch.to(device)
            targets = targets.to(device)

            optimizer.zero_grad()

            outputs = model(batch)

            loss = loss_fn(outputs, targets)
            lossdiv = 5*loss/steps_per_batch
            lossdiv.backward()

            #torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
            optimizer.step()


            batch_losses.append(loss.item())

            _, preds = torch.max(outputs, 1)
            correct_train += torch.sum(preds == targets.data)
            progress.update(batch.shape[0], mean(batch_losses) )

        #scheduler.step()

        train_losses.append( mean(batch_losses))

        model.eval()

        y_pred = []

        correct_test = 0

        with torch.no_grad():

            for batch, targets in test_loader:

                # Move the training batch to the GPU
                batch = batch.to(device)
                targets = targets.to(device)

                # forward propagation
                outputs = model(batch)

                # calculate the loss
                loss = loss_fn(outputs, targets)

                # save predictions
                y_pred.extend( outputs.argmax(dim=1).cpu().numpy() )

                # accumulate correct count
                _, preds = torch.max(outputs, 1)
                correct_test += torch.sum(preds == targets.data)


        # Calculate accuracy
        train_acc = correct_train.item() / train_set.data.shape[0]
        test_acc = correct_test.item() / test_set.data.shape[0]


        print('Training accuracy: {:.2f}%'.format(float(train_acc) * 100))
        print('Test accuracy: {:.2f}%'.format(float(test_acc) * 100))
        print('Learning Rate: {:.7f}'.format(float(lr)), end='')

        # Save the best model
        if test_acc > best_test_acc:
            best_test_acc = test_acc
            torch.save( model.state_dict(), 'best_model.pt' )
            print( ' --> Best model saved!\n' )


    return train_losses, test_losses, y_pred

In [None]:
train(optim, model=model, num_epochs = S-8,first_epoch=8)#, train_loader=train_loader,valid_loader=test_loader)

Epoch 8


0,1
Loss: 2.1975 50000 / 50000,50000


Training accuracy: 21.97%
Test accuracy: 27.30%
Learning Rate: 0.0030000 --> Best model saved!

Epoch 9


0,1
Loss: 2.0674 50000 / 50000,50000


Training accuracy: 27.92%
Test accuracy: 30.94%
Learning Rate: 0.0030000 --> Best model saved!

Epoch 10


0,1
Loss: 2.0138 50000 / 50000,50000


Training accuracy: 30.42%
Test accuracy: 31.59%
Learning Rate: 0.0003000 --> Best model saved!

Epoch 11


0,1
Loss: 2.0068 50000 / 50000,50000


Training accuracy: 30.12%
Test accuracy: 31.85%
Learning Rate: 0.0003000 --> Best model saved!

Epoch 12


0,1
Loss: 1.9991 50000 / 50000,50000


Training accuracy: 30.80%
Test accuracy: 32.23%
Learning Rate: 0.0003000 --> Best model saved!

Epoch 13


0,1
Loss: 1.9914 50000 / 50000,50000


Training accuracy: 30.75%
Test accuracy: 32.08%
Learning Rate: 0.0003000Epoch 14


0,1
Loss: 1.9839 50000 / 50000,50000


Training accuracy: 31.35%
Test accuracy: 32.57%
Learning Rate: 0.0003000 --> Best model saved!

Epoch 15


0,1
Loss: 1.9779 50000 / 50000,50000


Training accuracy: 30.94%
Test accuracy: 32.69%
Learning Rate: 0.0003000 --> Best model saved!

Epoch 16


0,1
Loss: 1.9706 50000 / 50000,50000


Training accuracy: 31.32%
Test accuracy: 32.49%
Learning Rate: 0.0003000Epoch 17


0,1
Loss: 1.9650 50000 / 50000,50000


Training accuracy: 31.33%
Test accuracy: 32.66%
Learning Rate: 0.0003000Epoch 18


0,1
Loss: 1.9582 50000 / 50000,50000


Training accuracy: 31.74%
Test accuracy: 33.12%
Learning Rate: 0.0003000 --> Best model saved!

Epoch 19


0,1
Loss: 1.9526 50000 / 50000,50000


Training accuracy: 31.78%
Test accuracy: 32.85%
Learning Rate: 0.0003000Epoch 20


0,1
Loss: 1.9490 50000 / 50000,50000


Training accuracy: 31.67%
Test accuracy: 32.99%
Learning Rate: 0.0000300Epoch 21


0,1
Loss: 1.9488 50000 / 50000,50000


Training accuracy: 31.79%
Test accuracy: 33.06%
Learning Rate: 0.0000300Epoch 22


0,1
Loss: 1.9463 50000 / 50000,50000


Training accuracy: 31.99%
Test accuracy: 33.16%
Learning Rate: 0.0000300 --> Best model saved!

Epoch 23


0,1
Loss: 1.9470 50000 / 50000,50000


Training accuracy: 31.81%
Test accuracy: 33.18%
Learning Rate: 0.0000300 --> Best model saved!

Epoch 24


0,1
Loss: 1.9466 50000 / 50000,50000


Training accuracy: 32.08%
Test accuracy: 33.18%
Learning Rate: 0.0000300Epoch 25


0,1
Loss: 1.9472 50000 / 50000,50000


Training accuracy: 32.03%
Test accuracy: 33.17%
Learning Rate: 0.0000030Epoch 26


0,1
Loss: 1.9458 50000 / 50000,50000


Training accuracy: 31.98%
Test accuracy: 33.20%
Learning Rate: 0.0000030 --> Best model saved!

Epoch 27


0,1
Loss: 1.9443 50000 / 50000,50000


Training accuracy: 32.10%
Test accuracy: 33.18%
Learning Rate: 0.0000030Epoch 28


0,1
Loss: 1.9466 50000 / 50000,50000


Training accuracy: 31.91%
Test accuracy: 33.18%
Learning Rate: 0.0000030Epoch 29


0,1
Loss: 1.9459 50000 / 50000,50000


Training accuracy: 31.73%
Test accuracy: 33.18%
Learning Rate: 0.0000030

([2.1974608042418615,
  2.0673912695715693,
  2.0138413563647184,
  2.0067862925129827,
  1.9990794996534946,
  1.9913544908823757,
  1.983885079641336,
  1.9779379941375304,
  1.9706326097688533,
  1.964976764686277,
  1.9581956427141556,
  1.9526252201254828,
  1.9489795899284397,
  1.9487778221424465,
  1.9462709202647437,
  1.9470464488442556,
  1.946609341632992,
  1.9472332827494225,
  1.9458300917863998,
  1.9442824945950157,
  1.9466296009581132,
  1.9459392308273609],
 [],
 [2,
  0,
  8,
  0,
  4,
  6,
  5,
  6,
  2,
  0,
  2,
  9,
  5,
  7,
  1,
  8,
  5,
  6,
  8,
  6,
  7,
  0,
  8,
  1,
  4,
  4,
  3,
  6,
  9,
  6,
  6,
  2,
  4,
  5,
  8,
  6,
  2,
  9,
  8,
  5,
  1,
  2,
  2,
  2,
  8,
  0,
  5,
  2,
  4,
  4,
  0,
  8,
  5,
  6,
  8,
  4,
  5,
  1,
  6,
  2,
  6,
  6,
  8,
  5,
  6,
  4,
  6,
  0,
  3,
  8,
  2,
  6,
  2,
  0,
  0,
  2,
  1,
  5,
  5,
  8,
  8,
  8,
  2,
  0,
  5,
  2,
  0,
  8,
  9,
  8,
  0,
  2,
  8,
  2,
  2,
  4,
  6,
  2,
  5,
  5,
  4,
  3,
  5

In [None]:
torch.save(model.state_dict(), "PATH.pt")

In [None]:
model.load_state_dict(torch.load("PATH1.pt"))
model.eval()

LinNet(
  (root): Sequential(
    (conv): StdConv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (pad): ConstantPad2d(padding=(1, 1, 1, 1), value=0)
    (pool): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (body): Sequential(
    (block1): ConvWithLin(
      (unit): Sequential(
        (unit01): PreActConv(
          (gn1): GroupNorm(32, 64, eps=1e-05, affine=True)
          (conv1): StdConv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (gn2): GroupNorm(32, 64, eps=1e-05, affine=True)
          (conv2): StdConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (gn3): GroupNorm(32, 64, eps=1e-05, affine=True)
          (conv3): StdConv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (relu): ReLU(inplace=True)
        )
        (unit02): PreActConv(
          (gn1): GroupNorm(32, 256, eps=1e-05, affine=True)
          (conv1): StdConv2d(256, 64, kerne

In [None]:
for bname, block in model.body.named_children():
  for uname, unit in block.unit.named_children():
    unit.melt()

In [None]:
for bname, block in model.body.named_children():
  for uname, unit in block.unit.named_children():
    #unit.load_from(weights, prefix=f'resnet/{bname}/{uname}/')
    t1 = tf2th(weights[f'resnet/{bname}/{uname}/' + 'a/standardized_conv2d/kernel']).to(device)
    print(torch.equal(unit.conv1.weight,t1))
    print(t1.shape)
    print(unit.conv1.weight.shape)

False
torch.Size([64, 64, 1, 1])
torch.Size([64, 64, 1, 1])
False
torch.Size([64, 256, 1, 1])
torch.Size([64, 256, 1, 1])
False
torch.Size([64, 256, 1, 1])
torch.Size([64, 256, 1, 1])
False
torch.Size([128, 256, 1, 1])
torch.Size([128, 256, 1, 1])
False
torch.Size([128, 512, 1, 1])
torch.Size([128, 512, 1, 1])
False
torch.Size([128, 512, 1, 1])
torch.Size([128, 512, 1, 1])
False
torch.Size([128, 512, 1, 1])
torch.Size([128, 512, 1, 1])
False
torch.Size([256, 512, 1, 1])
torch.Size([256, 512, 1, 1])
False
torch.Size([256, 1024, 1, 1])
torch.Size([256, 1024, 1, 1])
False
torch.Size([256, 1024, 1, 1])
torch.Size([256, 1024, 1, 1])
False
torch.Size([256, 1024, 1, 1])
torch.Size([256, 1024, 1, 1])
False
torch.Size([256, 1024, 1, 1])
torch.Size([256, 1024, 1, 1])
False
torch.Size([256, 1024, 1, 1])
torch.Size([256, 1024, 1, 1])
False
torch.Size([512, 1024, 1, 1])
torch.Size([512, 1024, 1, 1])
False
torch.Size([512, 2048, 1, 1])
torch.Size([512, 2048, 1, 1])
False
torch.Size([512, 2048, 1, 1]