In [1]:
import os
import sys
import random
import time
from copy import deepcopy
from pathlib import Path

import h5py
import numpy as np
from tqdm import tqdm
import nibabel as nib
from monai import data, transforms as mt

import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.model_selection import KFold
from sklearn.preprocessing import StandardScaler

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data.dataset import Dataset

  warn(f"Failed to load image Python extension: {e}")


In [2]:
# target/crop shape for the images and masks when training
tar_shape = (256, 256)
crop_shape = (224, 224)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
dict_args = {
    "random_seed": 42,
    "checkpoint": "output_model",
    "predict_mode": False,
    "workers": 4,
    "batch_size": 2,
    "max_epoch": 3,
    "lr": 1e-3,
    "decay": 1e-3,
    "lr_factor": 0.5,
    "min_lr": 5e-7,
    "lr_scheduler": "ReduceLROnPlateau"
    }

def normalize(data):
    data = (data - data.mean()) / data.std()
    return data



In [3]:


class Attention(nn.Module):
    def __init__(self,
                 channels,
                 num_heads,
                 proj_drop=0.0,
                 kernel_size=3,
                 stride_kv=1,
                 stride_q=1,
                 padding_kv="same",
                 padding_q="same",
                 attention_bias=True
                 ):
        super().__init__()
        self.stride_kv = stride_kv
        self.stride_q = stride_q
        self.num_heads = num_heads
        self.proj_drop = proj_drop
        
        self.conv_q = nn.Conv2d(channels, channels, kernel_size, stride_q, padding_q, bias=attention_bias, groups=channels)
        self.layernorm_q = nn.LayerNorm(channels, eps=1e-5)
        self.conv_k = nn.Conv2d(channels, channels, kernel_size, stride_kv, stride_kv, bias=attention_bias, groups=channels)
        self.layernorm_k = nn.LayerNorm(channels, eps=1e-5)
        self.conv_v = nn.Conv2d(channels, channels, kernel_size, stride_kv, stride_kv, bias=attention_bias, groups=channels)
        self.layernorm_v = nn.LayerNorm(channels, eps=1e-5)
        
        self.attention = nn.MultiheadAttention(embed_dim=channels, 
                                               bias=attention_bias, 
                                               batch_first=True,
                                               # dropout = 0.0,
                                               num_heads=1)#num_heads=self.num_heads)

    def _build_projection(self, x, qkv):

        
        if qkv == "q":
            x1 = F.relu(self.conv_q(x))
            x1 = x1.permute(0, 2, 3, 1)
            x1 = self.layernorm_q(x1)
            proj = x1.permute(0, 3, 1, 2)
        elif qkv == "k":
            x1 = F.relu(self.conv_k(x))
            x1 = x1.permute(0, 2, 3, 1)
            x1 = self.layernorm_k(x1)
            proj = x1.permute(0, 3, 1, 2)            
        elif qkv == "v":
            x1 = F.relu(self.conv_v(x))
            x1 = x1.permute(0, 2, 3, 1)
            x1 = self.layernorm_v(x1)
            proj = x1.permute(0, 3, 1, 2)        

        return proj

    def forward_conv(self, x):
        q = self._build_projection(x, "q")
        k = self._build_projection(x, "k")
        v = self._build_projection(x, "v")

        return q, k, v

    def forward(self, x):
        q, k, v = self.forward_conv(x)
        q = q.view(x.shape[0], x.shape[1], x.shape[2]*x.shape[3])
        k = k.view(x.shape[0], x.shape[1], x.shape[2]*x.shape[3])
        v = v.view(x.shape[0], x.shape[1], x.shape[2]*x.shape[3])
        q = q.permute(0, 2, 1)
        k = k.permute(0, 2, 1)
        v = v.permute(0, 2, 1)
        x1 = self.attention(query=q, value=v, key=k, need_weights=False)
        
        x1 = x1[0].permute(0, 2, 1)
        x1 = x1.view(x1.shape[0], x1.shape[1], np.sqrt(x1.shape[2]).astype(int), np.sqrt(x1.shape[2]).astype(int))
        x1 = F.dropout(x1, self.proj_drop)

        return x1
 


In [4]:
class Transformer(nn.Module):

    def __init__(self,
                 # in_channels,
                 out_channels,
                 num_heads,
                 dpr,
                 proj_drop=0.0,
                 attention_bias=True,
                 padding_q="same",
                 padding_kv="same",
                 stride_kv=1,
                 stride_q=1):
        super().__init__()
        
        self.attention_output = Attention(channels=out_channels,
                                         num_heads=num_heads,
                                         proj_drop=proj_drop,
                                         padding_q=padding_q,
                                         padding_kv=padding_kv,
                                         stride_kv=stride_kv,
                                         stride_q=stride_q,
                                         attention_bias=attention_bias,
                                         )

        self.conv1 = nn.Conv2d(out_channels, out_channels, 3, 1, padding="same")
        self.layernorm = nn.LayerNorm(self.conv1.out_channels, eps=1e-5)
        self.wide_focus = Wide_Focus(out_channels, out_channels)

    def forward(self, x):
        x1 = self.attention_output(x)
        x1 = self.conv1(x1)
        x2 = torch.add(x1, x)
        x3 = x2.permute(0, 2, 3, 1)
        x3 = self.layernorm(x3)
        x3 = x3.permute(0, 3, 1, 2)
        x3 = self.wide_focus(x3)
        x3 = torch.add(x2, x3)
        return x3




In [5]:
    
class Wide_Focus(nn.Module): 
    """
    Wide-Focus module.
    """
    def __init__(self,
                 in_channels,
                 out_channels):
        super().__init__()
        
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, padding="same")
        self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, padding="same", dilation=2)
        self.conv3 = nn.Conv2d(in_channels, out_channels, 3, 1, padding="same", dilation=3)
        self.conv4 = nn.Conv2d(in_channels, out_channels, 3, 1, padding="same")


    def forward(self, x):
        x1 = self.conv1(x)
        x1 = F.gelu(x1)
        x1 = F.dropout(x1, 0.1)
        x2 = self.conv2(x)
        x2 = F.gelu(x2)
        x2 = F.dropout(x2, 0.1)
        x3 = self.conv3(x)
        x3 = F.gelu(x3)
        x3 = F.dropout(x3, 0.1)
        added = torch.add(x1, x2)
        added = torch.add(added, x3)
        x_out = self.conv4(added)
        x_out = F.gelu(x_out)
        x_out = F.dropout(x_out, 0.1)
        return x_out



