In [2]:

import os
gpu_ids = "1,7"
#     gpu_ids = "4,5,6,7"
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_ids  
import pandas as pd
from audtorch.metrics.functional import pearsonr
from mydata1D import dataGenerator,get_Dataframe_Data
from mydata1T import dataGenerator as dataG
from torch import nn
from tqdm import tqdm
from torch.utils.data import Dataset,ConcatDataset,DataLoader
import warnings
import torch
from torch.optim.lr_scheduler import LambdaLR,CosineAnnealingLR
from model.au128 import UNetWithTransformerEncoder
import logging  
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.cuda.amp import GradScaler, autocast
from accelerate import Accelerator
from utils.utils_datasetDisease import *
import lmdb
from noise import UNetn
import pytorch_warmup as warmup
from model.SQET import SQET



class NoiseModel(nn.Module):
    def __init__(
        self,
        util_model,
        learning_rate=1e-3,
        noise_coeff=15,
        min_scale=0,
        max_scale=1,
        batch_size=16,
        pretrained=None
    ):
        super().__init__()
        self.util_model = util_model
        # for layer in self.util_model.layers:
        #     layer.trainable = False
        for param in self.util_model.parameters():
            param.requires_grad = False

        self.noise_model = UNetn()


        self.normal = torch.distributions.normal.Normal(0, 1)
        self.learning_rate = learning_rate
        self.min_scale = min_scale
        self.max_scale = max_scale
        self.noise_coeff = noise_coeff
        self.criterion = nn.MSELoss()
        self.batch_size = batch_size

    def forward(self, x):

        B = torch.sigmoid(self.noise_model(x))

        # sample from normal  distribution
#         epsilon = self.normal.sample(B.shape).type_as(B)

        # reparametiation trick
        # print('B',torch.mean(B),torch.max(B),torch.min(B))
        # print('epsilon',torch.mean(epsilon),torch.max(epsilon),torch.min(epsilon))
        noise = B * 10
        age_pred = self.util_model((x + noise).float()).squeeze()

        return B,noise,age_pred

def train(model,criterion,L1,optimizer,train_loader,epoch,epochs,lr,k,scheduler,warmup_scheduler,device):
    model.train()
    Loss = 0
    BS = 0
    MSE = 0
    bloss = 10000  
    with torch.no_grad():
        for imgs,age in tqdm(train_loader):
            optimizer.zero_grad()
            imgs, age = imgs.to(device), age.to(device)
            age_pred = model(imgs)['final']


            mse = criterion(age_pred.float(), age.float())
            MSE += mse.data


    #         print(f'loss:{loss}-----mse:{mse}---------bs:{bs}')

        MSE = MSE / len(train_loader)
        print(f'{epoch}/{epochs}------MSE:{MSE}')
#     if Loss < bloss:
#         torch.save(model.module.state_dict(), f"unoise{k}.pth")
#             torch.save(model.state_dict(),'net_params.pth.')
#         torch.save(model.state_dict(), f"unoise{k}.pth")
        
    with warmup_scheduler.dampening():
        if epoch < 10:
            pass
        else:
            scheduler.step()



def main(dataframe_paths, env, batch_size, epochs,k):
    accelerator = Accelerator(split_batches=True)

    # 数据准备

    
    # 实例化dataGenerator
    # train_dataset = dataGenerator(train_IXI, lmdb_path)
    # val_dataset = dataGenerator(val_IXI, lmdb_path)
    train_dataset = my_dataset(dataframe_paths[0], env, True, 0)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=0, pin_memory=True,drop_last=True)    # nw 一般为0


    # train_dataset = dataGenerator(train_IXI, lmdb_path)
    # val_dataset = dataGenerator(val_IXI, lmdb_path)

    # train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=0, pin_memory=True,drop_last=True)
    # val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=0, pin_memory=True,drop_last=True)
    # 模型配置
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    dict = {'dim': (128, 256, 512),
            'depth':(2, 8, 2),
            'global_window_size': (8, 4, 2),
            'use_se': True,
            'use_pse': False,
            'use_sequence_pooling': True,
            'input_shape': [128, 128, 128]
            }


    # TODO:实例化模型
    model = SQET(
        dim=dict['dim'],  # dimension at each stage
        depth=dict['depth'],  # depth of transformer at each stage
        global_window_size=dict['global_window_size'],  # global window sizes at each stage
        use_se=dict['use_se'],
        use_pse=dict['use_pse'],
        use_sequence_pooling=dict['use_sequence_pooling'],
        input_shape=dict['input_shape'],
        attn_dropout=0.,
        ff_dropout=0.3
    ).to(device)
    model_path = "./results_checkpoint/SQET/checkpoint_best.tar"
    
