In [8]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F


class MLPLayer(nn.Module):
    def __init__(self, input_dim, hidden_features, out_features, dropout_rate=0.1):
        super(MLPLayer, self).__init__()
        self.dense1 = nn.Linear(input_dim, hidden_features, bias=True)
        self.dropout = nn.Dropout(dropout_rate)
        self.activation = nn.GELU()
        self.dense2 = nn.Linear(hidden_features, out_features, bias=True)

    def forward(self, x):
        x = self.dense1(x)
        x = self.activation(x)
        x = self.dropout(x)
        x = self.dense2(x)
        x = self.dropout(x)
        return x


class CrossAttentionLayer(nn.Module):
    def __init__(self, input_dim, num_heads):
        super(CrossAttentionLayer, self).__init__()
        self.multihead_attn = nn.MultiheadAttention(embed_dim=input_dim, num_heads=num_heads, batch_first=True)

    def forward(self, query, key, value, key_padding_mask=None):
        # query = [batch size, query len, features]
        # key, value = [batch size, key/value len, features]
        attn_output, _ = self.multihead_attn(query, key, value, key_padding_mask=key_padding_mask)
        return attn_output

class TransformerNoduleBimodalClassifier(nn.Module):
    def __init__(self, input_dim,
                 mlp_ratio_ct, mlp_ratio_pet,
                 num_heads_ct, num_heads_pet,
                 num_layers_ct, num_layers_pet,
                 num_classes):
        super(TransformerNoduleBimodalClassifier, self).__init__()

        self.transformer_encoder_ct = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=input_dim,
                                                                                       dim_feedforward=int(mlp_ratio_ct*input_dim),
                                                                                       nhead=num_heads_ct,
                                                                                       activation="gelu",
                                                                                       batch_first=True,
                                                                                       dropout=0.5),
                                                            num_layers=num_layers_ct)
        self.transformer_encoder_pet = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=input_dim,
                                                                                        dim_feedforward=int(mlp_ratio_pet*input_dim),
                                                                                        nhead=num_heads_pet,
                                                                                        activation="gelu",
                                                                                        batch_first=True,
                                                                                        dropout=0.5),
                                                             num_layers=num_layers_pet)

        self.norm_ct = nn.LayerNorm(input_dim)
        self.norm_pet = nn.LayerNorm(input_dim)

        self.cls_token_ct = nn.Parameter(torch.randn(1, 1, input_dim))
        self.cls_token_pet = nn.Parameter(torch.randn(1, 1, input_dim))

        self.classifier_ct = MLPLayer(input_dim, input_dim*2, num_classes, dropout_rate=0.1)
        self.classifier_pet = MLPLayer(input_dim, input_dim*2, num_classes, dropout_rate=0.1)

        self.projection_petct = MLPLayer(input_dim*2, input_dim, input_dim, dropout_rate=0.1)

        self.cross_attention_ct = CrossAttentionLayer(input_dim, num_heads_ct)
        self.cross_attention_pet = CrossAttentionLayer(input_dim, num_heads_ct)
        self.classifier_petct = MLPLayer(input_dim, input_dim*2, num_classes, dropout_rate=0.1)

    def forward(self, x_ct=None, x_pet=None):
        use_ct = x_ct is not None
        use_pet = x_pet is not None
        assert use_ct or use_pet, "At least one modality should be used"
        # add cls token and norm to each pet ct seq
        if use_ct:
            batch, seq_len, feature_dim = x_ct.shape
            x_ct = torch.cat([self.cls_token_ct.repeat(batch, 1, 1), x_ct], dim=1)
            x_ct = self.norm_ct(x_ct)
            x_ct = self.transformer_encoder_ct(x_ct)
            ct_cls_token = x_ct[:, 0, :]
        else:
            ct_cls_token = self.cls_token_ct.repeat(1, 1, 1)

        if use_pet:
            batch, seq_len, feature_dim = x_pet.shape
            x_pet = torch.cat([self.cls_token_pet.repeat(batch, 1, 1), x_pet], dim=1)
            x_pet = self.norm_pet(x_pet)
            x_pet = self.transformer_encoder_pet(x_pet)
            pet_cls_token = x_pet[:, 0, :]
        else:
            pet_cls_token = self.cls_token_pet.repeat(1, 1, 1)

        # cross attention between pet-ct and ct-pet
        if use_ct and use_pet:
            x_ct_attn = self.cross_attention_ct(query=x_ct, key=x_pet, value=x_pet)
            x_pet_attn = self.cross_attention_pet(query=x_pet, key=x_ct, value=x_ct)
            ct_cls_token = x_ct_attn[:, 0, :]
            pet_cls_token = x_pet_attn[:, 0, :]

            logits_ct = self.classifier_ct(ct_cls_token)
            logits_pet = self.classifier_pet(pet_cls_token)

            petct_cls_token = torch.cat([ct_cls_token, pet_cls_token], dim=1)
            petct_cls_token = self.projection_petct(petct_cls_token)
            logits_petct = self.classifier_petct(petct_cls_token)

        elif use_ct:
            logits_ct = self.classifier_ct(ct_cls_token)
            logits_pet = logits_ct
            logits_petct = logits_ct
            petct_cls_token = ct_cls_token
        else:
            logits_pet = self.classifier_pet(pet_cls_token)
            logits_ct = logits_pet
            logits_petct = logits_pet
            petct_cls_token = pet_cls_token

        return logits_petct, petct_cls_token, logits_ct, logits_pet