In [6]:



class Block_decoder(nn.Module):
    def __init__(self, in_channels, out_channels, att_heads, dpr):
        super().__init__()
        self.layernorm = nn.LayerNorm(in_channels, eps=1e-5)
        self.upsample = nn.Upsample(scale_factor=2)
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, padding="same")
        self.conv2 = nn.Conv2d(out_channels*2, out_channels, 3, 1, padding="same")
        self.conv3 = nn.Conv2d(out_channels, out_channels, 3, 1, padding="same")
        self.trans = Transformer(out_channels, att_heads, dpr)
        
    def forward(self, x, skip):
        x1 = x.permute(0, 2, 3, 1)
        x1 = self.layernorm(x1)
        x1 = x1.permute(0, 3, 1, 2)
        x1 = self.upsample(x1)
        x1 = F.relu(self.conv1(x1))
        x1 = torch.cat((skip, x1), axis=1)
        x1 = F.relu(self.conv2(x1))
        x1 = F.relu(self.conv3(x1))
        x1 = F.dropout(x1, 0.3)
        out = self.trans(x1)
        return out




In [7]:


class DS_out(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.upsample = nn.Upsample(scale_factor=2)
        self.layernorm = nn.LayerNorm(in_channels, eps=1e-5)
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, 3, 1, padding="same"),
            nn.ReLU()
        ) 
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, 3, 1, padding="same"),
            nn.ReLU()
        ) 
        self.conv3 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, padding="same"),
            nn.Sigmoid()
        ) 

    def forward(self, x):
        x1 = self.upsample(x)
        x1 = x1.permute(0, 2, 3, 1)
        x1 = self.layernorm(x1)
        x1 = x1.permute(0, 3, 1, 2)
        x1 = self.conv1(x1)
        x1 = self.conv2(x1)
        out = self.conv3(x1)
        
        return out



In [8]:


class Block_encoder_bottleneck(nn.Module):
    def __init__(self, blk, in_channels, out_channels, att_heads, dpr):
        super().__init__()
        self.blk = blk
        if ((self.blk=="first") or (self.blk=="bottleneck")):
            self.layernorm = nn.LayerNorm(in_channels, eps=1e-5)
            self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, padding="same")
            self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, padding="same")
            self.trans = Transformer(out_channels, att_heads, dpr)
        elif ((self.blk=="second") or (self.blk=="third") or (self.blk=="fourth")):
            self.layernorm = nn.LayerNorm(in_channels, eps=1e-5)
            self.conv1 = nn.Conv2d(1, in_channels, 3, 1, padding="same")
            self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, padding="same")
            self.conv3 = nn.Conv2d(out_channels, out_channels, 3, 1, padding="same")
            self.trans = Transformer(out_channels, att_heads, dpr)


    def forward(self, x, scale_img="none"):
        if ((self.blk=="first") or (self.blk=="bottleneck")):
            x1 = x.permute(0, 2, 3, 1)
            x1 = self.layernorm(x1)
            x1 = x1.permute(0, 3, 1, 2)
            x1 = F.relu(self.conv1(x1))
            x1 = F.relu(self.conv2(x1))
            x1 = F.dropout(x1, 0.3)
            x1 = F.max_pool2d(x1, (2,2))
            out = self.trans(x1)
            # without skip
        elif ((self.blk=="second") or (self.blk=="third") or (self.blk=="fourth")):
            x1 = x.permute(0, 2, 3, 1)
            x1 = self.layernorm(x1)
            x1 = x1.permute(0, 3, 1, 2)
            x1 = torch.cat((F.relu(self.conv1(scale_img)), x1), axis=1)
            x1 = F.relu(self.conv2(x1))
            x1 = F.relu(self.conv3(x1))
            x1 = F.dropout(x1, 0.3)
            x1 = F.max_pool2d(x1, (2,2))
            out = self.trans(x1)
            # with skip
        return out




In [9]:
class FCT_Head(nn.Module):
    def __init__(self) -> None:
        super().__init__()

        h_attent_head = [2, 2, 2, 2, 2]
        filters = [8, 16, 32, 64, 128,]
        # number of blocks used in the model
        blocks = len(filters)

        stochastic_depth_rate = 0.0

        #probability for each block
        dpr = [x for x in np.linspace(0, stochastic_depth_rate, blocks)]

        # Multi-scale input
        self.scale_img = nn.AvgPool2d(2,2)   

        # model
        self.block_1 = Block_encoder_bottleneck("first", 1, filters[0], h_attent_head[0], dpr[0])
        self.block_2 = Block_encoder_bottleneck("second", filters[0], filters[1], h_attent_head[1], dpr[1])
        self.block_3 = Block_encoder_bottleneck("third", filters[1], filters[2], h_attent_head[2], dpr[2])
        self.block_4 = Block_encoder_bottleneck("fourth", filters[2], filters[3], h_attent_head[3], dpr[3])
        self.block_5 = Block_encoder_bottleneck("bottleneck", filters[3], filters[4], h_attent_head[4], dpr[4])
    
    def forward(self, x):
        # Multi-scale input
        scale_img_2 = self.scale_img(x)
        scale_img_3 = self.scale_img(scale_img_2)
        scale_img_4 = self.scale_img(scale_img_3)  

        x = self.block_1(x)
        print(f"Block 1 out -> {list(x.size())}")
        skip1 = x
        x = self.block_2(x, scale_img_2)
        print(f"Block 2 out -> {list(x.size())}")
        skip2 = x
        x = self.block_3(x, scale_img_3)
        print(f"Block 3 out -> {list(x.size())}")
        skip3 = x
        x = self.block_4(x, scale_img_4)
        print(f"Block 4 out -> {list(x.size())}")
        skip4 = x

        return {
            "skip1": skip1.cpu().detach().numpy(), 
            "skip2": skip2.cpu().detach().numpy(), 
            "skip3": skip3.cpu().detach().numpy(), 
            "skip4": skip4.cpu().detach().numpy(),
           }
 

In [10]:


   

