In [106]:
import argparse
import os
import sys
import tabulate
import time
import torch
import torch.nn.functional as F

import curves
import data
import models
import utils
import pickle

import copy
import numpy as np

In [107]:
# architecture = getattr(models, "VGG16")
architecture = getattr(models, "Linear")


In [108]:
architecture

models.linear.Linear

In [109]:
model = architecture.base(num_classes=10, **architecture.kwargs)
checkpoint_name = "checkpoint-100.pt"
model.load_state_dict(torch.load("curves_linear/curve1/"+checkpoint_name)['model_state'])

In [110]:
loaders, num_classes = data.loaders(
    "CIFAR10",
    "data",
    128,
    1,
    "VGG",
    False
)

Files already downloaded and verified
Using train (45000) + validation (5000)
Files already downloaded and verified


In [111]:
for X, y in loaders['train']:
    break

In [112]:
X.shape

torch.Size([128, 3, 32, 32])

In [113]:
3*32*32

3072

In [114]:
y_pred = torch.argmax(model(X), dim=-1)

In [115]:
y

tensor([4, 9, 9, 9, 3, 7, 0, 6, 4, 0, 3, 1, 0, 6, 6, 2, 3, 4, 8, 0, 6, 7, 5, 5,
        7, 1, 7, 8, 7, 6, 2, 0, 9, 9, 9, 7, 1, 2, 3, 1, 2, 4, 0, 9, 6, 5, 9, 9,
        1, 1, 1, 6, 9, 5, 2, 2, 1, 0, 1, 4, 8, 2, 4, 7, 2, 4, 5, 8, 7, 8, 6, 1,
        3, 6, 4, 7, 8, 0, 5, 7, 1, 2, 1, 2, 2, 7, 2, 8, 7, 2, 5, 7, 3, 7, 5, 7,
        0, 5, 0, 8, 5, 7, 4, 2, 0, 6, 6, 9, 5, 8, 8, 6, 5, 6, 9, 7, 3, 6, 7, 7,
        7, 3, 8, 7, 5, 3, 7, 7])

In [116]:
y_pred

tensor([4, 9, 9, 9, 3, 7, 0, 6, 4, 0, 3, 1, 0, 6, 6, 2, 3, 4, 8, 0, 6, 7, 5, 5,
        7, 1, 7, 8, 7, 6, 2, 0, 9, 9, 9, 7, 1, 2, 3, 1, 2, 4, 0, 9, 6, 5, 9, 9,
        1, 1, 1, 6, 9, 5, 2, 2, 1, 0, 1, 4, 8, 2, 4, 7, 2, 4, 5, 7, 7, 8, 6, 1,
        3, 6, 4, 7, 8, 0, 5, 7, 1, 2, 1, 2, 2, 7, 2, 8, 7, 2, 5, 7, 3, 7, 5, 7,
        0, 5, 0, 8, 5, 7, 4, 2, 0, 6, 6, 9, 5, 8, 8, 6, 5, 6, 9, 7, 3, 6, 7, 7,
        7, 3, 8, 7, 5, 3, 7, 7])

In [117]:
sum(~(y_pred==y))

tensor(1, dtype=torch.uint8)

## Rescale

In [22]:
def rescale(l, scale):
    list(model.modules())[l].weight=torch.nn.Parameter(list(model.modules())[l].weight*scale)
    list(model.modules())[l].bias=torch.nn.Parameter(list(model.modules())[l].bias*scale)

In [34]:
list(model.modules())[-5]

Linear(in_features=1152, out_features=1000, bias=True)

In [35]:
rescale(-3, 10)

In [36]:
rescale(-5, 0.1)

In [37]:
y_pred_r = torch.argmax(model(X), dim=-1)

In [38]:
y_pred_r

tensor([8, 4, 6, 6, 9, 3, 1, 3, 4, 4, 4, 4, 1, 0, 4, 2, 9, 6, 6, 0, 7, 9, 4, 2,
        1, 4, 4, 4, 1, 1, 3, 4, 5, 6, 2, 4, 0, 4, 4, 5, 4, 6, 6, 9, 5, 8, 5, 4,
        5, 7, 4, 4, 9, 4, 6, 5, 6, 7, 6, 4, 9, 5, 5, 2, 7, 4, 4, 0, 4, 7, 6, 4,
        9, 0, 5, 7, 1, 3, 7, 9, 1, 6, 8, 7, 4, 8, 4, 9, 7, 4, 4, 5, 4, 9, 4, 4,
        5, 4, 4, 4, 4, 3, 3, 6, 3, 4, 0, 6, 3, 0, 8, 9, 9, 4, 8, 3, 2, 2, 6, 4,
        1, 4, 0, 5, 4, 4, 8, 1])

In [39]:
sum(~(y_pred==y_pred_r))

tensor(33, dtype=torch.uint8)

## Node

In [118]:
# def change_node(l1, l2, i, j):
    
#     c = copy.deepcopy(torch.nn.Parameter(list(model.modules())[l1].weight[j]))
#     list(model.modules())[l1].weight[j]  = list(model.modules())[l1].weight[i] 
#     list(model.modules())[l1].weight[i] = c
    
#     c = copy.deepcopy(torch.nn.Parameter(list(model.modules())[l2].weight.transpose(0,1)[j]))
#     list(model.modules())[l2].weight.transpose(0,1)[j]  = list(model.modules())[l2].weight.transpose(0,1)[i]
#     list(model.modules())[l2].weight.transpose(0,1)[i] = c
    
 

