In [1]:
import glob
import pandas as pd
import numpy as np

In [2]:
import matplotlib.pyplot as plt

In [3]:
import torchvision
import torch

In [4]:
import numpy as np

In [5]:
class ModelTackOn(torch.nn.Module):
    def __init__(self, base_model, un_modified_model, pre_head_fc_sizes=[100], post_head_fc_sizes=[100], classifier_fc_sizes=None):
            super(ModelTackOn, self).__init__()
            self.base_model = base_model
            final_base_layer = list(un_modified_model.children())[-1]
            # final_base_layer = list(list(model.children())[-1].children())[-1]
            # print(final_base_layer)

            self.pre_head_fc_lst = []
            self.post_head_fc_lst = []
            self.classifier_fc_lst = []

            self.init_prehead(final_base_layer, pre_head_fc_sizes)
            self.init_posthead(pre_head_fc_sizes[-1], post_head_fc_sizes)
            if classifier_fc_sizes is not None:
                self.init_classifier(pre_head_fc_sizes[-1], classifier_fc_sizes)
    
    def init_prehead(self, prv_layer, pre_head_fc_sizes):
        for i, pre_head_fc in enumerate(pre_head_fc_sizes):
            if i == 0:
                in_features = prv_layer.in_features if hasattr(prv_layer,'in_features') else 1280
            else:
                in_features = pre_head_fc_sizes[i - 1]
            fc_layer = torch.nn.Linear(in_features=in_features, out_features=pre_head_fc)
            self.add_module(f'PreHead_{i}', fc_layer)
            self.pre_head_fc_lst.append(fc_layer)

#             if i < len(pre_head_fc_sizes) - 1:
            non_linearity = torch.nn.ReLU()
            self.add_module(f'PreHead_{i}_NonLinearity', non_linearity)
            self.pre_head_fc_lst.append(non_linearity)

    def init_posthead(self, prv_size, post_head_fc_sizes):
        for i, post_head_fc in enumerate(post_head_fc_sizes):
            if i == 0:
                in_features = prv_size
            else:
                in_features = post_head_fc_sizes[i - 1]
            fc_layer = torch.nn.Linear(in_features=in_features, out_features=post_head_fc)
            self.add_module(f'PostHead_{i}', fc_layer)
            self.post_head_fc_lst.append(fc_layer)

            if i < len(post_head_fc_sizes) - 1:
                non_linearity = torch.nn.ReLU()
                self.add_module(f'PostHead_{i}_NonLinearity', non_linearity)
                self.pre_head_fc_lst.append(non_linearity)
    
    def init_classifier(self, prv_size, classifier_fc_sizes):
            for i, classifier_fc in enumerate(classifier_fc_sizes):
                if i == 0:
                    in_features = prv_size
                else:
                    in_features = classifier_fc_sizes[i - 1]
            fc_layer = torch.nn.Linear(in_features=in_features, out_features=classifier_fc)
            self.add_module(f'Classifier_{i}', fc_layer)
            self.classifier_fc_lst.append(fc_layer)

    def reinit_classifier(self):
        for i_layer, layer in enumerate(self.classifier_fc_lst):
            layer.reset_parameters()
    
#     def forward(self, X):
#         interim = self.base_model(X)
#         interim = self.get_head(interim)
#         interim = self.get_latent(interim)
#         return interim

    def forward_classifier(self, X):
        interim = self.base_model(X)
        interim = self.get_head(interim)
        interim = self.classify(interim)
        return interim

    def forward_latent(self, X):
        interim = self.base_model(X)
        interim = self.get_head(interim)
        interim = self.get_latent(interim)
        return interim


    def get_head(self, base_out):
        # print('base_out', base_out.shape)
        head = base_out
        for pre_head_layer in self.pre_head_fc_lst:
          # print('pre_head_layer', pre_head_layer.in_features)
          head = pre_head_layer(head)
          # print('head', head.shape)
        return head

    def get_latent(self, head):
        latent = head
        for post_head_layer in self.post_head_fc_lst:
            latent = post_head_layer(latent)
        return latent

    def classify(self, head):
        logit = head
        for classifier_layer in self.classifier_fc_lst:
            logit = classifier_layer(logit)
        return logit

    def set_pre_head_grad(self, requires_grad=True):
        for layer in self.pre_head_fc_lst:
            for param in layer.parameters():
                param.requires_grad = requires_grad
                
    def set_post_head_grad(self, requires_grad=True):
        for layer in self.post_head_fc_lst:
            for param in layer.parameters():
                param.requires_grad = requires_grad

    def set_classifier_grad(self, requires_grad=True):
        for layer in self.classifier_fc_lst:
            for param in layer.parameters():
                param.requires_grad = requires_grad

    def prep_contrast(self):
        self.set_pre_head_grad(requires_grad=True)
        self.set_post_head_grad(requires_grad=True)
        self.set_classifier_grad(requires_grad=False)

    def prep_classifier(self):
        self.set_pre_head_grad(requires_grad=False)
        self.set_post_head_grad(requires_grad=False)
        self.set_classifier_grad(requires_grad=True)


