In [1]:
import torch
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256'

In [2]:
torch.cuda.is_available(), torch.cuda.device_count()

(True, 1)

In [1]:
from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import models, transforms
import matplotlib.pyplot as plt
import time
import copy
import pandas as pd
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset
from tqdm.notebook import tqdm
from copy import deepcopy
from pathlib import Path

dirname = 'mnist-infl-per-layer'

In [2]:
def reset_all_weights(model: nn.Module) -> None:
    """
    refs:
        - https://discuss.pytorch.org/t/how-to-re-set-alll-parameters-in-a-network/20819/6
        - https://stackoverflow.com/questions/63627997/reset-parameters-of-a-neural-network-in-pytorch
        - https://pytorch.org/docs/stable/generated/torch.nn.Module.html
    """

    @torch.no_grad()
    def weight_reset(m: nn.Module):
        # - check if the current module has reset_parameters & if it's callabed called it on m
        reset_parameters = getattr(m, "reset_parameters", None)
        if callable(reset_parameters):
            m.reset_parameters()

    # Applies fn recursively to every submodule see: https://pytorch.org/docs/stable/generated/torch.nn.Module.html
    model.apply(fn=weight_reset)

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
import torchvision

train = torchvision.datasets.MNIST(root='./mnist', train=True,
                                        download=True, 
                                        transform=transforms.Compose([transforms.Grayscale(), transforms.ToTensor()]))
test = torchvision.datasets.MNIST(root='./mnist-test', train=False,
                                        download=True, 
                                  transform=transforms.Compose([transforms.Grayscale(), transforms.ToTensor()]))

In [5]:
def train_model_opt(model, criterion, optimizer, scheduler, dataloader, num_epochs=25, position=0, leave=True, stats=True):
    model.train()  # Set model to training mode
    for epoch in (q:= tqdm(range(num_epochs), unit='Epoch', position=position, desc='Training Epochs', leave=leave)):
        if stats:
            running_correct = 0
            running_loss = 0
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            # zero the parameter gradients
            optimizer.zero_grad()

            # forward
            # track history if only in train
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            with torch.no_grad():            
                if stats:
                    running_loss += loss.item() * inputs.size(0)
                    _, preds = torch.max(outputs, 1)
                    running_correct += torch.sum(preds == labels.data)            

            # backward + optimize only if in training phase
            loss.backward()
            
            optimizer.step()
            scheduler.step()
            
        if stats:
            epoch_loss = running_loss / len(train.data)
            epoch_acc = running_correct.double() / len(train.data)
            q.set_description(f'Epoch Loss: {epoch_loss:.5f}. Epoch Accuracy: {epoch_acc:.2f}')

    return model

In [6]:
epochs = 5

# make resnet 18 model
model_ft = models.resnet18()

# change input layer
# the default number of input channel in the resnet is 3, but our images are 1 channel. So we have to change 3 to 1.
# nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) <- default
model_ft.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) 

# change fc layer
# the number of classes in our dataset is 10. default is 1000.
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 10)

criterion = nn.CrossEntropyLoss()

model_ft = model_ft.to(device)

In [7]:
def make_model(model_ft, subset, batch_size=128, reset_weights=True, position=0, leave=True, lr=.1, gamma=.1, step_size=5):    
    loader = torch.utils.data.DataLoader(subset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=8)
    
    if reset_weights:
        reset_all_weights(model_ft)

    optimizer_ft = optim.Adam(filter(lambda p: p.requires_grad, model_ft.parameters()), lr=lr)

    # decay LR by a factor of 0.1 every 5 epochs
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=step_size, gamma=gamma)
        
    return train_model_opt(
        model_ft, criterion, optimizer_ft, exp_lr_scheduler, loader, num_epochs=epochs, position=position, leave=leave)

In [8]:
train_loader = torch.utils.data.DataLoader(train, batch_size=128 * 8, shuffle=False, num_workers=8)
test_loader = torch.utils.data.DataLoader(test, batch_size=128 * 8, shuffle=False, num_workers=8)

def get_probs(model, only_correct=True, loader=train_loader):
    prob_list = []
    
    model.to(device)
    model.eval()
    
    with torch.no_grad():
        for images, labels in tqdm(train_loader, leave=False):
            images = images.to(device)
            labels = labels.to(device)
            
            logits = model(images)
            probs = torch.softmax(logits, dim=1)
            if only_correct:
                prob_list.append(probs[torch.arange(probs.size(0)), labels].cpu())
            else:
                prob_list.append(probs.cpu())
    return torch.cat(prob_list, dim=0)

