In [None]:
%cd /Users/Pracioppo/Desktop/VPTR/VPTR/

In [None]:
import torch
import torch.nn as nn

import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader, random_split
from torch.utils.tensorboard import SummaryWriter

from pathlib import Path
import random
from datetime import datetime

import scipy.io as sio

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm # Loading bar
from PIL import Image

from model import VPTREnc, VPTRDec, VPTRDisc, init_weights
from model import GDL, MSELoss, L1Loss, GANLoss
from utils import get_dataloader
from utils import VidCenterCrop, VidPad, VidResize, VidNormalize, VidReNormalize, VidCrop, VidRandomHorizontalFlip, VidRandomVerticalFlip, VidToTensor
from utils import visualize_batch_clips, save_ckpt, load_ckpt, set_seed, AverageMeters, init_loss_dict, write_summary, resume_training
from utils import set_seed
from utils import f_rescale_dataset, f_Residuals, f_reshape_training_data, f_rotate, f_rotate_and_zoom, f_random_crop, f_rotate_and_zoom_all, f_crop_all, f_flip_all, f_augment_dataset2

from utils import KTHDataset, BAIRDataset, MovingMNISTDataset

set_seed(2024)
torch.manual_seed(2024)

import argparse

import cv2
from tabulate import tabulate

import datetime
import os

In [None]:
ckpt_save_dir = Path('Y:/FAZ/Healthy/all_Healthy/Models/FAZ_ckpt')
tensorboard_save_dir = Path('Y:/FAZ/Healthy/all_Healthy/Models/FAZ_tensorboard')

start_epoch = 0

resume_ckpt = None

summary_writer = SummaryWriter(tensorboard_save_dir.absolute().as_posix())

In [None]:
parser = argparse.ArgumentParser('SINDy NN')
args = parser.parse_args(args=[])

In [None]:
args.num_past_frames = 1
args.num_future_frames = 1
args.encH, args.encW, args.encC = 16, 16, 128 # Increased dimensions for 512 x 512 images
args.img_channels = 1 #3 channels for BAIR datset
args.epochs = 10

args.N = 10

args.AE_lr = 2e-4

args.batch_size = args.N
test_past_frames = 1
test_future_frames = 1

In [None]:
import torch
print(torch.__version__)

In [None]:
torch.cuda.is_available()

In [None]:
args.device = torch.device('cuda:0')
# args.device = torch.device('cpu')
print(f"Using {args.device} device")

In [None]:
def num_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
def get_train_test_splits(imgs,split=4):
    num_data = int(imgs.size()[0])
    num_split = int(num_data/split)

    tot_idxs_s = np.arange(num_data)
    np.random.shuffle(tot_idxs_s)

    res = np.mod(num_data,num_split)
    ones = np.ones(res)

    splits = []
    R = 0
    for i in np.arange(split):
        if len(ones) > i:
            r_i = int(ones[i])
        else:
            r_i = 0
        splits.append(np.sort(tot_idxs_s[i*num_split+R:(i+1)*num_split+R+r_i]).tolist())
        R += r_i

    return splits

def flatten_list(xss):
    return [x for xs in xss for x in xs]

def train_test_split(imgs,masks,splits,test_r):

    test_idxs = np.array(splits[test_r]).flatten().astype('int')
    train_idxs = np.array(flatten_list(splits[0:test_r] + splits[test_r+1:])).astype('int')
    train_idxs = np.sort(train_idxs)

    train_imgs = imgs[train_idxs]
    test_imgs = imgs[test_idxs]

    train_masks = masks[train_idxs]
    test_masks = masks[test_idxs]

    return train_imgs, test_imgs, train_masks, test_masks

In [None]:
class GaussianNoise:
    def __init__(self, mean, std, clip):
        self.mean = mean
        self.std = std
        self.clip = clip

    def __call__(self, img):
        noise = torch.randn_like(img) * self.std + self.mean
        noisy_img = img + noise
        if self.clip:
            noisy_img = torch.clamp(noisy_img, 0, 1)
        return noisy_img

In [None]:
def rearrange_by_index(tensor):
    even_indices = tensor[::2]
    odd_indices = tensor[1::2]
    return torch.cat((even_indices, odd_indices))

In [None]:
def f_augment_all_single(imgs, masks, p_rot = 0.8, p_flip = 0.5, p_jitter = 0.8, p_shear = 0.5, p_noise = 0.5, trans=0, N = 10, mode=0):
    
    L = len(imgs)
#     print(L)

    VF = torchvision.transforms.RandomVerticalFlip(1)
    HF = torchvision.transforms.RandomHorizontalFlip(1)
    CJ = torchvision.transforms.ColorJitter(brightness = 0.8, contrast = 0.8)

    Imgs = torch.zeros(N*L,256,256)
    Masks = torch.zeros(N*L,256,256)

    for j in np.arange(N):

        for i in np.arange(L):

            img = imgs[i]
            mask = masks[i]

#             if j > 0 and mode != 1:
            if mode != 1:
            
                h_flip_dice = np.random.uniform(0,1)
                if h_flip_dice < p_flip:
                    img = HF.forward(img)
                    mask = HF.forward(mask)

                rot_dice = np.random.uniform(0,1)
                if rot_dice < p_rot:
                    angle = np.random.uniform(0,360)
#                     angle = np.random.uniform(0,1)**2*45
#                     img = f_rotate_and_zoom(img, angle)
#                     mask = f_rotate_and_zoom(mask, angle)
                    img = f_rotate(img, angle)
                    mask = f_rotate(mask, angle)

                v_flip_dice = np.random.uniform(0,1)
                if v_flip_dice < p_flip:
                    img = VF.forward(img)
                    mask = VF.forward(mask)

                shear_dice = np.random.uniform(0,1)
                if shear_dice < p_shear:
                    shr = np.random.uniform(-15,15)

                    if rot_dice > p_rot:
                        transx = np.random.uniform(-trans,trans)
                        transy = np.random.uniform(-trans,trans)
                    else:
                        transx = 0
                        transy = 0

                    img = torchvision.transforms.functional.affine(img.unsqueeze(0),angle=0,scale=1,translate=(transx,transy),shear=(shr,shr))
                    mask = torchvision.transforms.functional.affine(mask.unsqueeze(0),angle=0,scale=1,translate=(transx,transy),shear=(shr,shr))

                    jitter_dice = np.random.uniform(0,1)
                    if jitter_dice < p_jitter:
                        img = CJ.forward(img.unsqueeze(0)).squeeze()
                        
                    v_flip_noise = np.random.uniform(0,1)
                    if v_flip_noise < p_noise:
                        sigma = np.random.uniform(low=0.0, high=0.1)
                        GN = GaussianNoise(0, sigma, True)
                        img = GN(img)
                        img = (img - torch.min(img))/torch.max(img)

            Imgs[i*N + j,:,:] = img
            Masks[i*N + j,:,:] = mask
            
    if mode == 1:
        Imgs = rearrange_by_index(Imgs)
        Masks = rearrange_by_index(Masks)

    return Imgs, Masks

### Domain Adapter

In [None]:
class Sum_Domain_Adapter(nn.Module):
    def __init__(self, num_tasks, dim, args):
        super(Sum_Domain_Adapter, self).__init__()
        
        L1 = int(dim/2)
        L2 = dim
        L3 = 2*dim
#         self.l1 = nn.Linear(num_tasks, L1).to(args.device)
#         self.l2 = nn.Linear(L1, L1).to(args.device)
#         self.l3 = nn.Linear(L1, L2).to(args.device)
        
        self.l1 = nn.Linear(num_tasks, L1).to(args.device)
        self.l2 = nn.Linear(L1, L2).to(args.device)
        self.l3 = nn.Linear(L2, L3).to(args.device)
        self.l4 = nn.Linear(L3, L2).to(args.device)
        
        self.ReLU = nn.ReLU()
        
    def forward(self, feats, one_hot_enc): 

        a1 = self.l1(one_hot_enc)
        a2 = self.ReLU(a1)
        a3 = self.l2(a2)
        a4 = self.ReLU(a3)

#         x1_f = self.l3(a4)
        a5 = self.l3(a4)
        a6 = self.ReLU(a5)
        x1_f = self.l4(a6) + a4

        x0_a = feats.flatten(start_dim=2)
        x0_b = torch.swapaxes(x0_a,1,2)

        sum1 = torch.swapaxes(x1_f + x0_b,1,2)

        return sum1.reshape(feats.size())

In [None]:
class Head_Domain_Adapter(nn.Module):
    def __init__(self, num_tasks, dim, nhead, args):
        super(Head_Domain_Adapter, self).__init__()
        
        self.nhead = nhead
        
        self.l1 = nn.Linear(num_tasks, int(dim/4)).to(args.device)
        self.l2 = nn.Linear(int(dim/4), int(dim/2)).to(args.device)
        self.l3 = nn.Linear(int(dim/2), int(dim/2)).to(args.device)
        self.l4 = nn.Linear(int(dim/2), dim).to(args.device)

        self.mha1 = torch.nn.MultiheadAttention(embed_dim=dim, num_heads=nhead, batch_first=True).to(args.device)

        self.ReLU = nn.ReLU()
        self.SM = nn.Softmax(dim=1)
        self.Sigmoid = nn.Sigmoid()
        
    def forward(self, feats, one_hot_enc):
        
        dim = feats.size()[1]
        num_embed = feats.size()[-1]*feats.size()[-2]
        head_dim = int(dim/self.nhead)
        
        a1 = self.l1(one_hot_enc)
        a2 = self.ReLU(a1)
        a3 = self.l2(a2)
        a4 = self.ReLU(a3)
        a5 = self.l3(a4)
        a6 = self.ReLU(a5)
        a7 = self.l4(a6)
        a8 = a7.reshape(args.batch_size,self.nhead,head_dim)
        a9 = self.SM(a8)

        # Query, key, and value matrices
        q_w, k_w, v_w = torch.split(self.mha1.in_proj_weight, [dim, dim, dim])
        q_b, k_b, v_b = torch.split(self.mha1.in_proj_bias, [dim, dim, dim])

        q_w = q_w.T
        k_w = k_w.T
        v_w = v_w.T

        x_a = feats.flatten(start_dim=2)
        x_b = torch.swapaxes(x_a,1,2)
        x_f = f_pos_encoding_additive(x_b,args).squeeze()

        x_r = x_f.reshape(args.batch_size, num_embed, self.nhead, head_dim).transpose(1, 2)
        x_p = a9.unsqueeze(2)*x_r
        x_pp = x_p.reshape(x_f.size())

        qp = x_pp @ q_w + q_b
        kp = x_pp @ k_w + k_b
        vp = x_pp @ v_w + v_b

        qpp = qp.reshape(args.batch_size, num_embed, self.nhead, head_dim).transpose(1, 2)
        kpp = kp.reshape(args.batch_size, num_embed, self.nhead, head_dim).transpose(1, 2)
        vpp = vp.reshape(args.batch_size, num_embed, self.nhead, head_dim).transpose(1, 2)

        ap = self.SM(qpp @ kpp.transpose(-1, -2) / np.sqrt(head_dim))
        pp = ap @ vpp

        pp_t = pp.transpose(1, 2).reshape(args.batch_size, num_embed, dim)

        op = self.mha1.out_proj(pp_t).transpose(1, 2)

        return op.reshape(feats.size())

In [None]:
# def f_construct_kernel(N,gamma=1):

#     L = np.sqrt(N)
#     adj = np.zeros((N,N))

#     for i in np.arange(N):
#         for j in np.arange(N):
# #             adj[i,j] = np.exp(-gamma*(np.sqrt(np.mod(np.abs(i-j),L)**2 + (np.floor(np.abs(i-j)/L))**2)))
#             adj[i,j] = (np.sqrt(np.mod(np.abs(i-j),L)**2 + (np.floor(np.abs(i-j)/L))**2))
    
#     return adj

class KernelAttn(nn.Module):
    def __init__(self, N, args):
        super(KernelAttn, self).__init__()
        self.N = N
        
        # Initialize learnable parameters for four directions
        self.weight_left = nn.Parameter(torch.tensor(0.01)).to(args.device)    # Weight for left
        self.weight_right = nn.Parameter(torch.tensor(0.01)).to(args.device)   # Weight for right
        self.weight_up = nn.Parameter(torch.tensor(0.01)).to(args.device)      # Weight for up
        self.weight_down = nn.Parameter(torch.tensor(0.01)).to(args.device)    # Weight for down
        
        self.M = nn.Parameter(torch.rand(N,N)*0.01).to(args.device) 
        
        self.i = torch.arange(self.N).view(-1, 1).to(args.device)  # Shape (N, 1)
        self.j = torch.arange(self.N).view(1, -1).to(args.device)  # Shape (1, N)
                              
        self.Sigmoid = nn.Sigmoid()
        self.ReLU = nn.ReLU()
        
          # Calculate L

    def forward(self):
        i = self.i
        j = self.j
        L = np.sqrt(self.N)
        
        # Calculate horizontal and vertical distances
        h_dist = (torch.sign(i - j)*(torch.abs(i - j) % L)).float() # Horizontal distances (Shape: (N, N))
        v_dist = (torch.sign(i - j)*torch.floor(torch.abs(i - j) / L)).float() # Vertical distances (Shape: (N, N))

        # Initialize adjacency matrix
        adj = torch.zeros((self.N, self.N), dtype=torch.float32).to(args.device) 

        # Apply weights based on conditions
        adj += (h_dist**2 * (h_dist > 0) * (0.001 + self.ReLU(self.weight_right)))  # Right direction
        adj += (h_dist**2 * (h_dist < 0) * (0.001 + self.ReLU(self.weight_left)))   # Left direction
        adj += (v_dist**2 * (v_dist > 0) * (0.001 + self.ReLU(self.weight_down)))  # Down direction
        adj += (v_dist**2 * (v_dist < 0) * (0.001 + self.ReLU(self.weight_up)))     # Up direction
        
#         adj += h_dist**2 + v_dist**2

        return torch.exp(-torch.sqrt(0.001 * adj))*self.Sigmoid(self.M)

KA1 = KernelAttn(256, args).to(args.device) 
adj = KA1().detach().cpu().numpy()
plt.imshow(adj)

In [None]:
class Sparse_MHA(nn.Module):
    def __init__(self, dim, nhead, n_embed, args):
        super(Sparse_MHA, self).__init__()
        
        self.nhead = nhead
        self.dim = dim
        self.n_embed = n_embed
        self.head_dim = int(self.dim/self.nhead)
        
        self.mha1 = torch.nn.MultiheadAttention(embed_dim=dim, num_heads=nhead, batch_first=True)

        self.KA1 = KernelAttn(n_embed, args).to(args.device)
        self.M = nn.Parameter(torch.rand(n_embed,n_embed)).to(args.device) 
        
        self.Softmax = nn.Softmax(dim=-1)
        
    def forward(self, q, k, v):
        
#         dim = feats.size()[1]

        # Query, key, and value matrices
        q_w, k_w, v_w = torch.split(self.mha1.in_proj_weight, [self.dim, self.dim, self.dim])
        q_b, k_b, v_b = torch.split(self.mha1.in_proj_bias,   [self.dim, self.dim, self.dim])

        qp = q @ q_w.T + q_b
        kp = k @ k_w.T + k_b
        vp = v @ v_w.T + v_b

        qpp = qp.reshape(args.batch_size, self.n_embed, self.nhead, self.head_dim).transpose(1, 2)
        kpp = kp.reshape(args.batch_size, self.n_embed, self.nhead, self.head_dim).transpose(1, 2)
        vpp = vp.reshape(args.batch_size, self.n_embed, self.nhead, self.head_dim).transpose(1, 2)

        ap1 = self.Softmax(qpp @ kpp.transpose(-1, -2) / np.sqrt(self.head_dim))
        self.ap = ap1 + self.KA1() + self.M
#         self.ap = ap1
    
        pp = self.ap @ vpp

        pp_t = pp.transpose(1, 2).reshape(args.batch_size, self.n_embed, self.dim)

        op = self.mha1.out_proj(pp_t)
        return op

In [None]:
class Sparse_AE(nn.Module):
    def __init__(self, dim, k, args):
        super(Sparse_AE, self).__init__()
        
        self.l1 = nn.Linear(dim, k).to(args.device)
        self.l2 = nn.Linear(k, dim).to(args.device)

        self.ReLU = nn.ReLU()
        
    def forward(self, feats):
  
        a1 = self.l1(feats)
        a2 = self.ReLU(a1)
        a3 = self.l2(a2)
        
        return a3

#####
    
class Sparse_AE_bn(nn.Module):
    def __init__(self, dim, k1, k2, args):
        super(Sparse_AE_bn, self).__init__()
        
        self.l1 = nn.Linear(dim, int(dim/k1)).to(args.device)
        self.l2 = nn.Linear(int(dim/k1), int(k2*dim)).to(args.device)
        self.l3 = nn.Linear(int(k2*dim), int(dim/k1)).to(args.device)
        self.l4 = nn.Linear(int(dim/k1), dim).to(args.device)

        self.ReLU = nn.ReLU()
        
    def forward(self, feats):
  
        a1 = self.l1(feats)
        a2 = self.ReLU(a1)
        a3 = self.l2(a2)
        a4 = self.ReLU(a3)
        a5 = self.l3(a4)
        a6 = self.ReLU(a5)
        a7 = self.l4(a6)
        
        return a7

In [None]:
class SparseTransformerEncoder(nn.Module):
    def __init__(self, dim_embed, nhead, n_embed, k1, k2, args):
        super(SparseTransformerEncoder, self).__init__()
        self.smha = Sparse_MHA(dim_embed, nhead, n_embed, args)
        self.layer_norm1 = nn.LayerNorm(dim_embed)
        self.SAE = Sparse_AE_bn(dim_embed, k1, k2, args)
        self.layer_norm2 = nn.LayerNorm(dim_embed)

    def forward(self, q, k, v):
        # Multi-head attention
        a1 = self.smha(q, k, v)
        # Add & Norm
        a2 = self.layer_norm1(q + a1)
        # Feedforward network
        a3 = self.SAE(a2)
        # Final Add & Norm
        out = self.layer_norm2(a3)
        
        return out

In [None]:
# ## See https://stackoverflow.com/questions/75189397/getting-the-query-key-and-value-matrices-from-pytorch-with-self-attn-in-proj-we
# class Head_Domain_Adapter_v2(nn.Module):
#     def __init__(self, num_tasks, dim, nhead, args):
#         super(Head_Domain_Adapter_v2, self).__init__()
        
#         self.nhead = nhead

#         self.l1 = nn.Linear(num_tasks, int(dim/4)).to(args.device)
#         self.l2 = nn.Linear(int(dim/4), int(dim/2)).to(args.device)
#         self.l3 = nn.Linear(int(dim/2), dim).to(args.device)
#         self.l4 = nn.Linear(dim, dim*3).to(args.device)

#         self.mha1 = torch.nn.MultiheadAttention(embed_dim=dim, num_heads=nhead, batch_first=True).to(args.device)

#         self.ReLU = nn.ReLU()
#         self.SM = nn.Softmax(dim=1)
#         self.Sigmoid = nn.Sigmoid()
        
#     def forward(self, feats, one_hot_enc):
        
#         dim = feats.size()[1]
#         num_embed = feats.size()[-1]*feats.size()[-2]
#         head_dim = int(dim/self.nhead)

#         a1 = self.l1(one_hot_enc)
#         a2 = self.ReLU(a1)
#         a3 = self.l2(a2)
#         a4 = self.ReLU(a3)
#         a5 = self.l3(a4)
#         a6 = self.ReLU(a5)
#         a7 = self.l4(a6)
# #         a8 = a7.reshape(args.batch_size,self.nhead,head_dim,3)
#         a8 = self.Sigmoid(a7).reshape(args.batch_size,self.nhead,head_dim,3)
#         a9 = self.SM(a8)
#         aq = a9[:,:,:,0]
#         ak = a9[:,:,:,1]
#         av = a9[:,:,:,2]

#         # Query, key, and value matrices
#         q_w, k_w, v_w = torch.split(self.mha1.in_proj_weight, [dim, dim, dim])
#         q_b, k_b, v_b = torch.split(self.mha1.in_proj_bias, [dim, dim, dim])

#         q_w = q_w.T
#         k_w = k_w.T
#         v_w = v_w.T

#         x_a = feats.flatten(start_dim=2)
#         x_b = torch.swapaxes(x_a,1,2)
#         x_f = f_pos_encoding_additive(x_b,args).squeeze()

#         qp = x_f @ q_w + q_b
#         kp = x_f @ k_w + k_b
#         vp = x_f @ v_w + v_b

#         qpp = qp.reshape(args.batch_size, num_embed, self.nhead, head_dim).transpose(1, 2)
#         kpp = kp.reshape(args.batch_size, num_embed, self.nhead, head_dim).transpose(1, 2)
#         vpp = vp.reshape(args.batch_size, num_embed, self.nhead, head_dim).transpose(1, 2)
        
#         qpp_a = aq.unsqueeze(2)*qpp
#         kpp_a = ak.unsqueeze(2)*kpp
#         vpp_a = av.unsqueeze(2)*vpp

#         ap = self.SM(qpp_a @ kpp_a.transpose(-1, -2) / np.sqrt(head_dim))
#         pp = ap @ vpp_a

#         pp_t = pp.transpose(1, 2).reshape(args.batch_size, num_embed, dim)
#         op = self.mha1.out_proj(pp_t).transpose(1, 2)

#         return op.reshape(feats.size())

In [None]:
# class Dot_Product_Domain_Adapter(nn.Module):
#     def __init__(self, num_tasks, dim, nhead, args):
#         super(Dot_Product_Domain_Adapter, self).__init__()
        
#         self.l1 = nn.Linear(num_tasks, int(dim/2)).to(args.device)
#         self.l2 = nn.Linear(int(dim/2), int(dim/2)).to(args.device)
#         self.l3 = nn.Linear(int(dim/2), dim).to(args.device)
        
#         self.mha1 = torch.nn.MultiheadAttention(embed_dim=dim, num_heads=nhead, batch_first=True).to(args.device)
        
#         self.ReLU = nn.ReLU()
        
#     def forward(self, feats, one_hot_enc):
        
#         a1 = self.l1(one_hot_enc)
#         a2 = self.ReLU(a1)
#         a3 = self.l2(a2)
#         a4 = self.ReLU(a3)
#         mask = self.l3(a4)

#         # Query, key, and value matrices
#         q_w, k_w, v_w = torch.split(self.mha1.in_proj_weight, [dim, dim, dim])
#         q_b, k_b, v_b = torch.split(self.mha1.in_proj_bias, [dim, dim, dim])

#         q_w = q_w.T
#         k_w = k_w.T
#         v_w = v_w.T

#         x_a = feats.flatten(start_dim=2)
#         x_b = torch.swapaxes(x_a,1,2)
#         x_f = f_pos_encoding_additive(x_b,args).squeeze()

#         qp = mask @ q_w + q_b
#         kp = x_f @ k_w + k_b
#         vp = x_f @ v_w + v_b

#         qpp = qp.reshape(args.batch_size, 1, self.nhead, head_dim).transpose(1, 2)
#         kpp = kp.reshape(args.batch_size, num_embed, self.nhead, head_dim).transpose(1, 2)
#         vpp = vp.reshape(args.batch_size, num_embed, self.nhead, head_dim).transpose(1, 2)
        
#         ap = self.SM(qpp @ kpp.transpose(-1, -2) / np.sqrt(head_dim))
#         pp = ap @ vpp
        
# #         prod = mask.unsqueeze(2)*pp
# #         prod_t = prod.transpose(1, 2).reshape(args.batch_size, num_embed, dim)
# #         op = self.mha1.out_proj(prod_t).transpose(1, 2)
        
#         return out.reshape(feats.size())

In [None]:
class Full_Domain_Adapter(nn.Module):
    def __init__(self, num_tasks, dim, nhead, args):
        super(Full_Domain_Adapter, self).__init__()
        
        self.DA_s = Sum_Domain_Adapter(num_tasks, dim, args)
#         self.DA_sp = Simple_Domain_Adapter(num_tasks, dim, num_t, args)
#         self.DA_p = Prod_Domain_Adapter(num_tasks, dim, args)
#         self.DA_dp = Dot_Product_Domain_Adapter(num_tasks, dim, nhead, args)
        self.DA_h = Head_Domain_Adapter(num_tasks, dim, nhead, args)
        
    def forward(self, feats, one_hot_enc):
        
        a2 = self.DA_h(feats, one_hot_enc)
#         a2 = self.DA_s(feats, one_hot_enc)

#         a1 = self.DA_s(feats, one_hot_enc)
#         a2 = self.DA_h(a1, one_hot_enc)

        return a2

In [None]:
class Domain_Adapter_lin_v1(nn.Module):
    def __init__(self, num_tasks, dim, args):
        super(Domain_Adapter_lin_v1, self).__init__()
        
        self.l1 = nn.Linear(num_tasks, int(dim/2)).to(args.device)
        self.l2 = nn.Linear(int(dim/2), int(dim/2)).to(args.device)
        self.l3 = nn.Linear(int(dim/2), dim).to(args.device)
        
        self.ReLU = nn.ReLU()
        
    def forward(self, feats, one_hot_enc):

        a1 = self.l1(one_hot_enc.squeeze())
        a2 = self.ReLU(a1)
        a3 = self.l2(a2)
        a4 = self.ReLU(a3)
        mask = self.l3(a4)

        return mask*feats

In [None]:
class Domain_Adapter_lin_v2(nn.Module):
    def __init__(self, num_tasks, dim, args):
        super(Domain_Adapter_lin_v2, self).__init__()
        
        self.l1 = nn.Linear(num_tasks, int(dim/2)).to(args.device)
        self.l2 = nn.Linear(int(dim/2), int(dim/2)).to(args.device)
        self.l3 = nn.Linear(int(dim/2), dim).to(args.device)
        
        self.ReLU = nn.ReLU()
        self.Sigmoid = Sigmoid()
#         self.Tanh = Tanh()
        
    def forward(self, feats, one_hot_enc):

        a1 = self.l1(one_hot_enc.squeeze())
        a2 = self.ReLU(a1)
        a3 = self.l2(a2)
        a4 = self.ReLU(a3)
        a5 = self.l3(a4)
        mask = self.Sigmoid(a5)

        return mask*feats, mask

In [None]:
# Train_Data = Healthy_Imgs_SVC_aug
# Target_Data = Healthy_RealMasks_SVC_aug

# enc_base = Enc_base

# args.num_batches = int(len(Train_Data)/args.N)

# ea = nn.Conv2d(128, 64, kernel_size=3, padding=1).to(args.device) # output: 136x136x256

# random_idxs = np.arange(len(Train_Data))
# np.random.shuffle(random_idxs)
# train_data_shuffle = Train_Data[random_idxs]
# target_data_shuffle = Target_Data[random_idxs]

# epoch_losses = np.zeros(args.num_batches)

# it = 0

# in_frames = train_data_shuffle[it*args.batch_size:(it+1)*args.batch_size,:,:,:,:]
# target_frames = target_data_shuffle[it*args.batch_size:(it+1)*args.batch_size,:,:,:,:]

# feats0u, feats1u, feats2u, feats3u_f, feats4u_f, x4_sz = enc_base(in_frames, one_hot_enc)

In [None]:
# dim0 = 1
# dim1 = 32
# dim2 = 64
# dim3 = 128

# a0 = nn.Conv2d(dim0, dim1, kernel_size=3, stride = 1, padding=1).to(args.device)

# a1 = nn.Conv2d(dim1, dim1, kernel_size=3, stride = 1, padding=1).to(args.device)
# a2 = nn.Conv2d(dim1, dim1, kernel_size=3, stride = 1, padding=1).to(args.device)
# a3 = nn.Conv2d(dim1, dim1, kernel_size=3, stride = 1, padding=1).to(args.device)

# b0 = nn.Conv2d(dim1, dim2, kernel_size=2, stride = 2, padding=0).to(args.device)

