In [2]:
import torch
import os
import pandas as pd
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import time
import sys
import copy

import torch.nn as nn
import torch.backends.cudnn as cudnn
import torchvision
import torch.optim as optim
import torch.nn.functional as tfunc
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch.nn.functional as func

from sklearn.metrics.ranking import roc_auc_score

from torch.utils.data import Dataset
from PIL import Image
from models.chexnet.DensenetModels import DenseNet121
from models.models import ResNet18
from tensorboardX import SummaryWriter

In [2]:
checkpoint = './forward/m-37050_0.897.pth.tar'

In [3]:
nnClassCount = 9
model = DenseNet121(nnClassCount, False).cuda()
model = torch.nn.DataParallel(model).cuda()
modelCheckpoint = torch.load(checkpoint)
model.load_state_dict(modelCheckpoint['state_dict'])
optimizer = optim.Adam (model.parameters(), lr=0.0001, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5)
optimizer.load_state_dict(modelCheckpoint['optimizer'])

In [4]:
modelCheckpoint['state_dict'].keys()

odict_keys(['module.densenet121.features.conv0.weight', 'module.densenet121.features.norm0.weight', 'module.densenet121.features.norm0.bias', 'module.densenet121.features.norm0.running_mean', 'module.densenet121.features.norm0.running_var', 'module.densenet121.features.norm0.num_batches_tracked', 'module.densenet121.features.denseblock1.denselayer1.norm1.weight', 'module.densenet121.features.denseblock1.denselayer1.norm1.bias', 'module.densenet121.features.denseblock1.denselayer1.norm1.running_mean', 'module.densenet121.features.denseblock1.denselayer1.norm1.running_var', 'module.densenet121.features.denseblock1.denselayer1.norm1.num_batches_tracked', 'module.densenet121.features.denseblock1.denselayer1.conv1.weight', 'module.densenet121.features.denseblock1.denselayer1.norm2.weight', 'module.densenet121.features.denseblock1.denselayer1.norm2.bias', 'module.densenet121.features.denseblock1.denselayer1.norm2.running_mean', 'module.densenet121.features.denseblock1.denselayer1.norm2.runni

In [5]:
W, b = (modelCheckpoint['state_dict']['module.densenet121.classifier.weight'],
        modelCheckpoint['state_dict']['module.densenet121.classifier.bias'])

In [6]:
print(W.shape, b.shape)

torch.Size([9, 1024]) torch.Size([9])


In [7]:
new_rows = (1/np.sqrt(1024)) * np.random.randn(2, 1024)
new_W_data = np.vstack([W.data[0, :].cpu().numpy(), new_rows, W.data[1:, :].cpu().numpy()])
new_biases = np.zeros((2,))
print(np.array(b.data[1:].cpu().numpy()))
print(np.array([b.data[0].cpu().numpy()]))
print(new_biases)
new_b_data = np.concatenate([np.array([b.data[0].cpu().numpy()]), new_biases, np.array(b.data[1:].cpu().numpy())])

[-0.02240922 -0.02525789  0.0368581   0.02283076  0.00329747  0.0016001
  0.03601192 -0.05278656]
[0.00723741]
[0. 0.]


In [8]:
new_W = torch.from_numpy(new_W_data).cuda()
new_b = torch.from_numpy(new_b_data).cuda()
print(b, new_b)

tensor([ 0.0072, -0.0224, -0.0253,  0.0369,  0.0228,  0.0033,  0.0016,  0.0360,
        -0.0528], device='cuda:0') tensor([ 0.0072,  0.0000,  0.0000, -0.0224, -0.0253,  0.0369,  0.0228,  0.0033,
         0.0016,  0.0360, -0.0528], device='cuda:0', dtype=torch.float64)


In [9]:
new_state_dict = copy.deepcopy(modelCheckpoint['state_dict'])
new_state_dict['module.densenet121.classifier.weight'] = new_W
new_state_dict['module.densenet121.classifier.bias'] = new_b

In [10]:
print(modelCheckpoint['state_dict'].keys())

odict_keys(['module.densenet121.features.conv0.weight', 'module.densenet121.features.norm0.weight', 'module.densenet121.features.norm0.bias', 'module.densenet121.features.norm0.running_mean', 'module.densenet121.features.norm0.running_var', 'module.densenet121.features.norm0.num_batches_tracked', 'module.densenet121.features.denseblock1.denselayer1.norm1.weight', 'module.densenet121.features.denseblock1.denselayer1.norm1.bias', 'module.densenet121.features.denseblock1.denselayer1.norm1.running_mean', 'module.densenet121.features.denseblock1.denselayer1.norm1.running_var', 'module.densenet121.features.denseblock1.denselayer1.norm1.num_batches_tracked', 'module.densenet121.features.denseblock1.denselayer1.conv1.weight', 'module.densenet121.features.denseblock1.denselayer1.norm2.weight', 'module.densenet121.features.denseblock1.denselayer1.norm2.bias', 'module.densenet121.features.denseblock1.denselayer1.norm2.running_mean', 'module.densenet121.features.denseblock1.denselayer1.norm2.runni

In [11]:
nnClassCount = 11
model = DenseNet121(nnClassCount, False).cuda()
model = torch.nn.DataParallel(model).cuda()
model.load_state_dict(new_state_dict)

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [12]:
print(optimizer.state_dict())

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



In [10]:
counter = modelCheckpoint['counter']
aurocMean = modelCheckpoint['valAUROC']

torch.save({'counter' : counter, 'state_dict': model.state_dict(), 'valAUROC' : aurocMean ,
            'optimizer' : optimizer.state_dict()},
           'm-0.897_starter' + str(counter) + '_' + str(round(aurocMean, 3)) + '.pth.tar')

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



In [3]:
checkpoint = 'm-0.897_starter37050_0.897.pth.tar'
nnClassCount = 11
model = DenseNet121(nnClassCount, False).cuda()
model = torch.nn.DataParallel(model).cuda()
modelCheckpoint = torch.load(checkpoint)
model.load_state_dict(modelCheckpoint['state_dict'])
optimizer = optim.Adam (model.parameters(), lr=0.0001, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5)
optimizer.load_state_dict(modelCheckpoint['optimizer'])

TypeError: state_dict() missing 1 required positional argument: 'self'