In [None]:
import tqdm
import numpy as np
import torch
import pandas as pd
import pickle
import time
import datetime
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=pd.errors.PerformanceWarning)
pd.set_option('mode.chained_assignment',  None)
import os

import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import random
from scipy.stats import pearsonr

import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader


In [None]:
with open('TF_in_data_all.pickle', 'rb') as f:
    TF_in_data, TF_in_std, TF_in_mask, ID_list, ID_year_list = pickle.load(f)
TF_in_data = TF_in_data.detach().cpu()
TF_in_std = TF_in_std.detach().cpu()
TF_in_mask = TF_in_mask.detach().cpu()

In [None]:
loss_dir = './loss'
model_dir = './model_save'
table_loc_master = "/data/Storage_DAS02/jaeyoon/230101_BreastCancer_Project/data/230629_BC_new/"
table_data_master = "230714_Table_merged.xlsx"
table_type_master = "230714_BC_new_variable_type.xlsx"

Table_type = pd.read_excel(table_loc_master+table_type_master)
Table_type= Table_type[Table_type["Distribution"] != "-"]
Table_type= Table_type[Table_type["In VAE"] != "X"]

Table_1 = pd.read_excel(table_loc_master+table_data_master).reset_index(drop = True).set_index(['new_ID','Year_num'])
Table_Date = Table_1[["Test_date_normed"]]


In [None]:
ID_test = pd.read_excel("Test_ID_set.xlsx").values[:,1]
ID_train = pd.read_excel("Train_ID_set.xlsx").values[:,1]
# ID_test = np.unique([i.split(',')[0][2:-1] for i in ID_test])
# ID_train = np.unique([i.split(',')[0][2:-1] for i in ID_train])

In [None]:
# DataLoader
class BART_pretrain_Dataset(Dataset):
    def __init__(self, inputs):
        TF_in_data, TF_in_mask, ID_list, ID_year_list, ID_use = inputs
        idx_sub = [n for n,i in enumerate(ID_list) if i in ID_use]
        self.data = TF_in_data[idx_sub] # N, 10, 50
        self.mask = TF_in_mask[idx_sub] # N, 10, 1
        self.ID_list = np.array(ID_list)[idx_sub]
        self.ID_year_list = [i for n,i in enumerate(ID_year_list) if n in idx_sub]
        # self.ID_year_list = np.array(ID_year_list)[idx_sub]

            
    def __len__(self):
        return len(self.ID_list)

    def __getitem__(self, idx_sub): # N x 260x512
        # 260 x 512
        return {'data' : self.data[idx_sub],
                'mask' : self.mask[idx_sub],
                'idx' : self.ID_list[idx_sub],
                'data_idx' : self.ID_year_list[idx_sub]}
def BART_collate_fn(samples):
    p = 0.3
    data = torch.stack([sample['data'] for sample in samples]) # N, 10, 50
    mask = torch.stack([sample['mask'] for sample in samples]).bool() # N, 10, 1
    IDs = [sample['idx'] for sample in samples] # N, 10, 1
    Rand_mask = torch.zeros_like(mask)
    for n,sample in enumerate(samples):
        idx_sub = np.setdiff1d(sample['data_idx'],[1])
        if idx_sub.shape[0] != 0:
            idx_MLM = np.random.choice(idx_sub,max(1,int(len(idx_sub) * p)),replace = False)            
            for idx in idx_MLM:
                Rand_mask[n,idx] = 1
    Rand_mask = Rand_mask.bool()
    Enc_mask = mask + Rand_mask
    Enc_in = data * (~Enc_mask)
    Out_mask = ~mask
    Out_data = data
    Dec_in = torch.concat([torch.zeros(data.shape[0],1,data.shape[2]).float(),data[:,:-1,:]], axis = 1)
    Dec_mask = torch.concat([torch.zeros(data.shape[0],1,1),mask[:,:-1,:]], axis = 1).repeat(1,1,10).bool()
    
    return Enc_in, Enc_mask, Dec_in, Dec_mask, Out_mask, Out_data, IDs    