# c0 = nn.Conv2d(dim2, dim2, kernel_size=3, stride = 1, padding=1).to(args.device)
# c1 = nn.Conv2d(dim2, dim2, kernel_size=3, stride = 1, padding=1).to(args.device)
# c2 = nn.Conv2d(dim2, dim2, kernel_size=3, stride = 1, padding=1).to(args.device)
# c3 = nn.Conv2d(dim2, dim2, kernel_size=3, stride = 1, padding=1).to(args.device)

# d0 = nn.Conv2d(dim2, dim3, kernel_size=2, stride = 2, padding=0).to(args.device)

# e0 = nn.Conv2d(dim3, dim3, kernel_size=3, stride = 1, padding=1).to(args.device)
# e1 = nn.Conv2d(dim3, dim3, kernel_size=3, stride = 1, padding=1).to(args.device)
# e2 = nn.Conv2d(dim3, dim3, kernel_size=3, stride = 1, padding=1).to(args.device)
# e3 = nn.Conv2d(dim3, dim3, kernel_size=3, stride = 1, padding=1).to(args.device)

# l0 = nn.Linear(args.num_tasks, int(dim1/2)).to(args.device)
# l1 = nn.Linear(int(dim1/2), dim1).to(args.device)
# l2 = nn.Linear(dim1, dim1).to(args.device)
# l3 = nn.Linear(dim1, dim2).to(args.device)
# l4 = nn.Linear(dim2, dim2).to(args.device)
# l5 = nn.Linear(dim2, dim3).to(args.device)
# l6 = nn.Linear(dim3, dim3).to(args.device)

# ReLU = nn.ReLU()

# q0 = l0(one_hot_enc.squeeze())   
# q01 = ReLU(q0)                 
# q1 = l1(q01)
# q2 = ReLU(q1)
# q3 = l2(q2)
# u1 = q3
# q4 = l3(q2)
# q5 = ReLU(q4)
# q6 = l4(q5)
# u2 = q6
# q7 = l5(q5)                 
# q8 = ReLU(q7)
# q9 = l6(q8)    
# u3 = q9

# x1 = a0(in_frames.squeeze().unsqueeze(1))
# x2 = ReLU(x1)
# x2u = x2 + u1.unsqueeze(2).unsqueeze(3)

# x3 = a1(x2u)
# x4 = ReLU(x3) + x2u
# x5 = a2(x4)
# x6 = ReLU(x5) + x4
# x7 = a3(x6)
# x8 = ReLU(x7) + x6
# # x8u = x8 + u1.unsqueeze(2).unsqueeze(3)

# x9 = self.b0(x8)
# x10 = self.ReLU(x9) + x9
# x10u = x10 + u2.unsqueeze(2).unsqueeze(3)

# x11 = self.c0(x10u)
# x12 = self.ReLU(x11) + x10u
# x13 = self.c1(x12)
# x14 = self.ReLU(x13) + x12
# x15 = self.c2(x14)
# x16 = self.ReLU(x15) + x14
# x17 = self.c3(x16)
# x18 = self.ReLU(x17) + x16
# # x18u = x18 + u2.unsqueeze(2).unsqueeze(3)

# x19 = self.d0(x18)
# x20 = self.ReLU(x19) + x19
# x20u = x20 + u3.unsqueeze(2).unsqueeze(3)

# x21 = self.e0(x20u)
# x22 = self.ReLU(x21) + x20u
# x23 = self.e1(x22)
# x24 = self.ReLU(x23) + x22
# x25 = self.e2(x24)
# x26 = self.ReLU(x25) + x24
# x27 = self.e3(x26)
# out = self.ReLU(x27) + x26


In [None]:
# feats = feats2u
# dim = feats.size()[1]
# num_t = feats.size()[2]*feats.size()[3]
# dim_tot = dim*num_t
# num_tasks = 2

# head_dim = int(dim/nhead)

# l1 = nn.Linear(num_tasks, int(dim/4)).to(args.device)
# l2 = nn.Linear(int(dim/4), int(dim/2)).to(args.device)
# l3 = nn.Linear(int(dim/2), int(dim/2)).to(args.device)
# l4 = nn.Linear(int(dim/2), dim).to(args.device)

# mha1 = torch.nn.MultiheadAttention(embed_dim=dim, num_heads=nhead, batch_first=True).to(args.device)

# ReLU = nn.ReLU()
# SM = nn.Softmax(dim=1)

# num_embed = feats.size()[-1]*feats.size()[-2]
# head_dim = int(dim/nhead)

# a1 = l1(one_hot_enc)
# a2 = ReLU(a1)
# a3 = l2(a2)
# a4 = ReLU(a3)
# a5 = l3(a4)
# a6 = ReLU(a5)
# a7 = l4(a6)
# a8 = a7.reshape(args.batch_size,nhead,head_dim)
# a9 = SM(a8)

# # Query, key, and value matrices
# q_w, k_w, v_w = torch.split(mha1.in_proj_weight, [dim, dim, dim])
# q_b, k_b, v_b = torch.split(mha1.in_proj_bias, [dim, dim, dim])

# q_w = q_w.T
# k_w = k_w.T
# v_w = v_w.T

# x_a = feats.flatten(start_dim=2)
# x_b = torch.swapaxes(x_a,1,2)
# x_f = f_pos_encoding_additive(x_b,args).squeeze()

# x_r = x_f.reshape(args.batch_size, num_embed, nhead, head_dim).transpose(1, 2)
# x_p = a9.unsqueeze(2)*x_r
# x_pp = x_p.reshape(x_f.size())

# qp = x_pp @ q_w + q_b
# kp = x_pp @ k_w + k_b
# vp = x_pp @ v_w + v_b

# qpp = qp.reshape(args.batch_size, num_embed, nhead, head_dim).transpose(1, 2)
# kpp = kp.reshape(args.batch_size, num_embed, nhead, head_dim).transpose(1, 2)
# vpp = vp.reshape(args.batch_size, num_embed, nhead, head_dim).transpose(1, 2)

# qpp_a = aq.unsqueeze(2)*qpp
# kpp_a = ak.unsqueeze(2)*kpp
# vpp_a = av.unsqueeze(2)*vpp

# ap = SM(qpp @ kpp.transpose(-1, -2) / np.sqrt(head_dim))
# pp = ap @ vpp

# pp_t = pp.transpose(1, 2).reshape(args.batch_size, num_embed, dim)

# op = mha1.out_proj(pp_t).transpose(1, 2)

# op.reshape(feats.size())

### Net Architectures

In [None]:
# class CNN_Enc(torch.nn.Module):

#     # Initialize the network and specify input/output dimensions:
#     def __init__(self, args, dim0, dim1, dim2, dim3):
#         super(CNN_Enc, self).__init__()
    
#         self.a0 = nn.Conv2d(dim0, dim1, kernel_size=3, stride = 1, padding=1)
#         self.a1 = nn.Conv2d(dim1, dim1, kernel_size=3, stride = 1, padding=1)
#         self.a2 = nn.Conv2d(dim1, dim1, kernel_size=3, stride = 1, padding=1)
#         self.a3 = nn.Conv2d(dim1, dim1, kernel_size=3, stride = 1, padding=1)

#         self.b0 = nn.Conv2d(dim1, dim2, kernel_size=2, stride = 2, padding=0)

#         self.c0 = nn.Conv2d(dim2, dim2, kernel_size=3, stride = 1, padding=1)
#         self.c1 = nn.Conv2d(dim2, dim2, kernel_size=3, stride = 1, padding=1)
#         self.c2 = nn.Conv2d(dim2, dim2, kernel_size=3, stride = 1, padding=1)
#         self.c3 = nn.Conv2d(dim2, dim2, kernel_size=3, stride = 1, padding=1)

#         self.d0 = nn.Conv2d(dim2, dim3, kernel_size=2, stride = 2, padding=0)

#         self.e0 = nn.Conv2d(dim3, dim3, kernel_size=3, stride = 1, padding=1)
#         self.e1 = nn.Conv2d(dim3, dim3, kernel_size=3, stride = 1, padding=1)
#         self.e2 = nn.Conv2d(dim3, dim3, kernel_size=3, stride = 1, padding=1)
#         self.e3 = nn.Conv2d(dim3, dim3, kernel_size=3, stride = 1, padding=1)

#         self.ReLU = nn.ReLU()

#      # Build the network:
#     def forward(self, in_frames):

#         x1 = self.a0(in_frames)
#         x2 = self.ReLU(x1)
#         x3 = self.a1(x2)
#         x4 = self.ReLU(x3) + x2
#         x5 = self.a2(x4)
#         x6 = self.ReLU(x5) + x4
#         x7 = self.a3(x6)
#         x8 = self.ReLU(x7) + x6

#         x9 = self.b0(x8)
#         x10 = self.ReLU(x9) + x9

#         x11 = self.c0(x10)
#         x12 = self.ReLU(x11) + x10
#         x13 = self.c1(x12)
#         x14 = self.ReLU(x13) + x12
#         x15 = self.c2(x14)
#         x16 = self.ReLU(x15) + x14
#         x17 = self.c3(x16)
#         x18 = self.ReLU(x17) + x16

#         x19 = self.d0(x18)
#         x20 = self.ReLU(x19) + x19

#         x21 = self.e0(x20)
#         x22 = self.ReLU(x21) + x20
#         x23 = self.e1(x22)
#         x24 = self.ReLU(x23) + x22
#         x25 = self.e2(x24)
#         x26 = self.ReLU(x25) + x24
#         x27 = self.e3(x26)
#         out = self.ReLU(x27) + x26

#         return out

In [None]:
class CNN_Enc(torch.nn.Module):

    # Initialize the network and specify input/output dimensions:
    def __init__(self, args, dim0, dim1, dim2, dim3, mode):
        super(CNN_Enc, self).__init__()
        
        self.mode = mode
    
        self.a0 = nn.Conv2d(dim0, dim1, kernel_size=3, stride = 1, padding=1)
        self.a1 = nn.Conv2d(dim1, dim1, kernel_size=3, stride = 1, padding=1)
        self.a2 = nn.Conv2d(dim1, dim1, kernel_size=3, stride = 1, padding=1)
        self.a3 = nn.Conv2d(dim1, dim1, kernel_size=3, stride = 1, padding=1)

        self.b0 = nn.Conv2d(dim1, dim2, kernel_size=2, stride = 2, padding=0)

        self.c0 = nn.Conv2d(dim2, dim2, kernel_size=3, stride = 1, padding=1)
        self.c1 = nn.Conv2d(dim2, dim2, kernel_size=3, stride = 1, padding=1)
        self.c2 = nn.Conv2d(dim2, dim2, kernel_size=3, stride = 1, padding=1)
        self.c3 = nn.Conv2d(dim2, dim2, kernel_size=3, stride = 1, padding=1)

        self.d0 = nn.Conv2d(dim2, dim3, kernel_size=2, stride = 2, padding=0)

        self.e0 = nn.Conv2d(dim3, dim3, kernel_size=3, stride = 1, padding=1)
        self.e1 = nn.Conv2d(dim3, dim3, kernel_size=3, stride = 1, padding=1)
        self.e2 = nn.Conv2d(dim3, dim3, kernel_size=3, stride = 1, padding=1)
        self.e3 = nn.Conv2d(dim3, dim3, kernel_size=3, stride = 1, padding=1)
        
        if self.mode == 'base':
            self.l0 = nn.Linear(args.num_tasks, int(dim1/2)).to(args.device)
            self.l1 = nn.Linear(int(dim1/2), dim1).to(args.device)
            self.l2 = nn.Linear(dim1, dim1).to(args.device)
            self.l3 = nn.Linear(dim1, dim2).to(args.device)
            self.l4 = nn.Linear(dim2, dim2).to(args.device)
            self.l5 = nn.Linear(dim2, dim3).to(args.device)
            self.l6 = nn.Linear(dim3, dim3).to(args.device)

        self.ReLU = nn.ReLU()

     # Build the network:
    def forward(self, in_frames, one_hot_enc):
            
        if self.mode == 'base':
            q0 = self.l0(one_hot_enc.squeeze())   
            q01 = self.ReLU(q0)                 
            q1 = self.l1(q01)
            q2 = self.ReLU(q1)
            q3 = self.l2(q2)
            u1 = q3.unsqueeze(2).unsqueeze(3)   
            q4 = self.l3(q2)
            q5 = self.ReLU(q4)
            q6 = self.l4(q5)
            u2 = q6.unsqueeze(2).unsqueeze(3)   
            q7 = self.l5(q5)                 
            q8 = self.ReLU(q7)
            q9 = self.l6(q8)    
            u3 = q9.unsqueeze(2).unsqueeze(3)   
        else:
            u1 = u2 = u3 = 0

        x1 = self.a0(in_frames)
        x2 = self.ReLU(x1)
        x2u = x2 + u1              
        
        x3 = self.a1(x2u)
        x4 = self.ReLU(x3) + x2u
        x5 = self.a2(x4)
        x6 = self.ReLU(x5) + x4
        x7 = self.a3(x6)
        x8 = self.ReLU(x7) + x6

        x9 = self.b0(x8)
        x10 = self.ReLU(x9) + x9
        x10u = x10 + u2               

        x11 = self.c0(x10u)
        x12 = self.ReLU(x11) + x10u
        x13 = self.c1(x12)
        x14 = self.ReLU(x13) + x12
        x15 = self.c2(x14)
        x16 = self.ReLU(x15) + x14
        x17 = self.c3(x16)
        x18 = self.ReLU(x17) + x16

        x19 = self.d0(x18)
        x20 = self.ReLU(x19) + x19
        x20u = x20 + u3           

        x21 = self.e0(x20u)
        x22 = self.ReLU(x21) + x20u
        x23 = self.e1(x22)
        x24 = self.ReLU(x23) + x22
        x25 = self.e2(x24)
        x26 = self.ReLU(x25) + x24
        x27 = self.e3(x26)
        out = self.ReLU(x27) + x26

        return out

In [None]:
# class CNN_Dec(torch.nn.Module):

#     # Initialize the network and specify input/output dimensions:
#     def __init__(self, args, dim3, dim2, dim1, dim0):
#         super(CNN_Dec, self).__init__()
#         self.e0 = nn.Conv2d(dim3, dim3, kernel_size=3, stride = 1, padding=1)
#         self.e1 = nn.Conv2d(dim3, dim3, kernel_size=3, stride = 1, padding=1)
#         self.e2 = nn.Conv2d(dim3, dim3, kernel_size=3, stride = 1, padding=1)
#         self.e3 = nn.Conv2d(dim3, dim3, kernel_size=3, stride = 1, padding=1)

#         self.d0 = nn.ConvTranspose2d(dim3, dim2, kernel_size=2, stride=2)

#         self.c0 = nn.Conv2d(dim2, dim2, kernel_size=3, stride = 1, padding=1)
#         self.c1 = nn.Conv2d(dim2, dim2, kernel_size=3, stride = 1, padding=1)
#         self.c2 = nn.Conv2d(dim2, dim2, kernel_size=3, stride = 1, padding=1)
#         self.c3 = nn.Conv2d(dim2, dim2, kernel_size=3, stride = 1, padding=1)

#         self.b0 = nn.ConvTranspose2d(dim2, dim1, kernel_size=2, stride=2)

#         self.a0 = nn.Conv2d(dim1, dim1, kernel_size=3, stride = 1, padding=1)
#         self.a1 = nn.Conv2d(dim1, dim1, kernel_size=3, stride = 1, padding=1)
#         self.a2 = nn.Conv2d(dim1, dim1, kernel_size=3, stride = 1, padding=1)
#         self.a3 = nn.Conv2d(dim1, dim0, kernel_size=3, stride = 1, padding=1)

#         self.ReLU = nn.ReLU()
#         self.Sigmoid = nn.Sigmoid()
    
#     # Build the network:
#     def forward(self, x0):

#         x1 = self.e0(x0)
#         x2 = self.ReLU(x1) + x0
#         x3 = self.e1(x2)
#         x4 = self.ReLU(x3) + x2
#         x5 = self.e2(x4)
#         x6 = self.ReLU(x5) + x4
#         x7 = self.e3(x6)
#         x8 = self.ReLU(x7) + x6

#         x9 = self.d0(x8)
#         x10 = self.ReLU(x9) + x9

#         x11 = self.c0(x10)
#         x12 = self.ReLU(x11) + x10
#         x13 = self.c1(x12)
#         x14 = self.ReLU(x13) + x12
#         x15 = self.c2(x14)
#         x16 = self.ReLU(x15) + x14
#         x17 = self.c3(x16)
#         x18 = self.ReLU(x17) + x16

#         x19 = self.b0(x18)
#         x20 = self.ReLU(x19) + x19

#         x21 = self.a0(x20)
#         x22 = self.ReLU(x21) + x20
#         x23 = self.a1(x22)
#         x24 = self.ReLU(x23) + x22
#         x25 = self.a2(x24)
#         x26 = self.ReLU(x25) + x24
#         x27 = self.a3(x26)
#         out = self.Sigmoid(x27)

#         return out

In [None]:
class CNN_Dec(torch.nn.Module):

    # Initialize the network and specify input/output dimensions:
    def __init__(self, args, dim3, dim2, dim1, dim0, mode):
        super(CNN_Dec, self).__init__()
        
        self.mode = mode
        
        self.e0 = nn.Conv2d(dim3, dim3, kernel_size=3, stride = 1, padding=1)
        self.e1 = nn.Conv2d(dim3, dim3, kernel_size=3, stride = 1, padding=1)
        self.e2 = nn.Conv2d(dim3, dim3, kernel_size=3, stride = 1, padding=1)
        self.e3 = nn.Conv2d(dim3, dim3, kernel_size=3, stride = 1, padding=1)

        self.d0 = nn.ConvTranspose2d(dim3, dim2, kernel_size=2, stride=2)

        self.c0 = nn.Conv2d(dim2, dim2, kernel_size=3, stride = 1, padding=1)
        self.c1 = nn.Conv2d(dim2, dim2, kernel_size=3, stride = 1, padding=1)
        self.c2 = nn.Conv2d(dim2, dim2, kernel_size=3, stride = 1, padding=1)
        self.c3 = nn.Conv2d(dim2, dim2, kernel_size=3, stride = 1, padding=1)

        self.b0 = nn.ConvTranspose2d(dim2, dim1, kernel_size=2, stride=2)

        self.a0 = nn.Conv2d(dim1, dim1, kernel_size=3, stride = 1, padding=1)
        self.a1 = nn.Conv2d(dim1, dim1, kernel_size=3, stride = 1, padding=1)
        self.a2 = nn.Conv2d(dim1, dim1, kernel_size=3, stride = 1, padding=1)
        self.a3 = nn.Conv2d(dim1, dim0, kernel_size=3, stride = 1, padding=1)
        
        if self.mode == 'base':
            self.l0 = nn.Linear(args.num_tasks, int(dim1/2)).to(args.device)
            self.l1 = nn.Linear(int(dim1/2), dim1).to(args.device)
            self.l2 = nn.Linear(dim1, dim1).to(args.device)
            self.l3 = nn.Linear(dim1, dim2).to(args.device)
            self.l4 = nn.Linear(dim2, dim2).to(args.device)
            self.l5 = nn.Linear(dim2, dim3).to(args.device)
            self.l6 = nn.Linear(dim3, dim3).to(args.device)

        self.ReLU = nn.ReLU()
        self.Sigmoid = nn.Sigmoid()
    
    # Build the network:
    def forward(self, x0, one_hot_enc):
        
        if self.mode == 'base':
            q0 = self.l0(one_hot_enc.squeeze())   
            q01 = self.ReLU(q0)                 
            q1 = self.l1(q01)
            q2 = self.ReLU(q1)
            q3 = self.l2(q2)
            u1 = q3.unsqueeze(2).unsqueeze(3)
            q4 = self.l3(q2)
            q5 = self.ReLU(q4)
            q6 = self.l4(q5)
            u2 = q6.unsqueeze(2).unsqueeze(3)
            q7 = self.l5(q5)                 
            q8 = self.ReLU(q7)
            q9 = self.l6(q8)    
            u3 = q9.unsqueeze(2).unsqueeze(3)
        elif self.mode == 'aux':
            u1 = u2 = u3 = 0

        x1 = self.e0(x0)
        x2 = self.ReLU(x1) + x0
        x2u = x2 + u3
        
        x3 = self.e1(x2u)
        x4 = self.ReLU(x3) + x2u
        x5 = self.e2(x4)
        x6 = self.ReLU(x5) + x4
        x7 = self.e3(x6)
        x8 = self.ReLU(x7) + x6

        x9 = self.d0(x8)
        x10 = self.ReLU(x9) + x9
        x10u = x10 + u2

        x11 = self.c0(x10u)
        x12 = self.ReLU(x11) + x10u
        x13 = self.c1(x12)
        x14 = self.ReLU(x13) + x12
        x15 = self.c2(x14)
        x16 = self.ReLU(x15) + x14
        x17 = self.c3(x16)
        x18 = self.ReLU(x17) + x16

        x19 = self.b0(x18)
        x20 = self.ReLU(x19) + x19
        x20u = x20 + u1

        x21 = self.a0(x20u)
        x22 = self.ReLU(x21) + x20u
        x23 = self.a1(x22)
        x24 = self.ReLU(x23) + x22
        x25 = self.a2(x24)
        x26 = self.ReLU(x25) + x24
        x27 = self.a3(x26)
        out = self.Sigmoid(x27)

        return out

In [None]:
class Compress_Im_feats(torch.nn.Module):
    """
    Compress img feats
    """

    # Initialize the network and specify input/output dimensions:
    def __init__(self, args, dim1, dim2, dim3):
        super(Compress_Im_feats, self).__init__()
    
        self.h_enc1 = nn.Linear(dim1, dim2).to(args.device)
        self.h_enc2 = nn.Linear(dim2, dim2).to(args.device)
        self.h_enc3 = nn.Linear(dim2, dim2).to(args.device)
        self.h_enc4 = nn.Linear(dim2, dim3).to(args.device)

        self.ReLU = nn.ReLU()

     # Build the network:
    def forward(self, in_feats):

        a3 = self.h_enc1(in_feats)
        a4 = self.ReLU(a3)
        a5 = self.h_enc2(a4)
        a6 = self.ReLU(a5)
        a7 = self.h_enc3(a6)
        a8 = self.ReLU(a7)
        out = self.h_enc4(a8)
        
        return out

In [None]:
class Compress_Im_feats_v2(torch.nn.Module):
    """
    Compress img feats
    """

    # Initialize the network and specify input/output dimensions:
    def __init__(self, args, dim1, dim2, dim3):
        super(Compress_Im_feats_v2, self).__init__()
    
        self.h_enc1 = nn.Linear(dim1, dim2).to(args.device)
        self.h_enc2 = nn.Linear(dim2, dim2).to(args.device)
        self.h_enc3 = nn.Linear(dim2, dim2).to(args.device)
        self.h_enc4 = nn.Linear(dim2, dim3).to(args.device)
        
        self.l01 = nn.Linear(args.num_tasks, int(dim2/2)).to(args.device)
        self.l02 = nn.Linear(int(dim2/2), dim2).to(args.device)
        self.l1 = nn.Linear(dim2, dim2).to(args.device)
        self.l2 = nn.Linear(dim2, dim2).to(args.device)
        self.l3 = nn.Linear(dim2, dim2).to(args.device)
        
        self.ReLU = nn.ReLU()
        self.Sigmoid = nn.Sigmoid()

     # Build the network:
    def forward(self, in_feats, one_hot_enc):
                
        u1 = self.l01(one_hot_enc.squeeze())
        u2 = self.ReLU(u1)
        u3 = self.l02(u2)
        u4 = self.ReLU(u3)
        
        u5 = self.l1(u4)
        u6 = self.Sigmoid(u5)
        u7 = self.l2(u4)
        u8 = self.Sigmoid(u7)
        u9 = self.l3(u4)
        u10 = self.Sigmoid(u9)

        a3 = self.h_enc1(in_feats)
        a4 = self.ReLU(a3)
        a4u = u6*a4                                     
        a5 = self.h_enc2(a4u)                       
        a6 = self.ReLU(a5)
        a6u = u8*a6                    
        a7 = self.h_enc3(a6u)                  
        a8 = self.ReLU(a7)
        a8u = u10*a8                       
        out = self.h_enc4(a8u)
        
#         u_mean = (torch.mean(u6) + torch.mean(u8) + torch.mean(u10))/3
#         U = torch.cat((u6, u8, u10),dim=1)
        
        return out

In [None]:
# class Head_Domain_Adapter(nn.Module):
#     def __init__(self, num_tasks, dim, nhead, args):
#         super(Head_Domain_Adapter, self).__init__()
        
#         self.nhead = nhead
        
#         self.mha1 = torch.nn.MultiheadAttention(embed_dim=dim, num_heads=nhead, batch_first=True).to(args.device)

#         self.ReLU = nn.ReLU()
#         self.SM = nn.Softmax(dim=1)
#         self.Sigmoid = nn.Sigmoid()
        
#     def forward(self, feats):
        
#         dim = feats.size()[1]
#         num_embed = feats.size()[-1]*feats.size()[-2]
#         head_dim = int(dim/self.nhead)
        

#         # Query, key, and value matrices
#         q_w, k_w, v_w = torch.split(self.mha1.in_proj_weight, [dim, dim, dim])
#         q_b, k_b, v_b = torch.split(self.mha1.in_proj_bias, [dim, dim, dim])

#         q_w = q_w.T
#         k_w = k_w.T
#         v_w = v_w.T

#         x_a = feats.flatten(start_dim=2)
#         x_b = torch.swapaxes(x_a,1,2)
#         x_f = f_pos_encoding_additive(x_b,args).squeeze()

#         x_r = x_f.reshape(args.batch_size, num_embed, self.nhead, head_dim).transpose(1, 2)
#         x_p = a9.unsqueeze(2)*x_r
#         x_pp = x_p.reshape(x_f.size())

#         qp = x_pp @ q_w + q_b
#         kp = x_pp @ k_w + k_b
#         vp = x_pp @ v_w + v_b

#         qpp = qp.reshape(args.batch_size, num_embed, self.nhead, head_dim).transpose(1, 2)
#         kpp = kp.reshape(args.batch_size, num_embed, self.nhead, head_dim).transpose(1, 2)
#         vpp = vp.reshape(args.batch_size, num_embed, self.nhead, head_dim).transpose(1, 2)

#         ap = self.SM(qpp @ kpp.transpose(-1, -2) / np.sqrt(head_dim))
#         pp = ap @ vpp

#         pp_t = pp.transpose(1, 2).reshape(args.batch_size, num_embed, dim)

#         op = self.mha1.out_proj(pp_t).transpose(1, 2)

#         return op.reshape(feats.size())

In [None]:
class Unet_Enc_base(nn.Module):
    def __init__(self, nhead, args):
        super(Unet_Enc_base, self).__init__()

        self.Enc1 = CNN_Enc(args, 1, 32, 64, 128,mode='none')
    
        self.ea0 = nn.Conv2d(args.d_attn0, args.d_attn1, kernel_size=4, stride=4, padding=0)
        self.ea1 = nn.Conv2d(args.d_attn1, args.d_attn1, kernel_size=3, stride = 1, padding=1)
        self.ea2 = nn.Conv2d(args.d_attn1, args.d_attn1, kernel_size=3, stride = 1, padding=1)
        self.ea3 = nn.Conv2d(args.d_attn1, args.d_attn1, kernel_size=3, stride = 1, padding=1)
        
        self.l1 = nn.Linear(64**2,8**2).to(args.device)
        self.te0u_1 = SparseTransformerEncoder(args.d_attn0, nhead, 8**2, 4, 50, args)
        self.te0u_2 = SparseTransformerEncoder(args.d_attn0, nhead, 8**2, 4, 50, args)
        self.l2 = nn.Linear(8**2,64**2).to(args.device)
        
        self.e0 = nn.Conv2d(args.d_attn1, args.d_attn2, kernel_size=3, padding=1).to(args.device) # output: 136x136x256
        self.e0a = nn.Conv2d(args.d_attn2, args.d_attn2, kernel_size=3, padding=1).to(args.device) # output: 136x136x256
        self.e0b = nn.Conv2d(args.d_attn2, args.d_attn2, kernel_size=3, padding=1).to(args.device) # output: 136x136x256
        self.e0c = nn.Conv2d(args.d_attn2, args.d_attn2, kernel_size=3, padding=1).to(args.device) # output: 136x136x256

        self.pool0 = nn.MaxPool2d(kernel_size=4, stride=4).to(args.device) # output: 68x68x256

        self.te1u_1 = SparseTransformerEncoder(args.d_attn1, nhead, 16**2, 8, 50, args)
        self.te1u_2 = SparseTransformerEncoder(args.d_attn1, nhead, 16**2, 8, 50, args)

        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2).to(args.device) # output: 68x68x256

        self.te2u_1 = SparseTransformerEncoder(args.d_attn2, nhead, 4**2, 16, 50, args)
        self.te2u_2 = SparseTransformerEncoder(args.d_attn2, nhead, 4**2, 16, 50, args)

        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2).to(args.device) # output: 68x68x256