In [6]:
import torchvision.models

# base_model_frozen = torchvision.models.resnet101(pretrained=True)
# base_model_frozen = torchvision.models.resnet18(pretrained=True)
# base_model_frozen = torchvision.models.wide_resnet50_2(pretrained=True)
# base_model_frozen = torchvision.models.resnet50(pretrained=True)

base_model_frozen = torchvision.models.efficientnet_b0(pretrained=True)

In [7]:
# unfreeze particular blocks in ResNet model
base_model_frozen = torchvision.models.efficientnet_b0(pretrained=True)
model_chopped = torch.nn.Sequential(*(list(base_model_frozen.children())[:-1] + [torch.nn.Flatten()]))
model = ModelTackOn(model_chopped, base_model_frozen, pre_head_fc_sizes=[1024, 512], post_head_fc_sizes=[64], classifier_fc_sizes=[4])
model.train();

base_dir = '/Users/josh/Documents'
model_file_name = 'ResNet18_simCLR_model_202112078_EOD_transfmod=efficient2'
model.load_state_dict(torch.load(f'{base_dir}/github_repos/GCaMP_ROI_classifier/new_stuff/models/{model_file_name}.pth', map_location=torch.device('cpu')))


<All keys matched successfully>

In [8]:
# unfreeze particular blocks in ResNet model
base_model_frozen = torchvision.models.efficientnet_b0(pretrained=True)
model_chopped = torch.nn.Sequential(*(list(base_model_frozen.children())[:-1] + [torch.nn.Flatten()]))
model2 = ModelTackOn(model_chopped, base_model_frozen, pre_head_fc_sizes=[1024, 512], post_head_fc_sizes=[64], classifier_fc_sizes=[4])
model2.train();

base_dir = '/Users/josh/Documents'
model_file_name = 'ResNet18_simCLR_model_202112078_EOD_transfmod=efficient'
model2.load_state_dict(torch.load(f'{base_dir}/github_repos/GCaMP_ROI_classifier/new_stuff/models/{model_file_name}.pth', map_location=torch.device('cpu')))


<All keys matched successfully>

In [9]:
# unfreeze particular blocks in ResNet model
base_model_frozen = torchvision.models.efficientnet_b0(pretrained=True)
model_chopped = torch.nn.Sequential(*(list(base_model_frozen.children())[:-1] + [torch.nn.Flatten()]))
model3 = ModelTackOn(model_chopped, base_model_frozen, pre_head_fc_sizes=[1024, 512], post_head_fc_sizes=[64], classifier_fc_sizes=[4])
model3.train();

base_dir = '/Users/josh/Documents'
model_file_name = 'ResNet18_simCLR_model_202112078_EOD_transfmod=efficient2-only7unfrozen'
model3.load_state_dict(torch.load(f'{base_dir}/github_repos/GCaMP_ROI_classifier/new_stuff/models/{model_file_name}.pth', map_location=torch.device('cpu')))


<All keys matched successfully>

In [10]:
# unfreeze particular blocks in ResNet model
base_model_frozen = torchvision.models.efficientnet_b0(pretrained=True)
model_chopped = torch.nn.Sequential(*(list(base_model_frozen.children())[:-1] + [torch.nn.Flatten()]))
model4 = ModelTackOn(model_chopped, base_model_frozen, pre_head_fc_sizes=[1024, 512], post_head_fc_sizes=[64], classifier_fc_sizes=[4])
model4.train();


In [11]:
model_params = list(model.named_parameters())
model2_params = list(model2.named_parameters())
model3_params = list(model3.named_parameters())
model4_params = list(model4.named_parameters())

for i, (name, param) in enumerate(model_params):
    param_1 = model_params[i][1]
    param_2 = model2_params[i][1]
    param_3 = model3_params[i][1]
    param_4 = model4_params[i][1]
    
    print(i,'1=4',model_params[i][0],np.all(param_1.detach().numpy() == param_4.detach().numpy()))
    print(i,'2=4',model2_params[i][0],np.all(param_2.detach().numpy() == param_4.detach().numpy()))  
    print(i,'3=4',model3_params[i][0],np.all(param_3.detach().numpy() == param_4.detach().numpy()))