class FCT_Body(nn.Module):
    def __init__(self, ) -> None:
        super().__init__()

        b_attent_body = [2, 2, 2, 2, 2]
        filters = [64, 128, 64, 32, 16, 8] 
        # number of blocks used in the model
        blocks = len(b_attent_body)

        stochastic_depth_rate = 0.0

        #probability for each block
        dpr = [x for x in np.linspace(0, stochastic_depth_rate, blocks)]


        # model
        self.block_5 = Block_encoder_bottleneck("bottleneck", filters[0], filters[1], b_attent_body[0], dpr[0])
        self.block_6 = Block_decoder(filters[1], filters[2], b_attent_body[1], dpr[1])
        self.block_7 = Block_decoder(filters[2], filters[3], b_attent_body[2], dpr[2])
        self.block_8 = Block_decoder(filters[3], filters[4], b_attent_body[3], dpr[3])
        self.block_9 = Block_decoder(filters[4], filters[5], b_attent_body[4], dpr[4])
    
    def forward(self, skip1, skip2, skip3, skip4):
        
        x = self.block_5(skip4)
        print(f"Block 5 out -> {list(x.size())}")
        x = self.block_6(x, skip4)
        print(f"Block 6 out -> {list(x.size())}")
        x = self.block_7(x, skip3)
        print(f"Block 7 out -> {list(x.size())}")
        skip7 = x
        x = self.block_8(x, skip2)
        print(f"Block 8 out -> {list(x.size())}")
        skip8 = x
        x = self.block_9(x, skip1)
        print(f"Block 9 out -> {list(x.size())}")
        skip9 = x

        return {
        #     "skip7": skip7.cpu().detach().numpy(), 
        #     "skip8": skip8.cpu().detach().numpy(), 
            "skip9": skip9.cpu().detach().numpy(),
           }



        

In [11]:


class FCT_Tail(nn.Module):
    def __init__(self) -> None:
        super().__init__()

        filters = [32, 16, 8] 
        # number of blocks used in the model

        # self.ds7 = DS_out(filters[0], 1)
        # self.ds8 = DS_out(filters[1], 1)
        # self.ds9 = DS_out(filters[2], 1)
        self.ds10 = DS_out(filters[2], 1)
    
    def forward(self, # skip7, skip8, skip9):
                skip9):
        
        # out7 = self.ds7(skip7)
        # print(f"DS 7 out -> {list(out7.size())}")
        # out8 = self.ds8(skip8)
        # print(f"DS 8 out -> {list(out8.size())}")
        # out9 = self.ds9(skip9)
        # print(f"DS 9 out -> {list(out9.size())}")
        out10 = self.ds10(skip9)
        print(f"DS 10 out -> {list(out10.size())}")

        return out10

        

In [12]:
class ACDC_2D(Dataset):
    def __init__(self, source, ind, Transform=None):
        # basic transforms
        self.loader = mt.LoadImaged(keys=["image", "mask"])
        self.add_channel = mt.EnsureChannelFirstd(keys=["image", "mask"])
        self.spatial_pad = mt.SpatialPadD(keys=["image", "mask"], spatial_size=tar_shape, mode="edge")
        self.spacing = mt.Spacingd(keys=["image", "mask"], pixdim=(1.25, 1.25, -1.0), mode=("nearest", "nearest"))
        # index
        self.ind = ind
        # transform
        if Transform is not None:
            self.transform = Transform
        else:
            self.transform = mt.Compose([
                mt.SpatialPadD(keys=["image", "mask"], spatial_size=tar_shape, mode="edge"),
                mt.ToTensorD(keys=["image", "mask"], allow_missing_keys=False)
            ])

        # take the images
        source = Path(source)
        # dirs = os.listdir(str(source))  # stores patient name
        all_data_ed = []
        all_data_ed_mask = []
        all_data_es = []
        all_data_es_mask = []
        for filenames in source.iterdir():
            if filenames.is_dir():
                # patient_path = Path(str(source), filenames)  # individual patient path
                patient_info = str(filenames / "Info.cfg")  # patient information
                file = open(patient_info, 'r').readlines()
                ED_frame = int(file[0].split(":")[1])
                ES_frame = int(file[1].split(":")[1])
                ED = (filenames / f"{filenames.name}_frame{ED_frame:02d}.nii.gz")
                ES = (filenames / f"{filenames.name}_frame{ES_frame:02d}.nii.gz")
                ED_gt = (filenames / f"{filenames.name}_frame{ED_frame:02d}_gt.nii.gz")
                ES_gt = (filenames / f"{filenames.name}_frame{ES_frame:02d}_gt.nii.gz")
                all_data_ed.append(ED)
                all_data_ed_mask.append(ED_gt)
                all_data_es.append(ES)
                all_data_es_mask.append(ES_gt)

        if self.ind is not None:
            all_data_ed = [all_data_ed[i] for i in self.ind]
            all_data_ed_mask = [all_data_ed_mask[i] for i in self.ind]
            all_data_es = [all_data_es[i] for i in self.ind]
            all_data_es_mask = [all_data_es_mask[i] for i in self.ind]

        self.data = [all_data_ed, all_data_ed_mask, all_data_es, all_data_es_mask]

    def __len__(self):
        return len(self.data[0])

    def __getitem__(self, idx):
        ED_img, ED_mask, ES_img, ES_mask = self.data[0][idx], self.data[1][idx], self.data[2][idx], self.data[3][idx]
        # data dict
        ED_data_dict = {"image": ED_img,
                        "mask": ED_mask}
        ES_data_dict = {"image": ES_img,
                        "mask": ES_mask}
        # instead of returning both ED and ES, I have to return just a random choice between ED and ES(image and mask)
        datalist = [ED_data_dict, ES_data_dict]
        data_return = np.random.choice(datalist)
        data_return = self.loader(data_return)
        data_return = self.add_channel(data_return)
        data_return = self.spacing(data_return)
        data_return["image"] = normalize(data_return["image"])
        num_slice = data_return["image"].shape[3]
        random_slice = random.randint(0, num_slice - 1)
        data_return["image"] = data_return["image"][:, :, :, random_slice]
        data_return["image"] = normalize(data_return["image"])
        data_return["mask"] = data_return["mask"][:, :, :, random_slice]
        data_return = self.transform(data_return)
        return data_return

In [13]:
def train_loader_ACDC(train_index, data_path=r"../dataset/train_rosfl", transform=None):
    train_loader = ACDC_2D(source=data_path, Transform=transform, ind=train_index)
    return train_loader


