In [1]:
import os 
import numpy as np 
import random 
import time 
import logging 
from tqdm import tqdm


import torch 
import torch.optim as optim
from torchvision import transforms
import torch.nn.functional as F 
import torch.nn as nn 
from torch.utils.data import DataLoader
from torch.nn.parameter import Parameter


from utils.params import params 
from utils.losses import DiceLoss, softmax_mse_loss, softmax_kl_loss, l_correlation_cos_mean
from networks.utils import BCP_net, get_current_consistency_weight, update_ema_variable
from dataset.basedataset import BaseDataset
from dataset.utils import TwoStreamBatchSampler, RandomGenerator, patients_to_slices

from tensorboardX import SummaryWriter

#### 1. Linear Cofficient (Gb)

In [2]:
class Linear_vector(nn.Module): 
    """
    Create linear cofficients (Gb) respect to the weight in UNet 
    """
    def __init__(self, n_dim): 
        super(Linear_vector, self).__init__() 
        self.n_dim = n_dim 
        self.params = Parameter(torch.Tensor(self.n_dim, self.n_dim))
        self.init_ratio = 1e-3
        self.initialize() 
    
    def initialize(self): 
        for param in self.params: 
            param.data.normal_(0, self.init_ratio)
        
    
    def forward(self, x): 
        result = torch.mm(self.params, x) 
        return result

In [11]:
# Where it use 
model = BCP_net(in_chns=1, class_num= 4) 

linear_coffies1 = []
linear_coffies2 = [] 
for name, parameter in model.named_parameters(): 
    if 'conv' in name and 'weight' in name: 
        if len(parameter.shape) == 4: 
            out_dim = parameter.shape[0] # output channel 
            linear_coffies1.append(Linear_vector(out_dim))
            linear_coffies2.append(Linear_vector(out_dim))

# Convert to torch 
linear_coffies1 = nn.ModuleList(linear_coffies1)
linear_coffies2 = nn.ModuleList(linear_coffies2)
linear_coffies1 = linear_coffies1.cuda()
linear_coffies2 = linear_coffies2.cuda() 

# create optimizer for them 
linear_optimizer1 = torch.optim.Adam(linear_coffies1.parameters(), 2e-2)
linear_optimizer2 = torch.optim.Adam(linear_coffies2.parameters(), 2e-2)

#### 2. Loss function 

##### 2.1 Supervised = CE + Dice 

In [None]:
model = BCP_net(in_chns=1, class_num= 4) 