#     util_model = UNetWithTransformerEncoder(32, 8, 6, 128, 0.1).to(device)
    checkpoint1 = torch.load(model_path)
    model.load_state_dict(checkpoint1['model_state_dict'], strict=False)
#     model_path = "./results_checkpoint/Unet/checkpoint_best.tar"

#     model = UNetWithTransformerEncoder(32, 8, 6, 128, 0.1).to(device)
#     checkpoint1 = torch.load(model_path)
#     model.load_state_dict(checkpoint1['model_state_dict'], strict=False)

    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-2, weight_decay=1e-2)
    
    # optimizer = torch.optim.SGD(params=model.parameters(),lr = 0.00001, momentum=0.5)
    #     scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.9, patience=2, verbose=False, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=1e-8, eps=1e-08)
    # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.7, patience=2, verbose=False, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=1e-6, eps=1e-08)
    scheduler =  CosineAnnealingLR(optimizer, T_max=50, eta_min=1e-5)
    # scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=100, T_mult=2, eta_min=7e-6)
    # scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.98)
    #     optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate,
    #                                   weight_decay=args.weight_decay)
    #     # TODO 学习率
    #     scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=500, T_mult=2, eta_min=2e-5)
    #     # scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.98)
    warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period=4)
    # accelerator.print('training...')
    criterion = nn.MSELoss()
    L1 = nn.L1Loss()
    
    # 训练循环
    for epoch in range(epochs):
        # ... 训练和验证过程 ...
        # 记录日志、保存模型等
        lr = optimizer.state_dict()['param_groups'][0]['lr']
        train(model,criterion,L1,optimizer,train_loader,epoch,epochs,lr,k,scheduler,warmup_scheduler,device)
        


            
            
if __name__ == '__main__':
    
    
    k = 'SZ'
    dataframe_paths = [f'csvs/{k}.csv', f'csvs/{k}.csv']
    # dataframe_paths = ['5-fold/train_2.csv', '5-fold/val_2.csv']
    batch_size = 1
    epochs = 1
    env = lmdb.open("/home/caojiaxiang/disease_data", readonly=True, lock=False, readahead=False,
                        meminit=False)
    main(dataframe_paths, env, batch_size, epochs,k)


Detected kernel version 4.15.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
100%|██████████| 72/72 [00:11<00:00,  6.11it/s]

0/1------MSE:102.31681060791016





In [1]:
#SZ
import os
gpu_ids = "1,7"
#     gpu_ids = "4,5,6,7"
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_ids  
import pandas as pd
from audtorch.metrics.functional import pearsonr
from mydata1D import dataGenerator,get_Dataframe_Data
from mydata1T import dataGenerator as dataG
from torch import nn
from tqdm import tqdm
from torch.utils.data import Dataset,ConcatDataset,DataLoader
import warnings
import torch
from torch.optim.lr_scheduler import LambdaLR,CosineAnnealingLR
from model.au128 import UNetWithTransformerEncoder
import logging  
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.cuda.amp import GradScaler, autocast
from accelerate import Accelerator
from utils.utils_datasetDisease import *
import lmdb
from noise import UNetn
import pytorch_warmup as warmup
from model.SQET import SQET

torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.allow_tf32 = True