class TransformerNoduleClassifier(nn.Module):
    def __init__(self, input_dim, dim_feedforward, num_heads, num_classes, num_layers,):
        super(TransformerNoduleClassifier, self).__init__()
        encoder_layer = nn.TransformerEncoderLayer(d_model=input_dim,
                                                   dim_feedforward=dim_feedforward,
                                                   nhead=num_heads,
                                                   activation="gelu",
                                                   batch_first=True,
                                                   dropout=0.1)
        self.norm = nn.LayerNorm(input_dim)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.cls_token = nn.Parameter(torch.randn(1, 1, input_dim))
        self.cross_attention = CrossAttentionLayer(input_dim, num_heads) #added
        self.classifier_prev = MLPLayer(input_dim, input_dim*2, num_classes, dropout_rate=0.1)
        self.classifier_final = MLPLayer(input_dim, input_dim*2, num_classes, dropout_rate=0.1)
        
    def forward(self, x):
        batch, seq_len, feature_dim = x.shape
        cls_token = self.cls_token.repeat(batch, 1, 1)
        x = torch.cat([cls_token, x], dim=1)
        x = self.norm(x)
        x = self.transformer_encoder(x)
        ###added
        cls_token = x[:, 0, :]    
        logits_prev = self.classifier_prev(cls_token)

        x_attn = self.cross_attention(query=x, key=x, value=x)
                
        logits_final = self.classifier_final(x_attn[:, 0, :])
        ###added
        return logits_final, logits_prev

In [24]:
num_classes=2
feature_dim=768
mlp_ratio = 4
num_heads = 4
num_layers = 2
dim_feedforward = int(feature_dim*mlp_ratio)
model=TransformerNoduleClassifier(input_dim=feature_dim,
                                dim_feedforward=dim_feedforward,
                                num_heads=num_heads,
                                num_classes=num_classes,
                                num_layers=num_layers)

In [15]:
ls ../models/petct_online_rad_dino_cross_ct/rad_dino_transformer_stanford/ct/kfold_0/model_epoch_0018.pth

losses.html           test_metrics_27.json   train_metrics_19.json
model_epoch_0000.pth  test_metrics_28.json   train_metrics_1.json
model_epoch_0009.pth  test_metrics_29.json   train_metrics_20.json
model_epoch_0010.pth  test_metrics_2.json    train_metrics_21.json
model_epoch_0011.pth  test_metrics_30.json   train_metrics_22.json
model_epoch_0018.pth  test_metrics_31.json   train_metrics_23.json
test_metrics_0.json   test_metrics_32.json   train_metrics_24.json
test_metrics_10.json  test_metrics_33.json   train_metrics_25.json
test_metrics_11.json  test_metrics_3.json    train_metrics_26.json
test_metrics_12.json  test_metrics_4.json    train_metrics_27.json
test_metrics_13.json  test_metrics_5.json    train_metrics_28.json
test_metrics_14.json  test_metrics_6.json    train_metrics_29.json
test_metrics_15.json  test_metrics_7.json    train_metrics_2.json
test_metrics_16.json  test_metrics_8.json    train_metrics_30.json
test_metrics_17.json  test_metrics_9.json    train_metrics_31.js