In [None]:
class MultiHeadAttentionLayer(nn.Module):
    def __init__(self, hidden_dim, n_heads, dropout_ratio):
        super().__init__()

        assert hidden_dim % n_heads == 0
        self.hidden_dim = hidden_dim # 임베딩 차원
        self.n_heads = n_heads # 헤드(head)의 개수: 서로 다른 어텐션(attention) 컨셉의 수
        self.head_dim = hidden_dim // n_heads # 각 헤드(head)에서의 임베딩 차원
        self.fc_q = nn.Linear(hidden_dim, hidden_dim) # Query 값에 적용될 FC 레이어
        self.fc_k = nn.Linear(hidden_dim, hidden_dim) # Key 값에 적용될 FC 레이어
        self.fc_v = nn.Linear(hidden_dim, hidden_dim) # Value 값에 적용될 FC 레이어

        self.fc_o = nn.Linear(hidden_dim, hidden_dim)

        self.dropout = nn.Dropout(dropout_ratio)

    def forward(self, query, key, value, mask = None):

        batch_size = query.shape[0]

        Q = self.fc_q(query)
        K = self.fc_k(key)
        V = self.fc_v(value)

        # hidden_dim → n_heads X head_dim 형태로 변형
        # n_heads(h)개의 서로 다른 어텐션(attention) 컨셉을 학습하도록 유도
        Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)


        # Attention Energy 계산
        energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / math.sqrt(self.hidden_dim)
        energy = torch.clamp(energy,max = 1e6, min = -1e6)
        # 마스크(mask)를 사용하는 경우
        if mask is not None:
            energy = energy.masked_fill(mask.unsqueeze(1).repeat(1, self.n_heads, 1, 1), -1e10)
        attention = torch.softmax(energy, dim=-1)
        
        x = torch.matmul(self.dropout(attention), V)
        x = x.permute(0, 2, 1, 3).contiguous()
        x = x.view(batch_size, -1, self.hidden_dim)
        x = self.fc_o(x)

        return x, attention
    
class PositionwiseFeedforwardLayer(nn.Module):
    def __init__(self, hidden_dim, pf_dim, dropout_ratio):
        super().__init__()

        self.fc_1 = nn.Linear(hidden_dim, pf_dim)
        self.fc_2 = nn.Linear(pf_dim, hidden_dim)

        self.dropout = nn.Dropout(dropout_ratio)

    def forward(self, x):

        x = self.dropout(torch.relu(self.fc_1(x)))
        x = self.fc_2(x)

        return x
    
class EncoderLayer(nn.Module):
    def __init__(self, hidden_dim, n_heads, pf_dim, dropout_ratio):
        super().__init__()

        self.self_attn_layer_norm = nn.LayerNorm(hidden_dim)
        self.ff_layer_norm = nn.LayerNorm(hidden_dim)
        self.self_attention = MultiHeadAttentionLayer(hidden_dim, n_heads, dropout_ratio)
        self.positionwise_feedforward = PositionwiseFeedforwardLayer(hidden_dim, pf_dim, dropout_ratio)
        self.dropout = nn.Dropout(dropout_ratio)

    # 하나의 임베딩이 복제되어 Query, Key, Value로 입력되는 방식
    def forward(self, src, src_mask):

        # self attention
        _src, attn = self.self_attention(src, src, src, src_mask)
        src = self.self_attn_layer_norm(src + self.dropout(_src))

        # position-wise feedforward
        _src = self.positionwise_feedforward(src)
        src = self.ff_layer_norm(src + self.dropout(_src))

        return src, attn
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, n_layers, n_heads, pf_dim, dropout_ratio):
        super().__init__()

        self.hidden_dim = hidden_dim
        self.tok_embedding = nn.Linear(input_dim, hidden_dim)
        self.pos_embedding = nn.Embedding(input_dim, hidden_dim)

        self.layers = nn.ModuleList([EncoderLayer(hidden_dim, n_heads, pf_dim, dropout_ratio) for _ in range(n_layers)])
        self.dropout = nn.Dropout(dropout_ratio)


    def forward(self, src, src_mask):

        batch_size = src.shape[0]
        src_len = src.shape[1]
        device = src.device
        pos = torch.arange(0, src_len).unsqueeze(0).repeat(batch_size, 1).to(device)

        self.tok_embedding.to(device)
        self.pos_embedding.to(device)
        # 소스 문장의 임베딩과 위치 임베딩을 더한 것을 사용
        src = self.dropout((self.tok_embedding(src) * math.sqrt(self.hidden_dim)) + self.pos_embedding(pos))
        
        # 모든 인코더 레이어를 차례대로 거치면서 순전파(forward) 수행
        attns = []
        for layer in self.layers:
            src, attn = layer(src, src_mask)
            attns += [attn]

        return src, attns # 마지막 레이어의 출력을 반환
    