#         self.e1 = nn.Conv2d(args.d_attn2, args.d_attn2*2, kernel_size=3, padding=1).to(args.device) # output: 138x138x256
#         self.e1a = nn.Conv2d(args.d_attn2*2, args.d_attn2*2, kernel_size=3, padding=1).to(args.device) # output: 138x138x256
#         self.e1b = nn.Conv2d(args.d_attn2*2, args.d_attn2*2, kernel_size=3, padding=1).to(args.device) # output: 138x138x256
#         self.e2 = nn.Conv2d(args.d_attn2*2, args.d_attn2*2, kernel_size=3, padding=1).to(args.device) # output: 138x138x256
#         self.e2a = nn.Conv2d(args.d_attn2*2, args.d_attn2*2, kernel_size=3, padding=1).to(args.device) # output: 138x138x256
#         self.e2b = nn.Conv2d(args.d_attn2*2, args.d_attn2*2, kernel_size=3, padding=1).to(args.device) # output: 138x138x256
        
#         self.C1 = Compress_Im_feats_v2(args, dim1=args.d_attn2*2, dim2 = 256, dim3 = 256).to(args.device)
    
#         self.DA1 = Sum_Domain_Adapter(args.num_tasks, args.d_attn1, args)
#         self.DA2 = Sum_Domain_Adapter(args.num_tasks, args.d_attn2, args)

#         self.DA_l3 = Domain_Adapter_lin_v1(args.num_tasks, args.d_attn2*2, args)
    
        self.e1 = nn.Conv2d(args.d_attn2, args.d_attn2, kernel_size=3, padding=1)  # reduced output channels
        self.e1a = nn.Conv2d(args.d_attn2, args.d_attn2, kernel_size=3, padding=1) # same
        self.e1b = nn.Conv2d(args.d_attn2, args.d_attn2, kernel_size=3, padding=1) # same
        self.e2 = nn.Conv2d(args.d_attn2, args.d_attn2, kernel_size=3, padding=1)  # reduced back to args.d_attn2
        self.e2a = nn.Conv2d(args.d_attn2, args.d_attn2, kernel_size=3, padding=1) # same
        self.e2b = nn.Conv2d(args.d_attn2, args.d_attn2, kernel_size=3, padding=1) # same
        
        self.C1 = Compress_Im_feats_v2(args, dim1=args.d_attn2, dim2 = args.d_attn1, dim3 = args.d_attn1).to(args.device)
        
        self.DA1 = Sum_Domain_Adapter(args.num_tasks, args.d_attn1, args)
        self.DA2 = Sum_Domain_Adapter(args.num_tasks, args.d_attn2, args)

        self.DA_l3 = Domain_Adapter_lin_v1(args.num_tasks, args.d_attn2, args)
    
        self.ReLU = nn.ReLU()
    
    def forward(self, in_frames, one_hot_enc):

        feats0u = self.Enc1(in_frames.squeeze().unsqueeze(1), one_hot_enc)

        feats0u_a = feats0u.flatten(start_dim=2)
        feats0u_a = torch.swapaxes(feats0u_a,1,2)
        feats0u_a = f_pos_encoding_additive(feats0u_a,args).squeeze()
        x1 = torch.swapaxes(self.l1(torch.swapaxes(feats0u_a,1,2)),1,2)
        q1 = self.te0u_1(x1, x1, x1)
        q2 = self.te0u_2(q1, q1, q1)
        q3 = torch.swapaxes(q2,1,2)
        q4 = self.l2(q3)
        feats0u = feats0u + q4.reshape(feats0u.size())

        feats1u = self.ea0(feats0u)        
        feats1ua = self.ea1(feats1u)
        feats1ub = self.ReLU(feats1ua) + feats1u
        feats1uc = self.ea2(feats1ub)
        feats1ud = self.ReLU(feats1uc) + feats1ub
        feats1ue = self.ea3(feats1ud)  
        feats1uf = self.ReLU(feats1ue) + feats1ud
        
        feats1u = self.DA1(feats1uf,one_hot_enc)

        feats1u_a = feats1u.flatten(start_dim=2)
        feats1u_a = torch.swapaxes(feats1u_a,1,2)

        x1 = f_pos_encoding_additive(feats1u_a,args).squeeze()
        q1 = self.te1u_1(x1, x1, x1)
        q2 = self.te1u_2(q1, q1, q1)
        q2 = torch.swapaxes(q2,1,2)
        
        feats1u = feats1u + q2.reshape(feats1u.size())

        feats2u = self.pool0(self.e0(feats1u))
        feats2ua = self.e0a(feats2u)
        feats2ub = self.ReLU(feats2ua) + feats2u
        feats2uc = self.e0b(feats2ub)
        feats2ud = self.ReLU(feats2uc) + feats2ub
        feats2ue = self.e0c(feats2ud)
        feats2uf = self.ReLU(feats2ue) + feats2ud
        
        feats2u = self.DA2(feats2uf,one_hot_enc)

        feats2u_a = feats2u.flatten(start_dim=2)
        feats2u_a = torch.swapaxes(feats2u_a,1,2)

        x1 = f_pos_encoding_additive(feats2u_a,args).squeeze()
        q1 = self.te2u_1(x1, x1, x1)
        q2 = self.te2u_2(q1,q1,q1)
        q2 = torch.swapaxes(q2,1,2)
        feats2u = feats2u + q2.reshape(feats2u.size())

        x1 = self.e1(feats2u)
        x2 = self.pool1(x1)
        x2a = self.e1a(x2)
        x2b = self.ReLU(x2a) + x2
        
        x3 = self.e2(x2b)
        x4 = self.pool2(x3)
        x4a = self.e2a(x4)
        x4b = self.ReLU(x4a) + x4
        x4c = self.e2b(x4b)
        x4d = self.ReLU(x4c) + x4b

        feats3u_f = x4d.flatten(start_dim=1)
        
        feats3u_f = self.DA_l3(feats3u_f, one_hot_enc)
#         feats4u_f = self.C1(feats3u_f)
        feats4u_f = self.C1(feats3u_f, one_hot_enc)
        
#         return feats0u, feats1u, feats2u, feats3u_f, feats4u_f, x4.size()
        return feats0u, feats1u, feats2u, feats3u_f, feats4u_f, x4.size()

In [None]:
class Unet_Dec_v3(nn.Module):
    def __init__(self, nhead, args):
        super(Unet_Dec_v3, self).__init__()
        
# #         self.D1 = Compress_Im_feats(args, dim1=256, dim2 = 256, dim3 = args.d_attn2*2).to(args.device)
#         self.D1 = Compress_Im_feats_v2(args, dim1=args.d_attn1, dim2 = args.d_attn1, dim3 = args.d_attn2*2).to(args.device)

#         self.m1 = nn.ConvTranspose2d(args.d_attn2*2, args.d_attn2*2, kernel_size=2, stride=2).to(args.device)
#         self.q0a = nn.Conv2d(args.d_attn2*2, args.d_attn2*2, kernel_size=3, padding=1).to(args.device) # output: 138x138x256
#         self.q0b = nn.Conv2d(args.d_attn2*2, args.d_attn2*2, kernel_size=3, padding=1).to(args.device) # output: 138x138x256
#         self.m2 = nn.ConvTranspose2d(args.d_attn2*2, args.d_attn2, kernel_size=2, stride=2).to(args.device)
#         self.q1a = nn.Conv2d(args.d_attn2, args.d_attn2, kernel_size=3, padding=1).to(args.device) # output: 138x138x256
#         self.q1b = nn.Conv2d(args.d_attn2, args.d_attn2, kernel_size=3, padding=1).to(args.device) # output: 138x138x256

# #         self.ed0 = nn.Conv2d(args.d_attn0*args.fac, args.d_attn0, kernel_size=1, padding=0).to(args.device)
#         self.ed1 = nn.Conv2d(args.d_attn1*args.fac, args.d_attn1, kernel_size=1, padding=0).to(args.device)
#         self.ed2 = nn.Conv2d(args.d_attn2*args.fac, args.d_attn2, kernel_size=1, padding=0).to(args.device)
#         self.d3 = nn.Linear(args.d_attn2*2*args.fac, args.d_attn2*2).to(args.device)

        self.D1 = Compress_Im_feats_v2(args, dim1=args.d_attn1, dim2 = args.d_attn1, dim3 = args.d_attn2).to(args.device)

        self.m1 = nn.ConvTranspose2d(args.d_attn2, args.d_attn2, kernel_size=2, stride=2).to(args.device)
        self.q0a = nn.Conv2d(args.d_attn2, args.d_attn2, kernel_size=3, padding=1).to(args.device) # output: 138x138x256
        self.q0b = nn.Conv2d(args.d_attn2, args.d_attn2, kernel_size=3, padding=1).to(args.device) # output: 138x138x256
        self.m2 = nn.ConvTranspose2d(args.d_attn2, args.d_attn2, kernel_size=2, stride=2).to(args.device)
        self.q1a = nn.Conv2d(args.d_attn2, args.d_attn2, kernel_size=3, padding=1).to(args.device) # output: 138x138x256
        self.q1b = nn.Conv2d(args.d_attn2, args.d_attn2, kernel_size=3, padding=1).to(args.device) # output: 138x138x256
        
        #         self.ed0 = nn.Conv2d(args.d_attn0*args.fac, args.d_attn0, kernel_size=1, padding=0).to(args.device)
        self.ed1 = nn.Conv2d(args.d_attn1*args.fac, args.d_attn1, kernel_size=1, padding=0).to(args.device)
        self.ed2 = nn.Conv2d(args.d_attn2*args.fac, args.d_attn2, kernel_size=1, padding=0).to(args.device)
        self.d3 = nn.Linear(args.d_attn2*args.fac, args.d_attn2).to(args.device)
        
#         self.te2d_1 = torch.nn.TransformerEncoderLayer(d_model=args.d_attn2, nhead=nhead, dim_feedforward=args.d_attn2, dropout=0, layer_norm_eps=1e-05, batch_first=True).to(args.device)
#         self.te2d_2 = torch.nn.TransformerEncoderLayer(d_model=args.d_attn2, nhead=nhead, dim_feedforward=args.d_attn2, dropout=0, layer_norm_eps=1e-05, batch_first=True).to(args.device)
        self.te2d_1 = SparseTransformerEncoder(args.d_attn2, nhead, 4**2, 16, 50, args)
        self.te2d_2 = SparseTransformerEncoder(args.d_attn2, nhead, 4**2, 16, 50, args)
        
#         self.mha2 = torch.nn.MultiheadAttention(embed_dim=args.d_attn2, num_heads=nhead, batch_first=True).to(args.device)
        self.mha2 = SparseTransformerEncoder(args.d_attn2, nhead, 4**2, 16, 50, args)

        self.m3 = nn.ConvTranspose2d(args.d_attn2, args.d_attn1, kernel_size=4, stride=4).to(args.device)
        
#         self.te1d_1 = torch.nn.TransformerEncoderLayer(d_model=args.d_attn1, nhead=nhead, dim_feedforward=args.d_attn2, dropout=0, layer_norm_eps=1e-05, batch_first=True).to(args.device)
#         self.te1d_2 = torch.nn.TransformerEncoderLayer(d_model=args.d_attn1, nhead=nhead, dim_feedforward=args.d_attn2, dropout=0, layer_norm_eps=1e-05, batch_first=True).to(args.device)
        self.te1d_1 = SparseTransformerEncoder(args.d_attn1, nhead, 16**2, 8, 50, args)
        self.te1d_2 = SparseTransformerEncoder(args.d_attn1, nhead, 16**2, 8, 50, args)
        
#         self.mha1 = torch.nn.MultiheadAttention(embed_dim=args.d_attn1, num_heads=nhead, batch_first=True).to(args.device)
        self.mha1 = SparseTransformerEncoder(args.d_attn1, nhead, 16**2, 8, 50, args)
        
        self.l1 = nn.Linear(64**2,8**2).to(args.device)
        self.te0d_1 = SparseTransformerEncoder(args.d_attn0, nhead, 8**2, 4, 50, args)
        self.te0d_2 = SparseTransformerEncoder(args.d_attn0, nhead, 8**2, 4, 50, args)
        self.l2 = nn.Linear(8**2,64**2).to(args.device)
        
        self.mha0 = SparseTransformerEncoder(args.d_attn0, nhead, 8**2, 4, 50, args)
        
#         self.ea = nn.Conv2d(args.d_attn0, args.d_attn1, kernel_size=4, stride=4, padding=0).to(args.device) # output: 136x136x256
        self.m4 = nn.ConvTranspose2d(args.d_attn1, args.d_attn0*args.fac, kernel_size=4, stride=4).to(args.device)
        
#         self.Dec1 = VPTRDec(1, feat_dim = args.d_attn1, n_downsampling = 4, out_layer = 'Sigmoid').to(args.device)
        self.Dec1 = CNN_Dec(args, args.d_attn0*args.fac, 64, 32, 1,'aux').to(args.device)
        
        self.ReLU = nn.ReLU()
    
    def forward(self, feats0u, feats1u, feats2u, feats3u_f, feats4u_f, sz4, mode):

        feats4_n = feats4u_f*(mode>=0)
        feats3_n = feats3u_f*(mode>=1)
        feats2_n = feats2u*(mode>=2)
        feats1_n = feats1u*(mode>=3)
        feats0_n = feats0u*(mode>=3)
        
        feats3_n = self.d3(feats3_n)
        feats2_n = self.ed2(feats2_n)
        feats1_n = self.ed1(feats1_n)
#         feats0_n = self.ed0(feats0_n)
        sz = [args.batch_size, int(sz4[1]/args.fac), 1, 1]
        
#         feats3d_f = self.D1(feats4_n)
        place_holder = torch.ones(args.batch_size,args.num_tasks).to(args.device)
        feats3d_f = self.D1(feats4_n, place_holder)
    
        feats3d_f = feats3_n + feats3d_f

        feats3d = torch.unflatten(feats3d_f,dim=1,sizes = sz[1:])

        y2 = self.m1(feats3d)
        y2a = self.q0a(y2)
        y2b = self.ReLU(y2a) + y2
        
        y4 = self.m2(y2b)
        y4a = self.q1a(y4)
        feats2d = self.ReLU(y4a) + y4
        
        feats2d_a = feats2d.squeeze().flatten(start_dim=2)
        feats2d_a = torch.swapaxes(feats2d_a,1,2)
        x1 = f_pos_encoding_additive(feats2d_a,args).squeeze()
        a1 = self.te2d_1(x1, x1, x1)
        a2 = self.te2d_2(a1, a1, a1)
        a3 = torch.swapaxes(a2,1,2)
        feats2d = feats2d + a3.reshape(feats2_n.size())

        feats2n_a = feats2_n.flatten(start_dim=2)
        feats2n_a = torch.swapaxes(feats2n_a,1,2)
        feats2n_a = f_pos_encoding_additive(feats2n_a,args).squeeze()

        attn_output2 = self.mha2(a2, feats2n_a, feats2n_a)
        attn_output2 = torch.swapaxes(attn_output2,1,2)
        feats2d = feats2d + attn_output2.reshape(feats2d.size())

        feats1d = self.m3(feats2_n + feats2d)

        feats1d_a = feats1d.squeeze().flatten(start_dim=2)
        feats1d_a = torch.swapaxes(feats1d_a,1,2)
        x1 = f_pos_encoding_additive(feats1d_a,args).squeeze()
        b1 = self.te1d_1(x1, x1, x1)
        b2 = self.te1d_2(b1, b1, b1)
        b3 = torch.swapaxes(b2,1,2)
        feats1d = feats1d + b3.reshape(feats1d.size())

        feats1n_a = feats1_n.flatten(start_dim=2)
        feats1n_a = torch.swapaxes(feats1n_a,1,2)
        feats1n_a = f_pos_encoding_additive(feats1n_a,args).squeeze()
        attn_output1 = self.mha1(b2, feats1n_a, feats1n_a)
        feats1d = feats1d + attn_output1.reshape(feats1d.size())
        
        feats0d = self.m4(feats1_n + feats1d)
        
        feats0d_a = feats0d.flatten(start_dim=2)
        feats0d_a = torch.swapaxes(feats0d_a,1,2)
        feats0d_a = f_pos_encoding_additive(feats0d_a,args).squeeze()
        x1 = torch.swapaxes(self.l1(torch.swapaxes(feats0d_a,1,2)),1,2)
        q1 = self.te0d_1(x1, x1, x1)
        q2 = self.te0d_2(q1, q1, q1)
        
        feats0n_a = feats0_n.flatten(start_dim=2)
        feats0n_a = torch.swapaxes(feats0n_a,1,2)
        feats0n_a = f_pos_encoding_additive(feats0n_a,args).squeeze()
        xa = torch.swapaxes(self.l1(torch.swapaxes(feats0n_a,1,2)),1,2)
        q3 = q2 + self.mha0(q2, q2, xa)
        q4 = self.l2(torch.swapaxes(q3,1,2))
        feats0d = feats0d + q4.reshape(feats0d.size())

        out_frames = self.Dec1(feats0d + feats0_n, []).unsqueeze(1)
        
        return out_frames, feats0_n, feats1_n, feats2_n, feats3_n, feats4_n

In [None]:
# dec_base
class Unet_Dec_base(nn.Module):
    def __init__(self, nhead, args):
        super(Unet_Dec_base, self).__init__()
        
# #         self.D1 = Compress_Im_feats(args, dim1=256, dim2 = 256, dim3 = args.d_attn2*2).to(args.device)
#         self.D1 = Compress_Im_feats_v2(args, dim1=256, dim2 = 256, dim3 = args.d_attn2*2).to(args.device)

#         self.m1 = nn.ConvTranspose2d(args.d_attn2*2, args.d_attn2*2, kernel_size=2, stride=2).to(args.device)
#         self.q0a = nn.Conv2d(args.d_attn2*2, args.d_attn2*2, kernel_size=3, padding=1).to(args.device) # output: 138x138x256
#         self.q0b = nn.Conv2d(args.d_attn2*2, args.d_attn2*2, kernel_size=3, padding=1).to(args.device) # output: 138x138x256
#         self.m2 = nn.ConvTranspose2d(args.d_attn2*2, args.d_attn2, kernel_size=2, stride=2).to(args.device)
#         self.q1a = nn.Conv2d(args.d_attn2, args.d_attn2, kernel_size=3, padding=1).to(args.device) # output: 138x138x256
#         self.q1b = nn.Conv2d(args.d_attn2, args.d_attn2, kernel_size=3, padding=1).to(args.device) # output: 138x138x256
        
        self.D1 = Compress_Im_feats_v2(args, dim1=args.d_attn1, dim2 = args.d_attn1, dim3 = args.d_attn2).to(args.device)
        self.m1 = nn.ConvTranspose2d(args.d_attn2, args.d_attn2, kernel_size=2, stride=2).to(args.device)
        self.q0a = nn.Conv2d(args.d_attn2, args.d_attn2, kernel_size=3, padding=1).to(args.device) # output: 138x138x256
        self.q0b = nn.Conv2d(args.d_attn2, args.d_attn2, kernel_size=3, padding=1).to(args.device) # output: 138x138x256
        self.m2 = nn.ConvTranspose2d(args.d_attn2, args.d_attn2, kernel_size=2, stride=2).to(args.device)
        self.q1a = nn.Conv2d(args.d_attn2, args.d_attn2, kernel_size=3, padding=1).to(args.device) # output: 138x138x256
        self.q1b = nn.Conv2d(args.d_attn2, args.d_attn2, kernel_size=3, padding=1).to(args.device) # output: 138x138x256
        
#         self.te2d_1 = torch.nn.TransformerEncoderLayer(d_model=args.d_attn2, nhead=nhead, dim_feedforward=args.d_attn2, dropout=0, layer_norm_eps=1e-05, batch_first=True).to(args.device)
#         self.te2d_2 = torch.nn.TransformerEncoderLayer(d_model=args.d_attn2, nhead=nhead, dim_feedforward=args.d_attn2, dropout=0, layer_norm_eps=1e-05, batch_first=True).to(args.device)
        self.te2d_1 = SparseTransformerEncoder(args.d_attn2, nhead, 4**2, 16, 50, args)
        self.te2d_2 = SparseTransformerEncoder(args.d_attn2, nhead, 4**2, 16, 50, args)
        
#         self.mha2 = torch.nn.MultiheadAttention(embed_dim=args.d_attn2, num_heads=nhead, batch_first=True).to(args.device)
        self.mha2 = SparseTransformerEncoder(args.d_attn2, nhead, 4**2, 16, 50, args)

        self.m3 = nn.ConvTranspose2d(args.d_attn2, args.d_attn1, kernel_size=4, stride=4).to(args.device)
        self.q2a = nn.Conv2d(args.d_attn1, args.d_attn1, kernel_size=3, padding=1).to(args.device) # output: 138x138x256
        self.q2b = nn.Conv2d(args.d_attn1, args.d_attn1, kernel_size=3, padding=1).to(args.device) # output: 138x138x256

#         self.te1d_1 = torch.nn.TransformerEncoderLayer(d_model=args.d_attn1, nhead=nhead, dim_feedforward=args.d_attn2, dropout=0, layer_norm_eps=1e-05, batch_first=True).to(args.device)
#         self.te1d_2 = torch.nn.TransformerEncoderLayer(d_model=args.d_attn1, nhead=nhead, dim_feedforward=args.d_attn2, dropout=0, layer_norm_eps=1e-05, batch_first=True).to(args.device)
        self.te1d_1 = SparseTransformerEncoder(args.d_attn1, nhead, 16**2, 8, 50, args)
        self.te1d_2 = SparseTransformerEncoder(args.d_attn1, nhead, 16**2, 8, 50, args)
        
#         self.mha1 = torch.nn.MultiheadAttention(embed_dim=args.d_attn1, num_heads=nhead, batch_first=True).to(args.device)
        self.mha1 = SparseTransformerEncoder(args.d_attn1, nhead, 16**2, 8, 50, args)
        
        self.m4 = nn.ConvTranspose2d(args.d_attn1, args.d_attn0, kernel_size=4, stride=4).to(args.device)

        self.l1 = nn.Linear(64**2,8**2).to(args.device)
        self.te0d_1 = SparseTransformerEncoder(args.d_attn0, nhead, 8**2, 4, 50, args)
        self.te0d_2 = SparseTransformerEncoder(args.d_attn0, nhead, 8**2, 4, 50, args)
        self.l2 = nn.Linear(8**2,64**2).to(args.device)
        
        self.mha0 = SparseTransformerEncoder(args.d_attn0, nhead, 8**2, 4, 50, args)
        
#         self.Dec1 = VPTRDec(1, feat_dim = args.d_attn0, n_downsampling = 2, out_layer = 'Sigmoid').to(args.device)
        self.Dec1 = CNN_Dec(args, 128, 64, 32, 1,'base').to(args.device)

#         self.DA0u = Sum_Domain_Adapter(args.num_tasks, args.d_attn0, args)
        self.DA1u = Sum_Domain_Adapter(args.num_tasks, args.d_attn1, args)
        self.DA2u = Sum_Domain_Adapter(args.num_tasks, args.d_attn2, args)
#         self.DA0d = Sum_Domain_Adapter(args.num_tasks, args.d_attn0, args)
        self.DA1d = Sum_Domain_Adapter(args.num_tasks, args.d_attn1, args)
        self.DA2d = Sum_Domain_Adapter(args.num_tasks, args.d_attn2, args)
#         self.DA0u = Full_Domain_Adapter(args.num_tasks, args.d_attn0, nhead, args)
#         self.DA1u = Full_Domain_Adapter(args.num_tasks, args.d_attn1, nhead, args)
#         self.DA2u = Full_Domain_Adapter(args.num_tasks, args.d_attn2, nhead, args)
#         self.DA0d = Full_Domain_Adapter(args.num_tasks, args.d_attn0, nhead, args)
#         self.DA1d = Full_Domain_Adapter(args.num_tasks, args.d_attn1, nhead, args)
#         self.DA2d = Full_Domain_Adapter(args.num_tasks, args.d_attn2, nhead, args)

#         self.DA_l3u = Domain_Adapter_lin_v1(args.num_tasks, 1024, args)
#         self.DA_l4u = Domain_Adapter_lin_v1(args.num_tasks, 256, args)   
#         self.DA_l3d = Domain_Adapter_lin_v1(args.num_tasks, 1024, args)
        self.DA_l3u = Domain_Adapter_lin_v1(args.num_tasks, args.d_attn2, args)
        self.DA_l4u = Domain_Adapter_lin_v1(args.num_tasks, args.d_attn1, args)   
        self.DA_l3d = Domain_Adapter_lin_v1(args.num_tasks, args.d_attn2, args)
        
#         self.SAE1 = Sparse_AE(args.d_attn1,8,50,args)
#         self.SAE2 = Sparse_AE(args.d_attn2,16,50,args)
#         self.SAE3 = Sparse_AE(args.d_attn1,8,50,args)
#         self.SAE4 = Sparse_AE(args.d_attn2,16,50,args)
        
        self.ReLU = nn.ReLU()
        
    def forward(self, feats0u, feats1u, feats2u, feats3u_f, feats4u_f, x4_sz, one_hot_enc, mode):
        
        feats4_n = feats4u_f*(mode>=0)
        feats3_n = feats3u_f*(mode>=1)
        feats2_n = feats2u*(mode>=2)
        feats1_n = feats1u*(mode>=3)
        feats0_n = feats0u*(mode>=4)
        
        feats4_n = self.DA_l4u(feats4_n,one_hot_enc)
        feats3_n = self.DA_l3u(feats3_n, one_hot_enc)
        feats2_n = self.DA2u(feats2_n,one_hot_enc)
        feats1_n = self.DA1u(feats1_n,one_hot_enc)
#         feats0_n = self.DA0u(feats0_n,one_hot_enc)

#         feats3d_f = self.D1(feats4_n)
        feats3d_f = self.D1(feats4_n, one_hot_enc)
        
        feats3d_f = self.DA_l3d(feats3d_f, one_hot_enc)

        feats3d_f = feats3_n + feats3d_f

        feats3d = torch.unflatten(feats3d_f,dim=1,sizes = x4_sz[1:])

        y2 = self.m1(feats3d)
        y2a = self.q0a(y2)
        y2b = self.ReLU(y2a) + y2
        y2c = self.q0b(y2b)
        y2d = self.ReLU(y2c) + y2b
        
        y4 = self.m2(y2d)
        y4a = self.q1a(y4)
        y4b = self.ReLU(y4a) + y4
        y4c = self.q1b(y4b)
        feats2d = self.ReLU(y4c) + y4b
        
        feats2d = self.DA2d(feats2d,one_hot_enc)

        feats2d_a = feats2d.squeeze().flatten(start_dim=2)
        feats2d_a = torch.swapaxes(feats2d_a,1,2)
        x1 = f_pos_encoding_additive(feats2d_a,args).squeeze()
        a1 = self.te2d_1(x1, x1, x1)
        a2 = self.te2d_2(a1, a1, a1)
        a3 = torch.swapaxes(a2,1,2)
        feats2d = feats2d + a3.reshape(feats2u.size())

        feats2n_a = feats2_n.flatten(start_dim=2)
        feats2n_a = torch.swapaxes(feats2n_a,1,2)
        feats2n_a = f_pos_encoding_additive(feats2n_a,args).squeeze()

        at_2 = self.mha2(a2, feats2n_a, feats2n_a)
        at_2 = torch.swapaxes(at_2,1,2)
        feats2d = feats2d + at_2.reshape(feats2d.size())

        feats1d = self.m3(feats2_n + feats2d)
        za = self.q2a(feats1d)
        zb = self.ReLU(za) + feats1d
        zc = self.q2b(zb)
        feats1d = self.ReLU(zc) + zb
        
        feats1d = self.DA1d(feats1d,one_hot_enc)

        feats1d_a = feats1d.squeeze().flatten(start_dim=2)
        feats1d_a = torch.swapaxes(feats1d_a,1,2)
        x1 = f_pos_encoding_additive(feats1d_a,args).squeeze()
        b1 = self.te1d_1(x1, x1, x1)
        b2 = self.te1d_2(b1, b1, b1)
        b3 = torch.swapaxes(b2,1,2)
        feats1d = feats1d + b3.reshape(feats1d.size())

        feats1n_a = feats1_n.flatten(start_dim=2)
        feats1n_a = torch.swapaxes(feats1n_a,1,2)
        feats1n_a = f_pos_encoding_additive(feats1n_a,args).squeeze()
        at_1 = self.mha1(b2, feats1n_a, feats1n_a)
        feats1d = feats1d + at_1.reshape(feats1d.size())

        feats0d = self.m4(feats1_n + feats1d)
        
        feats0d_a = feats0d.flatten(start_dim=2)
        feats0d_a = torch.swapaxes(feats0d_a,1,2)
        feats0d_a = f_pos_encoding_additive(feats0d_a,args).squeeze()
        x1 = torch.swapaxes(self.l1(torch.swapaxes(feats0d_a,1,2)),1,2)
        q1 = self.te0d_1(x1, x1, x1)
        q2 = self.te0d_2(q1, q1, q1)
        
        feats0n_a = feats0_n.flatten(start_dim=2)
        feats0n_a = torch.swapaxes(feats0n_a,1,2)
        feats0n_a = f_pos_encoding_additive(feats0n_a,args).squeeze()
        xa = torch.swapaxes(self.l1(torch.swapaxes(feats0n_a,1,2)),1,2)
        q3 = q2 + self.mha0(q2, q2, xa)
        q4 = self.l2(torch.swapaxes(q3,1,2))
        feats0d = feats0d + q4.reshape(feats0d.size())
        
        out_frames = self.Dec1(feats0_n + feats0d, one_hot_enc).unsqueeze(1)
        
        return out_frames, feats0_n, feats1_n, feats2_n, feats3_n, feats4_n

In [None]:
class U_Net_F(nn.Module):
    def __init__(self, args):
        super(U_Net_F, self).__init__()
        
        self.E1 = Unet_Enc(args)
        self.D1 = Unet_Dec(args)
        
    def forward(self, in_frames):
        feats1u, feats2u, feats4u_f, feats5_f, x4_sz = self.E1(in_frames)
        out_frames, feats1n, feats2n, feats4n, feats5n = self.D1(feats1u, feats2u, feats4u_f, feats5_f, x4_sz)
        
        return out_frames, feats1n, feats2n, feats4n, feats5n

### Training Functionality

In [None]:
class Dice_loss:
    def __call__(self, y_true, y_pred):
        return self.forward(y_true, y_pred)
    
    def dsc(self, y_true, y_pred):
        smooth = 1.
        y_true_f = torch.flatten(y_true)
        y_pred_f = torch.flatten(y_pred)
        mask = y_true_f * y_pred_f
        intersection = torch.sum(mask)
        score = (2. * intersection + smooth) / (torch.sum(y_true_f) + torch.sum(y_pred_f) + smooth)
        return score
    
    def forward(self, y_true, y_pred):
        return (1 - self.dsc(y_true, y_pred)) + nn.functional.binary_cross_entropy(y_true, y_pred)

In [None]:
def f_pos_encoding_additive(features,args):
    pos = (torch.arange(features.size()[-1])/features.size()[-1]).to(args.device)
    return features + pos

In [None]:
def single_iter_AE(Enc, Dec, optimizer, loss, loss2, in_frames, target_frames, one_hot_enc, args, train_flag = True, w_p = 1):

#     for i in range(len(in_frames)):
#         frame = in_frames[i].detach().cpu().numpy().squeeze()
#         plt.imshow(frame)
#         plt.show()
    
    optimizer.zero_grad() # Zero out gradients

#     feats0u, feats1u, feats2u, feats3u_f, feats4_f, x4_sz = Enc(in_frames)
    feats0u, feats1u, feats2u, feats3u_f, feats4_f, x4_sz = Enc(in_frames, one_hot_enc)
    
#     rec_frames_0, _, _, _, _, _ = Dec(feats0u, feats1u, feats2u, feats3u_f, feats4_f, x4_sz,mode=0)
#     rec_frames_1, _, _, _, _, _ = Dec(feats0u, feats1u, feats2u, feats3u_f, feats4_f, x4_sz,mode=1)
#     rec_frames_2, _, _, _, _, _ = Dec(feats0u, feats1u, feats2u, feats3u_f, feats4_f, x4_sz,mode=2)
#     rec_frames_3, _, _, _, _, _ = Dec(feats0u, feats1u, feats2u, feats3u_f, feats4_f, x4_sz,mode=3)
    rec_frames_4, _, _, _, _, _ = Dec(feats0u, feats1u, feats2u, feats3u_f, feats4_f, x4_sz, one_hot_enc, mode=4)

    #     loss_i = loss(rec_frames, in_masks)
#     loss_i = dice_loss(rec_frames, in_frames)

#     loss_0 = loss(rec_frames_0, target_frames)
#     loss_1 = loss(rec_frames_1, target_frames)
#     loss_2 = loss(rec_frames_2, target_frames)
#     loss_3 = loss(rec_frames_3, target_frames)
    loss_4 = loss(rec_frames_4, target_frames)
    
#     loss_i = loss_0/8 + loss_1/4 + loss_2/2 + loss_3/1.5 + loss_4
    loss_i = loss_4

#     noise_mode = 0
#     if noise_mode == 1:
#         w_p = 0.1
#         penalty = loss2(feats1_n.flatten(start_dim=1),torch.zeros_like(feats1_n.flatten(start_dim=1)))
#         penalty += loss2(feats2_n.flatten(start_dim=1),torch.zeros_like(feats2_n.flatten(start_dim=1)))
#         penalty += loss2(feats4_n,torch.zeros_like(feats4_n))
#         penalty += loss2(feats5_n,torch.zeros_like(feats5_n))
#         loss_i += w_p*penalty
#     else:
#         pass

    lambda0 = 1
    lambda1 = 1/2
    lambda2 = 1/4
#     lambda3 = 1/8
    
    p1 = f_L1(Enc.te0u_1.parameters()) + f_L1(Enc.te0u_2.parameters())
    p2 = f_L1(Enc.te1u_1.parameters()) + f_L1(Enc.te1u_2.parameters())
    p3 = f_L1(Enc.te2u_1.parameters()) + f_L1(Enc.te2u_2.parameters())
    p4 = f_L1(Dec.te0d_1.parameters()) + f_L1(Dec.te0d_2.parameters()) + f_L1(Dec.mha0.parameters())
    p5 = f_L1(Dec.te1d_1.parameters()) + f_L1(Dec.te1d_2.parameters()) + f_L1(Dec.mha1.parameters()) 
    p6 = f_L1(Dec.te2d_1.parameters()) + f_L1(Dec.te2d_2.parameters()) + f_L1(Dec.mha2.parameters())
#     p7 = f_L1(Enc.parameters()) + f_L1(Dec.parameters())
    
    penalty = lambda0*(p1+p4) + lambda1*(p2+p5) + lambda2*(p3+p6)
    loss_i += penalty

    loss_i.backward()
    torch.nn.utils.clip_grad_norm_(params_list, 1)
    optimizer.step()
    
    return loss_i

In [None]:
def single_epoch_AE(Train_Data, Target_Data, Enc, Dec, one_hot_enc, optimizer, loss, loss2, params_list, args):
    
    random_idxs = np.arange(len(Train_Data))
    np.random.shuffle(random_idxs)
    train_data_shuffle = Train_Data[random_idxs]
    target_data_shuffle = Target_Data[random_idxs]

    epoch_losses = np.zeros(args.num_batches)

    for it in range(args.num_batches):

        in_frames = train_data_shuffle[it*args.batch_size:(it+1)*args.batch_size,:,:,:,:]
        target_frames = target_data_shuffle[it*args.batch_size:(it+1)*args.batch_size,:,:,:,:]

        epoch_losses[it] = single_iter_AE(Enc, Dec, optimizer, loss, loss2, in_frames, target_frames, one_hot_enc, params_list, args)
        
    return epoch_losses

In [None]:
# def single_iter_AE(enc_base, dec_base, optimizer, loss, loss2, in_frames, target_frames, one_hot_enc, args, train_flag = True, w_p = 1):

#     optimizer.zero_grad() # Zero out gradients

#     feats0u, feats1u, feats2u, feats3u_f, feats4_f, x4_sz = enc_base(in_frames, one_hot_enc)
    
# #     rec_frames_base_0, _, _, _, _, _ = dec_base(feats0u, feats1u, feats2u, feats3u_f, feats4_f, x4_sz, one_hot_enc, mode=0)
# #     rec_frames_base_1, _, _, _, _, _ = dec_base(feats0u, feats1u, feats2u, feats3u_f, feats4_f, x4_sz, one_hot_enc, mode=1)
# #     rec_frames_base_2, _, _, _, _, _ = dec_base(feats0u, feats1u, feats2u, feats3u_f, feats4_f, x4_sz, one_hot_enc, mode=2)
# #     rec_frames_base_3, _, _, _, _, _ = dec_base(feats0u, feats1u, feats2u, feats3u_f, feats4_f, x4_sz, one_hot_enc, mode=3)
#     rec_frames_base_4, _, _, _, _, _ = dec_base(feats0u, feats1u, feats2u, feats3u_f, feats4_f, x4_sz, one_hot_enc, mode=4)

# #     loss_base_target_0 = loss(rec_frames_base_0, target_frames)
# #     loss_base_target_1 = loss(rec_frames_base_1, target_frames)
# #     loss_base_target_2 = loss(rec_frames_base_2, target_frames)
# #     loss_base_target_3 = loss(rec_frames_base_3, target_frames)
#     loss_base_target_4 = loss(rec_frames_base_4, target_frames)
# #     loss_base_target = loss_base_target_0/12 + loss_base_target_1/8 + loss_base_target_2/4 + loss_base_target_3/2 + loss_base_target_4
#     loss_base_target = loss_base_target_4

#     loss_i = loss_base_target
    
#     loss_i.backward()
# #     params_list = list(Enc_base.parameters()) + list(Dec_base.parameters())
#     torch.nn.utils.clip_grad_norm_(params_list, 1)
#     optimizer.step()
    
#     return loss_i 

In [None]:
# def single_epoch_AE(Train_Data, Target_Data, enc_base, dec_base, one_hot_enc, optimizer, loss, loss2, args):
    
#     random_idxs = np.arange(len(Train_Data))
#     np.random.shuffle(random_idxs)
    
#     num_batches = int(len(Train_Data)/args.batch_size)
    
#     train_data_shuffle = Train_Data[random_idxs]
#     target_data_shuffle = Target_Data[random_idxs]
    
#     epoch_losses = np.zeros(num_batches)
#     for it in np.arange(num_batches):

#         in_frames = train_data_shuffle[it*args.batch_size:(it+1)*args.batch_size,:,:,:,:]
#         target_frames = target_data_shuffle[it*args.batch_size:(it+1)*args.batch_size,:,:,:,:]

#         epoch_losses[it] = single_iter_AE(enc_base, dec_base, optimizer, loss, loss2, in_frames, target_frames, one_hot_enc, args)
        
#     return epoch_losses

In [None]:
def single_iter_FAZ(enc_base, dec_base, dec_aux, optimizer, loss, loss2, in_frames, target_frames, one_hot_enc, args, params_list, train_flag = True, w_p = 1):

    optimizer.zero_grad() # Zero out gradients

    feats0u, feats1u, feats2u, feats3u_f, feats4_f, x4_sz = enc_base(in_frames, one_hot_enc)
    
#     rec_frames_base, feats1n_base, feats2n_base, feats4n_base, feats5n_base = dec_base(feats0u, feats1u, feats2u, feats4u_f, feats5_f, x4_sz, one_hot_enc)
#     rec_frames_aux, feats1n_aux, feats2n_aux, feats4n_aux, feats5n_aux      = dec_aux(feats0u, feats1u, feats2u, feats4u_f, feats5_f, x4_sz)

    rec_frames_base_0, _, _, _, _, _ = dec_base(feats0u, feats1u, feats2u, feats3u_f, feats4_f, x4_sz, one_hot_enc, mode=0)
    rec_frames_base_1, _, _, _, _, _ = dec_base(feats0u, feats1u, feats2u, feats3u_f, feats4_f, x4_sz, one_hot_enc, mode=1)
    rec_frames_base_2, _, _, _, _, _ = dec_base(feats0u, feats1u, feats2u, feats3u_f, feats4_f, x4_sz, one_hot_enc, mode=2)
    rec_frames_base_3, _, _, _, _, _ = dec_base(feats0u, feats1u, feats2u, feats3u_f, feats4_f, x4_sz, one_hot_enc, mode=3)
    rec_frames_base_4, _, _, _, _, _ = dec_base(feats0u, feats1u, feats2u, feats3u_f, feats4_f, x4_sz, one_hot_enc, mode=4)

#     with torch.no_grad():
    rec_frames_aux_0, _, _, _, _, _ = dec_aux(feats0u, feats1u, feats2u, feats3u_f, feats4_f, x4_sz, mode=0)
    rec_frames_aux_1, _, _, _, _, _ = dec_aux(feats0u, feats1u, feats2u, feats3u_f, feats4_f, x4_sz, mode=1)
    rec_frames_aux_2, _, _, _, _, _ = dec_aux(feats0u, feats1u, feats2u, feats3u_f, feats4_f, x4_sz, mode=2)
    rec_frames_aux_3, _, _, _, _, _ = dec_aux(feats0u, feats1u, feats2u, feats3u_f, feats4_f, x4_sz, mode=3)
    rec_frames_aux_4, _, _, _, _, _ = dec_aux(feats0u, feats1u, feats2u, feats3u_f, feats4_f, x4_sz, mode=4)

#     loss_base_target_0 = loss(rec_frames_base_0, target_frames)
#     loss_base_target_1 = loss(rec_frames_base_1, target_frames)
#     loss_base_target_2 = loss(rec_frames_base_2, target_frames)
#     loss_base_target_3 = loss(rec_frames_base_3, target_frames)
    loss_base_target_4 = loss(rec_frames_base_4, target_frames)
#     loss_base_target = loss_base_target_0/12 + loss_base_target_1/8 + loss_base_target_2/4 + loss_base_target_3/2 + loss_base_target_4
    loss_base_target = loss_base_target_4
    
#     loss_aux_target_0 = loss(rec_frames_aux_0, target_frames)
#     loss_aux_target_1 = loss(rec_frames_aux_1, target_frames)
#     loss_aux_target_2 = loss(rec_frames_aux_2, target_frames)
#     loss_aux_target_3 = loss(rec_frames_aux_3, target_frames)
    loss_aux_target_4 = loss(rec_frames_aux_4, target_frames)
#     loss_aux_target = loss_aux_target_0/12 + loss_aux_target_1/8 + loss_aux_target_2/4 + loss_aux_target_3/2 + loss_aux_target_4
#     loss_aux_target = loss_aux_target_2/4 + loss_aux_target_4
    loss_aux_target  = loss_aux_target_4

#     loss_base_aux_0 = loss(rec_frames_base_0, rec_frames_aux_0.detach())
#     loss_base_aux_1 = loss(rec_frames_base_1, rec_frames_aux_1.detach())
#     loss_base_aux_2 = loss(rec_frames_base_2, rec_frames_aux_2.detach())
#     loss_base_aux_3 = loss(rec_frames_base_3, rec_frames_aux_3.detach())
    loss_base_aux_4 = loss(rec_frames_base_4, rec_frames_aux_4.detach())
#     loss_base_aux = loss_base_aux_0/12 + loss_base_aux_1/8 + loss_base_aux_2/4 + loss_base_aux_3/2 + loss_base_aux_4
    loss_base_aux = loss_base_aux_4
    
#     loss_base_target = loss_base_target_3
#     loss_aux_target = loss_aux_target_3
#     loss_base_aux = loss_base_aux_3
    
#     loss_base_aux = dice_loss(rec_frames_base, rec_frames_aux.detach())
    loss_i = loss_base_target + 0.5*loss_aux_target + 0.25*loss_base_aux
#     loss_i = loss_base_target + 0.5*loss_base_aux
#     loss_i = loss_base_target

#     U = torch.cat((u_enc, u_dec, u_aux),dim=1)
#     Z = torch.zeros_like(U).to(args.device)
#     penalty = loss2(U,Z)
#     penalty = (u_enc + u_dec + u_aux)/3
#     print(loss_i)
#     print(penalty)
#     loss_i += 0.5*penalty

#     loss_i = 0
#     for j in np.arange(5):
#         rec_frames_base_j = rec_frames_base[j]
#         rec_frames_aux_j = rec_frames_aux[j]
#         loss_base_target_j = dice_loss(rec_frames_base_j, target_frames)
#         loss_aux_target_j = dice_loss(rec_frames_aux_j, target_frames)
#         if j < 4:
#             loss_base_aux_j = dice_loss(rec_frames_base_j, rec_frames_aux_j.detach())
#         else:
#             loss_base_aux_j = 0
#         loss_j = loss_base_target_j + 0.5*loss_aux_target_j + 0.5*loss_base_aux_j
#         loss_i += loss_j/(4-j)**2

    lambda0 = 1
    lambda1 = 1/2
    lambda2 = 1/4
    
    p1 = f_L1(enc_base.te0u_1.parameters()) + f_L1(enc_base.te0u_2.parameters())
    p2 = f_L1(enc_base.te1u_1.parameters()) + f_L1(enc_base.te1u_2.parameters())
    p3 = f_L1(enc_base.te2u_1.parameters()) + f_L1(enc_base.te2u_2.parameters())
    p4 = f_L1(dec_base.te0d_1.parameters()) + f_L1(dec_base.te0d_2.parameters()) + f_L1(dec_base.mha0.parameters())
    p5 = f_L1(dec_base.te1d_1.parameters()) + f_L1(dec_base.te1d_2.parameters()) + f_L1(dec_base.mha1.parameters()) 
    p6 = f_L1(dec_base.te2d_1.parameters()) + f_L1(dec_base.te2d_2.parameters()) + f_L1(dec_base.mha2.parameters())
    
    penalty = lambda0*(p1+p4) + lambda1*(p2+p5) + lambda2*(p3+p6)
    loss_i += penalty
    
    loss_i.backward()
#     params_list = list(Enc_base.parameters()) + list(Dec_base.parameters())
    torch.nn.utils.clip_grad_norm_(params_list, 1)
    optimizer.step()
    
    return loss_i 

In [None]:
def single_epoch_FAZ(Train_Data, Target_Data, enc_base, dec_base, dec_aux, one_hot_enc, optimizer, loss, loss2, args, params_list):
    
    random_idxs = np.arange(len(Train_Data))
    np.random.shuffle(random_idxs)
    
#     if one_hot_enc[0][0][0] == 1:
#         random_idxs = random_idxs[0:len(Train_Data)]
#         num_batches = int(len(Train_Data)/args.batch_size)
#     else:
#         random_idxs = random_idxs[0:args.sz_min]    
#         num_batches = int(args.sz_min/args.batch_size) 
    num_batches = int(len(Train_Data)/args.batch_size)
    
#     random_idxs = random_idxs[0:args.sz_min]
    train_data_shuffle = Train_Data[random_idxs]
    target_data_shuffle = Target_Data[random_idxs]

#     num_batches = args.num_batches
#     num_batches = int(len(Train_Data)/args.batch_size)
#     print(num_batches)
    
#     # Use twice as many training examples on first task (photocoagulation)
#     if one_hot_enc[0][0][0] == 1:
#         num_batches = int(2*num_batches)
    
    epoch_losses = np.zeros(num_batches)
    for it in np.arange(num_batches):
        
#         print(it)

        in_frames = train_data_shuffle[it*args.batch_size:(it+1)*args.batch_size,:,:,:,:]
        target_frames = target_data_shuffle[it*args.batch_size:(it+1)*args.batch_size,:,:,:,:]
        
#         for i in np.arange(len(in_frames)):
#             plt.imshow(in_frames[i].detach().cpu().numpy().squeeze())
#             plt.show()
#             plt.imshow(target_frames[i].detach().cpu().numpy().squeeze())
#             plt.show()
        
        epoch_losses[it] = single_iter_FAZ(enc_base, dec_base, dec_aux, optimizer, loss, loss2, in_frames, target_frames, one_hot_enc, args, params_list)
        
    return epoch_losses

In [None]:
# random_idxs = np.arange(len(Train_Data_aug))
# np.random.shuffle(random_idxs)
# train_data_shuffle = Train_Data_aug[random_idxs]
# nhead = 16

# epoch_losses = np.zeros(args.num_batches)

# it = 0
# in_frames = train_data_shuffle[it*args.batch_size:(it+1)*args.batch_size,:,:,:,:]

# Enc1 = CNN_Enc(args, 1, 32, 64, 128,mode='non').to(args.device)

# ea0 = nn.Conv2d(args.d_attn0, args.d_attn1, kernel_size=4, stride=4, padding=0).to(args.device)
# ea1 = nn.Conv2d(args.d_attn1, args.d_attn1, kernel_size=3, stride = 1, padding=1).to(args.device)
# ea2 = nn.Conv2d(args.d_attn1, args.d_attn1, kernel_size=3, stride = 1, padding=1).to(args.device)
# ea3 = nn.Conv2d(args.d_attn1, args.d_attn1, kernel_size=3, stride = 1, padding=1).to(args.device)

# e0 = nn.Conv2d(args.d_attn1, args.d_attn2, kernel_size=3, padding=1).to(args.device) # output: 136x136x256
# e0a = nn.Conv2d(args.d_attn2, args.d_attn2, kernel_size=3, padding=1).to(args.device) # output: 136x136x256
# e0b = nn.Conv2d(args.d_attn2, args.d_attn2, kernel_size=3, padding=1).to(args.device) # output: 136x136x256
# e0c = nn.Conv2d(args.d_attn2, args.d_attn2, kernel_size=3, padding=1).to(args.device) # output: 136x136x256

# pool0 = nn.MaxPool2d(kernel_size=4, stride=4).to(args.device) # output: 68x68x256

# te1u_1 = torch.nn.TransformerEncoderLayer(d_model=args.d_attn1, nhead=nhead, dim_feedforward=args.d_attn1, dropout=0, layer_norm_eps=1e-05, batch_first=True).to(args.device)
# te1u_2 = torch.nn.TransformerEncoderLayer(d_model=args.d_attn1, nhead=nhead, dim_feedforward=args.d_attn1, dropout=0, layer_norm_eps=1e-05, batch_first=True).to(args.device)

# DA1 = Sum_Domain_Adapter(args.num_tasks, args.d_attn1, args).to(args.device)

# dim = 256
# SAE1 = Sparse_AE(dim,4,50,args)

# ReLU = nn.ReLU()

# feats0u = Enc1(in_frames.squeeze().unsqueeze(1), one_hot_enc)

# feats1u = ea0(feats0u)        
# feats1ua = ea1(feats1u)
# feats1ub = ReLU(feats1ua) + feats1u
# feats1uc = ea2(feats1ub)
# feats1ud = ReLU(feats1uc) + feats1ub
# feats1ue = ea3(feats1ud)  
# feats1uf = ReLU(feats1ue) + feats1ud

# feats1u = DA1(feats1uf,one_hot_enc)

# feats1u_a = feats1u.flatten(start_dim=2)
# feats1u_a = torch.swapaxes(feats1u_a,1,2)
# feats1u_a = f_pos_encoding_additive(feats1u_a,args).squeeze()
# q1 = te1u_1(feats1u_a)
# q2 = te1u_2(q1)

# SAE1(q2).size()

In [None]:
# # num_params(te1u_1)
# SAE1 = Sparse_AE(256,4,50,args)
# print(num_params(SAE1))

# SAE2 = Sparse_AE(512,8,50,args)
# num_params(SAE2)

# 64*50*2
# 54*512

In [None]:
def display_AE(train_data, target_data, enc_base, dec_base, one_hot_enc, args, i=0):
    
    random_idxs = np.arange(len(train_data.data))
    np.random.shuffle(random_idxs)
    train_data_shuffle = train_data.data[random_idxs]
    target_data_shuffle = target_data[random_idxs]
    it = 0
    in_frames = train_data_shuffle[it*args.batch_size:(it+1)*args.batch_size,:,:,:,:]
    target_frames = target_data_shuffle[it*args.batch_size:(it+1)*args.batch_size,:,:,:,:]
    
    with torch.no_grad():
        feats0u, feats1u, feats2u, feats3u_f, feats4_f, x4_sz = enc_base(in_frames, one_hot_enc)

        rec_frames_base_4, _, _, _, _, _ = dec_base(feats0u, feats1u, feats2u, feats3u_f, feats4_f, x4_sz, one_hot_enc, mode=4)

    plt.imshow(in_frames[i,0,0,:,:].detach().cpu().numpy())
    plt.show()
    plt.imshow(target_frames[i,0,0,:,:].detach().cpu().numpy())
    plt.show()

    plt.imshow(rec_frames_base_4[i,0,0,:,:].detach().cpu().numpy())
    plt.show()
#         plt.imshow((1.0*(rec_frames_base_4[i,0,0,:,:]>0.5)).detach().cpu().numpy())
#         plt.show()

In [None]:
def display_FAZ(train_data, target_data, enc_base, dec_base, one_hot_enc, args, i=0):
    
    random_idxs = np.arange(len(train_data.data))
    np.random.shuffle(random_idxs)
    train_data_shuffle = train_data.data[random_idxs]
    target_data_shuffle = target_data[random_idxs]
    it = 0
    in_frames = train_data_shuffle[it*args.batch_size:(it+1)*args.batch_size,:,:,:,:]
    target_frames = target_data_shuffle[it*args.batch_size:(it+1)*args.batch_size,:,:,:,:]
    
    with torch.no_grad():
        feats0u, feats1u, feats2u, feats3u_f, feats4_f, x4_sz = enc_base(in_frames, one_hot_enc)
#         rec_frames, feats1n_base, feats2n_base, feats4n_base, feats5n_base = dec_base(feats1u, feats2u, feats4u_f, feats5_f, x4_sz, one_hot_enc)
#         rec_frames_aux, feats1n_aux, feats2n_aux, feats4n_aux, feats5n_aux      = dec_aux(feats1u, feats2u, feats4u_f, feats5_f, x4_sz)

        rec_frames_base_0, _, _, _, _, _ = dec_base(feats0u, feats1u, feats2u, feats3u_f, feats4_f, x4_sz, one_hot_enc, mode=0)
        rec_frames_base_1, _, _, _, _, _ = dec_base(feats0u, feats1u, feats2u, feats3u_f, feats4_f, x4_sz, one_hot_enc, mode=1)
        rec_frames_base_2, _, _, _, _, _ = dec_base(feats0u, feats1u, feats2u, feats3u_f, feats4_f, x4_sz, one_hot_enc, mode=2)
        rec_frames_base_3, _, _, _, _, _ = dec_base(feats0u, feats1u, feats2u, feats3u_f, feats4_f, x4_sz, one_hot_enc, mode=3)
        rec_frames_base_4, _, _, _, _, _ = dec_base(feats0u, feats1u, feats2u, feats3u_f, feats4_f, x4_sz, one_hot_enc, mode=4)

    plt.imshow(in_frames[i,0,0,:,:].detach().cpu().numpy())
    plt.show()
    plt.imshow(target_frames[i,0,0,:,:].detach().cpu().numpy())
    plt.show()
#         plt.imshow(rec_frames[i,0,0,:,:].detach().cpu().numpy())
#         plt.show()
#         plt.imshow((in_frames[i,0,0,:,:] + rec_frames[i,0,0,:,:]).detach().cpu().numpy())
#         plt.show()

#     plt.imshow(rec_frames_base_0[i,0,0,:,:].detach().cpu().numpy())
#     plt.show()
#     plt.imshow(rec_frames_base_1[i,0,0,:,:].detach().cpu().numpy())
#     plt.show()
#     plt.imshow(rec_frames_base_2[i,0,0,:,:].detach().cpu().numpy())
#     plt.show()
#     plt.imshow(rec_frames_base_3[i,0,0,:,:].detach().cpu().numpy())
#     plt.show()
    plt.imshow(rec_frames_base_4[i,0,0,:,:].detach().cpu().numpy())
    plt.show()

In [None]:
def f_display_all(data, target_data, enc_base, dec_base, one_hot_enc, test_split, args, test_folder):

#     random_idxs = np.arange(len(data))
#     np.random.shuffle(random_idxs)
#     data_shuffle = data[random_idxs]
#     target_data_shuffle = target_data[random_idxs]

    data_shuffle = data
    target_data_shuffle = target_data

    bs = args.batch_size
    LD = len(data)
#     Lb = int(LD/2)
    Lb = LD
    NN = int(LD/bs)
#     NN = int(np.ceil(len(data)/bs))
    dice_v = []
    dices_v = []
    target_area_v = []
    
    dice_tot = 0
    dice_tot_s = 0
    target_area_tot = 0
    for it in range(NN):

        in_frames = data_shuffle[it*bs:(it+1)*bs,:,:,:,:]
        target_frames = target_data_shuffle[it*bs:(it+1)*bs,:,:,:,:]

        with torch.no_grad():
            feats0u, feats1u, feats2u, feats3u_f, feats4_f, x4_sz = enc_base(in_frames, one_hot_enc)

            rec_frames, _, _, _, _, _ = dec_base(feats0u, feats1u, feats2u, feats3u_f, feats4_f, x4_sz, one_hot_enc, mode=4)

        for i in np.arange(args.batch_size):

            in_frame = in_frames[i,0,0,:,:].detach().cpu().numpy()
            target_frame = target_frames[i,0,0,:,:].detach().cpu().numpy()
            rec_frame = (rec_frames[i,0,0,:,:]).detach().cpu().numpy()

#             plt.imshow(in_frame)
#             plt.savefig(test_folder + '/' + str(i) + '_a_in_frame' + '.png')
#             plt.show()

#             plt.imshow(target_frame)
#             plt.savefig(test_folder + '/' + str(i) + '_b_ground_truth' + '.png')
#             plt.show()
            
#             plt.imshow(rec_frame)
#             plt.savefig(test_folder + '/' + str(i) + '_b_pred_frame' + '.png')
#             plt.show()
            
    #         target_frame_d = target_frame
    #         rec_frame_d = rec_frame
            target_frame_d = 1.*(target_frame>0.5)
            rec_frame_d = 1.*(rec_frame>0.5)

            dice_i, _ = f_dice_jac(target_frame_d,rec_frame_d)
            dice_s, _ = f_dice_jac_scaled(target_frame_d,rec_frame_d)
            dice_v.append(dice_i)
            dices_v.append(dice_s)
#             dice_tot += dice_i
#             dice_tot_s += dice_s
            target_area_i = np.sum(target_frame_d)
            target_area_v.append(target_area_i)
#             target_area_tot += target_area_i
#             dices.append(dice_i)
#             print(dice_i)

#             if dice_i < 0.5:
#                 print(dice_i)
        
#                 plt.imshow(in_frame)
#                 plt.show()
                
#                 plt.imshow(target_frame)
#                 plt.show()
            
#                 plt.imshow(rec_frame)
#                 plt.show()
                
#     dice_tot_s = np.sum(dices_v)
#     target_area_tot = np.sum(target_area_v)

#     plt.plot(dice_v)
#     plt.grid()
#     print('Mean Dice =', np.mean(dice_v))
#     print('Mean Dice (scaled by area) =', np.sum(dices_v)/np.sum(target_area_v))
#     plt.show()

#     print(dice_v)
    
    plt.plot(dice_v[0:Lb])
    plt.grid()
    print('Mean Dice =', np.mean(dice_v[0:Lb]))
    print('Mean Dice (scaled by area) =', np.sum(dices_v[0:Lb])/np.sum(target_area_v[0:Lb]))
    plt.show()
    
    print(np.array(dice_v[0:Lb]))
    
    med = np.median(np.array(dice_v))
#     med = 0.4
    print(med)
    TF = np.array(dice_v) > med
    print(TF)
    
    filtered_vals = [v for v, f in zip(test_split, TF) if f]
    
    print(np.sum(TF*np.array(dice_v))/np.sum(TF))
    
    return np.array(filtered_vals)

In [None]:
def f_encode(feats1u,args):
    feats1u_a = feats1u.flatten(start_dim=2)
    feats1u_b = torch.swapaxes(feats1u_a,1,2)
    feats1u_c = f_pos_encoding_additive(feats1u_b,args).squeeze()
    return feats1u_c

### Evaluation functionality

In [None]:
def f_dice_jac(mask1,mask2):

    union = 1*((mask1 + mask2) > 0)

    intersect = (mask1*mask2)

    sum1 = np.sum(mask1) + np.sum(mask2)

    dice = 2*np.sum(intersect)/sum1
    jaccard = np.sum(intersect)/np.sum(union)

#     return dice, jaccard, union, np.sum(intersect), sum1
    return dice, jaccard

In [None]:
def f_dice_jac_scaled(mask1,mask2):

    union = 1*((mask1 + mask2) > 0)

    intersect = (mask1*mask2)

    sum1 = np.sum(mask1) + np.sum(mask2)

    dice = 2*np.sum(intersect)/sum1
    jaccard = np.sum(intersect)/np.sum(union)

#     return dice, jaccard, union, np.sum(intersect), sum1
    return dice*np.sum(mask1), jaccard*np.sum(mask1)

In [None]:
def f_Tot_Dice(Data, Target_Data, Enc_base, Dec_base, one_hot_enc, args):

    Target_Data = 1.*(Target_Data > 0.5)
    
    LD = len(Data)
    num_batches = int(LD/args.batch_size)

#     dice_tot = 0
#     dice_tot_s = 0
#     target_area_tot = 0
    
    dice_v = []
    dices_v = []
    target_area_v = []
    
    L = 0
    for it in np.arange(num_batches):

        in_frames = Data[it*args.batch_size:(it+1)*args.batch_size,:,:,:,:]
        target_frames = Target_Data[it*args.batch_size:(it+1)*args.batch_size,:,:,:,:]

        with torch.no_grad():
            feats0u, feats1u, feats2u, feats3u_f, feats4_f, x4_sz = Enc_base(in_frames, one_hot_enc)
            rec_frames, _, _, _, _, _ = Dec_base(feats0u, feats1u, feats2u, feats3u_f, feats4_f, x4_sz, one_hot_enc, mode=4)

        rec_frames = 1.*(rec_frames > 0.5)
        
        for i in np.arange(args.batch_size):
            if torch.sum(target_frames[i]) > 0:
                in_frame =  in_frames.detach().cpu().numpy().squeeze()[i]
                target_frame = target_frames.detach().cpu().numpy().squeeze()[i]
                rec_frame = rec_frames.detach().cpu().numpy().squeeze()[i]

                dice_i, _ = f_dice_jac(target_frame,rec_frame)
                dice_is, _ = f_dice_jac_scaled(target_frame,rec_frame)
                dice_v.append(dice_i)
                dices_v.append(dice_is)
#                 dice_tot += dice_i
#                 dice_tot_s += dice_is
                target_area_i = np.sum(target_frame)
                target_area_v.append(target_area_i)
#                 target_area_tot += target_area_i
                L += 1
        
    #     plt.imshow(in_frame)
    #     plt.show()
    #     plt.imshow(target_frame)
    #     plt.show()
    #     plt.imshow(rec_frame)
    #     plt.show()
    #     plt.imshow(rec_frame - target_frame)
    #     plt.show()

    dice_tot = np.sum(dice_v[0:LD])
    dice_tot_s = np.sum(dices_v[0:LD])
    target_area_tot = np.sum(target_area_v[0:LD])
    
#     return dice_tot/L
    return dice_tot_s/target_area_tot

### Autoencoder Training Loop

In [None]:
# #####################Training loop ###########################

# # args.mode = 1

# for task_idx in tqdm(np.arange(args.num_tasks)):

#     Enc = Enc_base

#     Dec = Dec_aux_list[task_idx]
#     Dec_name = Dec_name_list[task_idx]

#     Train_Data  = Train_Data_list[task_idx]
#     Train_Target_Data = Target_Data_list[task_idx]
#     args.num_batches = int(len(Train_Data)/args.N)

#     args.training_mode = 1
#     args.num_batches = int(len(Train_Data)/args.N)

#     # RM_path_save = resume_AE_ckpt.joinpath('FAZ_RM_ckpt')

#     args.num_epochs = 100
#     args.show_example_epochs = 1
#     args.save_epochs = 1

#     log_mean_epoch_losses = np.zeros(args.num_epochs)

#     loss = nn.BCELoss() # <-- Use this one
#     loss2 = nn.L1Loss()

#     lr = 1E-4
#     params_list = list(Enc.parameters()) + list(Dec.parameters())
#     optimizer = torch.optim.Adam(params_list, lr=lr, betas=(0.9, 0.999))
    
#     one_hot_enc = torch.zeros(args.num_tasks).to(args.device)
#     one_hot_enc[task_idx] = 1

#     # train_dice_v = []
#     # train_dice_v2 = []
#     # test_dice_v = []
#     # test_dice_v2 = []

#     # Train for a maximum of max_epochs:
#     for epoch in tqdm(np.arange(args.num_epochs), desc="Training progress..."):
        
#         epoch_losses = single_epoch_AE(Train_Data, Target_Data, Enc, Dec, one_hot_enc, optimizer, loss, loss2, args)
#         log_mean_epoch_losses[epoch] = np.log(np.mean(epoch_losses))
#     #     coord_errors[epoch] = f_coord_error(train_data,train_data_traj,VPTR_Enc, C1, args)

#         if np.mod(epoch,args.show_example_epochs) == 0:
#             display_AE(Train_Data, Target_Data, Enc, Dec, one_hot_enc, args)

#             plt.plot(log_mean_epoch_losses[0:epoch])
#             plt.grid()
#             plt.show()

#     #         train_dice = f_Tot_Dice(Train_Data,args)
#     #         print(train_dice)
#     #         train_dice_v.append(train_dice)
#     #         plt.plot(train_dice_v)
#     #         plt.grid()
#     #         plt.show()

#     #         test_dice, test_dice2 = f_Tot_Dice(Test_Data,args)
#     #         print(test_dice)
#     #         print(test_dice2)
#     #         test_dice_v.append(test_dice)
#     #         test_dice_v2.append(test_dice2)
#     #         plt.plot(test_dice_v)
#     #         plt.plot(test_dice_v2)
#     #         plt.grid()
#     #         plt.show()

#     #         plt.plot(coord_errors[0:epoch])
#     #         plt.grid()
#     #         plt.show()
#     #     if np.mod(epoch,args.save_epochs) == 0:
#     #         torch.save(model.state_dict(), RM_path_save)

# #     Dec_aux_load = resume_AE_ckpt.joinpath(Dec_name)
# #     torch.save(Dec.state_dict(), Dec_aux_load)

In [None]:
def save_models(Enc_base, Dec_base, Dec_aux_list, Dec_name_list, text_idx, model_folder):

    newfolder = model_folder + '/test_idx_' + str(test_idx)
    if not os.path.exists(newfolder):
        os.makedirs(newfolder)

    date = datetime.date.today()
    date_str = str(date.year) + '_' + str(date.month) + '_' + str(date.day)

    model = Enc_base
    model_name = 'Enc_base'
    save_path = model_folder + '/test_idx_' + str(test_idx) + '/' + model_name + '__' + 'aya' + '__' + date_str
    torch.save(model.state_dict(), save_path)

    model = Dec_base
    model_name = 'Dec_base'
    save_path = model_folder + '/test_idx_' + str(test_idx) + '/' + model_name + '__' + 'aya' + '__' + date_str
    torch.save(model.state_dict(), save_path)

    if len(Dec_aux_list) > 0:
    
        for i in range(len(Dec_aux_list)):
            model = Dec_aux_list[i]
            model_name = Dec_name_list[i]
            save_path = model_folder + '/test_idx_' + str(test_idx) + '/' + model_name + '__' + 'aya' + '__' + date_str
            torch.save(model.state_dict(), save_path)
        
    print('Done')

In [None]:
def load_models_AE(Enc_base, Dec_base, test_idx, model_folder, date_str):
    
#     date = datetime.date.today()
#     date_str = str(date.year) + '_' + str(date.month) + '_' + str(date.day)
    
    model = Enc_base
    model_name = 'Enc_base'
    load_path = model_folder + '/test_idx_' + str(test_idx) + '/' + model_name + '__' + 'aya' + '__' + date_str
    model.load_state_dict(torch.load(load_path))

    model = Dec_base
    model_name = 'Dec_base'
    load_path = model_folder + '/test_idx_' + str(test_idx) + '/' + model_name + '__' + 'aya' + '__' + date_str
    model.load_state_dict(torch.load(load_path))

    print('Done')

In [None]:
def load_models_FAZ(Enc_base, Dec_base, Dec_aux_list, Dec_name_list, test_idx, model_folder, date_str):
    
#     date = datetime.date.today()
#     date_str = str(date.year) + '_' + str(date.month) + '_' + str(date.day)
    
    model = Enc_base
    model_name = 'Enc_base'
    load_path = model_folder + '/test_idx_' + str(test_idx) + '/' + model_name + '__' + 'aya' + '__' + date_str
    model.load_state_dict(torch.load(load_path))

    model = Dec_base
    model_name = 'Dec_base'
    load_path = model_folder + '/test_idx_' + str(test_idx) + '/' + model_name + '__' + 'aya' + '__' + date_str
    model.load_state_dict(torch.load(load_path))

    for i in range(len(Dec_aux_list)):
        model = Dec_aux_list[i]
        model_name = Dec_name_list[i]
        load_path = model_folder + '/test_idx_' + str(test_idx) + '/' + model_name + '__' + 'aya' + '__' + date_str
        model.load_state_dict(torch.load(load_path))

    print('Done')

In [None]:
# # Total Variation (TV) regularization helps promote spatial smoothness and connectivity.
# def total_variation_loss(mask):
#     dx = np.diff(mask, axis=0)  # Differences along the x-axis
#     dy = np.diff(mask, axis=1)  # Differences along the y-axis

## TRAIN

In [None]:
def f_normalize(x):
    m1, _ = torch.max(x,axis=1)
    m2, _ = torch.max(m1,axis=1)
    return (x.T/m2).T

################## LOAD IN DATA ######################

############
a_n_mode = 0 # Use Aya data
# a_n_mode = 1 # Use Nathan data
############

## Photocoagulation Data

%cd Y:/FAZ/Photocoagulation/Laser and Mask - All-orig-extractedImages/mats

Photocoag_Imgs = torch.from_numpy(sio.loadmat('resized.mat')["Imgs_rsz"].astype(np.float32))
Photocoag_Masks = torch.from_numpy(sio.loadmat('real_masks.mat')["Imgs_mask"].astype(np.float32))

# Photocoag_Imgs /= torch.max(Photocoag_Imgs)
# Photocoag_Masks /= torch.max(Photocoag_Masks)

Photocoag_Imgs = f_normalize(Photocoag_Imgs)
Photocoag_Masks = f_normalize(Photocoag_Masks)

# L = len(Photocoag_Masks)
# for i in np.arange(L):
#     x = (Photocoag_Imgs + Photocoag_Masks)[i].detach().cpu().numpy().squeeze()
#     plt.imshow(x)
#     plt.show()

## Healthy Data

%cd Y:\FAZ\FAZ-cropped-SVC-graded\Healthy\all_Healthy\PR\mats

imgs_SVC = torch.from_numpy(sio.loadmat('SVC_resized.mat')["Imgs"].astype(np.float32))
masks_SVC_aya = torch.from_numpy(sio.loadmat('SVC_real_masks_aya.mat')["Imgs"].astype(np.float32))
masks_SVC_nathan = torch.from_numpy(sio.loadmat('SVC_real_masks_nathan.mat')["Imgs"].astype(np.float32))

# imgs_SVC /= torch.max(imgs_SVC)
# masks_SVC_aya /= torch.max(masks_SVC_aya)
# masks_SVC_nathan /= torch.max(masks_SVC_nathan)
imgs_SVC = f_normalize(imgs_SVC)
masks_SVC_aya = f_normalize(masks_SVC_aya)
masks_SVC_nathan = f_normalize(masks_SVC_nathan)
# masks_SVC_aya = 1.0*(masks_SVC_aya > 0.1)
# masks_SVC_nathan = 1.0*(masks_SVC_nathan > 0.1)

# masks_SVC = (masks_SVC_aya + masks_SVC_nathan)/2
if a_n_mode == 0:
    masks_SVC = masks_SVC_aya
else:
    masks_SVC = masks_SVC_nathan

# Healthy_Imgs_SVC = imgs_SVC
# Healthy_Masks_SVC = masks_SVC

Healthy_Imgs_SVC = imgs_SVC
Healthy_Masks_SVC = masks_SVC_aya
Healthy_Masks_SVC_alt = masks_SVC_nathan

# L = len(masks_SVC)
# for i in np.arange(L):
#     x = masks_SVC[i].detach().cpu().numpy().squeeze()
#     plt.imshow(x)
#     plt.show()
    
## ALZ Data

%cd Y:\FAZ\FAZ-cropped-SVC-graded\Alz\all_Alz\PR\mats

imgs_SVC = torch.from_numpy(sio.loadmat('SVC_resized.mat')["Imgs"].astype(np.float32))
masks_SVC_aya = torch.from_numpy(sio.loadmat('SVC_real_masks_aya.mat')["Imgs"].astype(np.float32))
masks_SVC_nathan = torch.from_numpy(sio.loadmat('SVC_real_masks_nathan.mat')["Imgs"].astype(np.float32))

# imgs_SVC /= torch.max(imgs_SVC)
# masks_SVC_aya /= torch.max(masks_SVC_aya)
# masks_SVC_nathan /= torch.max(masks_SVC_nathan)
imgs_SVC = f_normalize(imgs_SVC)
masks_SVC_aya = f_normalize(masks_SVC_aya)
masks_SVC_nathan = f_normalize(masks_SVC_nathan)
# masks_SVC_aya = 1.0*(masks_SVC_aya > 0.1)
# masks_SVC_nathan = 1.0*(masks_SVC_nathan > 0.1)

# masks_SVC = (masks_SVC_aya + masks_SVC_nathan)/2
if a_n_mode == 0:
    masks_SVC = masks_SVC_aya
else:
    masks_SVC = masks_SVC_nathan

# ALZ_Imgs_SVC = imgs_SVC
# ALZ_Masks_SVC = masks_SVC

ALZ_Imgs_SVC = imgs_SVC
ALZ_Masks_SVC = masks_SVC_aya
ALZ_Masks_SVC_alt = masks_SVC_nathan

# L = len(masks_SVC)
# for i in np.arange(L):
#     x = masks_SVC[i].detach().cpu().numpy().squeeze()
#     plt.imshow(x)
#     plt.show()
    
## AMD Data

%cd Y:\FAZ\FAZ-cropped-SVC-graded\AMD\all_AMD\PR\mats

imgs_SVC = torch.from_numpy(sio.loadmat('SVC_resized.mat')["Imgs"].astype(np.float32))
masks_SVC_aya = torch.from_numpy(sio.loadmat('SVC_real_masks_aya.mat')["Imgs"].astype(np.float32))
masks_SVC_nathan = torch.from_numpy(sio.loadmat('SVC_real_masks_nathan.mat')["Imgs"].astype(np.float32))

# imgs_SVC /= torch.max(imgs_SVC)
# masks_SVC_aya /= torch.max(masks_SVC_aya)
# masks_SVC_nathan /= torch.max(masks_SVC_nathan)
imgs_SVC = f_normalize(imgs_SVC)
masks_SVC_aya = f_normalize(masks_SVC_aya)
masks_SVC_nathan = f_normalize(masks_SVC_nathan)
# masks_SVC_aya = 1.0*(masks_SVC_aya > 0.1)
# masks_SVC_nathan = 1.0*(masks_SVC_nathan > 0.1)

# masks_SVC = (masks_SVC_aya + masks_SVC_nathan)/2
# masks_SVC = masks_SVC_aya
if a_n_mode == 0:
    masks_SVC = masks_SVC_aya
else:
    masks_SVC = masks_SVC_nathan

# AMD_Imgs_SVC = imgs_SVC
# AMD_Masks_SVC = masks_SVC

AMD_Imgs_SVC = imgs_SVC
AMD_Masks_SVC = masks_SVC_aya
AMD_Masks_SVC_alt = masks_SVC_nathan

# L = len(masks_SVC)
# for i in np.arange(L):
#     x = masks_SVC[i].detach().cpu().numpy().squeeze()
#     plt.imshow(x)
#     plt.show()
    
## DR Data

%cd Y:\FAZ\FAZ-cropped-SVC-graded\DR\all_DR\PR\mats
    
imgs_SVC = torch.from_numpy(sio.loadmat('SVC_resized.mat')["Imgs"].astype(np.float32))
masks_SVC_aya = torch.from_numpy(sio.loadmat('SVC_real_masks_aya.mat')["Imgs"].astype(np.float32))
masks_SVC_nathan = torch.from_numpy(sio.loadmat('SVC_real_masks_nathan.mat')["Imgs"].astype(np.float32))

# imgs_SVC /= torch.max(imgs_SVC)
# masks_SVC_aya /= torch.max(masks_SVC_aya)
# masks_SVC_nathan /= torch.max(masks_SVC_nathan)
imgs_SVC = f_normalize(imgs_SVC)
masks_SVC_aya = f_normalize(masks_SVC_aya)
masks_SVC_nathan = f_normalize(masks_SVC_nathan)
# masks_SVC_aya = 1.0*(masks_SVC_aya > 0.1)
# masks_SVC_nathan = 1.0*(masks_SVC_nathan > 0.1)

# masks_SVC = (masks_SVC_aya + masks_SVC_nathan)/2
if a_n_mode == 0:
    masks_SVC = masks_SVC_aya
else:
    masks_SVC = masks_SVC_nathan

# DR_Imgs_SVC = imgs_SVC
# DR_Masks_SVC = masks_SVC

DR_Imgs_SVC = imgs_SVC
DR_Masks_SVC = masks_SVC_aya
DR_Masks_SVC_alt = masks_SVC_nathan

# masks_SVC = 1.*(masks_SVC>0.5)

# L = len(masks_SVC)
# for i in np.arange(L):
#     x = masks_SVC[i].detach().cpu().numpy().squeeze()
#     plt.imshow(x)
#     plt.show()

print('Healthy Size:')
print(Healthy_Imgs_SVC.size())
print('ALZ Size:')
print(ALZ_Imgs_SVC.size())
print('AMD Size:')
print(AMD_Imgs_SVC.size())
print('DR Size:')
print(DR_Imgs_SVC.size())

### Autoencoder training loop

In [None]:
def f_L1(params):
    return torch.mean(torch.abs(torch.cat([x.view(-1) for x in params])))

In [None]:
def gaussian_kernel(size: int, sigma: float) -> torch.Tensor:
    """Generates a Gaussian kernel."""
    # Create a 2D Gaussian kernel
    kernel = torch.tensor(
        [[(1 / (2 * np.pi * sigma ** 2)) * 
          np.exp(-((x - size // 2) ** 2 + (y - size // 2) ** 2) / (2 * sigma ** 2))
          for y in range(size)] for x in range(size)]
    )
    # Normalize the kernel
    kernel /= kernel.sum()
    return kernel.unsqueeze(0).unsqueeze(0)  # Shape (1, 1, size, size)

In [None]:
def f_interpolate(inmask1, inmask2, args):
    mask1 = 1.*inmask1.unsqueeze(1)
    mask2 = 1.*inmask2.unsqueeze(1)

    beta = np.random.beta(0.5, 0.5)
    pflip = np.random.uniform(0,1)

    one_hot = torch.zeros(args.num_tasks).to(args.device)
#     beta = beta_v[epoch]
#     pflip = pflip_v[epoch]
    if pflip < 0.5:
        if pflip < 0.25:
            out = inmask1
            one_hot[task_idx] = 1
            
        else:
            out = inmask2
            one_hot[task_idx+int(args.num_tasks/2)-1] = 1
    else:
        one_hot[task_idx] = beta
        one_hot[task_idx+int(args.num_tasks/2)-1] = 1 - beta
        
#         interpolated = beta*mask1 + (1-beta)*mask2
#         kernel_size = 25
#         sigma = 30
#         kernel = gaussian_kernel(kernel_size, sigma).to(torch.float32)
#         smoothed1 = (mask2>0.5)*((1-mask1)>0.5)*torch.nn.functional.conv2d(mask1, kernel, padding=kernel_size // 2, groups=1)
#         smoothed_border = 1*smoothed1/torch.max(smoothed1)
# #         smoothed_border = torch.where(smoothed_border > 1, torch.tensor(1.0), smoothed_border)
#         smoothed_mask1 = smoothed_border + mask1
#         out = beta*smoothed_mask1 + (1-beta)*mask2

        out = beta*mask1 + (1-beta)*mask2
        
    ones = torch.ones(args.batch_size,1,args.num_tasks).to(args.device)
    one_hot_enc = ones*one_hot

    return out.squeeze(1), one_hot_enc

# task_idx = 0

# Train_Target_Data = Train_Target_Data_list[task_idx]

# # Train_Target_Data = Train_Target_Data > 0.5

# Train_Target_Data_alt = Train_Target_Data_list_alt[task_idx]

# Train_Target_Data, one_hot_enc = f_interpolate(Train_Target_Data, Train_Target_Data_alt, args)

In [None]:
# %cd Y:\FAZ\Models\Models New\Models New New\AE

# filt_idxs0 = np.loadtxt('filt_v_test_idx_0.csv', delimiter=',')
# filt_idxs1 = np.loadtxt('filt_v_test_idx_1.csv', delimiter=',')

# filt_idxs = np.sort(np.concatenate((filt_idxs0,filt_idxs1)))

# L_PC = len(Photocoag_Imgs)
# bad_idxs = [x for x in range(L_PC) if x not in filt_idxs]

# print(len(bad_idxs))
# print(len(filt_idxs))
# print(L_PC)

# # Photocoag_Imgs = Photocoag_Imgs[bad_idxs]
# # Photocoag_Masks = Photocoag_Masks[bad_idxs]
# Photocoag_Imgs = Photocoag_Imgs[filt_idxs]
# Photocoag_Masks = Photocoag_Masks[filt_idxs]

In [None]:
########################################## jump

## AE Loop

## DEFINE MODELS

args.num_tasks = 4*1
np.random.seed(2024)
num_splits = 2 # Two fold cross validation
args.num_aug = 2

args.num_epochs = 1000
args.show_example_epochs = 1
args.save_epochs = int(args.num_epochs/5)

task_names = ['Healthy', 'ALZ', 'AMD', 'DR', 'Photocoag']

## TRAIN TEST SPLIT
for test_idx in tqdm(np.arange(num_splits)):
# for test_idx in [1]:

#     for task_idx in tqdm(np.arange(args.num_tasks)):
#     for task_idx in tqdm(np.arange(int(args.num_tasks/2))):
    for task_idx in np.arange(4):
#     for task_idx in [4]:

#         model_folder = 'Y:/FAZ/Models/Models New/Models New New/AE/' + task_names[task_idx]
#         model_folder = 'Y:/FAZ/Models/Models New/Models New New/AE_with_2_annotations, with pretraining/' + task_names[task_idx]

        model_folder = 'Y:/FAZ/Models/Models New/Models New New/AE2/' + task_names[task_idx]
#         model_folder = 'Y:/FAZ/Models/Models New/Models New New/AE filtered 2_annotations/' + task_names[task_idx]

#         if task_idx == 0:
#             %cd Y:\FAZ\Models\Models New\AE\Healthy
#         if task_idx == 1:
#             %cd Y:\FAZ\Models\Models New\AE\AMD
#         if task_idx == 2:
#             %cd Y:\FAZ\Models\Models New\AE\ALZ
#         if task_idx == 3:
#             %cd Y:\FAZ\Models\Models New\AE\DR
#         if task_idx == 4:
#             %cd Y:\FAZ\Models\Models New\AE\PC

        save_name = '_test_idx_' + str(test_idx)
        test_folder = model_folder + '/Test_Results_' + str(test_idx)
        if not os.path.exists(model_folder):
            os.mkdir(model_folder)
        if not os.path.exists(test_folder):
            os.mkdir(test_folder)
        
        np.random.seed(2024)

#         splits = get_train_test_splits(Photocoag_Imgs,split=num_splits)
#         Photocoag_Imgs_train, Photocoag_Imgs_test, Photocoag_Masks_train, Photocoag_Masks_test = train_test_split(Photocoag_Imgs, Photocoag_Masks,splits,test_idx)
        Photocoag_Imgs_train = Photocoag_Imgs
        Photocoag_Imgs_test = Photocoag_Imgs
        Photocoag_Masks_train = Photocoag_Masks
        Photocoag_Masks_test = Photocoag_Masks
        
        splits = get_train_test_splits(Healthy_Imgs_SVC,split=num_splits)
        Healthy_Imgs_SVC_train, Healthy_Imgs_SVC_test, Healthy_RealMasks_SVC_train, Healthy_RealMasks_SVC_test = train_test_split(Healthy_Imgs_SVC, Healthy_Masks_SVC,splits,test_idx)
        Healthy_Imgs_SVC_train_alt, Healthy_Imgs_SVC_test_alt, Healthy_RealMasks_SVC_train_alt, Healthy_RealMasks_SVC_test_alt = train_test_split(Healthy_Imgs_SVC, Healthy_Masks_SVC_alt,splits,test_idx)

        splits = get_train_test_splits(ALZ_Imgs_SVC,split=num_splits)
        ALZ_Imgs_SVC_train, ALZ_Imgs_SVC_test, ALZ_RealMasks_SVC_train, ALZ_RealMasks_SVC_test = train_test_split(ALZ_Imgs_SVC, ALZ_Masks_SVC,splits,test_idx)
        ALZ_Imgs_SVC_train_alt, ALZ_Imgs_SVC_test_alt, ALZ_RealMasks_SVC_train_alt, ALZ_RealMasks_SVC_test_alt = train_test_split(ALZ_Imgs_SVC, ALZ_Masks_SVC_alt,splits,test_idx)

        splits = get_train_test_splits(AMD_Imgs_SVC,split=num_splits)
        AMD_Imgs_SVC_train, AMD_Imgs_SVC_test, AMD_RealMasks_SVC_train, AMD_RealMasks_SVC_test = train_test_split(AMD_Imgs_SVC, AMD_Masks_SVC,splits,test_idx)
        AMD_Imgs_SVC_train_alt, AMD_Imgs_SVC_test_alt, AMD_RealMasks_SVC_train_alt, AMD_RealMasks_SVC_test_alt = train_test_split(AMD_Imgs_SVC, AMD_Masks_SVC_alt,splits,test_idx)

        splits = get_train_test_splits(DR_Imgs_SVC,split=num_splits)
        DR_Imgs_SVC_train, DR_Imgs_SVC_test, DR_RealMasks_SVC_train, DR_RealMasks_SVC_test = train_test_split(DR_Imgs_SVC, DR_Masks_SVC,splits,test_idx)
        DR_Imgs_SVC_train_alt, DR_Imgs_SVC_test_alt, DR_RealMasks_SVC_train_alt, DR_RealMasks_SVC_test_alt = train_test_split(DR_Imgs_SVC, DR_Masks_SVC_alt,splits,test_idx)

        np.random.seed(None)

        # (For the test data, this is to bring everything to the appropriate length, it doesn't actually augment it):
    #         Photocoag_Imgs_test, Photocoag_Masks_test = f_augment_all_single(Photocoag_Imgs_test, Photocoag_Masks_test, p_rot = 0.5, p_flip = 0.5, p_jitter = 0.8, p_shear = 0.5, trans=0, N = 1, mode=1)
        Healthy_Imgs_SVC_test, Healthy_RealMasks_SVC_test = f_augment_all_single(Healthy_Imgs_SVC_test, Healthy_RealMasks_SVC_test, p_rot = 0.5, p_flip = 0.5, p_jitter = 0.8, p_shear = 0.5, trans=50, N = args.num_aug, mode=1)
        ALZ_Imgs_SVC_test, ALZ_RealMasks_SVC_test = f_augment_all_single(ALZ_Imgs_SVC_test, ALZ_RealMasks_SVC_test, p_rot = 0.5, p_flip = 0.5, p_jitter = 0.8, p_shear = 0.5, trans=50, N = args.num_aug, mode=1)
        AMD_Imgs_SVC_test, AMD_RealMasks_SVC_test = f_augment_all_single(AMD_Imgs_SVC_test, AMD_RealMasks_SVC_test, p_rot = 0.5, p_flip = 0.5, p_jitter = 0.8, p_shear = 0.5, trans=50, N = args.num_aug, mode=1)
        DR_Imgs_SVC_test, DR_RealMasks_SVC_test = f_augment_all_single(DR_Imgs_SVC_test, DR_RealMasks_SVC_test, p_rot = 0.5, p_flip = 0.5, p_jitter = 0.8, p_shear = 0.5, trans=50, N = args.num_aug, mode=1)

        _, Healthy_RealMasks_SVC_test_alt = f_augment_all_single(Healthy_Imgs_SVC_test_alt, Healthy_RealMasks_SVC_test_alt, p_rot = 0.5, p_flip = 0.5, p_jitter = 0.8, p_shear = 0.5, trans=50, N = args.num_aug, mode=1)
        _, ALZ_RealMasks_SVC_test_alt = f_augment_all_single(ALZ_Imgs_SVC_test_alt, ALZ_RealMasks_SVC_test_alt, p_rot = 0.5, p_flip = 0.5, p_jitter = 0.8, p_shear = 0.5, trans=50, N = args.num_aug, mode=1)
        _, AMD_RealMasks_SVC_test_alt = f_augment_all_single(AMD_Imgs_SVC_test_alt, AMD_RealMasks_SVC_test_alt, p_rot = 0.5, p_flip = 0.5, p_jitter = 0.8, p_shear = 0.5, trans=50, N = args.num_aug, mode=1)
        _, DR_RealMasks_SVC_test_alt = f_augment_all_single(DR_Imgs_SVC_test_alt, DR_RealMasks_SVC_test_alt, p_rot = 0.5, p_flip = 0.5, p_jitter = 0.8, p_shear = 0.5, trans=50, N = args.num_aug, mode=1)

        Photocoag_Imgs_test = Photocoag_Imgs_test.unsqueeze(1).unsqueeze(1).to(args.device)
        Photocoag_Masks_test = Photocoag_Masks_test.unsqueeze(1).unsqueeze(1).to(args.device)
        Healthy_Imgs_SVC_test = Healthy_Imgs_SVC_test.unsqueeze(1).unsqueeze(1).to(args.device)
        Healthy_RealMasks_SVC_test = Healthy_RealMasks_SVC_test.unsqueeze(1).unsqueeze(1).to(args.device)
        ALZ_Imgs_SVC_test = ALZ_Imgs_SVC_test.unsqueeze(1).unsqueeze(1).to(args.device)
        ALZ_RealMasks_SVC_test = ALZ_RealMasks_SVC_test.unsqueeze(1).unsqueeze(1).to(args.device)
        AMD_Imgs_SVC_test = AMD_Imgs_SVC_test.unsqueeze(1).unsqueeze(1).to(args.device)
        AMD_RealMasks_SVC_test = AMD_RealMasks_SVC_test.unsqueeze(1).unsqueeze(1).to(args.device)
        DR_Imgs_SVC_test = DR_Imgs_SVC_test.unsqueeze(1).unsqueeze(1).to(args.device)
        DR_RealMasks_SVC_test = DR_RealMasks_SVC_test.unsqueeze(1).unsqueeze(1).to(args.device)

        Healthy_RealMasks_SVC_test_alt = Healthy_RealMasks_SVC_test_alt.unsqueeze(1).unsqueeze(1).to(args.device)
        ALZ_RealMasks_SVC_test_alt = ALZ_RealMasks_SVC_test_alt.unsqueeze(1).unsqueeze(1).to(args.device)
        AMD_RealMasks_SVC_test_alt = AMD_RealMasks_SVC_test_alt.unsqueeze(1).unsqueeze(1).to(args.device)
        DR_RealMasks_SVC_test_alt = DR_RealMasks_SVC_test_alt.unsqueeze(1).unsqueeze(1).to(args.device)

        Train_Data_list = [Healthy_Imgs_SVC_train, ALZ_Imgs_SVC_train, AMD_Imgs_SVC_train, DR_Imgs_SVC_train, Photocoag_Imgs_train]
        Train_Target_Data_list = [Healthy_RealMasks_SVC_train, ALZ_RealMasks_SVC_train, AMD_RealMasks_SVC_train, DR_RealMasks_SVC_train, Photocoag_Masks_train]
        Test_Data_list = [Healthy_Imgs_SVC_test, ALZ_Imgs_SVC_test, AMD_Imgs_SVC_test, DR_Imgs_SVC_test, Photocoag_Imgs_test]
        Test_Target_Data_list = [Healthy_RealMasks_SVC_test, ALZ_RealMasks_SVC_test, AMD_RealMasks_SVC_test, DR_RealMasks_SVC_test, Photocoag_Masks_test]
    #         Dec_aux_list = [Dec_aux_healthy_mask, Dec_aux_ALZ_mask, Dec_aux_AMD_mask, Dec_aux_DR_mask]
    #         Dec_name_list = ['Dec_aux_healthy_mask', 'Dec_aux_ALZ_mask', 'Dec_aux_AMD_mask', 'Dec_aux_DR_mask']

        Train_Target_Data_list_alt = [Healthy_RealMasks_SVC_train_alt, ALZ_RealMasks_SVC_train_alt, AMD_RealMasks_SVC_train_alt, DR_RealMasks_SVC_train_alt, Photocoag_Masks_train]
        Test_Target_Data_list_alt = [Healthy_RealMasks_SVC_test_alt, ALZ_RealMasks_SVC_test_alt, AMD_RealMasks_SVC_test_alt, DR_RealMasks_SVC_test_alt, Photocoag_Masks_test]

        ##################################

        # args.fac = 2

        args.d_attn0 = 128
        args.d_attn1 = 256
        args.d_attn2 = 512

    #         args.d_attn0 = 64
    #         args.d_attn1 = 128
    #         args.d_attn2 = 256

        nhead_base = 8
        # nhead = int(nhead_base/args.fac)
        # args.nhead = nhead

        Enc_base = Unet_Enc_base(nhead_base, args).to(args.device)
        Dec_base = Unet_Dec_base(nhead_base, args).to(args.device)
        
#         #################
#         # Load weights
# #         load_folder = 'Y:\FAZ\Models\Models New\Models New New\AE\Photocoag'
# #         load_folder = 'Y:\FAZ\Models\Models New\Models New New\PC Pretrain'
# #         load_folder = 'Y:\FAZ\Models\Models New\Models New New\PC Pretrain filtered'
# #         load_folder = 'Y:\FAZ\Models\Models New\Models New New\PC Pretrain bad_idxs'
#         load_folder = 'Y:\FAZ\Models\Models New\Models New New\PC Pretrain two annotations'
#         load_folder = load_folder.replace('\\', '/')
#         date_str = '2024_11_18'
# #         load_models_AE(Enc_base, Dec_base, test_idx, load_folder, date_str)
#         load_models_AE(Enc_base, Dec_base, 0, load_folder, date_str)
#         #################

        Train_Data = Train_Data_list[task_idx]
        Train_Target_Data = Train_Target_Data_list[task_idx]
        Test_Data_aug = Test_Data_list[task_idx]
        Test_Target_Data_aug = Test_Target_Data_list[task_idx]

        Train_Target_Data_alt = Train_Target_Data_list_alt[task_idx]
        Test_Target_Data_aug_alt = Test_Target_Data_list_alt[task_idx]

        params_list = list(Enc_base.parameters()) + list(Dec_base.parameters())
        lr = 1E-4
        optimizer = torch.optim.Adam(params_list, lr=lr, betas=(0.9, 0.999))

        one_hot = torch.zeros(args.num_tasks).to(args.device)
        one_hot[task_idx] = 1
        ones = torch.ones(args.batch_size,1,args.num_tasks).to(args.device)
        one_hot_enc = ones*one_hot

        #####################Training loop ###########################

        mean_epoch_losses = np.zeros(args.num_epochs)
        log_mean_epoch_losses = np.zeros(args.num_epochs)

        train_dice = np.zeros(args.num_epochs)
        test_dice = np.zeros(args.num_epochs)

        loss = nn.BCELoss()
        # loss_b = Dice_loss()
        loss2 = nn.L1Loss()

        # Train for a maximum of max_epochs:
        for epoch in tqdm(np.arange(args.num_epochs), desc="Training progress..."):
            #################################

            Train_Target = Train_Target_Data
#             Train_Target, one_hot_enc = f_interpolate(Train_Target_Data, Train_Target_Data_alt, args)

            if task_names[task_idx] != 'Photocoag':
                trans = 50
                p_shear = 0.5
                p_rot = 0.5
                num_aug = args.num_aug
            else:
                trans = 0
                p_shear = 0
                p_rot = 0.5
                num_aug = 1
    #                 Train_Data_aug = Train_Data
    #                 Train_Target_Data_aug = Train_Target_Data
            Train_Data_aug, Train_Target_Data_aug = f_augment_all_single(Train_Data, Train_Target, p_rot = p_rot, p_flip = 0.5, p_jitter = 0.8, p_shear = p_shear, trans = trans, N = num_aug)

            Train_Data_aug = Train_Data_aug.unsqueeze(1).unsqueeze(1).to(args.device)
            Train_Target_Data_aug = Train_Target_Data_aug.unsqueeze(1).unsqueeze(1).to(args.device)

            args.num_batches = int(len(Train_Data_aug)/args.N)

            #################################

            epoch_losses = single_epoch_AE(Train_Data_aug, Train_Target_Data_aug, Enc_base, Dec_base, one_hot_enc, optimizer, loss, loss2, params_list, args)

            mean_epoch_loss = np.mean(epoch_losses)
            mean_epoch_losses[epoch] = mean_epoch_loss
            log_mean_epoch_losses[epoch] = np.log(mean_epoch_loss)

            train_dice[epoch] = f_Tot_Dice(Train_Data_aug, Train_Target_Data_aug,  Enc_base, Dec_base, one_hot_enc, args)
            test_dice[epoch] = f_Tot_Dice(Test_Data_aug, Test_Target_Data_aug,  Enc_base, Dec_base, one_hot_enc, args)

            if np.mod(epoch,args.show_example_epochs) == args.show_example_epochs-1:
                print("Task:", task_names[task_idx])
                
                display_AE(Train_Data_aug, Train_Target_Data_aug, Enc_base, Dec_base, one_hot_enc, args)

                fig1, ax = plt.subplots(nrows=1, ncols=1 )
                plt.plot(log_mean_epoch_losses[0:epoch])
                plt.grid()
                plt.show()

                colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k', 'orange', 'aquamarine', 'mediumseagreen']
                fig2, ax = plt.subplots(nrows=1, ncols=1 )

                plt.plot(train_dice[0:epoch],color=colors[task_idx])
    #             plt.grid()
    #             plt.show()

                plt.plot(test_dice[0:epoch],color=colors[task_idx],linestyle='--')
                plt.grid()
                plt.show()

            if np.mod(epoch,args.save_epochs) == args.save_epochs-1:    

                np.savetxt(model_folder + '/' + 'task_idx'+ str(task_idx) + '_mean_epoch_losses' + '_epoch_' + str(epoch+1) + '_' + save_name + '.txt', mean_epoch_losses, fmt='%.2f')
                np.savetxt(model_folder + '/' +'task_idx'+ str(task_idx) + '_train_dice' + '_epoch_' + str(epoch+1) + '_' + save_name + '.txt', train_dice, fmt='%.2f')
                np.savetxt(model_folder + '/' +'task_idx'+ str(task_idx) + '_test_dice' + '_epoch_' + str(epoch+1) + '_' + save_name + '.txt', test_dice, fmt='%.2f')

                fig1.savefig(model_folder + '/' +'task_idx'+ str(task_idx) + '_loss' + save_name + '.png')
                fig2.savefig(model_folder + '/' +'task_idx'+ str(task_idx) + '_dice' + save_name + '.png')

                save_models(Enc_base, Dec_base, [], [], test_idx, model_folder)

        np.savetxt(model_folder + '/' +'task_idx'+ str(task_idx) + '_mean_epoch_losses' + '_epoch_' + str(epoch+1) + '_' + save_name + '.txt', mean_epoch_losses, fmt='%.2f')
        np.savetxt(model_folder + '/' +'task_idx'+ str(task_idx) + '_train_dice' + '_epoch_' + str(epoch+1) + '_' + save_name + '.txt', train_dice, fmt='%.2f')
        np.savetxt(model_folder + '/' +'task_idx'+ str(task_idx) + '_test_dice' + '_epoch_' + str(epoch+1) + '_' + save_name + '.txt', test_dice, fmt='%.2f')

        fig1.savefig(model_folder + '/' +'task_idx'+ str(task_idx) + '_loss' + save_name + '.png')
        fig2.savefig(model_folder + '/' +'task_idx'+ str(task_idx) + '_dice' + save_name + '.png')

        save_models(Enc_base, Dec_base, [], [], test_idx, model_folder)
        
        f_display_all(Test_Data_aug, Test_Target_Data_aug, Enc_base, Dec_base, one_hot_enc, splits[test_idx], args, test_folder)

In [None]:
task_idx = 4

test_idx = 1

np.random.seed(2024)

splits = get_train_test_splits(Photocoag_Imgs,split=num_splits)
Photocoag_Imgs_train, Photocoag_Imgs_test, Photocoag_Masks_train, Photocoag_Masks_test = train_test_split(Photocoag_Imgs, Photocoag_Masks,splits,test_idx)

# splits = get_train_test_splits(Healthy_Imgs_SVC,split=num_splits)
# Healthy_Imgs_SVC_train, Healthy_Imgs_SVC_test, Healthy_RealMasks_SVC_train, Healthy_RealMasks_SVC_test = train_test_split(Healthy_Imgs_SVC, Healthy_Masks_SVC,splits,test_idx)
# Healthy_Imgs_SVC_train_alt, Healthy_Imgs_SVC_test_alt, Healthy_RealMasks_SVC_train_alt, Healthy_RealMasks_SVC_test_alt = train_test_split(Healthy_Imgs_SVC, Healthy_Masks_SVC_alt,splits,test_idx)

# splits = get_train_test_splits(ALZ_Imgs_SVC,split=num_splits)
# ALZ_Imgs_SVC_train, ALZ_Imgs_SVC_test, ALZ_RealMasks_SVC_train, ALZ_RealMasks_SVC_test = train_test_split(ALZ_Imgs_SVC, ALZ_Masks_SVC,splits,test_idx)
# ALZ_Imgs_SVC_train_alt, ALZ_Imgs_SVC_test_alt, ALZ_RealMasks_SVC_train_alt, ALZ_RealMasks_SVC_test_alt = train_test_split(ALZ_Imgs_SVC, ALZ_Masks_SVC_alt,splits,test_idx)

# splits = get_train_test_splits(AMD_Imgs_SVC,split=num_splits)
# AMD_Imgs_SVC_train, AMD_Imgs_SVC_test, AMD_RealMasks_SVC_train, AMD_RealMasks_SVC_test = train_test_split(AMD_Imgs_SVC, AMD_Masks_SVC,splits,test_idx)
# AMD_Imgs_SVC_train_alt, AMD_Imgs_SVC_test_alt, AMD_RealMasks_SVC_train_alt, AMD_RealMasks_SVC_test_alt = train_test_split(AMD_Imgs_SVC, AMD_Masks_SVC_alt,splits,test_idx)

# splits = get_train_test_splits(DR_Imgs_SVC,split=num_splits)
# DR_Imgs_SVC_train, DR_Imgs_SVC_test, DR_RealMasks_SVC_train, DR_RealMasks_SVC_test = train_test_split(DR_Imgs_SVC, DR_Masks_SVC,splits,test_idx)
# DR_Imgs_SVC_train_alt, DR_Imgs_SVC_test_alt, DR_RealMasks_SVC_train_alt, DR_RealMasks_SVC_test_alt = train_test_split(DR_Imgs_SVC, DR_Masks_SVC_alt,splits,test_idx)

# print(Healthy_Imgs_SVC_train.size())
# print(Healthy_Imgs_SVC_test.size())
print(Photocoag_Imgs_train.size())
print(Photocoag_Imgs_test.size())

# print(splits)

In [None]:
# np.random.seed(2024)

# splits = get_train_test_splits(Photocoag_Imgs,split=num_splits)
# Photocoag_Imgs_train, Photocoag_Imgs_test, Photocoag_Masks_train, Photocoag_Masks_test = train_test_split(Photocoag_Imgs, Photocoag_Masks,splits,test_idx)

In [None]:
args.num_tasks = 5*2

Enc_base = Unet_Enc_base(nhead_base, args).to(args.device)
Dec_base = Unet_Dec_base(nhead_base, args).to(args.device)

In [None]:
date_str = '2024_11_6'

# # test_idx = 1
# # date_str = '2024_10_20'

#################
# Load weights
%cd Y:\FAZ\Models\Models New\Models New New\AE
load_folder = 'Y:\FAZ\Models\Models New\Models New New\AE' + '\\' + task_names[task_idx]
load_folder = load_folder.replace('\\', '/')
load_models_AE(Enc_base, Dec_base, test_idx, load_folder, date_str)
#################

In [None]:
# Train_Data = Train_Data_list[task_idx]
# Train_Target_Data = Train_Target_Data_list[task_idx]

Train_Data = Photocoag_Imgs_train
Train_Target_Data = Photocoag_Masks_train
Test_Data = Photocoag_Imgs_test.unsqueeze(1).unsqueeze(1).to(args.device)
Test_Target_Data = Photocoag_Masks_test.unsqueeze(1).unsqueeze(1).to(args.device)

Train_Data_aug, Train_Target_Data_aug = f_augment_all_single(Train_Data, Train_Target_Data, p_rot = p_rot, p_flip = 0.5, p_jitter = 0.8, p_shear = p_shear, trans = trans, N = 1, mode=1)

# Test_Data_aug = Test_Data_list[task_idx]
# Test_Target_Data_aug = Test_Target_Data_list[task_idx]
Test_Data_aug = Photocoag_Imgs_test.unsqueeze(1).unsqueeze(1).to(args.device)
Test_Target_Data_aug = Photocoag_Masks_test.unsqueeze(1).unsqueeze(1).to(args.device)

Train_Data_aug = Train_Data_aug.unsqueeze(1).unsqueeze(1).to(args.device)
Train_Target_Data_aug = Train_Target_Data_aug.unsqueeze(1).unsqueeze(1).to(args.device)

one_hot = torch.zeros(args.num_tasks).to(args.device)
one_hot[task_idx] = 1
ones = torch.ones(args.batch_size,1,args.num_tasks).to(args.device)
one_hot_enc = ones*one_hot

print(Train_Data_aug.size())
print(Test_Data_aug.size())
print(Train_Target_Data_aug.size())
print(Test_Target_Data_aug.size())

Test_Data_aug = rearrange_by_index(Test_Data_aug)
Test_Target_Data_aug = rearrange_by_index(Test_Target_Data_aug)

In [None]:
filt_v = f_display_all(Test_Data, Test_Target_Data, Enc_base, Dec_base, one_hot_enc, splits[test_idx], args, model_folder)
print(filt_v)
print(len(filt_v))
f_Tot_Dice(Test_Data_aug, Test_Target_Data_aug,  Enc_base, Dec_base, one_hot_enc, args)

np.savetxt('filt_v_test_idx_' + str(test_idx) + ".csv", filt_v, delimiter=",")

In [None]:
# x = Photocoag_Imgs[filt_v].unsqueeze(1).unsqueeze(1).to(args.device)
# y =  Photocoag_Masks[filt_v].unsqueeze(1).unsqueeze(1).to(args.device)

# filt_v = f_display_all(x,y, Enc_base, Dec_base, one_hot_enc, splits[test_idx], args, model_folder)
# print(filt_v)
# print(len(filt_v))
# f_Tot_Dice(Test_Data_aug, Test_Target_Data_aug,  Enc_base, Dec_base, one_hot_enc, args)

# # np.savetxt('filt_v_test_idx_' + str(test_idx) + ".csv", filt_v, delimiter=",")

In [None]:
filt_idxs0 = np.loadtxt('filt_v_test_idx_0.csv', delimiter=',')
filt_idxs1 = np.loadtxt('filt_v_test_idx_1.csv', delimiter=',')

set1 = set(filt_idxs0)
set2 = set(filt_idxs1)
print(list(set1.intersection(set2)))

print(filt_idxs0)
print(filt_idxs1)

filt_idxs = np.sort(np.concatenate((filt_idxs0,filt_idxs1)))
np.shape(filt_idxs)
filt_idxs
print(len(filt_idxs))

In [None]:
# i1 = Photocoag_Imgs[filt_idxs].unsqueeze(1).unsqueeze(2).to(args.device)
# i2 = Photocoag_Masks[filt_idxs].unsqueeze(1).unsqueeze(2).to(args.device)

# f_display_all(i1, i2, Enc_base, Dec_base, one_hot_enc, splits[1-test_idx], args, model_folder)


In [None]:
filt_v = f_display_all(Train_Data_aug, Train_Target_Data_aug, Enc_base, Dec_base, one_hot_enc, splits, args, model_folder)
print(filt_v)

### Differentiable Mask Generator

In [None]:
# class SimpleConvNet(nn.Module):
#     def __init__(self, sigma, smooth_kernel_size):
#         super(SimpleConvNet, self).__init__()

#         self.conv1 = nn.Conv2d(1, 16, kernel_size=smooth_kernel_size, padding=smooth_kernel_size // 2, bias=False)
#         self.relu1 = nn.ReLU()
#         self.conv2 = nn.Conv2d(16, 16, kernel_size=smooth_kernel_size, padding=smooth_kernel_size // 2, bias=False)
#         self.relu2 = nn.ReLU()
#         self.conv3 = nn.Conv2d(16, 1, kernel_size=smooth_kernel_size, padding=smooth_kernel_size // 2, bias=False)

#         # Initialize the convolutional layers with standard initialization
#         self.initialize_weights()

#     def initialize_weights(self):
#         """Use Xavier initialization and scale the output."""
#         nn.init.xavier_uniform_(self.conv1.weight)
#         self.conv1.weight.data *= 0.1  # Scale down by a factor of 0.1

#         nn.init.xavier_uniform_(self.conv2.weight)
#         self.conv2.weight.data *= 0.1  # Scale down by a factor of 0.1

#         nn.init.xavier_uniform_(self.conv3.weight)
#         self.conv3.weight.data *= 0.1  # Scale down by a factor of 0.1

#     def forward(self, x):
#         # Apply the first convolution and residual connection
#         out1 = self.conv1(x)
#         out1_relu = self.relu1(out1)
#         out1_residual = out1 + x  # Residual connection

#         # Apply the second convolution and residual connection
#         out2 = self.conv2(out1_residual)
#         out2_relu = self.relu2(out2)
#         out2_residual = out2 + out1_residual  # Residual connection

#         # Apply the third convolution
#         return self.conv3(out2_residual)

# class CNN(nn.Module):
#     def __init__(self):
#         super(CNN, self).__init__()

#         # Define a symmetric 5-layer CNN with padding to preserve input size
#         self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, padding=1)  # (N, 1, H, W) -> (N, 16, H, W)
#         self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1) # (N, 16, H, W) -> (N, 32, H, W)
#         self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1) # (N, 32, H, W) -> (N, 64, H, W)
#         self.conv4 = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, padding=1)  # (N, 64, H, W) -> (N, 32, H, W)
#         self.conv5 = nn.Conv2d(in_channels=32, out_channels=1, kernel_size=3, padding=1)   # (N, 32, H, W) -> (N, 1, H, W)

#         # Activation function
#         self.relu = nn.ReLU()

#     def forward(self, x):
#         # Pass through the layers with ReLU activations
#         x = self.relu(self.conv1(x))  # (N, 16, H, W)
#         x = self.relu(self.conv2(x))  # (N, 32, H, W)
#         x = self.relu(self.conv3(x))  # (N, 64, H, W)
#         x = self.relu(self.conv4(x))  # (N, 32, H, W)
#         x = self.conv5(x)              # (N, 1, H, W)
#         return x
    
# class DifferentiableMaskGenerator(nn.Module):
#     def __init__(self, sigma=1.0, smooth_kernel_size=5, initial_threshold=5.0):
#         super(DifferentiableMaskGenerator, self).__init__()

#         # Replace smoothing convolution with SimpleConvNet
#         self.smoothing_conv = SimpleConvNet(sigma, smooth_kernel_size)
#         self.smoothing_conv1 = nn.Conv2d(1, 1, kernel_size=smooth_kernel_size, padding=smooth_kernel_size // 2, bias=False)
#         self.init_smoothing_kernel(sigma)

#         # Initialize a learnable threshold parameter
#         self.c = nn.Parameter(torch.tensor(1.0))
#         self.threshold = nn.Parameter(torch.tensor(initial_threshold))
        
#         self.Sigmoid = nn.Sigmoid()
# #         self.ReLU = nn.ReLU()
        
#     def init_smoothing_kernel(self, sigma):
#         """Initialize the smoothing kernel to behave like a Gaussian blur with the given sigma."""
#         def gaussian_kernel(size, sigma):
#             coords = torch.arange(size).float()
#             x_grid, y_grid = torch.meshgrid(coords, coords, indexing='ij')
#             center = size // 2
#             gaussian_kernel = 1 / (2.0 * np.pi * sigma ** 2) * torch.exp(-((x_grid - center) ** 2 + (y_grid - center) ** 2) / (2 * sigma ** 2))
#             gaussian_kernel /= gaussian_kernel.sum()
#             return gaussian_kernel

#         size = self.smoothing_conv1.kernel_size[0]
#         with torch.no_grad():
#             self.smoothing_conv1.weight.copy_(gaussian_kernel(size, sigma).view(1, 1, size, size))

#     def forward(self, c_img):
#         # Apply Gaussian smoothing using SimpleConvNet
#         smoothed_img = self.smoothing_conv(c_img) + self.smoothing_conv1(c_img) # [N, 1, H, W]

#         # Soft thresholding
#         soft_mask = self.c / (1 + torch.exp(smoothed_img * 255 - self.threshold))

#         return 1 - soft_mask  # Return the opened mask

# class StackedMaskGenerator(nn.Module):
#     def __init__(self, layer_params):
#         super(StackedMaskGenerator, self).__init__()
#         self.layers = nn.ModuleList()
#         for params in layer_params:
#             # Unpack the parameters for each layer
#             sigma, smooth_kernel_size, initial_threshold = params
#             self.layers.append(DifferentiableMaskGenerator(sigma, smooth_kernel_size, initial_threshold))
        
#         # Replace dilation convolution with SimpleConvNet
#         self.dilation_conv = SimpleConvNet(sigma=1.0, smooth_kernel_size=5)
#         self.dilation_conv1 = nn.Conv2d(1, 1, kernel_size=5, padding=2, bias=False)  # Adjust kernel size as needed
#         nn.init.ones_(self.dilation_conv1.weight)  # Initialize the weights to ones for dilation effect

#         self.scaling_factor1 = nn.Parameter(torch.tensor(1.0))
#         self.scaling_factor2 = nn.Parameter(torch.tensor(1.0))
        
#         self.CNN_enc = CNN()
#         self.CNN_dec = CNN()
        
#         self.Sigmoid = nn.Sigmoid()
        
#     def forward(self, c_img):
#         # Pass the input through all layers
#         c_img = self.CNN_enc(c_img)
#         output_mask = c_img
#         first_layer_output = None  # Variable to store output of the first layer

#         for idx, layer in enumerate(self.layers):
#             output_mask = layer(output_mask)  # Feed the output of the previous layer into the next
#             if idx == 0:
#                 first_layer_output = 1 - output_mask  # Store the output of the first layer

#         # If there's a second layer, apply the learnable dilation convolution
#         if len(self.layers) > 1:
#             second_layer_output = 1 - output_mask

#             # Use the learnable dilation convolution (now using SimpleConvNet)
#             dilated_mask = self.dilation_conv(second_layer_output) + self.dilation_conv1(second_layer_output)  # Dilation using SimpleConvNet

# #             # Ensure the mask values are between 0 and 1
# #             dilated_mask = 1.0*(dilated_mask > 0)  # Convert to binary
#             dilated_mask = self.Sigmoid(dilated_mask * self.scaling_factor1)

#             # Multiply the dilated mask with the first layer's output
#             output_mask = first_layer_output * dilated_mask
        
#         output_mask = self.CNN_dec(output_mask)
#         output_mask = self.Sigmoid(output_mask * self.scaling_factor2)

#         return output_mask  # Return the output mask if only one layer is present

In [None]:
# # Hyperparameters
# num_epochs = 100
# learning_rate = 1E-3

# layer_params = [
#     (1.0, 5, 5.0), # Parameters for layer 1
#     (1.0, 5, 5.0)  # Parameters for layer 2
# ]
# model = StackedMaskGenerator(layer_params).to(args.device)
# model.train()  # Set the model to training mode

# loss = nn.BCELoss()
# optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Train_Data = Train_Data_list[0]
# Target_Data = Train_Target_Data_list[0]

# Train_Data_aug, Train_Target_Data_aug = f_augment_all_single(Train_Data, Train_Target_Data, p_rot = p_rot, p_flip = 0.5, p_jitter = 0.8, p_shear = p_shear, trans = trans, N = num_aug)

# Train_Data_aug = Train_Data_aug.unsqueeze(1).unsqueeze(1).to(args.device)
# Train_Target_Data_aug = Train_Target_Data_aug.unsqueeze(1).unsqueeze(1).to(args.device)

# def single_iter(model, optimizer, loss, in_frames, target_frames):

#     optimizer.zero_grad() # Zero out gradients

#     out_frames = model(in_frames)
#     loss_i = loss(out_frames,target_frames)

#     loss_i.backward()
#     torch.nn.utils.clip_grad_norm_(params_list, 1)
#     optimizer.step()
    
#     return loss_i

# def single_epoch(Train_Data, Target_Data, model, optimizer, loss, args):
    
#     random_idxs = np.arange(len(Train_Data))
#     np.random.shuffle(random_idxs)
#     train_data_shuffle = Train_Data[random_idxs]
#     target_data_shuffle = Target_Data[random_idxs]

#     epoch_losses = np.zeros(args.num_batches)

#     for it in range(args.num_batches):

#         in_frames = train_data_shuffle[it*args.batch_size:(it+1)*args.batch_size,:,:,:,:]
#         target_frames = target_data_shuffle[it*args.batch_size:(it+1)*args.batch_size,:,:,:,:]
        
#         in_frames = in_frames.squeeze(2)
#         target_frames = target_frames.squeeze(2)

#         epoch_losses[it] = single_iter(model, optimizer, loss, in_frames, target_frames)
        
#     return epoch_losses

# # Training loop
# mean_epoch_losses = np.zeros(num_epochs)
# for epoch in range(num_epochs):

#     epoch_losses = single_epoch(Train_Data_aug, Train_Target_Data_aug, model, optimizer, loss, args)
    
#     mean_epoch_losses[epoch] = np.mean(epoch_losses)
    
#     plt.plot(mean_epoch_losses)
#     plt.show()
    
#     in_frames = Train_Data_aug[0]
#     out_frames = Train_Target_Data_aug[0]
#     out_frames = model(in_frames)
#     plt.imshow(in_frames + out_frames)
#     plt.show()

In [None]:
# print(num_params(Enc_base.Enc1))
# print(num_params(Enc_base.ea0))
# print(num_params(Enc_base.ea1))
# print(num_params(Enc_base.ea2))
# print(num_params(Enc_base.ea3))
# print(num_params(Enc_base.l1))
# print(num_params(Enc_base.l2))
# print(num_params(Enc_base.te0u_1))
# print(num_params(Enc_base.te0u_2))
# print(num_params(Enc_base.te1u_1))
# print(num_params(Enc_base.te1u_2))
# print(num_params(Enc_base.te2u_1))
# print(num_params(Enc_base.te2u_2))
# print(num_params(Enc_base.e0))
# print(num_params(Enc_base.e0a))
# print(num_params(Enc_base.e0b))
# print(num_params(Enc_base.e0c))
# print(num_params(Enc_base.e1))
# print(num_params(Enc_base.e1a))
# print(num_params(Enc_base.e1b))
# print(num_params(Enc_base.e2))
# print(num_params(Enc_base.e2a))
# print(num_params(Enc_base.e2b))
# print(num_params(Enc_base.C1))
# print(num_params(Enc_base.DA1))
# print(num_params(Enc_base.DA2))

# print(num_params(Enc_base))
# print(num_params(Dec_base))

### Train FAZ Loop

In [None]:
%cd Y:\FAZ\Models\Models New\Models New New\AE

filt_idxs0 = np.loadtxt('filt_v_test_idx_0.csv', delimiter=',')
filt_idxs1 = np.loadtxt('filt_v_test_idx_1.csv', delimiter=',')

filt_idxs = np.sort(np.concatenate((filt_idxs0,filt_idxs1)))

L_PC = len(Photocoag_Imgs)
bad_idxs = [x for x in range(L_PC) if x not in filt_idxs]

print(len(bad_idxs))
print(len(filt_idxs))
print(L_PC)

# Photocoag_Imgs = Photocoag_Imgs[bad_idxs]
# Photocoag_Masks = Photocoag_Masks[bad_idxs]
Photocoag_Imgs = Photocoag_Imgs[filt_idxs]
Photocoag_Masks = Photocoag_Masks[filt_idxs]

In [None]:
########################################## jump

# FAZ Loop

## TRAIN TEST SPLIT

%cd Y:\FAZ\Models
# model_folder = 'Y:/FAZ/Models'
# model_folder = 'Y:\FAZ\Models\Models New\Models New New\FAZ 4 tasks, 2 annotations, no pretraining'
model_folder = 'Y:\FAZ\Models\Models New\Models New New\Multitask_PC_pretrain_filtered'

# args.training_mode = 1

args.num_tasks = 4*1

args.num_epochs = 500
#     args.show_example_epochs = args.num_tasks
args.show_example_epochs = 1
args.save_epochs = int(args.num_epochs/5)
# args.augment_epochs = 4*args.num_tasks

# task_names = ['Healthy', 'ALZ', 'AMD', 'DR']
task_names = ['Healthy', 'ALZ', 'AMD', 'DR', 'Photocoag']
# task_names = ['Healthy', 'ALZ', 'AMD', 'DR', 'Healthy_AE', 'ALZ_AE', 'AMD_AE', 'DR_AE']

epoch_per_task = 1

args.num_aug = 2

np.random.seed(2024)
num_splits = 1 # N-fold cross validation

# for test_idx in np.arange(num_splits):
for test_idx in [0]:

    save_name = '_test_idx_' + str(test_idx)    

#     splits = get_train_test_splits(Photocoag_Imgs,split=num_splits)
#     Photocoag_Imgs_train, Photocoag_Imgs_test, Photocoag_Masks_train, Photocoag_Masks_test = train_test_split(Photocoag_Imgs, Photocoag_Masks,splits,test_idx)
    Photocoag_Imgs_train = Photocoag_Imgs
    Photocoag_Masks_train = Photocoag_Masks
    Photocoag_Imgs_test = Photocoag_Imgs
    Photocoag_Masks_test = Photocoag_Masks
    
    splits = get_train_test_splits(Healthy_Imgs_SVC,split=num_splits)
    Healthy_Imgs_SVC_train, Healthy_Imgs_SVC_test, Healthy_RealMasks_SVC_train, Healthy_RealMasks_SVC_test = train_test_split(Healthy_Imgs_SVC, Healthy_Masks_SVC,splits,test_idx)
    Healthy_Imgs_SVC_train_alt, Healthy_Imgs_SVC_test_alt, Healthy_RealMasks_SVC_train_alt, Healthy_RealMasks_SVC_test_alt = train_test_split(Healthy_Imgs_SVC, Healthy_Masks_SVC_alt,splits,test_idx)
    
    splits = get_train_test_splits(ALZ_Imgs_SVC,split=num_splits)
    ALZ_Imgs_SVC_train, ALZ_Imgs_SVC_test, ALZ_RealMasks_SVC_train, ALZ_RealMasks_SVC_test = train_test_split(ALZ_Imgs_SVC, ALZ_Masks_SVC,splits,test_idx)
    ALZ_Imgs_SVC_train_alt, ALZ_Imgs_SVC_test_alt, ALZ_RealMasks_SVC_train_alt, ALZ_RealMasks_SVC_test_alt = train_test_split(ALZ_Imgs_SVC, ALZ_Masks_SVC_alt,splits,test_idx)

    splits = get_train_test_splits(AMD_Imgs_SVC,split=num_splits)
    AMD_Imgs_SVC_train, AMD_Imgs_SVC_test, AMD_RealMasks_SVC_train, AMD_RealMasks_SVC_test = train_test_split(AMD_Imgs_SVC, AMD_Masks_SVC,splits,test_idx)
    AMD_Imgs_SVC_train_alt, AMD_Imgs_SVC_test_alt, AMD_RealMasks_SVC_train_alt, AMD_RealMasks_SVC_test_alt = train_test_split(AMD_Imgs_SVC, AMD_Masks_SVC_alt,splits,test_idx)

    splits = get_train_test_splits(DR_Imgs_SVC,split=num_splits)
    DR_Imgs_SVC_train, DR_Imgs_SVC_test, DR_RealMasks_SVC_train, DR_RealMasks_SVC_test = train_test_split(DR_Imgs_SVC, DR_Masks_SVC,splits,test_idx)
    DR_Imgs_SVC_train_alt, DR_Imgs_SVC_test_alt, DR_RealMasks_SVC_train_alt, DR_RealMasks_SVC_test_alt = train_test_split(DR_Imgs_SVC, DR_Masks_SVC_alt,splits,test_idx)

    np.random.seed(None)

#     Photocoag_Imgs_test, Photocoag_Masks_test = f_augment_all_single(Photocoag_Imgs_test, Photocoag_Masks_test, p_rot = 0.5, p_flip = 0.5, p_jitter = 0.8, p_shear = 0.5, trans=0, N = 2, mode=1)
    Healthy_Imgs_SVC_test, Healthy_RealMasks_SVC_test = f_augment_all_single(Healthy_Imgs_SVC_test, Healthy_RealMasks_SVC_test, p_rot = 0.5, p_flip = 0.5, p_jitter = 0.8, p_shear = 0.5, trans=50, N = args.num_aug, mode=1)
    ALZ_Imgs_SVC_test, ALZ_RealMasks_SVC_test = f_augment_all_single(ALZ_Imgs_SVC_test, ALZ_RealMasks_SVC_test, p_rot = 0.5, p_flip = 0.5, p_jitter = 0.8, p_shear = 0.5, trans=50, N = args.num_aug, mode=1)
    AMD_Imgs_SVC_test, AMD_RealMasks_SVC_test = f_augment_all_single(AMD_Imgs_SVC_test, AMD_RealMasks_SVC_test, p_rot = 0.5, p_flip = 0.5, p_jitter = 0.8, p_shear = 0.5, trans=50, N = args.num_aug, mode=1)
    DR_Imgs_SVC_test, DR_RealMasks_SVC_test = f_augment_all_single(DR_Imgs_SVC_test, DR_RealMasks_SVC_test, p_rot = 0.5, p_flip = 0.5, p_jitter = 0.8, p_shear = 0.5, trans=50, N = args.num_aug, mode=1)

    _, Healthy_RealMasks_SVC_test_alt = f_augment_all_single(Healthy_Imgs_SVC_test_alt, Healthy_RealMasks_SVC_test_alt, p_rot = 0.5, p_flip = 0.5, p_jitter = 0.8, p_shear = 0.5, trans=50, N = args.num_aug, mode=1)
    _, ALZ_RealMasks_SVC_test_alt = f_augment_all_single(ALZ_Imgs_SVC_test_alt, ALZ_RealMasks_SVC_test_alt, p_rot = 0.5, p_flip = 0.5, p_jitter = 0.8, p_shear = 0.5, trans=50, N = args.num_aug, mode=1)
    _, AMD_RealMasks_SVC_test_alt = f_augment_all_single(AMD_Imgs_SVC_test_alt, AMD_RealMasks_SVC_test_alt, p_rot = 0.5, p_flip = 0.5, p_jitter = 0.8, p_shear = 0.5, trans=50, N = args.num_aug, mode=1)
    _, DR_RealMasks_SVC_test_alt = f_augment_all_single(DR_Imgs_SVC_test_alt, DR_RealMasks_SVC_test_alt, p_rot = 0.5, p_flip = 0.5, p_jitter = 0.8, p_shear = 0.5, trans=50, N = args.num_aug, mode=1)    
    
    Photocoag_Imgs_test = Photocoag_Imgs_test.unsqueeze(1).unsqueeze(1).to(args.device)
    Photocoag_Masks_test = Photocoag_Masks_test.unsqueeze(1).unsqueeze(1).to(args.device)
    Healthy_Imgs_SVC_test = Healthy_Imgs_SVC_test.unsqueeze(1).unsqueeze(1).to(args.device)
    Healthy_RealMasks_SVC_test = Healthy_RealMasks_SVC_test.unsqueeze(1).unsqueeze(1).to(args.device)
    ALZ_Imgs_SVC_test = ALZ_Imgs_SVC_test.unsqueeze(1).unsqueeze(1).to(args.device)
    ALZ_RealMasks_SVC_test = ALZ_RealMasks_SVC_test.unsqueeze(1).unsqueeze(1).to(args.device)
    AMD_Imgs_SVC_test = AMD_Imgs_SVC_test.unsqueeze(1).unsqueeze(1).to(args.device)
    AMD_RealMasks_SVC_test = AMD_RealMasks_SVC_test.unsqueeze(1).unsqueeze(1).to(args.device)
    DR_Imgs_SVC_test = DR_Imgs_SVC_test.unsqueeze(1).unsqueeze(1).to(args.device)
    DR_RealMasks_SVC_test = DR_RealMasks_SVC_test.unsqueeze(1).unsqueeze(1).to(args.device)

    Healthy_RealMasks_SVC_test_alt = Healthy_RealMasks_SVC_test_alt.unsqueeze(1).unsqueeze(1).to(args.device)
    ALZ_RealMasks_SVC_test_alt = ALZ_RealMasks_SVC_test_alt.unsqueeze(1).unsqueeze(1).to(args.device)
    AMD_RealMasks_SVC_test_alt = AMD_RealMasks_SVC_test_alt.unsqueeze(1).unsqueeze(1).to(args.device)
    DR_RealMasks_SVC_test_alt = DR_RealMasks_SVC_test_alt.unsqueeze(1).unsqueeze(1).to(args.device)
    
    ##################################

    ## DEFINE MODELS

    args.d_attn0 = 128
    args.d_attn1 = 256
    args.d_attn2 = 512

#     args.d_attn0 = 64
#     args.d_attn1 = 128
#     args.d_attn2 = 256
    
    args.fac = 1 # Relative size of base and aux decoders

    nhead_base = 8
    nhead = int(nhead_base/args.fac)
    args.nhead = nhead

    #################
    # Define model
    Enc_base = Unet_Enc_base(nhead_base, args).to(args.device)
    Dec_base = Unet_Dec_base(nhead_base, args).to(args.device)

    args.d_attn0 = int(args.d_attn0/args.fac)
    args.d_attn1 = int(args.d_attn1/args.fac)
    args.d_attn2 = int(args.d_attn2/args.fac)

    Dec_aux_healthy_mask = Unet_Dec_v3(nhead, args).to(args.device)
    Dec_aux_ALZ_mask = Unet_Dec_v3(nhead, args).to(args.device)
    Dec_aux_AMD_mask = Unet_Dec_v3(nhead, args).to(args.device)
    Dec_aux_DR_mask = Unet_Dec_v3(nhead, args).to(args.device)
    #     Dec_aux_Photocoag_mask = Unet_Dec_v3(nhead, args).to(args.device)

    Dec_aux_healthy_mask_alt = Unet_Dec_v3(nhead, args).to(args.device)
    Dec_aux_ALZ_mask_alt = Unet_Dec_v3(nhead, args).to(args.device)
    Dec_aux_AMD_mask_alt = Unet_Dec_v3(nhead, args).to(args.device)
    Dec_aux_DR_mask_alt = Unet_Dec_v3(nhead, args).to(args.device)
    
    ###############################################################

    ## Four Tasks (photocoag only)
    Train_Data_list = [Photocoag_Imgs_train, Photocoag_Imgs_train, Photocoag_Imgs_train, Photocoag_Imgs_train]
    Train_Target_Data_list = [Photocoag_Masks_train, Photocoag_Masks_train, Photocoag_Masks_train, Photocoag_Masks_train]
    Test_Data_list = [Photocoag_Imgs_test, Photocoag_Imgs_test, Photocoag_Imgs_test, Photocoag_Imgs_test]
    Test_Target_Data_list = [Photocoag_Masks_test, Photocoag_Masks_test, Photocoag_Masks_test, Photocoag_Masks_test]
    Dec_aux_list = [Dec_aux_healthy_mask, Dec_aux_ALZ_mask, Dec_aux_AMD_mask, Dec_aux_DR_mask]
    Dec_name_list = ['Dec_aux_healthy_mask', 'Dec_aux_ALZ_mask', 'Dec_aux_AMD_mask', 'Dec_aux_DR_mask']

#     ## Four Tasks (masks)
#     Train_Data_list = [Healthy_Imgs_SVC_train, ALZ_Imgs_SVC_train, AMD_Imgs_SVC_train, DR_Imgs_SVC_train]
#     Train_Target_Data_list = [Healthy_RealMasks_SVC_train, ALZ_RealMasks_SVC_train, AMD_RealMasks_SVC_train, DR_RealMasks_SVC_train]
#     Test_Data_list = [Healthy_Imgs_SVC_test, ALZ_Imgs_SVC_test, AMD_Imgs_SVC_test, DR_Imgs_SVC_test]
#     Test_Target_Data_list = [Healthy_RealMasks_SVC_test, ALZ_RealMasks_SVC_test, AMD_RealMasks_SVC_test, DR_RealMasks_SVC_test]
#     Dec_aux_list = [Dec_aux_healthy_mask, Dec_aux_ALZ_mask, Dec_aux_AMD_mask, Dec_aux_DR_mask]
#     Dec_name_list = ['Dec_aux_healthy_mask', 'Dec_aux_ALZ_mask', 'Dec_aux_AMD_mask', 'Dec_aux_DR_mask']

    Train_Target_Data_list_alt = [Healthy_RealMasks_SVC_train_alt, ALZ_RealMasks_SVC_train_alt, AMD_RealMasks_SVC_train_alt, DR_RealMasks_SVC_train_alt, Photocoag_Masks_train]
    Test_Target_Data_list_alt = [Healthy_RealMasks_SVC_test_alt, ALZ_RealMasks_SVC_test_alt, AMD_RealMasks_SVC_test_alt, DR_RealMasks_SVC_test_alt, Photocoag_Masks_test]
    
#         ## Five Tasks (masks)
#         Train_Data_list = [Photocoag_Imgs_train, Healthy_Imgs_SVC_train, ALZ_Imgs_SVC_train, AMD_Imgs_SVC_train, DR_Imgs_SVC_train]
#         Train_Target_Data_list = [Photocoag_Masks_train, Healthy_RealMasks_SVC_train, ALZ_RealMasks_SVC_train, AMD_RealMasks_SVC_train, DR_RealMasks_SVC_train]
#         Test_Data_list = [Photocoag_Imgs_test, Healthy_Imgs_SVC_test, ALZ_Imgs_SVC_test, AMD_Imgs_SVC_test, DR_Imgs_SVC_test]
#         Test_Target_Data_list = [Photocoag_Masks_test, Healthy_RealMasks_SVC_test, ALZ_RealMasks_SVC_test, AMD_RealMasks_SVC_test, DR_RealMasks_SVC_test]
#         Dec_aux_list = [Dec_aux_Photocoag_mask, Dec_aux_healthy_mask, Dec_aux_ALZ_mask, Dec_aux_AMD_mask, Dec_aux_DR_mask]
#         Dec_name_list = ['Dec_aux_Photocoag_mask', 'Dec_aux_healthy_mask', 'Dec_aux_ALZ_mask', 'Dec_aux_AMD_mask', 'Dec_aux_DR_mask']

#     # Eight Tasks (masks)
#     Train_Data_list = [Healthy_Imgs_SVC_train, ALZ_Imgs_SVC_train, AMD_Imgs_SVC_train, DR_Imgs_SVC_train, Healthy_Imgs_SVC, ALZ_Imgs_SVC, AMD_Imgs_SVC, DR_Imgs_SVC]
#     Train_Target_Data_list = [Healthy_RealMasks_SVC_train, ALZ_RealMasks_SVC_train, AMD_RealMasks_SVC_train, DR_RealMasks_SVC_train, Healthy_Masks_SVC_alt, ALZ_Masks_SVC_alt, AMD_Masks_SVC_alt, DR_Masks_SVC_alt]
#     Test_Data_list = [Healthy_Imgs_SVC_test, ALZ_Imgs_SVC_test, AMD_Imgs_SVC_test, DR_Imgs_SVC_test, 0*Healthy_Imgs_SVC_test, 0*ALZ_Imgs_SVC_test, 0*AMD_Imgs_SVC_test, 0*DR_Imgs_SVC_test]
#     Test_Target_Data_list = [Healthy_RealMasks_SVC_test, ALZ_RealMasks_SVC_test, AMD_RealMasks_SVC_test, DR_RealMasks_SVC_test, 0*Healthy_RealMasks_SVC_test, 0*ALZ_RealMasks_SVC_test, 0*AMD_RealMasks_SVC_test, 0*DR_RealMasks_SVC_test]
#     Dec_aux_list = [Dec_aux_healthy_mask, Dec_aux_ALZ_mask, Dec_aux_AMD_mask, Dec_aux_DR_mask, Dec_aux_healthy_mask_alt, Dec_aux_ALZ_mask_alt, Dec_aux_AMD_mask_alt, Dec_aux_DR_mask_alt]
#     Dec_name_list = ['Dec_aux_healthy_mask', 'Dec_aux_ALZ_mask', 'Dec_aux_AMD_mask', 'Dec_aux_DR_mask', 'Dec_aux_healthy_mask_alt', 'Dec_aux_ALZ_mask_alt', 'Dec_aux_AMD_mask_alt', 'Dec_aux_DR_mask_alt']

    #################
    # Load weights
#     load_folder = 'Y:\FAZ\Models\Models New\Models New New\FAZ 4 tasks with PC pretraining'
#     load_folder = 'Y:\FAZ\Models\Models New\Models New New\PC Pretrain all (800 epochs)'
    load_folder = 'Y:\FAZ\Models\Models New\Models New New\Multitask_PC_pretrain_filtered_bad'
    load_folder = load_folder.replace('\\', '/')
    date_str = '2024_11_20'
    load_models_FAZ(Enc_base, Dec_base, Dec_aux_list, Dec_name_list, test_idx, load_folder, date_str)
    #################
    
    sz_min = np.inf
    for i in np.arange(len(Train_Data_list)):
        sz_i = Train_Data_list[i].size()
        if sz_i[0] < sz_min:
            sz_min = sz_i[0]
    args.sz_min = sz_min*args.num_aug

    params_list = list(Enc_base.parameters()) + list(Dec_base.parameters())
    optimizers = []
    lr = 1E-4
    for i in np.arange(len(Dec_aux_list)):
        params_list_i = params_list + list(Dec_aux_list[i].parameters())
        optimizer_i = torch.optim.Adam(params_list_i, lr=lr, betas=(0.9, 0.999))
        optimizers.append(optimizer_i)

    one_hot_enc_v = []
    for task_idx in np.arange(args.num_tasks):
        one_hot = torch.zeros(args.num_tasks).to(args.device)
        one_hot[task_idx] = 1
        ones = torch.ones(args.batch_size,1,args.num_tasks).to(args.device)
        one_hot_enc = ones*one_hot

        one_hot_enc_v.append(one_hot_enc)

    #####################Training loop ###########################

    mean_epoch_losses = np.zeros((args.num_tasks,int(args.num_epochs/args.num_tasks)))
    log_mean_epoch_losses = np.zeros((args.num_tasks,int(args.num_epochs/args.num_tasks)))

    train_dice = np.zeros((args.num_tasks,int(args.num_epochs/args.num_tasks)))
    test_dice = np.zeros((args.num_tasks,int(args.num_epochs/args.num_tasks)))

    loss = nn.BCELoss()
    # loss_b = Dice_loss()
    loss2 = nn.L1Loss()

    # Train for a maximum of max_epochs:
    for epoch in tqdm(np.arange(args.num_epochs), desc="Training progress..."):

        #################################
    #     task_idx = np.random.choice(np.arange(args.num_tasks))
        task_idx = int(np.floor(np.mod(epoch,epoch_per_task*args.num_tasks)/epoch_per_task))
        print(task_idx)

    #     if task_idx < (args.num_tasks/2):
    #         loss2 = loss_b
    #     else:
    #         loss2 = loss

        Dec_i = Dec_aux_list[task_idx]
        params_list_i = params_list + list(Dec_aux_list[task_idx].parameters())
        optimizer = optimizers[task_idx]
#         one_hot_enc = one_hot_enc_v[task_idx]

        Train_Data = Train_Data_list[task_idx]
        Train_Target_Data = Train_Target_Data_list[task_idx]
        Test_Data_aug = Test_Data_list[task_idx]
        Test_Target_Data_aug = Test_Target_Data_list[task_idx]
        
        Train_Target_Data_alt = Train_Target_Data_list_alt[task_idx]
        Test_Target_Data_aug_alt = Test_Target_Data_list_alt[task_idx]
        
        Train_Target = Train_Target_Data
#         Train_Target, one_hot_enc = f_interpolate(Train_Target_Data, Train_Target_Data_alt, args)

        # Params for Photocoag data augmentation
        num_aug = args.num_aug
        trans = 0
        p_shear = 0
        p_rot = 0.5
        
#         # Params for FAZ data augmention:
#         num_aug = args.num_aug
#         trans = 50
#         p_shear = 0.5
#         p_rot = 0.5
        
        Train_Data_aug, Train_Target_Data_aug = f_augment_all_single(Train_Data, Train_Target, p_rot = p_rot, p_flip = 0.5, p_jitter = 0.8, p_shear = p_shear, trans = trans, N = num_aug)

        Train_Data_aug = Train_Data_aug.unsqueeze(1).unsqueeze(1).to(args.device)
        Train_Target_Data_aug = Train_Target_Data_aug.unsqueeze(1).unsqueeze(1).to(args.device)

        args.num_batches = int(len(Train_Data_aug)/args.N)

        #################################

        epoch_losses = single_epoch_FAZ(Train_Data_aug, Train_Target_Data_aug, Enc_base, Dec_base, Dec_i, one_hot_enc, optimizer, loss, loss2, args, params_list_i)
        mean_epoch_loss = np.mean(epoch_losses)
        mean_epoch_losses[task_idx,int(epoch/args.num_tasks)] = mean_epoch_loss
        log_mean_epoch_losses[task_idx,int(epoch/args.num_tasks)] = np.log(mean_epoch_loss)

        if task_idx < 4:
            train_dice[task_idx,int(epoch/args.num_tasks)] = f_Tot_Dice(Train_Data_aug, Train_Target_Data_aug, Enc_base, Dec_base, one_hot_enc, args)
            test_dice[task_idx,int(epoch/args.num_tasks)] = f_Tot_Dice(Test_Data_aug, Test_Target_Data_aug, Enc_base, Dec_base, one_hot_enc, args)

            if np.mod(epoch,args.show_example_epochs) == 0:
        #         if np.mod(epoch,args.show_example_epochs) == args.show_example_epochs - 1:
                display_FAZ(Train_Data_aug, Train_Target_Data_aug, Enc_base, Dec_base, one_hot_enc, args)

                fig1, ax = plt.subplots(nrows=1, ncols=1 )
                for task in np.arange(args.num_tasks):
        #             plt.plot(mean_epoch_losses[task,0:int(epoch/args.num_tasks)])
                    plt.plot(log_mean_epoch_losses[task,0:int(epoch/args.num_tasks)])
                plt.grid()
                plt.show()

                colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k', 'orange', 'aquamarine', 'mediumseagreen']
                fig2, ax = plt.subplots(nrows=1, ncols=1 )
                for i in np.arange(len(train_dice)):
                    plt.plot(train_dice[i,0:int(epoch/args.num_tasks)+1],color=colors[i])
    #             plt.grid()
    #             plt.show()
    #             fig3, ax = plt.subplots(nrows=1, ncols=1 )
                for i in np.arange(4):
                    plt.plot(test_dice[i,0:int(epoch/args.num_tasks)+1],color=colors[i],linestyle='--')
                plt.grid()
                plt.show()

            if np.mod(epoch,args.save_epochs) == args.save_epochs-1:    

                np.savetxt('mean_epoch_losses' + '_epoch_' + str(epoch) + '_' + save_name + '.txt', mean_epoch_losses, fmt='%.2f')
                np.savetxt('train_dice' + '_epoch_' + str(epoch) + '_' + save_name + '.txt', train_dice, fmt='%.2f')
                np.savetxt('test_dice' + '_epoch_' + str(epoch) + '_' + save_name + '.txt', test_dice, fmt='%.2f')

                fig1.savefig('loss' + save_name + '.png')
                fig2.savefig('dice' + save_name + '.png')

                save_models(Enc_base, Dec_base, Dec_aux_list, Dec_name_list, test_idx, model_folder)

    np.savetxt(model_folder + '/' + 'mean_epoch_losses' + '_epoch_' + str(epoch) + '_' + save_name + '.txt', mean_epoch_losses, fmt='%.2f')
    np.savetxt(model_folder + '/' + 'train_dice' + '_epoch_' + str(epoch) + '_' + save_name + '.txt', train_dice, fmt='%.2f')
    np.savetxt(model_folder + '/' + 'test_dice' + '_epoch_' + str(epoch) + '_' + save_name + '.txt', test_dice, fmt='%.2f')

    fig1.savefig('loss' + save_name + '.png')
    fig2.savefig('dice' + save_name + '.png')

    save_models(Enc_base, Dec_base, Dec_aux_list, Dec_name_list, test_idx, model_folder)

In [None]:
filt_v = f_display_all(Test_Data_aug, Test_Target_Data_aug, Enc_base, Dec_base, one_hot_enc, args, model_folder)


In [None]:
# Data = Test_Data_aug
# Target_Data = 1.*(Test_Target_Data_aug > 0.5)

# num_batches = int(len(Data)/args.batch_size)

# dice_tot = 0
# for it in np.arange(num_batches):

#     in_frames = Data[it*args.batch_size:(it+1)*args.batch_size,:,:,:,:]
#     target_frames = Target_Data[it*args.batch_size:(it+1)*args.batch_size,:,:,:,:]

#     with torch.no_grad():
#         feats0u, feats1u, feats2u, feats3u_f, feats4_f, x4_sz = Enc_base(in_frames, one_hot_enc)
#         rec_frames, _, _, _, _, _ = Dec_base(feats0u, feats1u, feats2u, feats3u_f, feats4_f, x4_sz, one_hot_enc, mode=4)

# #     rec_frames = 1.*(rec_frames > 0.5)

#     for i in np.arange(args.batch_size):
#         in_frame =  in_frames.detach().cpu().numpy().squeeze()[i]
#         target_frame = target_frames.detach().cpu().numpy().squeeze()[i]
#         rec_frame = rec_frames.detach().cpu().numpy().squeeze()[i]
        
#         plt.imshow(in_frame)
#         plt.show()
#         plt.imshow(target_frame)
#         plt.show()
#         plt.imshow(rec_frame)
#         plt.show()

#         dice_i, _ = f_dice_jac(target_frame,rec_frame)
#         print(dice_i)
#         dice_tot += dice_i

In [None]:
# task_idx = 0

# Dec_i = Dec_aux_list[task_idx]
# params_list_i = params_list + list(Dec_aux_list[task_idx].parameters())
# optimizer = optimizers[task_idx]
# one_hot_enc = one_hot_enc_v[task_idx]

# Train_Data = Train_Data_list[task_idx]
# Train_Target_Data = Train_Target_Data_list[task_idx]
# Test_Data_aug = Test_Data_list[task_idx]
# Test_Target_Data_aug = Test_Target_Data_list[task_idx]

# if Dec_name_list[task_idx] != 'Dec_aux_Photocoag_mask' and Dec_name_list[task_idx] != 'Dec_aux_Photocoag_auto':
#     num_aug = args.num_aug
#     trans = 50
#     p_shear = 0.5
#     p_rot = 0.5
# else:
#     num_aug = 2
#     trans = 0
#     p_shear = 0
#     p_rot = 0
# Train_Data_aug, Train_Target_Data_aug = f_augment_all_single(Train_Data, Train_Target_Data, p_rot = p_rot, p_flip = 0.5, p_jitter = 0.8, p_shear = p_shear, trans = trans, N = num_aug)

# Train_Data_aug = Train_Data_aug.unsqueeze(1).unsqueeze(1).to(args.device)
# Train_Target_Data_aug = Train_Target_Data_aug.unsqueeze(1).unsqueeze(1).to(args.device)

# args.num_batches = int(len(Train_Data_aug)/args.N)

# #################################

# # epoch_losses = single_epoch_FAZ(Train_Data_aug, Train_Target_Data_aug, Enc_base, Dec_base, Dec_i, one_hot_enc, optimizer, loss, loss2, args, params_list_i)
# # mean_epoch_loss = np.mean(epoch_losses)
# # mean_epoch_losses[task_idx,int(epoch/args.num_tasks)] = mean_epoch_loss
# # log_mean_epoch_losses[task_idx,int(epoch/args.num_tasks)] = np.log(mean_epoch_loss)

# # train_dice[task_idx,int(epoch/args.num_tasks)] = f_Tot_Dice(Train_Data_aug, Train_Target_Data_aug, one_hot_enc, args)
# # test_dice[task_idx,int(epoch/args.num_tasks)] = f_Tot_Dice(Test_Data_aug, Test_Target_Data_aug, one_hot_enc, args)

# display_FAZ(Train_Data_aug, Train_Target_Data_aug, Enc_base, Dec_base, one_hot_enc, args)

In [None]:
# train_data = Train_Data_aug
# target_data = Train_Target_Data_aug

# random_idxs = np.arange(len(train_data))
# np.random.shuffle(random_idxs)
# train_data_shuffle = train_data[random_idxs]
# target_data_shuffle = target_data[random_idxs]
# it = 0
# in_frames = train_data_shuffle[it*args.batch_size:(it+1)*args.batch_size,:,:,:,:]
# target_frames = target_data_shuffle[it*args.batch_size:(it+1)*args.batch_size,:,:,:,:]

# with torch.no_grad():
#     feats0u, feats1u, feats2u, feats3u_f, feats4_f, x4_sz = Enc_base(in_frames, one_hot_enc)
# #         rec_frames, feats1n_base, feats2n_base, feats4n_base, feats5n_base = dec_base(feats1u, feats2u, feats4u_f, feats5_f, x4_sz, one_hot_enc)
# #         rec_frames_aux, feats1n_aux, feats2n_aux, feats4n_aux, feats5n_aux      = dec_aux(feats1u, feats2u, feats4u_f, feats5_f, x4_sz)

#     rec_frames_base_4, _, _, _, _, _ = Dec_base(feats0u, feats1u, feats2u, feats3u_f, feats4_f, x4_sz, one_hot_enc, mode=4)


# plt.imshow(in_frames[i,0,0,:,:].detach().cpu().numpy())
# plt.show()
# plt.imshow(target_frames[i,0,0,:,:].detach().cpu().numpy())
# plt.show()
# plt.imshow(rec_frames_base_4[i,0,0,:,:].detach().cpu().numpy())
# plt.show()

## Evaluate

In [None]:
import re
import matplotlib.pyplot as plt
import numpy as np

##################################

def txt_to_array(txt_name):
    text_file = open(txt_name, "r")
    lines = text_file.read().split('\n')[0:-1]
    S = []
    for i in range(len(lines)):
        s = '[' + lines[i] + ']'
        s2 = re.sub("\s+", ",", s.strip())
        s3 = eval(s2)
        S.append(s3)
    return np.array(S)

In [None]:
# import re
# import matplotlib.pyplot as plt
# import numpy as np

# ##################################

# # %cd Y:/FAZ/Models/Aya_with_photocoag_without_inner_losses
# # %cd Y:/FAZ/Models/Aya, without photocoag, no inner losses

# # epoch = 599
# epoch = 499
# colors = ['b', 'g', 'r', 'c', 'y']
# # labels = ['Healthy', 'ALZ', 'AMD', 'DR']
# labels = ['Photocoag', 'Healthy', 'ALZ', 'AMD', 'DR']

# task_idx = 0
# model_folder = 'Y:/FAZ/Models/AE2/' + task_names[task_idx]

# ##################################

# def txt_to_array(txt_name):
#     text_file = open(txt_name, "r")
#     lines = text_file.read().split('\n')[0:-1]
#     S = []
#     for i in range(len(lines)):
#         s = '[' + lines[i] + ']'
#         s2 = re.sub("\s+", ",", s.strip())
#         s3 = eval(s2)
#         S.append(s3)
#     return np.array(S)

# ##########################

# train_dice = 0
# for test_idx in range(4):
# #     epoch = args.num_epochs-1
#     txt_name = 'task_idx' + str(task_idx) + 'train_dice' + '_epoch_' + str(epoch) + '__' + 'test_idx_' + str(test_idx) + '.txt'
#     train_dice += txt_to_array(txt_name)
# train_dice /= 4

# for i in range(len(train_dice)):
#     label = labels[i]
#     plt.plot(train_dice[i],color=colors[i],label=label)
# plt.grid()
# plt.title('Train Dice')
# # plt.ylim((0.85,1))
# plt.legend()
# plt.show()

# ##########################

# test_dice = 0
# for test_idx in range(4):
# #     epoch = args.num_epochs-1
#     txt_name = 'test_dice' + '_epoch_' + str(epoch) + '__' + 'test_idx_' + str(test_idx) + '.txt'
#     test_dice += txt_to_array(txt_name)
# test_dice /= 4

# for i in range(len(test_dice)):
#     label = labels[i]
#     plt.plot(test_dice[i],color=colors[i],label=label)
# plt.grid()
# plt.title('Test Dice')
# # plt.ylim((0.7,1))
# plt.legend()
# plt.show()

# test_dice_final = np.round(np.mean(test_dice[:,-100:],axis=1)*100,1)

# print('Dice Index:')
# for i in range(len(test_dice_final)):
#     print(labels[i] + ': ' + str(test_dice_final[i]) + '%')

In [None]:
task_idx = 0
splits = [0, 1]
# splits = [0]
final_epoch = 1000
N_epochs = final_epoch
# N_epochs = 100
colors = ['b', 'g', 'r', 'c', 'y']
task_names = ['Healthy', 'ALZ', 'AMD', 'DR', 'Photocoag']
test_names = ['split 1', 'split 2']
N_splits = len(splits)
N_avg_epochs = 100

# model_folder = 'Y:/FAZ/Models/Models New/Models New New/AE'
# model_folder = 'Y:/FAZ/Models/Models New/Models New New/AE_with_2_annotations'
# model_folder = 'Y:\FAZ\Models\Models New\Models New New\FAZ 4 tasks, no pretraining test ixd0'
# model_folder = 'Y:\FAZ\Models\Models New\Models New New\FAZ 4 tasks, no pretraining test idx1'
# model_folder = 'Y:\FAZ\Models\Models New\Models New New\FAZ 4 tasks, with PC Pretrained, Idx0'
# model_folder = 'Y:\FAZ\Models\Models New\Models New New\FAZ 4 tasks, with PC Pretrained, Idx1'
# model_folder = 'Y:\FAZ\Models\Models New\Aya, without photocoag, no inner losses'
# model_folder = 'Y:\FAZ\Models\Models New\Models New New\AE filtered 2_annotations'
# model_folder = 'Y:\FAZ\Models\Models New\Models New New\AE_with_2_annotations'
# model_folder = 'Y:\FAZ\Models\Models New\Models New New\AE_filtered'
# model_folder = 'Y:\FAZ\Models\Models New\Models New New\PC Pretrain two-fold'
model_folder = 'Y:/FAZ/Models/Models New/Models New New/AE2/'

# model_folder = model_folder + '\\'
model_folder = model_folder + task_names[task_idx] + '\\'

train_dice = np.zeros((N_splits,N_epochs))
test_dice = np.zeros((N_splits,N_epochs))

for test_idx in range(N_splits):
    split_idx = splits[test_idx]
#     txt_name = model_folder + task_names[task_idx] + '/' + 'task_idx' + str(task_idx) + 'train_dice' + '_epoch_' + str(final_epoch) + '__' + 'test_idx_' + str(test_idx) + '.txt'
#     train_dice[test_idx,:] += np.squeeze(txt_to_array(txt_name))
    txt_name = model_folder + 'task_idx' + str(task_idx) + '_' + 'train_dice' + '_epoch_' + str(final_epoch) + '__' + 'test_idx_' + str(split_idx) + '.txt'
    train_dice[test_idx,:] += np.squeeze(txt_to_array(txt_name))

#     txt_name = model_folder + task_names[task_idx] + '/' + 'task_idx' + str(task_idx) + 'test_dice' + '_epoch_' + str(final_epoch) + '__' + 'test_idx_' + str(test_idx) + '.txt'
#     test_dice[test_idx,:] += np.squeeze(txt_to_array(txt_name))
    txt_name = model_folder + 'task_idx' + str(task_idx) + '_' + 'test_dice' + '_epoch_' + str(final_epoch) + '__' + 'test_idx_' + str(split_idx) + '.txt'
    test_dice[test_idx,:] += np.squeeze(txt_to_array(txt_name))

for i in range(len(train_dice)):
    label = test_names[i]
    plt.plot(train_dice[i],color=colors[i],label=label)
plt.grid()
plt.title('Train Dice')
# plt.ylim((0.85,1))
plt.legend()
plt.show()

##########################

for i in range(len(test_dice)):
    label = test_names[i]
    plt.plot(test_dice[i],color=colors[i],label=label)
plt.grid()
plt.title('Test Dice')
# plt.ylim((0.7,1))
plt.legend()
plt.show()

train_dice_final = np.round(np.mean(train_dice[:,-N_avg_epochs:],axis=1)*N_avg_epochs,1)
test_dice_final = np.round(np.mean(test_dice[:,-N_avg_epochs:],axis=1)*N_avg_epochs,1)

print('TASK:', task_names[task_idx])

print('')

print('Train Dice Index:')
for i in range(len(train_dice_final)):
    print(test_names[i] + ': ' + str(train_dice_final[i]) + '%')
print('Avg Train Dice:', np.mean(train_dice_final), '%')

print('')

print('Test Dice Index:')
for i in range(len(test_dice_final)):
    print(test_names[i] + ': ' + str(test_dice_final[i]) + '%')
print('Avg Test Dice:',np.mean(test_dice_final), '%')

In [None]:
final_epoch = 1000

In [None]:
from scipy.stats import ttest_rel

# Calculate Cohen's d for paired samples
def cohens_d(set1, set2):
    diff = set2 - set1
    return np.mean(diff) / np.std(diff, ddof=1)

task_idx = 0
splits = [0, 1]
# splits = [0]
final_epoch = 1000
N_epochs = final_epoch
# N_epochs = 100
colors = ['b', 'g', 'r', 'c', 'y']
task_names = ['Healthy', 'ALZ', 'AMD', 'DR', 'Photocoag']
test_names = ['split 1', 'split 2']
N_splits = len(splits)
N_avg_epochs = 100

# model_folder = 'Y:/FAZ/Models/Models New/Models New New/AE'
# model_folder = 'Y:/FAZ/Models/Models New/Models New New/AE_with_2_annotations'
# model_folder = 'Y:\FAZ\Models\Models New\Models New New\FAZ 4 tasks, no pretraining test ixd0'
# model_folder = 'Y:\FAZ\Models\Models New\Models New New\FAZ 4 tasks, no pretraining test idx1'
# model_folder = 'Y:\FAZ\Models\Models New\Models New New\FAZ 4 tasks, with PC Pretrained, Idx0'
# model_folder = 'Y:\FAZ\Models\Models New\Models New New\FAZ 4 tasks, with PC Pretrained, Idx1'
# model_folder = 'Y:\FAZ\Models\Models New\Aya, without photocoag, no inner losses'
# model_folder = 'Y:\FAZ\Models\Models New\Models New New\AE filtered 2_annotations'
# model_folder = 'Y:\FAZ\Models\Models New\Models New New\PC Pretrain two-fold'

main_folder1 = 'Y:\FAZ\Models\Models New\Models New New\AE2'
# main_folder2 = 'Y:\FAZ\Models\Models New\Models New New\AE_filtered'
main_folder2 = 'Y:\FAZ\Models\Models New\Models New New\AE_with_2_annotations'
# main_folder2 = 'Y:\FAZ\Models\Models New\Models New New\Multitask 4 tasks, 2 annotations, no pretraining'

folders = [main_folder1, main_folder2]
ms = [0, 0]

# model_folder = model_folder + '\\'

Test_Dices = np.zeros((2,4,N_avg_epochs))

for m_idx in range(2):
    
    main_folder = folders[m_idx] + '\\'

    if ms[m_idx] == 0:
        final_epoch = 1000
        for task_idx in range(4):
            train_dice = np.zeros((N_avg_epochs))
            test_dice = np.zeros((N_avg_epochs))
            model_folder = main_folder + task_names[task_idx] + '\\'
            for test_idx in range(N_splits):
                split_idx = splits[test_idx]
                txt_name = model_folder + 'task_idx' + str(task_idx) + '_' + 'train_dice' + '_epoch_' + str(final_epoch) + '__' + 'test_idx_' + str(test_idx) + '.txt'
                train_dice += np.squeeze(txt_to_array(txt_name))[-N_avg_epochs:]/2

                txt_name = model_folder + 'task_idx' + str(task_idx) + '_' + 'test_dice' + '_epoch_' + str(final_epoch) + '__' + 'test_idx_' + str(test_idx) + '.txt'
                test_dice += np.squeeze(txt_to_array(txt_name))[-N_avg_epochs:]/2
            Test_Dices[m_idx,task_idx,:] = test_dice

    else:
        model_folder = main_folder
        f_epoch = 479
        for test_idx in range(N_splits):
            txt_name = model_folder + '\\' + 'train_dice' + '_epoch_' + str(f_epoch) + '__' + 'test_idx_' + str(split_idx) + '.txt'
            Test_Dices[1] += np.squeeze(txt_to_array(txt_name))[:,-N_avg_epochs:]/2


for task_idx in range(4):
    t_stat, p_value = ttest_rel(Test_Dices[0,task_idx,:], Test_Dices[1,task_idx,:]) # Perform a paired t-test
    print('p value:', p_value)
    cd = cohens_d(Test_Dices[0,task_idx,:], Test_Dices[1,task_idx,:])
    print('Cohens d:', cd)

# for i in range(len(train_dice)):
#     label = test_names[i]
#     plt.plot(train_dice[i],color=colors[i],label=label)
# plt.grid()
# plt.title('Train Dice')
# # plt.ylim((0.85,1))
# plt.legend()
# plt.show()

# for i in range(len(test_dice)):
#     label = test_names[i]
#     plt.plot(test_dice[i],color=colors[i],label=label)
# plt.grid()
# plt.title('Test Dice')
# # plt.ylim((0.7,1))
# plt.legend()
# plt.show()

In [None]:
m_idx

In [None]:
model_folder
m_idx

In [None]:
# # %cd Y:/FAZ/Models/Aya
# # %cd Y:/FAZ/Models/Nathan
# # %cd Y:/FAZ/Models/Aya no inner losses
# # %cd Y:/FAZ/Models/Nathan no inner losses

# ##########################

# train_dice = 0
# for test_idx in range(4):
# #     epoch = args.num_epochs-1
#     epoch = 1999
#     txt_name = 'train_dice' + '_epoch_' + str(epoch) + '__' + 'test_idx_' + str(test_idx) + '.txt'
#     train_dice += txt_to_array(txt_name)
# train_dice /= len(train_dice)

# colors = ['b', 'g', 'r', 'c']
# labels = ['Healthy', 'ALZ', 'AMD', 'DR']
# for i in range(len(test_dice)):
#     label = labels[i]
#     plt.plot(train_dice[i],color=colors[i],label=label)
# plt.grid()
# plt.title('Train Dice')
# plt.ylim((0.85,1))
# plt.legend()
# plt.show()

# ##########################

# test_dice = 0
# for test_idx in range(4):
# #     epoch = args.num_epochs-1
#     epoch = 1999
#     txt_name = 'test_dice' + '_epoch_' + str(epoch) + '__' + 'test_idx_' + str(test_idx) + '.txt'
#     test_dice += txt_to_array(txt_name)
# test_dice /= len(test_dice)

# colors = ['b', 'g', 'r', 'c']
# labels = ['Healthy', 'ALZ', 'AMD', 'DR']
# for i in range(len(test_dice)):
#     label = labels[i]
#     plt.plot(test_dice[i],color=colors[i],label=label)
# plt.grid()
# plt.title('Test Dice')
# plt.ylim((0.7,1))
# plt.legend()
# plt.show()

# test_dice_final = np.round(np.mean(test_dice[:,-100:],axis=1)*100,1)

# print('Dice Index:')
# for i in range(len(test_dice_final)):
#     print(labels[i] + ': ' + str(test_dice_final[i]) + '%')