def val_loader_ACDC(val_index, data_path=r"../dataset/train_rosfl", transform=None):
    val_loader = ACDC_2D(source=data_path, Transform=transform, ind=val_index)
    return val_loader


def test_loader_ACDC(test_index, data_path=r"../dataset/testing", transform=None):
    test_loader = ACDC_2D(source=data_path, Transform=transform, ind=test_index)
    return test_loader

In [14]:
""" To test if the dataloader works """

train_compose = mt.Compose(
    [mt.SpatialPadD(keys=["image", "mask"], spatial_size=tar_shape, mode="edge"),
     mt.RandSpatialCropD(keys=["image", "mask"], roi_size=crop_shape, random_center=True, random_size=False),
     mt.ToTensorD(keys=["image", "mask"], allow_missing_keys=False),
     ]
)

val_compose = mt.Compose(
    [
        mt.SpatialPadD(keys=["image", "mask"], spatial_size=tar_shape, mode="edge"),
        mt.RandSpatialCropD(keys=["image", "mask"], roi_size=crop_shape, random_center=True, random_size=False),
        mt.ToTensorD(keys=["image", "mask"], allow_missing_keys=False),
    ]
)

test_compose = mt.Compose(
    [
        mt.SpatialPadD(keys=["image", "mask"], spatial_size=tar_shape, mode="edge"),
        mt.RandSpatialCropD(keys=["image", "mask"], roi_size=crop_shape, random_center=True, random_size=False),
        mt.ToTensorD(keys=["image", "mask"], allow_missing_keys=False),
    ]
)

In [15]:
splits = KFold(n_splits=3, shuffle=True, random_state=42)

concatenated_dataset = train_loader_ACDC(transform=None, train_index=None)

In [16]:
for fold, (train_idx, val_idx) in enumerate(splits.split(np.arange(len(concatenated_dataset)))):

    print("--------------------------", "Fold", fold + 1, "--------------------------")

    # training dataset
    training_data = DataLoader(train_loader_ACDC(transform=train_compose, train_index=train_idx), batch_size=2,
                               shuffle=False)
    print("train from here", len(training_data))
    # for dic in training_data:
    #     images = dic["image"]
    #     masks = dic["mask"]
    #     print(images.shape, masks.shape)
    #     image, label = dic["image"], dic["mask"]
    #     plt.figure("visualise", (8, 4))
    #     plt.subplot(1, 2, 1)
    #     plt.title("image")
    #     plt.imshow(image[0, 0, :, :], cmap="gray")
    #     plt.subplot(1, 2, 2)
    #     plt.title("mask")
    #     plt.imshow(label[0, 0, :, :], cmap="gray")
    #     plt.show()
    #     break

    # validation dataset
    validation_data = DataLoader(val_loader_ACDC(transform=val_compose, val_index=val_idx), batch_size=1,
                                 shuffle=False)
    print("val from here", len(validation_data))
    # for dic in validation_data:
    #     images = dic["image"]
    #     masks = dic["mask"]
    #     print(images.shape, masks.shape)
    #     image, label = dic["image"], dic["mask"]
    #     plt.figure("visualise", (8, 4))
    #     plt.subplot(1, 2, 1)
    #     plt.title("image")
    #     plt.imshow(image[0, 0, :, :], cmap="gray")
    #     plt.subplot(1, 2, 2)
    #     plt.title("mask")
    #     plt.imshow(label[0, 0, :, :], cmap="gray")
    #     plt.show()
    #     break

    # test dataset
    # ========================== TEST Data ===================
    # test_data = DataLoader(test_loader_ACDC(transform=test_compose, test_index=None), batch_size=1, shuffle=False)
    # ========================== TEST Data ===================
    # print("test from here")
    # for dic in test_data:
    #     images = dic["image"]
    #     masks = dic["mask"]
    #     print(images.shape, masks.shape)
    #     image, label = dic["image"], dic["mask"]
    #     plt.figure("visualise", (8, 4))
    #     plt.subplot(1, 2, 1)
    #     plt.title("image")
    #     plt.imshow(image[0, 0, :, :], cmap="gray")
    #     plt.subplot(1, 2, 2)
    #     plt.title("mask")
    #     plt.imshow(label[0, 0, :, :], cmap="gray")
    #     plt.show()
    #     break

-------------------------- Fold 1 --------------------------
train from here 20
val from here 20
-------------------------- Fold 2 --------------------------
train from here 20
val from here 20
-------------------------- Fold 3 --------------------------
train from here 20
val from here 20


In [17]:
from monai.losses.dice import GeneralizedDiceLoss

def init_weights(m):
    """
    Initialize the weights
    """
    if isinstance(m, nn.Conv2d):
        torch.nn.init.kaiming_normal_(m.weight)
        if m.bias is not None:
            torch.nn.init.zeros_(m.bias)

loss_fn = nn.BCELoss()

dic_loss_fn = GeneralizedDiceLoss(to_onehot_y=True, softmax=True)

In [18]:
# =======================================================================
#                                HEAD
# =======================================================================

model_head = FCT_Head()
model_head.apply(init_weights)

optimizer_head = torch.optim.Adam(model_head.parameters(), lr=dict_args['lr'],weight_decay=dict_args['decay'])

scheduler_head = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer_head,
            mode='min',
            factor=dict_args['lr_factor'],
            verbose=True,
            threshold=1e-6,
            patience=10,
            min_lr=dict_args['min_lr'])

model_head.to(device)



# =======================================================================
#                                BODY
# =======================================================================

model_body = FCT_Body()
model_body.apply(init_weights)

optimizer_body = torch.optim.AdamW(model_body.parameters(), lr=dict_args['lr'],weight_decay=dict_args['decay'])

scheduler_body = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer_body,
            mode='min',
            factor=dict_args['lr_factor'],
            verbose=True,
            threshold=1e-6,
            patience=10,
            min_lr=dict_args['min_lr'])

model_body.to(device)

# =======================================================================
#                                TAIL
# =======================================================================

model_tail = FCT_Tail()
model_tail.apply(init_weights)

optimizer_tail = torch.optim.AdamW(model_tail.parameters(), lr=dict_args['lr'],weight_decay=dict_args['decay'])