In [16]:
model_path="../models/petct_online_rad_dino_cross_ct/rad_dino_transformer_stanford/ct/kfold_0/model_epoch_0018.pth"

In [26]:
model.load_state_dict(torch.load(model_path, map_location="cuda:0",weights_only=True))

<All keys matched successfully>

In [20]:
w=torch.load(model_path, map_location="cpu",weights_only=True)

In [110]:
num_classes=2
feature_dim=768
mlp_ratio_ct = 4
num_heads_ct = 4
num_layers_ct = 2

mlp_ratio_pet = 4
num_heads_pet = 4
num_layers_pet = 2
bimodel = TransformerNoduleBimodalClassifier(feature_dim,
                                           mlp_ratio_ct, mlp_ratio_pet,
                                           num_heads_ct, num_heads_pet,
                                           num_layers_ct, num_layers_pet,
                                           num_classes=num_classes)

In [34]:
ls ../models/petct_online_rad_dino_focalloss_thr0/rad_dino_transformer_santa_maria/petct/kfold_2/model_epoch_0008.pth

losses.html           test_metrics_22.json   train_metrics_17.json
model_epoch_0000.pth  test_metrics_23.json   train_metrics_18.json
model_epoch_0002.pth  test_metrics_2.json    train_metrics_19.json
model_epoch_0008.pth  test_metrics_3.json    train_metrics_1.json
test_metrics_0.json   test_metrics_4.json    train_metrics_20.json
test_metrics_10.json  test_metrics_5.json    train_metrics_21.json
test_metrics_11.json  test_metrics_6.json    train_metrics_22.json
test_metrics_12.json  test_metrics_7.json    train_metrics_23.json
test_metrics_13.json  test_metrics_8.json    train_metrics_2.json
test_metrics_14.json  test_metrics_9.json    train_metrics_3.json
test_metrics_15.json  train_metrics_0.json   train_metrics_4.json
test_metrics_16.json  train_metrics_10.json  train_metrics_5.json
test_metrics_17.json  train_metrics_11.json  train_metrics_6.json
test_metrics_18.json  train_metrics_12.json  train_metrics_7.json
test_metrics_19.json  train_metrics_13.json  train_metrics_8.json
tes

In [38]:
bimodel_path="../models/petct_online_rad_dino_focalloss_thr0/rad_dino_transformer_santa_maria/petct/kfold_2/model_epoch_0008.pth"
bw=torch.load(bimodel_path, map_location="cuda:0",weights_only=True)
bimodel.load_state_dict(bw)

<All keys matched successfully>

In [22]:
w.keys()

odict_keys(['cls_token', 'norm.weight', 'norm.bias', 'transformer_encoder.layers.0.self_attn.in_proj_weight', 'transformer_encoder.layers.0.self_attn.in_proj_bias', 'transformer_encoder.layers.0.self_attn.out_proj.weight', 'transformer_encoder.layers.0.self_attn.out_proj.bias', 'transformer_encoder.layers.0.linear1.weight', 'transformer_encoder.layers.0.linear1.bias', 'transformer_encoder.layers.0.linear2.weight', 'transformer_encoder.layers.0.linear2.bias', 'transformer_encoder.layers.0.norm1.weight', 'transformer_encoder.layers.0.norm1.bias', 'transformer_encoder.layers.0.norm2.weight', 'transformer_encoder.layers.0.norm2.bias', 'transformer_encoder.layers.1.self_attn.in_proj_weight', 'transformer_encoder.layers.1.self_attn.in_proj_bias', 'transformer_encoder.layers.1.self_attn.out_proj.weight', 'transformer_encoder.layers.1.self_attn.out_proj.bias', 'transformer_encoder.layers.1.linear1.weight', 'transformer_encoder.layers.1.linear1.bias', 'transformer_encoder.layers.1.linear2.weigh

In [40]:
bw.keys()

odict_keys(['cls_token_ct', 'cls_token_pet', 'transformer_encoder_ct.layers.0.self_attn.in_proj_weight', 'transformer_encoder_ct.layers.0.self_attn.in_proj_bias', 'transformer_encoder_ct.layers.0.self_attn.out_proj.weight', 'transformer_encoder_ct.layers.0.self_attn.out_proj.bias', 'transformer_encoder_ct.layers.0.linear1.weight', 'transformer_encoder_ct.layers.0.linear1.bias', 'transformer_encoder_ct.layers.0.linear2.weight', 'transformer_encoder_ct.layers.0.linear2.bias', 'transformer_encoder_ct.layers.0.norm1.weight', 'transformer_encoder_ct.layers.0.norm1.bias', 'transformer_encoder_ct.layers.0.norm2.weight', 'transformer_encoder_ct.layers.0.norm2.bias', 'transformer_encoder_ct.layers.1.self_attn.in_proj_weight', 'transformer_encoder_ct.layers.1.self_attn.in_proj_bias', 'transformer_encoder_ct.layers.1.self_attn.out_proj.weight', 'transformer_encoder_ct.layers.1.self_attn.out_proj.bias', 'transformer_encoder_ct.layers.1.linear1.weight', 'transformer_encoder_ct.layers.1.linear1.bias

In [49]:
num_classes=2
input_dim=768
mlp_ratio_ct = 4
num_heads_ct = 4
num_layers_ct = 2

mlp_ratio_pet = 4
num_heads_pet = 4
num_layers_pet = 2

transformer_encoder_ct = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=input_dim,
                                                                                       dim_feedforward=int(mlp_ratio_ct*input_dim),
                                                                                       nhead=num_heads_ct,
                                                                                       activation="gelu",
                                                                                       batch_first=True,
                                                                                       dropout=0.5),
                                                            num_layers=num_layers_ct)
transformer_encoder_pet = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=input_dim,
                                                                                dim_feedforward=int(mlp_ratio_pet*input_dim),
                                                                                nhead=num_heads_pet,
                                                                                activation="gelu",
                                                                                batch_first=True,
                                                                                dropout=0.5),
                                                     num_layers=num_layers_pet)