In [9]:
subset_proportion = 1
subset_size = int(subset_proportion * len(train))
subset_idx = torch.randint(len(train), (subset_size,))
trainset = torch.utils.data.Subset(train, subset_idx)

In [14]:
epochs = 5

model_ft = make_model(model_ft, trainset, batch_size=128 * 8, step_size=35, gamma=5e-3, lr=1e-3)

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

In [126]:
#model_ft.load_state_dict(deepcopy(model_state_base))

<All keys matched successfully>

In [15]:
probs = get_probs(model_ft).cpu()

model_state = model_ft.state_dict()

torch.save({
    'model_state': model_state,
    'probs': probs,
    'indexes': subset_idx
}, Path(dirname) / f'master')

  0%|          | 0/59 [00:00<?, ?it/s]

In [10]:
d = torch.load(Path(dirname) / f'master')
model_state_base = deepcopy(d['model_state'])

In [11]:
@torch.no_grad()
def prepare_grad_for_only_training_layer(model_ft, layer_num):
    model_ft.load_state_dict(deepcopy(model_state_base))
    for i, nn_child in enumerate(model_ft.children()):
        if i == layer_num:
            reset_all_weights(nn_child)
            for param in nn_child.parameters():
                param.requires_grad = True
        else:
            for param in nn_child.parameters():
                param.requires_grad = False
    return model_ft

In [12]:
import os

for layer in range(len(list(model_ft.children()))):
    os.makedirs(Path(dirname) / f'layer_{layer}')

In [13]:
model_ft = prepare_grad_for_only_training_layer(model_ft, 8)

In [14]:
def reset_weights_of_layer(layer_num):
    reset_all_weights(list(model_ft.children())[layer_num])

In [22]:
# Do For 1 layer to test.

layer = 7

model_ft = prepare_grad_for_only_training_layer(model_ft, layer)
children = list(model_ft.children())
nn_layer = children[layer]

for i in tqdm(range(150), position=1, desc='Models', unit='Model'):
    subset_proportion = .7
    subset_size = int(subset_proportion * len(train))
    subset_idx = torch.randint(len(train), (subset_size,))
    trainset = torch.utils.data.Subset(train, subset_idx)

    reset_weights_of_layer(layer)
    model_ft = make_model(model_ft, 
                          trainset,
                          position=2, leave=False, reset_weights=False, 
                          batch_size=128 * 8, step_size=35, gamma=5e-3, lr=1e-3)
    probs = get_probs(model_ft).cpu()

    model_state = model_ft.state_dict()

    torch.save({
        'model_state': model_state,
        'probs': probs,
        'indexes': subset_idx
    }, Path(dirname) / f'layer_{layer}' / f'{i}')

Models:   0%|          | 0/10 [00:00<?, ?Model/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

KeyboardInterrupt: 

# Generate all of the Models

In [None]:
epochs = 5

for layer in tqdm(range(len(list(model_ft.children()))), desc='Layer', position=0):
    model_ft = prepare_grad_for_only_training_layer(model_ft, layer)
    children = list(model_ft.children())
    nn_layer = children[layer]
    if len(list(nn_layer.parameters())) == 0:
        print(f'skipping {layer} bc no params')
        continue
        
    for i in tqdm(range(140, 200), position=1, desc='Models', unit='Model'):
        subset_proportion = .7
        subset_size = int(subset_proportion * len(train))
        subset_idx = torch.randint(len(train), (subset_size,))
        trainset = torch.utils.data.Subset(train, subset_idx)

        reset_weights_of_layer(layer)
        model_ft = make_model(model_ft, 
                              trainset,
                              position=2, leave=False, reset_weights=False, 
                              batch_size=128 * 8, step_size=35, gamma=5e-3, lr=1e-2)

        probs = get_probs(model_ft).cpu()

        model_state = model_ft.state_dict()
        path = Path(dirname) / f'layer_{layer}' / f'{i}'

        torch.save({
            'model_state': model_state,
            'probs': probs,
            'indexes': subset_idx
        }, path)
        torch.load(path)

Layer:   0%|          | 0/10 [00:00<?, ?it/s]

Models:   0%|          | 0/60 [00:00<?, ?Model/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Models:   0%|          | 0/60 [00:00<?, ?Model/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]

  0%|          | 0/59 [00:00<?, ?it/s]

Training Epochs:   0%|          | 0/5 [00:00<?, ?Epoch/s]