scheduler_tail = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer_tail,
            mode='min',
            factor=dict_args['lr_factor'],
            verbose=True,
            threshold=1e-6,
            patience=10,
            min_lr=dict_args['min_lr'])

model_tail.to(device)

print("Initialized ....")

Initialized ....


In [19]:
model_head.train()

try:
    fh5_head = h5py.File('params_and_grads/h5_head_values.hdf5', 'w') 
    fh5_label = h5py.File('params_and_grads/h5_train_label.hdf5', 'w')
    for index, train_dict in tqdm(enumerate(training_data), total=len(training_data)):
        print("index value is ", index)
        X_train = train_dict["image"]
        y_train = train_dict["mask"]
        X_train = X_train.to(device)

        layer_data = model_head(X_train)

        grp_head = fh5_head.create_group(f'IterKey_{index}')
        for k, v in layer_data.items():
            grp_head.create_dataset(k, data=v)
        
        grp_label = fh5_label.create_group(f'IterKey_{index}')
        grp_label.create_dataset("tlabel", data=y_train.cpu().detach().numpy())
except Exception as ex:
    import traceback
    print("+=" * 25)
    print("Error encountered as :", ex)
    print("+=" * 25)
    traceback.print_exc()

finally:
    fh5_head.close()
    fh5_label.close()

  0%|          | 0/20 [00:00<?, ?it/s]

index value is  0
Block 1 out -> [2, 8, 112, 112]
Block 2 out -> [2, 16, 56, 56]
Block 3 out -> [2, 32, 28, 28]


  5%|▌         | 1/20 [00:02<00:53,  2.80s/it]

Block 4 out -> [2, 64, 14, 14]
index value is  1


 10%|█         | 2/20 [00:05<00:46,  2.59s/it]

Block 1 out -> [2, 8, 112, 112]
Block 2 out -> [2, 16, 56, 56]
Block 3 out -> [2, 32, 28, 28]
Block 4 out -> [2, 64, 14, 14]
index value is  2
Block 1 out -> [2, 8, 112, 112]
Block 2 out -> [2, 16, 56, 56]


 15%|█▌        | 3/20 [00:07<00:43,  2.57s/it]

Block 3 out -> [2, 32, 28, 28]
Block 4 out -> [2, 64, 14, 14]
index value is  3
Block 1 out -> [2, 8, 112, 112]
Block 2 out -> [2, 16, 56, 56]
Block 3 out -> [2, 32, 28, 28]


 20%|██        | 4/20 [00:10<00:41,  2.57s/it]

Block 4 out -> [2, 64, 14, 14]
index value is  4
Block 1 out -> [2, 8, 112, 112]
Block 2 out -> [2, 16, 56, 56]
Block 3 out -> [2, 32, 28, 28]


 25%|██▌       | 5/20 [00:12<00:38,  2.57s/it]

Block 4 out -> [2, 64, 14, 14]
index value is  5
Block 1 out -> [2, 8, 112, 112]
Block 2 out -> [2, 16, 56, 56]
Block 3 out -> [2, 32, 28, 28]
Block 4 out -> [2, 64, 14, 14]


 30%|███       | 6/20 [00:15<00:35,  2.54s/it]

index value is  6
Block 1 out -> [2, 8, 112, 112]
Block 2 out -> [2, 16, 56, 56]
Block 3 out -> [2, 32, 28, 28]
Block 4 out -> [2, 64, 14, 14]


 35%|███▌      | 7/20 [00:17<00:32,  2.53s/it]

index value is  7
Block 1 out -> [2, 8, 112, 112]
Block 2 out -> [2, 16, 56, 56]
Block 3 out -> [2, 32, 28, 28]
Block 4 out -> [2, 64, 14, 14]


 40%|████      | 8/20 [00:20<00:30,  2.54s/it]

index value is  8
Block 1 out -> [2, 8, 112, 112]
Block 2 out -> [2, 16, 56, 56]
Block 3 out -> [2, 32, 28, 28]


 45%|████▌     | 9/20 [00:23<00:27,  2.54s/it]

Block 4 out -> [2, 64, 14, 14]
index value is  9


 50%|█████     | 10/20 [00:25<00:25,  2.56s/it]

Block 1 out -> [2, 8, 112, 112]
Block 2 out -> [2, 16, 56, 56]
Block 3 out -> [2, 32, 28, 28]
Block 4 out -> [2, 64, 14, 14]
index value is  10


 55%|█████▌    | 11/20 [00:28<00:22,  2.52s/it]

Block 1 out -> [2, 8, 112, 112]
Block 2 out -> [2, 16, 56, 56]
Block 3 out -> [2, 32, 28, 28]
Block 4 out -> [2, 64, 14, 14]
index value is  11


 60%|██████    | 12/20 [00:30<00:20,  2.51s/it]

Block 1 out -> [2, 8, 112, 112]
Block 2 out -> [2, 16, 56, 56]
Block 3 out -> [2, 32, 28, 28]
Block 4 out -> [2, 64, 14, 14]
index value is  12
Block 1 out -> [2, 8, 112, 112]
Block 2 out -> [2, 16, 56, 56]
Block 3 out -> [2, 32, 28, 28]
Block 4 out -> [2, 64, 14, 14]


 65%|██████▌   | 13/20 [00:33<00:17,  2.50s/it]

index value is  13
Block 1 out -> [2, 8, 112, 112]
Block 2 out -> [2, 16, 56, 56]
Block 3 out -> [2, 32, 28, 28]
Block 4 out -> [2, 64, 14, 14]


 70%|███████   | 14/20 [00:35<00:14,  2.49s/it]

index value is  14
Block 1 out -> [2, 8, 112, 112]
Block 2 out -> [2, 16, 56, 56]
Block 3 out -> [2, 32, 28, 28]


 75%|███████▌  | 15/20 [00:37<00:12,  2.50s/it]

Block 4 out -> [2, 64, 14, 14]
index value is  15
Block 1 out -> [2, 8, 112, 112]
Block 2 out -> [2, 16, 56, 56]
Block 3 out -> [2, 32, 28, 28]
Block 4 out -> [2, 64, 14, 14]


 80%|████████  | 16/20 [00:40<00:10,  2.51s/it]

index value is  16


 85%|████████▌ | 17/20 [00:43<00:07,  2.51s/it]