norm_ct = nn.LayerNorm(input_dim)
norm_pet = nn.LayerNorm(input_dim)

cls_token_ct = nn.Parameter(torch.randn(1, 1, input_dim))
cls_token_pet = nn.Parameter(torch.randn(1, 1, input_dim))

classifier_ct = MLPLayer(input_dim, input_dim*2, num_classes, dropout_rate=0.1)
classifier_pet = MLPLayer(input_dim, input_dim*2, num_classes, dropout_rate=0.1)

projection_petct = MLPLayer(input_dim*2, input_dim, input_dim, dropout_rate=0.1)

cross_attention_ct = CrossAttentionLayer(input_dim, num_heads_ct)
cross_attention_pet = CrossAttentionLayer(input_dim, num_heads_ct)
classifier_petct = MLPLayer(input_dim, input_dim*2, num_classes, dropout_rate=0.1)

In [73]:
def add_mode(param,mode):
    words = param.split(".")  
    words[0] = words[0]+f"_{mode}"
    new_name = ".".join(words)
    return new_name

In [74]:
params_ct = {add_mode(k,"ct"):v for k, v in bw.items() if k.startswith('cls_token_ct')}
params_ct = {add_mode(k,"ct"):v for k, v in bw.items() if k.startswith('cls_token_ct')}

In [52]:
cls_token_ct.load_state_dict(trained_trim , strict=False)

AttributeError: 'Parameter' object has no attribute 'load_state_dict'

In [48]:
cls_token_ct

