In [1]:

import time
from pathlib import Path

import h5py
import numpy as np
from tqdm import tqdm


import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data.dataset import Dataset

from model import Block_encoder_bottleneck, Block_decoder, device, dict_args, init_weights

  warn(f"Failed to load image Python extension: {e}")


In [2]:


class FCT_Body(nn.Module):
    def __init__(self, ) -> None:
        super().__init__()

        b_attent_body = [2, 2, 2, 2, 2]
        filters = [64, 128, 64, 32, 16, 8] 
        # number of blocks used in the model
        blocks = len(b_attent_body)

        stochastic_depth_rate = 0.0

        #probability for each block
        dpr = [x for x in np.linspace(0, stochastic_depth_rate, blocks)]


        # model
        self.block_5 = Block_encoder_bottleneck("bottleneck", filters[0], filters[1], b_attent_body[0], dpr[0])
        self.block_6 = Block_decoder(filters[1], filters[2], b_attent_body[1], dpr[1])
        self.block_7 = Block_decoder(filters[2], filters[3], b_attent_body[2], dpr[2])
        self.block_8 = Block_decoder(filters[3], filters[4], b_attent_body[3], dpr[3])
        self.block_9 = Block_decoder(filters[4], filters[5], b_attent_body[4], dpr[4])
    
    def forward(self, skip1, skip2, skip3, skip4):
        
        x = self.block_5(skip4)
        print(f"Block 5 out -> {list(x.size())}")
        x = self.block_6(x, skip4)
        print(f"Block 6 out -> {list(x.size())}")
        x = self.block_7(x, skip3)
        print(f"Block 7 out -> {list(x.size())}")
        skip7 = x
        x = self.block_8(x, skip2)
        print(f"Block 8 out -> {list(x.size())}")
        skip8 = x
        x = self.block_9(x, skip1)
        print(f"Block 9 out -> {list(x.size())}")
        skip9 = x

        return {
        #     "skip7": skip7.cpu().detach().numpy(), 
        #     "skip8": skip8.cpu().detach().numpy(), 
            "skip9": skip9.cpu().detach().numpy(),
           }



        

In [3]:
# =======================================================================
#                                BODY
# =======================================================================

model_body = FCT_Body()
model_body.apply(init_weights)

optimizer_body = torch.optim.AdamW(model_body.parameters(), lr=dict_args['lr'],weight_decay=dict_args['decay'])

scheduler_body = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer_body,
            mode='min',
            factor=dict_args['lr_factor'],
            verbose=True,
            threshold=1e-6,
            patience=10,
            min_lr=dict_args['min_lr'])

model_body.to(device)

FCT_Body(
  (block_5): Block_encoder_bottleneck(
    (layernorm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (trans): Transformer(
      (attention_output): Attention(
        (conv_q): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=same, groups=128)
        (layernorm_q): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (conv_k): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128)
        (layernorm_k): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (conv_v): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128)
        (layernorm_v): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=Tru

In [5]:
# Forward propagation in body model

model_body.train()

body_fwd = h5py.File('params_and_grads/body_forward_pass.hdf5', 'w')

try:
    with h5py.File('params_and_grads/head_forward_pass.hdf5', 'r') as head_fwd:
        for key, grp in tqdm(head_fwd.items(), total=len(head_fwd)):
            skip_1 = torch.from_numpy(grp['skip1'][:]).to(device)
            skip_2 = torch.from_numpy(grp['skip2'][:]).to(device)
            skip_3 = torch.from_numpy(grp['skip3'][:]).to(device)     
            skip_4 = torch.from_numpy(grp['skip4'][:]).to(device)

            bd_layer_data = model_body(skip_1, skip_2, skip_3, skip_4)

            bgrp = body_fwd.create_group(key)
            for k,v in bd_layer_data.items():
                bgrp.create_dataset(k, data=v)
except Exception as e:
    import traceback
    traceback.print_exc()
    body_fwd.close()
    head_fwd.close()
finally:
    body_fwd.close()


OSError: Unable to synchronously create file (unable to truncate a file which is already open)