Block 1 out -> [2, 8, 112, 112]
Block 2 out -> [2, 16, 56, 56]
Block 3 out -> [2, 32, 28, 28]
Block 4 out -> [2, 64, 14, 14]
index value is  17
Block 1 out -> [2, 8, 112, 112]
Block 2 out -> [2, 16, 56, 56]
Block 3 out -> [2, 32, 28, 28]


 90%|█████████ | 18/20 [00:45<00:05,  2.56s/it]

Block 4 out -> [2, 64, 14, 14]
index value is  18
Block 1 out -> [2, 8, 112, 112]
Block 2 out -> [2, 16, 56, 56]
Block 3 out -> [2, 32, 28, 28]
Block 4 out -> [2, 64, 14, 14]


 95%|█████████▌| 19/20 [00:48<00:02,  2.52s/it]

index value is  19
Block 1 out -> [2, 8, 112, 112]
Block 2 out -> [2, 16, 56, 56]
Block 3 out -> [2, 32, 28, 28]
Block 4 out -> [2, 64, 14, 14]


100%|██████████| 20/20 [00:50<00:00,  2.53s/it]


In [20]:
model_body.train()

fh5_body = h5py.File('params_and_grads/h5_body_values.hdf5', 'w')

try:
    with h5py.File('params_and_grads/h5_head_values.hdf5', 'r') as f_head:
        for key, grp in tqdm(f_head.items(), total=len(f_head)):
            skip_1 = torch.from_numpy(grp['skip1'][:]).to(device)
            skip_2 = torch.from_numpy(grp['skip2'][:]).to(device)
            skip_3 = torch.from_numpy(grp['skip3'][:]).to(device)     
            skip_4 = torch.from_numpy(grp['skip4'][:]).to(device)

            bd_layer_data = model_body(skip_1, skip_2, skip_3, skip_4)

            bgrp = fh5_body.create_group(key)
            for k,v in bd_layer_data.items():
                bgrp.create_dataset(k, data=v)
except Exception as e:
    import traceback
    traceback.print_exc()
    fh5_body.close()


  0%|          | 0/20 [00:00<?, ?it/s]

Block 5 out -> [2, 128, 7, 7]
Block 6 out -> [2, 64, 14, 14]
Block 7 out -> [2, 32, 28, 28]
Block 8 out -> [2, 16, 56, 56]


  5%|▌         | 1/20 [00:02<00:42,  2.26s/it]

Block 9 out -> [2, 8, 112, 112]
Block 5 out -> [2, 128, 7, 7]
Block 6 out -> [2, 64, 14, 14]
Block 7 out -> [2, 32, 28, 28]
Block 8 out -> [2, 16, 56, 56]


 10%|█         | 2/20 [00:04<00:39,  2.19s/it]

Block 9 out -> [2, 8, 112, 112]
Block 5 out -> [2, 128, 7, 7]
Block 6 out -> [2, 64, 14, 14]
Block 7 out -> [2, 32, 28, 28]
Block 8 out -> [2, 16, 56, 56]


 15%|█▌        | 3/20 [00:06<00:37,  2.20s/it]

Block 9 out -> [2, 8, 112, 112]
Block 5 out -> [2, 128, 7, 7]
Block 6 out -> [2, 64, 14, 14]
Block 7 out -> [2, 32, 28, 28]
Block 8 out -> [2, 16, 56, 56]


 20%|██        | 4/20 [00:08<00:35,  2.22s/it]

Block 9 out -> [2, 8, 112, 112]
Block 5 out -> [2, 128, 7, 7]
Block 6 out -> [2, 64, 14, 14]
Block 7 out -> [2, 32, 28, 28]
Block 8 out -> [2, 16, 56, 56]


 25%|██▌       | 5/20 [00:10<00:32,  2.19s/it]

Block 9 out -> [2, 8, 112, 112]
Block 5 out -> [2, 128, 7, 7]
Block 6 out -> [2, 64, 14, 14]
Block 7 out -> [2, 32, 28, 28]
Block 8 out -> [2, 16, 56, 56]


 30%|███       | 6/20 [00:13<00:30,  2.15s/it]

Block 9 out -> [2, 8, 112, 112]
Block 5 out -> [2, 128, 7, 7]
Block 6 out -> [2, 64, 14, 14]
Block 7 out -> [2, 32, 28, 28]
Block 8 out -> [2, 16, 56, 56]


 35%|███▌      | 7/20 [00:15<00:27,  2.14s/it]

Block 9 out -> [2, 8, 112, 112]
Block 5 out -> [2, 128, 7, 7]
Block 6 out -> [2, 64, 14, 14]
Block 7 out -> [2, 32, 28, 28]
Block 8 out -> [2, 16, 56, 56]


 40%|████      | 8/20 [00:17<00:25,  2.14s/it]

Block 9 out -> [2, 8, 112, 112]
Block 5 out -> [2, 128, 7, 7]
Block 6 out -> [2, 64, 14, 14]
Block 7 out -> [2, 32, 28, 28]
Block 8 out -> [2, 16, 56, 56]


 45%|████▌     | 9/20 [00:19<00:23,  2.15s/it]

Block 9 out -> [2, 8, 112, 112]
Block 5 out -> [2, 128, 7, 7]
Block 6 out -> [2, 64, 14, 14]
Block 7 out -> [2, 32, 28, 28]
Block 8 out -> [2, 16, 56, 56]


 50%|█████     | 10/20 [00:21<00:21,  2.18s/it]

Block 9 out -> [2, 8, 112, 112]
Block 5 out -> [2, 128, 7, 7]
Block 6 out -> [2, 64, 14, 14]
Block 7 out -> [2, 32, 28, 28]
Block 8 out -> [2, 16, 56, 56]


 55%|█████▌    | 11/20 [00:23<00:19,  2.17s/it]

Block 9 out -> [2, 8, 112, 112]
Block 5 out -> [2, 128, 7, 7]
Block 6 out -> [2, 64, 14, 14]
Block 7 out -> [2, 32, 28, 28]
Block 8 out -> [2, 16, 56, 56]


 60%|██████    | 12/20 [00:26<00:17,  2.16s/it]

Block 9 out -> [2, 8, 112, 112]
Block 5 out -> [2, 128, 7, 7]
Block 6 out -> [2, 64, 14, 14]
Block 7 out -> [2, 32, 28, 28]
Block 8 out -> [2, 16, 56, 56]


 65%|██████▌   | 13/20 [00:28<00:15,  2.16s/it]