0 1=4 base_model.0.0.0.weight True
0 2=4 base_model.0.0.0.weight True
0 3=4 base_model.0.0.0.weight True
1 1=4 base_model.0.0.1.weight True
1 2=4 base_model.0.0.1.weight True
1 3=4 base_model.0.0.1.weight True
2 1=4 base_model.0.0.1.bias True
2 2=4 base_model.0.0.1.bias True
2 3=4 base_model.0.0.1.bias True
3 1=4 base_model.0.1.0.block.0.0.weight True
3 2=4 base_model.0.1.0.block.0.0.weight True
3 3=4 base_model.0.1.0.block.0.0.weight True
4 1=4 base_model.0.1.0.block.0.1.weight True
4 2=4 base_model.0.1.0.block.0.1.weight True
4 3=4 base_model.0.1.0.block.0.1.weight True
5 1=4 base_model.0.1.0.block.0.1.bias True
5 2=4 base_model.0.1.0.block.0.1.bias True
5 3=4 base_model.0.1.0.block.0.1.bias True
6 1=4 base_model.0.1.0.block.1.fc1.weight True
6 2=4 base_model.0.1.0.block.1.fc1.weight True
6 3=4 base_model.0.1.0.block.1.fc1.weight True
7 1=4 base_model.0.1.0.block.1.fc1.bias True
7 2=4 base_model.0.1.0.block.1.fc1.bias True
7 3=4 base_model.0.1.0.block.1.fc1.bias True
8 1=4 base_model

160 2=4 base_model.0.6.1.block.1.1.weight False
160 3=4 base_model.0.6.1.block.1.1.weight True
161 1=4 base_model.0.6.1.block.1.1.bias False
161 2=4 base_model.0.6.1.block.1.1.bias False
161 3=4 base_model.0.6.1.block.1.1.bias True
162 1=4 base_model.0.6.1.block.2.fc1.weight False
162 2=4 base_model.0.6.1.block.2.fc1.weight False
162 3=4 base_model.0.6.1.block.2.fc1.weight True
163 1=4 base_model.0.6.1.block.2.fc1.bias False
163 2=4 base_model.0.6.1.block.2.fc1.bias False
163 3=4 base_model.0.6.1.block.2.fc1.bias True
164 1=4 base_model.0.6.1.block.2.fc2.weight False
164 2=4 base_model.0.6.1.block.2.fc2.weight False
164 3=4 base_model.0.6.1.block.2.fc2.weight True
165 1=4 base_model.0.6.1.block.2.fc2.bias False
165 2=4 base_model.0.6.1.block.2.fc2.bias False
165 3=4 base_model.0.6.1.block.2.fc2.bias True
166 1=4 base_model.0.6.1.block.3.0.weight False
166 2=4 base_model.0.6.1.block.3.0.weight False
166 3=4 base_model.0.6.1.block.3.0.weight True
167 1=4 base_model.0.6.1.block.3.1.weight

In [12]:
np.all(param_1.detach().numpy() == param_2.detach().numpy()), np.all(param_1.detach().numpy() == param_3.detach().numpy()), np.all(param_2.detach().numpy() == param_3.detach().numpy())

(False, False, False)

In [13]:
np.squeeze(model_params[195][1].detach().numpy())

array([[-0.09176432, -0.00588252,  0.02027349, ..., -0.03092107,
         0.06761403, -0.01461071],
       [-0.07848272, -0.01371442, -0.07482359, ...,  0.10687242,
        -0.04747317, -0.08791421],
       [-0.10090791, -0.05378744, -0.04679393, ...,  0.02166761,
        -0.05217503, -0.03194081],
       ...,
       [-0.02420462, -0.00365353,  0.22248024, ...,  0.01756833,
        -0.01050397, -0.03010276],
       [-0.06661663,  0.08000631, -0.0047844 , ..., -0.17939425,
         0.07883891, -0.1561394 ],
       [-0.01726327, -0.13413514,  0.12521721, ...,  0.10406896,
        -0.07013278,  0.0010998 ]], dtype=float32)

In [14]:
np.squeeze(model2_params[195][1].detach().numpy())

array([[-0.09259737, -0.00148659,  0.02181471, ..., -0.02394226,
         0.06473046,  0.00086494],
       [-0.06993268, -0.02882577, -0.06502995, ...,  0.10163853,
        -0.05011335, -0.08235805],
       [-0.1077586 , -0.05161716, -0.04124809, ...,  0.01477882,
        -0.03953496, -0.04590347],
       ...,
       [-0.02840404, -0.00700152,  0.22568905, ...,  0.00865884,
        -0.0146779 , -0.04499444],
       [-0.05971367,  0.08413228, -0.00977129, ..., -0.18800034,
         0.08246574, -0.16589299],
       [-0.02095276, -0.13889514,  0.12529647, ...,  0.09240452,
        -0.06731603,  0.00355882]], dtype=float32)