class NoiseModel(nn.Module):
    def __init__(
        self,
        util_model,
        learning_rate=1e-3,
        noise_coeff=15,
        min_scale=0,
        max_scale=1,
        batch_size=16,
        pretrained=None
    ):
        super().__init__()
        self.util_model = util_model
        # for layer in self.util_model.layers:
        #     layer.trainable = False
        for param in self.util_model.parameters():
            param.requires_grad = False

        self.noise_model = UNetn()


        self.normal = torch.distributions.normal.Normal(0, 1)
        self.learning_rate = learning_rate
        self.min_scale = min_scale
        self.max_scale = max_scale
        self.noise_coeff = noise_coeff
        self.criterion = nn.MSELoss()
        self.batch_size = batch_size

    def forward(self, x):

        B = torch.sigmoid(self.noise_model(x))

        # sample from normal  distribution
#         epsilon = self.normal.sample(B.shape).type_as(B)

        # reparametiation trick
        # print('B',torch.mean(B),torch.max(B),torch.min(B))
        # print('epsilon',torch.mean(epsilon),torch.max(epsilon),torch.min(epsilon))
        noise = B * 10
        age_pred = self.util_model((x + noise).float()).squeeze()

        return B,noise,age_pred

def train(model,criterion,L1,optimizer,train_loader,epoch,epochs,lr,k,scheduler,warmup_scheduler,device):
    model.train()
    Loss = 0
    BS = 0
    MSE = 0
    bloss = 10000    
    for imgs,age in tqdm(train_loader):
        optimizer.zero_grad()
        imgs, age = imgs.to(device), age.to(device)
        B,noise,age_pred = model(imgs)
        bs = torch.mean(B.log())
        
        mse = criterion(age_pred.float(), age.float())
        loss =  mse * 1.5 - torch.var(B) * 10 ** 2 - 75 * bs
        Loss += loss.data
        BS += bs.data
        MSE += mse.data
        
        loss.backward()
        optimizer.step()
#         print(f'loss:{loss}-----mse:{mse}---------bs:{bs}')
    Loss = Loss / len(train_loader)
    BS = BS / len(train_loader)
    MSE = MSE / len(train_loader)
    print(f'{epoch}/{epochs}-------Loss:{Loss}----BS:{BS}----MSE:{MSE}')
    if Loss < bloss:
#         torch.save(model.module.state_dict(), f"unoise{k}.pth")
#             torch.save(model.state_dict(),'net_params.pth.')
        torch.save(model.state_dict(), f"unoise{k}.pth")
        
    with warmup_scheduler.dampening():
        if epoch < 10:
            pass
        else:
            scheduler.step()



def main(dataframe_paths, env, batch_size, epochs,k):
    accelerator = Accelerator(split_batches=True)

    # 数据准备

    
    # 实例化dataGenerator
    # train_dataset = dataGenerator(train_IXI, lmdb_path)
    # val_dataset = dataGenerator(val_IXI, lmdb_path)
    train_dataset = my_dataset(dataframe_paths[0], env, True, 0)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=0, pin_memory=True,drop_last=True)    # nw 一般为0


    # train_dataset = dataGenerator(train_IXI, lmdb_path)
    # val_dataset = dataGenerator(val_IXI, lmdb_path)

    # train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=0, pin_memory=True,drop_last=True)
    # val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=0, pin_memory=True,drop_last=True)
    # 模型配置
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    model_path = "./results_checkpoint/Unet/checkpoint_best.tar"
    
    util_model = UNetWithTransformerEncoder(32, 8, 6, 128, 0.1).to(device)
    checkpoint1 = torch.load(model_path)
    util_model.load_state_dict(checkpoint1['model_state_dict'], strict=False)
    model = NoiseModel(util_model).to(device)
#     model = nn.DataParallel(model).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2)
    
    # optimizer = torch.optim.SGD(params=model.parameters(),lr = 0.00001, momentum=0.5)
    #     scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.9, patience=2, verbose=False, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=1e-8, eps=1e-08)
    # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.7, patience=2, verbose=False, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=1e-6, eps=1e-08)
    scheduler =  CosineAnnealingLR(optimizer, T_max=50, eta_min=1e-5)
    # scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=100, T_mult=2, eta_min=7e-6)
    # scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.98)
    #     optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate,
    #                                   weight_decay=args.weight_decay)
    #     # TODO 学习率
    #     scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=500, T_mult=2, eta_min=2e-5)
    #     # scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.98)
    warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period=4)
    # accelerator.print('training...')
    criterion = nn.MSELoss()
    L1 = nn.L1Loss()
    
    # 训练循环
    for epoch in range(epochs):
        # ... 训练和验证过程 ...
        # 记录日志、保存模型等
        lr = optimizer.state_dict()['param_groups'][0]['lr']
        train(model,criterion,L1,optimizer,train_loader,epoch,epochs,lr,k,scheduler,warmup_scheduler,device)
        


            
            