Block 9 out -> [2, 8, 112, 112]
Block 5 out -> [2, 128, 7, 7]
Block 6 out -> [2, 64, 14, 14]
Block 7 out -> [2, 32, 28, 28]
Block 8 out -> [2, 16, 56, 56]


 70%|███████   | 14/20 [00:30<00:12,  2.15s/it]

Block 9 out -> [2, 8, 112, 112]
Block 5 out -> [2, 128, 7, 7]
Block 6 out -> [2, 64, 14, 14]
Block 7 out -> [2, 32, 28, 28]
Block 8 out -> [2, 16, 56, 56]


 75%|███████▌  | 15/20 [00:32<00:10,  2.13s/it]

Block 9 out -> [2, 8, 112, 112]
Block 5 out -> [2, 128, 7, 7]
Block 6 out -> [2, 64, 14, 14]
Block 7 out -> [2, 32, 28, 28]
Block 8 out -> [2, 16, 56, 56]


 80%|████████  | 16/20 [00:34<00:08,  2.13s/it]

Block 9 out -> [2, 8, 112, 112]
Block 5 out -> [2, 128, 7, 7]
Block 6 out -> [2, 64, 14, 14]
Block 7 out -> [2, 32, 28, 28]
Block 8 out -> [2, 16, 56, 56]


 85%|████████▌ | 17/20 [00:36<00:06,  2.14s/it]

Block 9 out -> [2, 8, 112, 112]
Block 5 out -> [2, 128, 7, 7]
Block 6 out -> [2, 64, 14, 14]
Block 7 out -> [2, 32, 28, 28]
Block 8 out -> [2, 16, 56, 56]


 90%|█████████ | 18/20 [00:38<00:04,  2.12s/it]

Block 9 out -> [2, 8, 112, 112]
Block 5 out -> [2, 128, 7, 7]
Block 6 out -> [2, 64, 14, 14]
Block 7 out -> [2, 32, 28, 28]
Block 8 out -> [2, 16, 56, 56]


 95%|█████████▌| 19/20 [00:41<00:02,  2.16s/it]

Block 9 out -> [2, 8, 112, 112]
Block 5 out -> [2, 128, 7, 7]
Block 6 out -> [2, 64, 14, 14]
Block 7 out -> [2, 32, 28, 28]
Block 8 out -> [2, 16, 56, 56]


100%|██████████| 20/20 [00:43<00:00,  2.16s/it]

Block 9 out -> [2, 8, 112, 112]





In [21]:
model_tail.train()
train_loss_list = []
grads_dict = {}
abs_grads_dict = {}


fh5_body = h5py.File('params_and_grads/h5_body_values.hdf5', 'r')
fh5_label = h5py.File('params_and_grads/h5_train_label.hdf5', 'r')

try:
    for (key, grp), (lkey, lgrp) in zip(fh5_body.items(), fh5_label.items()):

        if str(key) != str(lkey):
            print(f"Not the same key tail:: {key} and label:: {lkey}, data could be different ")
        
        # skip_7 = torch.tensor(grp['skip7'][:], requires_grad=True).to(device)
        # skip_8 = torch.tensor(grp['skip8'][:], requires_grad=True).to(device)
        skip_9 = torch.tensor(grp['skip9'][:], requires_grad=True).to(device)
        
        y_mask = torch.from_numpy(lgrp['tlabel'][:]).to(device)

        tl_output_data = model_tail(# skip_7, skip_8, 
            skip_9
            )
        
        loss = loss_fn(tl_output_data, y_mask)
        train_loss_list.append(loss)
        loss.backward()
        optimizer_tail.step()

except Exception as ex:
    import traceback
    print("+=" * 25)
    print("Error encountered as :", ex)
    print("+=" * 25)
    traceback.print_exc()

finally:
    fh5_body.close()
    fh5_label.close()


DS 10 out -> [2, 1, 224, 224]
DS 10 out -> [2, 1, 224, 224]
DS 10 out -> [2, 1, 224, 224]
DS 10 out -> [2, 1, 224, 224]
DS 10 out -> [2, 1, 224, 224]
DS 10 out -> [2, 1, 224, 224]
DS 10 out -> [2, 1, 224, 224]
DS 10 out -> [2, 1, 224, 224]
DS 10 out -> [2, 1, 224, 224]
DS 10 out -> [2, 1, 224, 224]
DS 10 out -> [2, 1, 224, 224]
DS 10 out -> [2, 1, 224, 224]
DS 10 out -> [2, 1, 224, 224]
DS 10 out -> [2, 1, 224, 224]
DS 10 out -> [2, 1, 224, 224]
DS 10 out -> [2, 1, 224, 224]
DS 10 out -> [2, 1, 224, 224]
DS 10 out -> [2, 1, 224, 224]
DS 10 out -> [2, 1, 224, 224]
DS 10 out -> [2, 1, 224, 224]


In [23]:
grads_dict = {}
mean_grads_dict = {}

for name, params in model_tail.named_parameters():
    if (name not in grads_dict) and ("ds10" in name):
        grads_dict[name] = []
        mean_grads_dict[name] = []
    if params.grad is not None:
        grads_dict[name].append(params.grad)
        mean_grads_dict[name].append(params.grad.mean())


print("grads_dict : \n", grads_dict)
print("=+" * 15)
print("abs_grads_dict : \n", mean_grads_dict)


