In [10]:
import torch
from torch import nn
import torchvision.models as models
from PIL import Image
from torchvision import transforms
from torchvision.models import vgg16, VGG16_Weights

In [11]:
import sys
sys.path.append('../..')  # Navigate up to GitHub folder
from LeNet_5.classes.convolution import ConvolutionLayer

In [65]:
from math import sqrt

def reshape_output(out_tensor):
    batch_size, features = out_tensor.shape
    height = width = int(sqrt(features//512))
    return out_tensor.reshape(
        batch_size, 
        512, 
        height,
        width,
        )

In [15]:
conv = nn.Conv2d(512, 4096, 7)
lin = nn.Linear(in_features=100, out_features=100)
lin.weight.shape

torch.Size([100, 100])

In [77]:
class SkipConnection(nn.Module):
    '''
    applies a convolution, upsamples 2x, and stores the resulting output,
    but passes the raw output forward to the next layer

    '''
    def __init__(self, in_channels, num_classes):
        super().__init__()
        # pool4 has 512 output channels
        # 30 output channels for each class
        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=num_classes, kernel_size=1, stride=1)
        # upsample layer
        self.val = None
        self.last_in = None
        # initialize weights with 0
        nn.init.zeros_(self.conv.weight) 

    def forward(self, x):
        '''
        returns the input as is but stores the result of a convolution.
        
        args
        - x: tensor(batch_size, 512, img_height/16, img_width/16)
        
        '''
        # apply convolution, upsample 2x, and store
        self.val = self.conv(x)
        # apply nothing to x
        self.last_in = x
        return x

In [114]:
def to_tensor(img):
    transform = transforms.Compose([
        transforms.CenterCrop((224, 224)),
        transforms.ToTensor()
    ])
    return transform(img)

In [173]:
cat_img = Image.open('../data/custom/cat108.jpg')
lion_img = Image.open('../data/custom/lion.jpg')

cat_tensor = to_tensor(cat_img)[:3,:,:].unsqueeze(0)
lion_tensor = to_tensor(lion_img)[:3,:,:].unsqueeze(0)
batch = torch.cat([cat_tensor, cat_tensor], dim=0)
batch.shape

batch = nn.functional.pad(input=batch, pad=(2, 2, 2, 2), mode='constant', value=0)
batch.shape

torch.Size([2, 3, 228, 228])

In [116]:

# get the vgg-16 pretrained model used in paper
model = vgg16(weights=VGG16_Weights.DEFAULT)

# get weights and biases from last 3 fully connected layers in VGG-16
linear1_weights = model.classifier[0].weight.data
linear1_bias = model.classifier[0].bias.data
linear2_weights = model.classifier[3].weight.data
linear2_bias = model.classifier[3].bias.data

# reshape to fit convolution layer dimensions
reshaped_weights1 = linear1_weights.reshape(4096, 512, 7, 7)
reshaped_weights2 = linear2_weights.reshape(4096, 4096, 1, 1)

# use pytorch's convolution layer class
conv1 = nn.Conv2d(in_channels=512, out_channels=4096, kernel_size=7)
conv2 = nn.Conv2d(in_channels=4096, out_channels=4096, kernel_size=1)
conv3 = nn.Conv2d(in_channels=4096, out_channels=30, kernel_size=1)

# replace default weights with trained reshaped weights
conv1.weight.data = reshaped_weights1
conv2.weight.data = reshaped_weights2

# replace biases
conv1.bias.data = linear1_bias
conv2.bias.data = linear2_bias

# replace linear layers with convlayers
model.classifier[0] = conv1
model.classifier[3] = conv2
model.classifier[6] = conv3

# remove average pooling layer
model.avgpool = nn.Identity()

# initialize post-pool 1x1 conv predictions
pool3_conv = SkipConnection(in_channels=256, num_classes=30)
pool4_conv = SkipConnection(in_channels=512, num_classes=30)
size_check_1 = SkipConnection(in_channels=128, num_classes=30)
size_check_2 = SkipConnection(in_channels=128, num_classes=30)
size_check_3 = SkipConnection(in_channels=512, num_classes=30)

# insert conv_wrap in sequential after pool3 and pool4
model.features.insert(17, pool3_conv)
model.features.insert(25, pool4_conv)
model.features.insert(6, size_check_1)
model.features.insert(11, size_check_2)
model.features.insert(35, size_check_3)

Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU(inplace=True)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): SkipConnection(
    (conv): Conv2d(128, 30, kernel_size=(1, 1), stride=(1, 1))
  )
  (7): ReLU(inplace=True)
  (8): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (9): ReLU(inplace=True)
  (10): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (11): SkipConnection(
    (conv): Conv2d(128, 30, kernel_size=(1, 1), stride=(1, 1))
  )
  (12): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (13): ReLU(inplace=True)
  (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (15): ReLU(inplace=True)
  (16): Conv2d(256, 256, kernel_size=(

In [130]:
pred = model.features(batch)
print(pred.shape)
pred = model.classifier(pred)

torch.Size([2, 512, 7, 7])


In [127]:

upsample = nn.ConvTranspose2d(in_channels=30, out_channels=30, kernel_size=4,stride=2, bias=False)
# pred2 = nn.functional.interpolate(pred, (1024, 1024), mode='bilinear')
pred2 = upsample(pred)
pred2.shape
pred.shape
# pred2 = upsample3(pred2)
model.features[19].val.shape


torch.Size([2, 30, 28, 28])

solving equations below to determine stride and kernel size for 2x upsampling

output = (input - 1) * stride + kernel
output = 2*input

2*i = (i - 1) * s + k

(k-2)/(2-s)


In [154]:
for name, child in vgg.named_children():
vgg = models.vgg16()
    print(name, child)

features Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU(inplace=True)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): ReLU(inplace=True)
  (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): ReLU(inplace=True)
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): ReLU(inplace=True)
  (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (13): ReLU(inplace=True)
  (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (15): ReLU(inplace=True)
  (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (17): Conv2d(256, 512, 

In [1]:
def print_shape(name):
    def hook(module, input, output):
        print(f"{name}: {output.shape}")
    return hook

# Register hooks
for name, layer in vgg.features.named_children():
    layer.register_forward_hook(print_shape(f"Layer {name}"))

for name, layer in vgg.avgpool.named_children():
    layer.register_forward_hook(print_shape(f"Layer {name}"))

for name, layer in vgg.classifier.named_children():
    layer.register_forward_hook(print_shape(f"Layer {name}"))

NameError: name 'vgg' is not defined

In [156]:
# vgg.classifier = nn.Identity()
# vgg.avgpool = nn.Identity()
# for i in range(23, len(vgg.features)):
#     vgg.features[i] = nn.Identity()


In [163]:
pred_vgg_f = vgg.features(batch)
pred_vgg_f = vgg.avgpool(pred_vgg_f)
pred_vgg = vgg.classifier(pred_vgg_f)
pred_vgg.shape

Layer 0: torch.Size([2, 64, 224, 224])
Layer 0: torch.Size([2, 64, 224, 224])
Layer 1: torch.Size([2, 64, 224, 224])
Layer 1: torch.Size([2, 64, 224, 224])
Layer 2: torch.Size([2, 64, 224, 224])
Layer 2: torch.Size([2, 64, 224, 224])
Layer 3: torch.Size([2, 64, 224, 224])
Layer 3: torch.Size([2, 64, 224, 224])
Layer 4: torch.Size([2, 64, 112, 112])
Layer 4: torch.Size([2, 64, 112, 112])
Layer 5: torch.Size([2, 128, 112, 112])
Layer 5: torch.Size([2, 128, 112, 112])
Layer 6: torch.Size([2, 128, 112, 112])
Layer 6: torch.Size([2, 128, 112, 112])
Layer 7: torch.Size([2, 128, 112, 112])
Layer 7: torch.Size([2, 128, 112, 112])
Layer 8: torch.Size([2, 128, 112, 112])
Layer 8: torch.Size([2, 128, 112, 112])
Layer 9: torch.Size([2, 128, 56, 56])
Layer 9: torch.Size([2, 128, 56, 56])
Layer 10: torch.Size([2, 256, 56, 56])
Layer 10: torch.Size([2, 256, 56, 56])
Layer 11: torch.Size([2, 256, 56, 56])
Layer 11: torch.Size([2, 256, 56, 56])
Layer 12: torch.Size([2, 256, 56, 56])
Layer 12: torch.Siz

RuntimeError: mat1 and mat2 shapes cannot be multiplied (7168x7 and 25088x4096)

RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [2, 25088]