Parameter containing:
tensor([[[ 4.2889e-02,  3.0930e-02,  3.6204e-01,  4.2075e-01,  8.0185e-01,
           4.5093e-01, -1.3860e-03,  4.1037e-01,  6.5379e-01, -5.2713e-01,
           1.7773e+00,  2.9679e-01, -2.9481e-01,  8.5147e-01,  5.7022e-01,
          -1.3976e+00,  4.6590e-02,  1.6242e+00, -1.5761e+00,  2.4232e+00,
           1.3479e+00, -1.2183e+00, -3.5285e-01, -1.4980e+00,  8.9015e-01,
          -1.8051e-01, -8.7205e-01,  8.4085e-01,  1.1092e+00,  1.4244e+00,
          -6.4386e-01, -1.6729e+00, -4.7799e-01,  1.8768e-01, -1.1927e+00,
           2.7242e-02,  8.3766e-01, -2.6364e-01, -8.0378e-02,  5.6323e-01,
           2.9465e-01, -1.2101e+00,  1.1094e-01, -5.8797e-01, -7.6733e-01,
           1.4253e+00, -7.0368e-01,  2.5992e-01,  1.8720e+00, -1.7000e+00,
           1.8712e-01,  1.6055e+00,  9.3516e-01, -5.7682e-01,  1.1219e+00,
          -5.1189e-01, -6.6195e-01,  1.1972e+00,  2.4028e+00,  1.1868e+00,
           2.0569e+00,  6.6520e-01,  5.9752e-01,  1.3007e+00,  5.1760e-01,
   

In [53]:
type(bw)

collections.OrderedDict

In [62]:
for k,v in w.items():
    print(k)

cls_token
norm.weight
norm.bias
transformer_encoder.layers.0.self_attn.in_proj_weight
transformer_encoder.layers.0.self_attn.in_proj_bias
transformer_encoder.layers.0.self_attn.out_proj.weight
transformer_encoder.layers.0.self_attn.out_proj.bias
transformer_encoder.layers.0.linear1.weight
transformer_encoder.layers.0.linear1.bias
transformer_encoder.layers.0.linear2.weight
transformer_encoder.layers.0.linear2.bias
transformer_encoder.layers.0.norm1.weight
transformer_encoder.layers.0.norm1.bias
transformer_encoder.layers.0.norm2.weight
transformer_encoder.layers.0.norm2.bias
transformer_encoder.layers.1.self_attn.in_proj_weight
transformer_encoder.layers.1.self_attn.in_proj_bias
transformer_encoder.layers.1.self_attn.out_proj.weight
transformer_encoder.layers.1.self_attn.out_proj.bias
transformer_encoder.layers.1.linear1.weight
transformer_encoder.layers.1.linear1.bias
transformer_encoder.layers.1.linear2.weight
transformer_encoder.layers.1.linear2.bias
transformer_encoder.layers.1.nor

In [86]:
xs={k:v for k, v in bw.items() if 'ct' in k}
for k,v in xs.items():
    print(k)

cls_token_ct
transformer_encoder_ct.layers.0.self_attn.in_proj_weight
transformer_encoder_ct.layers.0.self_attn.in_proj_bias
transformer_encoder_ct.layers.0.self_attn.out_proj.weight
transformer_encoder_ct.layers.0.self_attn.out_proj.bias
transformer_encoder_ct.layers.0.linear1.weight
transformer_encoder_ct.layers.0.linear1.bias
transformer_encoder_ct.layers.0.linear2.weight
transformer_encoder_ct.layers.0.linear2.bias
transformer_encoder_ct.layers.0.norm1.weight
transformer_encoder_ct.layers.0.norm1.bias
transformer_encoder_ct.layers.0.norm2.weight
transformer_encoder_ct.layers.0.norm2.bias
transformer_encoder_ct.layers.1.self_attn.in_proj_weight
transformer_encoder_ct.layers.1.self_attn.in_proj_bias
transformer_encoder_ct.layers.1.self_attn.out_proj.weight
transformer_encoder_ct.layers.1.self_attn.out_proj.bias
transformer_encoder_ct.layers.1.linear1.weight
transformer_encoder_ct.layers.1.linear1.bias
transformer_encoder_ct.layers.1.linear2.weight
transformer_encoder_ct.layers.1.line

In [84]:
"bab" in "aabbaa"

False

In [76]:
trained_trim = {add_mode(k,"ct"):v for k, v in w.items() if k.startswith(('cls_token','transformer_encoder','norm'))}
for k,v in trained_trim.items():
    print(k)

cls_token_ct
norm_ct.weight
norm_ct.bias
transformer_encoder_ct.layers.0.self_attn.in_proj_weight
transformer_encoder_ct.layers.0.self_attn.in_proj_bias
transformer_encoder_ct.layers.0.self_attn.out_proj.weight
transformer_encoder_ct.layers.0.self_attn.out_proj.bias
transformer_encoder_ct.layers.0.linear1.weight
transformer_encoder_ct.layers.0.linear1.bias
transformer_encoder_ct.layers.0.linear2.weight
transformer_encoder_ct.layers.0.linear2.bias
transformer_encoder_ct.layers.0.norm1.weight
transformer_encoder_ct.layers.0.norm1.bias
transformer_encoder_ct.layers.0.norm2.weight
transformer_encoder_ct.layers.0.norm2.bias
transformer_encoder_ct.layers.1.self_attn.in_proj_weight
transformer_encoder_ct.layers.1.self_attn.in_proj_bias
transformer_encoder_ct.layers.1.self_attn.out_proj.weight
transformer_encoder_ct.layers.1.self_attn.out_proj.bias
transformer_encoder_ct.layers.1.linear1.weight
transformer_encoder_ct.layers.1.linear1.bias
transformer_encoder_ct.layers.1.linear2.weight
transfor

In [72]:
a = "transformer_encoder.layers.1.norm2.weight"

# Split into words
words = a.split(".")  
words[0]=words[0]+"_ct"
# Join with a hyphen
b = ".".join(words) 
print(a)
b

transformer_encoder.layers.1.norm2.weight


'transformer_encoder_ct.layers.1.norm2.weight'

In [129]:
ct_path="../models/petct_online_rad_dino_cross_ct/rad_dino_transformer_santa_maria/ct/kfold_0/model_epoch_0002.pth"
pet_path="../models/petct_online_rad_dino_cross_pet/rad_dino_transformer_santa_maria/pet/kfold_0/model_epoch_0015.pth"
ctp=torch.load(ct_path, map_location="cuda:0",weights_only=True)
petp=torch.load(pet_path, map_location="cuda:0",weights_only=True)

  ctp=torch.load(ct_path, map_location="cuda:0")


In [130]:
for param in ctp.parameters():
    param.requires_grad = False

AttributeError: 'collections.OrderedDict' object has no attribute 'parameters'

In [128]:
ctp

OrderedDict([('cls_token',
              tensor([[[ 2.5748e+00,  7.1870e-01, -7.1926e-02, -6.0352e-02, -5.3397e-01,
                        -9.0280e-01,  9.2894e-01,  9.9280e-01,  2.4266e+00, -4.8779e-02,
                        -8.1178e-02, -7.9223e-01, -9.6928e-02, -8.6002e-02,  5.6291e-02,
                         4.4645e-01, -5.7181e-01, -3.7841e-01, -1.2197e+00, -1.8756e+00,
                        -7.1101e-01, -1.2137e+00,  1.0869e+00,  4.5417e-01, -1.1013e+00,
                         5.0657e-01,  3.8715e-01, -1.3928e+00, -8.4868e-01, -3.5549e-01,
                        -2.6235e-01,  1.2615e+00, -1.7479e+00,  5.4597e-01,  7.6204e-01,
                        -1.5567e+00, -8.7593e-01,  1.0196e+00,  2.3928e+00,  1.5347e+00,
                        -1.9434e+00, -5.6123e-01,  8.5640e-01,  2.2427e+00, -2.3693e-01,
                         4.8252e-01,  5.5330e-01, -5.3863e-02,  5.4210e-01,  1.1532e-01,
                        -2.7465e-01,  7.6971e-01, -2.7006e-01,  1.8909e-02,  6.3825

In [105]:
params_ct = {add_mode(k,"ct"):v for k, v in ctp.items() if k.startswith(('cls_token','transformer_encoder','norm'))}
params_pet = {add_mode(k,"pet"):v for k, v in petp.items() if k.startswith(('cls_token','transformer_encoder','norm'))}
params_ct.update(params_pet)


In [114]:
bimodel.load_state_dict(params_ct, strict=False)

_IncompatibleKeys(missing_keys=['classifier_ct.dense1.weight', 'classifier_ct.dense1.bias', 'classifier_ct.dense2.weight', 'classifier_ct.dense2.bias', 'classifier_pet.dense1.weight', 'classifier_pet.dense1.bias', 'classifier_pet.dense2.weight', 'classifier_pet.dense2.bias', 'projection_petct.dense1.weight', 'projection_petct.dense1.bias', 'projection_petct.dense2.weight', 'projection_petct.dense2.bias', 'cross_attention_ct.multihead_attn.in_proj_weight', 'cross_attention_ct.multihead_attn.in_proj_bias', 'cross_attention_ct.multihead_attn.out_proj.weight', 'cross_attention_ct.multihead_attn.out_proj.bias', 'cross_attention_pet.multihead_attn.in_proj_weight', 'cross_attention_pet.multihead_attn.in_proj_bias', 'cross_attention_pet.multihead_attn.out_proj.weight', 'cross_attention_pet.multihead_attn.out_proj.bias', 'classifier_petct.dense1.weight', 'classifier_petct.dense1.bias', 'classifier_petct.dense2.weight', 'classifier_petct.dense2.bias'], unexpected_keys=[])

In [109]:
elpp = {k:v for k, v in params_ct.items() if "norm" in k}
for k,v in elpp.items():
    print(k)

norm_ct.weight
norm_ct.bias
transformer_encoder_ct.layers.0.norm1.weight
transformer_encoder_ct.layers.0.norm1.bias
transformer_encoder_ct.layers.0.norm2.weight
transformer_encoder_ct.layers.0.norm2.bias
transformer_encoder_ct.layers.1.norm1.weight
transformer_encoder_ct.layers.1.norm1.bias
transformer_encoder_ct.layers.1.norm2.weight
transformer_encoder_ct.layers.1.norm2.bias
norm_pet.weight
norm_pet.bias
transformer_encoder_pet.layers.0.norm1.weight
transformer_encoder_pet.layers.0.norm1.bias
transformer_encoder_pet.layers.0.norm2.weight
transformer_encoder_pet.layers.0.norm2.bias
transformer_encoder_pet.layers.1.norm1.weight
transformer_encoder_pet.layers.1.norm1.bias
transformer_encoder_pet.layers.1.norm2.weight
transformer_encoder_pet.layers.1.norm2.bias


In [100]:
ls ../models/petct_online_rad_dino_cross_ct/rad_dino_transformer_santa_maria/ct/kfold_0/model_epoch_0002.pth

losses.html           test_metrics_2.json    train_metrics_15.json
model_epoch_0000.pth  test_metrics_3.json    train_metrics_16.json
model_epoch_0001.pth  test_metrics_4.json    train_metrics_17.json
model_epoch_0002.pth  test_metrics_5.json    train_metrics_1.json
test_metrics_0.json   test_metrics_6.json    train_metrics_2.json
test_metrics_10.json  test_metrics_7.json    train_metrics_3.json
test_metrics_11.json  test_metrics_8.json    train_metrics_4.json
test_metrics_12.json  test_metrics_9.json    train_metrics_5.json
test_metrics_13.json  train_metrics_0.json   train_metrics_6.json
test_metrics_14.json  train_metrics_10.json  train_metrics_7.json
test_metrics_15.json  train_metrics_11.json  train_metrics_8.json
test_metrics_16.json  train_metrics_12.json  train_metrics_9.json
test_metrics_17.json  train_metrics_13.json
test_metrics_1.json   train_metrics_14.json


In [115]:
type(True)

bool

In [116]:
def abc(a,b):
    a=a+5
    return b

In [117]:
aa=10
bb=200
abc(aa,bb)

200

In [119]:
bb

200

In [122]:
bimodel.transformer_encoder_ct

TransformerEncoder(
  (layers): ModuleList(
    (0-1): 2 x TransformerEncoderLayer(
      (self_attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
      )
      (linear1): Linear(in_features=768, out_features=3072, bias=True)
      (dropout): Dropout(p=0.5, inplace=False)
      (linear2): Linear(in_features=3072, out_features=768, bias=True)
      (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout1): Dropout(p=0.5, inplace=False)
      (dropout2): Dropout(p=0.5, inplace=False)
    )
  )
)

In [123]:
for param in bimodel.transformer_encoder_ct.parameters():
    param.requires_grad = False

In [124]:
for name, param in bimodel.named_parameters():
    print("{} {}".format(name, param.requires_grad))


cls_token_ct True
cls_token_pet True
transformer_encoder_ct.layers.0.self_attn.in_proj_weight False
transformer_encoder_ct.layers.0.self_attn.in_proj_bias False
transformer_encoder_ct.layers.0.self_attn.out_proj.weight False
transformer_encoder_ct.layers.0.self_attn.out_proj.bias False
transformer_encoder_ct.layers.0.linear1.weight False
transformer_encoder_ct.layers.0.linear1.bias False
transformer_encoder_ct.layers.0.linear2.weight False
transformer_encoder_ct.layers.0.linear2.bias False
transformer_encoder_ct.layers.0.norm1.weight False
transformer_encoder_ct.layers.0.norm1.bias False
transformer_encoder_ct.layers.0.norm2.weight False
transformer_encoder_ct.layers.0.norm2.bias False
transformer_encoder_ct.layers.1.self_attn.in_proj_weight False
transformer_encoder_ct.layers.1.self_attn.in_proj_bias False
transformer_encoder_ct.layers.1.self_attn.out_proj.weight False
transformer_encoder_ct.layers.1.self_attn.out_proj.bias False
transformer_encoder_ct.layers.1.linear1.weight False
tr