grads_dict : 
 {'ds10.layernorm.weight': [tensor([ 0.4994,  0.1910, -0.0348,  0.3572,  0.2148,  0.0237,  0.0675,  0.8519])], 'ds10.layernorm.bias': [tensor([ 0.2581, -0.3783,  0.1922, -0.5930, -0.2444,  0.1632,  0.1884,  0.4015])], 'ds10.conv1.0.weight': [tensor([[[[ 1.9865e-01,  1.4809e-01,  1.5886e-01],
          [ 2.2671e-01,  1.7270e-01,  1.9018e-01],
          [ 8.6830e-02,  4.7224e-02,  1.0767e-01]],

         [[ 1.4213e-02,  2.9643e-02, -1.2107e-02],
          [-2.3407e-02,  1.8475e-02, -8.3898e-03],
          [ 4.5466e-02,  5.7744e-02, -1.2270e-02]],

         [[-8.5668e-02, -8.7973e-02, -8.0554e-02],
          [-9.0638e-02, -1.0873e-01, -1.0334e-01],
          [-1.0130e-01, -1.0615e-01, -8.9239e-02]],

         [[-1.0082e-01, -9.3206e-02, -6.5363e-02],
          [-1.0830e-01, -1.0939e-01, -8.9939e-02],
          [-8.0069e-02, -7.0748e-02, -5.1862e-02]],

         [[-9.6369e-02, -9.8693e-02, -1.0131e-01],
          [-8.6684e-02, -9.7869e-02, -1.0381e-01],
          [-1.0483e-01

In [24]:
tail_weight_grad = model_tail.ds10.layernorm.weight.grad
tail_bias_grad = model_tail.ds10.layernorm.bias.grad

In [26]:
model_body_output = []
with h5py.File('params_and_grads/h5_body_values.hdf5', 'r') as fbody:
    print(f"there are around {len(fbody.items())} tensors")
    for key, group_val in fbody.items():
        blayer_tensor = torch.tensor(group_val['skip9'][:], requires_grad=True).to(device)
        print(blayer_tensor.shape)
        print(blayer_tensor.unsqueeze)
        

there are around 20 tensors
torch.Size([2, 8, 112, 112])
torch.Size([2, 8, 112, 112])
torch.Size([2, 8, 112, 112])
torch.Size([2, 8, 112, 112])
torch.Size([2, 8, 112, 112])
torch.Size([2, 8, 112, 112])
torch.Size([2, 8, 112, 112])
torch.Size([2, 8, 112, 112])
torch.Size([2, 8, 112, 112])
torch.Size([2, 8, 112, 112])
torch.Size([2, 8, 112, 112])
torch.Size([2, 8, 112, 112])
torch.Size([2, 8, 112, 112])
torch.Size([2, 8, 112, 112])
torch.Size([2, 8, 112, 112])
torch.Size([2, 8, 112, 112])
torch.Size([2, 8, 112, 112])
torch.Size([2, 8, 112, 112])
torch.Size([2, 8, 112, 112])
torch.Size([2, 8, 112, 112])


In [28]:
torch.stack(model_body_output).shap

tensor([[[[[ 2.7019e+00,  2.4791e+00,  8.8414e-01,  ...,  3.9019e+00,
             4.7333e+00,  2.0789e+00],
           [ 2.0556e+00,  3.9889e+00,  3.6516e+00,  ...,  8.0017e+00,
             5.7907e+00,  3.0306e+00],
           [ 2.7455e+00,  4.9532e+00,  4.3155e+00,  ...,  5.0719e+00,
             4.0301e+00,  4.8240e+00],
           ...,
           [ 1.4396e+00,  4.3965e+00,  2.3380e+00,  ...,  3.5045e+00,
             4.7522e-01,  1.6790e-01],
           [ 2.7240e+00,  6.1247e+00,  8.9059e+00,  ...,  3.9794e+00,
             5.6375e+00,  5.7376e-02],
           [ 1.9277e+00,  1.9695e+00,  1.3725e+00,  ...,  7.2240e-01,
             1.9424e+00,  9.7450e-01]],

          [[-1.9814e-02,  4.7274e-01, -6.3021e-02,  ..., -1.5852e-01,
            -6.9555e-02,  2.8797e-01],
           [-3.6659e-01,  6.3267e-02,  1.0207e+00,  ...,  4.4546e+00,
             8.6559e-01,  9.5499e-01],
           [ 9.6840e-02,  1.9394e+00,  3.6090e+00,  ..., -1.7593e-01,
             9.1913e-01,  5.7595e+00],
 

In [56]:
tail_weight_grad.mean(dim=0, keepdim=True)

tensor([-0.0483])

In [75]:
model_body.block_9.trans.wide_focus.conv4.weight.data.T.shape
# model_body.block_9.trans.wide_focus.conv4.bias

  model_body.block_9.trans.wide_focus.conv4.weight.data.T.shape


torch.Size([3, 3, 8, 8])

In [103]:
with torch.no_grad():
    optimizer_body.zero_grad()
    hidden_grad = torch.matmul(nn.functional.relu(model_body.block_9.trans.wide_focus.conv4.weight).T, tail_weight_grad)
    model_body.block_9.trans.wide_focus.conv4.weight.grad = hidden_grad.T#.mean(dim=0, keepdim=True)

RuntimeError: assigned grad has data of a different size

In [120]:
model_body.block_9.trans.wide_focus.conv4.weight.backward(hidden_grad.T)

RuntimeError: Mismatch in shape: grad_output[0] has a shape of torch.Size([8, 3, 3]) and output[0] has a shape of torch.Size([8, 8, 3, 3]).

In [119]:
# model_body.block_9.trans.wide_focus.conv4.weight.grad.data.zero_()
model_body.block_9.trans.wide_focus.conv4.bias.shape

torch.Size([8])

In [113]:
torch.matmul(tail_weight_grad, model_body.block_9.trans.wide_focus.conv4.weight.T).T.shape

torch.Size([8, 3, 3])

In [92]:
for name, param in model_body.named_parameters():
    if param.requires_grad:
        print(name)

block_5.layernorm.weight
block_5.layernorm.bias
block_5.conv1.weight
block_5.conv1.bias
block_5.conv2.weight
block_5.conv2.bias
block_5.trans.attention_output.conv_q.weight
block_5.trans.attention_output.conv_q.bias
block_5.trans.attention_output.layernorm_q.weight
block_5.trans.attention_output.layernorm_q.bias
block_5.trans.attention_output.conv_k.weight
block_5.trans.attention_output.conv_k.bias
block_5.trans.attention_output.layernorm_k.weight
block_5.trans.attention_output.layernorm_k.bias
block_5.trans.attention_output.conv_v.weight
block_5.trans.attention_output.conv_v.bias
block_5.trans.attention_output.layernorm_v.weight
block_5.trans.attention_output.layernorm_v.bias
block_5.trans.attention_output.attention.in_proj_weight
block_5.trans.attention_output.attention.in_proj_bias
block_5.trans.attention_output.attention.out_proj.weight
block_5.trans.attention_output.attention.out_proj.bias
block_5.trans.conv1.weight
block_5.trans.conv1.bias
block_5.trans.layernorm.weight
block_5.t