class DecoderLayer(nn.Module):
    def __init__(self, hidden_dim, n_heads, pf_dim, dropout_ratio):
        super().__init__()

        self.self_attn_layer_norm = nn.LayerNorm(hidden_dim)
        self.enc_attn_layer_norm = nn.LayerNorm(hidden_dim)
        self.ff_layer_norm = nn.LayerNorm(hidden_dim)
        self.self_attention = MultiHeadAttentionLayer(hidden_dim, n_heads, dropout_ratio)
        self.encoder_attention = MultiHeadAttentionLayer(hidden_dim, n_heads, dropout_ratio)
        self.positionwise_feedforward = PositionwiseFeedforwardLayer(hidden_dim, pf_dim, dropout_ratio)
        self.dropout = nn.Dropout(dropout_ratio)

    # 인코더의 출력 값(enc_src)을 어텐션(attention)하는 구조
    def forward(self, trg, enc_src, trg_mask, src_mask):

        # self attention
        _trg, attn_1 = self.self_attention(trg, trg, trg, trg_mask)
        trg = self.self_attn_layer_norm(trg + self.dropout(_trg))

        # encoder attention
        _trg, attn_2 = self.encoder_attention(trg, enc_src, enc_src, src_mask)
        trg = self.enc_attn_layer_norm(trg + self.dropout(_trg))

        # positionwise feedforward
        _trg = self.positionwise_feedforward(trg)
        trg = self.ff_layer_norm(trg + self.dropout(_trg))

        return trg, attn_1, attn_2
class Decoder(nn.Module):
    def __init__(self, output_dim, hidden_dim, n_layers, n_heads, pf_dim, dropout_ratio):
        super().__init__()

        self.hidden_dim = hidden_dim
        self.tok_embedding = nn.Linear(output_dim, hidden_dim)
        self.pos_embedding = nn.Embedding(output_dim, hidden_dim)

        self.layers = nn.ModuleList([DecoderLayer(hidden_dim, n_heads, pf_dim, dropout_ratio) for _ in range(n_layers)])

        self.fc_out_1 = nn.Linear(hidden_dim, hidden_dim)
        self.fc_out_2 = nn.Linear(hidden_dim, output_dim)

        self.dropout = nn.Dropout(dropout_ratio)


    def forward(self, trg, enc_src, trg_mask, src_mask):

        device = trg.device
        batch_size = trg.shape[0]
        trg_len = trg.shape[1]

        pos = torch.arange(0, trg_len).unsqueeze(0).repeat(batch_size, 1).to(device)

        trg = self.dropout((self.tok_embedding(trg) * math.sqrt(self.hidden_dim)) + self.pos_embedding(pos))

        # trg: [batch_size, trg_len, hidden_dim]
        attns_1 = []
        attns_2 = []
        for layer in self.layers:
            trg, attn_1, attn_2 = layer(trg, enc_src, trg_mask, src_mask)
            attns_1 += [attn_1]
            attns_2 += [attn_2]

        trg = self.fc_out_1(trg)
        output = self.fc_out_2(trg)

        return output, attns_1, attns_2
    
