In [None]:
import os
import sys
import random
import time
from pathlib import Path

import h5py
import numpy as np
from tqdm import tqdm


import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data.dataset import Dataset

from model import DS_out, device, dict_args, init_weights

In [None]:
from torch.nn import BCELoss
from monai.losses.dice import GeneralizedDiceLoss

loss_fn = BCELoss()
dic_loss_fn = GeneralizedDiceLoss(to_onehot_y=True, softmax=True)

In [None]:


class FCT_Tail(nn.Module):
    def __init__(self) -> None:
        super().__init__()

        filters = [32, 16, 8] 
        # number of blocks used in the model

        # self.ds7 = DS_out(filters[0], 1)
        # self.ds8 = DS_out(filters[1], 1)
        # self.ds9 = DS_out(filters[2], 1)
        self.ds10 = DS_out(filters[2], 1)
    
    def forward(self, # skip7, skip8, skip9):
                skip9):
        
        # out7 = self.ds7(skip7)
        # print(f"DS 7 out -> {list(out7.size())}")
        # out8 = self.ds8(skip8)
        # print(f"DS 8 out -> {list(out8.size())}")
        # out9 = self.ds9(skip9)
        # print(f"DS 9 out -> {list(out9.size())}")
        out10 = self.ds10(skip9)
        print(f"DS 10 out -> {list(out10.size())}")

        return out10

        

In [None]:
# =======================================================================
#                                TAIL
# =======================================================================

model_tail = FCT_Tail()
model_tail.apply(init_weights)

optimizer_tail = torch.optim.AdamW(model_tail.parameters(), lr=dict_args['lr'],weight_decay=dict_args['decay'])

scheduler_tail = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer_tail,
            mode='min',
            factor=dict_args['lr_factor'],
            verbose=True,
            threshold=1e-6,
            patience=10,
            min_lr=dict_args['min_lr'])

model_tail.to(device)

print("Initialized ....")

In [None]:
# Forward propagation in Tail model

model_tail.train()
train_loss_list = []
grads_dict = {}
abs_grads_dict = {}


body_fwd = h5py.File('params_and_grads/body_forward_pass.hdf5', 'r')
train_label = h5py.File('params_and_grads/train_values.hdf5', 'r')

try:
    for (key, grp), (lkey, lgrp) in zip(body_fwd.items(), train_label.items()):

        if str(key) != str(lkey):
            print(f"Not the same key tail:: {key} and label:: {lkey}, data could be different ")
        
        # skip_7 = torch.tensor(grp['skip7'][:], requires_grad=True).to(device)
        # skip_8 = torch.tensor(grp['skip8'][:], requires_grad=True).to(device)
        skip_9 = torch.tensor(grp['skip9'][:], requires_grad=True).to(device)
        
        y_mask = torch.from_numpy(lgrp['tlabel'][:]).to(device)

        tl_output_data = model_tail(# skip_7, skip_8, 
            skip_9
            )
        
        loss = loss_fn(tl_output_data, y_mask)
        train_loss_list.append(loss)
        loss.backward()
        optimizer_tail.step()

except Exception as ex:
    import traceback
    print("+=" * 25)
    print("Error encountered as :", ex)
    print("+=" * 25)
    traceback.print_exc()

finally:
    body_fwd.close()
    train_label.close()


In [None]:
grads_dict = {}
mean_grads_dict = {}

for name, params in model_tail.named_parameters():
    if (name not in grads_dict) and ("ds10" in name):
        grads_dict[name] = []
        mean_grads_dict[name] = []
    if params.grad is not None:
        grads_dict[name].append(params.grad)
        mean_grads_dict[name].append(params.grad.shape)


print("grads_dict : \n", grads_dict)
print("=+" * 15)
print("abs_grads_dict : \n", mean_grads_dict)

In [None]:
try:
    with h5py.File('params_and_grads/tail_back_prop.hdf5', 'w') as tail_grad:
        for name, params in model_tail.named_parameters():
            if params.requires_grad:
                if "ds10.conv1.0" in name:
                    tail_grad.create_dataset(name, data=params.grad)
except Exception as e:
    import traceback
    traceback.print_exc()
    fh5_body.close()


In [None]:
for name, params in model_tail.named_parameters():
    if params.requires_grad:
        if "ds10.conv1.0" in name: