In [1]:
import torch
import requests
import os
from pathlib import Path
URL = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'

In [2]:
%load_ext autoreload

%autoreload 2

In [4]:
from torchvision.models import densenet121
from torchsummary import summary

from glasses.nn.models.classification.densenet import DenseNet


summary(DenseNet.densenet121().cuda(), (3,224,224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
       BatchNorm2d-5           [-1, 64, 56, 56]             128
              ReLU-6           [-1, 64, 56, 56]               0
            Conv2d-7          [-1, 128, 56, 56]           8,192
       BatchNorm2d-8          [-1, 128, 56, 56]             256
              ReLU-9          [-1, 128, 56, 56]               0
           Conv2d-10           [-1, 32, 56, 56]          36,864
DenseBottleNeckBlock-11           [-1, 96, 56, 56]               0
      BatchNorm2d-12           [-1, 96, 56, 56]             192
             ReLU-13           [-1, 96, 56, 56]               0
           Conv2d-14          [-1, 1

In [30]:
from torchvision.models import densenet201

from glasses.utils.ModuleTransfer import ModuleTransfer

src = densenet201(True)
dst = DenseNet.densenet201()


# src.load_state_dict(torch.load('./densenet123'))
x = torch.rand((1,3,224,224))

ModuleTransfer(src, dst)(x)


Downloading: "https://download.pytorch.org/models/densenet201-c1103571.pth" to /home/zuppif/.cache/torch/checkpoints/densenet201-c1103571.pth


HBox(children=(FloatProgress(value=0.0, max=81131730.0), HTML(value='')))




In [35]:
from glasses.utils.PretrainedWeightsProvider import PretrainedWeightsProvider

provider = PretrainedWeightsProvider()

provider['densenet169']

AssertionError: 

In [6]:
import io
from utils.ModuleTransfer import ModuleTransfer
from torchvision.models import resnet18, resnet34, resnet50, resnet101, resnet152
from nn.models.classification.resnet import ResNet
from tqdm.autonotebook import tqdm

zoo = {
    resnet18: 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    resnet34: 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    resnet50: 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    resnet101: 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    resnet152: 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}

zoo_models_mapping = {
    
    resnet18: ResNet.resnet18,
    resnet34: ResNet.resnet34,
    resnet50: ResNet.resnet50,
    resnet101: ResNet.resnet101,
    resnet152: ResNet.resnet152,
}

for src_def, dst_def in zoo_models_mapping.items():
    url = zoo[src_def]
    print(f"Getting weights at url={url}")
    
    r = requests.get(url, stream=True)
    
    with open('../temp.pth', 'wb') as f:
        total_length = int(r.headers.get('content-length'))
        bar = tqdm(r.iter_content(chunk_size=1024), total=total_length/1024)
        for chunk in bar: 
            if chunk:
                f.write(chunk)
                f.flush()

    src = src_def(False)
    src.load_state_dict(torch.load('../temp.pth'))
    dst = dst_def()

    x = torch.rand((1, 3, 224, 224))

    src.eval()
    dst.eval()
    
    a = src(x)
    b = dst(x)

    assert not torch.equal(a, b)
    
    
    ModuleTransfer(src, dst)(x)

    a = src(x)
    b = dst(x)

    assert torch.equal(a, b)

Getting weights at url=https://download.pytorch.org/models/resnet18-5c106cde.pth


HBox(children=(FloatProgress(value=0.0, max=45730.0), HTML(value='')))


Getting weights at url=https://download.pytorch.org/models/resnet34-333f7ec4.pth


HBox(children=(FloatProgress(value=0.0, max=85260.0), HTML(value='')))


Getting weights at url=https://download.pytorch.org/models/resnet50-19c8e357.pth


HBox(children=(FloatProgress(value=0.0, max=100100.0), HTML(value='')))


Getting weights at url=https://download.pytorch.org/models/resnet101-5d3b4d8f.pth


HBox(children=(FloatProgress(value=0.0, max=174540.0), HTML(value='')))


Getting weights at url=https://download.pytorch.org/models/resnet152-b121ed2d.pth


HBox(children=(FloatProgress(value=0.0, max=235870.0), HTML(value='')))




In [4]:
!ls

data	     interpretability  nn		 __pycache__
__init__.py  models.ipynb      playground.ipynb  utils