class Transformer(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()

        self.encoder = encoder
        self.decoder = decoder
        self.trg_sub_mask = torch.tril(torch.ones((10, 10)),diagonal=-1).bool().T.unsqueeze(0)
    
    def forward(self, inputs):
        Enc_in, Enc_mask, Dec_in, Dec_mask = inputs
        
        device = Enc_in.device
        Dec_mask = Dec_mask.transpose(1,2) + self.trg_sub_mask.to(device)
        Enc_mask = Enc_mask.transpose(1,2)

        enc_src, enc_attns = self.encoder(Enc_in, Enc_mask)

        output, dec_attns_1, dec_attns_2 = self.decoder(Dec_in, enc_src, Dec_mask, Enc_mask)

        return output, enc_attns + dec_attns_1 + dec_attns_2

In [None]:
devices = [torch.device(f"cuda:{i}") for i in range(2, torch.cuda.device_count())]
device = devices[0]
dropout_ratio = 0.1

input_dim = 100
hidden_dim = 512
num_layers = 3
num_head = 8
ff_dim = 2048
lr = 1e-4
enc = Encoder(input_dim, hidden_dim, num_layers, num_head, pf_dim = ff_dim, dropout_ratio= dropout_ratio)
dec = Decoder(input_dim, hidden_dim, num_layers, num_head, pf_dim = ff_dim, dropout_ratio= dropout_ratio)
BART_model = Transformer(enc, dec).to(device)
optimizer = torch.optim.AdamW(BART_model.parameters(), lr=lr)
if len(devices) > 1:
    print("Let's use", len(devices), "GPUs!")
    # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
    BART_model = nn.DataParallel(BART_model, device_ids = devices)


In [None]:
batch_size = 64
train_dataloader = DataLoader(BART_pretrain_Dataset([TF_in_data, TF_in_mask, ID_list, ID_year_list, ID_train]),
                                batch_size=batch_size, shuffle=True, collate_fn=BART_collate_fn)
test_dataloader = DataLoader(BART_pretrain_Dataset([TF_in_data, TF_in_mask, ID_list, ID_year_list, ID_test]),
                                batch_size=batch_size, shuffle=True, collate_fn=BART_collate_fn)

In [None]:
n_epochs = 200
timer_use = tqdm.tqdm(range(n_epochs))

for epoch in timer_use:
    loss_train_all = 0
    loss_test_all = 0
    BART_model.train()
    len_train = 0
    pear_mask,pear_x,pear_y = [],[],[]
    for Enc_in, Enc_mask, Dec_in, Dec_mask, Out_mask, Out_data, IDs in train_dataloader:
        Enc_in = Enc_in.to(device) # N,10,26
        Dec_in = Dec_in.to(device) # N,10,26
        Enc_mask = Enc_mask.to(device) # N,10,1
        Dec_mask = Dec_mask.to(device) # N,10,1
        Out_mask = Out_mask.to(device)
        Out_data = Out_data.to(device)
        A,_ = BART_model([Enc_in, Enc_mask, Dec_in, Dec_mask])
        len_train += Enc_in.shape[0]
        loss_MLM = ((Out_mask)*(A - Out_data)).pow(2).sum()/((Out_mask).sum())
        loss = loss_MLM#+loss_Re
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_train_all += loss.item()*Enc_in.shape[0]
        
    loss_train_all /= len_train
    BART_model.eval()
    len_test = 0
    pear_mask,pear_x,pear_y = [],[],[]
    for Enc_in, Enc_mask, Dec_in, Dec_mask, Out_mask, Out_data, IDs in test_dataloader:
        Enc_in = Enc_in.to(device) # N,10,26
        Dec_in = Dec_in.to(device) # N,10,26
        Enc_mask = Enc_mask.to(device) # N,10,1
        Dec_mask = Dec_mask.to(device) # N,10,1
        Out_mask = Out_mask.to(device)
        Out_data = Out_data.to(device)
        A,_ = BART_model([Enc_in, Enc_mask, Dec_in, Dec_mask])
        len_test += Enc_in.shape[0]
        loss_MLM = ((Out_mask)*(A - Out_data)).pow(2).sum()/((Out_mask).sum())
        loss = loss_MLM#+loss_Re
        optimizer.zero_grad()
        loss_test_all += loss.item()*Enc_in.shape[0]
        
    loss_test_all /= len_test
    discript = f"model B | E {epoch} |"
    discript += f"Loss :"
    discript += f"Train : {loss_train_all:.3f} # {len_train} Validate : {loss_test_all:.3f} # {len_test}"
    timer_use.set_description(discript)

In [None]:
model_dir = "model_save"
# save_name = "230913_pretrained_BART"
save_name = "250320_pretrained_BART"
torch.save(BART_model, f"{model_dir}/{save_name}.pth")

In [None]:

plt.figure(figsize=(10,10))
for i in range(50):
    A_sub = A_all_val[:,:,i].reshape(-1)[O_mask_all.reshape(-1)]
    Data_sub = O_data_all[:,:,i].reshape(-1)[O_mask_all.reshape(-1)]
    plt.subplot(8,8,i+1)
    max_l = max(np.max(A_sub),np.max(Data_sub))
    min_l = min(np.min(A_sub),np.min(Data_sub))
    plt.scatter(Data_sub, A_sub, s = 1, alpha = 0.3)
    plt.xlim(max_l+0.2,min_l-0.2)
    plt.ylim(max_l+0.2,min_l-0.2)

In [None]:
A_all_val = []
O_data_all = []
O_mask_all = []
for Enc_in, Enc_mask, Dec_in, Dec_mask, Out_mask, Out_data, IDs in test_dataloader:
    Enc_in = Enc_in.to(device) # N,10,26
    Dec_in = Dec_in.to(device) # N,10,26
    Enc_mask = Enc_mask.to(device) # N,10,1
    Dec_mask = Dec_mask.to(device) # N,10,1
    Out_mask = Out_mask.to(device)
    Out_data = Out_data.to(device)
    A,B = BART_model([Enc_in, Enc_mask, Dec_in, Dec_mask])
    A_all_val = A.detach().cpu().numpy() if len(A_all_val) == 0 else np.concatenate([A_all_val,A.detach().cpu().numpy()],axis = 0)
    O_data_all = Out_data.detach().cpu().numpy() if len(O_data_all) == 0 else np.concatenate([O_data_all,Out_data.detach().cpu().numpy()],axis = 0)
    O_mask_all = Out_mask.detach().cpu().numpy() if len(O_mask_all) == 0 else np.concatenate([O_mask_all,Out_mask.detach().cpu().numpy()],axis = 0)

In [None]:

plt.figure(figsize=(10,10))
for i in range(50):
    A_sub = A_all_val[:,:,i].reshape(-1)[O_mask_all.reshape(-1)]
    Data_sub = O_data_all[:,:,i].reshape(-1)[O_mask_all.reshape(-1)]
    plt.subplot(8,8,i+1)
    max_l = max(np.max(A_sub),np.max(Data_sub))
    min_l = min(np.min(A_sub),np.min(Data_sub))
    plt.scatter(Data_sub, A_sub, s = 1, alpha = 0.3)
    plt.xlim(max_l+0.2,min_l-0.2)
    plt.ylim(max_l+0.2,min_l-0.2)

In [None]:
idx = np.argmin(Enc_mask.sum(axis = [1,2]).detach().cpu().numpy())
for j in range(len(B)):
    print(j)
    for i in range(8):
        plt.subplot(2,4,i+1)
        plt.imshow(B[j][idx,i].detach().cpu().numpy(), vmin = 0, vmax = 1)
    plt.show()