if __name__ == '__main__':
    
    
    k = 'SZ'
    dataframe_paths = [f'csvs/{k}.csv', f'csvs/{k}.csv']
    # dataframe_paths = ['5-fold/train_2.csv', '5-fold/val_2.csv']
    batch_size = 1
    epochs = 200
    env = lmdb.open("/home/caojiaxiang/disease_data", readonly=True, lock=False, readahead=False,
                        meminit=False)
    main(dataframe_paths, env, batch_size, epochs,k)


dataloader_config = DataLoaderConfiguration(split_batches=True)
Detected kernel version 4.15.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
  ret = func(*args, **kwargs)
100%|██████████| 72/72 [00:39<00:00,  1.81it/s]


0/200-------Loss:295.7616882324219----BS:-0.8573681116104126----MSE:154.72476196289062


100%|██████████| 72/72 [00:32<00:00,  2.20it/s]


1/200-------Loss:208.47500610351562----BS:-0.8367462158203125----MSE:97.7495346069336


100%|██████████| 72/72 [00:33<00:00,  2.16it/s]


2/200-------Loss:272.3033142089844----BS:-0.794295608997345----MSE:142.5313720703125


100%|██████████| 72/72 [00:33<00:00,  2.15it/s]


3/200-------Loss:298.59552001953125----BS:-0.7611318826675415----MSE:162.9511260986328


100%|██████████| 72/72 [00:33<00:00,  2.15it/s]


4/200-------Loss:219.5318603515625----BS:-0.7569014430046082----MSE:111.18446350097656


100%|██████████| 72/72 [00:33<00:00,  2.16it/s]


5/200-------Loss:151.51113891601562----BS:-0.7153610587120056----MSE:67.63530731201172


100%|██████████| 72/72 [00:33<00:00,  2.15it/s]


6/200-------Loss:125.76615142822266----BS:-0.6699857115745544----MSE:52.5391960144043


100%|██████████| 72/72 [00:33<00:00,  2.15it/s]


7/200-------Loss:105.5341796875----BS:-0.6404867768287659----MSE:40.44851303100586


100%|██████████| 72/72 [00:33<00:00,  2.14it/s]


8/200-------Loss:82.35935974121094----BS:-0.6094357371330261----MSE:26.436569213867188


100%|██████████| 72/72 [00:33<00:00,  2.14it/s]


9/200-------Loss:63.21118927001953----BS:-0.5811030268669128----MSE:14.870614051818848


100%|██████████| 72/72 [00:33<00:00,  2.14it/s]


10/200-------Loss:74.00394439697266----BS:-0.5506337881088257----MSE:23.55191993713379


100%|██████████| 72/72 [00:33<00:00,  2.13it/s]


11/200-------Loss:67.7399673461914----BS:-0.5409255623817444----MSE:19.80045509338379


100%|██████████| 72/72 [00:33<00:00,  2.14it/s]


12/200-------Loss:67.65064239501953----BS:-0.5140187740325928----MSE:20.936927795410156


100%|██████████| 72/72 [00:33<00:00,  2.13it/s]


13/200-------Loss:61.48275375366211----BS:-0.49984583258628845----MSE:17.430936813354492


100%|██████████| 72/72 [00:33<00:00,  2.15it/s]


14/200-------Loss:59.0590705871582----BS:-0.47946134209632874----MSE:16.733867645263672


100%|██████████| 72/72 [00:33<00:00,  2.15it/s]


15/200-------Loss:60.38368606567383----BS:-0.46166977286338806----MSE:18.452743530273438


100%|██████████| 72/72 [00:33<00:00,  2.15it/s]


16/200-------Loss:46.3214111328125----BS:-0.4322388470172882----MSE:10.416333198547363