In [119]:
def change_node(l1, l2, i, j):
    
    c = copy.deepcopy(torch.nn.Parameter(list(model.modules())[l1].weight[j]))
    list(model.modules())[l1].weight[j]  = list(model.modules())[l1].weight[i] 
    list(model.modules())[l1].weight[i] = c
    
    c = copy.deepcopy(torch.nn.Parameter(list(model.modules())[l2].weight.transpose(0,1)[j]))
    list(model.modules())[l2].weight.transpose(0,1)[j]  = list(model.modules())[l2].weight.transpose(0,1)[i]
    list(model.modules())[l2].weight.transpose(0,1)[i] = c

In [120]:
# model 65 878 442 vs 15 245 130

In [136]:
list(model.modules())[-8].weight.shape

torch.Size([1152, 6144])

In [122]:
# list(model.modules())[17].bias.shape

In [129]:
for i in range(200):
    change_node(-, -3, i, 2*i)

In [130]:
y_pred_n = torch.argmax(model(X), dim=-1)

In [131]:
y_pred

tensor([4, 9, 9, 9, 3, 7, 0, 6, 4, 0, 3, 1, 0, 6, 6, 2, 3, 4, 8, 0, 6, 7, 5, 5,
        7, 1, 7, 8, 7, 6, 2, 0, 9, 9, 9, 7, 1, 2, 3, 1, 2, 4, 0, 9, 6, 5, 9, 9,
        1, 1, 1, 6, 9, 5, 2, 2, 1, 0, 1, 4, 8, 2, 4, 7, 2, 4, 5, 7, 7, 8, 6, 1,
        3, 6, 4, 7, 8, 0, 5, 7, 1, 2, 1, 2, 2, 7, 2, 8, 7, 2, 5, 7, 3, 7, 5, 7,
        0, 5, 0, 8, 5, 7, 4, 2, 0, 6, 6, 9, 5, 8, 8, 6, 5, 6, 9, 7, 3, 6, 7, 7,
        7, 3, 8, 7, 5, 3, 7, 7])

In [132]:
y_pred_n

tensor([4, 9, 9, 9, 3, 7, 0, 6, 4, 0, 3, 1, 0, 6, 6, 2, 3, 4, 8, 0, 6, 7, 5, 5,
        7, 1, 7, 8, 7, 6, 2, 0, 9, 9, 9, 7, 1, 2, 3, 1, 2, 4, 0, 9, 6, 5, 9, 9,
        1, 1, 1, 6, 9, 5, 2, 2, 1, 0, 1, 4, 8, 2, 4, 7, 2, 4, 5, 7, 7, 8, 6, 1,
        3, 6, 4, 7, 8, 0, 5, 7, 1, 2, 1, 2, 2, 7, 2, 8, 7, 2, 5, 7, 3, 7, 5, 7,
        0, 5, 0, 8, 5, 7, 4, 2, 0, 6, 6, 9, 5, 8, 8, 6, 5, 6, 9, 7, 3, 6, 7, 7,
        7, 3, 8, 7, 5, 3, 7, 7])

In [133]:
print("eq ", ~(y_pred_n==y_pred),  (~(y_pred_n==y_pred)).sum())


eq  tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0], dtype=torch.uint8) tensor(0)


In [134]:
sum(~(y_pred==y_pred_n))

tensor(0, dtype=torch.uint8)

## Saving

In [71]:
#  torch.load("curve/checkpoint-50.pt")

In [74]:
for path, k in [("Aaa", 0), ("Bbb", 4 - 1)]:
    print('p ', path)
    print('k', k)
    

p  Aaa
k 0
p  Bbb
k 3


In [77]:
a = list(range(10))
a[::3]

[0, 3, 6, 9]

In [89]:
architecture = getattr(models, "VGG16")
curve = getattr(curves, 'PolyChain')

In [90]:
curve

curves.PolyChain

In [86]:
architecture.curve

models.vgg.VGGCurve

In [87]:
architecture.kwargs

{'batch_norm': False, 'depth': 16}

In [113]:
model = curves.CurveNet(
        10,
        curve,
        architecture.curve,
        3,
        True,
        True,
        architecture_kwargs=architecture.kwargs,
    )

In [114]:
model

CurveNet(
  (coeff_layer): PolyChain()
  (net): VGGCurve(
    (layer_blocks): ModuleList(
      (0): ModuleList(
        (0): Conv2d()
        (1): Conv2d()
      )
      (1): ModuleList(
        (0): Conv2d()
        (1): Conv2d()
      )
      (2): ModuleList(
        (0): Conv2d()
        (1): Conv2d()
        (2): Conv2d()
      )
      (3): ModuleList(
        (0): Conv2d()
        (1): Conv2d()
        (2): Conv2d()
      )
      (4): ModuleList(
        (0): Conv2d()
        (1): Conv2d()
        (2): Conv2d()
      )
    )
    (activation_blocks): ModuleList(
      (0): ModuleList(
        (0): ReLU(inplace)
        (1): ReLU(inplace)
      )
      (1): ModuleList(
        (0): ReLU(inplace)
        (1): ReLU(inplace)
      )
      (2): ModuleList(
        (0): ReLU(inplace)
        (1): ReLU(inplace)
        (2): ReLU(inplace)
      )
      (3): ModuleList(
        (0): ReLU(inplace)
        (1): ReLU(inplace)
        (2): ReLU(inplace)
      )
      (4): ModuleList(
        (

In [94]:
pred = model(X)

In [95]:
pred.shape

torch.Size([128, 10])

In [96]:
from torch.nn import Module

In [98]:
m = Module()

In [105]:
m.register_buffer('range', torch.arange(0, float(3)))

In [106]:
m.range

tensor([ 0.,  1.,  2.])

In [122]:
t_n = 2/3 * (3 - 1)
torch.max(m.range.new([0.0]), 1.0 - torch.abs(t_n - m.range))

tensor([ 0.0000,  0.6667,  0.3333])