100%|██████████| 72/72 [00:33<00:00,  2.16it/s]


17/200-------Loss:41.06838607788086----BS:-0.4093448519706726----MSE:7.997816562652588


100%|██████████| 72/72 [00:33<00:00,  2.16it/s]


18/200-------Loss:35.55615997314453----BS:-0.3866596519947052----MSE:5.409740924835205


100%|██████████| 72/72 [00:33<00:00,  2.16it/s]


19/200-------Loss:34.561180114746094----BS:-0.36274051666259766----MSE:5.8009514808654785


100%|██████████| 72/72 [00:33<00:00,  2.17it/s]


20/200-------Loss:33.03429412841797----BS:-0.3413909375667572----MSE:5.770875930786133


100%|██████████| 72/72 [00:33<00:00,  2.17it/s]


21/200-------Loss:32.35728454589844----BS:-0.31872689723968506----MSE:6.35360050201416


100%|██████████| 72/72 [00:33<00:00,  2.17it/s]


22/200-------Loss:36.14576721191406----BS:-0.3063472807407379----MSE:9.439452171325684


100%|██████████| 72/72 [00:33<00:00,  2.17it/s]


23/200-------Loss:40.0413703918457----BS:-0.2871796488761902----MSE:13.008686065673828


100%|██████████| 72/72 [00:33<00:00,  2.17it/s]


24/200-------Loss:30.12384033203125----BS:-0.2754049301147461----MSE:6.943163871765137


100%|██████████| 72/72 [00:33<00:00,  2.17it/s]


25/200-------Loss:28.065881729125977----BS:-0.2618549168109894----MSE:6.196703910827637


100%|██████████| 72/72 [00:33<00:00,  2.18it/s]


26/200-------Loss:31.755889892578125----BS:-0.2499004751443863----MSE:9.21010684967041


100%|██████████| 72/72 [00:33<00:00,  2.18it/s]


27/200-------Loss:28.75111961364746----BS:-0.23832488059997559----MSE:7.7594146728515625


100%|██████████| 72/72 [00:33<00:00,  2.17it/s]


28/200-------Loss:25.08924102783203----BS:-0.22598235309123993----MSE:5.897013187408447


100%|██████████| 72/72 [00:33<00:00,  2.15it/s]


29/200-------Loss:23.724843978881836----BS:-0.21146471798419952----MSE:5.663022994995117


100%|██████████| 72/72 [00:33<00:00,  2.16it/s]


30/200-------Loss:24.696657180786133----BS:-0.20093758404254913----MSE:6.820357799530029


100%|██████████| 72/72 [00:33<00:00,  2.15it/s]


31/200-------Loss:24.508968353271484----BS:-0.18981964886188507----MSE:7.2150444984436035


100%|██████████| 72/72 [00:33<00:00,  2.15it/s]


32/200-------Loss:19.210006713867188----BS:-0.1823180764913559----MSE:4.050852298736572


100%|██████████| 72/72 [00:33<00:00,  2.15it/s]


33/200-------Loss:17.7780704498291----BS:-0.1711208075284958----MSE:3.628092050552368


100%|██████████| 72/72 [00:33<00:00,  2.15it/s]


34/200-------Loss:14.852238655090332----BS:-0.16496486961841583----MSE:1.9834572076797485


100%|██████████| 72/72 [00:33<00:00,  2.15it/s]


35/200-------Loss:14.408146858215332----BS:-0.1545296162366867----MSE:2.1759884357452393


100%|██████████| 72/72 [00:33<00:00,  2.15it/s]


36/200-------Loss:13.93306827545166----BS:-0.14693784713745117----MSE:2.218299627304077


100%|██████████| 72/72 [00:33<00:00,  2.15it/s]


37/200-------Loss:12.671856880187988----BS:-0.14097140729427338----MSE:1.6731027364730835


100%|██████████| 72/72 [00:33<00:00,  2.15it/s]


38/200-------Loss:12.396400451660156----BS:-0.13497261703014374----MSE:1.7739055156707764


 90%|█████████ | 65/72 [00:30<00:03,  2.12it/s]


KeyboardInterrupt: 